add chinese correct check by seq2seq model.

This commit is contained in:
xuming06 2018-03-30 15:07:36 +08:00
parent e5f4b0af60
commit 10f2385aca
8 changed files with 244 additions and 48 deletions

View File

@ -0,0 +1,116 @@
<ROOT>
<DOC>
<TEXT id="200405109523200554_2_1x1">
他们知不道吸烟对未成年年的影响会造成的各种害处。
</TEXT>
<CORRECTION>
他们不知道吸烟对未成年人会造成的各种伤害。
</CORRECTION>
<ERROR start_off="3" end_off="4" type="W"></ERROR>
<ERROR start_off="12" end_off="12" type="S"></ERROR>
<ERROR start_off="13" end_off="15" type="R"></ERROR>
<ERROR start_off="22" end_off="23" type="S"></ERROR>
</DOC>
<DOC>
<TEXT id="200505109634201470_2_4x2">
从此,父母亲就会教咱们爬行、走路、叫爸爸妈妈。到我们长大了,我们开始从妈妈爸爸身上模仿行为,譬如学习爸妈走路时的高雅步姿,坐姿、礼貌、习惯……渐渐地你会发觉一直好象是自己好奇,觉得有趣才会照着做,模仿着双亲,但不知不觉间他们影响到孩子们的不再是表面的行为,思想,心态、待人接物上我们都领受了不少,这的确会影响我们的成长。
</TEXT>
<CORRECTION>
从此,父母亲就会教我们爬行、走路、叫爸爸妈妈。到我们长大了,我们开始从妈妈爸爸身上模仿行为,譬如学习爸妈走路时的高雅步姿,坐姿、礼貌、习惯……渐渐地你会发觉好象是自己一直好奇,觉得有趣才会照着做,模仿着双亲,但不知不觉间他们影响到孩子们的不再是表面的行为,思想,心态、待人接物上我们都领受了不少,这的确会影响我们的成长。
</CORRECTION>
<ERROR start_off="10" end_off="11" type="S"></ERROR>
<ERROR start_off="79" end_off="85" type="W"></ERROR>
</DOC>
<DOC>
<TEXT id="200510111523200014_2_1x1">
有些不喜欢流行歌曲的人也说流行歌曲能引起不好的作用。
</TEXT>
<CORRECTION>
有些不喜欢流行歌曲的人也说流行歌曲能引起不好的后果。
</CORRECTION>
<ERROR start_off="24" end_off="25" type="S"></ERROR>
</DOC>
<DOC>
<TEXT id="200505204525200257_2_2x2">
如果它呈出不太香的颜色,那就意味着它颜色的来源——你,就是教给它呈出那样的味道的。如果你将一块“白布”成功地染上的话,会出什么样的颜色呢?
</TEXT>
<CORRECTION>
如果它呈出不太美的颜色,那就意味着它颜色的来源——你,就是教给它呈出那样的颜色的人没教好。如果你将一块“白布”成功地染上颜色的话,会出什么样的颜色呢?
</CORRECTION>
<ERROR start_off="8" end_off="8" type="S"></ERROR>
<ERROR start_off="34" end_off="35" type="S"></ERROR>
<ERROR start_off="41" end_off="41" type="M"></ERROR>
<ERROR start_off="57" end_off="57" type="M"></ERROR>
</DOC>
<DOC>
<TEXT id="200405109525200464_2_5x1">
这都是他们自己引起的,埋怨什么呢?
</TEXT>
<CORRECTION>
这都是他们自己造成的,埋怨什么呢?
</CORRECTION>
<ERROR start_off="8" end_off="9" type="S"></ERROR>
</DOC>
<DOC>
<TEXT id="200310576525200063_2_6x1">
他长大以后突然产生要跟女孩交玩儿的念头。
</TEXT>
<CORRECTION>
他长大以后突然产生要跟女孩儿玩耍的念头。
</CORRECTION>
<ERROR start_off="14" end_off="16" type="S"></ERROR>
</DOC>
<DOC>
<TEXT id="200307271523200070_2_2x1">
可是我觉得饥饿是用科学机技来能解决问题,所以我认为吃“绿色食品”是还是重要的问题。
</TEXT>
<CORRECTION>
可是我觉得饥饿是用科学技术能解决的问题,所以我认为吃“绿色食品”是更重要的问题。
</CORRECTION>
<ERROR start_off="12" end_off="13" type="S"></ERROR>
<ERROR start_off="14" end_off="14" type="R"></ERROR>
<ERROR start_off="18" end_off="18" type="M"></ERROR>
<ERROR start_off="34" end_off="35" type="S"></ERROR>
</DOC>
<DOC>
<TEXT id="200405109523100546_2_1x1">
在韩国最近很流行不允许的电视节目,这节目说公共场所抽烟是不道德的行为。
</TEXT>
<CORRECTION>
在韩国最近不允许抽烟的电视节目很流行,这些节目说在公共场所抽烟是不道德的行为。
</CORRECTION>
<ERROR start_off="6" end_off="16" type="W"></ERROR>
<ERROR start_off="12" end_off="12" type="M"></ERROR>
<ERROR start_off="19" end_off="19" type="M"></ERROR>
<ERROR start_off="22" end_off="22" type="M"></ERROR>
</DOC>
<DOC>
<TEXT id="200510302523100195_2_9x2">
如果他喜欢听什么,就能听什么。因为这种现象,因韩流而得到的经济利益也很多。
</TEXT>
<CORRECTION>
他喜欢听什么,就能听什么。因为这种现象,韩国因韩流而得到的经济利益也很多。
</CORRECTION>
<ERROR start_off="1" end_off="2" type="R"></ERROR>
<ERROR start_off="23" end_off="23" type="M"></ERROR>
</DOC>
<DOC>
<TEXT id="200505109922201218_2_2x1">
从环境里小孩子能快速地学或模仿他所见到的事物。
</TEXT>
<CORRECTION>
小孩子能从环境里快速地学或模仿他所见到的事物。
</CORRECTION>
<ERROR start_off="1" end_off="8" type="W"></ERROR>
</DOC>
<DOC>
<TEXT id="200505522522250013_2_7x1">
认识到结婚过程不满六个月,也可以说我的故事中我是主动的。
</TEXT>
<CORRECTION>
认识到结婚的过程不满六个月,也可以说在我的故事中我是主动的。
</CORRECTION>
<ERROR start_off="6" end_off="6" type="M"></ERROR>
<ERROR start_off="18" end_off="18" type="M"></ERROR>
</DOC>
</ROOT>

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Author: XuMing <xuming624@qq.com>
# Brief: Use CGED corpus
# CGED chinese corpus
train_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml' # Training data path.
val_path = '../data/cn/CGED/CGED18_HSK_TestingSet.xml' # Validation data path.
test_path = '../data/cn/CGED/CGED18_HSK_TrainingSet.xml'
model_path = './output/cged_model' # Path of the model saved, default is output_path/model
enable_special_error = False
num_steps = 3000 # Number of steps to train.
decode_sentence = False # Whether we should decode sentences of the user.
# Config
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
steps_per_checkpoint = 100
max_steps = 2000
max_vocab_size = 10000
size = 512
num_layers = 4
max_gradient_norm = 5.0
batch_size = 64
learning_rate = 0.5
learning_rate_decay_factor = 0.99
use_lstm = False
use_rms_prop = False
enable_decode_sentence = False # Test with input error sentence
enable_test_decode = True # Test with test set
import os
if not os.path.exists(model_path):
os.makedirs(model_path)

View File

@ -1,9 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Author: XuMing <xuming624@qq.com> # Author: XuMing <xuming624@qq.com>
# Brief: # Brief: Corpus for model
import random import random
from reader import Reader, PAD_TOKEN, EOS_TOKEN, GO_TOKEN from reader import Reader, PAD_TOKEN, EOS_TOKEN, GO_TOKEN
from xml.dom import minidom
from text_util import segment
class FCEReader(Reader): class FCEReader(Reader):
@ -29,7 +31,7 @@ class FCEReader(Reader):
while True: while True:
line_src = f.readline() line_src = f.readline()
line_dst = f.readline() line_dst = f.readline()
if not line_src: if not line_src or len(line_src) < 5:
break break
source = line_src.lower()[5:].strip().split() source = line_src.lower()[5:].strip().split()
target = line_dst.lower()[5:].strip().split() target = line_dst.lower()[5:].strip().split()
@ -57,6 +59,46 @@ class FCEReader(Reader):
for line in f: for line in f:
# Input the correct text, which start with 0 # Input the correct text, which start with 0
if i % 2 == 1: if i % 2 == 1:
if line: if line and len(line) > 5:
yield line.lower()[5:].strip().split() yield line.lower()[5:].strip().split()
i += 1 i += 1
class CGEDReader(Reader):
"""
Read CGED data set
"""
UNKNOWN_TOKEN = 'UNK'
def __init__(self, config, train_path=None, token_2_id=None, dataset_copies=2):
super(CGEDReader, self).__init__(
config, train_path=train_path, token_2_id=token_2_id,
special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN, CGEDReader.UNKNOWN_TOKEN],
dataset_copies=dataset_copies)
self.UNKNOWN_ID = self.token_2_id[CGEDReader.UNKNOWN_TOKEN]
def read_samples_by_string(self, path):
with open(path, 'r', encoding='utf-8') as f:
dom_tree = minidom.parse(f)
docs = dom_tree.documentElement.getElementsByTagName('DOC')
for doc in docs:
source_text = doc.getElementsByTagName('TEXT')[0]. \
childNodes[0].data.strip()
target_text = doc.getElementsByTagName('CORRECTION')[0]. \
childNodes[0].data.strip()
source = segment(source_text, cut_type='char')
target = segment(target_text, cut_type='char')
yield source, target
def unknown_token(self):
return CGEDReader.UNKNOWN_TOKEN
def read_tokens(self, path):
with open(path, 'r', encoding='utf-8') as f:
dom_tree = minidom.parse(f)
docs = dom_tree.documentElement.getElementsByTagName('DOC')
for doc in docs:
# Input the correct text
sentence = doc.getElementsByTagName('CORRECTION')[0]. \
childNodes[0].data.strip()
yield segment(sentence, cut_type='char')

View File

@ -92,9 +92,9 @@ class CorrectorModel(object):
[corrective_tokens_tensor] * self.batch_size) [corrective_tokens_tensor] * self.batch_size)
self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \ self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \
tf.placeholder( tf.placeholder(
tf.float32, tf.float32,
shape=[None, None], shape=[None, None],
name="corrective_tokens") name="corrective_tokens")
# Our targets are decoder inputs shifted by one. # Our targets are decoder inputs shifted by one.
targets = [self.decoder_inputs[i + 1] targets = [self.decoder_inputs[i + 1]
@ -116,6 +116,7 @@ class CorrectorModel(object):
return tf.nn.sampled_softmax_loss(w_t, b, labels, logits, return tf.nn.sampled_softmax_loss(w_t, b, labels, logits,
num_samples, num_samples,
self.target_vocab_size) self.target_vocab_size)
softmax_loss_function = sampled_loss softmax_loss_function = sampled_loss
# Create the internal multi-layer cell for our RNN. # Create the internal multi-layer cell for our RNN.
@ -242,17 +243,14 @@ class CorrectorModel(object):
# Check if the sizes match. # Check if the sizes match.
encoder_size, decoder_size = self.buckets[bucket_id] encoder_size, decoder_size = self.buckets[bucket_id]
if len(encoder_inputs) != encoder_size: if len(encoder_inputs) != encoder_size:
raise ValueError( raise ValueError("Encoder length must be equal to the one in bucket,"
"Encoder length must be equal to the one in bucket," " %d != %d." % (len(encoder_inputs), encoder_size))
" %d != %d." % (len(encoder_inputs), encoder_size))
if len(decoder_inputs) != decoder_size: if len(decoder_inputs) != decoder_size:
raise ValueError( raise ValueError("Decoder length must be equal to the one in bucket,"
"Decoder length must be equal to the one in bucket," " %d != %d." % (len(decoder_inputs), decoder_size))
" %d != %d." % (len(decoder_inputs), decoder_size))
if len(target_weights) != decoder_size: if len(target_weights) != decoder_size:
raise ValueError( raise ValueError("Weights length must be equal to the one in bucket,"
"Weights length must be equal to the one in bucket," " %d != %d." % (len(target_weights), decoder_size))
" %d != %d." % (len(target_weights), decoder_size))
# Input feed: encoder inputs, decoder inputs, target_weights, # Input feed: encoder inputs, decoder inputs, target_weights,
# as provided. # as provided.
@ -263,9 +261,8 @@ class CorrectorModel(object):
input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
input_feed[self.target_weights[l].name] = target_weights[l] input_feed[self.target_weights[l].name] = target_weights[l]
# TODO: learn corrective tokens during training corrective_tokens_vector = (corrective_tokens if
corrective_tokens_vector = (corrective_tokens corrective_tokens is not None else
if corrective_tokens is not None else
np.zeros(self.target_vocab_size)) np.zeros(self.target_vocab_size))
batch_corrective_tokens = np.repeat([corrective_tokens_vector], batch_corrective_tokens = np.repeat([corrective_tokens_vector],
self.batch_size, axis=0) self.batch_size, axis=0)
@ -400,6 +397,7 @@ def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
Returns: Returns:
A loop function. A loop function.
""" """
def loop_function(prev, _): def loop_function(prev, _):
prev = project_and_apply_input_bias(prev, output_projection, prev = project_and_apply_input_bias(prev, output_projection,
input_bias) input_bias)
@ -411,7 +409,7 @@ def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
if not update_embedding: if not update_embedding:
emb_prev = array_ops.stop_gradient(emb_prev) emb_prev = array_ops.stop_gradient(emb_prev)
return emb_prev, prev_symbol return emb_prev, prev_symbol
return loop_function return loop_function
return fn_factory return fn_factory

View File

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Author: XuMing <xuming624@qq.com> # Author: XuMing <xuming624@qq.com>
# Brief: # Brief: Use FCE english corpus
# FCE english corpus
train_path = '../data/en/fce/fce_train.txt' # Training data path. train_path = '../data/en/fce/fce_train.txt' # Training data path.
val_path = '../data/en/fce/fce_val.txt' # Validation data path. val_path = '../data/en/fce/fce_val.txt' # Validation data path.
test_path = '../data/en/fce/fce_test.txt' test_path = '../data/en/fce/fce_test.txt'
@ -10,25 +12,21 @@ enable_special_error = False
num_steps = 3000 # Number of steps to train. num_steps = 3000 # Number of steps to train.
decode_sentence = False # Whether we should decode sentences of the user. decode_sentence = False # Whether we should decode sentences of the user.
# FCEConfig # Config
buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency. buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] # use a number of buckets and pad to the closest one for efficiency.
steps_per_checkpoint = 100 steps_per_checkpoint = 100
max_steps = 2000 max_steps = 2000
max_vocab_size = 10000 max_vocab_size = 10000
size = 512
size = 128 num_layers = 1
num_layers = 2
max_gradient_norm = 5.0 max_gradient_norm = 5.0
batch_size = 64 batch_size = 64
learning_rate = 0.5 learning_rate = 0.5
learning_rate_decay_factor = 0.99 learning_rate_decay_factor = 0.99
use_lstm = False use_lstm = False
use_rms_prop = False use_rms_prop = False
enable_decode_sentence = True # Test with input error sentence enable_decode_sentence = False # Test with input error sentence
enable_test_decode = True # Test with test set enable_test_decode = True # Test with test set
import os import os

View File

@ -8,8 +8,10 @@ from collections import defaultdict
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import seq2seq_config import fce_config
from fce_reader import FCEReader import cged_config
from corpus_reader import FCEReader
from corpus_reader import CGEDReader
from reader import EOS_ID from reader import EOS_ID
from train import create_model from train import create_model
@ -173,29 +175,29 @@ def evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path,
def main(_): def main(_):
print('Correcting error...') print('Correcting error...')
# Set the model path. # Set the model path.
model_path = seq2seq_config.model_path model_path = cged_config.model_path
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path) data_reader = CGEDReader(cged_config, cged_config.train_path)
if seq2seq_config.enable_decode_sentence: if cged_config.enable_decode_sentence:
# Correct user's sentences. # Correct user's sentences.
with tf.Session() as session: with tf.Session() as session:
model = create_model(session, True, model_path, config=seq2seq_config) model = create_model(session, True, model_path, config=cged_config)
print("Enter a sentence you'd like to correct") print("Enter a sentence you'd like to correct")
correct_new_sentence = input() correct_new_sentence = input()
while correct_new_sentence.lower() != 'no': while correct_new_sentence.lower() != 'no':
decode_sentence(session, model=model, data_reader=data_reader, decode_sentence(session, model=model, data_reader=data_reader,
sentence=correct_new_sentence, sentence=correct_new_sentence,
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path)) corrective_tokens=data_reader.read_tokens(cged_config.train_path))
print("Enter a sentence you'd like to correct or press NO") print("Enter a sentence you'd like to correct or press NO")
correct_new_sentence = input() correct_new_sentence = input()
elif seq2seq_config.enable_test_decode: elif cged_config.enable_test_decode:
# Decode test sentences. # Decode test sentences.
with tf.Session() as session: with tf.Session() as session:
model = create_model(session, True, model_path, config=seq2seq_config) model = create_model(session, True, model_path, config=cged_config)
print("Loaded model. Beginning decoding.") print("Loaded model. Beginning decoding.")
decodings = decode(session, model=model, data_reader=data_reader, decodings = decode(session, model=model, data_reader=data_reader,
data_to_decode=data_reader.read_tokens(seq2seq_config.test_path), data_to_decode=data_reader.read_tokens(cged_config.test_path),
corrective_tokens=data_reader.read_tokens(seq2seq_config.train_path)) corrective_tokens=data_reader.read_tokens(cged_config.train_path))
# Write the decoded tokens to stdout. # Write the decoded tokens to stdout.
for tokens in decodings: for tokens in decodings:
sys.stdout.flush() sys.stdout.flush()

View File

@ -10,13 +10,15 @@ import time
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import seq2seq_config import cged_config
import fce_config
from corrector_model import CorrectorModel from corrector_model import CorrectorModel
from fce_reader import FCEReader from corpus_reader import FCEReader
from corpus_reader import CGEDReader
from tf_util import get_ckpt_path from tf_util import get_ckpt_path
def create_model(session, forward_only, model_path, config=seq2seq_config): def create_model(session, forward_only, model_path, config=cged_config):
""" """
Create model and load parameters Create model and load parameters
:param session: :param session:
@ -115,7 +117,7 @@ def train(data_reader, train_path, test_path, model_path):
# Run evals on development set and print their perplexity. # Run evals on development set and print their perplexity.
for bucket_id in range(len(config.buckets)): for bucket_id in range(len(config.buckets)):
if len(test_data[bucket_id]) == 0: if len(test_data[bucket_id]) == 0:
print(" eval: empty bucket %d" % (bucket_id)) print(" eval: empty bucket %d" % bucket_id)
continue continue
encoder_inputs, decoder_inputs, target_weights = \ encoder_inputs, decoder_inputs, target_weights = \
model.get_batch(test_data, bucket_id) model.get_batch(test_data, bucket_id)
@ -133,11 +135,11 @@ def train(data_reader, train_path, test_path, model_path):
def main(_): def main(_):
print('Training model...') print('Training model...')
data_reader = FCEReader(seq2seq_config, seq2seq_config.train_path) data_reader = CGEDReader(cged_config, cged_config.train_path)
train(data_reader, train(data_reader,
seq2seq_config.train_path, cged_config.train_path,
seq2seq_config.val_path, cged_config.val_path,
seq2seq_config.model_path) cged_config.model_path)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -113,14 +113,17 @@ def simplified2traditional(sentence):
return sentence return sentence
def segment(sentence): def segment(sentence, cut_type='word'):
""" """
切词 切词
:param sentence: :param sentence:
:param cut_type: 'word' use jieba.lcut; 'char' use list(sentence)
:return: list :return: list
""" """
import logging import logging
jieba.default_logger.setLevel(logging.ERROR) jieba.default_logger.setLevel(logging.ERROR)
if cut_type == 'char':
return list(sentence)
return jieba.lcut(sentence) return jieba.lcut(sentence)