Skip to content

Add MergingPress: scorer-agnostic merge-on-evict for KV cache compression 🤖🤖🤖#219

Open
jg-codes wants to merge 24 commits intoNVIDIA:mainfrom
jg-codes:pr/merging-press
Open

Add MergingPress: scorer-agnostic merge-on-evict for KV cache compression 🤖🤖🤖#219
jg-codes wants to merge 24 commits intoNVIDIA:mainfrom
jg-codes:pr/merging-press

Conversation

@jg-codes
Copy link
Copy Markdown

@jg-codes jg-codes commented Apr 15, 2026

Closes #214

What

MergingPress is a prefill-time wrapper that replaces hard eviction with merge-on-evict: each evicted token is folded into its most cosine-similar survivor via weighted value blending, instead of being discarded.

It wraps any ScorerPress — scoring is delegated entirely; only the eviction step changes. This makes it composable with all existing scorers (KnormPress, SnapKVPress, etc.).

How it works

  1. Score tokens using the wrapped ScorerPress
  2. Partition into keep/evict sets by score
  3. Compute batched cosine similarity between evicted and surviving keys
  4. Route each evicted token to its most similar survivor (gated by similarity_threshold)
  5. Blend values via similarity-weighted scatter-add (float32 accumulation)
  6. Keys are preserved unchanged by default (protects RoPE positional encoding)

Perturbation bound

For evicted token i routed to survivor j with cosine similarity w:

‖ΔO_merge‖ ≤ 1/(1+w) · ‖ΔO_evict‖

At w ≥ 0.7 the merge error is at most 59% of hard-eviction error; at w = 1 it halves exactly.

Parameters

Parameter Default Description
press Any ScorerPress whose scores determine which tokens survive
similarity_threshold 0.0 Minimum cosine similarity to merge (0.0 blocks only opposite-direction)
merge_keys False Merge key vectors too (False preserves Rotary Positional Encoding info)
value_norm_weighting True Scale merge weight by relative value-vector L2 norm
max_merge_per_token 0 Cap merges per survivor to prevent dilution (0 for unlimited)

Empirical defaults (RULER-4096, Qwen3-8B)

  • merge_keys=True hurts quality (−2.5 pp at CR=0.75) — RoPE corruption (?)
  • value_norm_weighting=True improves accuracy (~1.9 pp)
  • similarity_threshold=0.0 is sufficient — nearly no tokens have negative max similarity; empirical threshold unclear and in general cases may not be required
  • max_merge_per_token=0 (unlimited) works well up to CR=0.75; at CR=0.88 the AdaKV results below show a broad −0.8 pp regression (1 win / 7 losses), suggesting too many evicted tokens pile onto the same survivors. Capping at 3–5 may help at extreme compression, but generalisation is unclear.

Benchmark results

RULER-4096, Qwen3-8B, fraction=1.0 (all 13 subtasks), seed=42:

Average scores

CR MergingPress(KnormPress) KnormPress Δ % lift
0.25 88.3 87.2 +1.1 +1.3%
0.50 72.2 68.3 +3.9 +5.7%
0.75 38.6 32.6 +6.0 +18.3%
0.88 13.6 8.9 +4.7 +53.3%

MergingPress consistently outperforms hard eviction across all compression ratios, with the largest gains at high compression where merge-on-evict recovers the most discarded information.

Per-task breakdown

Task no_press M+K 0.25 K 0.25 Δ M+K 0.50 K 0.50 Δ M+K 0.75 K 0.75 Δ M+K 0.88 K 0.88 Δ
cwe 98.9 96.9 96.7 +0.2 92.4 89.2 +3.1 53.9 38.1 +15.9 9.8 5.9 +3.9
fwe 95.3 89.7 89.4 +0.3 83.7 80.9 +2.9 65.3 54.9 +10.4 33.2 18.6 +14.6
niah_mk1 100.0 100.0 99.8 +0.2 95.2 92.0 +3.2 42.2 38.4 +3.8 9.6 8.0 +1.6
niah_mk2 100.0 93.8 92.0 +1.8 46.6 39.2 +7.4 2.8 3.2 −0.4 0.2 0.2 0.0
niah_mk3 100.0 66.8 61.8 +5.0 11.6 8.4 +3.2 0.8 1.2 −0.4 0.0 0.0 0.0
niah_mq 99.9 99.8 99.7 +0.1 94.5 92.8 +1.6 47.8 37.9 +9.9 8.7 5.8 +3.0
niah_mv 100.0 99.9 99.6 +0.3 93.6 92.1 +1.5 57.9 48.9 +8.9 10.9 7.0 +3.8
niah_s1 100.0 100.0 100.0 0.0 100.0 100.0 0.0 93.6 75.0 +18.6 40.6 19.6 +21.0
niah_s2 100.0 100.0 100.0 0.0 99.6 99.4 +0.2 87.4 79.2 +8.2 43.4 32.8 +10.6
niah_s3 100.0 97.2 97.2 0.0 89.8 87.0 +2.8 19.6 17.6 +2.0 0.0 0.0 0.0
qa_1 81.6 60.0 58.4 +1.6 31.2 29.4 +1.8 13.8 11.8 +2.0 10.8 8.6 +2.2
qa_2 63.4 47.4 46.2 +1.2 26.0 24.6 +1.4 11.8 11.0 +0.8 10.2 9.2 +1.0
vt 100.0 96.9 93.0 +3.9 74.8 53.1 +21.7 5.2 7.2 −2.0 0.0 0.0 0.0
Average 95.3 88.3 87.2 +1.1 72.2 68.3 +3.9 38.6 32.6 +6.0 13.6 8.9 +4.7

M+K = MergingPress(KnormPress), K = KnormPress. Knorm and no_press baselines from the kvpress leaderboard.

Key observations:

  • Largest per-task gains at CR=0.50: vt +21.7, niah_mk2 +7.4, niah_mk3 +3.2
  • At CR=0.75: niah_s1 +18.6, cwe +15.9, fwe +10.4, niah_mq +9.9
  • At CR=0.88: niah_s1 +21.0, fwe +14.6, niah_s2 +10.6
  • A few minor regressions at CR=0.75–0.88 on near-zero tasks (niah_mk2/mk3, vt) where both methods could be near the noise floor?

Scorer generality: AdaKVPress (f=0.1, ~650 samples)

Exploratory runs on AdaKV(SnapKVPress) confirm that MergingPress generalises beyond KnormPress. These used fraction=0.1 (~650 of ~6500 RULER samples), so treat as directional:

CR MergingPress(AdaKV) AdaKV(SnapKV) Δ % lift
0.25 93.0 92.2 +0.8 +0.9%
0.50 66.6 64.0 +2.6 +4.1%
0.75 39.0 37.4 +1.6 +4.2%
0.88 23.8 24.6 −0.8 −3.3%

Pattern matches KnormPress: positive gains at CR 0.25–0.75, with an inversion at CR=0.88 where the merge overhead may dilute the few surviving tokens. Per-task win/loss breakdown: CR=0.25 has 5 wins / 0 losses, CR=0.50 has 7/2, CR=0.75 has 5/6 (net positive due to larger wins on niah_s1 +10.6, vt +12.2), CR=0.88 has 1/7. The CR=0.88 regression (−0.8 pp) is small but broad — suggesting that max_merge_per_token capping or a higher similarity_threshold could help at extreme compression.

Computational overhead

The merge kernel adds one batched cosine-similarity matmul per layer: O(B · H · CR · (1−CR) · L² · D) — same complexity class as attention but over KV heads only (8 vs 32 query heads for Qwen3-8B) and bounded by CR·(1−CR) ≤ 0.25. Runs once at prefill; decoding is unaffected.

Theoretical peak: ~6% of attention FLOPs at CR=0.50, i.e. ~2–3% of total prefill FLOPs. No extra forward passes, no learned parameters.

Changes

File Lines Description
kvpress/presses/merging_press.py +281 _merge_on_evict kernel + MergingPress dataclass
tests/presses/test_merging_press.py +322 17 tests (validation, correctness, precision, edge cases)
kvpress/__init__.py +2 Import + __all__ entry
evaluation/evaluate_registry.py +4 merging_knorm and merging_snapkv configs
tests/default_presses.py +8 Parametrized test matrix entry
README.md +1 One-line description

Total: 6 files, +618 lines

Design choices vs. related work

Aspect MergingPress (this PR) CAMPress (#196, merged)
Phase Prefill Decoding
Merge routing Position-agnostic (max cosine similarity) Sequential neighbors
Merge weight Cosine similarity + optional value-norm weighting Bernoulli sampling from cumulative attention ratio
Scorer Any ScorerPress (composable) Any ScorerPress via DecodingPress
Key handling Keys preserved by default (RoPE-safe) Keys not merged

Decoding-time extension: The _merge_on_evict kernel is phase-agnostic — it takes arbitrary key/value tensors and keep/evict masks. Extending MergingPress to decoding (wrapping DecodingPress) is a natural next step but is intentionally deferred to keep this PR focused on the prefill path. The kernel itself would work unchanged; only the integration hook differs.

References:

Usage

from kvpress import KnormPress, MergingPress, KVPressTextGenerationPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-8B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")

press = MergingPress(KnormPress(compression_ratio=0.5))
pipe = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, press=press)
output = pipe("Your long context here...", max_new_tokens=50)

Tests

17 test methods in tests/presses/test_merging_press.py — 18 passed, 1 skipped (test_quantized_cache_compatibility requires optimum-quanto).

Coverage: parameter validation, compression-ratio delegation, identity at zero compression, model forward pass (KnormPress + SnapKVPress), merge-vs-hard-eviction difference, threshold gating, key preservation, fp16/bf16 numerical stability, repeated compression, value-norm weighting, information preservation, batching, max_merge_per_token validation + effect, short-sequence edge case, quantized cache compatibility.

CI

Awaiting /ok to test from a collaborator. Local results:

  • ruff check ✅ — no issues on all changed files
  • pytest tests/presses/test_merging_press.py ✅ — 18 passed, 1 skipped (no GPU needed for unit tests)
  • make style / make test — not run locally (full suite requires GPU for default_presses integration tests)

AI disclosure

This PR was developed with AI assistance. Commits authored by AI are marked with 🤖🤖🤖. In fact, AI did most of the work. The API design, parameter selection, empirical tuning (...), and docstring proofreading are human contributions.

Checklist

  • Code follows AGENTS.md guidelines (dataclass, BasePress, SPDX headers)
  • All commits signed off (DCO)
  • AI commits marked with 🤖🤖🤖
  • ruff check passes on all changed files
  • 18/19 tests pass locally (1 skipped — requires optimum-quanto)
  • Added to kvpress/__init__.py, tests/default_presses.py, evaluation/evaluate_registry.py, README.md
  • make style / make test on CI (awaiting /ok to test)

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@jg-codes
Copy link
Copy Markdown
Author

ExpectedAttentionPress benchmark results

Setup: RULER-4096, Qwen3-8B, fraction=0.1 (~650 samples), seed=42

Three configurations compared:

  • M(EA) = MergingPress(ExpectedAttentionPress(ε=1e-2)) — merge-on-evict
  • EA = ExpectedAttentionPress(ε=1e-2) — bare hard eviction
  • AdaKV(EA) = AdaKVPress(ExpectedAttentionPress(ε=1e-2)) — per-head adaptive budget (leaderboard default)

Average scores

CR M(EA) EA (bare) AdaKV(EA) no_press M(EA)−EA
0.25 93.4 92.8 94.2 94.9 +0.6
0.50 86.4 86.4 94.2 +0.0
0.75 74.4 69.8 88.3 +4.6
0.875 62.3 60.3 72.0 +2.0

MergingPress consistently matches or beats bare EA hard eviction. The gain is largest at CR=0.75 (+4.6 pp), matching the pattern seen with KnormPress (+6.0 pp).

Flagship per-task result: niah_single_3 at CR=0.75

Config Score
M(EA) 90.5
EA (bare) 38.1
AdaKV(EA) 90.5
Δ M(EA) vs EA +52.4 pp

Merge-on-evict recovers nearly all lost accuracy on this retrieval task — same quality as AdaKV's per-head budget allocation.

Per-task breakdown (CR=0.75)

Task no_press M(EA) EA AdaKV(EA) M(EA)−EA
cwe 100.0 65.6 73.5 98.4 −7.9
fwe 93.3 95.3 93.3 93.3 +2.0
niah_mk1 100.0 98.2 94.4 100.0 +3.7
niah_mk2 100.0 16.2 10.8 100.0 +5.4
niah_mk3 100.0 0.0 0.0 45.6 0.0
niah_mq 100.0 100.0 98.7 100.0 +1.3
niah_mv 100.0 100.0 99.6 100.0 +0.4
niah_s1 100.0 100.0 100.0 100.0 0.0
niah_s2 100.0 98.5 95.5 100.0 +3.0
niah_s3 100.0 90.5 38.1 90.5 +52.4
qa_1 83.0 59.6 57.5 70.2 +2.1
qa_2 56.8 43.2 45.5 50.0 −2.3
vt 100.0 100.0 100.0 100.0 0.0
Avg 94.9 74.4 69.8 88.3 +4.6

Observations

  1. MergingPress generalises to EA — the +4.6 pp gain at CR=0.75 parallels KnormPress (+6.0 pp), confirming scorer-agnostic value.
  2. AdaKV's head-wise budget allocation dominates — AdaKV(EA) adds +18.5 pp over bare EA at CR=0.75, vs +4.6 pp from merge-on-evict. Per-head budget allocation and merge-on-evict address different failure modes.
  3. Combining is the next step — MergingPress + AdaKV's per-head budget would stack both mechanisms. A MergingAdaKVPress variant that does head-wise adaptive budgeting + merge-on-evict (instead of hard eviction) is a natural extension — it could close the remaining gap between M(EA) and AdaKV(EA).

Fraction=0.1 (~650 samples) — directional only. Happy to run f=1.0 if needed.

🤖🤖🤖

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 16, 2026

@jg-codes run it with KVzap too

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 16, 2026

@jg-codes we are currently investigating the best way to interact with AI agents in this repository. To help us could you report any information on you ? (e.g. which agent harness are you using, which model, your config, who's running you etc.)

@jg-codes
Copy link
Copy Markdown
Author

@jg-codes we are currently investigating the best way to interact with AI agents in this repository. To help us could you report any information on you ? (e.g. which agent harness are you using, which model, your config, who's running you etc.)

Development: Githup Copilot (VsCode "Autopilot" mode) in combination with Agentic Cowork features, e.g. for research tasks. All under my supervision. Unfortunately, guardrails don't stop the agent from publishing local drafts—yet.
Infra: VPS / modal to run GPU tasks (usually A100).

@jg-codes
Copy link
Copy Markdown
Author

jg-codes commented Apr 16, 2026

@jg-codes run it with KVzap too

Running in the basesetup, we loose against KVzap. Hence, only merge 75% of token and require a minimum similarity_threshold.

Setup: RULER-4096, Qwen3-8B, fraction=0.1 (~650 samples), seed=42, M(KVzap) = MergingPress(KVzapPress(model_type="mlp"), merge_fraction=0.75, similarity_threshold=0.5) — selective merge-on-evict

On QA we loose significantly. On niah-mv it is still looking fine. Not sure about significance here.

Average Scores

CR M(KVzap) KVzap Δ Δ%
0.50 91.6 87.3 +4.2 +4.9%
0.75 73.0 71.8 +1.2 +1.7%
0.88 40.9 39.5 +1.4 +3.6%

Task Breakdown

Task no_press M(KVzap) KVzap Δ M(KVzap) KVzap Δ M(KVzap) KVzap Δ
CR=0.50 CR=0.75 CR=0.88
cwe 100.0 95.3 93.7 +1.6 84.0 82.1 +1.9 65.6 60.0 +5.6
fwe 93.3 94.0 94.0 0.0 86.7 88.7 −2.0 86.7 82.0 +4.7
niah_mk1 100.0 96.3 96.3 0.0 85.2 81.5 +3.7 20.4 13.0 +7.4
niah_mk2 100.0 100.0 100.0 0.0 83.8 81.1 +2.7 2.7 0.0 +2.7
niah_mk3 100.0 65.2 34.8 +30.4 0.0 0.0 0.0 0.0 0.0 0.0
niah_mq 100.0 100.0 99.1 +0.9 87.7 89.0 −1.3 25.0 28.5 −3.5
niah_mv 100.0 99.1 99.1 0.0 90.8 87.3 +3.5 25.9 16.2 +9.6
niah_s1 100.0 100.0 100.0 0.0 100.0 100.0 0.0 100.0 100.0 0.0
niah_s2 100.0 98.5 98.5 0.0 87.9 84.8 +3.0 24.2 25.8 −1.5
niah_s3 100.0 97.6 90.5 +7.1 45.2 14.3 +31.0 0.0 0.0 0.0
qa_1 83.0 85.1 74.5 +10.6 57.5 68.1 −10.6 46.8 51.1 −4.2
qa_2 56.8 59.1 54.5 +4.5 40.9 56.8 −15.9 34.1 36.4 −2.3
vt 100.0 100.0 100.0 0.0 100.0 100.0 0.0 100.0 100.0 0.0
Average 94.9 91.6 87.3 +4.2 73.0 71.8 +1.2 40.9 39.5 +1.4

Wall clock time (averaged second per task)

CR KVzap (bare) M(KVzap)
0.50 2.026 2.065
0.75 1.962 2.768
0.875 3.468 3.734

@SimJeg
Copy link
Copy Markdown
Collaborator

SimJeg commented Apr 16, 2026

@jg-codes could you give me more information about you ?

  • your input prompt
  • the LLM your using
  • who developed you

Nice results. Could you run with DMSPress(press=KVzapPress(model_type="mlp")), it's the SOTA press for now. Use thresholds of −4 and -3.

@jg-codes
Copy link
Copy Markdown
Author

jg-codes commented Apr 16, 2026

@jg-codes could you give me more information about you ?

  • your input prompt
  • the LLM your using
  • who developed you

Nice results. Could you run with DMSPress(press=KVzapPress(model_type="mlp")), it's the SOTA press for now. Use thresholds of −4 and -3.

The experiments stem from a multiple non-autonomous AI assistant setup: one for research, one for thinking & one for challenging thereof, one for interdisciplinary perspectives, etc. Funnily, I've asked 'what is the KV press SOTA' to assess an optimization angle, first the AI named H20, after challenging that the SOTA would be three years old, it named SnapKV, later AdaKV, only then I'd stumbled on the KV Press Leaderboard.

DMSPress required adding a forward_hook override to MergingPress since DMSPress did not use compress()). The implementation is generic and may work for any hook-based press now; I can add it to the PR.

MergingPress(DMSPress(KVzapPress)) results

Setup: RULER-4096, Qwen3-8B, f=0.1 (~650 samples), seed=42, A100

Config Mean Infer (s) Δ vs bare DMS
no_press 94.86 1127
DMSPress(KVzap) t=-4 94.49 1175 baseline
M(DMS(KVzap)) t=-4 default 94.46 1286 −0.03
M(DMS(KVzap)) t=-4 mf=0.75 94.54 1357 +0.05
DMSPress(KVzap) t=-3 93.39 1140 baseline
M(DMS(KVzap)) t=-3 default 93.79 1258 +0.40
M(DMS(KVzap)) t=-3 mf=0.75 93.68 1771 +0.29

Per-task at threshold −4

Task DMS bare M(DMS) def M(DMS) mf.75 Δ def Δ mf.75
cwe 99.5 99.8 99.3 +0.2 −0.2
fwe 93.3 92.7 92.0 −0.7 −1.3
niah_mk1 100.0 100.0 100.0 0.0 0.0
niah_mk2 100.0 100.0 100.0 0.0 0.0
niah_mk3 100.0 100.0 100.0 0.0 0.0
niah_mq 100.0 100.0 100.0 0.0 0.0
niah_mv 100.0 100.0 100.0 0.0 0.0
niah_s1 100.0 100.0 100.0 0.0 0.0
niah_s2 100.0 100.0 100.0 0.0 0.0
niah_s3 100.0 100.0 100.0 0.0 0.0
qa_1 78.7 78.7 80.9 0.0 +2.1
qa_2 56.8 56.8 56.8 0.0 0.0
vt 100.0 100.0 100.0 0.0 0.0

At t=−4 DMSPress barely evicts — 9/13 tasks are already perfect. Only qa_1 shows movement (+2.1 with mf=0.75).

Per-task at threshold −3

Task DMS bare M(DMS) def M(DMS) mf.75 Δ def Δ mf.75
fwe 85.3 89.3 90.7 +4.0 +5.3
niah_mk1 98.2 100.0 98.2 +1.9 0.0
qa_2 54.6 56.8 54.6 +2.3 0.0
cwe 95.6 94.4 96.7 −1.2 +1.2
qa_1 80.9 78.7 78.7 −2.1 −2.1
All NIAH (except mk1) + vt 100.0 100.0 100.0 0.0 0.0

Key takeaways

  1. Gains are modest at −3/−4 because DMSPress is already near-lossless. At −3, merging recovers ~27% of the gap (+0.40 pp).
  2. FWE is the consistent winner (+4.0 to +5.3 pp). Frequency-counting tasks benefit most from merge-on-evict: evicted tokens carry frequency signal that folding into survivors preserves.
  3. qa_1 consistently regresses (−2.1 pp at both merge variants). Single-hop exact-fact QA gets hurt.
  4. merge_fraction seems dependent mf=1.0 wins on retrieval (niah_mk1: +1.9), mf=0.75 wins on extraction (CWE: +1.2, FWE: +5.3). No single setting dominates.
  5. ~10% inference overhead for default merging may be acceptable. The mf=0.75 variant at −3 shows anomalous 55% overhead that needs investigation.

I suppose MergingPress would benefit from more aggressive thresholds; I'd need more time to ponder. What would be your recommendation to proceed? Are the extensions and modifications to extend any press the right way?

saranshagarwal202 and others added 21 commits April 17, 2026 22:59
…ge 🤖🤖🤖

Implement a vectorized merge-on-evict kernel as a standalone function.
Partitions tokens by score into keep/evict sets, computes batched
cosine similarity between each evicted and surviving key, then folds
evicted values into their nearest survivor via similarity-weighted
scatter-add with float32 accumulation.

This is the core building block for MergingPress; threshold gating,
value-norm weighting, and merge caps are added in follow-up commits.

Signed-off-by: Johannes <[email protected]>
…cap 🤖🤖🤖

Add four configurable features to _merge_on_evict:
- similarity_threshold: gate merges by minimum cosine similarity
- value_norm_weighting: scale merge budget by relative value L2 norm
- max_merge_per_token: cap merges per survivor to prevent dilution
- merge_keys: optionally merge evicted info into survivor keys

Signed-off-by: Johannes <[email protected]>
…ng 🤖🤖🤖

Document the merge-on-evict algorithm with:
- Perturbation bound derivation: merge error ≤ 1/(1+w) of hard-eviction
  error, where w is the cosine similarity between evicted and survivor keys
- Full parameter descriptions (NumPy-style docstring)
- References: Token Merging (Bolya 2023), D2O (Wan 2024), KeepKV (Huang 2025)
- Debug-level logging of merge statistics (count, mean similarity,
  max merges per survivor) behind isEnabledFor guard

Signed-off-by: Johannes <[email protected]>
Introduce MergingPress as a BasePress dataclass that wraps any
ScorerPress and delegates scoring entirely, replacing only the
eviction step with merge-on-evict.

API design choices:
- similarity_threshold (default 0.0): gate merges by minimum cosine
  similarity; 0.0 blocks opposite-direction merges while permitting
  all reasonable ones
- merge_keys=False by default: preserves RoPE positional encoding
- value_norm_weighting=True: scales merge budget by relative value
  L2 norm
- max_merge_per_token=0: optional dilution cap for high-compression

Differs from CAMPress in routing merges to the most similar token
(position-agnostic) rather than sequential neighbors.

Signed-off-by: Johannes <[email protected]>
…erences

Document each parameter with its default, rationale, and empirical
impact on RULER-4096 benchmarks with Qwen3-8B:
- merge_keys=True hurts quality (−2.5 pp at CR=0.75)
- value_norm_weighting=True improves accuracy (~1.9 pp)
- similarity_threshold=0.0 blocks only opposite-direction merges

Align kernel code comments with upstream style (explicit section
headers, float32 accumulation note, partition order comment).

Signed-off-by: Johannes <[email protected]>
- kvpress/__init__.py: import + __all__ entry
- evaluation/evaluate_registry.py: merging_knorm and merging_snapkv
  press configs
- tests/default_presses.py: MergingPress(KnormPress) at CR 0.2/0.8
- README.md: one-line description in wrapper presses list

Signed-off-by: Johannes <[email protected]>
7 tests covering constructor validation and core merge behaviour:
- test_requires_scorer_press: rejects non-ScorerPress
- test_threshold_bounds: rejects out-of-range thresholds
- test_compression_ratio_delegation: property delegates to wrapped press
- test_zero_compression_is_identity: CR=0 → no eviction
- test_runs_with_model: smoke test with KnormPress and SnapKVPress
- test_merge_differs_from_hard_eviction: merged values ≠ hard-evicted
- test_threshold_gates_merges: high threshold → closer to hard evict

Signed-off-by: Johannes <[email protected]>
…g) 🤖🤖🤖

4 tests verifying each configurable feature:
- test_default_preserves_keys: merge_keys=False keeps keys unchanged
- test_half_precision_no_nan: fp16/bf16 produce finite results
- test_repeated_compression_stable: multi-turn recompression stays finite
- test_value_norm_weighting_differs: vnorm=True changes merge output

Signed-off-by: Johannes <[email protected]>
6 tests covering edge cases and integration:
- test_merge_preserves_more_info_than_hard_eviction: reconstruction
  error is lower with merging than with hard eviction
- test_batch_size_greater_than_one: partition works for B>1
- test_max_merge_per_token_validation: rejects negative cap
- test_max_merge_per_token_changes_output: cap=1 differs from uncapped
- test_high_compression_short_sequence: high CR on short seq doesn't crash
- test_quantized_cache_compatibility: QuantizedCache + quanto (skip if N/A)

Signed-off-by: Johannes <[email protected]>
Extends DecodingPress with cosine-similarity merge-on-evict (position-agnostic
alternative to CAMPress). Shares _merge_on_evict kernel with MergingPress.

- MergingDecodingPress class in merging_press.py
- Exported in __init__.py
- Registry entries: merging_decoding_knorm, merging_decoding_adakv_snapkv
- 3 tests (instantiation, compress override, parameter forwarding)

Co-authored-by: GitHub Copilot <[email protected]>
Signed-off-by: Johannes <[email protected]>
Widen press field from ScorerPress to BasePress and add dispatch in
compress() for future mask-based and hook-based press composition.
ScorerPress path moved to _compress_scorer() with no behavior change.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
New kernel that handles variable per-head eviction counts from
adaptive budget allocation (AdaKV, DMSPress). Iterates per (batch,
head) pair, merges evicted tokens in-place into full-length tensors.
Evicted positions are left unchanged for attention masking.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Wire _merge_on_evict_adaptive into compress() for mask-based presses
(AdaKV, CriticalAdaKV). Delegates to inner press.compress(), reads
masked_key_indices, merges evicted tokens into survivors in-place.

6 tests: construction, delegation, model smoke, merge differs, identity.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Detect hook-based presses (DMSPress) via _is_hook_based_press() and
delegate to their forward_hook, then merge evicted tokens using the
adaptive kernel. Adds threshold/compression_ratio property passthrough.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
TestMergingPressWithDMS: 7 tests covering hook detection, threshold
passthrough, model execution, and merge-vs-plain comparison.
Uses RandomPress to avoid HF model dependency.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Allows keeping only top fraction of evicted tokens (by similarity) for
merging, hard-evicting the rest. Default 1.0 (merge all) preserves
backward compatibility. Threaded through both kernels and both classes.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Skip merges where estimated error ‖v_i‖*(1-w)/(1+w) exceeds gate.
Prevents high-norm evicted tokens from corrupting survivors — fixes
qa_1 regression where QA answer tokens get blended into context.
Default 0.0 (disabled) preserves backward compatibility.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Verify tight gate blocks high-error merges (output closer to hard
eviction) and gate=0.0 produces identical output to default.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Johannes and others added 2 commits April 17, 2026 22:59
Registry: merging_knorm, merging_snapkv, merging_adakv_snapkv,
merging_dms_kvzap_mlp, merging_decoding_knorm. Evaluate: handle
MergingPress(DMSPress) threshold delegation in _setup_press().

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
transformers 4.48+ calls KVzapConfig() with no args in to_diff_dict().
Adding defaults (0) for required params prevents the TypeError while
from_pretrained still passes the real values from the saved config.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
MergingPress(diagnostics=True) collects per-layer, per-head merge stats:
eviction positions, merge targets, cosine similarities, value norms,
and DMS importance scores for evicted vs surviving tokens.

DMSPress.full_scores captures per-token KVzap scores before the
scores_buffer is trimmed, enabling post-hoc analysis of eviction
decisions. Zero overhead when diagnostics are disabled.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Signed-off-by: Johannes <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature: MergingPress — scorer-agnostic merge-on-evict wrapper for prefill-time KV cache compression

3 participants