Skip to content

Add DeepSeek V4 unified decode layer#496

Open
high-cloud wants to merge 1 commit into
hw-native-sys:mainfrom
high-cloud:feat/deepseek-v4-unified-decode-layer
Open

Add DeepSeek V4 unified decode layer#496
high-cloud wants to merge 1 commit into
hw-native-sys:mainfrom
high-cloud:feat/deepseek-v4-unified-decode-layer

Conversation

@high-cloud

@high-cloud high-cloud commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Add models/deepseek/v4/decode_layer_ep.py as the DeepSeek V4 decode-layer smoke that composes layer-id selected SWA/HCA/CSA attention with MoE EP2.
  • Keep a single host orchestration entry, host_orch_auto, while the JIT decode layer selects the attention implementation from layer_id.
  • Keep kv_cache in-place and only expose x_next as the composed layer output.

Related Issues

None

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: c6973851-1958-44b8-a45a-06c04cd5b4e2

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds comprehensive end-to-end smoke tests for DeepSeek-V4 decode layers, composing local attention mechanisms (SWA, HCA, CSA, or auto-selected) with quantized MoE expert routing across 2 distributed ranks. It includes JIT kernels, host orchestration with buffer allocation, golden reference implementations for validation, and test infrastructure with CLI support.

Changes

DeepSeek-V4 Decode Layer MoE Integration

Layer / File(s) Summary
Module setup and imports
models/deepseek/v4/decode_layer_swa_moe_ep.py
Configures test module with PyPTO JIT/distributed utilities, attention (SWA/HCA/CSA), MoE EP, and model constants; enforces DP2 and unified host consistency constraints.
MoE EP test wrapper and inlining
models/deepseek/v4/moe_ep.py
Changes moe_ep from @pl.jit to @pl.jit.inline, adds moe_ep_test JIT wrapper that forwards all arguments to moe_ep, and updates l3_moe_ep call site to use the test wrapper.
Decode layer JIT kernels
models/deepseek/v4/decode_layer_swa_moe_ep.py
Implements four JIT kernels (decode_layer_swa_moe_ep, decode_layer_hca_moe_ep, decode_layer_csa_moe_ep, decode_layer_auto_moe_ep) that compute per-rank attention output and feed it into moe_ep with distributed plumbing.
Host orchestration and distributed buffers
models/deepseek/v4/decode_layer_swa_moe_ep.py
Provides four host orchestration entry points (host_orch, host_orch_hca, host_orch_csa, host_orch_auto) that allocate shared distributed window buffers for MoE combine, iterate over ranks, and invoke the corresponding decode kernel per device.
Golden references and validation utilities
models/deepseek/v4/decode_layer_swa_moe_ep.py
Implements golden reference functions for each attention mode and helper utilities (_validate_layer_id, _attention_kind_for_layer, _resolve_attention_mode, _host_and_golden_for_mode, _ranked_init) for correctness checking and attention mode resolution.
Test infrastructure and CLI
models/deepseek/v4/decode_layer_swa_moe_ep.py
Provides build_tensor_specs(...) generator that assembles ranked tensor specifications with special handling for initializers and attention parameter reordering; includes CLI __main__ that parses configuration, resolves attention mode, and runs distributed JIT testing with validation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#452: Updates to moe_ep decorator and moe_ep_test wrapper directly extend the EP MoE orchestration refactored in that PR.
  • hw-native-sys/pypto-lib#426: New decode layer smoke tests and distributed orchestration build on the 2-rank EP MoE program structure and call-path changes from that PR.
  • hw-native-sys/pypto-lib#450: The moe_ep decorator change and new moe_ep_test wrapper enable proper inlining and testing integration for the distributed MoE EP stage composed into decode layers.

Suggested labels

enhancement

Poem

🐰 Four paths through attention's art—
SWA, HCA, CSA, auto's smart—
Distributed MoE combines the prize,
Window buffers synchronize,
Golden reference checks the way,
Decode smoke tests save the day! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main addition to the PR: a unified decode layer for DeepSeek V4 that composes multiple attention mechanisms with MoE EP2.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request description clearly outlines the addition of a unified decode layer for DeepSeek V4 that composes attention mechanisms with MoE EP2, matching the changes in the summary.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces decode_layer_swa_moe_ep.py to implement DeepSeek-V4 decode layers combining SWA, HCA, and CSA attention mechanisms with MoE EP, alongside minor JIT inline and test function updates in moe_ep.py. The review feedback highlights several opportunities to improve code quality, including addressing the brittle, hardcoded attention selection logic in the auto-routing function, and refactoring duplicated code for window buffer allocations and tensor spec generation into reusable helper functions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +486 to +526
if layer_id < 2:
attention_swa(
x_hc,
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale,
wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin,
kv_cache, block_table,
attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale,
x_attn,
start_pos,
)
elif layer_id % 2 == 1:
attention_hca(
x_hc,
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale,
wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin,
hca_cmp_wkv, hca_cmp_wgate, hca_cmp_ape, hca_cmp_norm_w,
hca_compress_state, hca_compress_state_block_table,
kv_cache, block_table, cmp_kv, cmp_block_table,
attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale,
x_attn,
start_pos,
)
else:
attention_csa(
x_hc,
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale,
wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin,
csa_cmp_wkv, csa_cmp_wgate, csa_cmp_ape, csa_cmp_norm_w,
csa_compress_state, csa_compress_state_block_table,
csa_idx_wq_b, csa_idx_wq_b_scale, csa_weights_proj, csa_hadamard_idx,
csa_inner_wkv, csa_inner_wgate, csa_inner_ape, csa_inner_norm_w,
csa_inner_compress_state, csa_inner_compress_state_block_table,
kv_cache, block_table, cmp_kv, cmp_block_table,
idx_kv_cache, idx_block_table,
attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale,
x_attn,
start_pos,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for selecting the attention mechanism is tightly coupled to the specific compress_ratios pattern in the current model configuration. This makes the code brittle. If MODEL_CONFIG.compress_ratios changes in a way that violates the hardcoded layer_id < 2 or layer_id % 2 == 1 checks, this function will silently use the wrong attention mechanism. This could lead to incorrect behavior that is hard to debug.

The golden path implementation in golden_decode_layer_auto uses _attention_kind_for_layer, which is more robust as it directly uses the configuration.

While pl.jit functions have limitations on accessing external configuration, this hardcoded logic is a significant risk. Please consider adding a comment to highlight this strong assumption about the configuration, or an assertion to verify that the layer_id corresponds to the expected attention mechanism based on the MODEL_CONFIG if possible at JIT time.

Comment on lines +595 to +602
pub_counts_buf = pld.alloc_window_buffer(N_RANKS * N_RANKS * N_LOCAL * 4)
count_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
recv_x_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * D * 2)
recv_scale_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
recv_w_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
recv_r_route_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * IDX_PAD * 4)
routed_y_buf_buf = pld.alloc_window_buffer(N_ROUTES * D * 2)
combine_done_buf = pld.alloc_window_buffer(N_RANKS * 4)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for allocating window buffers is duplicated across host_orch, host_orch_hca, host_orch_csa, and host_orch_auto.

To improve maintainability and reduce redundancy, this logic can be extracted into a helper function.

Example of a helper function:

def _alloc_moe_windows():
    pub_counts_buf = pld.alloc_window_buffer(N_RANKS * N_RANKS * N_LOCAL * 4)
    count_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
    recv_x_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * D * 2)
    recv_scale_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
    recv_w_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
    recv_r_route_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * IDX_PAD * 4)
    routed_y_buf_buf = pld.alloc_window_buffer(N_ROUTES * D * 2)
    combine_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
    return (
        pub_counts_buf, count_done_buf, recv_x_buf, recv_scale_buf, 
        recv_w_buf, recv_r_route_buf, routed_y_buf_buf, combine_done_buf
    )

You can then call this helper in each host_orch_* function.

Suggested change
pub_counts_buf = pld.alloc_window_buffer(N_RANKS * N_RANKS * N_LOCAL * 4)
count_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
recv_x_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * D * 2)
recv_scale_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
recv_w_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * W_PAD * 4)
recv_r_route_buf = pld.alloc_window_buffer(N_LOCAL * RECV_MAX * IDX_PAD * 4)
routed_y_buf_buf = pld.alloc_window_buffer(N_ROUTES * D * 2)
combine_done_buf = pld.alloc_window_buffer(N_RANKS * 4)
(
pub_counts_buf, count_done_buf, recv_x_buf, recv_scale_buf, recv_w_buf,
recv_r_route_buf, routed_y_buf_buf, combine_done_buf
) = _alloc_moe_windows()

Comment on lines +1365 to +1390
for spec in moe_specs:
if not isinstance(spec, TensorSpec):
continue
if spec.name in {"x_hc", "x_next"}:
continue
if spec.name == "tid2eid":
def init_tid2eid():
base = torch.arange(VOCAB, dtype=torch.int32).reshape(VOCAB, 1) * TOPK
offs = torch.arange(TOPK, dtype=torch.int32).reshape(1, TOPK)
table = (base + offs) % N_EXPERTS_GLOBAL
return table.unsqueeze(0).expand(N_RANKS, -1, -1).contiguous()

specs.append(TensorSpec("tid2eid", spec.shape, spec.dtype, init_value=init_tid2eid))
elif spec.name == "input_ids":
def init_input_ids():
ids = torch.arange(T, dtype=torch.int64)
return ids.unsqueeze(0).expand(N_RANKS, -1).contiguous()

specs.append(TensorSpec("input_ids", spec.shape, spec.dtype, init_value=init_input_ids))
else:
specs.append(moe_tensor_specs[spec.name])

specs.extend([
TensorSpec("x_next", [N_RANKS, T, HC_MULT, D], torch.bfloat16, is_output=True),
ScalarSpec("layer_id", torch.int32, layer_id),
])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for adding MoE tensor specs and the final specs is duplicated. It appears here inside the if attention_mode == "auto" and unified_host: block, and again in the else block (lines 1524-1550).

To improve readability and maintainability of this very long function, this duplicated logic should be extracted into a helper function.

        _add_moe_and_final_specs(specs, moe_specs, moe_tensor_specs, layer_id)

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/decode_layer_swa_moe_ep.py`:
- Around line 486-526: The auto-dispatch branch in
models/deepseek/v4/decode_layer_swa_moe_ep.py is calling attention_swa,
attention_hca and attention_csa but discarding their return values; each
attention_* (decorated `@pl.jit.inline`) returns the updated tensor (e.g., x_out)
and should be assigned back to x_attn. Fix by replacing the bare calls with
assignments—x_attn = attention_swa(...), x_attn = attention_hca(...), and x_attn
= attention_csa(...)—matching the earlier correct usage elsewhere in the file.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: b15031b6-0b85-44c7-83b2-bdc4c922907c

📥 Commits

Reviewing files that changed from the base of the PR and between cef942a and 7a90f1c.

📒 Files selected for processing (2)
  • models/deepseek/v4/decode_layer_swa_moe_ep.py
  • models/deepseek/v4/moe_ep.py

Comment on lines +486 to +526
if layer_id < 2:
attention_swa(
x_hc,
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale,
wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin,
kv_cache, block_table,
attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale,
x_attn,
start_pos,
)
elif layer_id % 2 == 1:
attention_hca(
x_hc,
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale,
wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin,
hca_cmp_wkv, hca_cmp_wgate, hca_cmp_ape, hca_cmp_norm_w,
hca_compress_state, hca_compress_state_block_table,
kv_cache, block_table, cmp_kv, cmp_block_table,
attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale,
x_attn,
start_pos,
)
else:
attention_csa(
x_hc,
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale,
wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin,
csa_cmp_wkv, csa_cmp_wgate, csa_cmp_ape, csa_cmp_norm_w,
csa_compress_state, csa_compress_state_block_table,
csa_idx_wq_b, csa_idx_wq_b_scale, csa_weights_proj, csa_hadamard_idx,
csa_inner_wkv, csa_inner_wgate, csa_inner_ape, csa_inner_norm_w,
csa_inner_compress_state, csa_inner_compress_state_block_table,
kv_cache, block_table, cmp_kv, cmp_block_table,
idx_kv_cache, idx_block_table,
attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale,
x_attn,
start_pos,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check decorators on attention_swa, attention_hca, attention_csa

echo "=== Checking attention_swa decorator ==="
rg -n -B2 "^def attention_swa\(" --type=py

echo "=== Checking attention_hca decorator ==="
rg -n -B2 "^def attention_hca\(" --type=py

echo "=== Checking attention_csa decorator ==="
rg -n -B2 "^def attention_csa\(" --type=py

Repository: hw-native-sys/pypto-lib

Length of output: 712


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Confirm decode_layer_swa_moe_ep usage around the flagged block
echo "=== decode_layer_swa_moe_ep usage (lines ~450-560) ==="
sed -n '440,560p' models/deepseek/v4/decode_layer_swa_moe_ep.py | nl -ba

# Check where attention_swa/hca/csa return (show a window around first/last return)
for f in \
  models/deepseek/v4/decode_attention_swa.py \
  models/deepseek/v4/decode_attention_hca.py \
  models/deepseek/v4/decode_attention_csa.py
do
  echo "=== Returns in $f ==="
  # print function header line number
  rg -n "^def $(basename "$f" .py)\(" "$f" || true
  # show all return lines with a bit of context
  rg -n "return " "$f" || true
  # show tail of file section (in case return is near end)
  tail -n 80 "$f" | nl -ba
done

# Check decode_layer_hca_moe_ep and decode_layer_csa_moe_ep whether they capture x_attn
echo "=== decode_layer_hca_moe_ep usage (search for attention_* = attention_*) ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_hca_moe_ep.py || true

echo "=== decode_layer_csa_moe_ep usage (search for x_attn = attention_*) ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_csa_moe_ep.py || true

Repository: hw-native-sys/pypto-lib

Length of output: 183


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== decode_layer_swa_moe_ep usage (lines ~450-560) ==="
sed -n '440,560p' models/deepseek/v4/decode_layer_swa_moe_ep.py | cat -n

for f in \
  models/deepseek/v4/decode_attention_swa.py \
  models/deepseek/v4/decode_attention_hca.py \
  models/deepseek/v4/decode_attention_csa.py
do
  echo "=== Attention function: $f (show return lines) ==="
  # show function header
  rg -n "^def $(basename "$f" .py)\(" "$f" || true
  # show lines around return statements (limit output)
  rg -n "return " "$f" || true
done

echo "=== decode_layer_hca_moe_ep: x_attn capture from attention_* ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_hca_moe_ep.py || true

echo "=== decode_layer_csa_moe_ep: x_attn capture from attention_* ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_csa_moe_ep.py || true

echo "=== decode_layer_swa_moe_ep: any x_attn = attention_* calls ==="
rg -n "x_attn\s*=\s*attention_(swa|hca|csa)\(" models/deepseek/v4/decode_layer_swa_moe_ep.py || true

echo "=== decode_layer_swa_moe_ep: attention_swa/hca/csa callsites without assignment (look near ~486-526) ==="
sed -n '470,540p' models/deepseek/v4/decode_layer_swa_moe_ep.py | cat -n

Repository: hw-native-sys/pypto-lib

Length of output: 17100


Capture attention_* return values in auto dispatch

attention_swa / attention_hca / attention_csa are @pl.jit.inline and return a tensor (e.g., return x_out), but in models/deepseek/v4/decode_layer_swa_moe_ep.py the if/elif/else auto-dispatch calls are made without assigning back to x_attn (the returned value is discarded). Earlier in the same file the code correctly uses x_attn = attention_*..., so this branch is inconsistent and should capture the return like the others.

Additionally, decode_layer_hca_moe_ep.py / decode_layer_csa_moe_ep.py are not present under models/deepseek/v4/ in this repo, so the issue should be scoped to the actual HCA/CSA auto-dispatch implementation files found there.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_layer_swa_moe_ep.py` around lines 486 - 526, The
auto-dispatch branch in models/deepseek/v4/decode_layer_swa_moe_ep.py is calling
attention_swa, attention_hca and attention_csa but discarding their return
values; each attention_* (decorated `@pl.jit.inline`) returns the updated tensor
(e.g., x_out) and should be assigned back to x_attn. Fix by replacing the bare
calls with assignments—x_attn = attention_swa(...), x_attn = attention_hca(...),
and x_attn = attention_csa(...)—matching the earlier correct usage elsewhere in
the file.

@high-cloud high-cloud force-pushed the feat/deepseek-v4-unified-decode-layer branch 9 times, most recently from d20723a to bb63bc4 Compare June 15, 2026 11:27
- Add a unified decode-layer host path that selects SWA, HCA, or CSA attention by layer_id inside the layer function.
- Reuse the MoE EP path through an inline body plus standalone wrapper so the existing MoE smoke remains runnable.
- Add auto-mode tensor specs and golden routing for the unified host.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant