Skip to content

mcmc: Fix memory leak when chain_method="parallel"#2172

Open
solutionseekeras wants to merge 1 commit intopyro-ppl:masterfrom
solutionseekeras:fix-pmap-compilation-cache-leak
Open

mcmc: Fix memory leak when chain_method="parallel"#2172
solutionseekeras wants to merge 1 commit intopyro-ppl:masterfrom
solutionseekeras:fix-pmap-compilation-cache-leak

Conversation

@solutionseekeras
Copy link
Copy Markdown

Since jax >= 0.8.0 functions are cached based on identity. Creating a new partial() each run leads to new identity -> unable to cache. Fix this by reusing the same partial object across runs, only updating its keywords instead of constructing a new partial().

Reproducing code:

"""
Minimal reproducer for numpyro MCMC memory leak with parallel chains.

numpyro's MCMC.run() creates a new functools.partial each call and wraps it
in jax.pmap. Since jax >=0.8.0, pmap is implemented as jit(shard_map), and
jit caches by function identity. A new partial = new identity = new trace +
compilation that is never freed.
"""

from __future__ import annotations

import gc

import numpyro
import numpyro.distributions as dist
import psutil
from jax import random
from numpyro.infer import MCMC, NUTS


def get_rss_mib() -> float:
    return psutil.Process().memory_info().rss / (1024**2)


def model() -> None:
    x = numpyro.sample("x", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(x, 0.1), obs=1.0)


def main() -> None:
    numpyro.set_host_device_count(4)
    rng_key = random.key(0)

    mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=4, progress_bar=False, chain_method="parallel")

    rss_start = get_rss_mib()
    for _ in range(50):
        mcmc.run(rng_key)
        gc.collect()
    rss_end = get_rss_mib()

    print(f"numpyro=={numpyro.__version__}, chains=4, parallel")
    print(f"RSS: {rss_start:.0f} -> {rss_end:.0f} MiB (+{rss_end - rss_start:.0f} MiB over 50 calls)")


if __name__ == "__main__":
    main()

NB! I did not create an issue for tracking this. Let me know if you require that.

Since jax >= 0.8.0 functions are cached based on identity. Creating a
new partial() each run leads to new identity -> unable to cache. Fix
this by reusing the same partial object across runs, only updating its
keywords instead of constructing a new partial().
@Qazalbash
Copy link
Copy Markdown
Collaborator

@solutionseekeras, thanks for the PR. There is no need to create an issue.

cc: @fehiepsi, @juanitorduz

@solutionseekeras
Copy link
Copy Markdown
Author

solutionseekeras commented Apr 15, 2026

Updated reproducer, demonstrating the patch (by monkeypatching):

"""
Minimal reproducer for numpyro MCMC memory leak.
"""

from __future__ import annotations

import gc
from functools import partial

import jax
import numpyro
import numpyro.distributions as dist
import numpyro.infer.mcmc as mcmc_module
import psutil
from jax import random
from numpyro.infer import MCMC, NUTS


def get_rss_mib() -> float:
    return psutil.Process().memory_info().rss / (1024**2)


def model() -> None:
    x = numpyro.sample("x", dist.Normal(0, 1))
    numpyro.sample("obs", dist.Normal(x, 0.1), obs=1.0)


_original_partial = partial
_cached_partials: dict[int, partial] = {}


def _stable_partial(func, *args, **kwargs):
    """Return a cached partial when wrapping _single_chain_mcmc, otherwise fall back to normal partial."""
    if getattr(func, "__name__", None) == "_single_chain_mcmc":
        # Key on the underlying function + bound self to handle multiple MCMC instances
        key = id(getattr(func, "__self__", func))
        if key not in _cached_partials:
            _cached_partials[key] = _original_partial(func, *args, **kwargs)
        else:
            _cached_partials[key].keywords.update(kwargs)
        return _cached_partials[key]
    return _original_partial(func, *args, **kwargs)


def run_test(label: str, chain_method: str, patched: bool) -> None:
    if patched:
        mcmc_module.partial = _stable_partial
    else:
        mcmc_module.partial = _original_partial

    mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=4, progress_bar=False, chain_method=chain_method)

    rng_key = random.key(0)
    rss_start = get_rss_mib()
    for _ in range(50):
        mcmc.run(rng_key)
        gc.collect()
    rss_end = get_rss_mib()
    print(f"{label}: +{rss_end - rss_start:.0f} MiB over 50 calls")


def main() -> None:
    numpyro.set_host_device_count(4)

    print(f"jax=={jax.__version__}, numpyro=={numpyro.__version__}\n")

    run_test("parallel,   unpatched", "parallel", patched=False)
    run_test("sequential, unpatched", "sequential", patched=False)
    run_test("parallel,   patched  ", "parallel", patched=True)
    run_test("sequential, patched  ", "sequential", patched=True)


if __name__ == "__main__":
    main()

Results (Linux 6.17.7-ba29.fc43.x86_64 - note that the leak seems slower on Mac M-series, but it still leaks):

jax==0.9.2, numpyro==0.20.1

parallel,   unpatched: +1127 MiB over 50 calls
sequential, unpatched: +132 MiB over 50 calls
parallel,   patched  : +10 MiB over 50 calls
sequential, patched  : +21 MiB over 50 calls

Edit: This is curious: I ran the reproducer with jax 0.7.2, and it seems to still leak there (but quite a bit slower, and still fixed by this patch):

jax==0.7.2, numpyro==0.20.1

parallel,   unpatched: +318 MiB over 50 calls
sequential, unpatched: +86 MiB over 50 calls
parallel,   patched  : +8 MiB over 50 calls
sequential, patched  : +20 MiB over 50 calls
jax==0.8.0, numpyro==0.20.1

parallel,   unpatched: +1111 MiB over 50 calls
sequential, unpatched: +123 MiB over 50 calls
parallel,   patched  : +11 MiB over 50 calls
sequential, patched  : +23 MiB over 50 calls
jax==0.9.1, numpyro==0.20.0

parallel,   unpatched: +1097 MiB over 50 calls
sequential, unpatched: +138 MiB over 50 calls
parallel,   patched  : +9 MiB over 50 calls
sequential, patched  : +22 MiB over 50 calls

Comment thread numpyro/infer/mcmc.py
Comment on lines +720 to +723
self._partial_map_fn.keywords["args"] = args
self._partial_map_fn.keywords["kwargs"] = kwargs
self._partial_map_fn.keywords["collect_fields"] = collect_fields
self._partial_map_fn.keywords["remove_sites"] = remove_sites
Copy link
Copy Markdown
Author

@solutionseekeras solutionseekeras Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder about the correctness of this: Perhaps the memoization should use (id(func), *<these-four>) instead of just the function.

@solutionseekeras
Copy link
Copy Markdown
Author

I am not confident of the correctness of this patch, mostly due to the compilation no longer being able to take the four args, kwargs, collect_fields, and remove_sites into account. Surely you guys will know this.

If the patch is not correct, however, I think this can be considered a bug report: It seems clear to me that there's a quite severe memory leak here, that should be addressed.

@fehiepsi
Copy link
Copy Markdown
Member

Would the leak happen if you run jax.jit(mcmc.run)(key)?

@solutionseekeras
Copy link
Copy Markdown
Author

No, that does seem to fix it (at least in the reproducer)! I'll test it out in the production code where we saw this initially.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants