Add MPS support to the gemma node#345
Conversation
|
Can anyone with permissions review this please? M3 Ultra Mac Studio support is needed! |
|
Implemented similar fix here: Comfy-Org/ComfyUI#12809 |
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR improves device/dtype handling for the Gemma encoder path, aiming to make generation RNG forking safer across backends and to avoid unsupported dtypes on MPS.
Changes:
- Adjusted
torch.random.fork_rng(...)usage to avoid passing a CUDA device on non-CUDA backends. - Made
dtypeactually configurable (instead of always forcingbfloat16) and propagated it into embeddings connector loaders. - Added MPS-specific dtype selection (fallback to
float16) when loading the model.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
c8faaba to
7728d42
Compare
Thanks, #12809 is complementary rather than overlapping, it fixes device placement (text encoders now land on MPS instead of CPU). This PR fixes the node itself once it's there ( |
José Lugo (@josephlugo) I've rebased the PR, would you be able to run it on your M3 Ultra and report back? |
Adds MPS (Apple Silicon) support to the Gemma text-encoder node.
Related to #302 (gemma text-encoder portion of MPS support)
Changes
fork_rngdevice guard —torch.random.fork_rng(devices=[...])onlyaccepts CUDA devices; on MPS it raises. The device list is now passed only
when the model is on CUDA. Seeded generation on CUDA is unchanged.
dtype—_LTXVGemmaTextEncoderModel.__init__previously hardcoded
torch.bfloat16, overwriting thedtypeargument;it's now only the default. With this,
clip_dtypeflows throughltxv_gemma_clipinto the model and theload_text_embeddings_pipelineloaders.
clip_dtypefalls back to
float16when the torch device is MPS (CUDA keepsbfloat16). This also keeps the newtorch.autocast(dtype=model.dtype)context in
_enhanceon a dtype MPS actually supports.Rebased onto current master (single commit, one file, +11/−2). An earlier
revision also propagated
dtypeinto the embeddings connectors; that's nolonger needed — the
load_text_embeddings_pipelinerefactor on masteralready threads it through.
Testing status
fork_rngdevice constraints; the CUDApath is preserved unchanged.
Apple Silicon welcome.