tensorflow的幾個函數
-
tf.argmax()
通俗的講,該函數是傳回最大的那個數值所在的下标,難點在于第二個參數
先放一個例子,下面進行解釋
運作結果argmax_paramter = tf.Variable([[1, 32, 44, 56], [89, 12, 90, 33], [35, 69, 1, 10]]) argmax_0 = tf.argmax(argmax_paramter, 0) argmax_1 = tf.argmax(argmax_paramter, 1) print("argmax_0:", sess.run(argmax_0)) print("argmax_1:", sess.run(argmax_1))
第二個參數為0時比較所有數組相同位置上的數
第二個參數為1時分别比較每個數組中數的大小
- tf.reduce_mean()
reduce_mean(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None)
該函數用于計算張量tensor沿着指定的數軸(tensor的某一次元)上的的平均值,主要用 作降維或者計算tensor(圖像)的平均值
input_tensor: 輸入的tensor
axis: 指定的軸,如果不指定,則計算所有元素的均值
keep_dims:是否降次元,設定為True,輸出的結果保持輸入tensor的形狀, 設定為False,輸出結果會降低次元
name: 操作的名稱
reduction_indices:用來指定軸
import tensorflow as tf # 次元為2,形狀為[2,3] tensor = [[1,2,3], [1,2,3]] tensor = tf.cast(x,tf.float32) mean_all = tf.reduce_mean(xx, keep_dims=False) mean_0 = tf.reduce_mean(xx, axis=0, keep_dims=False) mean_1 = tf.reduce_mean(xx, axis=1, keep_dims=False) with tf.Session() as sess: m_a,m_0,m_1 = sess.run([mean_all, mean_0, mean_1]) print(m_a) # output: 2.0 print(m_0) # output: [1 2 3] print(m_1) # output: [2 2] # 如果設定保持原來的張量的次元,keep_dims=True print(m_a) # output: [[2]] print(m_0) # output: [[1 2 3]] print(m_1) # output: [[2], [2]]
- tf.equal()
該函數是判斷x,y是否相等,逐個元素進行判斷equal(x, y, name=None)
import tensorflow as tf a = [[1,2,3],[4,5,6]] b = [[1,0,3],[1,5,1]] with tf.Session() as sess: print(sess.run(tf.equal(a,b))) # output: # [[ True False True] # [False True False]]
- tf.cast()
該函數的功能是将x的資料類型轉換為dtypetf.cast(x, dtype, name=None)
a = tf.Variable([1,0,0,1,1]) b = tf.cast(a,dtype=tf.bool) sess = tf.Session() sess.run(tf.initialize_all_variables()) print(sess.run(b)) # output:[True False False True True]