diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index 29912c63..e7454c4f 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -34,6 +34,7 @@ KVzapPress, KVzipPress, LagKVPress, + MergingPress, ObservedAttentionPress, PyramidKVPress, QFilterPress, @@ -121,4 +122,9 @@ "decoding_adakv_expected_attention_e2": DecodingPress(base_press=AdaKVPress(ExpectedAttentionPress(epsilon=1e-2))), "decoding_adakv_snapkv": DecodingPress(base_press=AdaKVPress(SnapKVPress())), "decoding_keydiff": DecodingPress(base_press=KeyDiffPress()), + # MergingPress: merge-on-evict during prefill (values-only merge preserves RoPE keys) + "merging_knorm": MergingPress(KnormPress()), + "merging_snapkv": MergingPress(SnapKVPress()), + "merging_adakv_snapkv": MergingPress(AdaKVPress(SnapKVPress())), + "merging_dms_kvzap_mlp": MergingPress(DMSPress(press=KVzapPress(model_type="mlp"))), } diff --git a/kvpress/__init__.py b/kvpress/__init__.py index f4da32a4..b3771bf8 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -29,6 +29,7 @@ from kvpress.presses.kvzip_press import KVzipPress from kvpress.presses.lagkv_press import LagKVPress from kvpress.presses.leverage_press import LeverageScorePress +from kvpress.presses.merging_press import MergingPress from kvpress.presses.non_causal_attention_press import NonCausalAttnPress from kvpress.presses.observed_attention_press import ObservedAttentionPress from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress @@ -87,4 +88,5 @@ "DMSPress", "FastKVzipPress", "KVComposePress", + "MergingPress", ] diff --git a/kvpress/presses/merging_press.py b/kvpress/presses/merging_press.py new file mode 100644 index 00000000..6bd0e903 --- /dev/null +++ b/kvpress/presses/merging_press.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Generator + +import torch +from torch import nn +from transformers import Gemma3ForConditionalGeneration, PreTrainedModel, QuantizedCache + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.utils import extract_keys_and_values + +logger = logging.getLogger(__name__) + +# Epsilon for numerical stability — safe for float16 (min ~6e-8) and bfloat16 +_EPS = 1e-6 + + +def _merge_on_evict( + keys: torch.Tensor, + values: torch.Tensor, + evict_mask: torch.Tensor, + similarity_threshold: float, + merge_keys: bool, + value_norm_weighting: bool, + max_merge_per_token: int = 0, + merge_fraction: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Merge evicted tokens into their most cosine-similar survivors. + + Handles variable per-head eviction counts (from AdaKV, DMSPress, or uniform + scorer-based eviction). Merged information is written in-place into the + survivor positions of the full-length tensors. + + **Per-token value-reconstruction bound.** For a single evicted token *i* + routed to survivor *j* with cosine similarity :math:`w = \\cos(k_i, k_j)`: + + .. math:: + + \\|\\Delta v_{\\text{merge}}\\| \\leq \\frac{1}{1 + w} \\;\\|v_i\\| + + This bounds per-token reconstruction error in value space only, not the + end-to-end output after softmax re-normalization. + + Parameters + ---------- + keys : Tensor, shape ``(B, H, L, D)`` + values : Tensor, shape ``(B, H, L, D)`` + evict_mask : Tensor, shape ``(B, H, L)``, dtype bool + ``True`` at positions to evict, ``False`` at positions to keep. + similarity_threshold : float + Minimum cosine similarity for a merge to proceed. + merge_keys : bool + Whether to merge evicted information into survivor keys. + value_norm_weighting : bool + Scale merge weight by relative value-vector L2 norm. + max_merge_per_token : int, default=0 + Cap on merges per survivor. ``0`` disables. + merge_fraction : float, default=1.0 + Fraction of evicted tokens (ranked by similarity) that are merged. + + Returns + ------- + tuple[Tensor, Tensor] + ``(new_keys, new_values)`` — same shape as input. + + References + ---------- + .. [1] Bolya et al., "Token Merging: Your ViT But Faster", ICLR 2023. + .. [2] Wan et al., "D2O: Dynamic Discriminative Operations", 2024. + .. [3] Huang et al., "KeepKV: Lossless KV Cache Compression", 2025. + """ + bsz, num_kv_heads, k_len, head_dim = keys.shape + device = keys.device + + merged_values = values.float().clone() + merged_keys = keys.float().clone() if merge_keys else None + + for b in range(bsz): + for h in range(num_kv_heads): + evict_idx = evict_mask[b, h].nonzero(as_tuple=True)[0] + keep_idx = (~evict_mask[b, h]).nonzero(as_tuple=True)[0] + n_evict = evict_idx.shape[0] + n_kept = keep_idx.shape[0] + + if n_evict == 0 or n_kept == 0: + continue + + # Cosine similarity → nearest survivor + evict_k = keys[b, h, evict_idx].float() + kept_k = keys[b, h, keep_idx].float() + e_norms = evict_k.norm(dim=-1, keepdim=True).clamp(min=_EPS) + k_norms = kept_k.norm(dim=-1, keepdim=True).clamp(min=_EPS) + sim = (evict_k / e_norms) @ (kept_k / k_norms).T + max_sim, target = sim.max(dim=-1) + + # Threshold gate + merge_ok = max_sim >= similarity_threshold + + # Fraction gate: keep only top merge_fraction of evicted tokens + if merge_fraction < 1.0 and merge_ok.any(): + masked_sim = max_sim.clone() + masked_sim[~merge_ok] = float("inf") + q = 1.0 - merge_fraction + frac_threshold = masked_sim.quantile(q) + merge_ok = merge_ok & (max_sim >= frac_threshold) + + if not merge_ok.any(): + continue + + # Similarity-weighted merge + w = max_sim.clamp(min=0) * merge_ok.float() + + if value_norm_weighting: + evict_v = values[b, h, evict_idx].float() + target_v = values[b, h, keep_idx[target]].float() + ev_norm = evict_v.norm(dim=-1) + tv_norm = target_v.norm(dim=-1) + w = w * ev_norm / (ev_norm + tv_norm + _EPS) + + # Merge count cap: prevent survivor dilution + if max_merge_per_token > 0: + count = torch.zeros(n_kept, device=device, dtype=torch.float32) + count.scatter_add_(0, target, merge_ok.float()) + excess = (count / max_merge_per_token).clamp(min=1.0) + w = w / excess[target] + + w_exp = w.unsqueeze(-1) + evict_v = values[b, h, evict_idx].float() + + val_accum = torch.zeros(n_kept, head_dim, device=device, dtype=torch.float32) + val_accum.scatter_add_(0, target.unsqueeze(-1).expand_as(evict_v), w_exp * evict_v) + w_accum = torch.zeros(n_kept, device=device, dtype=torch.float32) + w_accum.scatter_add_(0, target, w) + + active = w_accum > 0 + total_w = (1.0 + w_accum).unsqueeze(-1) + + orig_v = merged_values[b, h, keep_idx] + new_v = (orig_v + val_accum) / total_w + merged_values[b, h, keep_idx] = torch.where(active.unsqueeze(-1), new_v, orig_v) + + if merge_keys and merged_keys is not None: + evict_k_orig = keys[b, h, evict_idx].float() + key_accum = torch.zeros(n_kept, head_dim, device=device, dtype=torch.float32) + key_accum.scatter_add_(0, target.unsqueeze(-1).expand_as(evict_k_orig), w_exp * evict_k_orig) + orig_k = merged_keys[b, h, keep_idx] + new_k = (orig_k + key_accum) / total_w + merged_keys[b, h, keep_idx] = torch.where(active.unsqueeze(-1), new_k, orig_k) + + result_values = merged_values.to(values.dtype) + result_keys = merged_keys.to(keys.dtype) if merge_keys else keys + return result_keys, result_values + + +@dataclass +class MergingPress(BasePress): + """ + Press-agnostic merge-on-evict wrapper for KV cache compression. + + Wraps **any** :class:`BasePress` and replaces hard eviction with merge-on-evict: + each evicted token is folded into its most similar surviving neighbor rather than + being discarded. Values are blended via a similarity-weighted average; keys can + optionally be merged depending on the ``merge_keys`` flag. + + **Composition modes:** + + * ``MergingPress(ScorerPress)``: calls ``.score()``, applies uniform per-head + budget, returns truncated tensors with merged survivors. + * ``MergingPress(AdaKVPress(ScorerPress))``: delegates to AdaKV's adaptive + per-head budget allocation, then merges evicted tokens into survivors in-place. + * ``MergingPress(DMSPress(ScorerPress))``: **post-hook composition** — lets + DMSPress register its own hooks via its ``__call__`` context manager, then + adds merge post-hooks that fire after each layer. The inner press runs + exactly as standalone; MergingPress reads ``masked_key_indices`` and merges. + Works with any hook-based press (DMSPress, KVzipPress, FastKVzipPress, + KVComposePress). + + Parameters + ---------- + press : BasePress + The underlying press. + similarity_threshold : float, default=0.0 + Minimum cosine similarity for a merge to proceed. + merge_keys : bool, default=False + Whether to merge evicted keys. ``False`` preserves RoPE encoding. + value_norm_weighting : bool, default=True + Scale merge weight by relative value-vector L2 norm. + max_merge_per_token : int, default=0 + Cap on merges per survivor. ``0`` disables. + merge_fraction : float, default=1.0 + Fraction of evicted tokens (by similarity) that are merged. + + See also + -------- + Bolya et al., "Token Merging", ICLR 2023. + Zhang et al., "CaM", ICML 2024. + Wan et al., "KeepKV", 2025. + """ + + press: BasePress + similarity_threshold: float = 0.0 + merge_keys: bool = False + value_norm_weighting: bool = True + max_merge_per_token: int = 0 + merge_fraction: float = 1.0 + + def __post_init__(self): + assert isinstance(self.press, BasePress), f"MergingPress requires a BasePress, got {type(self.press)}" + assert 0.0 <= self.similarity_threshold <= 1.0 + assert self.max_merge_per_token >= 0, "max_merge_per_token must be non-negative" + assert 0.0 < self.merge_fraction <= 1.0, "merge_fraction must be in (0, 1]" + + def post_init_from_model(self, model): + self.press.post_init_from_model(model) + + @property + def compression_ratio(self): + return self.press.compression_ratio + + @compression_ratio.setter + def compression_ratio(self, value): + self.press.compression_ratio = value + + def _merge_kwargs(self) -> dict: + """Common kwargs for _merge_on_evict calls.""" + return dict( + similarity_threshold=self.similarity_threshold, + merge_keys=self.merge_keys, + value_norm_weighting=self.value_norm_weighting, + max_merge_per_token=self.max_merge_per_token, + merge_fraction=self.merge_fraction, + ) + + def compress( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.press.compression_ratio == 0: + return keys, values + + bsz, num_kv_heads, k_len, head_dim = keys.shape + + # --- ScorerPress path: uniform per-head merge, returns truncated tensors --- + if isinstance(self.press, ScorerPress): + return self._compress_scorer(module, hidden_states, keys, values, attentions, kwargs) + + # --- Mask-based press path (AdaKV, CriticalAdaKV, etc.) --- + keys, values = self.press.compress(module, hidden_states, keys, values, attentions, kwargs) + + mask_indices = getattr(module, "masked_key_indices", None) + if mask_indices is None: + return keys, values + + evict_mask = torch.zeros(bsz, num_kv_heads, k_len, device=keys.device, dtype=torch.bool) + evict_mask[tuple(mask_indices)] = True + + new_keys, new_values = _merge_on_evict(keys, values, evict_mask, **self._merge_kwargs()) + return new_keys, new_values + + def _compress_scorer(self, module, hidden_states, keys, values, attentions, kwargs): + bsz, num_kv_heads, k_len, head_dim = keys.shape + scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) + + n_kept = int(k_len * (1 - self.press.compression_ratio)) + if n_kept >= k_len: + return keys, values + if n_kept <= 0: + return keys[:, :, :0, :].contiguous(), values[:, :, :0, :].contiguous() + + # Build evict mask from scores, merge in-place, then extract kept positions + keep_idx = scores.topk(n_kept, dim=-1).indices + evict_mask = torch.ones(bsz, num_kv_heads, k_len, device=keys.device, dtype=torch.bool) + evict_mask.scatter_(2, keep_idx, False) + + new_keys, new_values = _merge_on_evict(keys, values, evict_mask, **self._merge_kwargs()) + + # Extract kept positions (topk order to match ScorerPress.compress behavior) + idx4 = keep_idx.unsqueeze(-1).expand(-1, -1, -1, head_dim) + return new_keys.gather(2, idx4), new_values.gather(2, idx4) + + def _uses_hook_composition(self) -> bool: + """True if inner press uses forward_hook for eviction (not compress). + + These presses (e.g. DMSPress, KVzipPress, FastKVzipPress, KVComposePress) + set ``module.masked_key_indices`` in their ``forward_hook``. MergingPress + adds merge-on-evict as a post-processor after the inner hook, without + modifying the inner press's execution or state. + """ + return (type(self.press).compress is BasePress.compress + and type(self.press).forward_hook is not BasePress.forward_hook) + + def _merge_post_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """Post-hook that merges evicted tokens into survivors. + + Reads ``module.masked_key_indices`` (set by any mask-based press) and + applies merge-on-evict to fold evicted values into their most + cosine-similar survivors. Purely additive — does not modify the inner + press's execution or state. + """ + mask_indices = getattr(module, "masked_key_indices", None) + if mask_indices is None or len(mask_indices[0]) == 0: + return output + + cache = kwargs["past_key_values"] + keys, values = extract_keys_and_values(cache, module.layer_idx) + bsz, num_kv_heads, k_len, _ = keys.shape + + evict_mask = torch.zeros(bsz, num_kv_heads, k_len, device=keys.device, dtype=torch.bool) + evict_mask[tuple(mask_indices)] = True + + new_keys, new_values = _merge_on_evict(keys, values, evict_mask, **self._merge_kwargs()) + + # Write merged values back to cache + cache_layer = cache.layers[module.layer_idx] + if isinstance(cache, QuantizedCache): + cache_layer._quantized_keys = cache_layer._quantize(new_keys, axis=cache_layer.axis_key) + cache_layer._quantized_values = cache_layer._quantize(new_values, axis=cache_layer.axis_value) + cache_layer.keys = torch.zeros(0, dtype=new_keys.dtype, device=new_keys.device) + cache_layer.values = torch.zeros(0, dtype=new_keys.dtype, device=new_keys.device) + cache_layer.cumulative_length = new_keys.shape[2] + else: + cache_layer.keys = new_keys + cache_layer.values = new_values + + return output + + @contextmanager + def __call__(self, model: PreTrainedModel) -> Generator: + """Context manager supporting both standard and hook-based inner presses. + + For standard presses (ScorerPress, AdaKV): delegates to + :meth:`BasePress.__call__`. + + For hook-based presses (DMSPress, KVzipPress, etc.): lets the inner press + register its own hooks via its own ``__call__``, then adds merge + post-hooks on top. The inner press's contract is fully preserved. + """ + if not self._uses_hook_composition(): + with super().__call__(model): + yield + return + + # Hook-based press: let inner press register its hooks via its own __call__ + with self.press(model): + merge_hooks = [] + try: + language_model = ( + model.model.language_model if hasattr(model.model, "language_model") else model.model + ) + for layer in language_model.layers: + if isinstance(model, Gemma3ForConditionalGeneration) and layer.self_attn.is_sliding: + continue + merge_hooks.append( + layer.self_attn.register_forward_hook(self._merge_post_hook, with_kwargs=True) + ) + yield + finally: + for hook in merge_hooks: + hook.remove() + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """Forward hook supporting both standard and hook-based inner presses. + + For standard presses: delegates to :meth:`BasePress.forward_hook`. + + For hook-based presses (fallback for nested composition, e.g. inside + :class:`PrefillDecodingPress`): delegates to the inner press's + ``forward_hook`` first, then applies merge post-processing. + """ + if not self._uses_hook_composition(): + return super().forward_hook(module, input, kwargs, output) + + # Delegate to inner press hook (runs unchanged), then merge + output = self.press.forward_hook(module, input, kwargs, output) + return self._merge_post_hook(module, input, kwargs, output) diff --git a/tests/default_presses.py b/tests/default_presses.py index ecc62eb4..36f04a06 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -17,6 +17,7 @@ KVzipPress, LagKVPress, LeverageScorePress, + MergingPress, NonCausalAttnPress, PyramidKVPress, QFilterPress, @@ -151,4 +152,11 @@ def post_init_from_model(self, model): {"structured": False, "compression_ratio": 0.8}, ], }, + { + "cls": MergingPress, + "kwargs": [ + {"press": KnormPress(compression_ratio=0.2)}, + {"press": KnormPress(compression_ratio=0.8)}, + ], + }, ] diff --git a/tests/presses/test_merging_press.py b/tests/presses/test_merging_press.py new file mode 100644 index 00000000..148912c4 --- /dev/null +++ b/tests/presses/test_merging_press.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from transformers import DynamicCache + +from kvpress import AdaKVPress, DMSPress, KnormPress, RandomPress, SnapKVPress +from kvpress.presses.merging_press import MergingPress +from tests.fixtures import unit_test_model # noqa: F401 + + +def test_merge_differs_from_hard_eviction(unit_test_model): # noqa: F811 + """Merged values should differ from hard-evicted values.""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 64), device=unit_test_model.device) + + base = KnormPress(compression_ratio=0.5) + with base(unit_test_model): + cache_hard = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_hard) + + wrapper = MergingPress(press=KnormPress(compression_ratio=0.5), similarity_threshold=0.0) + with wrapper(unit_test_model): + cache_merge = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_merge) + + assert cache_hard.get_seq_length() == cache_merge.get_seq_length() == 32 + any_diff = any( + not torch.equal(cache_hard.layers[i].values, cache_merge.layers[i].values) + for i in range(len(cache_hard.layers)) + ) + assert any_diff, "Merging produced identical values to hard eviction" + + +def test_default_preserves_keys(unit_test_model): # noqa: F811 + """Default merge_keys=False should not modify keys (preserves RoPE).""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 64), device=unit_test_model.device) + + base = KnormPress(compression_ratio=0.5) + with base(unit_test_model): + cache_hard = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_hard) + + wrapper = MergingPress(press=KnormPress(compression_ratio=0.5)) + with wrapper(unit_test_model): + cache_merge = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_merge) + + for i in range(len(cache_hard.layers)): + assert torch.equal(cache_hard.layers[i].keys, cache_merge.layers[i].keys), ( + f"Layer {i}: merge_keys=False should not modify keys" + ) + + +def test_merge_preserves_more_info(unit_test_model): # noqa: F811 + """Merge-on-evict stays closer to uncompressed cache than hard eviction.""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 64), device=unit_test_model.device) + + cache_ref = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_ref) + ref_values = [layer.values.float() for layer in cache_ref.layers] + + base = KnormPress(compression_ratio=0.7) + with base(unit_test_model): + cache_hard = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_hard) + + wrapper = MergingPress(press=KnormPress(compression_ratio=0.7), similarity_threshold=0.0) + with wrapper(unit_test_model): + cache_merge = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_merge) + + def recon_error(cache): + return sum( + (layer.values.float() - ref_values[i][:, :, : layer.values.shape[2]]).norm().item() + for i, layer in enumerate(cache.layers) + ) + + assert recon_error(cache_merge) <= recon_error(cache_hard) + 1e-6 + + +def test_half_precision_no_nan(unit_test_model): # noqa: F811 + """Float32 accumulation must produce finite results in fp16.""" + model = unit_test_model.to(torch.float16) + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 64), device=model.device) + + wrapper = MergingPress(press=KnormPress(compression_ratio=0.5)) + with wrapper(model): + cache = DynamicCache() + model(input_ids, past_key_values=cache) + + for layer in cache.layers: + assert torch.isfinite(layer.keys).all() + assert torch.isfinite(layer.values).all() + model.float() + + +def test_batch_size_greater_than_one(unit_test_model): # noqa: F811 + """Kernel must handle batch_size > 1 correctly.""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (2, 64), device=unit_test_model.device) + + wrapper = MergingPress(press=KnormPress(compression_ratio=0.5)) + with wrapper(unit_test_model): + cache = DynamicCache() + unit_test_model(input_ids, past_key_values=cache) + + assert cache.get_seq_length() == 32 + for layer in cache.layers: + assert layer.keys.shape[0] == 2 + + +def test_adakv_composition(unit_test_model): # noqa: F811 + """MergingPress(AdaKV) uses mask-based path and changes values.""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 128), device=unit_test_model.device) + + plain = AdaKVPress(SnapKVPress(compression_ratio=0.5)) + with plain(unit_test_model): + cache_plain = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_plain) + + wrapper = MergingPress(press=AdaKVPress(SnapKVPress(compression_ratio=0.5))) + with wrapper(unit_test_model): + cache_merge = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_merge) + + any_diff = any( + not torch.equal(cache_plain.layers[i].values, cache_merge.layers[i].values) + for i in range(len(cache_plain.layers)) + ) + assert any_diff, "MergingPress(AdaKV) should differ from plain AdaKV" + + +def test_dms_hook_composition(unit_test_model): # noqa: F811 + """MergingPress(DMSPress) uses post-hook composition and changes values.""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 128), device=unit_test_model.device) + + plain = DMSPress(press=RandomPress(), threshold=0.5, sliding_window_size=0) + with plain(unit_test_model): + cache_plain = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_plain) + + wrapper = MergingPress( + press=DMSPress(press=RandomPress(), threshold=0.5, sliding_window_size=0), + similarity_threshold=0.0, + ) + assert wrapper._uses_hook_composition() + with wrapper(unit_test_model): + cache_merge = DynamicCache() + unit_test_model(input_ids.clone(), past_key_values=cache_merge) + + any_diff = any( + not torch.equal(cache_plain.layers[i].values, cache_merge.layers[i].values) + for i in range(len(cache_plain.layers)) + ) + assert any_diff, "MergingPress(DMSPress) should differ from plain DMSPress" + + +def test_forward_hook_fallback(unit_test_model): # noqa: F811 + """forward_hook delegation works for nested composition (PrefillDecodingPress path).""" + torch.manual_seed(42) + input_ids = torch.randint(0, 1024, (1, 128), device=unit_test_model.device) + dms = DMSPress(press=RandomPress(), threshold=0.5, sliding_window_size=0) + wrapper = MergingPress(press=dms, similarity_threshold=0.0) + + # Simulate PrefillDecodingPress: uses BasePress.__call__ which calls forward_hook directly + from kvpress.presses.base_press import BasePress + + with BasePress.__call__(wrapper, unit_test_model): + cache = DynamicCache() + unit_test_model(input_ids, past_key_values=cache) + assert cache.get_seq_length() > 0