Skip to content

postprocessing_vectorize default does not match docs/warnings #8238

@zongleon

Description

@zongleon

Sorry for the extremely minor issue.

Despite the documentation stating the default is "scan", the actual value used is "vmap". The default value for postprocessing_vectorize is None in the function signature and gets set to "vmap" in the body. This resulted in a small amount of confusion when I removed the explicit argument postprocessing_vectorize="scan" in my function call after seeing the warning and consequently running into an OOM.

pymc/pymc/sampling/jax.py

Lines 573 to 574 in bca2b1e

postprocessing_vectorize : Literal["vmap", "scan"], default "scan"
How to vectorize the postprocessing: vmap or sequential scan

pymc/pymc/sampling/jax.py

Lines 606 to 614 in bca2b1e

if postprocessing_vectorize is not None:
import warnings
warnings.warn(
'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.',
FutureWarning,
)
else:
postprocessing_vectorize = "vmap"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions