pymc.variational.IterableDataset#

class pymc.variational.IterableDataset[source]#

A re-iterable, out-of-core source of rows, like torch.utils.data.IterableDataset.

Subclass and implement __iter__() to yield np.ndarray blocks of rows (shape (rows, *sample_shape)); DataLoader re-batches those blocks into fixed-size minibatches. __iter__ must return a fresh iterator each call so the dataset can be replayed across epochs.

Optionally set n_rows (the total row count, if known cheaply, e.g. from file metadata) so a DataLoader with total_size="auto" can resolve N without a counting pass.

A plain zero-arg factory (Callable[[], Iterator[np.ndarray]]) or any re-iterable is also accepted directly by DataLoader; this base class is only needed when you want to attach behavior or n_rows to a custom source.

Methods

IterableDataset.__init__(*args, **kwargs)

Attributes

n_rows