手写中文文章识别(1)——问题描述
https://blog.csdn.net/foreseerwang/article/details/80833749
手写中文文章识别(2)——样本集构建
https://blog.csdn.net/foreseerwang/article/details/80842498
手写中文文章识别(3)——data feeding
https://blog.csdn.net/foreseerwang/article/details/80914473
手写中文文章识别(4)——模型搭建
https://blog.csdn.net/foreseerwang/article/details/81076936
本系列的最后一篇了,介绍模型训练及验证(validation)结果。
模型训练代码:
def train(train_files, valid_files):
print('Begin training')
time0=time.time()
train_feeder = DataIterator(train_files, istrain=True)
valid_feeder = DataIterator(valid_files, istrain=False)
model_name = 'chinese-rec-model'
train_start = time.time()
trn_dataset = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)
train_image_batch = tf.reshape(trn_dataset[0],[-1,FLAGS.image_size,FLAGS.image_size,
FLAGS.image_channel])
train_label_batch = tf.reshape(trn_dataset[1],[-1])
train_len_batch = trn_dataset[2]
train_mask_batch = tf.reshape(trn_dataset[3], [-1])
val_dataset = valid_feeder.input_pipeline(batch_size=FLAGS.batch_size, shuffle=False)
valid_image_batch = tf.reshape(val_dataset[0],[-1,FLAGS.image_size,FLAGS.image_size,
FLAGS.image_channel])
valid_label_batch = tf.reshape(val_dataset[1],[-1])
valid_len_batch = val_dataset[2]
valid_mask_batch = tf.reshape(val_dataset[3], [-1])
graphTRN = build_graph(top_k=1, images=train_image_batch, labels=train_label_batch,
seq_lens = train_len_batch, mask=train_mask_batch,
keep_prob=0.8, is_training=True)
graphVAL = build_graph(top_k=1, images=valid_image_batch, labels=valid_label_batch,
seq_lens=valid_len_batch, mask=valid_mask_batch, keep_prob=1,
is_training=False, reuse_variables=True)
accuracy_val = np.zeros([FLAGS.eval_steps//10])
mask_val = np.zeros([FLAGS.eval_steps // 10])
valid_accuracy = []
with tf.Session() as sess:
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
start_step = 0
if FLAGS.restore:
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
start_step += int(ckpt.split('-')[-1])
logger.info(':::Training Start:::')
try:
i = 0
while True:
i += 1
start_time = time.time()
_, loss_val, train_summary, step = sess.run(
[graphTRN['train_op'], graphTRN['loss'],
graphTRN['merged_summary_op'], graphTRN['global_step']])
train_writer.add_summary(train_summary, step)
end_time = time.time()
logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
if step > FLAGS.max_steps:
break
if (step % FLAGS.eval_steps == 1) and (step > 300):
valid_summary = ''
for iter_num in range(FLAGS.eval_steps//10):
accuracy_val[iter_num], valid_summary_val, mask_batch = \
sess.run([graphVAL['accuracy'], graphVAL['merged_summary_op'],
graphVAL['mask']])
valid_summary = valid_summary+valid_summary_val
mask_val[iter_num] = sum(mask_batch)
accuracy_valid = np.sum(accuracy_val*mask_val)/np.sum(mask_val)
valid_accuracy.append(accuracy_valid)
test_writer.add_summary(valid_summary, step)
logger.info('===============Eval a batch=======================')
logger.info('the step {0} test accuracy: {1}'
.format(step, accuracy_valid))
logger.info('{0} seconds passed'
.format(time.time() - time0))
logger.info('===============Eval a batch=======================')
if step % FLAGS.save_steps == 1:
logger.info('Save the ckpt of {0}'.format(step))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, model_name),
global_step=graphTRN['global_step'])
except tf.errors.OutOfRangeError:
logger.info('==================Train Finished================')
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, model_name), global_step=graphTRN['global_step'])
train_end = time.time()
print('total training time: %f' % (train_end-train_start))
个人认为,这里面需要注意的主要有两点:
1. 在训练的过程中,同时监控validation数据集的结果,是比较合理的做法,因为,如果代码没错的话,training数据集的loss会一直下降,直至过拟合,而validation数据集的结果才能正确反映模型的效果。关于这部分的处理,请关注代码中graphTRN和graphVAL两部分的处理,以及在上一篇文章中提到的tf.get_variable_scope().reuse_variables()的使用;
2. 代码中使用了Dataset,其输出可以作为每个batch的数据之间输入给模型函数,避免自行构建batch数据及使用placeholder。个人认为,这种方式更为简洁清晰。
验证(validation)代码:
def validation(filenames):
print('Begin validation')
test_feeder = DataIterator(filenames, istrain=False)
dataset = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1, shuffle=False)
val_image_batch = tf.reshape(dataset[0],[-1,FLAGS.image_size,FLAGS.image_size,FLAGS.image_channel])
val_label_batch = tf.reshape(dataset[1],[-1])
val_len_batch = dataset[2]
val_mask_batch = tf.reshape(dataset[3],[-1])
graph = build_graph(top_k=3, images=val_image_batch, labels=val_label_batch,
seq_lens=val_len_batch, mask=val_mask_batch,
keep_prob=1.0, is_training=False)
groundtruth = []
seq_lens = []
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
logger.info(':::Start validation:::')
try:
i = 0
acc_top_1 = 0.0
while True:
i += 1
start_time = time.time()
batch_labels, acc_1, seq_lens_batch =\
sess.run([graph['labels'], graph['accuracy'], graph['seq_lens']])
groundtruth += batch_labels.tolist()
acc_top_1 += acc_1*len(batch_labels)
seq_lens += seq_lens_batch.tolist()
end_time = time.time()
logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1)"
.format(i, end_time - start_time, acc_1))
except tf.errors.OutOfRangeError:
logger.info('==================Validation Finished================')
logger.info('size of acc_top_1: {0}; size of labels: {1}'.\
format(np.asarray(acc_top_1).shape, np.asarray(groundtruth).shape))
acc_top_1 = acc_top_1 / len(groundtruth) #* FLAGS.batch_size / test_feeder.size
logger.info('top 1 accuracy {0}'.format(acc_top_1))
return {'groundtruth': groundtruth, 'seq_lens': seq_lens}
有了前面训练的代码,这部分就相当简单了,因为基本就是训练过程中监控validation数据集部分的代码。
在给出最终validation结果之前,先说明用到的参数:
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")
tf.app.flags.DEFINE_integer('charset_size_long', 7356, "Choose the first `charset_size` characters only.")
tf.app.flags.DEFINE_integer('charset_size_short', 4000, "Choose the first `charset_size` characters only.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_integer('image_channel', 1, "channel number")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 200002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 1000, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 100, "the steps to save")
tf.app.flags.DEFINE_integer('sent_len_max', 30, "maximum sentence length")
tf.app.flags.DEFINE_integer('char_vec_len', 1024, "the length of character vector")
tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoint_0626v13/', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', '../data/hwdb_by_char_Train/', 'the train dataset dir')
tf.app.flags.DEFINE_string('test_data_dir', '../data/hwdb_by_char_Test/', 'the test dataset dir')
tf.app.flags.DEFINE_string('inferenc_dir', './inference/', 'the inference dir')
tf.app.flags.DEFINE_string('log_dir', './log', 'the logging dir')
tf.app.flags.DEFINE_string('sample_dir', './sample', 'the sample dir')
tf.app.flags.DEFINE_string('train_hwdb_dir', '../hwdb/hwdb_by_char_Train_gbk', 'the train hwdb dir')
tf.app.flags.DEFINE_string('test_hwdb_dir', '../hwdb/hwdb_by_char_Test_gbk', 'the test hwdb dir')
tf.app.flags.DEFINE_string('char_dict_filename', 'char_dict_gbk20180518', 'dict: char->number')
tf.app.flags.DEFINE_string('char_rvs_dict_filename', 'char_dict_gbk_rvs20180518', 'dict: number->char')
tf.app.flags.DEFINE_boolean('short_dict', True, 'whether to use short dict')
tf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_integer('batch_size', 8, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'validation', 'Running mode. One of {"train", "validation", "inference"}')
tf.app.flags.DEFINE_integer('lstm_size', 512, 'LSTM size')
tf.app.flags.DEFINE_integer('num_layers', 2, 'LSTM layers number')
tf.app.flags.DEFINE_boolean('viterbi', True, 'whether to restore from checkpoint')
其中比较重要的参数有三个:
- FLAGS.short_dict,用于指示是否使用短字库,也就是4000个中文字符的字库,而不是7356的完整hwdb字库。这4000个字符是通过上百万篇文章统计得到的使用最频繁的4000个中文字符,占据了上百万篇文章中文字符的95%以上。其它没有出现在这4000个字符中的中文字符,归结为<UNK>。本模型在实际训练和验证时使用了短字库;
- FLAGS.viterbi,是否使用Viterbi算法,具体使用方式在上一篇文章中已有介绍。本模型完成了打开和关闭Viterbi算法的模型训练和验证。需要说明的是,打开Viterbi算法之后,训练时长(单GTX1080Ti,每1000个batch训练时长需要1.5小时左右)是关闭时的20倍以上,validation数据集识别准确率提升约1.4个百分点;
- FLAGS.mode,在训练时改为'train',在validation时改为'validation'。
按照上述参数配置,训练20万次之后,validation数据识别结果:
- 关闭Viterbi算法(6次平均):96.16%
- 打开Viterbi算法(7次平均):97.53%
Viterbi算法带来约1.4个百分点的增益,但训练时长增加约20倍。
至此,本项目系列文章终结。很遗憾,因为版权原因,不能上传数据集供大家使用,只能提供代码供参考了。如有任何问题或代码bug,欢迎联系:[email protected] 谢谢!