pymc.variational.DataLoader#

class pymc.variational.DataLoader(dataset, *, batch_size, shuffle=False, buffer_size=None, seed=None, sample_shape=None, dtype='float64', total_size=None, preprocess_fn=None)[source]#

Turn an out-of-core dataset into fixed-size minibatches for variational inference.

Like torch.utils.data.DataLoader, it batches (and optionally shuffles) an IterableDataset into the minibatch stream that Trainer feeds to the model. It is iterable and sized (len(loader) is the dataset size N). With bounded source chunks the full dataset is never resident at once.

Parameters:
datasetIterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]]

The source of rows. An IterableDataset, a re-iterable (including a plain np.ndarray), or a zero-arg factory returning a fresh iterator (preferred, so the stream can be restarted each epoch). It may yield single samples (e.g. the rows of a raw array) or blocks of any size; the loader re-batches them, in order, to exactly batch_size rows. Trailing rows that do not fill a final batch are dropped at the end of a pass, like drop_last=True in PyTorch (required here because the model observes a fixed-shape placeholder). With shuffle=True the dropped remainder differs per epoch; with a fixed replay order it is the same rows every pass.

batch_sizeint

Leading dimension of every yielded minibatch.

shufflebool, default False

If True, wrap the source in a bounded shuffle_buffer() of buffer_size rows. This only approximates i.i.d. batches for an already unordered stream; a bounded buffer cannot fix strongly time/row-ordered data (pre-shuffle on disk for that; see the module docstring).

buffer_sizeint, optional

Shuffle-buffer size in rows when shuffle=True. Defaults to 50 * batch_size. Ignored when shuffle=False. A buffer at least as large as the dataset holds all of it in memory (a full shuffle).

seedint, optional

Seed for the shuffle buffer (ignored when shuffle=False).

sample_shapetuple of int, optional

Trailing shape of a single observation. () for scalar observations, (k,) to stream k columns (e.g. features + the observed column). Defaults to dataset.shape[1:] for a raw np.ndarray source (its rows are the samples, like torch’s TensorDataset), else ().

dtypestr, default “float64”

Dtype each prepared batch is cast to; match the dtype of the pm.Data placeholder the batches are streamed into.

total_sizeint or “auto”, optional

The true dataset size N (a positive integer), or "auto" to infer it (from the source’s n_rows if available, else a single counting pass). Pass it on to the observed distribution as total_size=len(loader) so the minibatch log-likelihood is rescaled by N / batch_size (the same mechanism as pm.Minibatch). Unlike pm.Minibatch it cannot be inferred from a resident array; None warns at construction and a non-positive value raises (it would otherwise silently disable or invert the rescaling).

preprocess_fncallable(), optional

Pure transform applied to each batch before validation (e.g. normalization). It must preserve the row count and sample_shape; to select columns, do it at the source instead (parquet_source(columns=...)).

Methods

DataLoader.__init__(dataset, *, batch_size)

Attributes

batch_size

batches_seen

rows_streamed

Total rows streamed into the model (grows past N across epochs).

total_size

The dataset size N (pass to the distribution's total_size).