天天看点

tensorflow——tf.argmax()和axis详解

tf.argmax(input,axis)根据axis取值的不同返回每行或者每列最大值的索引。

举个例子:

import tensorflow as tf
import numpy as np
 
A = [[1,3,4,5,6]]
B = [[1,3,4], [2,4,1]]
 
with tf.Session() as sess:
    print(sess.run(tf.argmax(A, 1)))
    print(sess.run(tf.argmax(B, 1)))
           

输出:

[4]

[2 1]

输出[4]因为在A中6最大,6的下标是4。

同理在B[0]中4最大,B[1]中也是4最大,其下标分别为2和1。

这里有一个参数axis可以设置:

axis=0时比较每一列的元素,将每一列最大元素所在的索引记录下来,最后输出每一列最大元素所在的索引数组。

axis=1的时候,将每一行最大元素所在的索引记录下来,最后返回每一行最大元素所在的索引数组。

这样说比较乱,看个例子:

test = np.array([
[1, 2, 3],
 [2, 3, 4], 
 [5, 4, 3], 
 [8, 7, 2]])
np.argmax(test, 0)   #输出:array([3, 3, 1]
np.argmax(test, 1)   #输出:array([2, 2, 0, 0]
           

继续阅读