update config.
This commit is contained in:
parent
10f2385aca
commit
fe63f7a1fb
|
@ -5,7 +5,7 @@
|
||||||
# CGED chinese corpus
|
# CGED chinese corpus
|
||||||
train_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' # Training data path.
|
train_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' # Training data path.
|
||||||
val_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml' # Validation data path.
|
val_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml' # Validation data path.
|
||||||
test_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml'
|
test_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml'
|
||||||
|
|
||||||
model_path = './output/cged_model' # Path of the model saved, default is output_path/model
|
model_path = './output/cged_model' # Path of the model saved, default is output_path/model
|
||||||
enable_special_error = False
|
enable_special_error = False
|
||||||
|
@ -15,12 +15,12 @@ decode_sentence = False # Whether we should decode sentences of the user.
|
||||||
# Config
|
# Config
|
||||||
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
||||||
steps_per_checkpoint = 100
|
steps_per_checkpoint = 100
|
||||||
max_steps = 2000
|
max_steps = 10000
|
||||||
max_vocab_size = 10000
|
max_vocab_size = 10000
|
||||||
size = 512
|
size = 512
|
||||||
num_layers = 4
|
num_layers = 4
|
||||||
max_gradient_norm = 5.0
|
max_gradient_norm = 5.0
|
||||||
batch_size = 64
|
batch_size = 128
|
||||||
learning_rate = 0.5
|
learning_rate = 0.5
|
||||||
learning_rate_decay_factor = 0.99
|
learning_rate_decay_factor = 0.99
|
||||||
use_lstm = False
|
use_lstm = False
|
||||||
|
|
|
@ -93,11 +93,16 @@ class CGEDReader(Reader):
|
||||||
def unknown_token(self):
|
def unknown_token(self):
|
||||||
return CGEDReader.UNKNOWN_TOKEN
|
return CGEDReader.UNKNOWN_TOKEN
|
||||||
|
|
||||||
def read_tokens(self, path):
|
def read_tokens(self, path, is_infer=False):
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
dom_tree = minidom.parse(f)
|
dom_tree = minidom.parse(f)
|
||||||
docs = dom_tree.documentElement.getElementsByTagName('DOC')
|
docs = dom_tree.documentElement.getElementsByTagName('DOC')
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
if is_infer:
|
||||||
|
# Input the error text
|
||||||
|
sentence = doc.getElementsByTagName('TEXT')[0]. \
|
||||||
|
childNodes[0].data.strip()
|
||||||
|
else:
|
||||||
# Input the correct text
|
# Input the correct text
|
||||||
sentence = doc.getElementsByTagName('CORRECTION')[0]. \
|
sentence = doc.getElementsByTagName('CORRECTION')[0]. \
|
||||||
childNodes[0].data.strip()
|
childNodes[0].data.strip()
|
||||||
|
|
|
@ -7,7 +7,7 @@ train_path = '../data/en/fce/fce_train.txt' # Training data path.
|
||||||
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
||||||
test_path = '../data/en/fce/fce_test.txt'
|
test_path = '../data/en/fce/fce_test.txt'
|
||||||
|
|
||||||
model_path = './output_model' # Path of the model saved, default is output_path/model
|
model_path = './output/fce_model' # Path of the model saved, default is output_path/model
|
||||||
enable_special_error = False
|
enable_special_error = False
|
||||||
num_steps = 3000 # Number of steps to train.
|
num_steps = 3000 # Number of steps to train.
|
||||||
decode_sentence = False # Whether we should decode sentences of the user.
|
decode_sentence = False # Whether we should decode sentences of the user.
|
||||||
|
|
|
@ -8,11 +8,10 @@ from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import fce_config
|
|
||||||
import cged_config
|
import cged_config
|
||||||
from corpus_reader import FCEReader
|
|
||||||
from corpus_reader import CGEDReader
|
from corpus_reader import CGEDReader
|
||||||
from reader import EOS_ID
|
from reader import EOS_ID
|
||||||
|
from text_util import segment
|
||||||
from train import create_model
|
from train import create_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ def decode(sess, model, data_reader, data_to_decode,
|
||||||
def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(),
|
def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(),
|
||||||
verbose=True):
|
verbose=True):
|
||||||
"""Used with InteractiveSession in IPython """
|
"""Used with InteractiveSession in IPython """
|
||||||
return next(decode(sess, model, data_reader, [sentence.split()],
|
return next(decode(sess, model, data_reader, [segment(sentence, 'char')],
|
||||||
corrective_tokens=corrective_tokens, verbose=verbose))
|
corrective_tokens=corrective_tokens, verbose=verbose))
|
||||||
|
|
||||||
|
|
||||||
|
@ -196,7 +195,7 @@ def main(_):
|
||||||
model = create_model(session, True, model_path, config=cged_config)
|
model = create_model(session, True, model_path, config=cged_config)
|
||||||
print("Loaded model. Beginning decoding.")
|
print("Loaded model. Beginning decoding.")
|
||||||
decodings = decode(session, model=model, data_reader=data_reader,
|
decodings = decode(session, model=model, data_reader=data_reader,
|
||||||
data_to_decode=data_reader.read_tokens(cged_config.test_path),
|
data_to_decode=data_reader.read_tokens(cged_config.test_path, is_infer=True),
|
||||||
corrective_tokens=data_reader.read_tokens(cged_config.train_path))
|
corrective_tokens=data_reader.read_tokens(cged_config.train_path))
|
||||||
# Write the decoded tokens to stdout.
|
# Write the decoded tokens to stdout.
|
||||||
for tokens in decodings:
|
for tokens in decodings:
|
||||||
|
|
Loading…
Reference in New Issue