tf.argmax(input,axis)根据axis取值的不同返回每行或者每列最大值的索引。
举个例子:
import tensorflow as tf
import numpy as np
A = [[1,3,4,5,6]]
B = [[1,3,4], [2,4,1]]
with tf.Session() as sess:
print(sess.run(tf.argmax(A, 1)))
print(sess.run(tf.argmax(B, 1)))
输出:
[4]
[2 1]
输出[4]因为在A中6最大,6的下标是4。
同理在B[0]中4最大,B[1]中也是4最大,其下标分别为2和1。
这里有一个参数axis可以设置:
axis=0时比较每一列的元素,将每一列最大元素所在的索引记录下来,最后输出每一列最大元素所在的索引数组。
axis=1的时候,将每一行最大元素所在的索引记录下来,最后返回每一行最大元素所在的索引数组。
这样说比较乱,看个例子:
test = np.array([
[1, 2, 3],
[2, 3, 4],
[5, 4, 3],
[8, 7, 2]])
np.argmax(test, 0) #输出:array([3, 3, 1]
np.argmax(test, 1) #输出:array([2, 2, 0, 0]