add infer of seq2seq model.

This commit is contained in:
xuming06 2018-03-30 01:25:36 +08:00
parent 8558d658ee
commit 52dd8f17b3
6 changed files with 379 additions and 41 deletions

View File

@ -21,24 +21,26 @@ class FCEReader(Reader):
dataset_copies=dataset_copies) dataset_copies=dataset_copies)
self.dropout_prob = dropout_prob self.dropout_prob = dropout_prob
self.replacement_prob = replacement_prob self.replacement_prob = replacement_prob
self.UNKNOW_ID = self.token_2_id[FCEReader.UNKNOWN_TOKEN] self.UNKNOWN_ID = self.token_2_id[FCEReader.UNKNOWN_TOKEN]
def read_samples_by_string(self, path): def read_samples_by_string(self, path):
for tokens in self.read_tokens(path): with open(path, 'r', encoding='utf-8') as f:
source = [] line_src = f.readline()
target = [] line_dst = f.readline()
for token in tokens: if line_src and line_dst:
target.append(token) source = line_src.lower()[5:].strip().split()
target = line_dst.lower()[5:].strip().split()
if self.config.enable_data_dropout: if self.config.enable_data_dropout:
new_source = []
for token in source:
# Random dropout words from the input # Random dropout words from the input
dropout_token = (token in FCEReader.DROPOUT_TOKENS and random.random() < self.dropout_prob) dropout_token = (token in FCEReader.DROPOUT_TOKENS and random.random() < self.dropout_prob)
replace_token = (token in FCEReader.REPLACEMENTS and random.random() < self.replacement_prob) replace_token = (token in FCEReader.REPLACEMENTS and random.random() < self.replacement_prob)
if replace_token: if replace_token:
source.append(FCEReader.REPLACEMENTS[tokens]) new_source.append(FCEReader.REPLACEMENTS[source])
elif not dropout_token: elif not dropout_token:
source.append(token) new_source.append(token)
else: source = new_source
source.append(token)
yield source, target yield source, target
def unknown_token(self): def unknown_token(self):
@ -46,9 +48,9 @@ class FCEReader(Reader):
def read_tokens(self, path): def read_tokens(self, path):
i = 0 i = 0
with open(path, 'r', 'utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
if line in f:
if i % 2 == 1: if i % 2 == 1:
if line:
yield line.lower()[5:].strip().split() yield line.lower()[5:].strip().split()
i += 1 i += 1

View File

@ -0,0 +1,205 @@
# -*- coding: utf-8 -*-
# Author: XuMing <xuming624@qq.com>
# Brief:
import sys
from collections import defaultdict
import numpy as np
import tensorflow as tf
import seq2seq_config
from fce_reader import FCEReader
from reader import EOS_ID
from train import create_model
def decode(sess, model, data_reader, data_to_decode,
corrective_tokens=None, verbose=True):
"""
Infer the correction sentence
:param sess:
:param model:
:param data_reader:
:param data_to_decode: an iterable of token lists representing the input
data we want to decode
:param corrective_tokens
:param verbose:
:return:
"""
model.batch_size = 1
corrective_tokens_mask = np.zeros(model.target_vocab_size)
corrective_tokens_mask[EOS_ID] = 1.0
if corrective_tokens is None:
corrective_tokens = set()
for tokens in corrective_tokens:
for token in tokens:
corrective_tokens_mask[data_reader.convert_token_to_id(token)] = 1.0
for tokens in data_to_decode:
token_ids = [data_reader.convert_token_to_id(token) for token in tokens]
# Which bucket does it belong to?
matching_buckets = [b for b in range(len(model.buckets))
if model.buckets[b][0] > len(token_ids)]
if not matching_buckets:
# The input string has more tokens than the largest bucket, so we
# have to skip it.
continue
bucket_id = min(matching_buckets)
# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
{bucket_id: [(token_ids, [])]}, bucket_id)
# Get output logits for the sentence.
_, _, output_logits = model.step(
sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
True, corrective_tokens=corrective_tokens_mask)
oov_input_tokens = [token for token in tokens if
data_reader.is_unknown_token(token)]
outputs = []
next_oov_token_idx = 0
for logit in output_logits:
max_likelihood_token_id = int(np.argmax(logit, axis=1))
# Check if this logit most likely points to the EOS identifier.
if max_likelihood_token_id == EOS_ID:
break
token = data_reader.convert_id_to_token(max_likelihood_token_id)
if data_reader.is_unknown_token(token):
# Replace the "unknown" token with the most probable OOV
# token from the input.
if next_oov_token_idx < len(oov_input_tokens):
# If we still have OOV input tokens available,
# pick the next available one.
token = oov_input_tokens[next_oov_token_idx]
# Advance to the next OOV input token.
next_oov_token_idx += 1
else:
# If we've already used all OOV input tokens,
# then we just leave the token as "UNK"
pass
outputs.append(token)
if verbose:
decoded_sentence = " ".join(outputs)
print("Input: {}".format(" ".join(tokens)))
print("Output: {}\n".format(decoded_sentence))
yield outputs
def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(),
verbose=True):
"""Used with InteractiveSession in IPython """
return next(decode(sess, model, data_reader, [sentence.split()],
corrective_tokens=corrective_tokens, verbose=verbose))
def evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path,
max_samples=None):
"""Evaluates the accuracy and BLEU score of the given model."""
import nltk # Loading here to avoid having to bundle it in lambda.
# Build a collection of "baseline" and model-based hypotheses, where the
# baseline is just the (potentially errant) source sequence.
baseline_hypotheses = defaultdict(list) # The model's input
model_hypotheses = defaultdict(list) # The actual model's predictions
targets = defaultdict(list) # Groundtruth
errors = []
n_samples_by_bucket = defaultdict(int)
n_correct_model_by_bucket = defaultdict(int)
n_correct_baseline_by_bucket = defaultdict(int)
n_samples = 0
# Evaluate the model against all samples in the test data set.
for source, target in data_reader.read_samples_by_string(test_path):
matching_buckets = [i for i, bucket in enumerate(model.buckets) if
len(source) < bucket[0]]
if not matching_buckets:
continue
bucket_id = matching_buckets[0]
decoding = next(
decode(sess, model, data_reader, [source],
corrective_tokens=corrective_tokens, verbose=False))
model_hypotheses[bucket_id].append(decoding)
if decoding == target:
n_correct_model_by_bucket[bucket_id] += 1
else:
errors.append((decoding, target))
baseline_hypotheses[bucket_id].append(source)
if source == target:
n_correct_baseline_by_bucket[bucket_id] += 1
# nltk.corpus_bleu expects a list of one or more reference
# translations per sample, so we wrap the target list in another list
targets[bucket_id].append([target])
n_samples_by_bucket[bucket_id] += 1
n_samples += 1
if max_samples is not None and n_samples > max_samples:
break
# Measure the corpus BLEU score and accuracy for the model and baseline
# across all buckets.
for bucket_id in targets.keys():
baseline_bleu_score = nltk.translate.bleu_score.corpus_bleu(
targets[bucket_id], baseline_hypotheses[bucket_id])
model_bleu_score = nltk.translate.bleu_score.corpus_bleu(
targets[bucket_id], model_hypotheses[bucket_id])
print("Bucket {}: {}".format(bucket_id, model.buckets[bucket_id]))
print("\tBaseline BLEU = {:.4f}\n\tModel BLEU = {:.4f}".format(
baseline_bleu_score, model_bleu_score))
print("\tBaseline Accuracy: {:.4f}".format(
1.0 * n_correct_baseline_by_bucket[bucket_id] /
n_samples_by_bucket[bucket_id]))
print("\tModel Accuracy: {:.4f}".format(
1.0 * n_correct_model_by_bucket[bucket_id] /
n_samples_by_bucket[bucket_id]))
return errors
def main(_):
print('Correcting error...')
# Set the model path.
model_path = seq2seq_config.model_path
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path)
if seq2seq_config.enable_decode_sentence:
# Correct user's sentences.
with tf.Session() as session:
model = create_model(session, True, model_path, config=seq2seq_config)
print("Enter a sentence you'd like to correct")
correct_new_sentence = input()
while correct_new_sentence.lower() != 'no':
decode_sentence(session, model=model, data_reader=data_reader,
sentence=correct_new_sentence,
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path))
print("Enter a sentence you'd like to correct or press NO")
correct_new_sentence = input()
elif seq2seq_config.enable_test_decode:
# Decode test sentences.
with tf.Session() as session:
model = create_model(session, True, model_path, config=seq2seq_config)
print("Loaded model. Beginning decoding.")
decodings = decode(session, model=model, data_reader=data_reader,
data_to_decode=data_reader.read_tokens(seq2seq_config.test_path),
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path))
# Write the decoded tokens to stdout.
for tokens in decodings:
sys.stdout.flush()
if __name__ == "__main__":
tf.app.run()

View File

@ -12,9 +12,6 @@ PAD_TOKEN = 'PAD'
EOS_TOKEN = 'EOS' EOS_TOKEN = 'EOS'
GO_TOKEN = 'GO' GO_TOKEN = 'GO'
import tensorflow as tf
tf.contrib.seq2seq
class Reader: class Reader:
def __init__(self, config, train_path=None, token_2_id=None, def __init__(self, config, train_path=None, token_2_id=None,
@ -34,7 +31,7 @@ class Reader:
# Get max_vocabulary size words # Get max_vocabulary size words
count_pairs = sorted(token_counts.items(), key=lambda k: (-k[1], k[0])) count_pairs = sorted(token_counts.items(), key=lambda k: (-k[1], k[0]))
vocab, _ = list(zip(*count_pairs)) vocab, _ = list(zip(*count_pairs))
vocab = list(vocab)
# Insert the special tokens to the beginning # Insert the special tokens to the beginning
vocab[0:0] = special_tokens vocab[0:0] = special_tokens
full_token_id = list(zip(vocab, range(len(vocab)))) full_token_id = list(zip(vocab, range(len(vocab))))
@ -71,7 +68,7 @@ class Reader:
:param token: :param token:
:return: :return:
""" """
token_id = token if token in self.token_2_id else self.unknow_token() token_id = token if token in self.token_2_id else self.unknown_token()
return self.token_2_id[token_id] return self.token_2_id[token_id]
def convert_id_2_token(self, id): def convert_id_2_token(self, id):
@ -82,13 +79,13 @@ class Reader:
""" """
return self.id_2_token[id] return self.id_2_token[id]
def is_unknow_token(self, token): def is_unknown_token(self, token):
""" """
True if the given token is out of vocabulary True if the given token is out of vocabulary
:param token: :param token:
:return: :return:
""" """
return token not in self.token_2_id or token == self.unknow_token() return token not in self.token_2_id or token == self.unknown_token()
def sentence_2_token_ids(self, sentence): def sentence_2_token_ids(self, sentence):
""" """
@ -106,14 +103,13 @@ class Reader:
""" """
return [self.convert_id_2_token(w) for w in word_ids] return [self.convert_id_2_token(w) for w in word_ids]
def read_samples(self, path): def read_samples(self, path):
""" """
Read sample of path's data Read sample of path's data
:param path: :param path:
:return: generate list :return: generate list
""" """
for source_words, target_words in self.read_sampless_by_string(path): for source_words, target_words in self.read_samples_by_string(path):
source = [self.convert_token_2_id(w) for w in source_words] source = [self.convert_token_2_id(w) for w in source_words]
target = [self.convert_token_2_id(w) for w in target_words] target = [self.convert_token_2_id(w) for w in target_words]
target.append(EOS_ID) target.append(EOS_ID)

View File

@ -5,7 +5,7 @@ train_path = '../data/en/fce/fce_train.txt' # Training data path.
val_path = '../data/en/fce/fce_val.txt' # Validation data path. val_path = '../data/en/fce/fce_val.txt' # Validation data path.
test_path = '../data/en/fce/fce_test.txt' test_path = '../data/en/fce/fce_test.txt'
output_path = './output' # Path of the model saved, default is output_path/model model_path = './output_model' # Path of the model saved, default is output_path/model
enable_data_dropout = False enable_data_dropout = False
num_steps = 3000 # Number of steps to train. num_steps = 3000 # Number of steps to train.
decode_sentence = False # Whether we should decode sentences of the user. decode_sentence = False # Whether we should decode sentences of the user.
@ -13,10 +13,10 @@ decode_sentence = False # Whether we should decode sentences of the user.
# FCEConfig # FCEConfig
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency. buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
steps_per_checkpoint = 20 steps_per_checkpoint = 100
max_steps = 100 max_steps = 10000
max_vocabulary_size = 10000 max_vocab_size = 10000
size = 128 size = 128
num_layers = 2 num_layers = 2
@ -27,3 +27,11 @@ learning_rate_decay_factor = 0.99
use_lstm = False use_lstm = False
use_rms_prop = False use_rms_prop = False
enable_decode_sentence = True # Test with input error sentence
enable_test_decode = True # Test with test set
import os
if not os.path.exists(model_path):
os.makedirs(model_path)

View File

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Author: XuMing <xuming624@qq.com>
# Brief:
import os
import tensorflow as tf
def get_ckpt_path(model_path):
ckpt = tf.train.get_checkpoint_state(model_path)
ckpt_path = ""
if ckpt:
ckpt_file = ckpt.model_checkpoint_path.split('/')[-1]
ckpt_path = os.path.join(model_path, ckpt_file)
return ckpt_path

View File

@ -4,17 +4,16 @@
import math import math
import os import os
import shutil
import sys import sys
import time import time
from collections import defaultdict
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from reader import EOS_ID
from fce_reader import FCEReader
from corrector_model import CorrectorModel
import seq2seq_config import seq2seq_config
from corrector_model import CorrectorModel
from fce_reader import FCEReader
from tf_util import get_ckpt_path
def create_model(session, forward_only, model_path, config=seq2seq_config): def create_model(session, forward_only, model_path, config=seq2seq_config):
@ -27,6 +26,119 @@ def create_model(session, forward_only, model_path, config=seq2seq_config):
:return: :return:
""" """
model = CorrectorModel( model = CorrectorModel(
config.max_vocabulary_size, config.max_vocab_size,
config.max_vocabulary_size, config.max_vocab_size,
) config.buckets,
config.size,
config.num_layers,
config.max_gradient_norm,
config.batch_size,
config.learning_rate,
config.learning_rate_decay_factor,
config.use_lstm,
forward_only=forward_only,
config=config)
ckpt_path = get_ckpt_path(model_path)
if ckpt_path:
print("Read model parameters from %s" % ckpt_path)
model.saver.restore(session, ckpt_path)
else:
print('Create model...')
session.run(tf.global_variables_initializer())
return model
def train(data_reader, train_path, test_path, model_path):
print('Read data, train:{0}, test:{1}'.format(train_path, test_path))
config = data_reader.config
train_data = data_reader.build_dataset(train_path)
test_data = data_reader.build_dataset(test_path)
with tf.Session() as sess:
# Create model
print('Create %d layers of %d units.' % (config.num_layers, config.size))
model = create_model(sess, False, model_path, config=config)
# Read data into buckets
train_bucket_sizes = [len(train_data[b]) for b in range(len(config.buckets))]
print("Training bucket sizes:{}".format(train_bucket_sizes))
train_total_size = float(sum(train_bucket_sizes))
print("Total train size:{}".format(train_total_size))
# Bucket scale
train_buckets_scale = [
sum(train_bucket_sizes[:i + 1]) / train_total_size
for i in range(len(train_bucket_sizes))]
# This is the training loop.
step_time, loss = 0.0, 0.0
current_step = 0
previous_losses = []
while current_step < config.max_steps:
# Choose a bucket according to data distribution. We pick a random
# number in [0, 1] and use the corresponding interval in
# train_buckets_scale.
random_number_01 = np.random.random_sample()
bucket_id = min([i for i in range(len(train_buckets_scale))
if train_buckets_scale[i] > random_number_01])
# Get a batch and make a step.
start_time = time.time()
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
train_data, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, False)
step_time += (time.time() - start_time) / config.steps_per_checkpoint
loss += step_loss / config.steps_per_checkpoint
current_step += 1
# Once in a while, we save checkpoint, print statistics, and run
# evals.
if current_step % config.steps_per_checkpoint == 0:
# Print statistics for the previous epoch.
perplexity = math.exp(float(loss)) if loss < 300 else float(
"inf")
print("global step %d learning rate %.4f step-time %.2f "
"perplexity %.2f" % (
model.global_step.eval(), model.learning_rate.eval(),
step_time, perplexity))
# Decrease learning rate if no improvement was seen over last
# 3 times.
if len(previous_losses) > 2 and loss > max(
previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
# Save checkpoint and zero timer and loss.
checkpoint_path = os.path.join(model_path, "translate.ckpt")
model.saver.save(sess, checkpoint_path,
global_step=model.global_step)
step_time, loss = 0.0, 0.0
# Run evals on development set and print their perplexity.
for bucket_id in range(len(config.buckets)):
if len(test_data[bucket_id]) == 0:
print(" eval: empty bucket %d" % (bucket_id))
continue
encoder_inputs, decoder_inputs, target_weights = \
model.get_batch(test_data, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs,
decoder_inputs,
target_weights, bucket_id,
True)
eval_ppx = math.exp(
float(eval_loss)) if eval_loss < 300 else float(
"inf")
print(" eval: bucket %d perplexity %.2f" % (
bucket_id, eval_ppx))
sys.stdout.flush()
def main(_):
print('Training model...')
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path)
train(data_reader,
seq2seq_config.train_path,
seq2seq_config.val_path,
seq2seq_config.model_path)
if __name__ == "__main__":
tf.app.run()