天天看点

tensorflow中将ckpt模型文件转化为pb文件定义网络结构去掉网络的dropout单元代码实现

定义网络结构

定义所要转化模型的网络结构,得到最后一层的输出,并且给他命名一个名字,在这里给他定义为out,因为后面把ckpt变成pb时,会用到一个函数,这个函数有一个参数叫做output_node_names,也就是输出节点的名字,所以我们需要定义输出节点的名字。

output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,  # = sess.graph_def,
            output_node_names=['out'])
           

去掉网络的dropout单元

Dropout单元是在训练过程中避免过拟合的,因此在预测时,不需要加入Dropout单元,因此在恢复网络结构时。不需要将Dropout单元进行恢复,因此不带有Dropout单元的网络结构为:

# 第一层卷积
    W_conv1 = weight_variable([3, 3, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    # 第二层卷积
    W_conv2 = weight_variable([3, 3, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    # 全连接
    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    #输出
    W_fc2 = weight_variable([1024, 2])
    b_fc2 = bias_variable([2])
    out = tf.nn.softmax((tf.matmul(h_fc1, W_fc2) + b_fc2),name='out')
           

代码实现

import tensorflow as tf
from tensorflow.python.framework import graph_util
from PIL import Image
import os
import numpy as np

#构造网络函数
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')



def freeze_graph(path, output):
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8,
                                allow_growth=True)  ##每个gpu占用0.8																				的显存
    config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    # 定义输入的x
    x = tf.placeholder(tf.float32, [1, 28, 28],name="placeholder1")
    x_image = tf.reshape(x, [-1, 28, 28, 1])

    # 第一层卷积
    W_conv1 = weight_variable([3, 3, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    # 第二层卷积
    W_conv2 = weight_variable([3, 3, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    # 全连接层
    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    #输出层
    W_fc2 = weight_variable([1024, 2])
    b_fc2 = bias_variable([2])
    out = tf.nn.softmax((tf.matmul(h_fc1, W_fc2) + b_fc2),name='out')
     sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    softmax_result = out
    # print(image.shape)
    saver = tf.train.import_meta_graph(path + '.meta', clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, path)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,  # = sess.graph_def,
            output_node_names=['out'])

        with tf.gfile.GFile(output, 'wb') as fgraph:
            fgraph.write(output_graph_def.SerializeToString())


if __name__ == '__main__':
    #
    freeze_graph('H:/models/model-3.ckpt', 'H:/models/model-3.pb')
           

继续阅读