天天看點

深入淺出TensorFlow2函數——tf.data.Dataset.batch

分類目錄:《深入淺出TensorFlow2函數》總目錄

函數:

batch(batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None,name=None)
           

該函數可以将此資料集的連續元素合并到batch中。

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())
# [array([0, 1, 2], dtype=int64), array([3, 4, 5], dtype=int64)]
           

函數傳回值将有一個額外的外部次元,即

batch_size

。如果

batch_size

未将輸入元素的數量

N

N

N平均分割,且

drop_remainder

False

,則最後一個元素的

batch_size

N % batch_size

。如果需要依賴于具有相同尺寸的batch,則應将

drop_rements

參數設定為

True

,以防止生成較小的批。如果程式要求資料具有靜态已知形狀,則應使用

drop_rements=True

。如果沒有

drop_rements=True

,則輸出資料集的形狀将具有未知的前導次元,因為最終批次可能更小。

參數 意義
batch_size [

tf.int64

/

tf.Tensor

]表示要在單個批次中組合的此資料集的連續元素數。
drop_remainder [可選,

tf.bool

/

tf.Tensor

]表示如果最後一批元素少于批次大小,是否應删除最後一批元素,預設為

False

num_parallel_calls [可選,

tf.int64

/

tf.Tensor

]表示異步并行計算的批數。如果未指定,将按順序計算批次。如果值為

tf.data.AUTOTUNE

被使用,則根據可用資源動态設定并行調用的數量。
deterministic [可選]當

num_parallel_calls

被指定時,如果指定了此布爾值(

True

False

),則它控制轉換生成元素的順序。如果設定為

False

,則允許轉換産生無序的元素,即損失性能的情況下換取确定性。如果未指定,則

tf.data.Options.deterministic

(預設為

True

)來控制行為。
name [可選]

tf.data

操作的名稱
傳回值 意義
Dataset 一個

tf.data.Dataset

的資料集。
def batch(self,
            batch_size,
            drop_remainder=False,
            num_parallel_calls=None,
            deterministic=None,
            name=None):
    """Combines consecutive elements of this dataset into batches.
    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3)
    >>> list(dataset.as_numpy_iterator())
    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
    >>> dataset = tf.data.Dataset.range(8)
    >>> dataset = dataset.batch(3, drop_remainder=True)
    >>> list(dataset.as_numpy_iterator())
    [array([0, 1, 2]), array([3, 4, 5])]
    The components of the resulting element will have an additional outer
    dimension, which will be `batch_size` (or `N % batch_size` for the last
    element if `batch_size` does not divide the number of input elements `N`
    evenly and `drop_remainder` is `False`). If your program depends on the
    batches having the same outer dimension, you should set the `drop_remainder`
    argument to `True` to prevent the smaller batch from being produced.
    Note: If your program requires data to have a statically known shape (e.g.,
    when using XLA), you should use `drop_remainder=True`. Without
    `drop_remainder=True` the shape of the output dataset will have an unknown
    leading dimension due to the possibility of a smaller final batch.
    Args:
      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
        consecutive elements of this dataset to combine in a single batch.
      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
        whether the last batch should be dropped in the case it has fewer than
        `batch_size` elements; the default behavior is not to drop the smaller
        batch.
      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
        representing the number of batches to compute asynchronously in
        parallel.
        If not specified, batches will be computed sequentially. If the value
        `tf.data.AUTOTUNE` is used, then the number of parallel
        calls is set dynamically based on available resources.
      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
        boolean is specified (`True` or `False`), it controls the order in which
        the transformation produces elements. If set to `False`, the
        transformation is allowed to yield elements out of order to trade
        determinism for performance. If not specified, the
        `tf.data.Options.deterministic` option (`True` by default) controls the
        behavior.
      name: (Optional.) A name for the tf.data operation.
    Returns:
      Dataset: A `Dataset`.
    """
    if num_parallel_calls is None or DEBUG_MODE:
      if deterministic is not None and not DEBUG_MODE:
        warnings.warn("The `deterministic` argument has no effect unless the "
                      "`num_parallel_calls` argument is specified.")
      return BatchDataset(self, batch_size, drop_remainder, name=name)
    else:
      return ParallelBatchDataset(
          self,
          batch_size,
          drop_remainder,
          num_parallel_calls,
          deterministic,
          name=name)
           

繼續閱讀