update config.

This commit is contained in:
xuming06 2018-03-30 16:49:55 +08:00
parent 10f2385aca
commit fe63f7a1fb
4 changed files with 16 additions and 12 deletions

View File

@ -5,7 +5,7 @@
# CGED chinese corpus # CGED chinese corpus
train_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' # Training data path. train_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' # Training data path.
val_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml' # Validation data path. val_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml' # Validation data path.
test_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' test_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml'
model_path = './output/cged_model' # Path of the model saved, default is output_path/model model_path = './output/cged_model' # Path of the model saved, default is output_path/model
enable_special_error = False enable_special_error = False
@ -15,12 +15,12 @@ decode_sentence = False # Whether we should decode sentences of the user.
# Config # Config
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency. buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
steps_per_checkpoint = 100 steps_per_checkpoint = 100
max_steps = 2000 max_steps = 10000
max_vocab_size = 10000 max_vocab_size = 10000
size = 512 size = 512
num_layers = 4 num_layers = 4
max_gradient_norm = 5.0 max_gradient_norm = 5.0
batch_size = 64 batch_size = 128
learning_rate = 0.5 learning_rate = 0.5
learning_rate_decay_factor = 0.99 learning_rate_decay_factor = 0.99
use_lstm = False use_lstm = False

View File

@ -93,11 +93,16 @@ class CGEDReader(Reader):
def unknown_token(self): def unknown_token(self):
return CGEDReader.UNKNOWN_TOKEN return CGEDReader.UNKNOWN_TOKEN
def read_tokens(self, path): def read_tokens(self, path, is_infer=False):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
dom_tree = minidom.parse(f) dom_tree = minidom.parse(f)
docs = dom_tree.documentElement.getElementsByTagName('DOC') docs = dom_tree.documentElement.getElementsByTagName('DOC')
for doc in docs: for doc in docs:
if is_infer:
# Input the error text
sentence = doc.getElementsByTagName('TEXT')[0]. \
childNodes[0].data.strip()
else:
# Input the correct text # Input the correct text
sentence = doc.getElementsByTagName('CORRECTION')[0]. \ sentence = doc.getElementsByTagName('CORRECTION')[0]. \
childNodes[0].data.strip() childNodes[0].data.strip()

View File

@ -7,7 +7,7 @@ train_path = '../data/en/fce/fce_train.txt' # Training data path.
val_path = '../data/en/fce/fce_val.txt' # Validation data path. val_path = '../data/en/fce/fce_val.txt' # Validation data path.
test_path = '../data/en/fce/fce_test.txt' test_path = '../data/en/fce/fce_test.txt'
model_path = './output_model' # Path of the model saved, default is output_path/model model_path = './output/fce_model' # Path of the model saved, default is output_path/model
enable_special_error = False enable_special_error = False
num_steps = 3000 # Number of steps to train. num_steps = 3000 # Number of steps to train.
decode_sentence = False # Whether we should decode sentences of the user. decode_sentence = False # Whether we should decode sentences of the user.

View File

@ -8,11 +8,10 @@ from collections import defaultdict
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import fce_config
import cged_config import cged_config
from corpus_reader import FCEReader
from corpus_reader import CGEDReader from corpus_reader import CGEDReader
from reader import EOS_ID from reader import EOS_ID
from text_util import segment
from train import create_model from train import create_model
@ -97,7 +96,7 @@ def decode(sess, model, data_reader, data_to_decode,
def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(), def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(),
verbose=True): verbose=True):
"""Used with InteractiveSession in IPython """ """Used with InteractiveSession in IPython """
return next(decode(sess, model, data_reader, [sentence.split()], return next(decode(sess, model, data_reader, [segment(sentence, 'char')],
corrective_tokens=corrective_tokens, verbose=verbose)) corrective_tokens=corrective_tokens, verbose=verbose))
@ -196,7 +195,7 @@ def main(_):
model = create_model(session, True, model_path, config=cged_config) model = create_model(session, True, model_path, config=cged_config)
print("Loaded model. Beginning decoding.") print("Loaded model. Beginning decoding.")
decodings = decode(session, model=model, data_reader=data_reader, decodings = decode(session, model=model, data_reader=data_reader,
data_to_decode=data_reader.read_tokens(cged_config.test_path), data_to_decode=data_reader.read_tokens(cged_config.test_path, is_infer=True),
corrective_tokens=data_reader.read_tokens(cged_config.train_path)) corrective_tokens=data_reader.read_tokens(cged_config.train_path))
# Write the decoded tokens to stdout. # Write the decoded tokens to stdout.
for tokens in decodings: for tokens in decodings: