add chinese correct check by seq2seq model.
This commit is contained in:
parent
e5f4b0af60
commit
10f2385aca
|
@ -0,0 +1,116 @@
|
||||||
|
<ROOT>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200405109523200554_2_1x1">
|
||||||
|
他们知不道吸烟对未成年年的影响会造成的各种害处。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
他们不知道吸烟对未成年人会造成的各种伤害。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="3" end_off="4" type="W"></ERROR>
|
||||||
|
<ERROR start_off="12" end_off="12" type="S"></ERROR>
|
||||||
|
<ERROR start_off="13" end_off="15" type="R"></ERROR>
|
||||||
|
<ERROR start_off="22" end_off="23" type="S"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200505109634201470_2_4x2">
|
||||||
|
从此,父母亲就会教咱们爬行、走路、叫爸爸妈妈。到我们长大了,我们开始从妈妈爸爸身上模仿行为,譬如学习爸妈走路时的高雅步姿,坐姿、礼貌、习惯……渐渐地你会发觉一直好象是自己好奇,觉得有趣才会照着做,模仿着双亲,但不知不觉间他们影响到孩子们的不再是表面的行为,思想,心态、待人接物上我们都领受了不少,这的确会影响我们的成长。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
从此,父母亲就会教我们爬行、走路、叫爸爸妈妈。到我们长大了,我们开始从妈妈爸爸身上模仿行为,譬如学习爸妈走路时的高雅步姿,坐姿、礼貌、习惯……渐渐地你会发觉好象是自己一直好奇,觉得有趣才会照着做,模仿着双亲,但不知不觉间他们影响到孩子们的不再是表面的行为,思想,心态、待人接物上我们都领受了不少,这的确会影响我们的成长。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="10" end_off="11" type="S"></ERROR>
|
||||||
|
<ERROR start_off="79" end_off="85" type="W"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200510111523200014_2_1x1">
|
||||||
|
有些不喜欢流行歌曲的人也说流行歌曲能引起不好的作用。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
有些不喜欢流行歌曲的人也说流行歌曲能引起不好的后果。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="24" end_off="25" type="S"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200505204525200257_2_2x2">
|
||||||
|
如果它呈出不太香的颜色,那就意味着它颜色的来源——你,就是教给它呈出那样的味道的。如果你将一块“白布”成功地染上的话,会出什么样的颜色呢?
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
如果它呈出不太美的颜色,那就意味着它颜色的来源——你,就是教给它呈出那样的颜色的人没教好。如果你将一块“白布”成功地染上颜色的话,会出什么样的颜色呢?
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="8" end_off="8" type="S"></ERROR>
|
||||||
|
<ERROR start_off="34" end_off="35" type="S"></ERROR>
|
||||||
|
<ERROR start_off="41" end_off="41" type="M"></ERROR>
|
||||||
|
<ERROR start_off="57" end_off="57" type="M"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200405109525200464_2_5x1">
|
||||||
|
这都是他们自己引起的,埋怨什么呢?
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
这都是他们自己造成的,埋怨什么呢?
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="8" end_off="9" type="S"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200310576525200063_2_6x1">
|
||||||
|
他长大以后突然产生要跟女孩交玩儿的念头。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
他长大以后突然产生要跟女孩儿玩耍的念头。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="14" end_off="16" type="S"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200307271523200070_2_2x1">
|
||||||
|
可是我觉得饥饿是用科学机技来能解决问题,所以我认为吃“绿色食品”是还是重要的问题。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
可是我觉得饥饿是用科学技术能解决的问题,所以我认为吃“绿色食品”是更重要的问题。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="12" end_off="13" type="S"></ERROR>
|
||||||
|
<ERROR start_off="14" end_off="14" type="R"></ERROR>
|
||||||
|
<ERROR start_off="18" end_off="18" type="M"></ERROR>
|
||||||
|
<ERROR start_off="34" end_off="35" type="S"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200405109523100546_2_1x1">
|
||||||
|
在韩国最近很流行不允许的电视节目,这节目说公共场所抽烟是不道德的行为。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
在韩国最近不允许抽烟的电视节目很流行,这些节目说在公共场所抽烟是不道德的行为。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="6" end_off="16" type="W"></ERROR>
|
||||||
|
<ERROR start_off="12" end_off="12" type="M"></ERROR>
|
||||||
|
<ERROR start_off="19" end_off="19" type="M"></ERROR>
|
||||||
|
<ERROR start_off="22" end_off="22" type="M"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200510302523100195_2_9x2">
|
||||||
|
如果他喜欢听什么,就能听什么。因为这种现象,因韩流而得到的经济利益也很多。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
他喜欢听什么,就能听什么。因为这种现象,韩国因韩流而得到的经济利益也很多。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="1" end_off="2" type="R"></ERROR>
|
||||||
|
<ERROR start_off="23" end_off="23" type="M"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200505109922201218_2_2x1">
|
||||||
|
从环境里小孩子能快速地学或模仿他所见到的事物。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
小孩子能从环境里快速地学或模仿他所见到的事物。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="1" end_off="8" type="W"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
<DOC>
|
||||||
|
<TEXT id="200505522522250013_2_7x1">
|
||||||
|
认识到结婚过程不满六个月,也可以说我的故事中我是主动的。
|
||||||
|
</TEXT>
|
||||||
|
<CORRECTION>
|
||||||
|
认识到结婚的过程不满六个月,也可以说在我的故事中我是主动的。
|
||||||
|
</CORRECTION>
|
||||||
|
<ERROR start_off="6" end_off="6" type="M"></ERROR>
|
||||||
|
<ERROR start_off="18" end_off="18" type="M"></ERROR>
|
||||||
|
</DOC>
|
||||||
|
</ROOT>
|
|
@ -0,0 +1,35 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Author: XuMing <xuming624@qq.com>
|
||||||
|
# Brief: Use CGED corpus
|
||||||
|
|
||||||
|
# CGED chinese corpus
|
||||||
|
train_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' # Training data path.
|
||||||
|
val_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml' # Validation data path.
|
||||||
|
test_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml'
|
||||||
|
|
||||||
|
model_path = './output/cged_model' # Path of the model saved, default is output_path/model
|
||||||
|
enable_special_error = False
|
||||||
|
num_steps = 3000 # Number of steps to train.
|
||||||
|
decode_sentence = False # Whether we should decode sentences of the user.
|
||||||
|
|
||||||
|
# Config
|
||||||
|
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
||||||
|
steps_per_checkpoint = 100
|
||||||
|
max_steps = 2000
|
||||||
|
max_vocab_size = 10000
|
||||||
|
size = 512
|
||||||
|
num_layers = 4
|
||||||
|
max_gradient_norm = 5.0
|
||||||
|
batch_size = 64
|
||||||
|
learning_rate = 0.5
|
||||||
|
learning_rate_decay_factor = 0.99
|
||||||
|
use_lstm = False
|
||||||
|
use_rms_prop = False
|
||||||
|
|
||||||
|
enable_decode_sentence = False # Test with input error sentence
|
||||||
|
enable_test_decode = True # Test with test set
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(model_path)
|
|
@ -1,9 +1,11 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Author: XuMing <xuming624@qq.com>
|
# Author: XuMing <xuming624@qq.com>
|
||||||
# Brief:
|
# Brief: Corpus for model
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from reader import Reader, PAD_TOKEN, EOS_TOKEN, GO_TOKEN
|
from reader import Reader, PAD_TOKEN, EOS_TOKEN, GO_TOKEN
|
||||||
|
from xml.dom import minidom
|
||||||
|
from text_util import segment
|
||||||
|
|
||||||
|
|
||||||
class FCEReader(Reader):
|
class FCEReader(Reader):
|
||||||
|
@ -29,7 +31,7 @@ class FCEReader(Reader):
|
||||||
while True:
|
while True:
|
||||||
line_src = f.readline()
|
line_src = f.readline()
|
||||||
line_dst = f.readline()
|
line_dst = f.readline()
|
||||||
if not line_src:
|
if not line_src or len(line_src) < 5:
|
||||||
break
|
break
|
||||||
source = line_src.lower()[5:].strip().split()
|
source = line_src.lower()[5:].strip().split()
|
||||||
target = line_dst.lower()[5:].strip().split()
|
target = line_dst.lower()[5:].strip().split()
|
||||||
|
@ -57,6 +59,46 @@ class FCEReader(Reader):
|
||||||
for line in f:
|
for line in f:
|
||||||
# Input the correct text, which start with 0
|
# Input the correct text, which start with 0
|
||||||
if i % 2 == 1:
|
if i % 2 == 1:
|
||||||
if line:
|
if line and len(line) > 5:
|
||||||
yield line.lower()[5:].strip().split()
|
yield line.lower()[5:].strip().split()
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
class CGEDReader(Reader):
|
||||||
|
"""
|
||||||
|
Read CGED data set
|
||||||
|
"""
|
||||||
|
UNKNOWN_TOKEN = 'UNK'
|
||||||
|
|
||||||
|
def __init__(self, config, train_path=None, token_2_id=None, dataset_copies=2):
|
||||||
|
super(CGEDReader, self).__init__(
|
||||||
|
config, train_path=train_path, token_2_id=token_2_id,
|
||||||
|
special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN, CGEDReader.UNKNOWN_TOKEN],
|
||||||
|
dataset_copies=dataset_copies)
|
||||||
|
self.UNKNOWN_ID = self.token_2_id[CGEDReader.UNKNOWN_TOKEN]
|
||||||
|
|
||||||
|
def read_samples_by_string(self, path):
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
dom_tree = minidom.parse(f)
|
||||||
|
docs = dom_tree.documentElement.getElementsByTagName('DOC')
|
||||||
|
for doc in docs:
|
||||||
|
source_text = doc.getElementsByTagName('TEXT')[0]. \
|
||||||
|
childNodes[0].data.strip()
|
||||||
|
target_text = doc.getElementsByTagName('CORRECTION')[0]. \
|
||||||
|
childNodes[0].data.strip()
|
||||||
|
source = segment(source_text, cut_type='char')
|
||||||
|
target = segment(target_text, cut_type='char')
|
||||||
|
yield source, target
|
||||||
|
|
||||||
|
def unknown_token(self):
|
||||||
|
return CGEDReader.UNKNOWN_TOKEN
|
||||||
|
|
||||||
|
def read_tokens(self, path):
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
dom_tree = minidom.parse(f)
|
||||||
|
docs = dom_tree.documentElement.getElementsByTagName('DOC')
|
||||||
|
for doc in docs:
|
||||||
|
# Input the correct text
|
||||||
|
sentence = doc.getElementsByTagName('CORRECTION')[0]. \
|
||||||
|
childNodes[0].data.strip()
|
||||||
|
yield segment(sentence, cut_type='char')
|
|
@ -92,9 +92,9 @@ class CorrectorModel(object):
|
||||||
[corrective_tokens_tensor] * self.batch_size)
|
[corrective_tokens_tensor] * self.batch_size)
|
||||||
self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \
|
self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \
|
||||||
tf.placeholder(
|
tf.placeholder(
|
||||||
tf.float32,
|
tf.float32,
|
||||||
shape=[None, None],
|
shape=[None, None],
|
||||||
name="corrective_tokens")
|
name="corrective_tokens")
|
||||||
|
|
||||||
# Our targets are decoder inputs shifted by one.
|
# Our targets are decoder inputs shifted by one.
|
||||||
targets = [self.decoder_inputs[i + 1]
|
targets = [self.decoder_inputs[i + 1]
|
||||||
|
@ -116,6 +116,7 @@ class CorrectorModel(object):
|
||||||
return tf.nn.sampled_softmax_loss(w_t, b, labels, logits,
|
return tf.nn.sampled_softmax_loss(w_t, b, labels, logits,
|
||||||
num_samples,
|
num_samples,
|
||||||
self.target_vocab_size)
|
self.target_vocab_size)
|
||||||
|
|
||||||
softmax_loss_function = sampled_loss
|
softmax_loss_function = sampled_loss
|
||||||
|
|
||||||
# Create the internal multi-layer cell for our RNN.
|
# Create the internal multi-layer cell for our RNN.
|
||||||
|
@ -242,17 +243,14 @@ class CorrectorModel(object):
|
||||||
# Check if the sizes match.
|
# Check if the sizes match.
|
||||||
encoder_size, decoder_size = self.buckets[bucket_id]
|
encoder_size, decoder_size = self.buckets[bucket_id]
|
||||||
if len(encoder_inputs) != encoder_size:
|
if len(encoder_inputs) != encoder_size:
|
||||||
raise ValueError(
|
raise ValueError("Encoder length must be equal to the one in bucket,"
|
||||||
"Encoder length must be equal to the one in bucket,"
|
" %d != %d." % (len(encoder_inputs), encoder_size))
|
||||||
" %d != %d." % (len(encoder_inputs), encoder_size))
|
|
||||||
if len(decoder_inputs) != decoder_size:
|
if len(decoder_inputs) != decoder_size:
|
||||||
raise ValueError(
|
raise ValueError("Decoder length must be equal to the one in bucket,"
|
||||||
"Decoder length must be equal to the one in bucket,"
|
" %d != %d." % (len(decoder_inputs), decoder_size))
|
||||||
" %d != %d." % (len(decoder_inputs), decoder_size))
|
|
||||||
if len(target_weights) != decoder_size:
|
if len(target_weights) != decoder_size:
|
||||||
raise ValueError(
|
raise ValueError("Weights length must be equal to the one in bucket,"
|
||||||
"Weights length must be equal to the one in bucket,"
|
" %d != %d." % (len(target_weights), decoder_size))
|
||||||
" %d != %d." % (len(target_weights), decoder_size))
|
|
||||||
|
|
||||||
# Input feed: encoder inputs, decoder inputs, target_weights,
|
# Input feed: encoder inputs, decoder inputs, target_weights,
|
||||||
# as provided.
|
# as provided.
|
||||||
|
@ -263,9 +261,8 @@ class CorrectorModel(object):
|
||||||
input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
|
input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
|
||||||
input_feed[self.target_weights[l].name] = target_weights[l]
|
input_feed[self.target_weights[l].name] = target_weights[l]
|
||||||
|
|
||||||
# TODO: learn corrective tokens during training
|
corrective_tokens_vector = (corrective_tokens if
|
||||||
corrective_tokens_vector = (corrective_tokens
|
corrective_tokens is not None else
|
||||||
if corrective_tokens is not None else
|
|
||||||
np.zeros(self.target_vocab_size))
|
np.zeros(self.target_vocab_size))
|
||||||
batch_corrective_tokens = np.repeat([corrective_tokens_vector],
|
batch_corrective_tokens = np.repeat([corrective_tokens_vector],
|
||||||
self.batch_size, axis=0)
|
self.batch_size, axis=0)
|
||||||
|
@ -400,6 +397,7 @@ def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
|
||||||
Returns:
|
Returns:
|
||||||
A loop function.
|
A loop function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def loop_function(prev, _):
|
def loop_function(prev, _):
|
||||||
prev = project_and_apply_input_bias(prev, output_projection,
|
prev = project_and_apply_input_bias(prev, output_projection,
|
||||||
input_bias)
|
input_bias)
|
||||||
|
@ -411,7 +409,7 @@ def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
|
||||||
if not update_embedding:
|
if not update_embedding:
|
||||||
emb_prev = array_ops.stop_gradient(emb_prev)
|
emb_prev = array_ops.stop_gradient(emb_prev)
|
||||||
return emb_prev, prev_symbol
|
return emb_prev, prev_symbol
|
||||||
|
|
||||||
return loop_function
|
return loop_function
|
||||||
|
|
||||||
return fn_factory
|
return fn_factory
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Author: XuMing <xuming624@qq.com>
|
# Author: XuMing <xuming624@qq.com>
|
||||||
# Brief:
|
# Brief: Use FCE english corpus
|
||||||
|
|
||||||
|
# FCE english corpus
|
||||||
train_path = '../data/en/fce/fce_train.txt' # Training data path.
|
train_path = '../data/en/fce/fce_train.txt' # Training data path.
|
||||||
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
val_path = '../data/en/fce/fce_val.txt' # Validation data path.
|
||||||
test_path = '../data/en/fce/fce_test.txt'
|
test_path = '../data/en/fce/fce_test.txt'
|
||||||
|
@ -10,25 +12,21 @@ enable_special_error = False
|
||||||
num_steps = 3000 # Number of steps to train.
|
num_steps = 3000 # Number of steps to train.
|
||||||
decode_sentence = False # Whether we should decode sentences of the user.
|
decode_sentence = False # Whether we should decode sentences of the user.
|
||||||
|
|
||||||
# FCEConfig
|
# Config
|
||||||
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
|
||||||
|
|
||||||
steps_per_checkpoint = 100
|
steps_per_checkpoint = 100
|
||||||
max_steps = 2000
|
max_steps = 2000
|
||||||
|
|
||||||
max_vocab_size = 10000
|
max_vocab_size = 10000
|
||||||
|
size = 512
|
||||||
size = 128
|
num_layers = 1
|
||||||
num_layers = 2
|
|
||||||
max_gradient_norm = 5.0
|
max_gradient_norm = 5.0
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
learning_rate = 0.5
|
learning_rate = 0.5
|
||||||
learning_rate_decay_factor = 0.99
|
learning_rate_decay_factor = 0.99
|
||||||
|
|
||||||
use_lstm = False
|
use_lstm = False
|
||||||
use_rms_prop = False
|
use_rms_prop = False
|
||||||
|
|
||||||
enable_decode_sentence = True # Test with input error sentence
|
enable_decode_sentence = False # Test with input error sentence
|
||||||
enable_test_decode = True # Test with test set
|
enable_test_decode = True # Test with test set
|
||||||
|
|
||||||
import os
|
import os
|
|
@ -8,8 +8,10 @@ from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import seq2seq_config
|
import fce_config
|
||||||
from fce_reader import FCEReader
|
import cged_config
|
||||||
|
from corpus_reader import FCEReader
|
||||||
|
from corpus_reader import CGEDReader
|
||||||
from reader import EOS_ID
|
from reader import EOS_ID
|
||||||
from train import create_model
|
from train import create_model
|
||||||
|
|
||||||
|
@ -173,29 +175,29 @@ def evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path,
|
||||||
def main(_):
|
def main(_):
|
||||||
print('Correcting error...')
|
print('Correcting error...')
|
||||||
# Set the model path.
|
# Set the model path.
|
||||||
model_path = seq2seq_config.model_path
|
model_path = cged_config.model_path
|
||||||
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path)
|
data_reader = CGEDReader(cged_config, cged_config.train_path)
|
||||||
|
|
||||||
if seq2seq_config.enable_decode_sentence:
|
if cged_config.enable_decode_sentence:
|
||||||
# Correct user's sentences.
|
# Correct user's sentences.
|
||||||
with tf.Session() as session:
|
with tf.Session() as session:
|
||||||
model = create_model(session, True, model_path, config=seq2seq_config)
|
model = create_model(session, True, model_path, config=cged_config)
|
||||||
print("Enter a sentence you'd like to correct")
|
print("Enter a sentence you'd like to correct")
|
||||||
correct_new_sentence = input()
|
correct_new_sentence = input()
|
||||||
while correct_new_sentence.lower() != 'no':
|
while correct_new_sentence.lower() != 'no':
|
||||||
decode_sentence(session, model=model, data_reader=data_reader,
|
decode_sentence(session, model=model, data_reader=data_reader,
|
||||||
sentence=correct_new_sentence,
|
sentence=correct_new_sentence,
|
||||||
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path))
|
corrective_tokens=data_reader.read_tokens(cged_config.train_path))
|
||||||
print("Enter a sentence you'd like to correct or press NO")
|
print("Enter a sentence you'd like to correct or press NO")
|
||||||
correct_new_sentence = input()
|
correct_new_sentence = input()
|
||||||
elif seq2seq_config.enable_test_decode:
|
elif cged_config.enable_test_decode:
|
||||||
# Decode test sentences.
|
# Decode test sentences.
|
||||||
with tf.Session() as session:
|
with tf.Session() as session:
|
||||||
model = create_model(session, True, model_path, config=seq2seq_config)
|
model = create_model(session, True, model_path, config=cged_config)
|
||||||
print("Loaded model. Beginning decoding.")
|
print("Loaded model. Beginning decoding.")
|
||||||
decodings = decode(session, model=model, data_reader=data_reader,
|
decodings = decode(session, model=model, data_reader=data_reader,
|
||||||
data_to_decode=data_reader.read_tokens(seq2seq_config.test_path),
|
data_to_decode=data_reader.read_tokens(cged_config.test_path),
|
||||||
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path))
|
corrective_tokens=data_reader.read_tokens(cged_config.train_path))
|
||||||
# Write the decoded tokens to stdout.
|
# Write the decoded tokens to stdout.
|
||||||
for tokens in decodings:
|
for tokens in decodings:
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
|
@ -10,13 +10,15 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import seq2seq_config
|
import cged_config
|
||||||
|
import fce_config
|
||||||
from corrector_model import CorrectorModel
|
from corrector_model import CorrectorModel
|
||||||
from fce_reader import FCEReader
|
from corpus_reader import FCEReader
|
||||||
|
from corpus_reader import CGEDReader
|
||||||
from tf_util import get_ckpt_path
|
from tf_util import get_ckpt_path
|
||||||
|
|
||||||
|
|
||||||
def create_model(session, forward_only, model_path, config=seq2seq_config):
|
def create_model(session, forward_only, model_path, config=cged_config):
|
||||||
"""
|
"""
|
||||||
Create model and load parameters
|
Create model and load parameters
|
||||||
:param session:
|
:param session:
|
||||||
|
@ -115,7 +117,7 @@ def train(data_reader, train_path, test_path, model_path):
|
||||||
# Run evals on development set and print their perplexity.
|
# Run evals on development set and print their perplexity.
|
||||||
for bucket_id in range(len(config.buckets)):
|
for bucket_id in range(len(config.buckets)):
|
||||||
if len(test_data[bucket_id]) == 0:
|
if len(test_data[bucket_id]) == 0:
|
||||||
print(" eval: empty bucket %d" % (bucket_id))
|
print(" eval: empty bucket %d" % bucket_id)
|
||||||
continue
|
continue
|
||||||
encoder_inputs, decoder_inputs, target_weights = \
|
encoder_inputs, decoder_inputs, target_weights = \
|
||||||
model.get_batch(test_data, bucket_id)
|
model.get_batch(test_data, bucket_id)
|
||||||
|
@ -133,11 +135,11 @@ def train(data_reader, train_path, test_path, model_path):
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
print('Training model...')
|
print('Training model...')
|
||||||
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path)
|
data_reader = CGEDReader(cged_config, cged_config.train_path)
|
||||||
train(data_reader,
|
train(data_reader,
|
||||||
seq2seq_config.train_path,
|
cged_config.train_path,
|
||||||
seq2seq_config.val_path,
|
cged_config.val_path,
|
||||||
seq2seq_config.model_path)
|
cged_config.model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -113,14 +113,17 @@ def simplified2traditional(sentence):
|
||||||
return sentence
|
return sentence
|
||||||
|
|
||||||
|
|
||||||
def segment(sentence):
|
def segment(sentence, cut_type='word'):
|
||||||
"""
|
"""
|
||||||
切词
|
切词
|
||||||
:param sentence:
|
:param sentence:
|
||||||
|
:param cut_type: 'word' use jieba.lcut; 'char' use list(sentence)
|
||||||
:return: list
|
:return: list
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
jieba.default_logger.setLevel(logging.ERROR)
|
jieba.default_logger.setLevel(logging.ERROR)
|
||||||
|
if cut_type == 'char':
|
||||||
|
return list(sentence)
|
||||||
return jieba.lcut(sentence)
|
return jieba.lcut(sentence)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue