Skip to content

refactor: pluggable guidance variant registry#1129

Open
FlexOr2 wants to merge 6 commits into
ace-step:mainfrom
FlexOr2:feat/pluggable-guidance-registry
Open

refactor: pluggable guidance variant registry#1129
FlexOr2 wants to merge 6 commits into
ace-step:mainfrom
FlexOr2:feat/pluggable-guidance-registry

Conversation

@FlexOr2

@FlexOr2 FlexOr2 commented Apr 22, 2026

Copy link
Copy Markdown

Summary

Replaces the hardcoded apg_forward / adg_forward / cfg_forward dispatch inside every generate_audio() with a pluggable registry keyed by a string guidance_variant and configured by a free-form guidance_params dict. Exposes both on /release_task. API docs updated (en/ja/ko/zh).

Closes #1124. Maintainer approval quote (@ChuxiJ, 2026-04-22):

approve, we can have a refact pr

Branched on top of upstream/main at 1d9d2d3.

Related PRs

Design decisions

  • typing.Protocol for the wrapper signature, not Pydantic or abc.ABC. Guidance runs in the per-step sampler loop on GPU — a Pydantic validator per call would add measurable overhead. The wrappers obey a structural Protocol: fn(pred_cond, pred_uncond, guidance_scale, state, **params) -> Tensor. This matches the Diffusers callback_on_step_end style.
  • Registry lives in its own module (acestep/models/common/guidance_registry.py) rather than being folded into apg_guidance.py. Keeps the raw math functions as a thin numerical layer; the wrappers are the adapter. apg_guidance.py is untouched.
  • Unknown variant raises ValueError listing the registered names. No silent default to APG — silent fallbacks make parameter bugs invisible in a diffusion stack.
  • State dict is the sampler's responsibility. Sampler creates an empty dict at the start of generate_audio, stashes per-step values (latents, sigma, step_role) into it before each call, and drops it on return. Wrappers persist cross-step bookkeeping (APG's MomentumBuffer) in the same dict. state["step_role"]{"main", "corrector"} so Heun correctors can degrade APG → CFG and ADG → CFG-at-sigma=0 exactly as the pre-refactor code did.
  • CFG=1 is a no-op by construction (uncond + 1 * (cond - uncond) == cond); no special case is needed. Documented in guidance_cfg docstring.

Scope

In:

  • PyTorch guidance dispatch for all four model families (base, sft, xl_base, xl_sft), both Euler and Heun samplers, main + corrector steps.
  • HTTP surface: guidance_variant and guidance_params on GenerateMusicRequest, with snake_case + camelCase alias resolution, JSON-string decoding with None fallback on malformed input.
  • End-to-end plumbing: HTTP body → GenerationParamsgenerate_musicservice_generate_build_service_generate_kwargsmodel.generate_audio.
  • Documentation: docs/{en,ja,ko,zh}/API.md gain rows for guidance_variant and guidance_params next to use_adg.

Out:

  • MLX. _mlx_apg_forward in acestep/models/mlx/dit_generate.py uses a different state-dict shape and does not expose eta. MLX parity with the PyTorch registry is a separate patch; called out explicitly so nobody assumes they get free MLX support from this PR.

Backward compatibility

The hard invariant: generate_audio(..., guidance_variant="apg_classic", guidance_params={}) produces output bit-identical to the pre-refactor call site (torch.allclose at default tolerance rtol=1e-5, atol=1e-8, which on same-input deterministic reruns is effectively byte-identical). The legacy use_adg=True path also preserves its behaviour byte-identically by auto-resolving to the adg registered variant when the caller has not specified a non-default guidance_variant.

Proof (all passing):

  • guidance_registry_test::test_apg_classic_wrapper_matches_apg_forward — wrapper vs raw apg_forward on fresh momentum.
  • guidance_registry_sampler_regression_test::test_full_euler_sequence_matches_hand_written_loop — N-step Euler parity with momentum carry-over.
  • guidance_registry_sampler_regression_test::test_heun_sequence_matches_hand_written_loop — Heun predictor+corrector parity for APG (predictor via apg_forward, corrector via cfg_forward).
  • guidance_registry_sampler_regression_test::test_heun_sequence_matches_hand_written_loop_with_zero_sigma_corrector — Heun parity for ADG including the sigma == 0 corrector safeguard.

Test plan

  • python -m unittest acestep.models.common.guidance_registry_test acestep.models.common.guidance_registry_sampler_regression_test — 16 tests, all green.
  • python -m unittest discover -s acestep -p "*_test.py" -t . — matches origin/main (same pre-existing errors, no new failures). +24 new tests added by this PR vs the baseline.
  • Registry unit tests: default contents, decorator wiring, unknown-name ValueError, per-wrapper parity with raw functions, momentum state persistence, fresh-state semantics, corrector-step degradation.
  • Sampler-level regression tests (3 tests): Euler APG parity, Heun APG parity, Heun ADG parity with sigma == 0 safeguard.
  • HTTP-level tests: Pydantic defaults, snake/camelCase alias resolution, JSON-string decoding with None fallback on malformed input, GenerationParams forwarding with and without the fields present on the request.
  • Manual smoke: real generation with each of {apg_classic, cfg, adg, adg_w_norm, adg_wo_clip}. Hardware-bound, leaving to reviewer / CI with weights.

Happy to adjust

Any decision flagged above is a starting point, not a commitment. If you prefer a different protocol signature, registry location, state-dict convention, or naming scheme — push back and I'll revise. The goal is to land something maintainable, not to litigate my initial choices.

Summary by CodeRabbit

  • New Features

    • Added pluggable guidance controls for music generation: choose a guidance variant and supply per-variant parameters; these flow through the full generation pipeline.
  • Tests

    • Added unit and regression tests for variant selection, alias (snake/camel) resolution, JSON parsing/validation of params, and end-to-end propagation.
  • Documentation

    • API docs updated to document guidance_variant and guidance_params (camelCase aliases and multipart JSON handling).

@coderabbitai

coderabbitai Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds a pluggable guidance-function registry and threads two new HTTP request fields—guidance_variant and guidance_params—through the request parsing, builder, job setup, service, inference, and model layers; includes registry implementation, model integrations, tests, and docs updates.

Changes

Guidance registry + HTTP plumbing

Cohort / File(s) Summary
HTTP API models & parsing
acestep/api/http/release_task_models.py, acestep/api/http/release_task_models_test.py, acestep/api/http/release_task_param_parser.py, acestep/api/http/release_task_param_parser_test.py
Added guidance_variant and guidance_params to the Pydantic request model and tests; extended PARAM_ALIASES to resolve snake_case and camelCase aliases.
Request builder & job wiring
acestep/api/http/release_task_request_builder.py, acestep/api/http/release_task_request_builder_test.py, acestep/api/job_generation_setup.py, acestep/api/job_generation_setup_test.py
Normalized guidance_params (JSON string → dict or None), added guidance_variant/guidance_params to payloads and GenerationParams, and added tests for parsing and wiring.
Service / handler plumbing
acestep/core/generation/handler/generate_music.py, acestep/core/generation/handler/generate_music_execute.py, acestep/core/generation/handler/service_generate.py, acestep/core/generation/handler/service_generate_execute.py
Extended method signatures to accept and forward guidance_variant and guidance_params through the generation call chain.
Core inference params
acestep/inference.py
Added guidance_variant and guidance_params to GenerationParams and passed them into DiT generation kwargs.
Guidance registry (new module)
acestep/models/common/guidance_registry.py
Added a registry API (register_guidance, get_guidance_fn, registered_guidance_names) and five built-in guidance wrappers implementing a unified GuidanceFn contract plus per-generation state management and corrector fallbacks.
Model integrations (4 variants)
acestep/models/base/..., acestep/models/sft/..., acestep/models/xl_base/..., acestep/models/xl_sft/...
Replaced hardcoded guidance imports/branches with registry-driven guidance_fn; added guidance_variant/guidance_params to generate_audio() signatures; introduced shared guidance_state and guidance_kwargs.
Registry tests & sampler parity
acestep/models/common/guidance_registry_test.py, acestep/models/common/guidance_registry_sampler_regression_test.py
Added tests validating registry contents, extensibility, error messages, state lifecycle, corrector semantics, and parity between registry-dispatched wrappers and legacy guidance functions.
Docs
docs/en/API.md, docs/ja/API.md, docs/ko/API.md, docs/zh/API.md
Documented guidance_variant and guidance_params as new POST /release_task parameters, including camelCase aliases and multipart JSON encoding note.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Client
    participant HTTP as HTTP Handler
    participant Parser as Param Parser
    participant Builder as Request Builder
    participant Job as Job Setup
    participant Service as Service Layer
    participant Model as DiT Model
    participant Registry as Guidance Registry

    Client->>HTTP: POST /release_task (guidance_variant, guidance_params)
    HTTP->>Parser: parse + resolve aliases
    Parser->>Builder: extracted fields
    Builder->>Builder: normalize guidance_params (JSON -> dict|None)
    Builder->>Job: GenerateMusicRequest
    Job->>Service: GenerationParams with guidance fields
    Service->>Model: generate_audio(..., guidance_variant, guidance_params)
    Model->>Registry: get_guidance_fn(guidance_variant)
    Registry-->>Model: guidance_fn
    Model->>Model: init guidance_state, guidance_kwargs
    loop per CFG step
      Model->>Model: set state["latents"], state["sigma"], state["step_role"]
      Model->>Model: guidance_fn(pred_cond, pred_uncond, scale, state, **guidance_kwargs)
    end
    Model-->>Client: audio payload
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • ace-step/ACE-Step-1.5#1120 — Similar plumbing: threads additional generation-time parameters through the same call chain (generate_music → service → model).
  • ace-step/ACE-Step-1.5#978 — Adds generation control parameters across the same call chain and overlaps on signature/propagation concerns.
  • ace-step/ACE-Step-1.5#479 — Contains the legacy APG/ADG/CFG guidance implementations that the new registry dispatches to.

Suggested reviewers

  • ChuxiJ

Poem

🐰 I hopped through code with eager paws,

Registered guidance, cleared old laws.
Variants lined up, params in store,
State kept tidy, called once more.
Now generation hops freer than before.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'refactor: pluggable guidance variant registry' accurately summarizes the primary change—replacing hardcoded guidance dispatch with a pluggable registry keyed by guidance_variant and configured via guidance_params.
Linked Issues check ✅ Passed The PR fully implements the core objectives from #1124 (pluggable guidance registry with GuidanceFn protocol, sampler state dict, HTTP surface, ValueError handling) and #1123 (expose APG eta/momentum via guidance_params), including registry in guidance_registry.py, wrappers for all 5 variants, HTTP plumbing with snake/camelCase aliases, and comprehensive unit/regression tests.
Out of Scope Changes check ✅ Passed All changes are directly aligned with RFC objectives: guidance registry implementation, HTTP API exposure with aliases, GenerationParams threading, model handler updates, documentation in 4 languages, and supporting tests. No out-of-scope refactoring or unrelated modifications detected.
Docstring Coverage ✅ Passed Docstring coverage is 93.22% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@FlexOr2

FlexOr2 commented Apr 22, 2026

Copy link
Copy Markdown
Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
acestep/core/generation/handler/service_generate_execute.py (1)

192-225: ⚠️ Potential issue | 🟡 Minor

MLX path silently ignores guidance_variant/guidance_params.

The PyTorch branch (line 225) picks up the new fields via generate_kwargs, but _mlx_run_diffusion (lines 192–213) is not passed them. PR scope excludes MLX from the registry, which is fine — but when a user supplies a non-default guidance_variant over HTTP while the MLX backend is active, the selection is silently dropped with no log/warning. Consider either logging a one-time warning when guidance_variant != "apg_classic" (or guidance_params is non-empty) and the MLX path is taken, or documenting this caveat in docs/*/API.md.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/service_generate_execute.py` around lines 192
- 225, The MLX branch is dropping guidance settings: when calling
_mlx_run_diffusion from service_generate_execute (the MLX path), it does not
pass generate_kwargs["guidance_variant"] or ["guidance_params"], so non-default
guidance is silently ignored; fix by either (A) adding
guidance_variant=generate_kwargs.get("guidance_variant", "apg_classic") and
guidance_params=generate_kwargs.get("guidance_params", {}) to the
_mlx_run_diffusion call so the MLX runner receives these options, or (B) if MLX
truly doesn't support them, emit a one-time logger.warning in the MLX branch
(before calling _mlx_run_diffusion) when generate_kwargs.get("guidance_variant")
!= "apg_classic" or generate_kwargs.get("guidance_params") is non-empty to alert
users that guidance was dropped.
🧹 Nitpick comments (6)
acestep/models/common/guidance_registry.py (2)

77-84: Optional: detect duplicate registrations.

register_guidance silently overwrites an existing name. For a public registration surface (third-party variants, per issue #1124), a duplicate is almost always a bug in load order; raising would catch it immediately. Low priority — feel free to defer.

Proposed fix
     def decorator(fn: GuidanceFn) -> GuidanceFn:
+        if name in _GUIDANCE_REGISTRY:
+            raise ValueError(
+                f"Guidance variant {name!r} is already registered."
+            )
         _GUIDANCE_REGISTRY[name] = fn
         return fn
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/guidance_registry.py` around lines 77 - 84, The
register_guidance decorator currently overwrites entries in _GUIDANCE_REGISTRY;
change it to detect duplicates by checking if name is already a key in
_GUIDANCE_REGISTRY inside the decorator (or before assignment) and raise a clear
exception (e.g., ValueError) that includes the conflicting name and guidance
function name instead of silently overwriting; update the function
register_guidance and its inner decorator to perform this guard using the unique
symbols register_guidance and _GUIDANCE_REGISTRY.

117-154: Dims parameter default is correct; add clarifying comment for maintainability.

The wrapper default dims=(1,) is intentional and matches the pre-refactor behavior. All test cases (guidance_registry_test.py, guidance_registry_sampler_regression_test.py, and MLX implementation comment in dit_generate.py:112) explicitly confirm that apg_forward(..., dims=[1]) is the standard usage for the ACE-Step shape convention [B, T, C].

The divergence from apg_forward's documented default of dims=[-1] is correct for this wrapper but creates a maintainability gap. Add a one-line comment above the wrapper's dims parameter explaining why (1,) is used instead of matching the underlying function's default:

dims: Any = (1,),  # APG projects along sequence/time axis (dim 1), not the trailing default
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/guidance_registry.py` around lines 117 - 154, The
wrapper guidance_apg_classic intentionally uses dims=(1,) rather than
apg_forward's default [-1]; add a one-line clarifying comment on the dims
parameter (e.g., next to "dims: Any = (1,),") stating that APG should project
along the sequence/time axis (dim 1) to match the ACE-Step shape convention [B,
T, C] and existing tests (guidance_registry_test.py,
guidance_registry_sampler_regression_test.py) and MLX usage in dit_generate.py;
this keeps the divergence from apg_forward's default explicit and maintainable.
acestep/api/http/release_task_request_builder.py (1)

39-48: Narrow exception and hoist the import; consider reusing existing JSON-parsing helper.

  • import json as _json inside the function is unusual — json is used elsewhere in this module's siblings, hoist to module-level imports.
  • Ruff flags the bare except Exception (BLE001). Narrow to (json.JSONDecodeError, TypeError) so genuine bugs (e.g. AttributeError) aren't swallowed.
  • release_task_param_parser.py::RequestParser._parse_json already implements exactly this pattern (dict pass-through, JSON-string → dict, fallback). Consider reusing/exposing it to avoid the duplicated logic drift (e.g. that helper does a .strip() check before json.loads, this one doesn't).
♻️ Proposed fix
+import json
+
 from typing import Any, Optional
@@
-    guidance_params = parser.get("guidance_params")
-    if isinstance(guidance_params, str):
-        import json as _json
-        try:
-            parsed_guidance = _json.loads(guidance_params)
-            guidance_params = parsed_guidance if isinstance(parsed_guidance, dict) else None
-        except Exception:
-            guidance_params = None
-    elif guidance_params is not None and not isinstance(guidance_params, dict):
-        guidance_params = None
+    guidance_params = parser.get("guidance_params")
+    if isinstance(guidance_params, str):
+        try:
+            parsed_guidance = json.loads(guidance_params) if guidance_params.strip() else None
+        except (json.JSONDecodeError, TypeError):
+            parsed_guidance = None
+        guidance_params = parsed_guidance if isinstance(parsed_guidance, dict) else None
+    elif guidance_params is not None and not isinstance(guidance_params, dict):
+        guidance_params = None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/api/http/release_task_request_builder.py` around lines 39 - 48, Hoist
the json import to module scope and avoid the bare except: when handling
guidance_params (the parser.get("guidance_params") branch) narrow the exception
to (json.JSONDecodeError, TypeError) and apply the same strip+json.loads→dict
logic currently in RequestParser._parse_json; better yet, call or reuse
RequestParser._parse_json (or extract its helper) instead of duplicating logic
so guidance_params follows the same dict/pass-through, JSON-string→dict, and
fallback behavior as RequestParser._parse_json.
acestep/core/generation/handler/service_generate.py (1)

41-42: LGTM — parameters threaded through correctly.

Optional nit: docstring Args section doesn't document guidance_variant/guidance_params. Worth a one-line addition for API clarity, but not blocking.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/service_generate.py` around lines 41 - 42,
Add one-line descriptions for the two new parameters to the function's docstring
Args section: document guidance_variant (str) as the guidance algorithm/variant
to use (default "apg_classic") and guidance_params (Optional[Dict[str, Any]]) as
optional tuning parameters for the guidance engine. Locate the docstring for the
function that declares guidance_variant/guidance_params and append these two
entries to the Args block for API clarity.
acestep/core/generation/handler/generate_music.py (1)

225-240: Document the new guidance arguments in the public docstring.

generate_music() now exposes guidance_variant and guidance_params, but the Args section omits them. As per coding guidelines, Docstrings are mandatory for all new or modified Python modules, classes, and functions, and must include purpose plus key inputs/outputs and raised exceptions when relevant.

📝 Proposed docstring addition
             guidance_scale: CFG guidance value.
+            guidance_variant: Registered guidance implementation to use.
+            guidance_params: Optional per-variant guidance parameter overrides.
             seed: Optional explicit seed from caller/UI.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music.py` around lines 225 - 240,
The generate_music() docstring’s Args section is missing the newly added
guidance_variant and guidance_params parameters; update the public docstring for
the generate_music function to list and describe these two arguments (expected
types, allowed values or enum for guidance_variant, structure and keys for
guidance_params, default behavior if omitted), include any effects on
guidance_scale or interaction with other args (seed, infer_method), and mention
relevant raised exceptions or validation errors thrown when invalid guidance
values are provided so the docstring fully documents inputs/outputs and error
conditions.
acestep/models/common/guidance_registry_test.py (1)

1-238: Split this new test module below the 200 LOC cap.

At 238 lines, this new file exceeds the repo’s hard cap. A clean split would be registry API tests vs. wrapper/state/corrector parity tests. Based on learnings, “only raise module-size concerns when a file exceeds 200 lines of code (LOC).” As per coding guidelines, “Module LOC policy is met (<=150 target, <=200 hard cap or justified exception).”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/guidance_registry_test.py` around lines 1 - 238, This
file exceeds the 200 LOC cap—split it into two test modules: one for registry
API tests (keep _fixed_tensors helper and GuidanceRegistryTests) and a second
for wrapper/state/corrector parity tests (move GuidanceWrapperParityTests,
GuidanceStateLifecycleTests, GuidanceCorrectorStepTests and any tests that
reference apg_forward, adg_forward, cfg_forward, MomentumBuffer,
get_guidance_fn, register_guidance, registered_guidance_names, _MOMENTUM_KEY).
Ensure both new modules import the shared helper _fixed_tensors (or duplicate it
if simpler), update imports to reference acestep.models.common.apg_guidance and
acestep.models.common.guidance_registry symbols used in the moved tests, and run
tests to confirm names like get_guidance_fn("apg_classic") still resolve; remove
the original oversized file once split.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/core/generation/handler/generate_music.py`:
- Around line 205-206: Change the public parameter guidance_variant in
generate_music (and any callers in the same internal call chain) from a concrete
default "apg_classic" to Optional[str] (i.e., allow None) so upstream callers
can omit it; propagate this None sentinel through internal functions rather than
converting to "apg_classic" immediately, and only resolve the effective variant
at the guidance-dispatch boundary (the function/method that actually selects
guidance behavior—e.g., the guidance dispatcher) with logic like: if
guidance_variant is None: resolved_variant = "adg" if use_adg else "apg_classic"
else: resolved_variant = guidance_variant. Ensure signatures and any
tests/places that construct the call use Optional[str] and handle None
accordingly.

In `@acestep/models/base/modeling_acestep_v15_base.py`:
- Around line 1867-1868: The default for guidance_variant currently masks
whether the caller explicitly selected APG; change the parameter type to
Optional[str] (guidance_variant: Optional[str] = None) and update the selection
logic so that if guidance_variant is not None you always honor it, otherwise
fall back to legacy behavior (if use_adg True -> "adg", else -> "apg_classic").
Apply the same change to the other signatures/handlers in this file (the block
covering lines ~1945-1954) so explicit guidance_variant always takes precedence
over use_adg.

In `@acestep/models/common/guidance_registry_sampler_regression_test.py`:
- Around line 77-78: Update the parity zip calls in the test to enforce
equal-length sequences: change all occurrences of zip(reference_outputs,
reg_outputs) (and the other similar zips used for comparing outputs) to
zip(reference_outputs, reg_outputs, strict=True) in the
guidance_registry_sampler_regression_test.py loops so length drift fails loudly
(apply the same change to the other zip instances around the comparisons).

In `@acestep/models/common/guidance_registry.py`:
- Line 225: Replace the hardcoded 3.14 approximations with math.pi: add import
math to the modules that define angle_clip defaults, change occurrences of 3.14
/ 6 to math.pi / 6 and 3.14 / 3 to math.pi / 3 (look for the angle_clip default
assignments and APG/ADG guidance default constructors or registry entries
referencing angle_clip), and update any unit tests that assert angle_clip ==
3.14 / 6 to use the new math.pi-based values so tests reflect the exact π
constants.

In `@acestep/models/sft/modeling_acestep_v15_base.py`:
- Around line 1868-1869: Change the guidance_variant parameter default from the
string "apg_classic" to None in the affected function(s) so None acts as the
"not provided" sentinel (e.g., the function/method signatures that currently
include guidance_variant: str = "apg_classic" should become guidance_variant:
Optional[str] = None); then, where guidance_variant is resolved (look for the
resolution logic that currently considers use_adg), only apply the legacy
fallback when guidance_variant is None: if guidance_variant is None set
guidance_variant = "adg" when use_adg is True, otherwise set guidance_variant =
"apg_classic"; apply the same change to the other occurrence around the guidance
handling code referenced in the file (the block near lines 1953-1962) so
explicit guidance_variant values are never overridden by use_adg.

In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py`:
- Around line 1880-1881: The function currently defaults guidance_variant to the
string "apg_classic", which causes explicit guidance_variant="apg_classic" to be
treated as “not provided” and be overridden by legacy use_adg logic; change the
parameter default to None (use this sentinel for “not provided”) in the
signature where guidance_variant is declared, and update the mapping logic that
reads use_adg to only map use_adg -> "adg" when guidance_variant is None (leave
any explicit guidance_variant unchanged); ensure callers preserve None for
unspecified variant so legacy use_adg compatibility continues to work (adjust
any nearby callers/converters that currently pass the literal default if
needed).

In `@acestep/models/xl_sft/modeling_acestep_v15_xl_base.py`:
- Around line 1880-1881: The function parameter guidance_variant should use None
(or a sentinel) as its default instead of the concrete string "apg_classic" so
callers can explicitly force APG even when legacy use_adg is true; update the
signature where guidance_variant: str = "apg_classic" to guidance_variant:
Optional[str] = None and change the resolution logic (the code that currently
maps guidance_variant="apg_classic", use_adg=True to "adg") so it only maps
use_adg->"adg" when guidance_variant is None; also update the HTTP/generation
plumbing that constructs calls to this API to pass None when the client omitted
guidance_variant rather than pre-filling "apg_classic" so the downstream
resolver can correctly honor an explicit guidance_variant value.

In `@docs/en/API.md`:
- Around line 215-216: Update the API docs for the guidance fields to clarify
encoding and aliases: state that guidance_params (and its camelCase alias
guidanceParams) is an object in JSON schemas but must be sent as a JSON-encoded
string when using form/multipart requests, and that guidance_variant (alias
guidanceVariant) accepts the registered variant names (`apg_classic`, `cfg`,
`adg`, `adg_w_norm`, `adg_wo_clip`); include a short example sentence showing a
form field carrying guidance_params as a JSON string (e.g.
'{"eta":1.0,"norm_threshold":2.5}') and mention that defaults depend on the
variant.

In `@docs/zh/API.md`:
- Around line 212-213: The docs example for guidance_params is inconsistent:
change the example `{"eta": 1.0, "norm_threshold": 2.5, "momentum": -0.75}` to
use `eta: 0.0` (keeping `momentum: -0.75`) so it matches the PR objective
(pre-refactor APG default) and align with the example in the model field
definition; update the `guidance_params` example in the docs entry for
`guidance_variant`/`guidance_params` and verify the example in the
`release_task_models.py` model field uses the same `eta: 0.0` value to keep all
references consistent.

---

Outside diff comments:
In `@acestep/core/generation/handler/service_generate_execute.py`:
- Around line 192-225: The MLX branch is dropping guidance settings: when
calling _mlx_run_diffusion from service_generate_execute (the MLX path), it does
not pass generate_kwargs["guidance_variant"] or ["guidance_params"], so
non-default guidance is silently ignored; fix by either (A) adding
guidance_variant=generate_kwargs.get("guidance_variant", "apg_classic") and
guidance_params=generate_kwargs.get("guidance_params", {}) to the
_mlx_run_diffusion call so the MLX runner receives these options, or (B) if MLX
truly doesn't support them, emit a one-time logger.warning in the MLX branch
(before calling _mlx_run_diffusion) when generate_kwargs.get("guidance_variant")
!= "apg_classic" or generate_kwargs.get("guidance_params") is non-empty to alert
users that guidance was dropped.

---

Nitpick comments:
In `@acestep/api/http/release_task_request_builder.py`:
- Around line 39-48: Hoist the json import to module scope and avoid the bare
except: when handling guidance_params (the parser.get("guidance_params") branch)
narrow the exception to (json.JSONDecodeError, TypeError) and apply the same
strip+json.loads→dict logic currently in RequestParser._parse_json; better yet,
call or reuse RequestParser._parse_json (or extract its helper) instead of
duplicating logic so guidance_params follows the same dict/pass-through,
JSON-string→dict, and fallback behavior as RequestParser._parse_json.

In `@acestep/core/generation/handler/generate_music.py`:
- Around line 225-240: The generate_music() docstring’s Args section is missing
the newly added guidance_variant and guidance_params parameters; update the
public docstring for the generate_music function to list and describe these two
arguments (expected types, allowed values or enum for guidance_variant,
structure and keys for guidance_params, default behavior if omitted), include
any effects on guidance_scale or interaction with other args (seed,
infer_method), and mention relevant raised exceptions or validation errors
thrown when invalid guidance values are provided so the docstring fully
documents inputs/outputs and error conditions.

In `@acestep/core/generation/handler/service_generate.py`:
- Around line 41-42: Add one-line descriptions for the two new parameters to the
function's docstring Args section: document guidance_variant (str) as the
guidance algorithm/variant to use (default "apg_classic") and guidance_params
(Optional[Dict[str, Any]]) as optional tuning parameters for the guidance
engine. Locate the docstring for the function that declares
guidance_variant/guidance_params and append these two entries to the Args block
for API clarity.

In `@acestep/models/common/guidance_registry_test.py`:
- Around line 1-238: This file exceeds the 200 LOC cap—split it into two test
modules: one for registry API tests (keep _fixed_tensors helper and
GuidanceRegistryTests) and a second for wrapper/state/corrector parity tests
(move GuidanceWrapperParityTests, GuidanceStateLifecycleTests,
GuidanceCorrectorStepTests and any tests that reference apg_forward,
adg_forward, cfg_forward, MomentumBuffer, get_guidance_fn, register_guidance,
registered_guidance_names, _MOMENTUM_KEY). Ensure both new modules import the
shared helper _fixed_tensors (or duplicate it if simpler), update imports to
reference acestep.models.common.apg_guidance and
acestep.models.common.guidance_registry symbols used in the moved tests, and run
tests to confirm names like get_guidance_fn("apg_classic") still resolve; remove
the original oversized file once split.

In `@acestep/models/common/guidance_registry.py`:
- Around line 77-84: The register_guidance decorator currently overwrites
entries in _GUIDANCE_REGISTRY; change it to detect duplicates by checking if
name is already a key in _GUIDANCE_REGISTRY inside the decorator (or before
assignment) and raise a clear exception (e.g., ValueError) that includes the
conflicting name and guidance function name instead of silently overwriting;
update the function register_guidance and its inner decorator to perform this
guard using the unique symbols register_guidance and _GUIDANCE_REGISTRY.
- Around line 117-154: The wrapper guidance_apg_classic intentionally uses
dims=(1,) rather than apg_forward's default [-1]; add a one-line clarifying
comment on the dims parameter (e.g., next to "dims: Any = (1,),") stating that
APG should project along the sequence/time axis (dim 1) to match the ACE-Step
shape convention [B, T, C] and existing tests (guidance_registry_test.py,
guidance_registry_sampler_regression_test.py) and MLX usage in dit_generate.py;
this keeps the divergence from apg_forward's default explicit and maintainable.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 1edc4be0-c208-40c8-b8ca-1798872823c3

📥 Commits

Reviewing files that changed from the base of the PR and between 1d9d2d3 and f693645.

📒 Files selected for processing (24)
  • acestep/api/http/release_task_models.py
  • acestep/api/http/release_task_models_test.py
  • acestep/api/http/release_task_param_parser.py
  • acestep/api/http/release_task_param_parser_test.py
  • acestep/api/http/release_task_request_builder.py
  • acestep/api/http/release_task_request_builder_test.py
  • acestep/api/job_generation_setup.py
  • acestep/api/job_generation_setup_test.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/guidance_registry.py
  • acestep/models/common/guidance_registry_sampler_regression_test.py
  • acestep/models/common/guidance_registry_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
  • docs/en/API.md
  • docs/ja/API.md
  • docs/ko/API.md
  • docs/zh/API.md

Comment on lines +205 to +206
guidance_variant: str = "apg_classic",
guidance_params: Optional[Dict[str, Any]] = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Preserve “omitted” vs explicit apg_classic.

Defaulting this public hop to "apg_classic" makes downstream legacy resolution unable to tell whether the caller omitted guidance_variant or explicitly requested APG while use_adg=True. Use an Optional[str] sentinel through the internal call chain and resolve the effective default only at the guidance-dispatch boundary.

🐛 Proposed direction
-        guidance_variant: str = "apg_classic",
+        guidance_variant: Optional[str] = None,
         guidance_params: Optional[Dict[str, Any]] = None,

Then resolve downstream as:

if guidance_variant is None:
    resolved_variant = "adg" if use_adg else "apg_classic"
else:
    resolved_variant = guidance_variant
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/generation/handler/generate_music.py` around lines 205 - 206,
Change the public parameter guidance_variant in generate_music (and any callers
in the same internal call chain) from a concrete default "apg_classic" to
Optional[str] (i.e., allow None) so upstream callers can omit it; propagate this
None sentinel through internal functions rather than converting to "apg_classic"
immediately, and only resolve the effective variant at the guidance-dispatch
boundary (the function/method that actually selects guidance behavior—e.g., the
guidance dispatcher) with logic like: if guidance_variant is None:
resolved_variant = "adg" if use_adg else "apg_classic" else: resolved_variant =
guidance_variant. Ensure signatures and any tests/places that construct the call
use Optional[str] and handle None accordingly.

Comment thread acestep/models/base/modeling_acestep_v15_base.py Outdated
Comment thread acestep/models/common/guidance_registry_sampler_regression_test.py Outdated
guidance_scale: float,
state: Dict[str, Any],
*,
angle_clip: float = 3.14 / 6,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Confirm whether pre-refactor adg_*forward call sites used 3.14 or math.pi.
rg -nP --type=py -C2 '(adg_forward|adg_w_norm_forward|adg_wo_clip_forward)\s*\('
git log -p -S 'angle_clip' --all -- 'acestep/models/**/*.py' | head -120

Repository: ace-step/ACE-Step-1.5

Length of output: 6326


🏁 Script executed:

cat -n acestep/models/common/guidance_registry.py | sed -n '220,235p'

Repository: ace-step/ACE-Step-1.5

Length of output: 671


🏁 Script executed:

cat -n acestep/models/common/guidance_registry.py | sed -n '250,270p'

Repository: ace-step/ACE-Step-1.5

Length of output: 856


🏁 Script executed:

cat -n acestep/models/common/apg_guidance.py | sed -n '100,180p'

Repository: ace-step/ACE-Step-1.5

Length of output: 4196


🏁 Script executed:

cat -n acestep/models/common/apg_guidance.py | sed -n '185,230p'

Repository: ace-step/ACE-Step-1.5

Length of output: 1763


🏁 Script executed:

cat -n acestep/models/common/guidance_registry.py | head -30

Repository: ace-step/ACE-Step-1.5

Length of output: 1726


🏁 Script executed:

git log --all --oneline -- acestep/models/common/guidance_registry.py | head -20

Repository: ace-step/ACE-Step-1.5

Length of output: 141


🏁 Script executed:

git log --all -p -S '3.14' -- acestep/models/common/apg_guidance.py acestep/models/common/guidance_registry.py | head -150

Repository: ace-step/ACE-Step-1.5

Length of output: 5363


🏁 Script executed:

git log --all -p -S 'pi/6' -- acestep/models/common/apg_guidance.py acestep/models/common/guidance_registry.py | head -150

Repository: ace-step/ACE-Step-1.5

Length of output: 5363


🏁 Script executed:

cat -n acestep/models/common/guidance_registry.py | sed -n '1,50p'

Repository: ace-step/ACE-Step-1.5

Length of output: 2435


🏁 Script executed:

cat -n acestep/models/common/guidance_registry.py | sed -n '275,310p'

Repository: ace-step/ACE-Step-1.5

Length of output: 1209


🏁 Script executed:

rg -n "3.14|math.pi" --type=py acestep/models/common/guidance_registry_test.py

Repository: ace-step/ACE-Step-1.5

Length of output: 142


Use math.pi instead of 3.14 for ADG angle-clip defaults.

Both acestep/models/common/guidance_registry.py (lines 225, 257) and acestep/models/common/apg_guidance.py (lines 113, 197) use the truncated π approximation 3.14. The inline comment at apg_guidance.py:113 (# pi/6 by default) shows the intent was exact π values. The 3.14 approximation introduces ~0.05% error and violates the coding guideline to use explicit, readable code. Add import math to both files and replace 3.14 / 6 with math.pi / 6 and 3.14 / 3 with math.pi / 3. Note that tests in guidance_registry_test.py:112 explicitly reference angle_clip=3.14 / 6 and will need updating.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/guidance_registry.py` at line 225, Replace the
hardcoded 3.14 approximations with math.pi: add import math to the modules that
define angle_clip defaults, change occurrences of 3.14 / 6 to math.pi / 6 and
3.14 / 3 to math.pi / 3 (look for the angle_clip default assignments and APG/ADG
guidance default constructors or registry entries referencing angle_clip), and
update any unit tests that assert angle_clip == 3.14 / 6 to use the new
math.pi-based values so tests reflect the exact π constants.

Comment thread acestep/models/sft/modeling_acestep_v15_base.py Outdated
Comment thread acestep/models/xl_base/modeling_acestep_v15_xl_base.py Outdated
Comment thread acestep/models/xl_sft/modeling_acestep_v15_xl_base.py Outdated
Comment thread docs/en/API.md Outdated
Comment thread docs/zh/API.md Outdated
Comment on lines +212 to +213
| `guidance_variant` | string | `"apg_classic"` | 已注册的引导变体名称。内置: `apg_classic`、`cfg`、`adg`、`adg_w_norm`、`adg_wo_clip` |
| `guidance_params` | object | `null` | 变体特定的参数覆盖(例如 `apg_classic` 的 `{"eta": 1.0, "norm_threshold": 2.5, "momentum": -0.75}`)。默认值取决于变体 |

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Example eta value inconsistent with defaults and other docs.

The example shows "eta": 1.0, but per the PR objectives the pre-refactor APG default is eta=0.0 (momentum=-0.75 is correct). The model field description in acestep/api/http/release_task_models.py uses "eta": 0.1 as its example. Align the three so users don't infer 1.0 is the default.

📝 Suggested fix
-| `guidance_params` | object | `null` | 变体特定的参数覆盖(例如 `apg_classic` 的 `{"eta": 1.0, "norm_threshold": 2.5, "momentum": -0.75}`)。默认值取决于变体 |
+| `guidance_params` | object | `null` | 变体特定的参数覆盖(例如 `apg_classic` 的 `{"eta": 0.0, "norm_threshold": 2.5, "momentum": -0.75}`,与重构前默认一致)。默认值取决于变体 |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| `guidance_variant` | string | `"apg_classic"` | 已注册的引导变体名称。内置: `apg_classic``cfg``adg``adg_w_norm``adg_wo_clip` |
| `guidance_params` | object | `null` | 变体特定的参数覆盖(例如 `apg_classic``{"eta": 1.0, "norm_threshold": 2.5, "momentum": -0.75}`)。默认值取决于变体 |
| `guidance_variant` | string | `"apg_classic"` | 已注册的引导变体名称。内置: `apg_classic``cfg``adg``adg_w_norm``adg_wo_clip` |
| `guidance_params` | object | `null` | 变体特定的参数覆盖(例如 `apg_classic``{"eta": 0.0, "norm_threshold": 2.5, "momentum": -0.75}`,与重构前默认一致)。默认值取决于变体 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/zh/API.md` around lines 212 - 213, The docs example for guidance_params
is inconsistent: change the example `{"eta": 1.0, "norm_threshold": 2.5,
"momentum": -0.75}` to use `eta: 0.0` (keeping `momentum: -0.75`) so it matches
the PR objective (pre-refactor APG default) and align with the example in the
model field definition; update the `guidance_params` example in the docs entry
for `guidance_variant`/`guidance_params` and verify the example in the
`release_task_models.py` model field uses the same `eta: 0.0` value to keep all
references consistent.

@FlexOr2

FlexOr2 commented Apr 22, 2026

Copy link
Copy Markdown
Author

Thanks @coderabbitai — review addressed in b80d91a. Summary:

Major (findings 1–5, the use_adg override bug):
Core fix applied. guidance_variant default changed from "apg_classic" to Optional[str] = None across the HTTP Pydantic model, the request builder, GenerationParams, all four handler layers, and all four model generate_audio signatures. Resolution logic now distinguishes "caller omitted" (None → legacy use_adg fallback applies) from "caller explicitly set apg_classic" (always honored, never silently switched to adg). Default-value test expectations updated.

Minor:

  • Finding 6 (Ruff B905): strict=True added to all five zip() calls in the sampler regression tests.
  • Findings 8, 9 (docs): default column now shows null in all four language API.md files with an inline note on the use_adg fallback. Added multipart form-encoding guidance for guidance_params, camelCase alias mention (guidanceVariant / guidanceParams), and the confusing {"eta": 1.0} example shrunk to {"eta": 0.1}.

Finding 7 (3.14 vs math.pi in angle_clip default): verified the registry wrappers' defaults (3.14 / 6 for adg, 3.14 / 3 for adg_w_norm) match the pre-refactor values in apg_guidance.py exactly. Consistent by design; no change needed.

PR still Draft — happy to take another round of feedback.

@coderabbitai

coderabbitai Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

@FlexOr2 Thanks for the thorough summary — the fixes in b80d91a look well-reasoned. The Optional[str] = None sentinel approach for guidance_variant is the right call: it cleanly separates "caller was silent" from "caller explicitly requested apg_classic", which was the root of the override bug. The strict=True zip fixes and docs updates are appreciated too. Acknowledged on finding 7 — preserving 3.14/6 and 3.14/3 to stay byte-identical with the pre-refactor constants is a sound design choice; I won't flag it again.

Triggering a fresh incremental review on b80d91a now.


🧠 Learnings used
Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: AGENTS.md:0-0
Timestamp: 2026-02-28T14:14:55.765Z
Learning: AI-Agent Workflow: Add/update focused tests. Self-review only changed hunks for regressions and scope creep. Summarize risk, validation, and non-target impact in PR notes.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 837
File: acestep/core/generation/handler/init_service_loader_components.py:44-47
Timestamp: 2026-03-14T10:35:13.950Z
Learning: In `acestep/core/generation/handler/init_service_loader.py` (origin/main, pre-extraction), `_load_text_encoder_and_tokenizer()` already used `self.text_encoder.to("cpu").to(self.dtype)` in the CPU-offload branch — using `self.dtype` (which can be bfloat16) instead of `torch.float32`. This is a pre-existing behavior, not introduced by PR `#837`. PR `#837` moves it unchanged into `init_service_loader_components.py` and intends to fix it there. Do not flag the original `init_service_loader.py` pattern as a new regression in future reviews.

Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: AGENTS.md:0-0
Timestamp: 2026-02-28T14:14:55.765Z
Learning: PR Readiness: New/updated tests cover changed behavior and edge cases. No unrelated refactor/formatting churn.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 774
File: acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py:70-95
Timestamp: 2026-03-05T23:12:28.095Z
Learning: In ace-step/ACE-Step-1.5, the hardcoded English placeholder strings in `acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py` (e.g., `"Music description..."` for `edit_caption`, `"pop, electronic, dance..."` for `edit_genre`, `Dropdown` choices `["Use Global Ratio", "Caption", "Genre"]` and its default value, and the lyrics placeholder `"[Verse 1]\nLyrics here...\n\n[Chorus]\n..."`) are pre-existing localization gaps carried forward from before the decomposition refactor. They are deferred to a follow-up PR and should not be flagged as new issues introduced by decomposition PRs.
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
acestep/models/xl_base/modeling_acestep_v15_xl_base.py (1)

1843-1883: Document the new guidance inputs on generate_audio.

The public generation entrypoint now accepts guidance_variant and guidance_params, but the modified function still has no docstring describing these inputs or the ValueError from unknown variants.

Proposed docstring addition
         guidance_params: Optional[Dict[str, Any]] = None,
         **kwargs,
     ):
+        """Generate audio latents using the configured sampler and guidance.
+
+        Args:
+            guidance_variant: Registered guidance variant name. If ``None``,
+                legacy ``use_adg`` selects between APG classic and ADG.
+            guidance_params: Optional keyword parameters forwarded to the
+                selected guidance wrapper.
+
+        Returns:
+            Dictionary containing generated target latents and timing metadata.
+
+        Raises:
+            ValueError: If ``guidance_variant`` is not registered.
+        """
         # Backward-compat: accept the old misspelled key "diffusion_guidance_sale"

As per coding guidelines, "Docstrings are mandatory for all new or modified Python modules, classes, and functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py` around lines 1843 -
1883, The generate_audio function now accepts new parameters guidance_variant
and guidance_params but lacks a docstring describing them (and the ValueError
raised for unknown guidance variants); update the function-level docstring for
generate_audio to document guidance_variant (allowed string values and
semantics), guidance_params (expected dict keys/types and defaults), and clearly
state that an unknown guidance_variant will raise ValueError, including any
examples or default behaviors (e.g., default None or behavior when not
provided); ensure the docstring follows existing style in the module and
mentions any interactions with diffusion_guidance_scale, guidance_interval
params, and where to find supported variants.
acestep/models/common/guidance_registry_sampler_regression_test.py (1)

77-78: Make the parity assertions exact.

These tests document a byte-identical contract, but default torch.allclose() allows drift. Use zero-tolerance assertions so registry changes cannot introduce tiny deltas unnoticed.

Proposed fix
-        for ref, got in zip(reference_outputs, reg_outputs, strict=True):
-            self.assertTrue(torch.allclose(ref, got))
+        for ref, got in zip(reference_outputs, reg_outputs, strict=True):
+            torch.testing.assert_close(got, ref, rtol=0, atol=0)
@@
-        for (ref_main, ref_corr), (got_main, got_corr) in zip(reference_outputs, reg_outputs, strict=True):
-            self.assertTrue(torch.allclose(ref_main, got_main))
-            self.assertTrue(torch.allclose(ref_corr, got_corr))
+        for (ref_main, ref_corr), (got_main, got_corr) in zip(reference_outputs, reg_outputs, strict=True):
+            torch.testing.assert_close(got_main, ref_main, rtol=0, atol=0)
+            torch.testing.assert_close(got_corr, ref_corr, rtol=0, atol=0)
@@
-        for (ref_main, ref_corr), (got_main, got_corr) in zip(reference_outputs, reg_outputs, strict=True):
-            self.assertTrue(torch.allclose(ref_main, got_main))
-            self.assertTrue(torch.allclose(ref_corr, got_corr))
+        for (ref_main, ref_corr), (got_main, got_corr) in zip(reference_outputs, reg_outputs, strict=True):
+            torch.testing.assert_close(got_main, ref_main, rtol=0, atol=0)
+            torch.testing.assert_close(got_corr, ref_corr, rtol=0, atol=0)

Also applies to: 112-114, 168-170

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/common/guidance_registry_sampler_regression_test.py` around
lines 77 - 78, Replace the fuzzy comparisons using torch.allclose in
guidance_registry_sampler_regression_test.py with exact byte-equality checks so
tiny numeric drift fails the test: where the test iterates "for ref, got in
zip(reference_outputs, reg_outputs, strict=True)" (and the similar loops at the
other occurrences), change the assertion to require exact equality (e.g., use
torch.equal / Tensor.equal via self.assertTrue or
self.assertTrue(ref.equal(got))) so the test enforces zero-tolerance parity
between reference_outputs and reg_outputs.
acestep/models/xl_sft/modeling_acestep_v15_xl_base.py (1)

1843-1883: Document the new guidance inputs on generate_audio.

This public generation entrypoint now accepts registry-specific guidance controls; add a concise docstring covering guidance_variant, guidance_params, returned data, and unknown-variant errors.

Proposed docstring addition
         guidance_params: Optional[Dict[str, Any]] = None,
         **kwargs,
     ):
+        """Generate audio latents using the configured sampler and guidance.
+
+        Args:
+            guidance_variant: Registered guidance variant name. If ``None``,
+                legacy ``use_adg`` selects between APG classic and ADG.
+            guidance_params: Optional keyword parameters forwarded to the
+                selected guidance wrapper.
+
+        Returns:
+            Dictionary containing generated target latents and timing metadata.
+
+        Raises:
+            ValueError: If ``guidance_variant`` is not registered.
+        """
         # Backward-compat: accept the old misspelled key "diffusion_guidance_sale"

As per coding guidelines, "Docstrings are mandatory for all new or modified Python modules, classes, and functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/models/xl_sft/modeling_acestep_v15_xl_base.py` around lines 1843 -
1883, Add a concise docstring to the public generate_audio function that
documents the purpose and types of the new registry-specific guidance inputs
(guidance_variant: Optional[str], guidance_params: Optional[Dict[str, Any]]),
expected format/keys for guidance_params (a short example or accepted keys),
what the function returns (describe returned tensor(s)/dict fields such as audio
latents, masks, and any metadata), and the behavior when an unknown
guidance_variant is passed (explicitly state that the function validates
guidance_variant and raises a ValueError with a clear message listing supported
variants). Reference the generate_audio signature and ensure the docstring is
placed immediately below the def line and mentions
cfg_interval_start/cfg_interval_end if relevant to guidance behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/api/http/release_task_request_builder.py`:
- Around line 39-48: The current parsing block for guidance_params uses a broad
except and an inline import; update it to import the stdlib json module at the
top of the file and replace the broad except with an explicit
json.JSONDecodeError handler when calling _json.loads on guidance_params
(keeping the logic that sets guidance_params to the parsed dict or None and
preserving the fallback for non-dict values); reference the
parser.get("guidance_params") variable, the parsed_guidance local, and the
json.JSONDecodeError exception in your change so only JSON decode errors are
caught and other exceptions bubble up.

In `@acestep/inference.py`:
- Around line 134-139: Update the GenerationParams class docstring to document
the new public fields guidance_variant and guidance_params: describe that
guidance_variant is an optional string controlling the guidance strategy (None =
"omitted" with legacy use_adg behavior of defaulting to 'adg' vs 'apg_classic',
and an explicit string always takes precedence over use_adg) and that
guidance_params is an optional dict of guidance-specific parameters; mention
types and typical values/semantics and any precedence behavior so callers
understand how guidance selection works. Ensure the docstring is added/updated
in the GenerationParams class definition (where guidance_variant and
guidance_params are declared) under the Attributes section following existing
style and brevity conventions.

In `@acestep/models/base/modeling_acestep_v15_base.py`:
- Around line 1867-1868: The public method generate_audio was extended with new
parameters guidance_variant and guidance_params but lacks docstring coverage;
update the generate_audio(...) docstring to briefly describe the purpose of
guidance_variant (allowed variant names) and guidance_params (dict of
hyperparameters), specify the return shape/type (e.g., bytes, AudioSegment or
tuple with sample rate), and document that a ValueError is raised for unknown
guidance_variant values; mention default behavior when guidance_variant is None
and any expected keys in guidance_params if relevant.

In `@acestep/models/sft/modeling_acestep_v15_base.py`:
- Around line 1868-1869: Update the generate_audio() docstring in
class/modeling_acestep_v15_base to document the two new public parameters
guidance_variant and guidance_params, describing allowed types/values (e.g.,
guidance_variant: Optional[str] with supported variant names and behavior;
guidance_params: Optional[Dict[str, Any]] with example keys and expected types)
and the effect they have on generation, and add a raised exceptions section
documenting the UnknownVariant or ValueError path when an unsupported
guidance_variant is passed so callers know the error semantics; reference the
generate_audio() signature and the internal branch that validates
guidance_variant to ensure the unknown-variant error is explicitly mentioned.

---

Nitpick comments:
In `@acestep/models/common/guidance_registry_sampler_regression_test.py`:
- Around line 77-78: Replace the fuzzy comparisons using torch.allclose in
guidance_registry_sampler_regression_test.py with exact byte-equality checks so
tiny numeric drift fails the test: where the test iterates "for ref, got in
zip(reference_outputs, reg_outputs, strict=True)" (and the similar loops at the
other occurrences), change the assertion to require exact equality (e.g., use
torch.equal / Tensor.equal via self.assertTrue or
self.assertTrue(ref.equal(got))) so the test enforces zero-tolerance parity
between reference_outputs and reg_outputs.

In `@acestep/models/xl_base/modeling_acestep_v15_xl_base.py`:
- Around line 1843-1883: The generate_audio function now accepts new parameters
guidance_variant and guidance_params but lacks a docstring describing them (and
the ValueError raised for unknown guidance variants); update the function-level
docstring for generate_audio to document guidance_variant (allowed string values
and semantics), guidance_params (expected dict keys/types and defaults), and
clearly state that an unknown guidance_variant will raise ValueError, including
any examples or default behaviors (e.g., default None or behavior when not
provided); ensure the docstring follows existing style in the module and
mentions any interactions with diffusion_guidance_scale, guidance_interval
params, and where to find supported variants.

In `@acestep/models/xl_sft/modeling_acestep_v15_xl_base.py`:
- Around line 1843-1883: Add a concise docstring to the public generate_audio
function that documents the purpose and types of the new registry-specific
guidance inputs (guidance_variant: Optional[str], guidance_params:
Optional[Dict[str, Any]]), expected format/keys for guidance_params (a short
example or accepted keys), what the function returns (describe returned
tensor(s)/dict fields such as audio latents, masks, and any metadata), and the
behavior when an unknown guidance_variant is passed (explicitly state that the
function validates guidance_variant and raises a ValueError with a clear message
listing supported variants). Reference the generate_audio signature and ensure
the docstring is placed immediately below the def line and mentions
cfg_interval_start/cfg_interval_end if relevant to guidance behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 656581f2-9d31-4ace-8294-d3a05ac6d64d

📥 Commits

Reviewing files that changed from the base of the PR and between f693645 and b80d91a.

📒 Files selected for processing (19)
  • acestep/api/http/release_task_models.py
  • acestep/api/http/release_task_models_test.py
  • acestep/api/http/release_task_request_builder.py
  • acestep/api/job_generation_setup.py
  • acestep/api/job_generation_setup_test.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/guidance_registry_sampler_regression_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
  • docs/en/API.md
  • docs/ja/API.md
  • docs/ko/API.md
  • docs/zh/API.md
✅ Files skipped from review due to trivial changes (6)
  • acestep/api/http/release_task_models_test.py
  • acestep/core/generation/handler/generate_music_execute.py
  • docs/zh/API.md
  • acestep/api/job_generation_setup_test.py
  • docs/en/API.md
  • docs/ja/API.md
🚧 Files skipped from review as they are similar to previous changes (4)
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/service_generate_execute.py
  • docs/ko/API.md

Comment thread acestep/api/http/release_task_request_builder.py
Comment thread acestep/inference.py
Comment thread acestep/models/base/modeling_acestep_v15_base.py
Comment thread acestep/models/sft/modeling_acestep_v15_base.py
@FlexOr2

FlexOr2 commented Apr 22, 2026

Copy link
Copy Markdown
Author

Second round addressed in 8ecd184:

  • Docstrings on the two new GenerationParams fields (slotted into the existing Attributes block in inference.py) and on generate_audio in all four model variants (base, sft, xl_base, xl_sft) — covers guidance_variant, guidance_params, and the ValueError path from unknown variant names.
  • Narrowed except Exceptionexcept json.JSONDecodeError around guidance_params JSON parsing in release_task_request_builder.py; import promoted to the stdlib group at module top per PEP 8.

No behaviour change, tests unchanged.

@FlexOr2

FlexOr2 commented Apr 22, 2026

Copy link
Copy Markdown
Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@FlexOr2 FlexOr2 marked this pull request as ready for review April 25, 2026 15:53
FlexOr2 and others added 6 commits May 7, 2026 20:46
Add a Protocol-typed registry at acestep/models/common/guidance_registry.py
that exposes five guidance variants — apg_classic, cfg, adg, adg_w_norm,
adg_wo_clip — behind a single callable contract:

    fn(pred_cond, pred_uncond, guidance_scale, state, **params) -> Tensor

The sampler owns a per-generate_audio 'state' dictionary and populates
per-step values (latents, sigma) before calling into the registered fn.
Variant parameters (eta, norm_threshold, angle_clip, momentum) arrive via
**params so users can tune a single variant without signature churn.

The wrappers call the existing raw functions in apg_guidance.py without
modifying them; apg_classic carries a MomentumBuffer in state, ADG
variants read latents/sigma out of state, and corrector steps
(state['step_role'] == 'corrector') preserve the pre-refactor behaviour
byte-for-byte: APG corrector downgrades to CFG, ADG corrector at sigma=0
downgrades to CFG.

get_guidance_fn() raises ValueError with the registered names on an
unknown variant instead of silently falling back to a default.
Thread the pluggable guidance variant through all four model families
(base, sft, xl_base, xl_sft).  Each generate_audio() now accepts two new
keyword arguments:

    guidance_variant: str = 'apg_classic'
    guidance_params: Optional[Dict[str, Any]] = None

Inside generate_audio the previously hard-coded
apg_forward/adg_forward/cfg_forward call sites (both the main sampler
step and the Heun corrector) are replaced by a single guidance_fn
dispatch sourced from the registry.  A local guidance_state dict is
created once per call and passed to every step; per-step values
(latents, sigma) and the corrector/main role marker are written into it
before each call.

Backward compatibility is preserved:

  * Default ('apg_classic', {}) with use_adg=False reproduces the
    pre-refactor apg_forward + cfg_forward (corrector) sequence.
  * Default variant + use_adg=True resolves to the 'adg' variant, which
    reproduces adg_forward + (adg_forward | cfg_forward at sigma=0).
  * Explicit guidance_variant always wins over use_adg.

The sampler-level regression test
(guidance_registry_sampler_regression_test.py) proves byte-identical
output to the hand-written pre-refactor call sequence across both Euler
and Heun sampler paths.
Thread the two new guidance fields all the way from the HTTP request
body to model.generate_audio():

  GenerateMusicRequest
    -> PARAM_ALIASES (snake_case + camelCase)
    -> build_generate_music_request (JSON-string decoding with None
       fallback on malformed input)
    -> GenerationParams
    -> build_generation_setup
    -> generate_music handler
    -> _run_generate_music_service_with_progress
    -> service_generate
    -> _build_service_generate_kwargs
    -> model.generate_audio(guidance_variant=..., guidance_params=...)

Both fields are optional and default to the backward-compatible values
('apg_classic', None), so existing clients see no behaviour change.  A
missing 'guidance_params' payload leaves the choice of per-variant
parameters up to the registry wrapper defaults (eta=0, norm_threshold=2.5,
momentum=-0.75, angle_clip as documented per variant).

New tests:
  * release_task_models_test — Pydantic defaults + explicit acceptance
  * release_task_param_parser_test — snake_case + camelCase alias
    resolution
  * release_task_request_builder_test — JSON-string decoding and the
    invalid-JSON-becomes-None safeguard
  * job_generation_setup_test — GenerationParams forwarding of both
    fields, including the 'missing attribute on request' default path
Adds the two new HTTP params introduced by this PR to the API reference
tables in all four language docs (en, ja, ko, zh), next to the existing
use_adg row.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Semantic fix (CodeRabbit findings 1–5, severity Major):
- guidance_variant default changes from "apg_classic" to Optional[str] = None
  across the HTTP Pydantic model, the request builder, GenerationParams,
  all four handler layers, and all four model generate_audio signatures.
- Resolution logic now distinguishes "caller omitted" (None) from "caller
  explicitly requested apg_classic". Only None falls back to the legacy
  use_adg -> adg mapping; an explicit variant always wins over use_adg.
- Updated docstrings + updated the two default-value test expectations
  (GenerateMusicRequest and build_generation_setup) to assert None.

Style (CodeRabbit finding 6):
- Added strict=True to all five zip() calls in the sampler regression
  tests so pair-length drift fails loudly (Ruff B905).

Docs (CodeRabbit findings 8, 9):
- docs/{en,ja,ko,zh}/API.md: default column shows null with an inline
  note on use_adg fallback, plus multipart form-encoding guidance for
  guidance_params, camelCase alias mention, and the eta example shrunk
  to {"eta": 0.1} (was {"eta": 1.0} which was both confusingly close
  to typical defaults and diverged across locales).

3.14 vs math.pi (CodeRabbit finding 7):
- Verified the angle_clip defaults in guidance_registry.py (3.14/6 for
  adg, 3.14/3 for adg_w_norm) match the pre-refactor values in
  apg_guidance.py exactly. No change; consistent by design.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CodeRabbit re-review of b80d91a flagged four minor nits (no new
semantic issues). All four addressed:

- Docstring for the two new GenerationParams fields in inference.py,
  slotting into the existing Attributes block.
- Docstring section on generate_audio in all four model variants
  (base, sft, xl_base, xl_sft) covering guidance_variant,
  guidance_params, and the ValueError path from unknown variant names.
- Narrowed `except Exception` -> `except json.JSONDecodeError` around
  guidance_params parsing in release_task_request_builder.py, with the
  import promoted to the stdlib group at module top per PEP 8.

No behaviour change; tests unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@FlexOr2 FlexOr2 force-pushed the feat/pluggable-guidance-registry branch from 88bb2dd to e4f2ce0 Compare May 7, 2026 18:54

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
acestep/api/job_generation_setup_test.py (1)

236-291: ⚡ Quick win

LGTM — consider adding a negative test for unknown guidance_variant.

Both new tests are correct and cover the expected forwarding behavior. One gap worth noting: the PR description states that unknown variant names raise ValueError, but there's no test exercising that path. Adding a third test case would complete the contract verification:

✅ Suggested additional test
def test_guidance_variant_unknown_raises_value_error(self) -> None:
    """An unregistered guidance_variant name should raise ValueError."""

    req = _base_req()
    req.guidance_variant = "nonexistent_variant"
    with self.assertRaises(ValueError):
        build_generation_setup(
            req=req,
            caption="cap",
            lyrics="lyr",
            bpm=None,
            key_scale="",
            time_signature="",
            audio_duration=None,
            thinking=False,
            sample_mode=False,
            format_has_duration=False,
            use_cot_caption=True,
            use_cot_language=True,
            lm_top_k=0,
            lm_top_p=0.9,
            parse_timesteps=lambda _value: None,
            is_instrumental=lambda _lyrics: False,
            default_dit_instruction="default instruction",
            task_instructions={},
        )

Note: If the ValueError is raised deeper in the call stack (e.g., inside generate_audio rather than build_generation_setup), a separate test at the registry or model level would be more appropriate.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@acestep/api/job_generation_setup_test.py` around lines 236 - 291, Add a
negative test that asserts build_generation_setup raises ValueError for an
unregistered guidance variant: create a test (e.g.,
test_guidance_variant_unknown_raises_value_error) that sets req.guidance_variant
to an invalid name like "nonexistent_variant" (leave guidance_params unset or
empty) and wraps the build_generation_setup(...) call in
self.assertRaises(ValueError); reference the existing helper _base_req and the
build_generation_setup function so the new test mirrors the other tests'
arguments and isolates the failure to guidance_variant resolution.
acestep/models/common/guidance_registry_test.py (1)

38-44: ⚡ Quick win

Avoid asserting the exact built-in registry set.

This will fail as soon as a new built-in guidance variant is added, even if the registry still behaves correctly. Assert that the required baseline names are present instead of requiring an exact match.

Suggested assertion shape
-        self.assertEqual(
-            sorted(registered_guidance_names()),
-            ["adg", "adg_w_norm", "adg_wo_clip", "apg_classic", "cfg"],
-        )
+        names = set(registered_guidance_names())
+        self.assertTrue(
+            {"adg", "adg_w_norm", "adg_wo_clip", "apg_classic", "cfg"} <= names
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@acestep/models/common/guidance_registry_test.py` around lines 38 - 44, The
test test_default_registry_has_five_variants currently asserts an exact list
from registered_guidance_names(); change it to assert that the required baseline
names are present rather than an exact match: call registered_guidance_names()
in the test and verify that the set (or all items) of baseline names
["adg","adg_w_norm","adg_wo_clip","apg_classic","cfg"] are contained in the
returned collection (e.g., using set.issuperset or all(name in ... for name in
...)) inside test_default_registry_has_five_variants so new built-ins don’t
break the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@acestep/models/common/guidance_registry_test.py`:
- Around line 1-238: This test module groups four distinct responsibilities into
one file (GuidanceRegistryTests, GuidanceWrapperParityTests,
GuidanceStateLifecycleTests, GuidanceCorrectorStepTests) and exceeds the 200-LOC
guideline; split the tests into focused modules (e.g. guidance_registry_test.py
for GuidanceRegistryTests, guidance_wrapper_parity_test.py for
GuidanceWrapperParityTests, guidance_state_lifecycle_test.py for
GuidanceStateLifecycleTests, guidance_corrector_step_test.py for
GuidanceCorrectorStepTests) keeping the same imports and fixtures
(_fixed_tensors, MomentumBuffer, get_guidance_fn, etc.), or if you cannot split
now, add a short PR note with a concrete follow-up split plan naming these
target files and a timeline before merging.

---

Nitpick comments:
In `@acestep/api/job_generation_setup_test.py`:
- Around line 236-291: Add a negative test that asserts build_generation_setup
raises ValueError for an unregistered guidance variant: create a test (e.g.,
test_guidance_variant_unknown_raises_value_error) that sets req.guidance_variant
to an invalid name like "nonexistent_variant" (leave guidance_params unset or
empty) and wraps the build_generation_setup(...) call in
self.assertRaises(ValueError); reference the existing helper _base_req and the
build_generation_setup function so the new test mirrors the other tests'
arguments and isolates the failure to guidance_variant resolution.

In `@acestep/models/common/guidance_registry_test.py`:
- Around line 38-44: The test test_default_registry_has_five_variants currently
asserts an exact list from registered_guidance_names(); change it to assert that
the required baseline names are present rather than an exact match: call
registered_guidance_names() in the test and verify that the set (or all items)
of baseline names ["adg","adg_w_norm","adg_wo_clip","apg_classic","cfg"] are
contained in the returned collection (e.g., using set.issuperset or all(name in
... for name in ...)) inside test_default_registry_has_five_variants so new
built-ins don’t break the test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: cd241e6f-e24d-4d0a-aa85-2691143e5242

📥 Commits

Reviewing files that changed from the base of the PR and between 8ecd184 and e4f2ce0.

📒 Files selected for processing (24)
  • acestep/api/http/release_task_models.py
  • acestep/api/http/release_task_models_test.py
  • acestep/api/http/release_task_param_parser.py
  • acestep/api/http/release_task_param_parser_test.py
  • acestep/api/http/release_task_request_builder.py
  • acestep/api/http/release_task_request_builder_test.py
  • acestep/api/job_generation_setup.py
  • acestep/api/job_generation_setup_test.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/inference.py
  • acestep/models/base/modeling_acestep_v15_base.py
  • acestep/models/common/guidance_registry.py
  • acestep/models/common/guidance_registry_sampler_regression_test.py
  • acestep/models/common/guidance_registry_test.py
  • acestep/models/sft/modeling_acestep_v15_base.py
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/xl_sft/modeling_acestep_v15_xl_base.py
  • docs/en/API.md
  • docs/ja/API.md
  • docs/ko/API.md
  • docs/zh/API.md
✅ Files skipped from review due to trivial changes (3)
  • docs/zh/API.md
  • docs/en/API.md
  • acestep/api/http/release_task_models_test.py
🚧 Files skipped from review as they are similar to previous changes (16)
  • acestep/api/job_generation_setup.py
  • acestep/api/http/release_task_param_parser.py
  • acestep/core/generation/handler/generate_music.py
  • acestep/core/generation/handler/generate_music_execute.py
  • acestep/core/generation/handler/service_generate_execute.py
  • acestep/api/http/release_task_models.py
  • acestep/inference.py
  • acestep/api/http/release_task_param_parser_test.py
  • docs/ko/API.md
  • acestep/api/http/release_task_request_builder.py
  • acestep/core/generation/handler/service_generate.py
  • acestep/api/http/release_task_request_builder_test.py
  • docs/ja/API.md
  • acestep/models/xl_base/modeling_acestep_v15_xl_base.py
  • acestep/models/common/guidance_registry_sampler_regression_test.py
  • acestep/models/common/guidance_registry.py

Comment on lines +1 to +238
"""Unit tests for the pluggable guidance-variant registry."""

from __future__ import annotations

import unittest

import torch

from acestep.models.common.apg_guidance import (
MomentumBuffer,
adg_forward,
adg_w_norm_forward,
adg_wo_clip_forward,
apg_forward,
cfg_forward,
)
from acestep.models.common.guidance_registry import (
_GUIDANCE_REGISTRY,
_MOMENTUM_KEY,
get_guidance_fn,
register_guidance,
registered_guidance_names,
)


def _fixed_tensors(seed: int = 0, shape=(1, 4, 8)):
"""Deterministic tensor pair for parity tests."""

generator = torch.Generator().manual_seed(seed)
pred_cond = torch.randn(*shape, generator=generator)
pred_uncond = torch.randn(*shape, generator=generator)
return pred_cond, pred_uncond


class GuidanceRegistryTests(unittest.TestCase):
"""Structural guarantees on the default registry contents."""

def test_default_registry_has_five_variants(self) -> None:
"""The five baseline guidance variants must ship with the module."""

self.assertEqual(
sorted(registered_guidance_names()),
["adg", "adg_w_norm", "adg_wo_clip", "apg_classic", "cfg"],
)

def test_register_guidance_adds_entry(self) -> None:
"""Decorator should add a callable entry resolvable by get_guidance_fn."""

sentinel = torch.tensor(3.14)

@register_guidance("unit_test_noop")
def _fake(_pc, _pu, _gs, _state, **_params):
return sentinel

try:
resolved = get_guidance_fn("unit_test_noop")
self.assertIs(resolved({}, {}, 1.0, {}), sentinel)
finally:
_GUIDANCE_REGISTRY.pop("unit_test_noop", None)

def test_get_guidance_fn_unknown_raises_valueerror(self) -> None:
"""Unknown variant lookup must surface a message listing registered names."""

with self.assertRaises(ValueError) as ctx:
get_guidance_fn("does_not_exist")
message = str(ctx.exception)
self.assertIn("does_not_exist", message)
for name in ("apg_classic", "cfg", "adg"):
self.assertIn(name, message)


class GuidanceWrapperParityTests(unittest.TestCase):
"""Wrappers must produce byte-identical output to the raw guidance functions."""

def test_apg_classic_wrapper_matches_apg_forward(self) -> None:
"""APG wrapper with default params equals apg_forward with fresh momentum."""

pred_cond, pred_uncond = _fixed_tensors(seed=1)
state: dict = {}
guided_via_registry = get_guidance_fn("apg_classic")(
pred_cond, pred_uncond, 5.0, state,
)
expected = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=5.0,
momentum_buffer=MomentumBuffer(),
dims=[1],
)
self.assertTrue(torch.allclose(guided_via_registry, expected))

def test_cfg_wrapper_matches_cfg_forward(self) -> None:
"""CFG wrapper delegates directly to cfg_forward."""

pred_cond, pred_uncond = _fixed_tensors(seed=2)
actual = get_guidance_fn("cfg")(pred_cond, pred_uncond, 4.0, {})
self.assertTrue(torch.allclose(actual, cfg_forward(pred_cond, pred_uncond, 4.0)))

def test_adg_wrapper_matches_adg_forward(self) -> None:
"""ADG wrapper forwards latents/sigma from state to adg_forward."""

pred_cond, pred_uncond = _fixed_tensors(seed=3)
latents, _ = _fixed_tensors(seed=33)
state = {"latents": latents, "sigma": 0.5}
actual = get_guidance_fn("adg")(pred_cond, pred_uncond, 6.0, state)
expected = adg_forward(
latents=latents,
noise_pred_cond=pred_cond,
noise_pred_uncond=pred_uncond,
sigma=0.5,
guidance_scale=6.0,
angle_clip=3.14 / 6,
)
self.assertTrue(torch.allclose(actual, expected))

def test_adg_w_norm_wrapper_matches_raw(self) -> None:
"""ADG-with-norm wrapper matches raw adg_w_norm_forward."""

pred_cond, pred_uncond = _fixed_tensors(seed=4)
latents, _ = _fixed_tensors(seed=44)
state = {"latents": latents, "sigma": 0.4}
actual = get_guidance_fn("adg_w_norm")(pred_cond, pred_uncond, 6.0, state)
expected = adg_w_norm_forward(
latents=latents,
noise_pred_cond=pred_cond,
noise_pred_uncond=pred_uncond,
sigma=0.4,
guidance_scale=6.0,
)
self.assertTrue(torch.allclose(actual, expected))

def test_adg_wo_clip_wrapper_matches_raw(self) -> None:
"""ADG-without-clip wrapper matches raw adg_wo_clip_forward."""

pred_cond, pred_uncond = _fixed_tensors(seed=5)
latents, _ = _fixed_tensors(seed=55)
state = {"latents": latents, "sigma": 0.3}
actual = get_guidance_fn("adg_wo_clip")(pred_cond, pred_uncond, 7.0, state)
expected = adg_wo_clip_forward(
latents=latents,
noise_pred_cond=pred_cond,
noise_pred_uncond=pred_uncond,
sigma=0.3,
guidance_scale=7.0,
)
self.assertTrue(torch.allclose(actual, expected))


class GuidanceStateLifecycleTests(unittest.TestCase):
"""State dict semantics across multiple APG steps."""

def test_momentum_state_persists_across_steps(self) -> None:
"""APG momentum buffer must accumulate across repeated wrapper calls."""

pred_cond, pred_uncond = _fixed_tensors(seed=6)
state: dict = {}
fn = get_guidance_fn("apg_classic")

buffer = MomentumBuffer()
expected_first = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=3.0,
momentum_buffer=buffer,
dims=[1],
)
expected_second = apg_forward(
pred_cond=pred_cond,
pred_uncond=pred_uncond,
guidance_scale=3.0,
momentum_buffer=buffer,
dims=[1],
)

first = fn(pred_cond, pred_uncond, 3.0, state)
self.assertIn(_MOMENTUM_KEY, state)
second = fn(pred_cond, pred_uncond, 3.0, state)

self.assertTrue(torch.allclose(first, expected_first))
self.assertTrue(torch.allclose(second, expected_second))

def test_state_dict_is_empty_at_fresh_call(self) -> None:
"""Each new generate_audio call should get a fresh empty state dict."""

pred_cond, pred_uncond = _fixed_tensors(seed=7)
fn = get_guidance_fn("apg_classic")

first_state: dict = {}
fn(pred_cond, pred_uncond, 3.0, first_state)
self.assertIn(_MOMENTUM_KEY, first_state)

second_state: dict = {}
self.assertNotIn(_MOMENTUM_KEY, second_state)
result = fn(pred_cond, pred_uncond, 3.0, second_state)
self.assertIn(_MOMENTUM_KEY, second_state)
self.assertTrue(torch.isfinite(result).all())


class GuidanceCorrectorStepTests(unittest.TestCase):
"""Corrector-step behaviour mirrors the pre-registry sampler."""

def test_apg_corrector_falls_back_to_cfg(self) -> None:
"""APG corrector must emit cfg_forward output and must not touch momentum."""

pred_cond, pred_uncond = _fixed_tensors(seed=8)
state: dict = {"step_role": "corrector"}
result = get_guidance_fn("apg_classic")(pred_cond, pred_uncond, 5.0, state)
self.assertTrue(torch.allclose(result, cfg_forward(pred_cond, pred_uncond, 5.0)))
self.assertNotIn(_MOMENTUM_KEY, state)

def test_adg_corrector_at_sigma_zero_falls_back_to_cfg(self) -> None:
"""ADG corrector must degrade to CFG when sigma == 0 to avoid NaNs."""

pred_cond, pred_uncond = _fixed_tensors(seed=9)
latents, _ = _fixed_tensors(seed=99)
state = {"latents": latents, "sigma": 0.0, "step_role": "corrector"}
result = get_guidance_fn("adg")(pred_cond, pred_uncond, 5.0, state)
self.assertTrue(torch.allclose(result, cfg_forward(pred_cond, pred_uncond, 5.0)))

def test_adg_corrector_nonzero_sigma_runs_adg(self) -> None:
"""ADG corrector with sigma > 0 must match adg_forward output."""

pred_cond, pred_uncond = _fixed_tensors(seed=10)
latents, _ = _fixed_tensors(seed=100)
state = {"latents": latents, "sigma": 0.5, "step_role": "corrector"}
result = get_guidance_fn("adg")(pred_cond, pred_uncond, 5.0, state)
expected = adg_forward(
latents=latents,
noise_pred_cond=pred_cond,
noise_pred_uncond=pred_uncond,
sigma=0.5,
guidance_scale=5.0,
)
self.assertTrue(torch.allclose(result, expected))


if __name__ == "__main__":
unittest.main()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Split this test module by responsibility.

This new file is already over the 200-LOC cap and bundles four separate concerns: registry structure, wrapper parity, state lifecycle, and corrector behavior. Please split it into a couple of focused test modules, or add the follow-up split plan in the PR notes.

As per coding guidelines, "If a module would exceed 200 LOC, split by responsibility before merging, or add a short justification in PR notes with a concrete follow-up split plan."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@acestep/models/common/guidance_registry_test.py` around lines 1 - 238, This
test module groups four distinct responsibilities into one file
(GuidanceRegistryTests, GuidanceWrapperParityTests, GuidanceStateLifecycleTests,
GuidanceCorrectorStepTests) and exceeds the 200-LOC guideline; split the tests
into focused modules (e.g. guidance_registry_test.py for GuidanceRegistryTests,
guidance_wrapper_parity_test.py for GuidanceWrapperParityTests,
guidance_state_lifecycle_test.py for GuidanceStateLifecycleTests,
guidance_corrector_step_test.py for GuidanceCorrectorStepTests) keeping the same
imports and fixtures (_fixed_tensors, MomentumBuffer, get_guidance_fn, etc.), or
if you cannot split now, add a short PR note with a concrete follow-up split
plan naming these target files and a timeline before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RFC: pluggable guidance function registry for sampler customization

1 participant