mcmc: Fix memory leak when chain_method="parallel"#2172
mcmc: Fix memory leak when chain_method="parallel"#2172solutionseekeras wants to merge 1 commit intopyro-ppl:masterfrom
Conversation
Since jax >= 0.8.0 functions are cached based on identity. Creating a new partial() each run leads to new identity -> unable to cache. Fix this by reusing the same partial object across runs, only updating its keywords instead of constructing a new partial().
|
@solutionseekeras, thanks for the PR. There is no need to create an issue. cc: @fehiepsi, @juanitorduz |
|
Updated reproducer, demonstrating the patch (by monkeypatching): """
Minimal reproducer for numpyro MCMC memory leak.
"""
from __future__ import annotations
import gc
from functools import partial
import jax
import numpyro
import numpyro.distributions as dist
import numpyro.infer.mcmc as mcmc_module
import psutil
from jax import random
from numpyro.infer import MCMC, NUTS
def get_rss_mib() -> float:
return psutil.Process().memory_info().rss / (1024**2)
def model() -> None:
x = numpyro.sample("x", dist.Normal(0, 1))
numpyro.sample("obs", dist.Normal(x, 0.1), obs=1.0)
_original_partial = partial
_cached_partials: dict[int, partial] = {}
def _stable_partial(func, *args, **kwargs):
"""Return a cached partial when wrapping _single_chain_mcmc, otherwise fall back to normal partial."""
if getattr(func, "__name__", None) == "_single_chain_mcmc":
# Key on the underlying function + bound self to handle multiple MCMC instances
key = id(getattr(func, "__self__", func))
if key not in _cached_partials:
_cached_partials[key] = _original_partial(func, *args, **kwargs)
else:
_cached_partials[key].keywords.update(kwargs)
return _cached_partials[key]
return _original_partial(func, *args, **kwargs)
def run_test(label: str, chain_method: str, patched: bool) -> None:
if patched:
mcmc_module.partial = _stable_partial
else:
mcmc_module.partial = _original_partial
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=4, progress_bar=False, chain_method=chain_method)
rng_key = random.key(0)
rss_start = get_rss_mib()
for _ in range(50):
mcmc.run(rng_key)
gc.collect()
rss_end = get_rss_mib()
print(f"{label}: +{rss_end - rss_start:.0f} MiB over 50 calls")
def main() -> None:
numpyro.set_host_device_count(4)
print(f"jax=={jax.__version__}, numpyro=={numpyro.__version__}\n")
run_test("parallel, unpatched", "parallel", patched=False)
run_test("sequential, unpatched", "sequential", patched=False)
run_test("parallel, patched ", "parallel", patched=True)
run_test("sequential, patched ", "sequential", patched=True)
if __name__ == "__main__":
main()Results (Linux Edit: This is curious: I ran the reproducer with jax |
| self._partial_map_fn.keywords["args"] = args | ||
| self._partial_map_fn.keywords["kwargs"] = kwargs | ||
| self._partial_map_fn.keywords["collect_fields"] = collect_fields | ||
| self._partial_map_fn.keywords["remove_sites"] = remove_sites |
There was a problem hiding this comment.
I wonder about the correctness of this: Perhaps the memoization should use (id(func), *<these-four>) instead of just the function.
|
I am not confident of the correctness of this patch, mostly due to the compilation no longer being able to take the four If the patch is not correct, however, I think this can be considered a bug report: It seems clear to me that there's a quite severe memory leak here, that should be addressed. |
|
Would the leak happen if you run |
|
No, that does seem to fix it (at least in the reproducer)! I'll test it out in the production code where we saw this initially. |
Since jax >= 0.8.0 functions are cached based on identity. Creating a new partial() each run leads to new identity -> unable to cache. Fix this by reusing the same partial object across runs, only updating its keywords instead of constructing a new partial().
Reproducing code:
NB! I did not create an issue for tracking this. Let me know if you require that.