add infer of seq2seq model.
This commit is contained in:
parent
8558d658ee
commit
52dd8f17b3
|
@ -21,24 +21,26 @@ class FCEReader(Reader):
|
||||||
dataset_copies=dataset_copies)
|
dataset_copies=dataset_copies)
|
||||||
self.dropout_prob = dropout_prob
|
self.dropout_prob = dropout_prob
|
||||||
self.replacement_prob = replacement_prob
|
self.replacement_prob = replacement_prob
|
||||||
self.UNKNOW_ID = self.token_2_id[FCEReader.UNKNOWN_TOKEN]
|
self.UNKNOWN_ID = self.token_2_id[FCEReader.UNKNOWN_TOKEN]
|
||||||
|
|
||||||
def read_samples_by_string(self, path):
|
def read_samples_by_string(self, path):
|
||||||
for tokens in self.read_tokens(path):
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
source = []
|
line_src = f.readline()
|
||||||
target = []
|
line_dst = f.readline()
|
||||||
for token in tokens:
|
if line_src and line_dst:
|
||||||
target.append(token)
|
source = line_src.lower()[5:].strip().split()
|
||||||
|
target = line_dst.lower()[5:].strip().split()
|
||||||
if self.config.enable_data_dropout:
|
if self.config.enable_data_dropout:
|
||||||
|
new_source = []
|
||||||
|
for token in source:
|
||||||
# Random dropout words from the input
|
# Random dropout words from the input
|
||||||
dropout_token = (token in FCEReader.DROPOUT_TOKENS and random.random() < self.dropout_prob)
|
dropout_token = (token in FCEReader.DROPOUT_TOKENS and random.random() < self.dropout_prob)
|
||||||
replace_token = (token in FCEReader.REPLACEMENTS and random.random() < self.replacement_prob)
|
replace_token = (token in FCEReader.REPLACEMENTS and random.random() < self.replacement_prob)
|
||||||
if replace_token:
|
if replace_token:
|
||||||
source.append(FCEReader.REPLACEMENTS[tokens])
|
new_source.append(FCEReader.REPLACEMENTS[source])
|
||||||
elif not dropout_token:
|
elif not dropout_token:
|
||||||
source.append(token)
|
new_source.append(token)
|
||||||
else:
|
source = new_source
|
||||||
source.append(token)
|
|
||||||
yield source, target
|
yield source, target
|
||||||
|
|
||||||
def unknown_token(self):
|
def unknown_token(self):
|
||||||
|
@ -46,9 +48,9 @@ class FCEReader(Reader):
|
||||||
|
|
||||||
def read_tokens(self, path):
|
def read_tokens(self, path):
|
||||||
i = 0
|
i = 0
|
||||||
with open(path, 'r', 'utf-8') as f:
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
if line in f:
|
|
||||||
if i % 2 == 1:
|
if i % 2 == 1:
|
||||||
|
if line:
|
||||||
yield line.lower()[5:].strip().split()
|
yield line.lower()[5:].strip().split()
|
||||||
i += 1
|
i += 1
|
||||||
|
|
|
@ -0,0 +1,205 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Author: XuMing <xuming624@qq.com>
|
||||||
|
# Brief:
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import seq2seq_config
|
||||||
|
from fce_reader import FCEReader
|
||||||
|
from reader import EOS_ID
|
||||||
|
from train import create_model
|
||||||
|
|
||||||
|
|
||||||
|
def decode(sess, model, data_reader, data_to_decode,
|
||||||
|
corrective_tokens=None, verbose=True):
|
||||||
|
"""
|
||||||
|
Infer the correction sentence
|
||||||
|
:param sess:
|
||||||
|
:param model:
|
||||||
|
:param data_reader:
|
||||||
|
:param data_to_decode: an iterable of token lists representing the input
|
||||||
|
data we want to decode
|
||||||
|
:param corrective_tokens
|
||||||
|
:param verbose:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
model.batch_size = 1
|
||||||
|
corrective_tokens_mask = np.zeros(model.target_vocab_size)
|
||||||
|
corrective_tokens_mask[EOS_ID] = 1.0
|
||||||
|
|
||||||
|
if corrective_tokens is None:
|
||||||
|
corrective_tokens = set()
|
||||||
|
for tokens in corrective_tokens:
|
||||||
|
for token in tokens:
|
||||||
|
corrective_tokens_mask[data_reader.convert_token_to_id(token)] = 1.0
|
||||||
|
|
||||||
|
for tokens in data_to_decode:
|
||||||
|
token_ids = [data_reader.convert_token_to_id(token) for token in tokens]
|
||||||
|
|
||||||
|
# Which bucket does it belong to?
|
||||||
|
matching_buckets = [b for b in range(len(model.buckets))
|
||||||
|
if model.buckets[b][0] > len(token_ids)]
|
||||||
|
if not matching_buckets:
|
||||||
|
# The input string has more tokens than the largest bucket, so we
|
||||||
|
# have to skip it.
|
||||||
|
continue
|
||||||
|
|
||||||
|
bucket_id = min(matching_buckets)
|
||||||
|
|
||||||
|
# Get a 1-element batch to feed the sentence to the model.
|
||||||
|
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||||
|
{bucket_id: [(token_ids, [])]}, bucket_id)
|
||||||
|
|
||||||
|
# Get output logits for the sentence.
|
||||||
|
_, _, output_logits = model.step(
|
||||||
|
sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
|
||||||
|
True, corrective_tokens=corrective_tokens_mask)
|
||||||
|
|
||||||
|
oov_input_tokens = [token for token in tokens if
|
||||||
|
data_reader.is_unknown_token(token)]
|
||||||
|
outputs = []
|
||||||
|
next_oov_token_idx = 0
|
||||||
|
|
||||||
|
for logit in output_logits:
|
||||||
|
max_likelihood_token_id = int(np.argmax(logit, axis=1))
|
||||||
|
# Check if this logit most likely points to the EOS identifier.
|
||||||
|
if max_likelihood_token_id == EOS_ID:
|
||||||
|
break
|
||||||
|
|
||||||
|
token = data_reader.convert_id_to_token(max_likelihood_token_id)
|
||||||
|
if data_reader.is_unknown_token(token):
|
||||||
|
# Replace the "unknown" token with the most probable OOV
|
||||||
|
# token from the input.
|
||||||
|
if next_oov_token_idx < len(oov_input_tokens):
|
||||||
|
# If we still have OOV input tokens available,
|
||||||
|
# pick the next available one.
|
||||||
|
token = oov_input_tokens[next_oov_token_idx]
|
||||||
|
# Advance to the next OOV input token.
|
||||||
|
next_oov_token_idx += 1
|
||||||
|
else:
|
||||||
|
# If we've already used all OOV input tokens,
|
||||||
|
# then we just leave the token as "UNK"
|
||||||
|
pass
|
||||||
|
outputs.append(token)
|
||||||
|
if verbose:
|
||||||
|
decoded_sentence = " ".join(outputs)
|
||||||
|
print("Input: {}".format(" ".join(tokens)))
|
||||||
|
print("Output: {}\n".format(decoded_sentence))
|
||||||
|
yield outputs
|
||||||
|
|
||||||
|
|
||||||
|
def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(),
|
||||||
|
verbose=True):
|
||||||
|
"""Used with InteractiveSession in IPython """
|
||||||
|
return next(decode(sess, model, data_reader, [sentence.split()],
|
||||||
|
corrective_tokens=corrective_tokens, verbose=verbose))
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path,
|
||||||
|
max_samples=None):
|
||||||
|
"""Evaluates the accuracy and BLEU score of the given model."""
|
||||||
|
|
||||||
|
import nltk # Loading here to avoid having to bundle it in lambda.
|
||||||
|
|
||||||
|
# Build a collection of "baseline" and model-based hypotheses, where the
|
||||||
|
# baseline is just the (potentially errant) source sequence.
|
||||||
|
baseline_hypotheses = defaultdict(list) # The model's input
|
||||||
|
model_hypotheses = defaultdict(list) # The actual model's predictions
|
||||||
|
targets = defaultdict(list) # Groundtruth
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
n_samples_by_bucket = defaultdict(int)
|
||||||
|
n_correct_model_by_bucket = defaultdict(int)
|
||||||
|
n_correct_baseline_by_bucket = defaultdict(int)
|
||||||
|
n_samples = 0
|
||||||
|
|
||||||
|
# Evaluate the model against all samples in the test data set.
|
||||||
|
for source, target in data_reader.read_samples_by_string(test_path):
|
||||||
|
matching_buckets = [i for i, bucket in enumerate(model.buckets) if
|
||||||
|
len(source) < bucket[0]]
|
||||||
|
if not matching_buckets:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bucket_id = matching_buckets[0]
|
||||||
|
|
||||||
|
decoding = next(
|
||||||
|
decode(sess, model, data_reader, [source],
|
||||||
|
corrective_tokens=corrective_tokens, verbose=False))
|
||||||
|
model_hypotheses[bucket_id].append(decoding)
|
||||||
|
if decoding == target:
|
||||||
|
n_correct_model_by_bucket[bucket_id] += 1
|
||||||
|
else:
|
||||||
|
errors.append((decoding, target))
|
||||||
|
|
||||||
|
baseline_hypotheses[bucket_id].append(source)
|
||||||
|
if source == target:
|
||||||
|
n_correct_baseline_by_bucket[bucket_id] += 1
|
||||||
|
|
||||||
|
# nltk.corpus_bleu expects a list of one or more reference
|
||||||
|
# translations per sample, so we wrap the target list in another list
|
||||||
|
targets[bucket_id].append([target])
|
||||||
|
|
||||||
|
n_samples_by_bucket[bucket_id] += 1
|
||||||
|
n_samples += 1
|
||||||
|
|
||||||
|
if max_samples is not None and n_samples > max_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Measure the corpus BLEU score and accuracy for the model and baseline
|
||||||
|
# across all buckets.
|
||||||
|
for bucket_id in targets.keys():
|
||||||
|
baseline_bleu_score = nltk.translate.bleu_score.corpus_bleu(
|
||||||
|
targets[bucket_id], baseline_hypotheses[bucket_id])
|
||||||
|
model_bleu_score = nltk.translate.bleu_score.corpus_bleu(
|
||||||
|
targets[bucket_id], model_hypotheses[bucket_id])
|
||||||
|
print("Bucket {}: {}".format(bucket_id, model.buckets[bucket_id]))
|
||||||
|
print("\tBaseline BLEU = {:.4f}\n\tModel BLEU = {:.4f}".format(
|
||||||
|
baseline_bleu_score, model_bleu_score))
|
||||||
|
print("\tBaseline Accuracy: {:.4f}".format(
|
||||||
|
1.0 * n_correct_baseline_by_bucket[bucket_id] /
|
||||||
|
n_samples_by_bucket[bucket_id]))
|
||||||
|
print("\tModel Accuracy: {:.4f}".format(
|
||||||
|
1.0 * n_correct_model_by_bucket[bucket_id] /
|
||||||
|
n_samples_by_bucket[bucket_id]))
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
print('Correcting error...')
|
||||||
|
# Set the model path.
|
||||||
|
model_path = seq2seq_config.model_path
|
||||||
|
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path)
|
||||||
|
|
||||||
|
if seq2seq_config.enable_decode_sentence:
|
||||||
|
# Correct user's sentences.
|
||||||
|
with tf.Session() as session:
|
||||||
|
model = create_model(session, True, model_path, config=seq2seq_config)
|
||||||
|
print("Enter a sentence you'd like to correct")
|
||||||
|
correct_new_sentence = input()
|
||||||
|
while correct_new_sentence.lower() != 'no':
|
||||||
|
decode_sentence(session, model=model, data_reader=data_reader,
|
||||||
|
sentence=correct_new_sentence,
|
||||||
|
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path))
|
||||||
|
print("Enter a sentence you'd like to correct or press NO")
|
||||||
|
correct_new_sentence = input()
|
||||||
|
elif seq2seq_config.enable_test_decode:
|
||||||
|
# Decode test sentences.
|
||||||
|
with tf.Session() as session:
|
||||||
|
model = create_model(session, True, model_path, config=seq2seq_config)
|
||||||
|
print("Loaded model. Beginning decoding.")
|
||||||
|
decodings = decode(session, model=model, data_reader=data_reader,
|
||||||
|
data_to_decode=data_reader.read_tokens(seq2seq_config.test_path),
|
||||||
|
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path))
|
||||||
|
# Write the decoded tokens to stdout.
|
||||||
|
for tokens in decodings:
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.app.run()
|
|
@ -12,9 +12,6 @@ PAD_TOKEN = 'PAD'
|
||||||
EOS_TOKEN = 'EOS'
|
EOS_TOKEN = 'EOS'
|
||||||
GO_TOKEN = 'GO'
|
GO_TOKEN = 'GO'
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
tf.contrib.seq2seq
|
|
||||||
|
|
||||||
|
|
||||||
class Reader:
|
class Reader:
|
||||||
def __init__(self, config, train_path=None, token_2_id=None,
|
def __init__(self, config, train_path=None, token_2_id=None,
|
||||||
|
@ -34,7 +31,7 @@ class Reader:
|
||||||
# Get max_vocabulary size words
|
# Get max_vocabulary size words
|
||||||
count_pairs = sorted(token_counts.items(), key=lambda k: (-k[1], k[0]))
|
count_pairs = sorted(token_counts.items(), key=lambda k: (-k[1], k[0]))
|
||||||
vocab, _ = list(zip(*count_pairs))
|
vocab, _ = list(zip(*count_pairs))
|
||||||
|
vocab = list(vocab)
|
||||||
# Insert the special tokens to the beginning
|
# Insert the special tokens to the beginning
|
||||||
vocab[0:0] = special_tokens
|
vocab[0:0] = special_tokens
|
||||||
full_token_id = list(zip(vocab, range(len(vocab))))
|
full_token_id = list(zip(vocab, range(len(vocab))))
|
||||||
|
@ -71,7 +68,7 @@ class Reader:
|
||||||
:param token:
|
:param token:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
token_id = token if token in self.token_2_id else self.unknow_token()
|
token_id = token if token in self.token_2_id else self.unknown_token()
|
||||||
return self.token_2_id[token_id]
|
return self.token_2_id[token_id]
|
||||||
|
|
||||||
def convert_id_2_token(self, id):
|
def convert_id_2_token(self, id):
|
||||||
|
@ -82,13 +79,13 @@ class Reader:
|
||||||
"""
|
"""
|
||||||
return self.id_2_token[id]
|
return self.id_2_token[id]
|
||||||
|
|
||||||
def is_unknow_token(self, token):
|
def is_unknown_token(self, token):
|
||||||
"""
|
"""
|
||||||
True if the given token is out of vocabulary
|
True if the given token is out of vocabulary
|
||||||
:param token:
|
:param token:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return token not in self.token_2_id or token == self.unknow_token()
|
return token not in self.token_2_id or token == self.unknown_token()
|
||||||
|
|
||||||
def sentence_2_token_ids(self, sentence):
|
def sentence_2_token_ids(self, sentence):
|
||||||
"""
|
"""
|
||||||
|
@ -106,14 +103,13 @@ class Reader:
|
||||||
"""
|
"""
|
||||||
return [self.convert_id_2_token(w) for w in word_ids]
|
return [self.convert_id_2_token(w) for w in word_ids]
|
||||||
|
|
||||||
|
|
||||||
def read_samples(self, path):
|
def read_samples(self, path):
|
||||||
"""
|
"""
|
||||||
Read sample of path's data
|
Read sample of path's data
|
||||||
:param path:
|
:param path:
|
||||||
:return: generate list
|
:return: generate list
|
||||||
"""
|
"""
|
||||||
for source_words, target_words in self.read_sampless_by_string(path):
|
for source_words, target_words in self.read_samples_by_string(path):
|
||||||
source = [self.convert_token_2_id(w) for w in source_words]
|
source = [self.convert_token_2_id(w) for w in source_words]
|
||||||
target = [self.convert_token_2_id(w) for w in target_words]
|
target = [self.convert_token_2_id(w) for w in target_words]
|
||||||
target.append(EOS_ID)
|
target.append(EOS_ID)
|
||||||
|
|
|
@ -5,7 +5,7 @@ train_path = '../data/en/fce/fce_train.txt' # Training data path.
|
||||||
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
||||||
test_path = '../data/en/fce/fce_test.txt'
|
test_path = '../data/en/fce/fce_test.txt'
|
||||||
|
|
||||||
output_path = './output' # Path of the model saved, default is output_path/model
|
model_path = './output_model' # Path of the model saved, default is output_path/model
|
||||||
enable_data_dropout = False
|
enable_data_dropout = False
|
||||||
num_steps = 3000 # Number of steps to train.
|
num_steps = 3000 # Number of steps to train.
|
||||||
decode_sentence = False # Whether we should decode sentences of the user.
|
decode_sentence = False # Whether we should decode sentences of the user.
|
||||||
|
@ -13,10 +13,10 @@ decode_sentence = False # Whether we should decode sentences of the user.
|
||||||
# FCEConfig
|
# FCEConfig
|
||||||
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
||||||
|
|
||||||
steps_per_checkpoint = 20
|
steps_per_checkpoint = 100
|
||||||
max_steps = 100
|
max_steps = 10000
|
||||||
|
|
||||||
max_vocabulary_size = 10000
|
max_vocab_size = 10000
|
||||||
|
|
||||||
size = 128
|
size = 128
|
||||||
num_layers = 2
|
num_layers = 2
|
||||||
|
@ -27,3 +27,11 @@ learning_rate_decay_factor = 0.99
|
||||||
|
|
||||||
use_lstm = False
|
use_lstm = False
|
||||||
use_rms_prop = False
|
use_rms_prop = False
|
||||||
|
|
||||||
|
enable_decode_sentence = True # Test with input error sentence
|
||||||
|
enable_test_decode = True # Test with test set
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(model_path)
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Author: XuMing <xuming624@qq.com>
|
||||||
|
# Brief:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def get_ckpt_path(model_path):
|
||||||
|
ckpt = tf.train.get_checkpoint_state(model_path)
|
||||||
|
ckpt_path = ""
|
||||||
|
if ckpt:
|
||||||
|
ckpt_file = ckpt.model_checkpoint_path.split('/')[-1]
|
||||||
|
ckpt_path = os.path.join(model_path, ckpt_file)
|
||||||
|
return ckpt_path
|
|
@ -4,17 +4,16 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from reader import EOS_ID
|
|
||||||
from fce_reader import FCEReader
|
|
||||||
from corrector_model import CorrectorModel
|
|
||||||
|
|
||||||
import seq2seq_config
|
import seq2seq_config
|
||||||
|
from corrector_model import CorrectorModel
|
||||||
|
from fce_reader import FCEReader
|
||||||
|
from tf_util import get_ckpt_path
|
||||||
|
|
||||||
|
|
||||||
def create_model(session, forward_only, model_path, config=seq2seq_config):
|
def create_model(session, forward_only, model_path, config=seq2seq_config):
|
||||||
|
@ -27,6 +26,119 @@ def create_model(session, forward_only, model_path, config=seq2seq_config):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
model = CorrectorModel(
|
model = CorrectorModel(
|
||||||
config.max_vocabulary_size,
|
config.max_vocab_size,
|
||||||
config.max_vocabulary_size,
|
config.max_vocab_size,
|
||||||
)
|
config.buckets,
|
||||||
|
config.size,
|
||||||
|
config.num_layers,
|
||||||
|
config.max_gradient_norm,
|
||||||
|
config.batch_size,
|
||||||
|
config.learning_rate,
|
||||||
|
config.learning_rate_decay_factor,
|
||||||
|
config.use_lstm,
|
||||||
|
forward_only=forward_only,
|
||||||
|
config=config)
|
||||||
|
ckpt_path = get_ckpt_path(model_path)
|
||||||
|
if ckpt_path:
|
||||||
|
print("Read model parameters from %s" % ckpt_path)
|
||||||
|
model.saver.restore(session, ckpt_path)
|
||||||
|
else:
|
||||||
|
print('Create model...')
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def train(data_reader, train_path, test_path, model_path):
|
||||||
|
print('Read data, train:{0}, test:{1}'.format(train_path, test_path))
|
||||||
|
config = data_reader.config
|
||||||
|
train_data = data_reader.build_dataset(train_path)
|
||||||
|
test_data = data_reader.build_dataset(test_path)
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Create model
|
||||||
|
print('Create %d layers of %d units.' % (config.num_layers, config.size))
|
||||||
|
model = create_model(sess, False, model_path, config=config)
|
||||||
|
# Read data into buckets
|
||||||
|
train_bucket_sizes = [len(train_data[b]) for b in range(len(config.buckets))]
|
||||||
|
print("Training bucket sizes:{}".format(train_bucket_sizes))
|
||||||
|
train_total_size = float(sum(train_bucket_sizes))
|
||||||
|
print("Total train size:{}".format(train_total_size))
|
||||||
|
|
||||||
|
# Bucket scale
|
||||||
|
train_buckets_scale = [
|
||||||
|
sum(train_bucket_sizes[:i + 1]) / train_total_size
|
||||||
|
for i in range(len(train_bucket_sizes))]
|
||||||
|
|
||||||
|
# This is the training loop.
|
||||||
|
step_time, loss = 0.0, 0.0
|
||||||
|
current_step = 0
|
||||||
|
previous_losses = []
|
||||||
|
while current_step < config.max_steps:
|
||||||
|
# Choose a bucket according to data distribution. We pick a random
|
||||||
|
# number in [0, 1] and use the corresponding interval in
|
||||||
|
# train_buckets_scale.
|
||||||
|
random_number_01 = np.random.random_sample()
|
||||||
|
bucket_id = min([i for i in range(len(train_buckets_scale))
|
||||||
|
if train_buckets_scale[i] > random_number_01])
|
||||||
|
|
||||||
|
# Get a batch and make a step.
|
||||||
|
start_time = time.time()
|
||||||
|
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||||
|
train_data, bucket_id)
|
||||||
|
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
|
||||||
|
target_weights, bucket_id, False)
|
||||||
|
step_time += (time.time() - start_time) / config.steps_per_checkpoint
|
||||||
|
loss += step_loss / config.steps_per_checkpoint
|
||||||
|
current_step += 1
|
||||||
|
|
||||||
|
# Once in a while, we save checkpoint, print statistics, and run
|
||||||
|
# evals.
|
||||||
|
if current_step % config.steps_per_checkpoint == 0:
|
||||||
|
# Print statistics for the previous epoch.
|
||||||
|
perplexity = math.exp(float(loss)) if loss < 300 else float(
|
||||||
|
"inf")
|
||||||
|
print("global step %d learning rate %.4f step-time %.2f "
|
||||||
|
"perplexity %.2f" % (
|
||||||
|
model.global_step.eval(), model.learning_rate.eval(),
|
||||||
|
step_time, perplexity))
|
||||||
|
# Decrease learning rate if no improvement was seen over last
|
||||||
|
# 3 times.
|
||||||
|
if len(previous_losses) > 2 and loss > max(
|
||||||
|
previous_losses[-3:]):
|
||||||
|
sess.run(model.learning_rate_decay_op)
|
||||||
|
previous_losses.append(loss)
|
||||||
|
# Save checkpoint and zero timer and loss.
|
||||||
|
checkpoint_path = os.path.join(model_path, "translate.ckpt")
|
||||||
|
model.saver.save(sess, checkpoint_path,
|
||||||
|
global_step=model.global_step)
|
||||||
|
step_time, loss = 0.0, 0.0
|
||||||
|
# Run evals on development set and print their perplexity.
|
||||||
|
for bucket_id in range(len(config.buckets)):
|
||||||
|
if len(test_data[bucket_id]) == 0:
|
||||||
|
print(" eval: empty bucket %d" % (bucket_id))
|
||||||
|
continue
|
||||||
|
encoder_inputs, decoder_inputs, target_weights = \
|
||||||
|
model.get_batch(test_data, bucket_id)
|
||||||
|
_, eval_loss, _ = model.step(sess, encoder_inputs,
|
||||||
|
decoder_inputs,
|
||||||
|
target_weights, bucket_id,
|
||||||
|
True)
|
||||||
|
eval_ppx = math.exp(
|
||||||
|
float(eval_loss)) if eval_loss < 300 else float(
|
||||||
|
"inf")
|
||||||
|
print(" eval: bucket %d perplexity %.2f" % (
|
||||||
|
bucket_id, eval_ppx))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
print('Training model...')
|
||||||
|
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path)
|
||||||
|
train(data_reader,
|
||||||
|
seq2seq_config.train_path,
|
||||||
|
seq2seq_config.val_path,
|
||||||
|
seq2seq_config.model_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.app.run()
|
||||||
|
|
Loading…
Reference in New Issue