天天看点

Tensorflow使用flags定义命令行参数详解

TensorFlow定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv,详细用法请看代码中的注释

例一:

import tensorflow as tf

##第一个是参数名称,第二个参数是默认值,第三个是参数描述
tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.app.flags.DEFINE_integer('int_name', ,"descript2")
tf.app.flags.DEFINE_boolean('bool_name', False, "descript3")
FLAGS = tf.app.flags.FLAGS
##必须带参数,否则:'TypeError: main() takes no arguments (1 given)';   ##main的参数名随意定义,无要求
def main(_):  
    print(FLAGS.str_name)
    print(FLAGS.int_name)
    print(FLAGS.bool_name)

if __name__ == '__main__':
    tf.app.run()  #tf.app.run()的作用:先处理flag解析,然后执行main函数,
           
例二:
import tensorflow as tf
flags = tf.flags #flags是一个文件:flags.py,用于处理命令行参数的解析工作
logging = tf.logging

#调用flags内部的DEFINE_string函数来制定解析规则
flags.DEFINE_string("para_name_1","default_val", "description")
flags.DEFINE_bool("para_name_2","default_val", "description")

#FLAGS是一个对象,保存了解析后的命令行参数
FLAGS = flags.FLAGS

def main(_):
    FLAGS.para_name #调用命令行输入的参数

if __name__ == "__main__": #使用这种方式保证了,如果此文件被其它文件import的时候,不会执行main中的代码

    tf.app.run() #解析命令行参数,调用main函数 main(sys.argv)
'''
调用方法,在命令行窗口中输入:
~/ python script.py --para_name_1=name --para_name_2=name2
# 不传的话,会使用默认值
'''
           
例三:
#coding:utf-8
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/feige", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("train_batch_size", , "batch size of train data")
tf.app.flags.DEFINE_integer("test_batch_size", , "batch size of test data")
tf.app.flags.DEFINE_float("learning_rate", , "learning rate")

def main(unused_argv):
    train_data_path = FLAGS.train_data_path
    print("train_data_path", train_data_path)
    train_batch_size = FLAGS.train_batch_size
    print("train_batch_size", train_batch_size)
    test_batch_size = FLAGS.test_batch_size
    print("test_batch_size", test_batch_size)
    size_sum = tf.add(train_batch_size, test_batch_size)
    with tf.Session() as sess:
        sum_result = sess.run(size_sum)
        print("sum_result", sum_result)

# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
    tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)
           

如果需要修改默认参数的值,则在命令行传入自定义参数值即可,若全部使用默认参数值,则可直接在命令行运行该 python 文件。

tf.app.run() 真正运行原理,还需查阅其源代码:

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:] + flags_passthrough))
           

flags_passthrough=f._parse_flags(args=args)这里的_parse_flags就是我们tf.app.flags源码中用来解析命令行参数的函数。所以这一行就是解析参数的功能;

下面两行代码也就是 tf.app.run 的核心意思:执行程序中 main 函数,并解析命令行参数!