TensorFlow中维护的集合列表
在一个计算图中,可以通过集合(
collection
)来管理不同类别的资源。比如通过 tf.add_to_collection
函数可以将资源加入一个 或多个集合中,然后通过 tf.get_collection
获取一个集合里面的所有资源(如张量,变量,或者运行TensorFlow程序所需的队列资源等等) 集合名称 | 集合内容 | 使用场景 |
---|---|---|
| 所有变量 | 持久化 TensorFlow 模型 |
| 可学习的变量(一般指神经网络中的参数) | 模型训练、生成模型可视化内容 |
| 日志生成相关的张量 | TensorFlow 计算可视化 |
| 处理输入的 QueueRunner | 输入处理 |
| 所有计算了滑动平均值的变量 | 计算变量的滑动平均值 |
- TensorFlow中的所有变量都会被自动加入
集合中,通过tf.GraphKeys.VARIABLES
函数可以拿到当前计算图上的所有变量。拿到计算图上的所有变量有助于持久化整个计算图的运行状态。tf.global_variables()
- 当构建机器学习模型时,比如神经网络,可以通过变量声明函数中的
参数来区分需要优化的参数(比如神经网络的参数)和其他参数(比如迭代的轮数,即超参数),若trainable
,则此变量会被加入trainable = True
集合。然后通过tf.GraphKeys.TRAINABLE_VARIABLES
函数便可得到所有需要优化的参数。TensorFlow中提供的优化算法会将tf.trainable_variables
集合中的变量作为 默认的优化对象。tf.GraphKeys.TRAINABLE_VARIABLES
- 变量的类型是不可以改变的。
- 变量的维度一般是不能改变的,除非设置参数
(很少去改变它)validate_shape = False
TensorFlow中的 tf.Variable
函数随机数和常数的生成:
tf.Variable
函数名 | 随机数分布 | 主要参数 |
---|---|---|
| 正态分布 | 平均值、标准差、取值类型 |
| 满足正态分布的随机值,但若随机值偏离平均值超过2个标准差,则这个数会被重新随机 | |
| 平均分布 | 最大、最小值、取值类型 |
| Gramma分布 | 形状参数alpha、尺度参数beta、取值类型 |
功能 | 示例 | |
---|---|---|
| 产生全0的数组 | |
| 产生全1的数组 | |
| 产生一个全部为给定数组的数组 | |
| 产生一个给定值的常量 | |
TensorFlow 中的 tf.get_variable
变量初始化函数
tf.get_variable
初始化函数 | ||
---|---|---|
| 将变量初始化为给定常数 | 常数的取值 |
| 将变量初始化为满足正态分布的随机值 | 正态分布的均值和标准差 |
| 将变量初始化为满足正态分布的随机值,但若随机值偏离平均值超过2个标准差,则这个数会被重新随机 | |
| 将变量初始化为满足平均分布的随机值 | 最大、最小值 |
| 将变量初始化为满足平均分布但不影响输出数量级的随机值 | factor(产生随机值时乘以的系数) |
| 将变量初始化为全0 | 变量维度 |
| 将变量初始化为全1 |
-
函数将张量限定在一定的范围内:tf.clip_by_value
sess = tf.InteractiveSession()
v = tf.constant([[1., 2., 3.], [4., 5., 6.]])
tf.clip_by_value(v, 2.5, 4.5).eval() # 小于2.5的数值设为2.5,大于4.5的数值设为4.5
array([[ 2.5, 2.5, 3. ],
[ 4. , 4.5, 4.5]], dtype=float32)
-
对张量所有元素进行对数运算tf.log
tf.log(v).eval()
array([[ 0. , 0.69314718, 1.09861231],
[ 1.38629436, 1.60943794, 1.79175949]], dtype=float32)
-
tf.greater
的输入是两个张量,此函数会比较这两个张量中的每一个元素,并返回比较结果;
当输入维度不一致时会进行广播(broadcasting)
v1 = tf.constant([1., 2., 3., 4.])
v2 = tf.constant([4., 3., 2., 1.])
f = tf.greater(v1, v2)
f.eval()
Out[11]:
array([False, False, True, True], dtype=bool)
-
tf.where
函数有三个参数:
第一个选择条件根据,当选择条件为
时,会选择第二个参数中的值,否则使用第三个参数中的值:True
tf.where(f, v1, v2).eval()
array([ 4., 3., 3., 4.], dtype=float32)
探寻有趣之事!