天天看点

tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)

      网上从文件中读取样本和标签的资料很多,但大多讲的不全面,或只讲原理,或只有变为.tfrecords部分,或没有调用的栗子。寄几and男票一起捣鼓了两天,终于有了目前这个完整版的代码,希望对看到的朋友有所帮助。

1. 准备样本和标签

样本图示如图1,标签文件train_y.csv如图2,这是个2分类问题。

tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)

图1

tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)

图2

2.生成记录样本的记录文件

我们的图片存储路径如图3红框所示,标签文件train_y.csv存储路径如图3绿框所示。

我们用ray14_train.py进行train,这个.py文件和train_y.csv不在同一目录下。所以,在标签文件train_y.csv中,我们需要将图片名称这一列变为相对路径,如图4所示,这个新csv我们存为y_train.csv,测试集也这么处理。

tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)

图3

tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)

图4

import numpy as np
import pandas as pd
import cv2
import csv
from os import path as osp
import os
           
base_path = os.path.join('images','images224')
train_y_path = os.path.join(base_path,'train_y.csv')
train_y = np.loadtxt(train_y_path, delimiter=",", skiprows=0, usecols=(0,1), dtype=str)
train_y_pd = pd.DataFrame(train_y)
for i in range(train_y.shape[0]):
    train_y_pd.iloc[i,0] = os.path.join(base_path,train_y[i,0])
train_y_pd.to_csv(os.path.join(base_path, 'y_train.csv'),header=None,index=None)
           

先将2运行,得到y_train.csv和y_test.csv,从3开始要正式读取了。

3.读取csv存于数组中,将图片路径和标签存于数组中

def load_file(example_list_file):
    lines = np.genfromtxt(example_list_file,delimiter=",",dtype=[('col1', 'S120'), ('col2', 'i8')])
    examples = []
    labels = []
    for example,label in lines:
        examples.append(example)
        labels.append(label)
    #convert to numpy array
    return np.asarray(examples),np.asarray(labels),len(lines)
           

4.使用cv2读取图片

def extract_image(filename,height,width):
    # print(filename)
    image = cv2.imread(filename)
    # image = cv2.resize(image,(height,width))
    b,g,r = cv2.split(image)
    rgb_image = cv2.merge([r,g,b])
    return rgb_image
           

5.将图片和标签转化为tfrecords文件

def trans2tfRecord(train_file,name,output_dir,height,width):
    if not os.path.exists(output_dir) or os.path.isfile(output_dir):
        os.makedirs(output_dir)
    _examples,_labels,examples_num = load_file(train_file)
    filename = name + '.tfrecords'
    writer = tf.python_io.TFRecordWriter(filename)
    for i,[example,label] in enumerate(zip(_examples,_labels)):
        # print("NO{}".format(i))
        #need to convert the example(bytes) to utf-8
        example = example.decode("UTF-8")
        image = extract_image(example,height,width)
        image_raw = image.tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw':_bytes_feature(image_raw),
                'height':_int64_feature(image.shape[0]),
                 'width': _int64_feature(32),
                'depth': _int64_feature(32),
                 'label': _int64_feature(label)
                }))
        writer.write(example.SerializeToString())
    writer.close()
    return filename
           
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
           

6.从tfrecords文件中读取训练数据

def read_tfRecord(file_tfRecord,shuffle=False):
    # 这个函数需要传入一个文件名,系统会自动将它转为一个文件名队列,这个队列存的是训练或测试过程用到的数据
    # tf.train.string_input_producer有两个重要的参数,一个是num_epochs,这个设成默认none就行,none表示无限次
    # 它表示将全部样本入队次数,一般程序迭代几次就入队几次。程序运行开始,数据就开始出队,为了保证队列一直不空,
    # 我们设为none,使全部样本入队无数次(无限循环)。
    # 另外一个就是shuffle,shuffle是指在一个epoch内文件的顺序是否被打乱(但是我测试时发现无论是True还是False,其实都打乱了)。
    queue = tf.train.string_input_producer([file_tfRecord], shuffle=shuffle)
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(queue)
    features = tf.parse_single_example(
            serialized_example,
            features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'height': tf.FixedLenFeature([], tf.int64),
          'width':tf.FixedLenFeature([], tf.int64),
          'depth': tf.FixedLenFeature([], tf.int64),
          'label': tf.FixedLenFeature([], tf.int64)
                    }
            )
    image = tf.decode_raw(features['image_raw'],tf.uint8)
    #height = tf.cast(features['height'], tf.int64)
    #width = tf.cast(features['width'], tf.int64)
    image = tf.reshape(image,[224,224,3])
    image = tf.cast(image, tf.float32)
    image = tf.image.per_image_standardization(image)
    label = tf.cast(features['label'], tf.int64)
    print(image,label)
    return image,label
           

7.调用3-6,开始训练

with tf.Session() as sess:
    # 训练过程
    base_path = os.path.join('images','images224')
    data_train_path = os.path.join(base_path,'y_train.csv')
    data_test_path = os.path.join(base_path,'y_test.csv')
    # 首次执行程序需要运行一旦生成之后就可以注释掉了:利用csv生成y_train.tfrecords和y_test.tfrecords文件,这俩文件是训练集和测试集的样本与标签,
    filename = trans2tfRecord(data_train_path, 'y_train', base_path, 224, 224)
    filename2 = trans2tfRecord(data_train_path, 'y_test', base_path, 224, 224)
    img_batch, path_batch = read_tfRecord(filename, shuffle=True)
    img_batch2, path_batch2 = read_tfRecord(filename2, shuffle=False)
    image_batches, label_batches = tf.train.batch([img_batch, path_batch], batch_size=batch, capacity=4096)
    image_batches2, label_batches2 = tf.train.batch([img_batch2, path_batch2], batch_size=batch, capacity=4096)
    tf.local_variables_initializer().run()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    # 定义一个模型
    model=ATDA(sess=sess)
    model.create_model()
    # 训练模型:(image_batches,label_batches)是训练集,(image_batches2,label_batches2)是测试集,
    model.fit_ATDA(source_train=image_batches, y_train=label_batches,
                   target_val=image_batches2, y_val=label_batches2,
                   # n是训练集总数,my_number是测试集总数,my_catelogy是标签种类,batch是迭代次数
                   nb_epoch=epochs, n = 86524, my_number = 25596, my_catelogy = 2,batch = 16)
    coord.request_stop()  # 请求线程结束
    coord.join()  # 等待线程结束
           

8.model.fit_ATDA(),这部分是训练模型。

def fit_ATDA(source_train, y_train, target_val, y_val, nb_epoch=30,
             n = 86524, my_number = 25596, my_catelogy = 2, batch = 4):
    for e in range(nb_epoch):
        n_batch = 0
        for my_batch_train in range(int(n/batch)):
            Xu_batch, Yu_batch = self.sess.run([source_train, y_train])
            Xu_batch = transform_batch_images(Xu_batch)
            Yu_batch = np_utils.to_categorical(Yu_batch, 2)
            # print('train label',Yu_batch)
            feed_dict = { self.x: Xu_batch, self.y_: Yu_batch ,self.istrain:True}
            cost, Ft_loss = self.sess.run([cost, Ft_loss], feed_dict=feed_dict)
            n_batch += 1
            #every 1000 minibatch print loss
            if n_batch % 1000==0:
                print("Epoch %d  total_loss %f Ft_loss %f" % (e + 1, cost,Ft_loss))
           

其中,从文件读取部分代码是:

Xu_batch, Yu_batch = self.sess.run([source_train, y_train])
           

9.测试的代码就不写了,类似8。

参考资料:

1.https://zhuanlan.zhihu.com/p/27238630

2.https://www.cnblogs.com/wktwj/p/7257526.html

继续阅读