天天看點

TensorFlow scan()函數詳解tf.scan()幫助文檔

@[tf.scan()函數使用介紹]

tf.scan()幫助文檔

在python互動式解釋器下輸入help(tf.scan)檢視幫助文檔

TensorFlow scan()函數詳解tf.scan()幫助文檔

這句話說明了scan的作用。他将張量elems的第一維元素list作函數計算,帶入至fn,作為參數,累加作用,直到周遊完elems的所有一維元素。作用類似于functools裡面的reduce函數。

Demo

1維數組:

TensorFlow scan()函數詳解tf.scan()幫助文檔

scan會取出nums每一個一維資料,在這個代碼中就是首先取出1,然後和initializer=0進行相加,結果為1;然後在取出x=2,與上一步結果相加,結果為3;以此類推,下一步結果為3+3=6,最終的結果為:

[ 1 3 6 10 15 21]

這個例子很easy.

2維數組:

TensorFlow scan()函數詳解tf.scan()幫助文檔

這裡需要注意的是,我們在傳入Initializer的時候,需要自己判定應該傳入什麼類型的初始化值。在這裡需要傳入一個(6,)的向量。為什麼?

tf.scan()會選取elems的每一個第一個次元的子元素,那麼在這裡就是[1,2,3,4,5,6]和[1,2,3,4,5,6]兩個。然後累加作用于lambda表達式中,也就是說[1,2,3,4,5,6]+?,由于不允許一個向量+一個數字的形式出現,是以initialize隻能傳(6,)的向量。

計算過程:

x = [1,2,3,4,5,6],a = [1,1,1,1,1,1],x+a=[2,3,4,5,6,7]

a = [2,3,4,5,6,7],x=[1,2,3,4,5,6],x+a=[3,5,7,9,11,13]

最終輸出:

[

[2,3,4,5,6,7],

[3,5,7,9,11,13]

]

3維數組:隻給出Demo,自己分析原因

TensorFlow scan()函數詳解tf.scan()幫助文檔

輸出為:

TensorFlow scan()函數詳解tf.scan()幫助文檔