fix: capture residual hidden states via forward hooks instead of output_hidden_states#384
fix: capture residual hidden states via forward hooks instead of output_hidden_states#384kali113 wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Updates get_residuals() to collect per-layer residual/hidden vectors via PyTorch forward hooks so it works on model implementations where output_hidden_states isn’t reliably supported.
Changes:
- Replaces reliance on
outputs.hidden_stateswith forward-hook capture across embedding + decoder layers. - Adds a fallback path that still uses
outputs.hidden_stateswhen hooks don’t capture anything. - Ensures hook removal via
try/finallyaround generation.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _, outputs = self.generate( | ||
| prompts, | ||
| max_new_tokens=1, | ||
| output_hidden_states=True, | ||
| return_dict_in_generate=True, | ||
| use_cache=False, | ||
| ) |
| def capture_embedding(m: Module, inp: tuple[Tensor, ...]) -> None: | ||
| if not all_hidden: | ||
| all_hidden.append(inp[0].detach()) | ||
| handles.append(layers[0].register_forward_pre_hook(capture_embedding)) | ||
|
|
||
| def capture_layer(m: Module, inp: tuple[Tensor, ...], out: Tensor | tuple[Tensor, ...]) -> None: | ||
| all_hidden.append((out[0] if isinstance(out, tuple) else out).detach()) | ||
| for layer in layers: | ||
| handles.append(layer.register_forward_hook(capture_layer)) |
| # This cast is valid because we passed output_hidden_states=True above. | ||
| hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0] | ||
| all_hidden: list[Tensor] = [] | ||
| handles: list[Any] = [] |
| all_hidden: list[Tensor] = [] | ||
| handles: list[Any] = [] | ||
| layers = self.get_layers() | ||
|
|
||
| # Forward hooks capture intermediate hidden states from each decoder layer, | ||
| # working even for models with custom code that don't implement | ||
| # output_hidden_states (e.g., AfmoeForCausalLM / Trinity-Nano). | ||
| # A pre-hook on the first layer captures the embedding output (its input), | ||
| # which is the first element of the standard hidden_states tuple. |
There was a problem hiding this comment.
Code Review
This pull request updates the get_residuals method in src/heretic/model.py to capture intermediate hidden states using PyTorch forward hooks, providing a more robust solution for custom model architectures that do not support output_hidden_states. The review feedback correctly identifies a potential issue where multiple forward passes during generation could cause the hook callbacks to accumulate duplicate hidden states, leading to shape mismatches. A code suggestion is provided to pre-allocate the hidden states list and safely assign states to specific indices only once.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| all_hidden: list[Tensor] = [] | ||
| handles: list[Any] = [] | ||
| layers = self.get_layers() | ||
|
|
||
| # Forward hooks capture intermediate hidden states from each decoder layer, | ||
| # working even for models with custom code that don't implement | ||
| # output_hidden_states (e.g., AfmoeForCausalLM / Trinity-Nano). | ||
| # A pre-hook on the first layer captures the embedding output (its input), | ||
| # which is the first element of the standard hidden_states tuple. | ||
| def capture_embedding(m: Module, inp: tuple[Tensor, ...]) -> None: | ||
| if not all_hidden: | ||
| all_hidden.append(inp[0].detach()) | ||
| handles.append(layers[0].register_forward_pre_hook(capture_embedding)) | ||
|
|
||
| def capture_layer(m: Module, inp: tuple[Tensor, ...], out: Tensor | tuple[Tensor, ...]) -> None: | ||
| all_hidden.append((out[0] if isinstance(out, tuple) else out).detach()) | ||
| for layer in layers: | ||
| handles.append(layer.register_forward_hook(capture_layer)) | ||
|
|
||
| try: | ||
| _, outputs = self.generate( | ||
| prompts, | ||
| max_new_tokens=1, | ||
| output_hidden_states=True, | ||
| return_dict_in_generate=True, | ||
| use_cache=False, | ||
| ) | ||
| finally: | ||
| for handle in handles: | ||
| handle.remove() | ||
|
|
||
| # The returned tensor has shape (prompt, layer, component). | ||
| residuals = torch.stack( | ||
| # layer_hidden_states has shape (prompt, position, component), | ||
| # so this extracts the hidden states at the end of each prompt, | ||
| # and stacks them up over the layers. | ||
| [layer_hidden_states[:, -1, :] for layer_hidden_states in hidden_states], | ||
| dim=1, | ||
| ) | ||
| if all_hidden: | ||
| # Use hook-captured hidden states (reliable for all model architectures). | ||
| # all_hidden: (embedding, layer_0, layer_1, ..., layer_N) | ||
| residuals = torch.stack( | ||
| [h[:, -1, :] for h in all_hidden], | ||
| dim=1, | ||
| ) |
There was a problem hiding this comment.
During generation, self.generate might trigger multiple forward passes (e.g., if there are multiple generation steps or internal helper passes). Since capture_layer unconditionally appends to all_hidden on every call, all_hidden can accumulate more hidden states than expected, leading to shape mismatches or incorrect layer alignment when stacking.
To make this robust, we can initialize all_hidden as a list of None with a fixed size of len(layers) + 1 and assign the captured states to their respective indices. This ensures that we only capture the first forward pass's states and that they are always in the correct order.
Additionally, the comments have been updated to adhere to the repository style guide (starting with a capital letter and ending with a period).
layers = self.get_layers()
all_hidden: list[Tensor | None] = [None] * (len(layers) + 1)
handles: list[Any] = []
# Forward hooks capture intermediate hidden states from each decoder layer.
# This works even for models with custom code that do not implement output_hidden_states (e.g., AfmoeForCausalLM / Trinity-Nano).
# A pre-hook on the first layer captures the embedding output (its input), which is the first element of the standard hidden_states tuple.
def capture_embedding(m: Module, inp: tuple[Tensor, ...]) -> None:
if all_hidden[0] is None and inp:
all_hidden[0] = inp[0].detach()
handles.append(layers[0].register_forward_pre_hook(capture_embedding))
def capture_layer(m: Module, inp: tuple[Tensor, ...], out: Tensor | tuple[Tensor, ...], idx: int) -> None:
if all_hidden[idx + 1] is None:
val = out[0] if isinstance(out, tuple) else out
if val is not None:
all_hidden[idx + 1] = val.detach()
for idx, layer in enumerate(layers):
handles.append(layer.register_forward_hook(
lambda m, inp, out, idx=idx: capture_layer(m, inp, out, idx)
))
try:
_, outputs = self.generate(
prompts,
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
use_cache=False,
)
finally:
for handle in handles:
handle.remove()
# The returned tensor has shape (prompt, layer, component).
if all_hidden and all(h is not None for h in all_hidden):
# Use hook-captured hidden states (reliable for all model architectures).
# The list all_hidden contains: (embedding, layer_0, layer_1, ..., layer_N).
residuals = torch.stack(
[h[:, -1, :] for h in all_hidden if h is not None],
dim=1,
)References
- Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. (link)
0b18bab to
2993502
Compare
|
Thanks for the PR! It's great that you figured out what the underlying problem is with this model. I actually have some very advanced hook functionality waiting in #211 already, and I'd prefer to implement the fix on top of that, so I'll let this PR rest until we can figure out how to best bring these things together, especially since this is really hard to test comprehensively and might have adverse effects on other models. In the future, please write PR descriptions yourself instead of handing the task over to an LLM. I'm a human, and it takes (a lot) of human effort to review PRs. That effort is only worth it for me if there is a human on the other end. I can run Claude myself anytime I want, that's not what PRs are for. |
|
@p-e-w I will do that next time! Sorry |
2993502 to
01a154d
Compare
…ut_hidden_states Models with custom code (e.g., AfmoeForCausalLM / Trinity-Nano) don't implement output_hidden_states in their forward method, so passing output_hidden_states=True to generate() returns hidden_states=None, crashing get_residuals with 'NoneType is not iterable'. Replace reliance on output_hidden_states with forward hooks on the decoder layers. A pre-hook on the first layer captures the embedding output (its input), and hooks on each layer capture the intermediate hidden states. This works for any model architecture, regardless of output_hidden_states support. Fixes p-e-w#109
01a154d to
fa80ae8
Compare
Trinity-Nano-Preview crashes in get_residuals because AfmoeForCausalLM doesn't support output_hidden_states — the flag gets swallowed by **kwargs and hidden_states comes back None.
Replaced output_hidden_states with forward hooks on the decoder layers. A pre-hook on the first layer grabs the embedding output, and hooks on each layer grab the intermediate hidden states. Works for any architecture.
Fixes #109