Skip to content

Fix xla_pmap_p import for JAX versions that removed pmap#2173

Open
saitcakmak wants to merge 1 commit intopyro-ppl:masterfrom
saitcakmak:fix-xla-pmap-p-import
Open

Fix xla_pmap_p import for JAX versions that removed pmap#2173
saitcakmak wants to merge 1 commit intopyro-ppl:masterfrom
saitcakmak:fix-xla-pmap-p-import

Conversation

@saitcakmak
Copy link
Copy Markdown

Summary

  • JAX 0.10.0 removed the C++ pmap infrastructure (including xla_pmap_p). This causes an ImportError when importing numpyro with newer JAX versions.
  • Guard the xla_pmap_p import with a try/except and skip the provenance tracking rule registration when it's unavailable.

Test plan

  • python -c "from numpyro.ops.provenance import eval_provenance" succeeds
  • eval_provenance works end-to-end (tested with lambda x, y, z: x + y)
  • All 10 tests in test/ops/test_provenance.py pass

JAX removed the C++ pmap infrastructure (including xla_pmap_p) in a
recent release. Guard the import so numpyro works with both old and
new JAX versions.
@Qazalbash
Copy link
Copy Markdown
Collaborator

Thanks @saitcakmak

@Qazalbash Qazalbash requested a review from juanitorduz April 16, 2026 14:23
@juanitorduz
Copy link
Copy Markdown
Collaborator

juanitorduz commented Apr 16, 2026

LGTM.

It seems the failing tests in Python 3.14 CI (test-inference) are not caused by this PR. They fail inside funsor, a transitive dependency:

  .venv/lib/python3.14/site-packages/funsor/jax/ops.py:203                                                                                                                                                           
  E       TypeError: clip() got an unexpected keyword argument 'a_max' 

funsor calls jnp.clip(..., a_max=...), but recent JAX releases dropped the deprecated a_max/a_min kwargs (replaced by max/min, aligning with NumPy 2.x).

I am looking into the upstream issue: pyro-ppl/funsor#611 @fehiepsi would you mind taking a look?

Qazalbash added a commit to kokabsc/gwkokab that referenced this pull request Apr 18, 2026
sethaxen added a commit to sethaxen/CAGPJax that referenced this pull request Apr 20, 2026
* chore: Temporarily upperbound jax

Until numpyro v0.20 is compatible with jax v0.10, pyro-ppl/numpyro#2173

* fix: Make support a property

Change made in numpyro v0.20.0

* fix: Explicitly import jax.test_util
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX 0.10.0 breaks NumPyro

4 participants