Sorry for the extremely minor issue.
Despite the documentation stating the default is "scan", the actual value used is "vmap". The default value for postprocessing_vectorize is None in the function signature and gets set to "vmap" in the body. This resulted in a small amount of confusion when I removed the explicit argument postprocessing_vectorize="scan" in my function call after seeing the warning and consequently running into an OOM.
|
postprocessing_vectorize : Literal["vmap", "scan"], default "scan" |
|
How to vectorize the postprocessing: vmap or sequential scan |
|
if postprocessing_vectorize is not None: |
|
import warnings |
|
|
|
warnings.warn( |
|
'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.', |
|
FutureWarning, |
|
) |
|
else: |
|
postprocessing_vectorize = "vmap" |
Sorry for the extremely minor issue.
Despite the documentation stating the default is
"scan", the actual value used is"vmap". The default value forpostprocessing_vectorizeisNonein the function signature and gets set to"vmap"in the body. This resulted in a small amount of confusion when I removed the explicit argumentpostprocessing_vectorize="scan"in my function call after seeing the warning and consequently running into an OOM.pymc/pymc/sampling/jax.py
Lines 573 to 574 in bca2b1e
pymc/pymc/sampling/jax.py
Lines 606 to 614 in bca2b1e