Skip to content

fix: capture residual hidden states via forward hooks instead of output_hidden_states#384

Open
kali113 wants to merge 1 commit into
p-e-w:masterfrom
kali113:fix/trinity-nano-residuals
Open

fix: capture residual hidden states via forward hooks instead of output_hidden_states#384
kali113 wants to merge 1 commit into
p-e-w:masterfrom
kali113:fix/trinity-nano-residuals

Conversation

@kali113

@kali113 kali113 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Trinity-Nano-Preview crashes in get_residuals because AfmoeForCausalLM doesn't support output_hidden_states — the flag gets swallowed by **kwargs and hidden_states comes back None.

Replaced output_hidden_states with forward hooks on the decoder layers. A pre-hook on the first layer grabs the embedding output, and hooks on each layer grab the intermediate hidden states. Works for any architecture.

Fixes #109

Copilot AI review requested due to automatic review settings June 16, 2026 19:22

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Updates get_residuals() to collect per-layer residual/hidden vectors via PyTorch forward hooks so it works on model implementations where output_hidden_states isn’t reliably supported.

Changes:

  • Replaces reliance on outputs.hidden_states with forward-hook capture across embedding + decoder layers.
  • Adds a fallback path that still uses outputs.hidden_states when hooks don’t capture anything.
  • Ensures hook removal via try/finally around generation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/heretic/model.py Outdated
Comment on lines +714 to +720
_, outputs = self.generate(
prompts,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
use_cache=False,
)
Comment thread src/heretic/model.py Outdated
Comment on lines +703 to +711
def capture_embedding(m: Module, inp: tuple[Tensor, ...]) -> None:
if not all_hidden:
all_hidden.append(inp[0].detach())
handles.append(layers[0].register_forward_pre_hook(capture_embedding))

def capture_layer(m: Module, inp: tuple[Tensor, ...], out: Tensor | tuple[Tensor, ...]) -> None:
all_hidden.append((out[0] if isinstance(out, tuple) else out).detach())
for layer in layers:
handles.append(layer.register_forward_hook(capture_layer))
Comment thread src/heretic/model.py
# This cast is valid because we passed output_hidden_states=True above.
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
all_hidden: list[Tensor] = []
handles: list[Any] = []
Comment thread src/heretic/model.py Outdated
Comment on lines +694 to +702
all_hidden: list[Tensor] = []
handles: list[Any] = []
layers = self.get_layers()

# Forward hooks capture intermediate hidden states from each decoder layer,
# working even for models with custom code that don't implement
# output_hidden_states (e.g., AfmoeForCausalLM / Trinity-Nano).
# A pre-hook on the first layer captures the embedding output (its input),
# which is the first element of the standard hidden_states tuple.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the get_residuals method in src/heretic/model.py to capture intermediate hidden states using PyTorch forward hooks, providing a more robust solution for custom model architectures that do not support output_hidden_states. The review feedback correctly identifies a potential issue where multiple forward passes during generation could cause the hook callbacks to accumulate duplicate hidden states, leading to shape mismatches. A code suggestion is provided to pre-allocate the hidden states list and safely assign states to specific indices only once.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/heretic/model.py Outdated
Comment on lines +694 to +732
all_hidden: list[Tensor] = []
handles: list[Any] = []
layers = self.get_layers()

# Forward hooks capture intermediate hidden states from each decoder layer,
# working even for models with custom code that don't implement
# output_hidden_states (e.g., AfmoeForCausalLM / Trinity-Nano).
# A pre-hook on the first layer captures the embedding output (its input),
# which is the first element of the standard hidden_states tuple.
def capture_embedding(m: Module, inp: tuple[Tensor, ...]) -> None:
if not all_hidden:
all_hidden.append(inp[0].detach())
handles.append(layers[0].register_forward_pre_hook(capture_embedding))

def capture_layer(m: Module, inp: tuple[Tensor, ...], out: Tensor | tuple[Tensor, ...]) -> None:
all_hidden.append((out[0] if isinstance(out, tuple) else out).detach())
for layer in layers:
handles.append(layer.register_forward_hook(capture_layer))

try:
_, outputs = self.generate(
prompts,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
use_cache=False,
)
finally:
for handle in handles:
handle.remove()

# The returned tensor has shape (prompt, layer, component).
residuals = torch.stack(
# layer_hidden_states has shape (prompt, position, component),
# so this extracts the hidden states at the end of each prompt,
# and stacks them up over the layers.
[layer_hidden_states[:, -1, :] for layer_hidden_states in hidden_states],
dim=1,
)
if all_hidden:
# Use hook-captured hidden states (reliable for all model architectures).
# all_hidden: (embedding, layer_0, layer_1, ..., layer_N)
residuals = torch.stack(
[h[:, -1, :] for h in all_hidden],
dim=1,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

During generation, self.generate might trigger multiple forward passes (e.g., if there are multiple generation steps or internal helper passes). Since capture_layer unconditionally appends to all_hidden on every call, all_hidden can accumulate more hidden states than expected, leading to shape mismatches or incorrect layer alignment when stacking.

To make this robust, we can initialize all_hidden as a list of None with a fixed size of len(layers) + 1 and assign the captured states to their respective indices. This ensures that we only capture the first forward pass's states and that they are always in the correct order.

Additionally, the comments have been updated to adhere to the repository style guide (starting with a capital letter and ending with a period).

        layers = self.get_layers()
        all_hidden: list[Tensor | None] = [None] * (len(layers) + 1)
        handles: list[Any] = []

        # Forward hooks capture intermediate hidden states from each decoder layer.
        # This works even for models with custom code that do not implement output_hidden_states (e.g., AfmoeForCausalLM / Trinity-Nano).
        # A pre-hook on the first layer captures the embedding output (its input), which is the first element of the standard hidden_states tuple.
        def capture_embedding(m: Module, inp: tuple[Tensor, ...]) -> None:
            if all_hidden[0] is None and inp:
                all_hidden[0] = inp[0].detach()
        handles.append(layers[0].register_forward_pre_hook(capture_embedding))

        def capture_layer(m: Module, inp: tuple[Tensor, ...], out: Tensor | tuple[Tensor, ...], idx: int) -> None:
            if all_hidden[idx + 1] is None:
                val = out[0] if isinstance(out, tuple) else out
                if val is not None:
                    all_hidden[idx + 1] = val.detach()
        for idx, layer in enumerate(layers):
            handles.append(layer.register_forward_hook(
                lambda m, inp, out, idx=idx: capture_layer(m, inp, out, idx)
            ))

        try:
            _, outputs = self.generate(
                prompts,
                max_new_tokens=1,
                output_hidden_states=True,
                return_dict_in_generate=True,
                use_cache=False,
            )
        finally:
            for handle in handles:
                handle.remove()

        # The returned tensor has shape (prompt, layer, component).
        if all_hidden and all(h is not None for h in all_hidden):
            # Use hook-captured hidden states (reliable for all model architectures).
            # The list all_hidden contains: (embedding, layer_0, layer_1, ..., layer_N).
            residuals = torch.stack(
                [h[:, -1, :] for h in all_hidden if h is not None],
                dim=1,
            )
References
  1. Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. (link)

@kali113 kali113 force-pushed the fix/trinity-nano-residuals branch 2 times, most recently from 0b18bab to 2993502 Compare June 16, 2026 19:28
@p-e-w

p-e-w commented Jun 17, 2026

Copy link
Copy Markdown
Owner

Thanks for the PR! It's great that you figured out what the underlying problem is with this model. I actually have some very advanced hook functionality waiting in #211 already, and I'd prefer to implement the fix on top of that, so I'll let this PR rest until we can figure out how to best bring these things together, especially since this is really hard to test comprehensively and might have adverse effects on other models.

In the future, please write PR descriptions yourself instead of handing the task over to an LLM. I'm a human, and it takes (a lot) of human effort to review PRs. That effort is only worth it for me if there is a human on the other end. I can run Claude myself anytime I want, that's not what PRs are for.

@kali113

kali113 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor Author

@p-e-w I will do that next time! Sorry

@kali113 kali113 force-pushed the fix/trinity-nano-residuals branch from 2993502 to 01a154d Compare June 17, 2026 14:09
…ut_hidden_states

Models with custom code (e.g., AfmoeForCausalLM / Trinity-Nano) don't
implement output_hidden_states in their forward method, so passing
output_hidden_states=True to generate() returns hidden_states=None,
crashing get_residuals with 'NoneType is not iterable'.

Replace reliance on output_hidden_states with forward hooks on the
decoder layers. A pre-hook on the first layer captures the embedding
output (its input), and hooks on each layer capture the intermediate
hidden states. This works for any model architecture, regardless of
output_hidden_states support.

Fixes p-e-w#109
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for Trinity-Nano

3 participants