pymc.variational.Trainer#
- class pymc.variational.Trainer(*, method='advi', dataloader, model=None, data_name='batch', **fit_kwargs)[source]#
Drive variational inference over a
DataLoaderwithout user callbacks.Follows the design in PyMC’s variational-inference rework and PyTorch Lightning: the
Trainerowns the training loop, theDataLoaderowns batching (andlen(dataloader)is the dataset sizeN), and the model owns the math. The model exposes apm.Dataplaceholder; theTrainerstreams minibatches into it withmodel.set_dataonce per step; no user callbacks are needed.- Parameters:
- method
strorInference, default “advi” Variational method, forwarded to
pymc.fit(): a name ("advi","fullrank_advi", …) or anInferenceinstance.pm.fitappliesmodelandrandom_seedonly to a name; an instance is already bound to a model, so configure it at construction (e.g.ADVI(random_seed=...)).- dataloader
DataLoader The minibatch source.
len(dataloader)isN; the model should pass it to the observed distribution’stotal_size.- model
pymc.Model, optional Defaults to the model on the context stack.
- data_name
str, default “batch” Name of the
pm.Dataplaceholder minibatches are streamed into. Must match the name used forpm.Data(name, ...)in the model.- **fit_kwargs
Default keyword arguments forwarded to
pymc.fit()(e.g.obj_optimizer); per-call kwargs tofit()override them.
- method
Notes
The per-step
set_datacurrently lives in theTrainer. Once the VI rework’sInference.step(batch)lands it moves there, at which point thetotal_sizerescaling can be derived fromlen(dataloader)and dropped from the model body entirely.Examples
loader = DataLoader( parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto" ) with pm.Model() as model: b = pm.Normal("b", 0.0, 3.0, shape=4) batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000)
Methods
Trainer.__init__(*[, method, model, data_name])Trainer.fit([n])Fit for
nsteps, streaming minibatches into the model's placeholder.