pymc.variational.DataLoader#
- class pymc.variational.DataLoader(dataset, *, batch_size, shuffle=False, buffer_size=None, seed=None, sample_shape=None, dtype='float64', total_size=None, preprocess_fn=None)[source]#
Turn an out-of-core dataset into fixed-size minibatches for variational inference.
Like
torch.utils.data.DataLoader, it batches (and optionally shuffles) anIterableDatasetinto the minibatch stream thatTrainerfeeds to the model. It is iterable and sized (len(loader)is the dataset sizeN). With bounded source chunks the full dataset is never resident at once.- Parameters:
- dataset
IterableDataset|Iterable[np.ndarray] |Callable[[],Iterator[np.ndarray]] The source of rows. An
IterableDataset, a re-iterable (including a plainnp.ndarray), or a zero-arg factory returning a fresh iterator (preferred, so the stream can be restarted each epoch). It may yield single samples (e.g. the rows of a raw array) or blocks of any size; the loader re-batches them, in order, to exactlybatch_sizerows. Trailing rows that do not fill a final batch are dropped at the end of a pass, likedrop_last=Truein PyTorch (required here because the model observes a fixed-shape placeholder). Withshuffle=Truethe dropped remainder differs per epoch; with a fixed replay order it is the same rows every pass.- batch_size
int Leading dimension of every yielded minibatch.
- shufflebool, default
False If
True, wrap the source in a boundedshuffle_buffer()ofbuffer_sizerows. This only approximates i.i.d. batches for an already unordered stream; a bounded buffer cannot fix strongly time/row-ordered data (pre-shuffle on disk for that; see the module docstring).- buffer_size
int, optional Shuffle-buffer size in rows when
shuffle=True. Defaults to50 * batch_size. Ignored whenshuffle=False. A buffer at least as large as the dataset holds all of it in memory (a full shuffle).- seed
int, optional Seed for the shuffle buffer (ignored when
shuffle=False).- sample_shape
tupleofint, optional Trailing shape of a single observation.
()for scalar observations,(k,)to streamkcolumns (e.g. features + the observed column). Defaults todataset.shape[1:]for a rawnp.ndarraysource (its rows are the samples, like torch’sTensorDataset), else().- dtype
str, default “float64” Dtype each prepared batch is cast to; match the dtype of the
pm.Dataplaceholder the batches are streamed into.- total_size
intor “auto”, optional The true dataset size
N(a positive integer), or"auto"to infer it (from the source’sn_rowsif available, else a single counting pass). Pass it on to the observed distribution astotal_size=len(loader)so the minibatch log-likelihood is rescaled byN / batch_size(the same mechanism aspm.Minibatch). Unlikepm.Minibatchit cannot be inferred from a resident array;Nonewarns at construction and a non-positive value raises (it would otherwise silently disable or invert the rescaling).- preprocess_fn
callable(), optional Pure transform applied to each batch before validation (e.g. normalization). It must preserve the row count and
sample_shape; to select columns, do it at the source instead (parquet_source(columns=...)).
- dataset
Methods
DataLoader.__init__(dataset, *, batch_size)Attributes
batch_sizebatches_seenrows_streamedTotal rows streamed into the model (grows past
Nacross epochs).total_sizeThe dataset size
N(pass to the distribution'stotal_size).