update rnn_crf and seq2seq with model train.
This commit is contained in:
parent
e482c736a3
commit
1d1d5f4aa6
|
@ -50,7 +50,6 @@ def save_preds(preds, test_ids, X_test, ids_word_dict,
|
||||||
for j in range(len(sent_ids)):
|
for j in range(len(sent_ids)):
|
||||||
if sent_ids[j] != 0:
|
if sent_ids[j] != 0:
|
||||||
label.append(preds[i][j])
|
label.append(preds[i][j])
|
||||||
print(label)
|
|
||||||
continue_error = False
|
continue_error = False
|
||||||
has_error = False
|
has_error = False
|
||||||
current_error = 0
|
current_error = 0
|
||||||
|
@ -80,7 +79,9 @@ def save_preds(preds, test_ids, X_test, ids_word_dict,
|
||||||
|
|
||||||
|
|
||||||
def is_error_label_id(label_id, label_ids_dict):
|
def is_error_label_id(label_id, label_ids_dict):
|
||||||
return label_id != label_ids_dict['O']
|
# return label_id != label_ids_dict['O']
|
||||||
|
return label_id == label_ids_dict['M'] or label_id == label_ids_dict['R'] or label_id == label_ids_dict[
|
||||||
|
'S'] or label_id == label_ids_dict['W']
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -30,7 +30,8 @@ def parse_xml_file(path):
|
||||||
error_type_change = 'B-' + error_type
|
error_type_change = 'B-' + error_type
|
||||||
else:
|
else:
|
||||||
error_type_change = 'I-' + error_type
|
error_type_change = 'I-' + error_type
|
||||||
locate_dict[i] = error_type_change
|
# locate_dict[i] = error_type_change
|
||||||
|
locate_dict[i] = error_type
|
||||||
# Segment with pos
|
# Segment with pos
|
||||||
word_seq, pos_seq = segment(text, cut_type='char', pos=True)
|
word_seq, pos_seq = segment(text, cut_type='char', pos=True)
|
||||||
word_arr, label_arr = [], []
|
word_arr, label_arr = [], []
|
||||||
|
@ -69,7 +70,8 @@ def parse_txt_file(input_path, truth_path):
|
||||||
error_type_change = 'B-' + error_type
|
error_type_change = 'B-' + error_type
|
||||||
else:
|
else:
|
||||||
error_type_change = 'I-' + error_type
|
error_type_change = 'I-' + error_type
|
||||||
locate_dict[i] = error_type_change
|
# locate_dict[i] = error_type_change
|
||||||
|
locate_dict[i] = error_type
|
||||||
# for i in range(int(start_off) - 1, int(end_off)):
|
# for i in range(int(start_off) - 1, int(end_off)):
|
||||||
# locate_dict[i] = error_type
|
# locate_dict[i] = error_type
|
||||||
if text_id in truth_dict:
|
if text_id in truth_dict:
|
||||||
|
|
|
@ -6,17 +6,19 @@ import os
|
||||||
output_dir = './output'
|
output_dir = './output'
|
||||||
|
|
||||||
# CGED chinese corpus
|
# CGED chinese corpus
|
||||||
train_paths = ['../data/cn/CGED/CGED18_HSK_TrainingSet.xml',
|
train_paths = [
|
||||||
'../data/cn/CGED/CGED17_HSK_TrainingSet.xml',
|
'../data/cn/CGED/CGED18_HSK_TrainingSet.xml',
|
||||||
'../data/cn/CGED/CGED16_HSK_TrainingSet.xml',
|
'../data/cn/CGED/CGED17_HSK_TrainingSet.xml',
|
||||||
# '../data/cn/CGED/sample_HSK_TrainingSet.xml',
|
'../data/cn/CGED/CGED16_HSK_TrainingSet.xml',
|
||||||
]
|
# '../data/cn/CGED/sample_HSK_TrainingSet.xml',
|
||||||
|
]
|
||||||
train_word_path = output_dir + '/train_words.txt'
|
train_word_path = output_dir + '/train_words.txt'
|
||||||
train_label_path = output_dir + '/train_labels.txt'
|
train_label_path = output_dir + '/train_labels.txt'
|
||||||
test_paths = {'../data/cn/CGED/CGED16_HSK_Test_Input.txt': '../data/cn/CGED/CGED16_HSK_Test_Truth.txt',
|
test_paths = {
|
||||||
'../data/cn/CGED/CGED17_HSK_Test_Input.txt': '../data/cn/CGED/CGED17_HSK_Test_Truth.txt',
|
'../data/cn/CGED/CGED16_HSK_Test_Input.txt': '../data/cn/CGED/CGED16_HSK_Test_Truth.txt',
|
||||||
# '../data/cn/CGED/sample_HSK_Test_Input.txt': '../data/cn/CGED/sample_HSK_Test_Truth.txt',
|
'../data/cn/CGED/CGED17_HSK_Test_Input.txt': '../data/cn/CGED/CGED17_HSK_Test_Truth.txt',
|
||||||
}
|
# '../data/cn/CGED/sample_HSK_Test_Input.txt': '../data/cn/CGED/sample_HSK_Test_Truth.txt',
|
||||||
|
}
|
||||||
test_word_path = output_dir + '/test_words.txt'
|
test_word_path = output_dir + '/test_words.txt'
|
||||||
test_label_path = output_dir + '/test_labels.txt'
|
test_label_path = output_dir + '/test_labels.txt'
|
||||||
test_id_path = output_dir + '/test_ids.txt'
|
test_id_path = output_dir + '/test_ids.txt'
|
||||||
|
@ -31,7 +33,7 @@ embedding_dim = 100
|
||||||
rnn_hidden_dim = 200
|
rnn_hidden_dim = 200
|
||||||
maxlen = 300
|
maxlen = 300
|
||||||
cutoff_frequency = 5
|
cutoff_frequency = 5
|
||||||
dropout = 0.5
|
dropout = 0.25
|
||||||
save_model_path = output_dir + '/rnn_crf_model.h5' # Path of the model saved, default is output_path/model
|
save_model_path = output_dir + '/rnn_crf_model.h5' # Path of the model saved, default is output_path/model
|
||||||
|
|
||||||
# infer
|
# infer
|
||||||
|
|
|
@ -16,7 +16,7 @@ def create_model(word_dict, label_dict, embedding_dim=100, rnn_hidden_dim=200, d
|
||||||
# build model
|
# build model
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
# embedding
|
# embedding
|
||||||
model.add(Embedding(len(word_dict), embedding_dim))
|
model.add(Embedding(len(word_dict), embedding_dim, mask_zero=True))
|
||||||
# bilstm
|
# bilstm
|
||||||
model.add(Bidirectional(LSTM(rnn_hidden_dim // 2, return_sequences=True,
|
model.add(Bidirectional(LSTM(rnn_hidden_dim // 2, return_sequences=True,
|
||||||
recurrent_dropout=dropout)))
|
recurrent_dropout=dropout)))
|
||||||
|
@ -25,6 +25,7 @@ def create_model(word_dict, label_dict, embedding_dim=100, rnn_hidden_dim=200, d
|
||||||
model.add(crf)
|
model.add(crf)
|
||||||
# loss
|
# loss
|
||||||
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
|
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
|
||||||
|
model.summary()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,11 +45,15 @@ def train(train_word_path=None,
|
||||||
# read data to index
|
# read data to index
|
||||||
word_ids = vectorize_data(train_word_path, word_ids_dict)
|
word_ids = vectorize_data(train_word_path, word_ids_dict)
|
||||||
label_ids = vectorize_data(train_label_path, label_ids_dict)
|
label_ids = vectorize_data(train_label_path, label_ids_dict)
|
||||||
|
max_len = np.max([len(i) for i in word_ids])
|
||||||
|
print('max_len:', max_len)
|
||||||
# pad sequence
|
# pad sequence
|
||||||
word_seq = pad_sequence(word_ids, maxlen=maxlen)
|
word_seq = pad_sequence(word_ids, maxlen=maxlen)
|
||||||
label_seq = pad_sequence(label_ids, maxlen=maxlen)
|
label_seq = pad_sequence(label_ids, maxlen=maxlen)
|
||||||
# reshape label for crf model use
|
# reshape label for crf model use
|
||||||
label_seq = np.reshape(label_seq, (label_seq.shape[0], label_seq.shape[1], 1))
|
label_seq = np.reshape(label_seq, (label_seq.shape[0], label_seq.shape[1], 1))
|
||||||
|
print(word_seq.shape)
|
||||||
|
print(label_seq.shape)
|
||||||
logger.info("Data loaded.")
|
logger.info("Data loaded.")
|
||||||
# model
|
# model
|
||||||
logger.info("Training BILSTM_CRF model...")
|
logger.info("Training BILSTM_CRF model...")
|
||||||
|
|
|
@ -3,36 +3,22 @@
|
||||||
# Brief: Use CGED corpus
|
# Brief: Use CGED corpus
|
||||||
import os
|
import os
|
||||||
|
|
||||||
output_dir = './output'
|
|
||||||
model_path = './output/cged_model' # Path of the model saved, default is output_path/model
|
|
||||||
|
|
||||||
# CGED chinese corpus
|
# CGED chinese corpus
|
||||||
raw_train_paths = ['../data/cn/CGED/CGED18_HSK_TrainingSet.xml',
|
raw_train_paths = [
|
||||||
'../data/cn/CGED/CGED17_HSK_TrainingSet.xml',
|
# '../data/cn/CGED/CGED18_HSK_TrainingSet.xml',
|
||||||
'../data/cn/CGED/CGED16_HSK_TrainingSet.xml',
|
'../data/cn/CGED/CGED17_HSK_TrainingSet.xml',
|
||||||
# '../data/cn/CGED/sample_HSK_TrainingSet.xml',
|
# '../data/cn/CGED/CGED16_HSK_TrainingSet.xml',
|
||||||
]
|
# '../data/cn/CGED/sample_HSK_TrainingSet.xml',
|
||||||
|
]
|
||||||
|
output_dir = './output'
|
||||||
train_path = output_dir + '/train.txt' # Training data path.
|
train_path = output_dir + '/train.txt' # Training data path.
|
||||||
test_path = output_dir + '/test.txt' # Validation data path.
|
test_path = output_dir + '/test.txt' # Validation data path.
|
||||||
num_steps = 3000 # Number of steps to train.
|
|
||||||
decode_sentence = False # Whether we should decode sentences of the user.
|
|
||||||
|
|
||||||
# Config
|
# config
|
||||||
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
|
||||||
steps_per_checkpoint = 100
|
|
||||||
max_steps = 10000
|
|
||||||
max_vocab_size = 10000
|
|
||||||
size = 512
|
|
||||||
num_layers = 4
|
|
||||||
max_gradient_norm = 5.0
|
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
learning_rate = 0.5
|
epochs = 10
|
||||||
learning_rate_decay_factor = 0.99
|
rnn_hidden_dim = 200
|
||||||
use_lstm = False
|
save_model_path = output_dir + '/cged_seq2seq_model.h5' # Path of the model saved
|
||||||
use_rms_prop = False
|
|
||||||
|
|
||||||
enable_decode_sentence = False # Test with input error sentence
|
if not os.path.exists(output_dir):
|
||||||
enable_test_decode = True # Test with test set
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
os.makedirs(model_path)
|
|
||||||
|
|
|
@ -14,12 +14,11 @@ class FCEReader(Reader):
|
||||||
DROPOUT_TOKENS = {"a", "an", "the", "'ll", "'s", "'m", "'ve"}
|
DROPOUT_TOKENS = {"a", "an", "the", "'ll", "'s", "'m", "'ve"}
|
||||||
REPLACEMENTS = {"there": "their", "their": "there", "then": "than", "than": "then"}
|
REPLACEMENTS = {"there": "their", "their": "there", "then": "than", "than": "then"}
|
||||||
|
|
||||||
def __init__(self, config, train_path=None, token_2_id=None,
|
def __init__(self, train_path=None, token_2_id=None,
|
||||||
dropout_prob=0.25, replacement_prob=0.25, dataset_copies=2):
|
dropout_prob=0.25, replacement_prob=0.25):
|
||||||
super(FCEReader, self).__init__(
|
super(FCEReader, self).__init__(
|
||||||
config, train_path=train_path, token_2_id=token_2_id,
|
train_path=train_path, token_2_id=token_2_id,
|
||||||
special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN, FCEReader.UNKNOWN_TOKEN],
|
special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN, FCEReader.UNKNOWN_TOKEN])
|
||||||
dataset_copies=dataset_copies)
|
|
||||||
self.dropout_prob = dropout_prob
|
self.dropout_prob = dropout_prob
|
||||||
self.replacement_prob = replacement_prob
|
self.replacement_prob = replacement_prob
|
||||||
self.UNKNOWN_ID = self.token_2_id[FCEReader.UNKNOWN_TOKEN]
|
self.UNKNOWN_ID = self.token_2_id[FCEReader.UNKNOWN_TOKEN]
|
||||||
|
@ -33,19 +32,6 @@ class FCEReader(Reader):
|
||||||
break
|
break
|
||||||
source = line_src.lower()[5:].strip().split()
|
source = line_src.lower()[5:].strip().split()
|
||||||
target = line_dst.lower()[5:].strip().split()
|
target = line_dst.lower()[5:].strip().split()
|
||||||
if self.config.enable_special_error:
|
|
||||||
new_source = []
|
|
||||||
for token in source:
|
|
||||||
# Random dropout words from the input
|
|
||||||
dropout_token = (token in FCEReader.DROPOUT_TOKENS and
|
|
||||||
random.random() < self.dropout_prob)
|
|
||||||
replace_token = (token in FCEReader.REPLACEMENTS and
|
|
||||||
random.random() < self.replacement_prob)
|
|
||||||
if replace_token:
|
|
||||||
new_source.append(FCEReader.REPLACEMENTS[source])
|
|
||||||
elif not dropout_token:
|
|
||||||
new_source.append(token)
|
|
||||||
source = new_source
|
|
||||||
yield source, target
|
yield source, target
|
||||||
|
|
||||||
def unknown_token(self):
|
def unknown_token(self):
|
||||||
|
@ -68,11 +54,10 @@ class CGEDReader(Reader):
|
||||||
"""
|
"""
|
||||||
UNKNOWN_TOKEN = 'UNK'
|
UNKNOWN_TOKEN = 'UNK'
|
||||||
|
|
||||||
def __init__(self, config, train_path=None, token_2_id=None, dataset_copies=2):
|
def __init__(self, train_path=None, token_2_id=None):
|
||||||
super(CGEDReader, self).__init__(
|
super(CGEDReader, self).__init__(
|
||||||
config, train_path=train_path, token_2_id=token_2_id,
|
train_path=train_path, token_2_id=token_2_id,
|
||||||
special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN, CGEDReader.UNKNOWN_TOKEN],
|
special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN, CGEDReader.UNKNOWN_TOKEN])
|
||||||
dataset_copies=dataset_copies)
|
|
||||||
self.UNKNOWN_ID = self.token_2_id[CGEDReader.UNKNOWN_TOKEN]
|
self.UNKNOWN_ID = self.token_2_id[CGEDReader.UNKNOWN_TOKEN]
|
||||||
|
|
||||||
def read_samples_by_string(self, path):
|
def read_samples_by_string(self, path):
|
||||||
|
@ -98,3 +83,12 @@ class CGEDReader(Reader):
|
||||||
if line and len(line) > 5:
|
if line and len(line) > 5:
|
||||||
yield line.lower()[5:].strip().split()
|
yield line.lower()[5:].strip().split()
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_vocab(input_texts):
|
||||||
|
vocab = set()
|
||||||
|
for line in input_texts:
|
||||||
|
for char in line:
|
||||||
|
if char not in vocab:
|
||||||
|
vocab.add(char)
|
||||||
|
return sorted(list(vocab))
|
||||||
|
|
|
@ -1,415 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Author: XuMing <xuming624@qq.com>
|
|
||||||
# Brief:
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import embedding_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import nn_ops
|
|
||||||
|
|
||||||
from pycorrector.seq2seq import seq2seq
|
|
||||||
from pycorrector.seq2seq.reader import PAD_ID, GO_ID
|
|
||||||
|
|
||||||
|
|
||||||
class CorrectorModel(object):
|
|
||||||
"""Sequence-to-sequence model used to correct grammatical errors in text.
|
|
||||||
|
|
||||||
NOTE: mostly copied from TensorFlow's seq2seq_model.py; only modifications
|
|
||||||
are:
|
|
||||||
- the introduction of RMSProp as an optional optimization algorithm
|
|
||||||
- the introduction of a "projection bias" that biases decoding towards
|
|
||||||
selecting tokens that appeared in the input
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, source_vocab_size, target_vocab_size, buckets, size,
|
|
||||||
num_layers, max_gradient_norm, batch_size, learning_rate,
|
|
||||||
learning_rate_decay_factor, use_lstm=False,
|
|
||||||
num_samples=512, forward_only=False, config=None,
|
|
||||||
corrective_tokens_mask=None):
|
|
||||||
"""Create the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_vocab_size: size of the source vocabulary.
|
|
||||||
target_vocab_size: size of the target vocabulary.
|
|
||||||
buckets: a list of pairs (I, O), where I specifies maximum input
|
|
||||||
length that will be processed in that bucket, and O specifies
|
|
||||||
maximum output length. Training instances that have longer than I
|
|
||||||
or outputs longer than O will be pushed to the next bucket and
|
|
||||||
padded accordingly. We assume that the list is sorted, e.g., [(2,
|
|
||||||
4), (8, 16)].
|
|
||||||
size: number of units in each layer of the model.
|
|
||||||
num_layers: number of layers in the model.
|
|
||||||
max_gradient_norm: gradients will be clipped to maximally this norm.
|
|
||||||
batch_size: the size of the batches used during training;
|
|
||||||
the model construction is independent of batch_size, so it can be
|
|
||||||
changed after initialization if this is convenient, e.g.,
|
|
||||||
for decoding.
|
|
||||||
learning_rate: learning rate to start with.
|
|
||||||
learning_rate_decay_factor: decay learning rate by this much when
|
|
||||||
needed.
|
|
||||||
use_lstm: if true, we use LSTM cells instead of GRU cells.
|
|
||||||
num_samples: number of samples for sampled softmax.
|
|
||||||
forward_only: if set, we do not construct the backward pass in the
|
|
||||||
model.
|
|
||||||
"""
|
|
||||||
self.source_vocab_size = source_vocab_size
|
|
||||||
self.target_vocab_size = target_vocab_size
|
|
||||||
self.buckets = buckets
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
|
|
||||||
self.learning_rate_decay_op = self.learning_rate.assign(
|
|
||||||
self.learning_rate * learning_rate_decay_factor)
|
|
||||||
self.global_step = tf.Variable(0, trainable=False)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# Feeds for inputs.
|
|
||||||
self.encoder_inputs = []
|
|
||||||
self.decoder_inputs = []
|
|
||||||
self.target_weights = []
|
|
||||||
for i in range(buckets[-1][0]): # Last bucket is the biggest one.
|
|
||||||
self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
|
|
||||||
name="encoder{0}".format(
|
|
||||||
i)))
|
|
||||||
for i in range(buckets[-1][1] + 1):
|
|
||||||
self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
|
|
||||||
name="decoder{0}".format(
|
|
||||||
i)))
|
|
||||||
self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
|
|
||||||
name="weight{0}".format(
|
|
||||||
i)))
|
|
||||||
|
|
||||||
# One hot encoding of corrective tokens.
|
|
||||||
corrective_tokens_tensor = tf.constant(corrective_tokens_mask if
|
|
||||||
corrective_tokens_mask else
|
|
||||||
np.zeros(self.target_vocab_size),
|
|
||||||
shape=[self.target_vocab_size],
|
|
||||||
dtype=tf.float32)
|
|
||||||
batched_corrective_tokens = tf.stack(
|
|
||||||
[corrective_tokens_tensor] * self.batch_size)
|
|
||||||
self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \
|
|
||||||
tf.placeholder(
|
|
||||||
tf.float32,
|
|
||||||
shape=[None, None],
|
|
||||||
name="corrective_tokens")
|
|
||||||
|
|
||||||
# Our targets are decoder inputs shifted by one.
|
|
||||||
targets = [self.decoder_inputs[i + 1]
|
|
||||||
for i in range(len(self.decoder_inputs) - 1)]
|
|
||||||
# If we use sampled softmax, we need an output projection.
|
|
||||||
output_projection = None
|
|
||||||
softmax_loss_function = None
|
|
||||||
# Sampled softmax only makes sense if we sample less than vocabulary
|
|
||||||
# size.
|
|
||||||
if num_samples > 0 and num_samples < self.target_vocab_size:
|
|
||||||
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
|
|
||||||
w_t = tf.transpose(w)
|
|
||||||
b = tf.get_variable("proj_b", [self.target_vocab_size])
|
|
||||||
|
|
||||||
output_projection = (w, b)
|
|
||||||
|
|
||||||
def sampled_loss(labels, logits):
|
|
||||||
labels = tf.reshape(labels, [-1, 1])
|
|
||||||
return tf.nn.sampled_softmax_loss(w_t, b, labels, logits,
|
|
||||||
num_samples,
|
|
||||||
self.target_vocab_size)
|
|
||||||
|
|
||||||
softmax_loss_function = sampled_loss
|
|
||||||
|
|
||||||
# Create the internal multi-layer cell for our RNN.
|
|
||||||
single_cell = tf.nn.rnn_cell.GRUCell(size)
|
|
||||||
if use_lstm:
|
|
||||||
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
|
|
||||||
cell = single_cell
|
|
||||||
if num_layers > 1:
|
|
||||||
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
|
|
||||||
|
|
||||||
# The seq2seq function: we use embedding for the input and attention.
|
|
||||||
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
|
|
||||||
"""
|
|
||||||
|
|
||||||
:param encoder_inputs: list of length equal to the input bucket
|
|
||||||
length of 1-D tensors (of length equal to the batch size) whose
|
|
||||||
elements consist of the token index of each sample in the batch
|
|
||||||
at a given index in the input.
|
|
||||||
:param decoder_inputs:
|
|
||||||
:param do_decode:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if do_decode:
|
|
||||||
# Modify bias here to bias the model towards selecting words
|
|
||||||
# present in the input sentence.
|
|
||||||
input_bias = self.build_input_bias(encoder_inputs,
|
|
||||||
batch_corrective_tokens_mask)
|
|
||||||
|
|
||||||
# Redefined seq2seq to allow for the injection of a special
|
|
||||||
# decoding function that
|
|
||||||
return seq2seq.embedding_attention_seq2seq(
|
|
||||||
encoder_inputs, decoder_inputs, cell,
|
|
||||||
num_encoder_symbols=source_vocab_size,
|
|
||||||
num_decoder_symbols=target_vocab_size,
|
|
||||||
embedding_size=size,
|
|
||||||
output_projection=output_projection,
|
|
||||||
feed_previous=do_decode,
|
|
||||||
loop_fn_factory=
|
|
||||||
apply_input_bias_and_extract_argmax_fn_factory(input_bias))
|
|
||||||
else:
|
|
||||||
return seq2seq.embedding_attention_seq2seq(
|
|
||||||
encoder_inputs, decoder_inputs, cell,
|
|
||||||
num_encoder_symbols=source_vocab_size,
|
|
||||||
num_decoder_symbols=target_vocab_size,
|
|
||||||
embedding_size=size,
|
|
||||||
output_projection=output_projection,
|
|
||||||
feed_previous=do_decode)
|
|
||||||
|
|
||||||
# Training outputs and losses.
|
|
||||||
if forward_only:
|
|
||||||
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
|
|
||||||
self.encoder_inputs, self.decoder_inputs, targets,
|
|
||||||
self.target_weights, buckets,
|
|
||||||
lambda x, y: seq2seq_f(x, y, True),
|
|
||||||
softmax_loss_function=softmax_loss_function)
|
|
||||||
|
|
||||||
if output_projection is not None:
|
|
||||||
for b in range(len(buckets)):
|
|
||||||
# We need to apply the same input bias used during model
|
|
||||||
# evaluation when decoding.
|
|
||||||
input_bias = self.build_input_bias(
|
|
||||||
self.encoder_inputs[:buckets[b][0]],
|
|
||||||
batch_corrective_tokens_mask)
|
|
||||||
self.outputs[b] = [
|
|
||||||
project_and_apply_input_bias(output, output_projection,
|
|
||||||
input_bias)
|
|
||||||
for output in self.outputs[b]]
|
|
||||||
else:
|
|
||||||
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
|
|
||||||
self.encoder_inputs, self.decoder_inputs, targets,
|
|
||||||
self.target_weights, buckets,
|
|
||||||
lambda x, y: seq2seq_f(x, y, False),
|
|
||||||
softmax_loss_function=softmax_loss_function)
|
|
||||||
|
|
||||||
# Gradients and SGD update operation for training the model.
|
|
||||||
params = tf.trainable_variables()
|
|
||||||
if not forward_only:
|
|
||||||
self.gradient_norms = []
|
|
||||||
self.updates = []
|
|
||||||
opt = tf.train.RMSPropOptimizer(0.001) if self.config.use_rms_prop \
|
|
||||||
else tf.train.GradientDescentOptimizer(self.learning_rate)
|
|
||||||
# opt = tf.train.AdamOptimizer()
|
|
||||||
|
|
||||||
for b in range(len(buckets)):
|
|
||||||
gradients = tf.gradients(self.losses[b], params)
|
|
||||||
clipped_gradients, norm = tf.clip_by_global_norm(
|
|
||||||
gradients, max_gradient_norm)
|
|
||||||
self.gradient_norms.append(norm)
|
|
||||||
self.updates.append(opt.apply_gradients(
|
|
||||||
zip(clipped_gradients, params),
|
|
||||||
global_step=self.global_step))
|
|
||||||
|
|
||||||
self.saver = tf.train.Saver(tf.global_variables())
|
|
||||||
|
|
||||||
def build_input_bias(self, encoder_inputs, batch_corrective_tokens_mask):
|
|
||||||
packed_one_hot_inputs = tf.one_hot(indices=tf.stack(
|
|
||||||
encoder_inputs, axis=1), depth=self.target_vocab_size)
|
|
||||||
return tf.maximum(batch_corrective_tokens_mask,
|
|
||||||
tf.reduce_max(packed_one_hot_inputs,
|
|
||||||
reduction_indices=1))
|
|
||||||
|
|
||||||
def step(self, session, encoder_inputs, decoder_inputs, target_weights,
|
|
||||||
bucket_id, forward_only, corrective_tokens=None):
|
|
||||||
"""Run a step of the model feeding the given inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: tensorflow session to use.
|
|
||||||
encoder_inputs: list of numpy int vectors to feed as encoder inputs.
|
|
||||||
decoder_inputs: list of numpy int vectors to feed as decoder inputs.
|
|
||||||
target_weights: list of numpy float vectors to feed as target weights.
|
|
||||||
bucket_id: which bucket of the model to use.
|
|
||||||
forward_only: whether to do the backward step or only forward.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A triple consisting of gradient norm (or None if we did not do
|
|
||||||
backward), average perplexity, and the outputs.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if length of encoder_inputs, decoder_inputs, or
|
|
||||||
target_weights disagrees with bucket size for the specified
|
|
||||||
bucket_id.
|
|
||||||
"""
|
|
||||||
# Check if the sizes match.
|
|
||||||
encoder_size, decoder_size = self.buckets[bucket_id]
|
|
||||||
if len(encoder_inputs) != encoder_size:
|
|
||||||
raise ValueError("Encoder length must be equal to the one in bucket,"
|
|
||||||
" %d != %d." % (len(encoder_inputs), encoder_size))
|
|
||||||
if len(decoder_inputs) != decoder_size:
|
|
||||||
raise ValueError("Decoder length must be equal to the one in bucket,"
|
|
||||||
" %d != %d." % (len(decoder_inputs), decoder_size))
|
|
||||||
if len(target_weights) != decoder_size:
|
|
||||||
raise ValueError("Weights length must be equal to the one in bucket,"
|
|
||||||
" %d != %d." % (len(target_weights), decoder_size))
|
|
||||||
|
|
||||||
# Input feed: encoder inputs, decoder inputs, target_weights,
|
|
||||||
# as provided.
|
|
||||||
input_feed = {}
|
|
||||||
for l in range(encoder_size):
|
|
||||||
input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
|
|
||||||
for l in range(decoder_size):
|
|
||||||
input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
|
|
||||||
input_feed[self.target_weights[l].name] = target_weights[l]
|
|
||||||
|
|
||||||
corrective_tokens_vector = (corrective_tokens if
|
|
||||||
corrective_tokens is not None else
|
|
||||||
np.zeros(self.target_vocab_size))
|
|
||||||
batch_corrective_tokens = np.repeat([corrective_tokens_vector],
|
|
||||||
self.batch_size, axis=0)
|
|
||||||
input_feed[self.batch_corrective_tokens_mask.name] = (
|
|
||||||
batch_corrective_tokens)
|
|
||||||
|
|
||||||
# Since our targets are decoder inputs shifted by one, we need one more.
|
|
||||||
last_target = self.decoder_inputs[decoder_size].name
|
|
||||||
input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
|
|
||||||
|
|
||||||
# Output feed: depends on whether we do a backward step or not.
|
|
||||||
if not forward_only:
|
|
||||||
output_feed = [self.updates[bucket_id], # Update Op that does SGD.
|
|
||||||
self.gradient_norms[bucket_id], # Gradient norm.
|
|
||||||
self.losses[bucket_id]] # Loss for this batch.
|
|
||||||
else:
|
|
||||||
output_feed = [self.losses[bucket_id]] # Loss for this batch.
|
|
||||||
for l in range(decoder_size): # Output logits.
|
|
||||||
output_feed.append(self.outputs[bucket_id][l])
|
|
||||||
|
|
||||||
outputs = session.run(output_feed, input_feed)
|
|
||||||
if not forward_only:
|
|
||||||
# Gradient norm, loss, no outputs.
|
|
||||||
return outputs[1], outputs[2], None
|
|
||||||
else:
|
|
||||||
# No gradient norm, loss, outputs.
|
|
||||||
return None, outputs[0], outputs[1:]
|
|
||||||
|
|
||||||
def get_batch(self, data, bucket_id):
|
|
||||||
"""Get a random batch of data from the specified bucket, prepare for
|
|
||||||
step.
|
|
||||||
|
|
||||||
To feed data in step(..) it must be a list of batch-major vectors, while
|
|
||||||
data here contains single length-major cases. So the main logic of this
|
|
||||||
function is to re-index data cases to be in the proper format for
|
|
||||||
feeding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: a tuple of size len(self.buckets) in which each element contains
|
|
||||||
lists of pairs of input and output data that we use to create a
|
|
||||||
batch.
|
|
||||||
bucket_id: integer, which bucket to get the batch for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The triple (encoder_inputs, decoder_inputs, target_weights) for
|
|
||||||
the constructed batch that has the proper format to call step(...)
|
|
||||||
later.
|
|
||||||
"""
|
|
||||||
encoder_size, decoder_size = self.buckets[bucket_id]
|
|
||||||
encoder_inputs, decoder_inputs = [], []
|
|
||||||
|
|
||||||
# Get a random batch of encoder and decoder inputs from data,
|
|
||||||
# pad them if needed, reverse encoder inputs and add GO to decoder.
|
|
||||||
for _ in range(self.batch_size):
|
|
||||||
encoder_input, decoder_input = random.choice(data[bucket_id])
|
|
||||||
|
|
||||||
# Encoder inputs are padded and then reversed.
|
|
||||||
encoder_pad = [PAD_ID] * (
|
|
||||||
encoder_size - len(encoder_input))
|
|
||||||
encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
|
|
||||||
|
|
||||||
# Decoder inputs get an extra "GO" symbol, and are padded then.
|
|
||||||
decoder_pad_size = decoder_size - len(decoder_input) - 1
|
|
||||||
decoder_inputs.append([GO_ID] + decoder_input +
|
|
||||||
[PAD_ID] * decoder_pad_size)
|
|
||||||
|
|
||||||
# Now we create batch-major vectors from the data selected above.
|
|
||||||
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
|
|
||||||
|
|
||||||
# Batch encoder inputs are just re-indexed encoder_inputs.
|
|
||||||
for length_idx in range(encoder_size):
|
|
||||||
batch_encoder_inputs.append(
|
|
||||||
np.array([encoder_inputs[batch_idx][length_idx]
|
|
||||||
for batch_idx in range(self.batch_size)],
|
|
||||||
dtype=np.int32))
|
|
||||||
|
|
||||||
# Batch decoder inputs are re-indexed decoder_inputs, we create weights.
|
|
||||||
for length_idx in range(decoder_size):
|
|
||||||
batch_decoder_inputs.append(
|
|
||||||
np.array([decoder_inputs[batch_idx][length_idx]
|
|
||||||
for batch_idx in range(self.batch_size)],
|
|
||||||
dtype=np.int32))
|
|
||||||
|
|
||||||
# Create target_weights to be 0 for targets that are padding.
|
|
||||||
batch_weight = np.ones(self.batch_size, dtype=np.float32)
|
|
||||||
for batch_idx in range(self.batch_size):
|
|
||||||
# We set weight to 0 if the corresponding target is a PAD
|
|
||||||
# symbol. The corresponding target is decoder_input shifted by 1
|
|
||||||
# forward.
|
|
||||||
if length_idx < decoder_size - 1:
|
|
||||||
target = decoder_inputs[batch_idx][length_idx + 1]
|
|
||||||
if length_idx == decoder_size - 1 or target == PAD_ID:
|
|
||||||
batch_weight[batch_idx] = 0.0
|
|
||||||
batch_weights.append(batch_weight)
|
|
||||||
return batch_encoder_inputs, batch_decoder_inputs, batch_weights
|
|
||||||
|
|
||||||
|
|
||||||
def project_and_apply_input_bias(logits, output_projection, input_bias):
|
|
||||||
if output_projection is not None:
|
|
||||||
logits = nn_ops.xw_plus_b(
|
|
||||||
logits, output_projection[0], output_projection[1])
|
|
||||||
|
|
||||||
# Apply softmax to ensure all tokens have a positive value.
|
|
||||||
probs = tf.nn.softmax(logits)
|
|
||||||
|
|
||||||
# Apply input bias, which is a mask of shape [batch, vocab len]
|
|
||||||
# where each token from the input in addition to all "corrective"
|
|
||||||
# tokens are set to 1.0.
|
|
||||||
return tf.multiply(probs, input_bias)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
|
|
||||||
"""
|
|
||||||
|
|
||||||
:param encoder_inputs: list of length equal to the input bucket
|
|
||||||
length of 1-D tensors (of length equal to the batch size) whose
|
|
||||||
elements consist of the token index of each sample in the batch
|
|
||||||
at a given index in the input.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def fn_factory(embedding, output_projection=None, update_embedding=True):
|
|
||||||
"""Get a loop_function that extracts the previous symbol and embeds it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: embedding tensor for symbols.
|
|
||||||
output_projection: None or a pair (W, B). If provided, each fed previous
|
|
||||||
output will first be multiplied by W and added B.
|
|
||||||
update_embedding: Boolean; if False, the gradients will not propagate
|
|
||||||
through the embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A loop function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def loop_function(prev, _):
|
|
||||||
prev = project_and_apply_input_bias(prev, output_projection,
|
|
||||||
input_bias)
|
|
||||||
|
|
||||||
prev_symbol = math_ops.argmax(prev, 1)
|
|
||||||
# Note that gradients will not propagate through the second
|
|
||||||
# parameter of embedding_lookup.
|
|
||||||
emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
|
|
||||||
if not update_embedding:
|
|
||||||
emb_prev = array_ops.stop_gradient(emb_prev)
|
|
||||||
return emb_prev, prev_symbol
|
|
||||||
|
|
||||||
return loop_function
|
|
||||||
|
|
||||||
return fn_factory
|
|
|
@ -1,35 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Author: XuMing <xuming624@qq.com>
|
|
||||||
# Brief: Use FCE english corpus
|
|
||||||
import os
|
|
||||||
|
|
||||||
# FCE english corpus
|
|
||||||
train_path = '../data/en/fce/fce_train.txt' # Training data path.
|
|
||||||
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
|
||||||
test_path = '../data/en/fce/fce_test.txt'
|
|
||||||
|
|
||||||
model_path = './output/fce_model' # Path of the model saved, default is output_path/model
|
|
||||||
enable_special_error = False
|
|
||||||
num_steps = 3000 # Number of steps to train.
|
|
||||||
decode_sentence = False # Whether we should decode sentences of the user.
|
|
||||||
|
|
||||||
# Config
|
|
||||||
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
|
||||||
steps_per_checkpoint = 100
|
|
||||||
max_steps = 2000
|
|
||||||
max_vocab_size = 10000
|
|
||||||
size = 512
|
|
||||||
num_layers = 1
|
|
||||||
max_gradient_norm = 5.0
|
|
||||||
batch_size = 64
|
|
||||||
learning_rate = 0.5
|
|
||||||
learning_rate_decay_factor = 0.99
|
|
||||||
use_lstm = False
|
|
||||||
use_rms_prop = False
|
|
||||||
|
|
||||||
enable_decode_sentence = False # Test with input error sentence
|
|
||||||
enable_test_decode = True # Test with test set
|
|
||||||
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
os.makedirs(model_path)
|
|
|
@ -2,205 +2,135 @@
|
||||||
# Author: XuMing <xuming624@qq.com>
|
# Author: XuMing <xuming624@qq.com>
|
||||||
# Brief:
|
# Brief:
|
||||||
|
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
from keras.layers import Input
|
||||||
|
from keras.models import Model, load_model
|
||||||
|
|
||||||
import cged_config
|
from pycorrector.seq2seq import cged_config as config
|
||||||
from corpus_reader import CGEDReader
|
from pycorrector.seq2seq.corpus_reader import GO_TOKEN
|
||||||
from reader import EOS_ID
|
from pycorrector.seq2seq.corpus_reader import CGEDReader
|
||||||
from utils.text_utils import segment
|
from pycorrector.utils.io_utils import get_logger
|
||||||
from train import create_model
|
|
||||||
|
from pycorrector.seq2seq.reader import EOS_TOKEN
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def decode(sess, model, data_reader, data_to_decode,
|
def decode_sequence(model, rnn_hidden_dim,input_token_index,
|
||||||
corrective_tokens=None, verbose=True):
|
num_decoder_tokens, target_token_index,encoder_input_data,
|
||||||
"""
|
reverse_target_char_index, max_decoder_seq_length):
|
||||||
Infer the correction sentence
|
# construct the encoder and decoder
|
||||||
:param sess:
|
encoder_inputs = model.input[0] # input_1
|
||||||
:param model:
|
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1
|
||||||
:param data_reader:
|
encoder_states = [state_h_enc, state_c_enc]
|
||||||
:param data_to_decode: an iterable of token lists representing the input
|
encoder_model = Model(encoder_inputs, encoder_states)
|
||||||
data we want to decode
|
|
||||||
:param corrective_tokens
|
|
||||||
:param verbose:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
model.batch_size = 1
|
|
||||||
corrective_tokens_mask = np.zeros(model.target_vocab_size)
|
|
||||||
corrective_tokens_mask[EOS_ID] = 1.0
|
|
||||||
|
|
||||||
if corrective_tokens is None:
|
decoder_inputs = model.input[1] # input_2
|
||||||
corrective_tokens = set()
|
decoder_state_input_h = Input(shape=(rnn_hidden_dim,), name='input_3')
|
||||||
for tokens in corrective_tokens:
|
decoder_state_input_c = Input(shape=(rnn_hidden_dim,), name='input_4')
|
||||||
for token in tokens:
|
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
|
||||||
corrective_tokens_mask[data_reader.convert_token_2_id(token)] = 1.0
|
decoder_lstm = model.layers[3]
|
||||||
|
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
|
||||||
|
decoder_inputs, initial_state=decoder_states_inputs)
|
||||||
|
decoder_states = [state_h_dec, state_c_dec]
|
||||||
|
decoder_dense = model.layers[4]
|
||||||
|
decoder_outputs = decoder_dense(decoder_outputs)
|
||||||
|
decoder_model = Model(
|
||||||
|
[decoder_inputs] + decoder_states_inputs,
|
||||||
|
[decoder_outputs] + decoder_states)
|
||||||
|
|
||||||
for tokens in data_to_decode:
|
# Reverse-lookup token index to decode sequences back to
|
||||||
token_ids = [data_reader.convert_token_2_id(token) for token in tokens]
|
# something readable.
|
||||||
|
reverse_input_char_index = dict(
|
||||||
|
(i, char) for char, i in input_token_index.items())
|
||||||
|
reverse_target_char_index = dict(
|
||||||
|
(i, char) for char, i in target_token_index.items())
|
||||||
|
|
||||||
# Which bucket does it belong to?
|
# Encode the input as state vectors.
|
||||||
matching_buckets = [b for b in range(len(model.buckets))
|
states_value = encoder_model.predict(encoder_input_data)
|
||||||
if model.buckets[b][0] > len(token_ids)]
|
|
||||||
if not matching_buckets:
|
|
||||||
# The input string has more tokens than the largest bucket, so we
|
|
||||||
# have to skip it.
|
|
||||||
continue
|
|
||||||
|
|
||||||
bucket_id = min(matching_buckets)
|
# Generate empty target sequence of length 1.
|
||||||
|
target_seq = np.zeros((1, 1, num_decoder_tokens))
|
||||||
|
# Populate the first character of target sequence with the start character.
|
||||||
|
# target_seq[0, 0, target_token_index[first_char]] = 1.
|
||||||
|
|
||||||
# Get a 1-element batch to feed the sentence to the model.
|
# Sampling loop for a batch of sequences
|
||||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
# (to simplify, here we assume a batch of size 1).
|
||||||
{bucket_id: [(token_ids, [])]}, bucket_id)
|
stop_condition = False
|
||||||
|
decoded_sentence = ''
|
||||||
|
|
||||||
# Get output logits for the sentence.
|
while not stop_condition:
|
||||||
_, _, output_logits = model.step(
|
output_tokens, h, c = decoder_model.predict(
|
||||||
sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
|
[target_seq] + states_value)
|
||||||
True, corrective_tokens=corrective_tokens_mask)
|
|
||||||
|
|
||||||
oov_input_tokens = [token for token in tokens if
|
# Sample a token
|
||||||
data_reader.is_unknown_token(token)]
|
sampled_token_index = np.argmax(output_tokens[0, -1, :])
|
||||||
outputs = []
|
sampled_char = reverse_target_char_index[sampled_token_index]
|
||||||
next_oov_token_idx = 0
|
decoded_sentence += sampled_char
|
||||||
|
|
||||||
for logit in output_logits:
|
# Exit condition: either hit max length
|
||||||
max_likelihood_token_id = int(np.argmax(logit, axis=1))
|
# or find stop character.
|
||||||
# Check if this logit most likely points to the EOS identifier.
|
if (sampled_char == EOS_TOKEN or
|
||||||
if max_likelihood_token_id == EOS_ID:
|
len(decoded_sentence) > max_decoder_seq_length):
|
||||||
break
|
stop_condition = True
|
||||||
|
|
||||||
token = data_reader.convert_id_2_token(max_likelihood_token_id)
|
# Update the target sequence (of length 1).
|
||||||
if data_reader.is_unknown_token(token):
|
target_seq = np.zeros((1, 1, num_decoder_tokens))
|
||||||
# Replace the "unknown" token with the most probable OOV
|
target_seq[0, 0, sampled_token_index] = 1.
|
||||||
# token from the input.
|
|
||||||
if next_oov_token_idx < len(oov_input_tokens):
|
# Update states
|
||||||
# If we still have OOV input tokens available,
|
states_value = [h, c]
|
||||||
# pick the next available one.
|
|
||||||
token = oov_input_tokens[next_oov_token_idx]
|
return decoded_sentence
|
||||||
# Advance to the next OOV input token.
|
|
||||||
next_oov_token_idx += 1
|
|
||||||
else:
|
|
||||||
# If we've already used all OOV input tokens,
|
|
||||||
# then we just leave the token as "UNK"
|
|
||||||
pass
|
|
||||||
outputs.append(token)
|
|
||||||
if verbose:
|
|
||||||
decoded_sentence = " ".join(outputs)
|
|
||||||
print("Input: {}".format(" ".join(tokens)))
|
|
||||||
print("Output: {}\n".format(decoded_sentence))
|
|
||||||
yield outputs
|
|
||||||
|
|
||||||
|
|
||||||
def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(),
|
def infer(train_path=None,
|
||||||
verbose=True):
|
test_path=None,
|
||||||
"""Used with InteractiveSession in IPython """
|
save_model_path=None,
|
||||||
return next(decode(sess, model, data_reader, [segment(sentence, 'char')],
|
rnn_hidden_dim=200):
|
||||||
corrective_tokens=corrective_tokens, verbose=verbose))
|
data_reader = CGEDReader(train_path)
|
||||||
|
input_texts, target_texts = data_reader.build_dataset(test_path)
|
||||||
|
|
||||||
|
input_characters = data_reader.read_vocab(input_texts)
|
||||||
|
target_characters = data_reader.read_vocab(target_texts)
|
||||||
|
num_encoder_tokens = len(input_characters)
|
||||||
|
num_decoder_tokens = len(target_characters)
|
||||||
|
max_encoder_seq_len = max([len(text) for text in input_texts])
|
||||||
|
max_decoder_seq_len = max([len(text) for text in target_texts])
|
||||||
|
|
||||||
def evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path,
|
print('num of samples:', len(input_texts))
|
||||||
max_samples=None):
|
print('num of unique input tokens:', num_encoder_tokens)
|
||||||
"""Evaluates the accuracy and BLEU score of the given model."""
|
print('num of unique output tokens:', num_decoder_tokens)
|
||||||
|
print('max sequence length for inputs:', max_encoder_seq_len)
|
||||||
|
print('max sequence length for outputs:', max_decoder_seq_len)
|
||||||
|
|
||||||
import nltk # Loading here to avoid having to bundle it in lambda.
|
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
|
||||||
|
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
|
||||||
|
|
||||||
# Build a collection of "baseline" and model-based hypotheses, where the
|
encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_len, num_encoder_tokens), dtype='float32')
|
||||||
# baseline is just the (potentially errant) source sequence.
|
|
||||||
baseline_hypotheses = defaultdict(list) # The model's input
|
|
||||||
model_hypotheses = defaultdict(list) # The actual model's predictions
|
|
||||||
targets = defaultdict(list) # Groundtruth
|
|
||||||
|
|
||||||
errors = []
|
# one hot representation
|
||||||
|
for i, input_text in enumerate(input_texts):
|
||||||
|
for t, char in enumerate(input_text):
|
||||||
|
encoder_input_data[i, t, input_token_index[char]] = 1.0
|
||||||
|
logger.info("Data loaded.")
|
||||||
|
|
||||||
n_samples_by_bucket = defaultdict(int)
|
# model
|
||||||
n_correct_model_by_bucket = defaultdict(int)
|
logger.info("Infer seq2seq model...")
|
||||||
n_correct_baseline_by_bucket = defaultdict(int)
|
model = load_model(save_model_path)
|
||||||
n_samples = 0
|
|
||||||
|
|
||||||
# Evaluate the model against all samples in the test data set.
|
decoded_sentences = decode_sequence(model, encoder_input_data, )
|
||||||
for source, target in data_reader.read_samples_by_string(test_path):
|
for seq_index in input_text:
|
||||||
matching_buckets = [i for i, bucket in enumerate(model.buckets) if
|
print('-')
|
||||||
len(source) < bucket[0]]
|
print('Input sentence:', input_texts[seq_index])
|
||||||
if not matching_buckets:
|
print('Decoded sentence:', decoded_sentences[seq_index])
|
||||||
continue
|
|
||||||
|
|
||||||
bucket_id = matching_buckets[0]
|
logger.info("Infer has finished.")
|
||||||
|
|
||||||
decoding = next(
|
|
||||||
decode(sess, model, data_reader, [source],
|
|
||||||
corrective_tokens=corrective_tokens, verbose=False))
|
|
||||||
model_hypotheses[bucket_id].append(decoding)
|
|
||||||
if decoding == target:
|
|
||||||
n_correct_model_by_bucket[bucket_id] += 1
|
|
||||||
else:
|
|
||||||
errors.append((decoding, target))
|
|
||||||
|
|
||||||
baseline_hypotheses[bucket_id].append(source)
|
|
||||||
if source == target:
|
|
||||||
n_correct_baseline_by_bucket[bucket_id] += 1
|
|
||||||
|
|
||||||
# nltk.corpus_bleu expects a list of one or more reference
|
|
||||||
# translations per sample, so we wrap the target list in another list
|
|
||||||
targets[bucket_id].append([target])
|
|
||||||
|
|
||||||
n_samples_by_bucket[bucket_id] += 1
|
|
||||||
n_samples += 1
|
|
||||||
|
|
||||||
if max_samples is not None and n_samples > max_samples:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Measure the corpus BLEU score and accuracy for the model and baseline
|
|
||||||
# across all buckets.
|
|
||||||
for bucket_id in targets.keys():
|
|
||||||
baseline_bleu_score = nltk.translate.bleu_score.corpus_bleu(
|
|
||||||
targets[bucket_id], baseline_hypotheses[bucket_id])
|
|
||||||
model_bleu_score = nltk.translate.bleu_score.corpus_bleu(
|
|
||||||
targets[bucket_id], model_hypotheses[bucket_id])
|
|
||||||
print("Bucket {}: {}".format(bucket_id, model.buckets[bucket_id]))
|
|
||||||
print("\tBaseline BLEU = {:.4f}\n\tModel BLEU = {:.4f}".format(
|
|
||||||
baseline_bleu_score, model_bleu_score))
|
|
||||||
print("\tBaseline Accuracy: {:.4f}".format(
|
|
||||||
1.0 * n_correct_baseline_by_bucket[bucket_id] /
|
|
||||||
n_samples_by_bucket[bucket_id]))
|
|
||||||
print("\tModel Accuracy: {:.4f}".format(
|
|
||||||
1.0 * n_correct_model_by_bucket[bucket_id] /
|
|
||||||
n_samples_by_bucket[bucket_id]))
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
|
||||||
print('Correcting error...')
|
|
||||||
# Set the model path.
|
|
||||||
model_path = cged_config.model_path
|
|
||||||
data_reader = CGEDReader(cged_config, cged_config.train_path)
|
|
||||||
|
|
||||||
if cged_config.enable_decode_sentence:
|
|
||||||
# Correct user's sentences.
|
|
||||||
with tf.Session() as session:
|
|
||||||
model = create_model(session, True, model_path, config=cged_config)
|
|
||||||
print("Enter a sentence you'd like to correct")
|
|
||||||
correct_new_sentence = input()
|
|
||||||
while correct_new_sentence.lower() != 'no':
|
|
||||||
decode_sentence(session, model=model, data_reader=data_reader,
|
|
||||||
sentence=correct_new_sentence,
|
|
||||||
corrective_tokens=data_reader.read_tokens(cged_config.train_path))
|
|
||||||
print("Enter a sentence you'd like to correct or press NO")
|
|
||||||
correct_new_sentence = input()
|
|
||||||
elif cged_config.enable_test_decode:
|
|
||||||
# Decode test sentences.
|
|
||||||
with tf.Session() as session:
|
|
||||||
model = create_model(session, True, model_path, config=cged_config)
|
|
||||||
print("Loaded model. Beginning decoding.")
|
|
||||||
decodings = decode(session, model=model, data_reader=data_reader,
|
|
||||||
data_to_decode=data_reader.read_tokens(cged_config.test_path, is_infer=True),
|
|
||||||
corrective_tokens=data_reader.read_tokens(cged_config.train_path))
|
|
||||||
# Write the decoded tokens to stdout.
|
|
||||||
for tokens in decodings:
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.app.run()
|
infer(train_path=config.train_path,
|
||||||
|
test_path=config.test_path,
|
||||||
|
save_model_path=config.save_model_path,
|
||||||
|
rnn_hidden_dim=config.rnn_hidden_dim)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from xml.dom import minidom
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
import cged_config as config
|
import pycorrector.seq2seq.cged_config as config
|
||||||
from utils.text_utils import segment
|
from utils.text_utils import segment
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,8 @@ GO_TOKEN = 'GO'
|
||||||
|
|
||||||
|
|
||||||
class Reader:
|
class Reader:
|
||||||
def __init__(self, config, train_path=None, token_2_id=None,
|
def __init__(self, train_path=None, token_2_id=None,
|
||||||
special_tokens=(), dataset_copies=1):
|
special_tokens=()):
|
||||||
self.config = config
|
|
||||||
self.dataset_copies = dataset_copies
|
|
||||||
# Vocabulary
|
|
||||||
max_vocab_size = config.max_vocab_size
|
|
||||||
if train_path is None:
|
if train_path is None:
|
||||||
self.token_2_id = token_2_id
|
self.token_2_id = token_2_id
|
||||||
else:
|
else:
|
||||||
|
@ -36,7 +32,7 @@ class Reader:
|
||||||
vocab[0:0] = special_tokens
|
vocab[0:0] = special_tokens
|
||||||
full_token_id = list(zip(vocab, range(len(vocab))))
|
full_token_id = list(zip(vocab, range(len(vocab))))
|
||||||
self.full_token_2_id = dict(full_token_id)
|
self.full_token_2_id = dict(full_token_id)
|
||||||
self.token_2_id = dict(full_token_id[:max_vocab_size])
|
self.token_2_id = dict(full_token_id)
|
||||||
self.id_2_token = {v: k for k, v in self.token_2_id.items()}
|
self.id_2_token = {v: k for k, v in self.token_2_id.items()}
|
||||||
|
|
||||||
def read_tokens(self, path):
|
def read_tokens(self, path):
|
||||||
|
@ -115,13 +111,21 @@ class Reader:
|
||||||
target.append(EOS_ID)
|
target.append(EOS_ID)
|
||||||
yield source, target
|
yield source, target
|
||||||
|
|
||||||
|
def read_samples_tokens(self, path):
|
||||||
|
"""
|
||||||
|
Read sample of path's data
|
||||||
|
:param path:
|
||||||
|
:return: generate list
|
||||||
|
"""
|
||||||
|
for source_words, target_words in self.read_samples_by_string(path):
|
||||||
|
target = target_words
|
||||||
|
target.append(EOS_TOKEN)
|
||||||
|
yield source_words, target
|
||||||
|
|
||||||
def build_dataset(self, path):
|
def build_dataset(self, path):
|
||||||
dataset = [[] for _ in self.config.buckets]
|
print('Read data, path:{0}'.format(path))
|
||||||
# Copy the data set for different dropouts
|
sources, targets = [], []
|
||||||
for _ in range(self.dataset_copies):
|
for source, target in self.read_samples_tokens(path):
|
||||||
for source, target in self.read_samples(path):
|
sources.append(source)
|
||||||
for bucket_id, (source_size, target_size) in enumerate(self.config.buckets):
|
targets.append(target)
|
||||||
if len(source) < source_size and len(target) < target_size:
|
return sources, targets
|
||||||
dataset[bucket_id].append([source, target])
|
|
||||||
break
|
|
||||||
return dataset
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,144 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Author: XuMing <xuming624@qq.com>
|
||||||
|
# Brief:
|
||||||
|
from keras.layers import Input, LSTM, Dense
|
||||||
|
from keras.models import Model
|
||||||
|
from keras.callbacks import LambdaCallback
|
||||||
|
from keras.callbacks import ModelCheckpoint
|
||||||
|
from pycorrector.seq2seq.reader import EOS_TOKEN
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(num_encoder_tokens, num_decoder_tokens, rnn_hidden_dim=200):
|
||||||
|
# Define an input sequence and process it.
|
||||||
|
encoder_inputs = Input(shape=(None, num_encoder_tokens))
|
||||||
|
encoder = LSTM(rnn_hidden_dim, return_state=True)
|
||||||
|
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
|
||||||
|
# We discard `encoder_outputs` and only keep the states.
|
||||||
|
encoder_states = [state_h, state_c]
|
||||||
|
|
||||||
|
# Set up the decoder, using `encoder_states` as initial state.
|
||||||
|
decoder_inputs = Input(shape=(None, num_decoder_tokens))
|
||||||
|
# We set up our decoder to return full output sequences,
|
||||||
|
# and to return internal states as well. We don't use the
|
||||||
|
# return states in the training model, but we will use them in inference.
|
||||||
|
decoder_lstm = LSTM(rnn_hidden_dim, return_sequences=True, return_state=True)
|
||||||
|
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
|
||||||
|
initial_state=encoder_states)
|
||||||
|
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
|
||||||
|
decoder_outputs = decoder_dense(decoder_outputs)
|
||||||
|
|
||||||
|
# Define the model that will turn
|
||||||
|
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
|
||||||
|
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
|
||||||
|
|
||||||
|
# Run training
|
||||||
|
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def eval(num_encoder_tokens, num_decoder_tokens, rnn_hidden_dim, input_token_index, target_token_index,
|
||||||
|
max_decoder_seq_length, encoder_input_data, input_texts):
|
||||||
|
# Define an input sequence and process it.
|
||||||
|
encoder_inputs = Input(shape=(None, num_encoder_tokens))
|
||||||
|
encoder = LSTM(rnn_hidden_dim, return_state=True)
|
||||||
|
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
|
||||||
|
# We discard `encoder_outputs` and only keep the states.
|
||||||
|
encoder_states = [state_h, state_c]
|
||||||
|
|
||||||
|
# Set up the decoder, using `encoder_states` as initial state.
|
||||||
|
decoder_inputs = Input(shape=(None, num_decoder_tokens))
|
||||||
|
# We set up our decoder to return full output sequences,
|
||||||
|
# and to return internal states as well. We don't use the
|
||||||
|
# return states in the training model, but we will use them in inference.
|
||||||
|
decoder_lstm = LSTM(rnn_hidden_dim, return_sequences=True, return_state=True)
|
||||||
|
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
|
||||||
|
initial_state=encoder_states)
|
||||||
|
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
|
||||||
|
decoder_outputs = decoder_dense(decoder_outputs)
|
||||||
|
|
||||||
|
# Define sampling models
|
||||||
|
encoder_model = Model(encoder_inputs, encoder_states)
|
||||||
|
|
||||||
|
decoder_state_input_h = Input(shape=(rnn_hidden_dim,))
|
||||||
|
decoder_state_input_c = Input(shape=(rnn_hidden_dim,))
|
||||||
|
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
|
||||||
|
decoder_outputs, state_h, state_c = decoder_lstm(
|
||||||
|
decoder_inputs, initial_state=decoder_states_inputs)
|
||||||
|
decoder_states = [state_h, state_c]
|
||||||
|
decoder_outputs = decoder_dense(decoder_outputs)
|
||||||
|
decoder_model = Model(
|
||||||
|
[decoder_inputs] + decoder_states_inputs,
|
||||||
|
[decoder_outputs] + decoder_states)
|
||||||
|
|
||||||
|
# Reverse-lookup token index to decode sequences back to
|
||||||
|
# something readable.
|
||||||
|
reverse_input_char_index = dict(
|
||||||
|
(i, char) for char, i in input_token_index.items())
|
||||||
|
reverse_target_char_index = dict(
|
||||||
|
(i, char) for char, i in target_token_index.items())
|
||||||
|
|
||||||
|
def decode_sequence(input_seq, seq_index):
|
||||||
|
# Encode the input as state vectors.
|
||||||
|
states_value = encoder_model.predict(input_seq)
|
||||||
|
|
||||||
|
# Generate empty target sequence of length 1.
|
||||||
|
target_seq = np.zeros((1, 1, num_decoder_tokens))
|
||||||
|
# Populate the first character of target sequence with the start character.
|
||||||
|
first_char = input_texts[seq_index][0]
|
||||||
|
print('first char:', first_char)
|
||||||
|
target_seq[0, 0, target_token_index[first_char]] = 1.
|
||||||
|
|
||||||
|
# Sampling loop for a batch of sequences
|
||||||
|
# (to simplify, here we assume a batch of size 1).
|
||||||
|
stop_condition = False
|
||||||
|
decoded_sentence = first_char
|
||||||
|
while not stop_condition:
|
||||||
|
output_tokens, h, c = decoder_model.predict(
|
||||||
|
[target_seq] + states_value)
|
||||||
|
|
||||||
|
# Sample a token
|
||||||
|
sampled_token_index = np.argmax(output_tokens[0, -1, :])
|
||||||
|
sampled_char = reverse_target_char_index[sampled_token_index]
|
||||||
|
if sampled_char != EOS_TOKEN:
|
||||||
|
decoded_sentence += sampled_char
|
||||||
|
|
||||||
|
# Exit condition: either hit max length
|
||||||
|
# or find stop character.
|
||||||
|
if (sampled_char == EOS_TOKEN or
|
||||||
|
len(decoded_sentence) > max_decoder_seq_length):
|
||||||
|
stop_condition = True
|
||||||
|
|
||||||
|
# Update the target sequence (of length 1).
|
||||||
|
target_seq = np.zeros((1, 1, num_decoder_tokens))
|
||||||
|
target_seq[0, 0, sampled_token_index] = 1.
|
||||||
|
|
||||||
|
# Update states
|
||||||
|
states_value = [h, c]
|
||||||
|
|
||||||
|
return decoded_sentence
|
||||||
|
|
||||||
|
for seq_index in range(10):
|
||||||
|
# Take one sequence (part of the training set)
|
||||||
|
# for trying out decoding.
|
||||||
|
input_seq = encoder_input_data[seq_index: seq_index + 1]
|
||||||
|
decoded_sentence = decode_sequence(input_seq, seq_index)
|
||||||
|
|
||||||
|
print('Input sentence:', input_texts[seq_index])
|
||||||
|
print('Decoded sentence:', decoded_sentence)
|
||||||
|
print('-')
|
||||||
|
|
||||||
|
|
||||||
|
def callback(save_model_path, logger=None):
|
||||||
|
# Print the batch number at the beginning of every batch.
|
||||||
|
if logger:
|
||||||
|
batch_print_callback = LambdaCallback(
|
||||||
|
on_batch_begin=lambda batch, logs: logger.info('batch: %d' % batch))
|
||||||
|
else:
|
||||||
|
batch_print_callback = LambdaCallback(
|
||||||
|
on_batch_begin=lambda batch, logs: print(batch))
|
||||||
|
# define the checkpoint, save model
|
||||||
|
checkpoint = ModelCheckpoint(save_model_path,
|
||||||
|
save_best_only=True,
|
||||||
|
verbose=1)
|
||||||
|
return [batch_print_callback, checkpoint]
|
|
@ -2,143 +2,76 @@
|
||||||
# Author: XuMing <xuming624@qq.com>
|
# Author: XuMing <xuming624@qq.com>
|
||||||
# Brief: Train seq2seq model for text grammar error correction
|
# Brief: Train seq2seq model for text grammar error correction
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from pycorrector.seq2seq import cged_config
|
from pycorrector.seq2seq import cged_config as config
|
||||||
from pycorrector.seq2seq.corpus_reader import CGEDReader
|
from pycorrector.seq2seq.corpus_reader import CGEDReader
|
||||||
from pycorrector.seq2seq.corrector_model import CorrectorModel
|
from pycorrector.seq2seq.seq2seq_model import create_model, callback, eval
|
||||||
from pycorrector.utils.tf_utils import get_ckpt_path
|
from pycorrector.utils.io_utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_model(session, forward_only, model_path, config=cged_config):
|
def train(train_path=None,
|
||||||
"""
|
save_model_path=None,
|
||||||
Create model and load parameters
|
batch_size=64,
|
||||||
:param session:
|
epochs=10,
|
||||||
:param forward_only:
|
rnn_hidden_dim=200):
|
||||||
:param model_path:
|
|
||||||
:param config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
model = CorrectorModel(
|
|
||||||
config.max_vocab_size,
|
|
||||||
config.max_vocab_size,
|
|
||||||
config.buckets,
|
|
||||||
config.size,
|
|
||||||
config.num_layers,
|
|
||||||
config.max_gradient_norm,
|
|
||||||
config.batch_size,
|
|
||||||
config.learning_rate,
|
|
||||||
config.learning_rate_decay_factor,
|
|
||||||
config.use_lstm,
|
|
||||||
forward_only=forward_only,
|
|
||||||
config=config)
|
|
||||||
ckpt_path = get_ckpt_path(model_path)
|
|
||||||
if ckpt_path:
|
|
||||||
print("Read model parameters from %s" % ckpt_path)
|
|
||||||
model.saver.restore(session, ckpt_path)
|
|
||||||
else:
|
|
||||||
print('Create model...')
|
|
||||||
session.run(tf.global_variables_initializer())
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def train(data_reader, train_path, test_path, model_path):
|
|
||||||
print('Read data, train:{0}, test:{1}'.format(train_path, test_path))
|
|
||||||
config = data_reader.config
|
|
||||||
train_data = data_reader.build_dataset(train_path)
|
|
||||||
test_data = data_reader.build_dataset(test_path)
|
|
||||||
|
|
||||||
with tf.Session() as sess:
|
|
||||||
# Create model
|
|
||||||
print('Create %d layers of %d units.' % (config.num_layers, config.size))
|
|
||||||
model = create_model(sess, False, model_path, config=config)
|
|
||||||
# Read data into buckets
|
|
||||||
train_bucket_sizes = [len(train_data[b]) for b in range(len(config.buckets))]
|
|
||||||
print("Training bucket sizes:{}".format(train_bucket_sizes))
|
|
||||||
train_total_size = float(sum(train_bucket_sizes))
|
|
||||||
print("Total train size:{}".format(train_total_size))
|
|
||||||
|
|
||||||
# Bucket scale
|
|
||||||
train_buckets_scale = [
|
|
||||||
sum(train_bucket_sizes[:i + 1]) / train_total_size
|
|
||||||
for i in range(len(train_bucket_sizes))]
|
|
||||||
|
|
||||||
# This is the training loop.
|
|
||||||
step_time, loss = 0.0, 0.0
|
|
||||||
current_step = 0
|
|
||||||
previous_losses = []
|
|
||||||
while current_step < config.max_steps:
|
|
||||||
# Choose a bucket according to data distribution. We pick a random
|
|
||||||
# number in [0, 1] and use the corresponding interval in
|
|
||||||
# train_buckets_scale.
|
|
||||||
random_number_01 = np.random.random_sample()
|
|
||||||
bucket_id = min([i for i in range(len(train_buckets_scale))
|
|
||||||
if train_buckets_scale[i] > random_number_01])
|
|
||||||
|
|
||||||
# Get a batch and make a step.
|
|
||||||
start_time = time.time()
|
|
||||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
|
||||||
train_data, bucket_id)
|
|
||||||
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
|
|
||||||
target_weights, bucket_id, False)
|
|
||||||
step_time += (time.time() - start_time) / config.steps_per_checkpoint
|
|
||||||
loss += step_loss / config.steps_per_checkpoint
|
|
||||||
current_step += 1
|
|
||||||
|
|
||||||
# Once in a while, we save checkpoint, print statistics, and run
|
|
||||||
# evals.
|
|
||||||
if current_step % config.steps_per_checkpoint == 0:
|
|
||||||
# Print statistics for the previous epoch.
|
|
||||||
perplexity = math.exp(float(loss)) if loss < 300 else float(
|
|
||||||
"inf")
|
|
||||||
print("global step %d learning rate %.4f step-time %.2f "
|
|
||||||
"perplexity %.2f" % (
|
|
||||||
model.global_step.eval(), model.learning_rate.eval(),
|
|
||||||
step_time, perplexity))
|
|
||||||
# Decrease learning rate if no improvement was seen over last
|
|
||||||
# 3 times.
|
|
||||||
if len(previous_losses) > 2 and loss > max(
|
|
||||||
previous_losses[-3:]):
|
|
||||||
sess.run(model.learning_rate_decay_op)
|
|
||||||
previous_losses.append(loss)
|
|
||||||
# Save checkpoint and zero timer and loss.
|
|
||||||
checkpoint_path = os.path.join(model_path, "translate.ckpt")
|
|
||||||
model.saver.save(sess, checkpoint_path,
|
|
||||||
global_step=model.global_step)
|
|
||||||
step_time, loss = 0.0, 0.0
|
|
||||||
# Run evals on development set and print their perplexity.
|
|
||||||
for bucket_id in range(len(config.buckets)):
|
|
||||||
if len(test_data[bucket_id]) == 0:
|
|
||||||
print(" eval: empty bucket %d" % bucket_id)
|
|
||||||
continue
|
|
||||||
encoder_inputs, decoder_inputs, target_weights = \
|
|
||||||
model.get_batch(test_data, bucket_id)
|
|
||||||
_, eval_loss, _ = model.step(sess, encoder_inputs,
|
|
||||||
decoder_inputs,
|
|
||||||
target_weights, bucket_id,
|
|
||||||
True)
|
|
||||||
eval_ppx = math.exp(
|
|
||||||
float(eval_loss)) if eval_loss < 300 else float(
|
|
||||||
"inf")
|
|
||||||
print(" eval: bucket %d perplexity %.2f" % (
|
|
||||||
bucket_id, eval_ppx))
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
|
||||||
print('Training model...')
|
print('Training model...')
|
||||||
data_reader = CGEDReader(cged_config, cged_config.train_path)
|
data_reader = CGEDReader(train_path)
|
||||||
train(data_reader,
|
input_texts, target_texts = data_reader.build_dataset(train_path)
|
||||||
cged_config.train_path,
|
print('input_texts:', input_texts[0])
|
||||||
cged_config.test_path,
|
print('target_texts:', target_texts[0])
|
||||||
cged_config.model_path)
|
|
||||||
|
input_characters = data_reader.read_vocab(input_texts)
|
||||||
|
target_characters = data_reader.read_vocab(target_texts)
|
||||||
|
num_encoder_tokens = len(input_characters)
|
||||||
|
num_decoder_tokens = len(target_characters)
|
||||||
|
max_encoder_seq_len = max([len(text) for text in input_texts])
|
||||||
|
max_decoder_seq_len = max([len(text) for text in target_texts])
|
||||||
|
|
||||||
|
print('num of samples:', len(input_texts))
|
||||||
|
print('num of unique input tokens:', num_encoder_tokens)
|
||||||
|
print('num of unique output tokens:', num_decoder_tokens)
|
||||||
|
print('max sequence length for inputs:', max_encoder_seq_len)
|
||||||
|
print('max sequence length for outputs:', max_decoder_seq_len)
|
||||||
|
|
||||||
|
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
|
||||||
|
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
|
||||||
|
|
||||||
|
encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_len, num_encoder_tokens), dtype='float32')
|
||||||
|
decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_len, num_decoder_tokens), dtype='float32')
|
||||||
|
decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_len, num_decoder_tokens), dtype='float32')
|
||||||
|
|
||||||
|
# one hot representation
|
||||||
|
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
|
||||||
|
for t, char in enumerate(input_text):
|
||||||
|
encoder_input_data[i, t, input_token_index[char]] = 1.0
|
||||||
|
for t, char in enumerate(target_text):
|
||||||
|
# decoder_target_data is a head of decoder_input_data by one timestep
|
||||||
|
decoder_input_data[i, t, target_token_index[char]] = 1.0
|
||||||
|
if t > 0:
|
||||||
|
decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
|
||||||
|
logger.info("Data loaded.")
|
||||||
|
|
||||||
|
# model
|
||||||
|
logger.info("Training seq2seq model...")
|
||||||
|
model = create_model(num_encoder_tokens, num_decoder_tokens, rnn_hidden_dim)
|
||||||
|
# save
|
||||||
|
callbacks_list = callback(save_model_path, logger)
|
||||||
|
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
|
callbacks=callbacks_list)
|
||||||
|
logger.info("Training has finished.")
|
||||||
|
|
||||||
|
eval(num_encoder_tokens, num_decoder_tokens, rnn_hidden_dim, input_token_index, target_token_index,
|
||||||
|
max_decoder_seq_len, encoder_input_data, input_texts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.app.run() # CPU, i5, about ten hours
|
train(train_path=config.train_path,
|
||||||
|
save_model_path=config.save_model_path,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
epochs=config.epochs,
|
||||||
|
rnn_hidden_dim=config.rnn_hidden_dim)
|
||||||
|
|
|
@ -135,7 +135,10 @@ def segment(sentence, cut_type='word', pos=False, None_flag='O'):
|
||||||
return word_seq, pos_seq
|
return word_seq, pos_seq
|
||||||
elif cut_type == 'char':
|
elif cut_type == 'char':
|
||||||
word_seq = list(sentence)
|
word_seq = list(sentence)
|
||||||
pos_seq = [None_flag for _ in word_seq]
|
pos_seq = []
|
||||||
|
for w in word_seq:
|
||||||
|
w_p = posseg.lcut(w)
|
||||||
|
pos_seq.append(w_p[0].flag)
|
||||||
return word_seq, pos_seq
|
return word_seq, pos_seq
|
||||||
else:
|
else:
|
||||||
if cut_type == 'word':
|
if cut_type == 'word':
|
||||||
|
|
Loading…
Reference in New Issue