Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(
# HMCState returned by hmc.init_kernel
self._init_state_cache = {}
self._cache = {}
self._partial_map_fn: partial | None = None
self._collection_params = {}
self._set_collection_params()

Expand Down Expand Up @@ -699,13 +700,29 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
collect_fields = tuple(collect_fields.keys())
remove_sites = tuple(remove_sites.keys())

partial_map_fn = partial(
self._single_chain_mcmc,
args=args,
kwargs=kwargs,
collect_fields=collect_fields,
remove_sites=remove_sites,
)
# Reuse the same partial object across run() calls so that
# jax.pmap sees a stable function identity and reuses the cached
# compilation.
# Since JAX 0.8.0 (aa2f995b1db548df5f1da92b43514f5314e75e4e),
# pmap is implemented via jit(shard_map) and jit caches by function
# identity: a new partial each call triggers a fresh trace + XLA compilation
# whose artifacts are never freed, causing unbounded memory
# growth.
if self._partial_map_fn is None:
self._partial_map_fn = partial(
self._single_chain_mcmc,
args=args,
kwargs=kwargs,
collect_fields=collect_fields,
remove_sites=remove_sites,
)
else:
self._partial_map_fn.keywords["args"] = args
self._partial_map_fn.keywords["kwargs"] = kwargs
self._partial_map_fn.keywords["collect_fields"] = collect_fields
self._partial_map_fn.keywords["remove_sites"] = remove_sites
Comment on lines +720 to +723
Copy link
Copy Markdown
Author

@solutionseekeras solutionseekeras Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder about the correctness of this: Perhaps the memoization should use (id(func), *<these-four>) instead of just the function.


partial_map_fn = self._partial_map_fn
map_args = (rng_key, init_state, init_params)
if self.num_chains == 1:
states_flat, last_state = partial_map_fn(map_args)
Expand Down
30 changes: 30 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,36 @@ def get_samples(rng_key, data):
)


def test_reuse_mcmc_run_stable_partial_identity():
"""Regression test: repeated run() calls must reuse the same partial object.

When pmap is implemented via jit(shard_map) (JAX >= 0.8.0), jit caches
by function identity. Creating a new functools.partial each run() call
causes a fresh trace + XLA compilation whose artifacts are never freed,
leading to unbounded memory growth in long-running services.
"""

def model():
numpyro.sample("x", dist.Normal(0, 1))

mcmc = MCMC(
NUTS(model),
num_warmup=5,
num_samples=5,
num_chains=1,
progress_bar=False,
)
mcmc.run(random.key(0))
first_partial = mcmc._partial_map_fn
assert first_partial is not None

mcmc.run(random.key(1))
assert mcmc._partial_map_fn is first_partial, (
"_partial_map_fn must be the same object across run() calls "
"to avoid pmap/jit recompilation leaks"
)


@pytest.mark.parametrize("num_chains", [1, 2])
@pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"])
@pytest.mark.parametrize("progress_bar", [True, False])
Expand Down
Loading