Skip to content

New ADVI fails for variables whose value transforms changes the variables' shape #646

Description

@Dekermanjian

Quick repro of the issue:

import pymc as pm

from pymc_extras.inference.advi.autoguide import AutoDiagonalNormal
from pymc_extras.inference.advi.training import SVIModule, SVITrainer

class SGDOptimizer:
    def __init__(self, learning_rate: float = 1e-5):
        self.learning_rate = learning_rate

    def init(self, params: dict[str, np.ndarray]) -> None:
        return None

    def update(
        self,
        grads: dict[str, np.ndarray],
        state: None,
        params: dict[str, np.ndarray],
    ) -> tuple[dict[str, np.ndarray], None]:
        updated_params = {k: v - self.learning_rate * grads[k] for k, v in params.items()}
        return updated_params, state

class NormalModel(SVIModule):
    def configure_guide(self, model):
        return AutoDiagonalNormal(model)

    def configure_optimizer(self, params: dict[str, np.ndarray]) -> tuple[Any, dict[str, Any]]:
        optimizer = SGDOptimizer(learning_rate=1e-5)
        opt_state = optimizer.init(params)
        return optimizer, opt_state

    def apply_gradients(
        self,
        params: dict[str, np.ndarray],
        grads: dict[str, np.ndarray],
        optimizer: Any,
        optimizer_state: dict[str, Any],
    ):
        updated_params, updated_opt_state = optimizer.update(grads, optimizer_state, params)
        return updated_params, updated_opt_state

with pm.Model() as m:
    p = pm.Dirichlet("p", np.ones(3))
    obs = pm.Categorical('obs', p=p, observed=[0, 1, 2])

    idata = pm.sample() # samples successfully

svi_trainer = SVITrainer(
    module=NormalModel(), stick_the_landing=True
)
svi_state = svi_trainer.fit(n_steps=10_000, model=m, draws_per_step=1) # Fails

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions