
08 Jun 2020

0. 数据转换


numpy的操作会自动把tensor变为numpy narray

0.1 手动转换


0.2 迭代


1. 创建Dataset对象

The Dataset object is a Python iterable. 可以对dataset迭代;

1.1 从 Numpy arrays中创建

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))

1.2 Python generators

The Dataset.from_generator constructor converts the python generator to a fully functional tf.data.Dataset

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):


[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

1.3 TFRecord data

1.4 loading text

1.5 CSV data

1.6 sets of files


2. 数据操作

2.1 batch


2.2 training workflows

process multiple epochs of the same data.

  1. use the Dataset.repeat() transformation
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)

def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')
titanic_batches = titanic_lines.repeat(3).batch(128)


titanic_batches = titanic_lines.batch(128).repeat(3)



  1. Randomly shuffling input data

The Dataset.shuffle() transformation maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer

2.3 预处理数据

The Dataset.map(f) transformation produces a new dataset by applying a given function f to each element of the input dataset.

use the tf.py_function() operation in a Dataset.map() transformation to call python libraries.

2.4 resampling
