TensorFlow Datasets: The Bad Parts

Random Access

import torch.utils.data

class RandomAccessDataset(torch.utils.data.Dataset):
def __init__(self, data: List) -> None:
self.data = data

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index: int): -> Any:
return self.data[index]

Sequential Access

def sequential_dataset(data: List) -> Iterator:
for item in data:
yield item
import itertools

def gen():
for i in itertools.count(1):
yield (i, [1] * i)

dataset = tf.data.Dataset.from_generator(
(tf.int64, tf.int64),
(tf.TensorShape([]), tf.TensorShape([None])))

Sequential Access in TensorFlow Datasets

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1,2,3])

for element in dataset:
print (element)
>>> tf.Tensor(1, shape=(), dtype=int32)
>>> tf.Tensor(2, shape=(), dtype=int32)
>>> tf.Tensor(3, shape=(), dtype=int32)

dataset = dataset.map(lambda x: x*2)
for element in dataset:
print (element)
>>> tf.Tensor(2, shape=(), dtype=int32)
>>> tf.Tensor(4, shape=(), dtype=int32)
>>> tf.Tensor(6, shape=(), dtype=int32)
>>> TypeError: 'TensorSliceDataset' object does not support indexing

>>> [1,2,3]

Data Shuffling

Data Sharding

  1. You need to split your data set into a larger number of files than the number of workers in your distributed training job. If you have a large dataset stored in a small number of files, you’re out of luck. Moreover, any size imbalances between those files will result in stragglers, hurting training performance.
  2. More likely, you might not realize any of this! A lot of real-world data loading code just converts a Python generator into a TensorFlow Dataset using Dataset.from_generator(). This will appear to work okay at small scale, but will quickly run into serious performance problems as your data set grows.

Saving and Restoring Iterator State


dataset = RandomAccessDataset()

def sequential_access_dataset() -> Iterator:
for index in range(len(dataset)):
yield dataset[index]




