pymc.variational.Trainer#

class pymc.variational.Trainer(*, method='advi', dataloader, model=None, data_name='batch', **fit_kwargs)[source]#

Drive variational inference over a DataLoader without user callbacks.

Follows the design in PyMC’s variational-inference rework and PyTorch Lightning: the Trainer owns the training loop, the DataLoader owns batching (and len(dataloader) is the dataset size N), and the model owns the math. The model exposes a pm.Data placeholder; the Trainer streams minibatches into it with model.set_data once per step; no user callbacks are needed.

Parameters:
methodstr or Inference, default “advi”

Variational method, forwarded to pymc.fit(): a name ("advi", "fullrank_advi", …) or an Inference instance. pm.fit applies model and random_seed only to a name; an instance is already bound to a model, so configure it at construction (e.g. ADVI(random_seed=...)).

dataloaderDataLoader

The minibatch source. len(dataloader) is N; the model should pass it to the observed distribution’s total_size.

modelpymc.Model, optional

Defaults to the model on the context stack.

data_namestr, default “batch”

Name of the pm.Data placeholder minibatches are streamed into. Must match the name used for pm.Data(name, ...) in the model.

**fit_kwargs

Default keyword arguments forwarded to pymc.fit() (e.g. obj_optimizer); per-call kwargs to fit() override them.

Notes

The per-step set_data currently lives in the Trainer. Once the VI rework’s Inference.step(batch) lands it moves there, at which point the total_size rescaling can be derived from len(dataloader) and dropped from the model body entirely.

Examples

loader = DataLoader(
    parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto"
)
with pm.Model() as model:
    b = pm.Normal("b", 0.0, 3.0, shape=4)
    batch = pm.Data("batch", np.zeros((4096, 4)))  # placeholder
    logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2]
    pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader))
    approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000)

Methods

Trainer.__init__(*[, method, model, data_name])

Trainer.fit([n])

Fit for n steps, streaming minibatches into the model's placeholder.