retrieval-heads 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Max Zuo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,112 @@
1
+ Metadata-Version: 2.4
2
+ Name: retrieval-heads
3
+ Version: 0.1.0
4
+ Summary: Retrieval Head detection in LLMs with vLLM
5
+ Author-email: Max Zuo <zuo@brown.edu>
6
+ Requires-Python: >=3.12
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: matplotlib>=3.11.0
10
+ Requires-Dist: nnsight>=0.7.0
11
+ Requires-Dist: pyyaml>=6.0.3
12
+ Requires-Dist: rouge-score>=0.1.2
13
+ Requires-Dist: seaborn>=0.13.2
14
+ Requires-Dist: torch>=2.10.0
15
+ Requires-Dist: tqdm>=4.68.2
16
+ Requires-Dist: tyro>=1.0.13
17
+ Requires-Dist: vllm==0.19.0
18
+ Dynamic: license-file
19
+
20
+ # retrieval-heads
21
+
22
+ Retrieval head detection in LLMs using vLLM and nnsight activation tracing.
23
+
24
+ This is my attempt to faithfully reproduce [Retrieval Head Mechanistically Explains Long-Context Factuality](https://arxiv.org/abs/2404.15574), and should work out of the box with any model that uses vLLM's `Attention` or `GatedDeltaNetAttention` implementations.
25
+
26
+ Two main workflows:
27
+
28
+ 1. **Needle-in-a-haystack (NIAH)** – insert a known fact into a long context at varying depths and lengths, then measure retrieval accuracy (ROUGE-L).
29
+ 2. **Retrieval head detection** – trace query/key activations through every attention head on NIAH results to identify which heads are responsible for retrieval.
30
+
31
+ ## Example Results
32
+
33
+ ### NIAH Heatmap
34
+
35
+ ![NIAH Heatmap](imgs/heatmap.png)
36
+
37
+ ### Retrieval Head Detection
38
+
39
+ ![Retrieval Head Detection Heatmap](imgs/detect_heatmap.png)
40
+
41
+ ## Setup
42
+
43
+ Installation:
44
+ ```bash
45
+ git clone https://github.com/maxzuo/retrieval-heads.git
46
+ pip install -e .
47
+ ```
48
+ Tested using Python 3.12 and vLLM 0.19.0.
49
+
50
+ ## Usage
51
+
52
+ ### NIAH sweep
53
+
54
+ ```bash
55
+ retrieval-heads.niah --config configs/qwen3_5_9b.yaml
56
+ ```
57
+
58
+ Runs the needle-in-a-haystack evaluation across a grid of context lengths and
59
+ document depths. Results are written to `output_dir` as `results.jsonl` (one
60
+ JSON record per cell) alongside the resolved `config.yaml`.
61
+
62
+ Any config field can be overridden via CLI flags:
63
+
64
+ ```bash
65
+ retrieval-heads.niah --config configs/qwen3_5_9b.yaml \
66
+ --model.max-model-len 16384 \
67
+ --output-dir ./results/short
68
+ ```
69
+
70
+ ### Retrieval head detection
71
+
72
+ ```bash
73
+ retrieval-heads.detect --config configs/detect.yaml
74
+ ```
75
+
76
+ Takes NIAH result files as input, traces each forward pass with nnsight to
77
+ capture per-head query/key matrices, and scores each head on whether it attends
78
+ to the needle span. Outputs `detected.json` and `detected-agg.json`.
79
+
80
+ ### Visualization
81
+
82
+ ```bash
83
+ retrieval-heads.visualize niah --results results/qwen3_5_9b/results.jsonl
84
+ retrieval-heads.visualize detect --results results/detect/detected-agg.json
85
+ ```
86
+
87
+ ## Configuration
88
+
89
+ Configs are YAML files with the following sections:
90
+
91
+ ```yaml
92
+ model:
93
+ model: Qwen/Qwen3.5-9B
94
+ max_model_len: 32768
95
+ dtype: bfloat16
96
+ chat_template: path/to/template.jinja
97
+ language_model_only: true
98
+
99
+ haystack:
100
+ haystack_dir: ./PaulGrahamEssays
101
+ needle: "\nThe best thing to do in San Francisco is eat a sandwich...\n"
102
+ retrieval_question: "What is the best thing to do in San Francisco?"
103
+
104
+ sweep:
105
+ context_lengths: {min: 1000, max: 32000, intervals: 31}
106
+ document_depths: {min: 0, max: 100, intervals: 10}
107
+
108
+ output_dir: ./results/qwen3_5_9b
109
+ ```
110
+
111
+ Sweep dimensions accept either a `{min, max, intervals}` shorthand or an
112
+ explicit list of values.
@@ -0,0 +1,93 @@
1
+ # retrieval-heads
2
+
3
+ Retrieval head detection in LLMs using vLLM and nnsight activation tracing.
4
+
5
+ This is my attempt to faithfully reproduce [Retrieval Head Mechanistically Explains Long-Context Factuality](https://arxiv.org/abs/2404.15574), and should work out of the box with any model that uses vLLM's `Attention` or `GatedDeltaNetAttention` implementations.
6
+
7
+ Two main workflows:
8
+
9
+ 1. **Needle-in-a-haystack (NIAH)** – insert a known fact into a long context at varying depths and lengths, then measure retrieval accuracy (ROUGE-L).
10
+ 2. **Retrieval head detection** – trace query/key activations through every attention head on NIAH results to identify which heads are responsible for retrieval.
11
+
12
+ ## Example Results
13
+
14
+ ### NIAH Heatmap
15
+
16
+ ![NIAH Heatmap](imgs/heatmap.png)
17
+
18
+ ### Retrieval Head Detection
19
+
20
+ ![Retrieval Head Detection Heatmap](imgs/detect_heatmap.png)
21
+
22
+ ## Setup
23
+
24
+ Installation:
25
+ ```bash
26
+ git clone https://github.com/maxzuo/retrieval-heads.git
27
+ pip install -e .
28
+ ```
29
+ Tested using Python 3.12 and vLLM 0.19.0.
30
+
31
+ ## Usage
32
+
33
+ ### NIAH sweep
34
+
35
+ ```bash
36
+ retrieval-heads.niah --config configs/qwen3_5_9b.yaml
37
+ ```
38
+
39
+ Runs the needle-in-a-haystack evaluation across a grid of context lengths and
40
+ document depths. Results are written to `output_dir` as `results.jsonl` (one
41
+ JSON record per cell) alongside the resolved `config.yaml`.
42
+
43
+ Any config field can be overridden via CLI flags:
44
+
45
+ ```bash
46
+ retrieval-heads.niah --config configs/qwen3_5_9b.yaml \
47
+ --model.max-model-len 16384 \
48
+ --output-dir ./results/short
49
+ ```
50
+
51
+ ### Retrieval head detection
52
+
53
+ ```bash
54
+ retrieval-heads.detect --config configs/detect.yaml
55
+ ```
56
+
57
+ Takes NIAH result files as input, traces each forward pass with nnsight to
58
+ capture per-head query/key matrices, and scores each head on whether it attends
59
+ to the needle span. Outputs `detected.json` and `detected-agg.json`.
60
+
61
+ ### Visualization
62
+
63
+ ```bash
64
+ retrieval-heads.visualize niah --results results/qwen3_5_9b/results.jsonl
65
+ retrieval-heads.visualize detect --results results/detect/detected-agg.json
66
+ ```
67
+
68
+ ## Configuration
69
+
70
+ Configs are YAML files with the following sections:
71
+
72
+ ```yaml
73
+ model:
74
+ model: Qwen/Qwen3.5-9B
75
+ max_model_len: 32768
76
+ dtype: bfloat16
77
+ chat_template: path/to/template.jinja
78
+ language_model_only: true
79
+
80
+ haystack:
81
+ haystack_dir: ./PaulGrahamEssays
82
+ needle: "\nThe best thing to do in San Francisco is eat a sandwich...\n"
83
+ retrieval_question: "What is the best thing to do in San Francisco?"
84
+
85
+ sweep:
86
+ context_lengths: {min: 1000, max: 32000, intervals: 31}
87
+ document_depths: {min: 0, max: 100, intervals: 10}
88
+
89
+ output_dir: ./results/qwen3_5_9b
90
+ ```
91
+
92
+ Sweep dimensions accept either a `{min, max, intervals}` shorthand or an
93
+ explicit list of values.
@@ -0,0 +1,39 @@
1
+ [build-system]
2
+ requires = ["setuptools>=75"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "retrieval-heads"
7
+ version = "0.1.0"
8
+ description = "Retrieval Head detection in LLMs with vLLM"
9
+ authors = [{name = "Max Zuo", email = "zuo@brown.edu"}]
10
+ readme = "README.md"
11
+ requires-python = ">=3.12"
12
+ dependencies = [
13
+ "matplotlib>=3.11.0",
14
+ "nnsight>=0.7.0",
15
+ "pyyaml>=6.0.3",
16
+ "rouge-score>=0.1.2",
17
+ "seaborn>=0.13.2",
18
+ "torch>=2.10.0",
19
+ "tqdm>=4.68.2",
20
+ "tyro>=1.0.13",
21
+ "vllm==0.19.0",
22
+ ]
23
+
24
+ [project.scripts]
25
+ "retrieval-heads.niah" = "retrieval_heads.cli:niah_cli"
26
+ "retrieval-heads.detect" = "retrieval_heads.cli:detect_cli"
27
+ "retrieval-heads.visualize" = "retrieval_heads.visualize:cli"
28
+
29
+ [dependency-groups]
30
+ dev = [
31
+ "pytest>=8",
32
+ ]
33
+
34
+ [tool.setuptools.packages.find]
35
+ include = ["retrieval_heads*"]
36
+
37
+ [tool.pytest.ini_options]
38
+ pythonpath = ["."]
39
+ testpaths = ["tests"]
@@ -0,0 +1,6 @@
1
+ from .configs import (
2
+ HaystackConfig,
3
+ ModelConfig,
4
+ RangeConfig,
5
+ SweepConfig,
6
+ )
@@ -0,0 +1,254 @@
1
+ import argparse
2
+ from dataclasses import asdict, dataclass, field
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import tqdm
9
+ import tyro
10
+ import yaml
11
+
12
+ from .configs import (
13
+ HaystackConfig,
14
+ ModelConfig,
15
+ SweepConfig,
16
+ )
17
+
18
+
19
+ # Experiment-level configs
20
+ @dataclass
21
+ class NIAHConfig:
22
+ """Top-level config for a needle-in-a-haystack run."""
23
+ model: ModelConfig
24
+ haystack: HaystackConfig
25
+ sweep: SweepConfig = field(default_factory=SweepConfig)
26
+ output_dir: str | None = None
27
+
28
+
29
+ @dataclass
30
+ class DetectConfig:
31
+ """Top-level config for retrieval-head detection."""
32
+ model: ModelConfig
33
+ results_files: tuple[str, ...] = ()
34
+ output_dir: str | None = None
35
+ sweep: SweepConfig | None = None
36
+ score_threshold: float | None = None
37
+
38
+
39
+ # Config loading functions
40
+
41
+
42
+ def load_niah_config(path: str | os.PathLike) -> NIAHConfig:
43
+ """Load a NIAHConfig from a YAML file."""
44
+ with open(path) as f:
45
+ raw = yaml.safe_load(f) or {}
46
+
47
+ return NIAHConfig(
48
+ model=ModelConfig(**raw['model']),
49
+ haystack=HaystackConfig(**raw['haystack']),
50
+ sweep=SweepConfig(**raw.get('sweep', {})),
51
+ output_dir=raw.get('output_dir'),
52
+ )
53
+
54
+
55
+ def load_detect_config(path: str | os.PathLike) -> DetectConfig:
56
+ """Load a DetectConfig from a YAML file."""
57
+ with open(path) as f:
58
+ raw = yaml.safe_load(f) or {}
59
+
60
+ results_files = raw.get('results_files', ())
61
+ if isinstance(results_files, str):
62
+ results_files = (results_files,)
63
+ else:
64
+ results_files = tuple(results_files)
65
+
66
+ return DetectConfig(
67
+ model=ModelConfig(**raw['model']),
68
+ results_files=results_files,
69
+ output_dir=raw.get('output_dir'),
70
+ sweep=SweepConfig(**raw['sweep']) if 'sweep' in raw else None,
71
+ score_threshold=raw.get('score_threshold'),
72
+ )
73
+
74
+
75
+ # Load results.jsonl file
76
+
77
+
78
+ def load_results(path: str | os.PathLike) -> list[dict[str, Any]]:
79
+ """Load and validate NIAH result records from a JSONL file."""
80
+ required_fields = {
81
+ 'context_length',
82
+ 'document_depth',
83
+ 'needle',
84
+ 'prompt',
85
+ 'token_ids',
86
+ }
87
+ results = []
88
+
89
+ with open(path) as f:
90
+ for line_number, line in enumerate(f, start=1):
91
+ if not line.strip():
92
+ continue
93
+ try:
94
+ result = json.loads(line)
95
+ except json.JSONDecodeError as error:
96
+ raise ValueError(
97
+ f'Invalid JSON in {path} at line {line_number}: {error.msg}'
98
+ ) from error
99
+
100
+ if not isinstance(result, dict):
101
+ raise ValueError(
102
+ f'Expected a JSON object in {path} at line {line_number}.')
103
+
104
+ missing = sorted(required_fields - result.keys())
105
+ if missing:
106
+ raise ValueError(
107
+ f'Missing required fields in {path} at line {line_number}: '
108
+ f'{", ".join(missing)}')
109
+ results.append(result)
110
+
111
+ if not results:
112
+ raise ValueError(f'No result records found in {path}.')
113
+
114
+ return results
115
+
116
+
117
+ def filter_results(
118
+ results: list[dict[str, Any]],
119
+ sweep: SweepConfig | None,
120
+ ) -> list[dict[str, Any]]:
121
+ """Filter results to exact context-length/document-depth sweep cells."""
122
+ if sweep is None:
123
+ return results
124
+
125
+ context_lengths = set(sweep.context_lengths)
126
+ document_depths = set(sweep.document_depths)
127
+ return [
128
+ result for result in results
129
+ if result['context_length'] in context_lengths and
130
+ result['document_depth'] in document_depths
131
+ ]
132
+
133
+
134
+ # Main experiment functions
135
+
136
+
137
+ def niah_main(config: NIAHConfig):
138
+ """Run a NIAH sweep and save results + resolved config."""
139
+ from .niah import NeedleInAHaystack
140
+
141
+ niah = NeedleInAHaystack(
142
+ haystack_config=config.haystack,
143
+ sweep_config=config.sweep,
144
+ model_config=config.model,
145
+ )
146
+ niah.run(output_path=os.path.join(config.output_dir, 'results.jsonl'))
147
+
148
+ with open(os.path.join(config.output_dir, 'config.yaml'), 'w') as f:
149
+ yaml.dump(asdict(config), f)
150
+
151
+
152
+ def detect_main(config: DetectConfig, detector_cls=None):
153
+ """Trace selected NIAH results and save aggregate retrieval-head scores."""
154
+ if not config.results_files:
155
+ raise ValueError(
156
+ 'results_files is required: set it in the --config YAML or pass '
157
+ '--results-files.')
158
+ if config.output_dir is None:
159
+ raise ValueError(
160
+ 'output_dir is required: set it in the --config YAML or pass '
161
+ '--output-dir.')
162
+
163
+ all_results = []
164
+ for path in config.results_files:
165
+ all_results.extend(load_results(path))
166
+ results = filter_results(all_results, config.sweep)
167
+ if not results:
168
+ raise ValueError('No results matched the configured sweep.')
169
+
170
+ if config.score_threshold is not None:
171
+ for result in results:
172
+ if 'rougeL' not in result:
173
+ raise ValueError(
174
+ 'score_threshold is set but result record is missing "rougeL".')
175
+ results = [
176
+ result for result in results
177
+ if result['rougeL'] >= config.score_threshold
178
+ ]
179
+ if not results:
180
+ raise ValueError('No results met the score threshold.')
181
+
182
+ if detector_cls is None:
183
+ from .detect import RetrievalHeadDetector
184
+ detector_cls = RetrievalHeadDetector
185
+
186
+ detector = detector_cls(model_config=config.model)
187
+ for result in tqdm.tqdm(results, desc='Tracing results'):
188
+ detector.calculate(
189
+ prompt=result['prompt'],
190
+ needle=result['needle'],
191
+ completion_tokens=result['token_ids'],
192
+ )
193
+
194
+ detector.save(config.output_dir)
195
+
196
+
197
+ # cli entry points
198
+
199
+ def niah_cli():
200
+ """Entry point for ``retrieval-heads.niah``."""
201
+ pre = argparse.ArgumentParser(add_help=False)
202
+ pre.add_argument(
203
+ '--config',
204
+ type=str,
205
+ default=None,
206
+ help='Path to a YAML config with top-level "model", "haystack", "sweep" '
207
+ 'and optional "output_dir" sections.',
208
+ )
209
+ known, remaining = pre.parse_known_args()
210
+
211
+ default = load_niah_config(
212
+ known.config) if known.config else tyro.MISSING_NONPROP
213
+ config = tyro.cli(
214
+ NIAHConfig,
215
+ description='Run Needle-in-a-Haystack (NIAH) file.\n\n'
216
+ 'Pass --config with a YAML file to load config from file.',
217
+ args=remaining,
218
+ default=default,
219
+ )
220
+
221
+ if config.output_dir is None:
222
+ raise SystemExit(
223
+ 'output_dir is required: set it in the --config YAML or pass --output-dir.'
224
+ )
225
+
226
+ niah_main(config)
227
+
228
+
229
+ def detect_cli():
230
+ """Entry point for ``retrieval-heads.detect``."""
231
+ pre = argparse.ArgumentParser(add_help=False)
232
+ pre.add_argument(
233
+ '--config',
234
+ type=str,
235
+ default=None,
236
+ help='Path to a YAML config with top-level "model", "haystack", '
237
+ 'optional "sweep", "results_file", and "output_path" sections.',
238
+ )
239
+ known, remaining = pre.parse_known_args()
240
+
241
+ default = load_detect_config(
242
+ known.config) if known.config else tyro.MISSING_NONPROP
243
+ config = tyro.cli(
244
+ DetectConfig,
245
+ description='Run retrieval-head detection over saved NIAH JSONL results.'
246
+ '\n\nPass --config with a YAML file to load config defaults.',
247
+ args=remaining,
248
+ default=default,
249
+ )
250
+
251
+ try:
252
+ detect_main(config)
253
+ except ValueError as e:
254
+ raise SystemExit(str(e)) from e
@@ -0,0 +1,55 @@
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class ModelConfig:
6
+ model: str
7
+ max_model_len: int | None = None
8
+ trust_remote_code: bool = False
9
+ dtype: str = 'auto'
10
+ tensor_parallel_size: int = 1
11
+ gpu_memory_utilization: float = 0.9
12
+ enable_prefix_caching: bool = True
13
+ chat_template: str | None = None
14
+ enforce_eager: bool = False
15
+ language_model_only: bool = False
16
+
17
+
18
+ @dataclass
19
+ class RangeConfig:
20
+ min: int
21
+ max: int
22
+ intervals: int
23
+
24
+
25
+ @dataclass
26
+ class SweepConfig:
27
+ context_lengths: list[int] = field(
28
+ default_factory=lambda: list(range(1_000, 50_000, 1_000)))
29
+ document_depths: list[int] = field(
30
+ default_factory=lambda: list(range(0, 100, 10)))
31
+
32
+ def __post_init__(self):
33
+
34
+ def expand_range(r: list[float] | list[int] | RangeConfig) -> list[int]:
35
+ if isinstance(r, RangeConfig):
36
+ return list(range(r.min, r.max + 1, (r.max - r.min) // r.intervals))
37
+ elif isinstance(r, dict):
38
+ return list(
39
+ range(
40
+ r['min'],
41
+ r['max'] + 1,
42
+ (r['max'] - r['min']) // r['intervals'],
43
+ ))
44
+ else:
45
+ return sorted(list(r))
46
+
47
+ self.context_lengths = expand_range(self.context_lengths)
48
+ self.document_depths = expand_range(self.document_depths)
49
+
50
+
51
+ @dataclass
52
+ class HaystackConfig:
53
+ haystack_dir: str
54
+ needle: str = '\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n'
55
+ retrieval_question: str = 'What is the best thing to do in San Francisco?'