DeepLabv3+代码阅读之data_generator.py
一、Dataset
1. __init__
class Dataset(object):
"""Represents input dataset for deeplab model."""
def __init__(self,
dataset_name,# Dataset name
split_name,# A train/val Split name
dataset_dir,# The directory of the dataset sources
batch_size,# Batch size
crop_size,# The size used to crop the image and label
min_resize_value=None,# Desired size of the smaller image side
max_resize_value=None,# Maximum allowed size of the larger image side
resize_factor=None,# Resized dimensions are multiple of factor plus one
min_scale_factor=1.,# Minimum scale factor value
max_scale_factor=1.,# Maximum scale factor value
scale_factor_step_size=0,
model_variant=None,# Model variant (string) for choosing how to mean-subtract the images.
# See feature_extractor.network_map for supported model variants.
num_readers=1,# Number of readers for data provider
is_training=False,# Boolean, if dataset is for training or not
should_shuffle=False,# Boolean, if should shuffle the input data
should_repeat=False):# Boolean, if should repeat the input data
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.')
self.dataset_name = dataset_name
splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes# 训练测试集对应的图片数量
if split_name not in splits_to_sizes:
raise ValueError('data split name %s not recognized' % split_name)
if model_variant is None:# 使用的模型,如xception_65
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
self.split_name = split_name
self.dataset_dir = dataset_dir
self.batch_size = batch_size
self.crop_size = crop_size
self.min_resize_value = min_resize_value
self.max_resize_value = max_resize_value
self.resize_factor = resize_factor
self.min_scale_factor = min_scale_factor
self.max_scale_factor = max_scale_factor
self.scale_factor_step_size = scale_factor_step_size
self.model_variant = model_variant
self.num_readers = num_readers
self.is_training = is_training
self.should_shuffle = should_shuffle
self.should_repeat = should_repeat
self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes# 类别数
self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label# 忽略的标记
2. get_one_shot_iterator
得到一个迭代数据集一次的迭代器
返回:
tf.data.Iterator类型迭代器
def get_one_shot_iterator(self):
"""Gets an iterator that iterates across the dataset once.
Returns:
An iterator of type tf.data.Iterator.
"""
files = self._get_all_files()# 得到数据集文件(tfreord)的文件路径列表
# tf.data.TFRecordDataset, 从一个或多个TFRecord文件生成数据集,map()对数据做自定义的transform
dataset = (
tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
.map(self._parse_function, num_parallel_calls=self.num_readers)# 解析得到图片、标记等信息
.map(self._preprocess_image, num_parallel_calls=self.num_readers))
if self.should_shuffle:# 是否需要打乱顺序
dataset = dataset.shuffle(buffer_size=100)# 从数据集里选100个组成buffer,每次读数据从buffer里随机取一个,
# 再从剩下的里边补一个到buffer。最完美的shuffle则是buffer中的个数
# 大于等于整个数据集所含的图片个数。
if self.should_repeat:
dataset = dataset.repeat() # Repeat forever for training.
else:
dataset = dataset.repeat(1) # 对于测试则只需要重复一遍数据集即可
dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
# .batch()形成batch,.prefetch(),将生成数据的时间和使用数据的时间分离开,减少空闲时间
return dataset.make_one_shot_iterator()# 创建一个列举数据集中元素的迭代器,最新版本tf中已被废弃
3. _get_all_files
得到所有文件读取的位置
返回:
输入文件的list
def _get_all_files(self):
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(self.dataset_dir,
file_pattern % self.split_name)
return tf.gfile.Glob(file_pattern)# tf.gfile.Glob()返回符合模式的文件路径列表
# 例如:file_pattern='tensorflow/models/research/deeplab/datasets/pascal_voc_seg/tfrecord/train-*'
其中
_FILE_PATTERN = '%s-*'
4. _parse_function
解析example proto
参数:
tf.Example格式的Proto
返回:
字典,包含解析后的图片、标签、尺寸、文件名
备注:
当前仅支持jpeg和png格式的图片
def _parse_function(self, example_proto):
# Currently only supports jpeg and png.
# Need to use this logic because the shape is not known for tf.image.decode_image and we rely on
# this info to extend label if necessary.
def _decode_image(content, channels):
return tf.cond(
tf.image.is_jpeg(content),
lambda: tf.image.decode_jpeg(content, channels),
lambda: tf.image.decode_png(content, channels))
features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height':
tf.FixedLenFeature((), tf.int64, default_value=0),
'image/width':
tf.FixedLenFeature((), tf.int64, default_value=0),
'image/segmentation/class/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/segmentation/class/format':
tf.FixedLenFeature((), tf.string, default_value='png'),
}
parsed_features = tf.parse_single_example(example_proto, features)
image = _decode_image(parsed_features['image/encoded'], channels=3)
label = None
if self.split_name != common.TEST_SET:
label = _decode_image(
parsed_features['image/segmentation/class/encoded'], channels=1)
image_name = parsed_features['image/filename']
if image_name is None:
image_name = tf.constant('')
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: parsed_features['image/height'],
common.WIDTH: parsed_features['image/width'],
}
if label is not None:
if label.get_shape().ndims == 2:
label = tf.expand_dims(label, 2)
elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
sample[common.LABELS_CLASS] = label
return sample
5. _preprocess_image
对图片和标记做处理
参数:
包含图片和标签的sample
返回:
处理后的sample
def _preprocess_image(self, sample):
image = sample[common.IMAGE]
label = sample[common.LABELS_CLASS]
original_image, image, label = input_preprocess.preprocess_image_and_label(
image=image,
label=label,
crop_height=self.crop_size[0],
crop_width=self.crop_size[1],
min_resize_value=self.min_resize_value,
max_resize_value=self.max_resize_value,
resize_factor=self.resize_factor,
min_scale_factor=self.min_scale_factor,
max_scale_factor=self.max_scale_factor,
scale_factor_step_size=self.scale_factor_step_size,
ignore_label=self.ignore_label,
is_training=self.is_training,
model_variant=self.model_variant)
sample[common.IMAGE] = image# 将sample里的image更新为处理过的
if not self.is_training:
# Original image is only used during visualization.
sample[common.ORIGINAL_IMAGE] = original_image
if label is not None:
sample[common.LABEL] = label
# Remove common.LABEL_CLASS key in the sample since it is only used to
# derive label and not used in training and evaluation.
sample.pop(common.LABELS_CLASS, None)
return sample
二、_DATASETS_INFORMATION
数据集的信息,包含有cityscapes等,使用自己的数据集也要补充在这里。
_DATASETS_INFORMATION = {
'cityscapes': _CITYSCAPES_INFORMATION,
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
'ade20k': _ADE20K_INFORMATION,
}
pascal_voc_seg数据集对应的INFORMATION,其中包含了训练、测试集图片数量,类别数等信息。由DatasetDescriptor类定义。
_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 1464,
'train_aug': 10582,
'trainval': 2913,
'val': 1449,
},
num_classes=21,
ignore_label=255,
)
三、DatasetDescriptor
DatasetDescriptor是用来描述数据集性质的元组。使用collections.namedtuple定义,包含数据集的各种信息。
# Named tuple to describe the dataset properties.
DatasetDescriptor = collections.namedtuple(
'DatasetDescriptor',
[
'splits_to_sizes', # Splits of the dataset into training, val and test.
'num_classes', # Number of semantic classes, including the
# background class (if exists). For example, there
# are 20 foreground classes + 1 background class in
# the PASCAL VOC 2012 dataset. Thus, we set
# num_classes=21.
'ignore_label', # Ignore label value.
])