diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index a1f7cb5d0..57c52a2da 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -389,6 +389,7 @@ def __init__( # HMCState returned by hmc.init_kernel self._init_state_cache = {} self._cache = {} + self._partial_map_fn: partial | None = None self._collection_params = {} self._set_collection_params() @@ -699,13 +700,29 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): collect_fields = tuple(collect_fields.keys()) remove_sites = tuple(remove_sites.keys()) - partial_map_fn = partial( - self._single_chain_mcmc, - args=args, - kwargs=kwargs, - collect_fields=collect_fields, - remove_sites=remove_sites, - ) + # Reuse the same partial object across run() calls so that + # jax.pmap sees a stable function identity and reuses the cached + # compilation. + # Since JAX 0.8.0 (aa2f995b1db548df5f1da92b43514f5314e75e4e), + # pmap is implemented via jit(shard_map) and jit caches by function + # identity: a new partial each call triggers a fresh trace + XLA compilation + # whose artifacts are never freed, causing unbounded memory + # growth. + if self._partial_map_fn is None: + self._partial_map_fn = partial( + self._single_chain_mcmc, + args=args, + kwargs=kwargs, + collect_fields=collect_fields, + remove_sites=remove_sites, + ) + else: + self._partial_map_fn.keywords["args"] = args + self._partial_map_fn.keywords["kwargs"] = kwargs + self._partial_map_fn.keywords["collect_fields"] = collect_fields + self._partial_map_fn.keywords["remove_sites"] = remove_sites + + partial_map_fn = self._partial_map_fn map_args = (rng_key, init_state, init_params) if self.num_chains == 1: states_flat, last_state = partial_map_fn(map_args) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 39637e23d..d76753d93 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -863,6 +863,36 @@ def get_samples(rng_key, data): ) +def test_reuse_mcmc_run_stable_partial_identity(): + """Regression test: repeated run() calls must reuse the same partial object. + + When pmap is implemented via jit(shard_map) (JAX >= 0.8.0), jit caches + by function identity. Creating a new functools.partial each run() call + causes a fresh trace + XLA compilation whose artifacts are never freed, + leading to unbounded memory growth in long-running services. + """ + + def model(): + numpyro.sample("x", dist.Normal(0, 1)) + + mcmc = MCMC( + NUTS(model), + num_warmup=5, + num_samples=5, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.key(0)) + first_partial = mcmc._partial_map_fn + assert first_partial is not None + + mcmc.run(random.key(1)) + assert mcmc._partial_map_fn is first_partial, ( + "_partial_map_fn must be the same object across run() calls " + "to avoid pmap/jit recompilation leaks" + ) + + @pytest.mark.parametrize("num_chains", [1, 2]) @pytest.mark.parametrize("chain_method", ["parallel", "sequential", "vectorized"]) @pytest.mark.parametrize("progress_bar", [True, False])