天天看点

手写中文文章识别(5)——训练及验证结果

手写中文文章识别(1)——问题描述

https://blog.csdn.net/foreseerwang/article/details/80833749

手写中文文章识别(2)——样本集构建

https://blog.csdn.net/foreseerwang/article/details/80842498

手写中文文章识别(3)——data feeding

https://blog.csdn.net/foreseerwang/article/details/80914473

手写中文文章识别(4)——模型搭建

https://blog.csdn.net/foreseerwang/article/details/81076936

本系列的最后一篇了,介绍模型训练及验证(validation)结果。

模型训练代码:

def train(train_files, valid_files):
    print('Begin training')

    time0=time.time()

    train_feeder = DataIterator(train_files, istrain=True)
    valid_feeder = DataIterator(valid_files, istrain=False)
    model_name = 'chinese-rec-model'

    train_start = time.time()

    trn_dataset = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)
    train_image_batch = tf.reshape(trn_dataset[0],[-1,FLAGS.image_size,FLAGS.image_size,
                                               FLAGS.image_channel])
    train_label_batch = tf.reshape(trn_dataset[1],[-1])
    train_len_batch = trn_dataset[2]
    train_mask_batch = tf.reshape(trn_dataset[3], [-1])

    val_dataset = valid_feeder.input_pipeline(batch_size=FLAGS.batch_size, shuffle=False)
    valid_image_batch = tf.reshape(val_dataset[0],[-1,FLAGS.image_size,FLAGS.image_size,
                                               FLAGS.image_channel])
    valid_label_batch = tf.reshape(val_dataset[1],[-1])
    valid_len_batch = val_dataset[2]
    valid_mask_batch = tf.reshape(val_dataset[3], [-1])

    graphTRN = build_graph(top_k=1, images=train_image_batch, labels=train_label_batch,
                           seq_lens = train_len_batch, mask=train_mask_batch,
                           keep_prob=0.8, is_training=True)
    graphVAL = build_graph(top_k=1, images=valid_image_batch, labels=valid_label_batch,
                           seq_lens=valid_len_batch, mask=valid_mask_batch, keep_prob=1,
                           is_training=False, reuse_variables=True)

    accuracy_val = np.zeros([FLAGS.eval_steps//10])
    mask_val = np.zeros([FLAGS.eval_steps // 10])
    valid_accuracy = []

    with tf.Session() as sess:

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())

        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
        test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
        start_step = 0
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print("restore from the checkpoint {0}".format(ckpt))
                start_step += int(ckpt.split('-')[-1])

        logger.info(':::Training Start:::')
        try:
            i = 0
            while True:
                i += 1
                start_time = time.time()

                _, loss_val, train_summary, step = sess.run(
                    [graphTRN['train_op'], graphTRN['loss'],
                     graphTRN['merged_summary_op'], graphTRN['global_step']])

                train_writer.add_summary(train_summary, step)
                end_time = time.time()
                logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
                if step > FLAGS.max_steps:
                    break
                if (step % FLAGS.eval_steps == 1) and (step > 300):
                    valid_summary = ''

                    for iter_num in range(FLAGS.eval_steps//10):
                        accuracy_val[iter_num], valid_summary_val, mask_batch = \
                            sess.run([graphVAL['accuracy'], graphVAL['merged_summary_op'],
                                      graphVAL['mask']])
                        valid_summary = valid_summary+valid_summary_val
                        mask_val[iter_num] = sum(mask_batch)

                    accuracy_valid = np.sum(accuracy_val*mask_val)/np.sum(mask_val)
                    valid_accuracy.append(accuracy_valid)

                    test_writer.add_summary(valid_summary, step)
                    logger.info('===============Eval a batch=======================')
                    logger.info('the step {0} test accuracy: {1}'
                                .format(step, accuracy_valid))
                    logger.info('{0} seconds passed'
                                .format(time.time() - time0))
                    logger.info('===============Eval a batch=======================')
                if step % FLAGS.save_steps == 1:
                    logger.info('Save the ckpt of {0}'.format(step))
                    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, model_name),
                               global_step=graphTRN['global_step'])
        except tf.errors.OutOfRangeError:
            logger.info('==================Train Finished================')
            saver.save(sess, os.path.join(FLAGS.checkpoint_dir, model_name), global_step=graphTRN['global_step'])

    train_end = time.time()

    print('total training time: %f' % (train_end-train_start))
           

个人认为,这里面需要注意的主要有两点:

1. 在训练的过程中,同时监控validation数据集的结果,是比较合理的做法,因为,如果代码没错的话,training数据集的loss会一直下降,直至过拟合,而validation数据集的结果才能正确反映模型的效果。关于这部分的处理,请关注代码中graphTRN和graphVAL两部分的处理,以及在上一篇文章中提到的tf.get_variable_scope().reuse_variables()的使用;

2. 代码中使用了Dataset,其输出可以作为每个batch的数据之间输入给模型函数,避免自行构建batch数据及使用placeholder。个人认为,这种方式更为简洁清晰。

验证(validation)代码:

def validation(filenames):
    print('Begin validation')

    test_feeder = DataIterator(filenames, istrain=False)
    dataset = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1, shuffle=False)
    val_image_batch = tf.reshape(dataset[0],[-1,FLAGS.image_size,FLAGS.image_size,FLAGS.image_channel])
    val_label_batch = tf.reshape(dataset[1],[-1])
    val_len_batch = dataset[2]
    val_mask_batch = tf.reshape(dataset[3],[-1])

    graph = build_graph(top_k=3, images=val_image_batch, labels=val_label_batch,
                        seq_lens=val_len_batch, mask=val_mask_batch,
                        keep_prob=1.0, is_training=False)

    groundtruth = []
    seq_lens = []

    with tf.Session() as sess:
        saver = tf.train.Saver()

        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        if ckpt:
            saver.restore(sess, ckpt)
            print("restore from the checkpoint {0}".format(ckpt))

        logger.info(':::Start validation:::')
        try:
            i = 0
            acc_top_1 = 0.0
            while True:
                i += 1
                start_time = time.time()
                batch_labels, acc_1, seq_lens_batch =\
                    sess.run([graph['labels'], graph['accuracy'], graph['seq_lens']])

                groundtruth += batch_labels.tolist()
                acc_top_1 += acc_1*len(batch_labels)
                seq_lens += seq_lens_batch.tolist()

                end_time = time.time()
                logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1)"
                            .format(i, end_time - start_time, acc_1))

        except tf.errors.OutOfRangeError:
            logger.info('==================Validation Finished================')
            logger.info('size of acc_top_1: {0}; size of labels: {1}'.\
                        format(np.asarray(acc_top_1).shape, np.asarray(groundtruth).shape))
            acc_top_1 = acc_top_1 / len(groundtruth) #* FLAGS.batch_size / test_feeder.size
            logger.info('top 1 accuracy {0}'.format(acc_top_1))

    return {'groundtruth': groundtruth, 'seq_lens': seq_lens}
           

有了前面训练的代码,这部分就相当简单了,因为基本就是训练过程中监控validation数据集部分的代码。

在给出最终validation结果之前,先说明用到的参数:

tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")

tf.app.flags.DEFINE_integer('charset_size_long', 7356, "Choose the first `charset_size` characters only.")
tf.app.flags.DEFINE_integer('charset_size_short', 4000, "Choose the first `charset_size` characters only.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_integer('image_channel', 1, "channel number")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 200002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 1000, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 100, "the steps to save")
tf.app.flags.DEFINE_integer('sent_len_max', 30, "maximum sentence length")
tf.app.flags.DEFINE_integer('char_vec_len', 1024, "the length of character vector")

tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoint_0626v13/', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', '../data/hwdb_by_char_Train/', 'the train dataset dir')
tf.app.flags.DEFINE_string('test_data_dir', '../data/hwdb_by_char_Test/', 'the test dataset dir')
tf.app.flags.DEFINE_string('inferenc_dir', './inference/', 'the inference dir')
tf.app.flags.DEFINE_string('log_dir', './log', 'the logging dir')
tf.app.flags.DEFINE_string('sample_dir', './sample', 'the sample dir')
tf.app.flags.DEFINE_string('train_hwdb_dir', '../hwdb/hwdb_by_char_Train_gbk', 'the train hwdb dir')
tf.app.flags.DEFINE_string('test_hwdb_dir', '../hwdb/hwdb_by_char_Test_gbk', 'the test hwdb dir')

tf.app.flags.DEFINE_string('char_dict_filename', 'char_dict_gbk20180518', 'dict: char->number')
tf.app.flags.DEFINE_string('char_rvs_dict_filename', 'char_dict_gbk_rvs20180518', 'dict: number->char')
tf.app.flags.DEFINE_boolean('short_dict', True, 'whether to use short dict')

tf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_integer('batch_size', 8, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'validation', 'Running mode. One of {"train", "validation", "inference"}')

tf.app.flags.DEFINE_integer('lstm_size', 512, 'LSTM size')
tf.app.flags.DEFINE_integer('num_layers', 2, 'LSTM layers number')
tf.app.flags.DEFINE_boolean('viterbi', True, 'whether to restore from checkpoint')
           

其中比较重要的参数有三个:

  1. FLAGS.short_dict,用于指示是否使用短字库,也就是4000个中文字符的字库,而不是7356的完整hwdb字库。这4000个字符是通过上百万篇文章统计得到的使用最频繁的4000个中文字符,占据了上百万篇文章中文字符的95%以上。其它没有出现在这4000个字符中的中文字符,归结为<UNK>。本模型在实际训练和验证时使用了短字库;
  2. FLAGS.viterbi,是否使用Viterbi算法,具体使用方式在上一篇文章中已有介绍。本模型完成了打开和关闭Viterbi算法的模型训练和验证。需要说明的是,打开Viterbi算法之后,训练时长(单GTX1080Ti,每1000个batch训练时长需要1.5小时左右)是关闭时的20倍以上,validation数据集识别准确率提升约1.4个百分点;
  3. FLAGS.mode,在训练时改为'train',在validation时改为'validation'。

按照上述参数配置,训练20万次之后,validation数据识别结果:

  • 关闭Viterbi算法(6次平均):96.16%
  • 打开Viterbi算法(7次平均):97.53%

Viterbi算法带来约1.4个百分点的增益,但训练时长增加约20倍。

至此,本项目系列文章终结。很遗憾,因为版权原因,不能上传数据集供大家使用,只能提供代码供参考了。如有任何问题或代码bug,欢迎联系:[email protected]   谢谢!

继续阅读