Add DeepSeek V4 unified decode layer#496
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds comprehensive end-to-end smoke tests for DeepSeek-V4 decode layers, composing local attention mechanisms (SWA, HCA, CSA, or auto-selected) with quantized MoE expert routing across 2 distributed ranks. It includes JIT kernels, host orchestration with buffer allocation, golden reference implementations for validation, and test infrastructure with CLI support. ChangesDeepSeek-V4 Decode Layer MoE Integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces decode_layer_swa_moe_ep.py to implement DeepSeek-V4 decode layers combining SWA, HCA, and CSA attention mechanisms with MoE EP, alongside minor JIT inline and test function updates in moe_ep.py. The review feedback highlights several opportunities to improve code quality, including addressing the brittle, hardcoded attention selection logic in the auto-routing function, and refactoring duplicated code for window buffer allocations and tensor spec generation into reusable helper functions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if layer_id < 2: | ||
| attention_swa( | ||
| x_hc, | ||
| hc_attn_fn, hc_attn_scale, hc_attn_base, | ||
| attn_norm_w, wq_a, wq_b, wq_b_scale, | ||
| wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, | ||
| kv_cache, block_table, | ||
| attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, | ||
| x_attn, | ||
| start_pos, | ||
| ) | ||
| elif layer_id % 2 == 1: | ||
| attention_hca( | ||
| x_hc, | ||
| hc_attn_fn, hc_attn_scale, hc_attn_base, | ||
| attn_norm_w, wq_a, wq_b, wq_b_scale, | ||
| wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, | ||
| hca_cmp_wkv, hca_cmp_wgate, hca_cmp_ape, hca_cmp_norm_w, | ||
| hca_compress_state, hca_compress_state_block_table, | ||
| kv_cache, block_table, cmp_kv, cmp_block_table, | ||
| attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, | ||
| x_attn, | ||
| start_pos, | ||
| ) | ||
| else: | ||
| attention_csa( | ||
| x_hc, | ||
| hc_attn_fn, hc_attn_scale, hc_attn_base, | ||
| attn_norm_w, wq_a, wq_b, wq_b_scale, | ||
| wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, | ||
| csa_cmp_wkv, csa_cmp_wgate, csa_cmp_ape, csa_cmp_norm_w, | ||
| csa_compress_state, csa_compress_state_block_table, | ||
| csa_idx_wq_b, csa_idx_wq_b_scale, csa_weights_proj, csa_hadamard_idx, | ||
| csa_inner_wkv, csa_inner_wgate, csa_inner_ape, csa_inner_norm_w, | ||
| csa_inner_compress_state, csa_inner_compress_state_block_table, | ||
| kv_cache, block_table, cmp_kv, cmp_block_table, | ||
| idx_kv_cache, idx_block_table, | ||
| attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, | ||
| x_attn, | ||
| start_pos, | ||
| ) |
There was a problem hiding this comment.
The logic for selecting the attention mechanism is tightly coupled to the specific compress_ratios pattern in the current model configuration. This makes the code brittle. If MODEL_CONFIG.compress_ratios changes in a way that violates the hardcoded layer_id < 2 or layer_id % 2 == 1 checks, this function will silently use the wrong attention mechanism. This could lead to incorrect behavior that is hard to debug.
The golden path implementation in golden_decode_layer_auto uses _attention_kind_for_layer, which is more robust as it directly uses the configuration.
While pl.jit functions have limitations on accessing external configuration, this hardcoded logic is a significant risk. Please consider adding a comment to highlight this strong assumption about the configuration, or an assertion to verify that the layer_id corresponds to the expected attention mechanism based on the MODEL_CONFIG if possible at JIT time.
| pub_counts_buf = pld.alloc_window_buffer(N_RANKS * N_RANKS * N_LOCAL * 4) | ||
| count_done_buf = pld.alloc_window_buffer(N_RANKS * 4) | ||
| recv_x_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * D * 2) | ||
| recv_scale_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4) | ||
| recv_w_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4) | ||
| recv_r_route_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * IDX_PAD * 4) | ||
| routed_y_buf_buf = pld.alloc_window_buffer(N_ROUTES * D * 2) | ||
| combine_done_buf = pld.alloc_window_buffer(N_RANKS * 4) |
There was a problem hiding this comment.
This block of code for allocating window buffers is duplicated across host_orch, host_orch_hca, host_orch_csa, and host_orch_auto.
To improve maintainability and reduce redundancy, this logic can be extracted into a helper function.
Example of a helper function:
def _alloc_moe_windows():
pub_counts_buf = pld.alloc_window_buffer(N_RANKS * N_RANKS * N_LOCAL * 4)
count_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
recv_x_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * D * 2)
recv_scale_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
recv_w_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
recv_r_route_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * IDX_PAD * 4)
routed_y_buf_buf = pld.alloc_window_buffer(N_ROUTES * D * 2)
combine_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
return (
pub_counts_buf, count_done_buf, recv_x_buf, recv_scale_buf,
recv_w_buf, recv_r_route_buf, routed_y_buf_buf, combine_done_buf
)You can then call this helper in each host_orch_* function.
| pub_counts_buf = pld.alloc_window_buffer(N_RANKS * N_RANKS * N_LOCAL * 4) | |
| count_done_buf = pld.alloc_window_buffer(N_RANKS * 4) | |
| recv_x_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * D * 2) | |
| recv_scale_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4) | |
| recv_w_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4) | |
| recv_r_route_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * IDX_PAD * 4) | |
| routed_y_buf_buf = pld.alloc_window_buffer(N_ROUTES * D * 2) | |
| combine_done_buf = pld.alloc_window_buffer(N_RANKS * 4) | |
| ( | |
| pub_counts_buf, count_done_buf, recv_x_buf, recv_scale_buf, recv_w_buf, | |
| recv_r_route_buf, routed_y_buf_buf, combine_done_buf | |
| ) = _alloc_moe_windows() |
| for spec in moe_specs: | ||
| if not isinstance(spec, TensorSpec): | ||
| continue | ||
| if spec.name in {"x_hc", "x_next"}: | ||
| continue | ||
| if spec.name == "tid2eid": | ||
| def init_tid2eid(): | ||
| base = torch.arange(VOCAB, dtype=torch.int32).reshape(VOCAB, 1) * TOPK | ||
| offs = torch.arange(TOPK, dtype=torch.int32).reshape(1, TOPK) | ||
| table = (base + offs) % N_EXPERTS_GLOBAL | ||
| return table.unsqueeze(0).expand(N_RANKS, -1, -1).contiguous() | ||
|
|
||
| specs.append(TensorSpec("tid2eid", spec.shape, spec.dtype, init_value=init_tid2eid)) | ||
| elif spec.name == "input_ids": | ||
| def init_input_ids(): | ||
| ids = torch.arange(T, dtype=torch.int64) | ||
| return ids.unsqueeze(0).expand(N_RANKS, -1).contiguous() | ||
|
|
||
| specs.append(TensorSpec("input_ids", spec.shape, spec.dtype, init_value=init_input_ids)) | ||
| else: | ||
| specs.append(moe_tensor_specs[spec.name]) | ||
|
|
||
| specs.extend([ | ||
| TensorSpec("x_next", [N_RANKS, T, HC_MULT, D], torch.bfloat16, is_output=True), | ||
| ScalarSpec("layer_id", torch.int32, layer_id), | ||
| ]) |
There was a problem hiding this comment.
This block of code for adding MoE tensor specs and the final specs is duplicated. It appears here inside the if attention_mode == "auto" and unified_host: block, and again in the else block (lines 1524-1550).
To improve readability and maintainability of this very long function, this duplicated logic should be extracted into a helper function.
_add_moe_and_final_specs(specs, moe_specs, moe_tensor_specs, layer_id)There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/decode_layer_swa_moe_ep.py`:
- Around line 486-526: The auto-dispatch branch in
models/deepseek/v4/decode_layer_swa_moe_ep.py is calling attention_swa,
attention_hca and attention_csa but discarding their return values; each
attention_* (decorated `@pl.jit.inline`) returns the updated tensor (e.g., x_out)
and should be assigned back to x_attn. Fix by replacing the bare calls with
assignments—x_attn = attention_swa(...), x_attn = attention_hca(...), and x_attn
= attention_csa(...)—matching the earlier correct usage elsewhere in the file.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: b15031b6-0b85-44c7-83b2-bdc4c922907c
📒 Files selected for processing (2)
models/deepseek/v4/decode_layer_swa_moe_ep.pymodels/deepseek/v4/moe_ep.py
| if layer_id < 2: | ||
| attention_swa( | ||
| x_hc, | ||
| hc_attn_fn, hc_attn_scale, hc_attn_base, | ||
| attn_norm_w, wq_a, wq_b, wq_b_scale, | ||
| wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, | ||
| kv_cache, block_table, | ||
| attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, | ||
| x_attn, | ||
| start_pos, | ||
| ) | ||
| elif layer_id % 2 == 1: | ||
| attention_hca( | ||
| x_hc, | ||
| hc_attn_fn, hc_attn_scale, hc_attn_base, | ||
| attn_norm_w, wq_a, wq_b, wq_b_scale, | ||
| wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, | ||
| hca_cmp_wkv, hca_cmp_wgate, hca_cmp_ape, hca_cmp_norm_w, | ||
| hca_compress_state, hca_compress_state_block_table, | ||
| kv_cache, block_table, cmp_kv, cmp_block_table, | ||
| attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, | ||
| x_attn, | ||
| start_pos, | ||
| ) | ||
| else: | ||
| attention_csa( | ||
| x_hc, | ||
| hc_attn_fn, hc_attn_scale, hc_attn_base, | ||
| attn_norm_w, wq_a, wq_b, wq_b_scale, | ||
| wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, | ||
| csa_cmp_wkv, csa_cmp_wgate, csa_cmp_ape, csa_cmp_norm_w, | ||
| csa_compress_state, csa_compress_state_block_table, | ||
| csa_idx_wq_b, csa_idx_wq_b_scale, csa_weights_proj, csa_hadamard_idx, | ||
| csa_inner_wkv, csa_inner_wgate, csa_inner_ape, csa_inner_norm_w, | ||
| csa_inner_compress_state, csa_inner_compress_state_block_table, | ||
| kv_cache, block_table, cmp_kv, cmp_block_table, | ||
| idx_kv_cache, idx_block_table, | ||
| attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, | ||
| x_attn, | ||
| start_pos, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check decorators on attention_swa, attention_hca, attention_csa
echo "=== Checking attention_swa decorator ==="
rg -n -B2 "^def attention_swa\(" --type=py
echo "=== Checking attention_hca decorator ==="
rg -n -B2 "^def attention_hca\(" --type=py
echo "=== Checking attention_csa decorator ==="
rg -n -B2 "^def attention_csa\(" --type=pyRepository: hw-native-sys/pypto-lib
Length of output: 712
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Confirm decode_layer_swa_moe_ep usage around the flagged block
echo "=== decode_layer_swa_moe_ep usage (lines ~450-560) ==="
sed -n '440,560p' models/deepseek/v4/decode_layer_swa_moe_ep.py | nl -ba
# Check where attention_swa/hca/csa return (show a window around first/last return)
for f in \
models/deepseek/v4/decode_attention_swa.py \
models/deepseek/v4/decode_attention_hca.py \
models/deepseek/v4/decode_attention_csa.py
do
echo "=== Returns in $f ==="
# print function header line number
rg -n "^def $(basename "$f" .py)\(" "$f" || true
# show all return lines with a bit of context
rg -n "return " "$f" || true
# show tail of file section (in case return is near end)
tail -n 80 "$f" | nl -ba
done
# Check decode_layer_hca_moe_ep and decode_layer_csa_moe_ep whether they capture x_attn
echo "=== decode_layer_hca_moe_ep usage (search for attention_* = attention_*) ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_hca_moe_ep.py || true
echo "=== decode_layer_csa_moe_ep usage (search for x_attn = attention_*) ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_csa_moe_ep.py || trueRepository: hw-native-sys/pypto-lib
Length of output: 183
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== decode_layer_swa_moe_ep usage (lines ~450-560) ==="
sed -n '440,560p' models/deepseek/v4/decode_layer_swa_moe_ep.py | cat -n
for f in \
models/deepseek/v4/decode_attention_swa.py \
models/deepseek/v4/decode_attention_hca.py \
models/deepseek/v4/decode_attention_csa.py
do
echo "=== Attention function: $f (show return lines) ==="
# show function header
rg -n "^def $(basename "$f" .py)\(" "$f" || true
# show lines around return statements (limit output)
rg -n "return " "$f" || true
done
echo "=== decode_layer_hca_moe_ep: x_attn capture from attention_* ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_hca_moe_ep.py || true
echo "=== decode_layer_csa_moe_ep: x_attn capture from attention_* ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_csa_moe_ep.py || true
echo "=== decode_layer_swa_moe_ep: any x_attn = attention_* calls ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_swa_moe_ep.py || true
echo "=== decode_layer_swa_moe_ep: attention_swa/hca/csa callsites without assignment (look near ~486-526) ==="
sed -n '470,540p' models/deepseek/v4/decode_layer_swa_moe_ep.py | cat -nRepository: hw-native-sys/pypto-lib
Length of output: 17100
Capture attention_* return values in auto dispatch
attention_swa / attention_hca / attention_csa are @pl.jit.inline and return a tensor (e.g., return x_out), but in models/deepseek/v4/decode_layer_swa_moe_ep.py the if/elif/else auto-dispatch calls are made without assigning back to x_attn (the returned value is discarded). Earlier in the same file the code correctly uses x_attn = attention_*..., so this branch is inconsistent and should capture the return like the others.
Additionally, decode_layer_hca_moe_ep.py / decode_layer_csa_moe_ep.py are not present under models/deepseek/v4/ in this repo, so the issue should be scoped to the actual HCA/CSA auto-dispatch implementation files found there.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/decode_layer_swa_moe_ep.py` around lines 486 - 526, The
auto-dispatch branch in models/deepseek/v4/decode_layer_swa_moe_ep.py is calling
attention_swa, attention_hca and attention_csa but discarding their return
values; each attention_* (decorated `@pl.jit.inline`) returns the updated tensor
(e.g., x_out) and should be assigned back to x_attn. Fix by replacing the bare
calls with assignments—x_attn = attention_swa(...), x_attn = attention_hca(...),
and x_attn = attention_csa(...)—matching the earlier correct usage elsewhere in
the file.
d20723a to
bb63bc4
Compare
- Add a unified decode-layer host path that selects SWA, HCA, or CSA attention by layer_id inside the layer function. - Reuse the MoE EP path through an inline body plus standalone wrapper so the existing MoE smoke remains runnable. - Add auto-mode tensor specs and golden routing for the unified host.
Summary
models/deepseek/v4/decode_layer_ep.pyas the DeepSeek V4 decode-layer smoke that composes layer-id selected SWA/HCA/CSA attention with MoE EP2.host_orch_auto, while the JIT decode layer selects the attention implementation fromlayer_id.kv_cachein-place and only exposex_nextas the composed layer output.Related Issues
None