perf(dsv4 ratio4): coarsen softmax_pool to POOL_GROUP=2 batches per task#497
Conversation
There was a problem hiding this comment.
Code Review
This pull request coarsens the softmax_pool SPMD loop in decode_compressor_ratio4.py by grouping batches into POOL_GROUP tasks and unrolling them to prevent intermittent NaN issues. Feedback on these changes suggests replacing the newly added assert statement with a ValueError check, as assertions can be globally disabled in optimized Python environments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| HEAD_DIM_TILE = 128 | ||
| RMS_TILE = 16 | ||
| POOL_GROUP = 4 # batches per softmax_pool task (pl.unroll; B % POOL_GROUP == 0) | ||
| assert B % POOL_GROUP == 0, "B must be divisible by POOL_GROUP" |
There was a problem hiding this comment.
Using assert statements for production configuration or runtime validation is discouraged because assertions can be globally disabled when Python is run with optimization flags (e.g., python -O). It is safer and more robust to raise a ValueError instead, which is also consistent with the validation pattern used in decode_compressor_ratio128.py.
| assert B % POOL_GROUP == 0, "B must be divisible by POOL_GROUP" | |
| if B % POOL_GROUP != 0: | |
| raise ValueError("B must be divisible by POOL_GROUP") |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThis PR refactors the softmax-pooling stage in the DeepSeek v4 decode compressor to use grouped batch parallelization. It introduces a ChangesGrouped Softmax Pooling
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly Related PRs
Suggested Labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
11194ad to
6fe66fe
Compare
The softmax_pool loop ran one task per batch (pl.spmd(B)=64 tasks), each a small online-softmax over HEAD_DIM//HEAD_TILE head tiles -- fine-grained, high per-task tail overhead. Group POOL_GROUP=2 batches per task via pl.unroll (B//POOL_GROUP=32 tasks). pl.unroll (NOT pl.range) is required: trace-time unroll gives each batch its own independent AST values for the online-softmax accumulators (mi/li/oi), whereas a runtime pl.range would loop-carry them across batches -> intermittent NaN/507018. Each batch keeps its own data-dependent gate and positions, so the unroll is bit-identical to the per-batch form. Standalone decode_compressor_ratio4 on a2a3 (kv/compress_state/cmp_kv_cache all PASS), softmax_pool task count / Total Test Time A/B: POOL_GROUP=1 (64 tasks): 327.32 us (Exec% 58.3, tail OH 8.1 us) POOL_GROUP=2 (32 tasks): 261.22 us (-20.2%, Exec% 85.0, tail OH 3.0 us) POOL_GROUP=4 (16 tasks): 324.70 us (over-coarsened, core-starved -> back to baseline) POOL_GROUP=2 is the sweet spot. In the full CSA orchestrator this scope is overlapped by qr_rope so the win is largely hidden there; the standalone gain is the scope-level signal.
What
Coarsen the ratio-4 compressor
softmax_poolloop from one task per batch (pl.spmd(B)= 64 tasks) to POOL_GROUP=2 batches per task (pl.spmd(B // POOL_GROUP)= 32 tasks), with an innerpl.unroll(POOL_GROUP)over the batches.pl.unroll(NOTpl.range): trace-time unroll gives each batch its own independent AST values for the online-softmax accumulators (mi/li/oi). A runtimepl.rangewould loop-carry them across batches → intermittent NaN/507018.pos_b + S >= COMPRESS_RATIO) and window positions, so the unroll is bit-identical to the per-batch form.Why / Tuning
softmax_poolwas a fine-grained scope (64 tiny online-softmax tasks, Exec% 58%, tail OH 8.1µs) that floods the scheduler. Grouping batches per task amortizes the per-task overhead. A/B sweep on standalonedecode_compressor_ratio4.py(a2a3,kv/compress_state/cmp_kv_cacheall PASS):POOL_GROUP=2 is the sweet spot — a clean U-curve. POOL_GROUP=4 over-coarsens (16 big serial tasks pinned to 16 cores) and collapses back to baseline.
Note on the full CSA orchestrator
In
decode_attention_csa.pythis scope is overlapped byqr_ropein the same window, so the standalone −20% is largely hidden at the CSA Total level (CSA wall-clock is dominated byqk_pv~442µs andgather_kv~332µs). The standalone gain is the scope-level signal; this change is bit-identical and a strict scope-level win, with no CSA regression.