retrieval-heads 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retrieval_heads-0.1.0/LICENSE +21 -0
- retrieval_heads-0.1.0/PKG-INFO +112 -0
- retrieval_heads-0.1.0/README.md +93 -0
- retrieval_heads-0.1.0/pyproject.toml +39 -0
- retrieval_heads-0.1.0/retrieval_heads/__init__.py +6 -0
- retrieval_heads-0.1.0/retrieval_heads/cli.py +254 -0
- retrieval_heads-0.1.0/retrieval_heads/configs.py +55 -0
- retrieval_heads-0.1.0/retrieval_heads/detect.py +256 -0
- retrieval_heads-0.1.0/retrieval_heads/niah.py +302 -0
- retrieval_heads-0.1.0/retrieval_heads/nnsight_utils.py +143 -0
- retrieval_heads-0.1.0/retrieval_heads/trace.py +161 -0
- retrieval_heads-0.1.0/retrieval_heads/visualize.py +261 -0
- retrieval_heads-0.1.0/retrieval_heads.egg-info/PKG-INFO +112 -0
- retrieval_heads-0.1.0/retrieval_heads.egg-info/SOURCES.txt +17 -0
- retrieval_heads-0.1.0/retrieval_heads.egg-info/dependency_links.txt +1 -0
- retrieval_heads-0.1.0/retrieval_heads.egg-info/entry_points.txt +4 -0
- retrieval_heads-0.1.0/retrieval_heads.egg-info/requires.txt +9 -0
- retrieval_heads-0.1.0/retrieval_heads.egg-info/top_level.txt +1 -0
- retrieval_heads-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Max Zuo
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: retrieval-heads
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Retrieval Head detection in LLMs with vLLM
|
|
5
|
+
Author-email: Max Zuo <zuo@brown.edu>
|
|
6
|
+
Requires-Python: >=3.12
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: matplotlib>=3.11.0
|
|
10
|
+
Requires-Dist: nnsight>=0.7.0
|
|
11
|
+
Requires-Dist: pyyaml>=6.0.3
|
|
12
|
+
Requires-Dist: rouge-score>=0.1.2
|
|
13
|
+
Requires-Dist: seaborn>=0.13.2
|
|
14
|
+
Requires-Dist: torch>=2.10.0
|
|
15
|
+
Requires-Dist: tqdm>=4.68.2
|
|
16
|
+
Requires-Dist: tyro>=1.0.13
|
|
17
|
+
Requires-Dist: vllm==0.19.0
|
|
18
|
+
Dynamic: license-file
|
|
19
|
+
|
|
20
|
+
# retrieval-heads
|
|
21
|
+
|
|
22
|
+
Retrieval head detection in LLMs using vLLM and nnsight activation tracing.
|
|
23
|
+
|
|
24
|
+
This is my attempt to faithfully reproduce [Retrieval Head Mechanistically Explains Long-Context Factuality](https://arxiv.org/abs/2404.15574), and should work out of the box with any model that uses vLLM's `Attention` or `GatedDeltaNetAttention` implementations.
|
|
25
|
+
|
|
26
|
+
Two main workflows:
|
|
27
|
+
|
|
28
|
+
1. **Needle-in-a-haystack (NIAH)** – insert a known fact into a long context at varying depths and lengths, then measure retrieval accuracy (ROUGE-L).
|
|
29
|
+
2. **Retrieval head detection** – trace query/key activations through every attention head on NIAH results to identify which heads are responsible for retrieval.
|
|
30
|
+
|
|
31
|
+
## Example Results
|
|
32
|
+
|
|
33
|
+
### NIAH Heatmap
|
|
34
|
+
|
|
35
|
+

|
|
36
|
+
|
|
37
|
+
### Retrieval Head Detection
|
|
38
|
+
|
|
39
|
+

|
|
40
|
+
|
|
41
|
+
## Setup
|
|
42
|
+
|
|
43
|
+
Installation:
|
|
44
|
+
```bash
|
|
45
|
+
git clone https://github.com/maxzuo/retrieval-heads.git
|
|
46
|
+
pip install -e .
|
|
47
|
+
```
|
|
48
|
+
Tested using Python 3.12 and vLLM 0.19.0.
|
|
49
|
+
|
|
50
|
+
## Usage
|
|
51
|
+
|
|
52
|
+
### NIAH sweep
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
retrieval-heads.niah --config configs/qwen3_5_9b.yaml
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Runs the needle-in-a-haystack evaluation across a grid of context lengths and
|
|
59
|
+
document depths. Results are written to `output_dir` as `results.jsonl` (one
|
|
60
|
+
JSON record per cell) alongside the resolved `config.yaml`.
|
|
61
|
+
|
|
62
|
+
Any config field can be overridden via CLI flags:
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
retrieval-heads.niah --config configs/qwen3_5_9b.yaml \
|
|
66
|
+
--model.max-model-len 16384 \
|
|
67
|
+
--output-dir ./results/short
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### Retrieval head detection
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
retrieval-heads.detect --config configs/detect.yaml
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Takes NIAH result files as input, traces each forward pass with nnsight to
|
|
77
|
+
capture per-head query/key matrices, and scores each head on whether it attends
|
|
78
|
+
to the needle span. Outputs `detected.json` and `detected-agg.json`.
|
|
79
|
+
|
|
80
|
+
### Visualization
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
retrieval-heads.visualize niah --results results/qwen3_5_9b/results.jsonl
|
|
84
|
+
retrieval-heads.visualize detect --results results/detect/detected-agg.json
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
## Configuration
|
|
88
|
+
|
|
89
|
+
Configs are YAML files with the following sections:
|
|
90
|
+
|
|
91
|
+
```yaml
|
|
92
|
+
model:
|
|
93
|
+
model: Qwen/Qwen3.5-9B
|
|
94
|
+
max_model_len: 32768
|
|
95
|
+
dtype: bfloat16
|
|
96
|
+
chat_template: path/to/template.jinja
|
|
97
|
+
language_model_only: true
|
|
98
|
+
|
|
99
|
+
haystack:
|
|
100
|
+
haystack_dir: ./PaulGrahamEssays
|
|
101
|
+
needle: "\nThe best thing to do in San Francisco is eat a sandwich...\n"
|
|
102
|
+
retrieval_question: "What is the best thing to do in San Francisco?"
|
|
103
|
+
|
|
104
|
+
sweep:
|
|
105
|
+
context_lengths: {min: 1000, max: 32000, intervals: 31}
|
|
106
|
+
document_depths: {min: 0, max: 100, intervals: 10}
|
|
107
|
+
|
|
108
|
+
output_dir: ./results/qwen3_5_9b
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
Sweep dimensions accept either a `{min, max, intervals}` shorthand or an
|
|
112
|
+
explicit list of values.
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# retrieval-heads
|
|
2
|
+
|
|
3
|
+
Retrieval head detection in LLMs using vLLM and nnsight activation tracing.
|
|
4
|
+
|
|
5
|
+
This is my attempt to faithfully reproduce [Retrieval Head Mechanistically Explains Long-Context Factuality](https://arxiv.org/abs/2404.15574), and should work out of the box with any model that uses vLLM's `Attention` or `GatedDeltaNetAttention` implementations.
|
|
6
|
+
|
|
7
|
+
Two main workflows:
|
|
8
|
+
|
|
9
|
+
1. **Needle-in-a-haystack (NIAH)** – insert a known fact into a long context at varying depths and lengths, then measure retrieval accuracy (ROUGE-L).
|
|
10
|
+
2. **Retrieval head detection** – trace query/key activations through every attention head on NIAH results to identify which heads are responsible for retrieval.
|
|
11
|
+
|
|
12
|
+
## Example Results
|
|
13
|
+
|
|
14
|
+
### NIAH Heatmap
|
|
15
|
+
|
|
16
|
+

|
|
17
|
+
|
|
18
|
+
### Retrieval Head Detection
|
|
19
|
+
|
|
20
|
+

|
|
21
|
+
|
|
22
|
+
## Setup
|
|
23
|
+
|
|
24
|
+
Installation:
|
|
25
|
+
```bash
|
|
26
|
+
git clone https://github.com/maxzuo/retrieval-heads.git
|
|
27
|
+
pip install -e .
|
|
28
|
+
```
|
|
29
|
+
Tested using Python 3.12 and vLLM 0.19.0.
|
|
30
|
+
|
|
31
|
+
## Usage
|
|
32
|
+
|
|
33
|
+
### NIAH sweep
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
retrieval-heads.niah --config configs/qwen3_5_9b.yaml
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
Runs the needle-in-a-haystack evaluation across a grid of context lengths and
|
|
40
|
+
document depths. Results are written to `output_dir` as `results.jsonl` (one
|
|
41
|
+
JSON record per cell) alongside the resolved `config.yaml`.
|
|
42
|
+
|
|
43
|
+
Any config field can be overridden via CLI flags:
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
retrieval-heads.niah --config configs/qwen3_5_9b.yaml \
|
|
47
|
+
--model.max-model-len 16384 \
|
|
48
|
+
--output-dir ./results/short
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Retrieval head detection
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
retrieval-heads.detect --config configs/detect.yaml
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
Takes NIAH result files as input, traces each forward pass with nnsight to
|
|
58
|
+
capture per-head query/key matrices, and scores each head on whether it attends
|
|
59
|
+
to the needle span. Outputs `detected.json` and `detected-agg.json`.
|
|
60
|
+
|
|
61
|
+
### Visualization
|
|
62
|
+
|
|
63
|
+
```bash
|
|
64
|
+
retrieval-heads.visualize niah --results results/qwen3_5_9b/results.jsonl
|
|
65
|
+
retrieval-heads.visualize detect --results results/detect/detected-agg.json
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## Configuration
|
|
69
|
+
|
|
70
|
+
Configs are YAML files with the following sections:
|
|
71
|
+
|
|
72
|
+
```yaml
|
|
73
|
+
model:
|
|
74
|
+
model: Qwen/Qwen3.5-9B
|
|
75
|
+
max_model_len: 32768
|
|
76
|
+
dtype: bfloat16
|
|
77
|
+
chat_template: path/to/template.jinja
|
|
78
|
+
language_model_only: true
|
|
79
|
+
|
|
80
|
+
haystack:
|
|
81
|
+
haystack_dir: ./PaulGrahamEssays
|
|
82
|
+
needle: "\nThe best thing to do in San Francisco is eat a sandwich...\n"
|
|
83
|
+
retrieval_question: "What is the best thing to do in San Francisco?"
|
|
84
|
+
|
|
85
|
+
sweep:
|
|
86
|
+
context_lengths: {min: 1000, max: 32000, intervals: 31}
|
|
87
|
+
document_depths: {min: 0, max: 100, intervals: 10}
|
|
88
|
+
|
|
89
|
+
output_dir: ./results/qwen3_5_9b
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
Sweep dimensions accept either a `{min, max, intervals}` shorthand or an
|
|
93
|
+
explicit list of values.
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=75"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "retrieval-heads"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Retrieval Head detection in LLMs with vLLM"
|
|
9
|
+
authors = [{name = "Max Zuo", email = "zuo@brown.edu"}]
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
requires-python = ">=3.12"
|
|
12
|
+
dependencies = [
|
|
13
|
+
"matplotlib>=3.11.0",
|
|
14
|
+
"nnsight>=0.7.0",
|
|
15
|
+
"pyyaml>=6.0.3",
|
|
16
|
+
"rouge-score>=0.1.2",
|
|
17
|
+
"seaborn>=0.13.2",
|
|
18
|
+
"torch>=2.10.0",
|
|
19
|
+
"tqdm>=4.68.2",
|
|
20
|
+
"tyro>=1.0.13",
|
|
21
|
+
"vllm==0.19.0",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[project.scripts]
|
|
25
|
+
"retrieval-heads.niah" = "retrieval_heads.cli:niah_cli"
|
|
26
|
+
"retrieval-heads.detect" = "retrieval_heads.cli:detect_cli"
|
|
27
|
+
"retrieval-heads.visualize" = "retrieval_heads.visualize:cli"
|
|
28
|
+
|
|
29
|
+
[dependency-groups]
|
|
30
|
+
dev = [
|
|
31
|
+
"pytest>=8",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.setuptools.packages.find]
|
|
35
|
+
include = ["retrieval_heads*"]
|
|
36
|
+
|
|
37
|
+
[tool.pytest.ini_options]
|
|
38
|
+
pythonpath = ["."]
|
|
39
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from dataclasses import asdict, dataclass, field
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import tqdm
|
|
9
|
+
import tyro
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
|
+
from .configs import (
|
|
13
|
+
HaystackConfig,
|
|
14
|
+
ModelConfig,
|
|
15
|
+
SweepConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Experiment-level configs
|
|
20
|
+
@dataclass
|
|
21
|
+
class NIAHConfig:
|
|
22
|
+
"""Top-level config for a needle-in-a-haystack run."""
|
|
23
|
+
model: ModelConfig
|
|
24
|
+
haystack: HaystackConfig
|
|
25
|
+
sweep: SweepConfig = field(default_factory=SweepConfig)
|
|
26
|
+
output_dir: str | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class DetectConfig:
|
|
31
|
+
"""Top-level config for retrieval-head detection."""
|
|
32
|
+
model: ModelConfig
|
|
33
|
+
results_files: tuple[str, ...] = ()
|
|
34
|
+
output_dir: str | None = None
|
|
35
|
+
sweep: SweepConfig | None = None
|
|
36
|
+
score_threshold: float | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Config loading functions
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_niah_config(path: str | os.PathLike) -> NIAHConfig:
|
|
43
|
+
"""Load a NIAHConfig from a YAML file."""
|
|
44
|
+
with open(path) as f:
|
|
45
|
+
raw = yaml.safe_load(f) or {}
|
|
46
|
+
|
|
47
|
+
return NIAHConfig(
|
|
48
|
+
model=ModelConfig(**raw['model']),
|
|
49
|
+
haystack=HaystackConfig(**raw['haystack']),
|
|
50
|
+
sweep=SweepConfig(**raw.get('sweep', {})),
|
|
51
|
+
output_dir=raw.get('output_dir'),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_detect_config(path: str | os.PathLike) -> DetectConfig:
|
|
56
|
+
"""Load a DetectConfig from a YAML file."""
|
|
57
|
+
with open(path) as f:
|
|
58
|
+
raw = yaml.safe_load(f) or {}
|
|
59
|
+
|
|
60
|
+
results_files = raw.get('results_files', ())
|
|
61
|
+
if isinstance(results_files, str):
|
|
62
|
+
results_files = (results_files,)
|
|
63
|
+
else:
|
|
64
|
+
results_files = tuple(results_files)
|
|
65
|
+
|
|
66
|
+
return DetectConfig(
|
|
67
|
+
model=ModelConfig(**raw['model']),
|
|
68
|
+
results_files=results_files,
|
|
69
|
+
output_dir=raw.get('output_dir'),
|
|
70
|
+
sweep=SweepConfig(**raw['sweep']) if 'sweep' in raw else None,
|
|
71
|
+
score_threshold=raw.get('score_threshold'),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Load results.jsonl file
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_results(path: str | os.PathLike) -> list[dict[str, Any]]:
|
|
79
|
+
"""Load and validate NIAH result records from a JSONL file."""
|
|
80
|
+
required_fields = {
|
|
81
|
+
'context_length',
|
|
82
|
+
'document_depth',
|
|
83
|
+
'needle',
|
|
84
|
+
'prompt',
|
|
85
|
+
'token_ids',
|
|
86
|
+
}
|
|
87
|
+
results = []
|
|
88
|
+
|
|
89
|
+
with open(path) as f:
|
|
90
|
+
for line_number, line in enumerate(f, start=1):
|
|
91
|
+
if not line.strip():
|
|
92
|
+
continue
|
|
93
|
+
try:
|
|
94
|
+
result = json.loads(line)
|
|
95
|
+
except json.JSONDecodeError as error:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f'Invalid JSON in {path} at line {line_number}: {error.msg}'
|
|
98
|
+
) from error
|
|
99
|
+
|
|
100
|
+
if not isinstance(result, dict):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f'Expected a JSON object in {path} at line {line_number}.')
|
|
103
|
+
|
|
104
|
+
missing = sorted(required_fields - result.keys())
|
|
105
|
+
if missing:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f'Missing required fields in {path} at line {line_number}: '
|
|
108
|
+
f'{", ".join(missing)}')
|
|
109
|
+
results.append(result)
|
|
110
|
+
|
|
111
|
+
if not results:
|
|
112
|
+
raise ValueError(f'No result records found in {path}.')
|
|
113
|
+
|
|
114
|
+
return results
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def filter_results(
|
|
118
|
+
results: list[dict[str, Any]],
|
|
119
|
+
sweep: SweepConfig | None,
|
|
120
|
+
) -> list[dict[str, Any]]:
|
|
121
|
+
"""Filter results to exact context-length/document-depth sweep cells."""
|
|
122
|
+
if sweep is None:
|
|
123
|
+
return results
|
|
124
|
+
|
|
125
|
+
context_lengths = set(sweep.context_lengths)
|
|
126
|
+
document_depths = set(sweep.document_depths)
|
|
127
|
+
return [
|
|
128
|
+
result for result in results
|
|
129
|
+
if result['context_length'] in context_lengths and
|
|
130
|
+
result['document_depth'] in document_depths
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# Main experiment functions
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def niah_main(config: NIAHConfig):
|
|
138
|
+
"""Run a NIAH sweep and save results + resolved config."""
|
|
139
|
+
from .niah import NeedleInAHaystack
|
|
140
|
+
|
|
141
|
+
niah = NeedleInAHaystack(
|
|
142
|
+
haystack_config=config.haystack,
|
|
143
|
+
sweep_config=config.sweep,
|
|
144
|
+
model_config=config.model,
|
|
145
|
+
)
|
|
146
|
+
niah.run(output_path=os.path.join(config.output_dir, 'results.jsonl'))
|
|
147
|
+
|
|
148
|
+
with open(os.path.join(config.output_dir, 'config.yaml'), 'w') as f:
|
|
149
|
+
yaml.dump(asdict(config), f)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def detect_main(config: DetectConfig, detector_cls=None):
|
|
153
|
+
"""Trace selected NIAH results and save aggregate retrieval-head scores."""
|
|
154
|
+
if not config.results_files:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
'results_files is required: set it in the --config YAML or pass '
|
|
157
|
+
'--results-files.')
|
|
158
|
+
if config.output_dir is None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
'output_dir is required: set it in the --config YAML or pass '
|
|
161
|
+
'--output-dir.')
|
|
162
|
+
|
|
163
|
+
all_results = []
|
|
164
|
+
for path in config.results_files:
|
|
165
|
+
all_results.extend(load_results(path))
|
|
166
|
+
results = filter_results(all_results, config.sweep)
|
|
167
|
+
if not results:
|
|
168
|
+
raise ValueError('No results matched the configured sweep.')
|
|
169
|
+
|
|
170
|
+
if config.score_threshold is not None:
|
|
171
|
+
for result in results:
|
|
172
|
+
if 'rougeL' not in result:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
'score_threshold is set but result record is missing "rougeL".')
|
|
175
|
+
results = [
|
|
176
|
+
result for result in results
|
|
177
|
+
if result['rougeL'] >= config.score_threshold
|
|
178
|
+
]
|
|
179
|
+
if not results:
|
|
180
|
+
raise ValueError('No results met the score threshold.')
|
|
181
|
+
|
|
182
|
+
if detector_cls is None:
|
|
183
|
+
from .detect import RetrievalHeadDetector
|
|
184
|
+
detector_cls = RetrievalHeadDetector
|
|
185
|
+
|
|
186
|
+
detector = detector_cls(model_config=config.model)
|
|
187
|
+
for result in tqdm.tqdm(results, desc='Tracing results'):
|
|
188
|
+
detector.calculate(
|
|
189
|
+
prompt=result['prompt'],
|
|
190
|
+
needle=result['needle'],
|
|
191
|
+
completion_tokens=result['token_ids'],
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
detector.save(config.output_dir)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# cli entry points
|
|
198
|
+
|
|
199
|
+
def niah_cli():
|
|
200
|
+
"""Entry point for ``retrieval-heads.niah``."""
|
|
201
|
+
pre = argparse.ArgumentParser(add_help=False)
|
|
202
|
+
pre.add_argument(
|
|
203
|
+
'--config',
|
|
204
|
+
type=str,
|
|
205
|
+
default=None,
|
|
206
|
+
help='Path to a YAML config with top-level "model", "haystack", "sweep" '
|
|
207
|
+
'and optional "output_dir" sections.',
|
|
208
|
+
)
|
|
209
|
+
known, remaining = pre.parse_known_args()
|
|
210
|
+
|
|
211
|
+
default = load_niah_config(
|
|
212
|
+
known.config) if known.config else tyro.MISSING_NONPROP
|
|
213
|
+
config = tyro.cli(
|
|
214
|
+
NIAHConfig,
|
|
215
|
+
description='Run Needle-in-a-Haystack (NIAH) file.\n\n'
|
|
216
|
+
'Pass --config with a YAML file to load config from file.',
|
|
217
|
+
args=remaining,
|
|
218
|
+
default=default,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if config.output_dir is None:
|
|
222
|
+
raise SystemExit(
|
|
223
|
+
'output_dir is required: set it in the --config YAML or pass --output-dir.'
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
niah_main(config)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def detect_cli():
|
|
230
|
+
"""Entry point for ``retrieval-heads.detect``."""
|
|
231
|
+
pre = argparse.ArgumentParser(add_help=False)
|
|
232
|
+
pre.add_argument(
|
|
233
|
+
'--config',
|
|
234
|
+
type=str,
|
|
235
|
+
default=None,
|
|
236
|
+
help='Path to a YAML config with top-level "model", "haystack", '
|
|
237
|
+
'optional "sweep", "results_file", and "output_path" sections.',
|
|
238
|
+
)
|
|
239
|
+
known, remaining = pre.parse_known_args()
|
|
240
|
+
|
|
241
|
+
default = load_detect_config(
|
|
242
|
+
known.config) if known.config else tyro.MISSING_NONPROP
|
|
243
|
+
config = tyro.cli(
|
|
244
|
+
DetectConfig,
|
|
245
|
+
description='Run retrieval-head detection over saved NIAH JSONL results.'
|
|
246
|
+
'\n\nPass --config with a YAML file to load config defaults.',
|
|
247
|
+
args=remaining,
|
|
248
|
+
default=default,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
detect_main(config)
|
|
253
|
+
except ValueError as e:
|
|
254
|
+
raise SystemExit(str(e)) from e
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class ModelConfig:
|
|
6
|
+
model: str
|
|
7
|
+
max_model_len: int | None = None
|
|
8
|
+
trust_remote_code: bool = False
|
|
9
|
+
dtype: str = 'auto'
|
|
10
|
+
tensor_parallel_size: int = 1
|
|
11
|
+
gpu_memory_utilization: float = 0.9
|
|
12
|
+
enable_prefix_caching: bool = True
|
|
13
|
+
chat_template: str | None = None
|
|
14
|
+
enforce_eager: bool = False
|
|
15
|
+
language_model_only: bool = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class RangeConfig:
|
|
20
|
+
min: int
|
|
21
|
+
max: int
|
|
22
|
+
intervals: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SweepConfig:
|
|
27
|
+
context_lengths: list[int] = field(
|
|
28
|
+
default_factory=lambda: list(range(1_000, 50_000, 1_000)))
|
|
29
|
+
document_depths: list[int] = field(
|
|
30
|
+
default_factory=lambda: list(range(0, 100, 10)))
|
|
31
|
+
|
|
32
|
+
def __post_init__(self):
|
|
33
|
+
|
|
34
|
+
def expand_range(r: list[float] | list[int] | RangeConfig) -> list[int]:
|
|
35
|
+
if isinstance(r, RangeConfig):
|
|
36
|
+
return list(range(r.min, r.max + 1, (r.max - r.min) // r.intervals))
|
|
37
|
+
elif isinstance(r, dict):
|
|
38
|
+
return list(
|
|
39
|
+
range(
|
|
40
|
+
r['min'],
|
|
41
|
+
r['max'] + 1,
|
|
42
|
+
(r['max'] - r['min']) // r['intervals'],
|
|
43
|
+
))
|
|
44
|
+
else:
|
|
45
|
+
return sorted(list(r))
|
|
46
|
+
|
|
47
|
+
self.context_lengths = expand_range(self.context_lengths)
|
|
48
|
+
self.document_depths = expand_range(self.document_depths)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class HaystackConfig:
|
|
53
|
+
haystack_dir: str
|
|
54
|
+
needle: str = '\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n'
|
|
55
|
+
retrieval_question: str = 'What is the best thing to do in San Francisco?'
|