entropy-profiler 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- entropy_profiler-0.2.0/.github/workflows/ci.yml +46 -0
- entropy_profiler-0.2.0/.github/workflows/publish.yml +56 -0
- entropy_profiler-0.2.0/.gitignore +46 -0
- entropy_profiler-0.2.0/LICENSE +21 -0
- entropy_profiler-0.2.0/PKG-INFO +347 -0
- entropy_profiler-0.2.0/README.md +307 -0
- entropy_profiler-0.2.0/entropy_profiler/__init__.py +100 -0
- entropy_profiler-0.2.0/entropy_profiler/analysis.py +558 -0
- entropy_profiler-0.2.0/entropy_profiler/distances.py +328 -0
- entropy_profiler-0.2.0/entropy_profiler/estimators.py +121 -0
- entropy_profiler-0.2.0/entropy_profiler/profiler.py +620 -0
- entropy_profiler-0.2.0/entropy_profiler/py.typed +0 -0
- entropy_profiler-0.2.0/entropy_profiler/viz.py +451 -0
- entropy_profiler-0.2.0/notebooks/api_tour.ipynb +870 -0
- entropy_profiler-0.2.0/notebooks/exploration.ipynb +447 -0
- entropy_profiler-0.2.0/pyproject.toml +73 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main, master]
|
|
6
|
+
pull_request:
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
lint:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
steps:
|
|
12
|
+
- uses: actions/checkout@v4
|
|
13
|
+
|
|
14
|
+
- uses: actions/setup-python@v5
|
|
15
|
+
with:
|
|
16
|
+
python-version: "3.12"
|
|
17
|
+
|
|
18
|
+
- name: Install uv
|
|
19
|
+
uses: astral-sh/setup-uv@v4
|
|
20
|
+
|
|
21
|
+
- name: Install dependencies
|
|
22
|
+
run: uv sync --extra dev
|
|
23
|
+
|
|
24
|
+
- name: Ruff check
|
|
25
|
+
run: uv run ruff check .
|
|
26
|
+
|
|
27
|
+
- name: Ruff format check
|
|
28
|
+
run: uv run ruff format --check .
|
|
29
|
+
|
|
30
|
+
typecheck:
|
|
31
|
+
runs-on: ubuntu-latest
|
|
32
|
+
steps:
|
|
33
|
+
- uses: actions/checkout@v4
|
|
34
|
+
|
|
35
|
+
- uses: actions/setup-python@v5
|
|
36
|
+
with:
|
|
37
|
+
python-version: "3.12"
|
|
38
|
+
|
|
39
|
+
- name: Install uv
|
|
40
|
+
uses: astral-sh/setup-uv@v4
|
|
41
|
+
|
|
42
|
+
- name: Install dependencies
|
|
43
|
+
run: uv sync --extra dev
|
|
44
|
+
|
|
45
|
+
- name: Import check
|
|
46
|
+
run: uv run python -c "import entropy_profiler; print(entropy_profiler.__version__)"
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
|
|
7
|
+
jobs:
|
|
8
|
+
build:
|
|
9
|
+
runs-on: ubuntu-latest
|
|
10
|
+
steps:
|
|
11
|
+
- uses: actions/checkout@v4
|
|
12
|
+
|
|
13
|
+
- uses: actions/setup-python@v5
|
|
14
|
+
with:
|
|
15
|
+
python-version: "3.12"
|
|
16
|
+
|
|
17
|
+
- name: Install build tool
|
|
18
|
+
run: pip install build
|
|
19
|
+
|
|
20
|
+
- name: Build sdist and wheel
|
|
21
|
+
run: python -m build
|
|
22
|
+
|
|
23
|
+
- uses: actions/upload-artifact@v4
|
|
24
|
+
with:
|
|
25
|
+
name: dist
|
|
26
|
+
path: dist/
|
|
27
|
+
|
|
28
|
+
publish:
|
|
29
|
+
needs: build
|
|
30
|
+
runs-on: ubuntu-latest
|
|
31
|
+
environment: pypi
|
|
32
|
+
permissions:
|
|
33
|
+
id-token: write
|
|
34
|
+
steps:
|
|
35
|
+
- uses: actions/download-artifact@v4
|
|
36
|
+
with:
|
|
37
|
+
name: dist
|
|
38
|
+
path: dist/
|
|
39
|
+
|
|
40
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
41
|
+
|
|
42
|
+
test-publish:
|
|
43
|
+
needs: build
|
|
44
|
+
runs-on: ubuntu-latest
|
|
45
|
+
environment: testpypi
|
|
46
|
+
permissions:
|
|
47
|
+
id-token: write
|
|
48
|
+
steps:
|
|
49
|
+
- uses: actions/download-artifact@v4
|
|
50
|
+
with:
|
|
51
|
+
name: dist
|
|
52
|
+
path: dist/
|
|
53
|
+
|
|
54
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
55
|
+
with:
|
|
56
|
+
repository-url: https://test.pypi.org/legacy/
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
*.egg-info/
|
|
7
|
+
*.egg
|
|
8
|
+
dist/
|
|
9
|
+
build/
|
|
10
|
+
*.whl
|
|
11
|
+
|
|
12
|
+
# Virtual environments
|
|
13
|
+
.venv/
|
|
14
|
+
venv/
|
|
15
|
+
env/
|
|
16
|
+
|
|
17
|
+
# IDE
|
|
18
|
+
.idea/
|
|
19
|
+
.vscode/
|
|
20
|
+
*.swp
|
|
21
|
+
*.swo
|
|
22
|
+
*~
|
|
23
|
+
|
|
24
|
+
# Jupyter
|
|
25
|
+
.ipynb_checkpoints/
|
|
26
|
+
*.ipynb_checkpoints
|
|
27
|
+
|
|
28
|
+
# Testing / linting
|
|
29
|
+
.pytest_cache/
|
|
30
|
+
.ruff_cache/
|
|
31
|
+
.mypy_cache/
|
|
32
|
+
htmlcov/
|
|
33
|
+
.coverage
|
|
34
|
+
|
|
35
|
+
# OS
|
|
36
|
+
.DS_Store
|
|
37
|
+
Thumbs.db
|
|
38
|
+
|
|
39
|
+
# Models (don't commit downloaded weights)
|
|
40
|
+
*.bin
|
|
41
|
+
*.safetensors
|
|
42
|
+
*.pt
|
|
43
|
+
*.pth
|
|
44
|
+
|
|
45
|
+
# uv
|
|
46
|
+
uv.lock
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 entropy-profiler contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: entropy-profiler
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Extract, analyze, and visualize entropy profiles from transformer models using the logit-lens technique.
|
|
5
|
+
Project-URL: Homepage, https://github.com/TheGitCommit/entropy-profiler
|
|
6
|
+
Project-URL: Documentation, https://github.com/TheGitCommit/entropy-profiler/blob/master/README.md
|
|
7
|
+
Project-URL: Repository, https://github.com/TheGitCommit/entropy-profiler
|
|
8
|
+
Project-URL: Issues, https://github.com/TheGitCommit/entropy-profiler/issues
|
|
9
|
+
Author: entropy-profiler contributors
|
|
10
|
+
License-Expression: MIT
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Keywords: entropy,interpretability,logit-lens,machine-learning,nlp,profiling,transformer
|
|
13
|
+
Classifier: Development Status :: 4 - Beta
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Typing :: Typed
|
|
23
|
+
Requires-Python: >=3.10
|
|
24
|
+
Requires-Dist: matplotlib>=3.7.0
|
|
25
|
+
Requires-Dist: numpy>=1.24.0
|
|
26
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
27
|
+
Requires-Dist: scipy>=1.11.0
|
|
28
|
+
Requires-Dist: seaborn>=0.12.0
|
|
29
|
+
Requires-Dist: torch>=2.1.0
|
|
30
|
+
Requires-Dist: transformers>=4.40.0
|
|
31
|
+
Provides-Extra: dev
|
|
32
|
+
Requires-Dist: pytest>=7.0.0; extra == 'dev'
|
|
33
|
+
Requires-Dist: ruff>=0.4.0; extra == 'dev'
|
|
34
|
+
Provides-Extra: notebook
|
|
35
|
+
Requires-Dist: datasets>=3.0.0; extra == 'notebook'
|
|
36
|
+
Requires-Dist: ipywidgets>=8.0.0; extra == 'notebook'
|
|
37
|
+
Requires-Dist: jupyter>=1.0.0; extra == 'notebook'
|
|
38
|
+
Requires-Dist: notebook>=7.0.0; extra == 'notebook'
|
|
39
|
+
Description-Content-Type: text/markdown
|
|
40
|
+
|
|
41
|
+
# entropy-profiler
|
|
42
|
+
|
|
43
|
+
Extract, analyse, and visualize entropy profiles from transformer models using
|
|
44
|
+
the logit-lens technique.
|
|
45
|
+
|
|
46
|
+
`entropy-profiler` computes per-layer Shannon or Rényi entropy by passing
|
|
47
|
+
hidden states through the model's own unembedding head (layer norm + lm_head).
|
|
48
|
+
It works on any HuggingFace `CausalLM` without architecture-specific hooks.
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from entropy_profiler import EntropyProfiler, plot_profile
|
|
52
|
+
import torch
|
|
53
|
+
|
|
54
|
+
profiler = EntropyProfiler("gpt2", dtype=torch.float32)
|
|
55
|
+
profile = profiler.profile_text("The meaning of life is", max_new_tokens=32)
|
|
56
|
+
plot_profile(profile)
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## Installation
|
|
62
|
+
|
|
63
|
+
### From source (recommended for development)
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
git clone https://github.com/TODO/entropy-profiler
|
|
67
|
+
cd entropy-profiler
|
|
68
|
+
|
|
69
|
+
# Using uv (fast, handles venvs automatically)
|
|
70
|
+
uv sync # core dependencies
|
|
71
|
+
uv sync --extra notebook # + Jupyter support
|
|
72
|
+
uv sync --extra dev # + pytest, ruff
|
|
73
|
+
|
|
74
|
+
# Or using pip
|
|
75
|
+
pip install -e .
|
|
76
|
+
pip install -e ".[notebook]"
|
|
77
|
+
pip install -e ".[dev]"
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
### From PyPI (once published)
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
pip install entropy-profiler
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
---
|
|
87
|
+
|
|
88
|
+
## Quick Start
|
|
89
|
+
|
|
90
|
+
### Profile a single prompt
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
from entropy_profiler import EntropyProfiler, plot_profile
|
|
94
|
+
import torch
|
|
95
|
+
|
|
96
|
+
profiler = EntropyProfiler("gpt2", dtype=torch.float32)
|
|
97
|
+
profile = profiler.profile_text("The capital of France is", max_new_tokens=32)
|
|
98
|
+
|
|
99
|
+
print(profile.entropy.shape) # (n_tokens, n_layers)
|
|
100
|
+
print(profile.mean_profile()) # (n_layers,) tensor
|
|
101
|
+
plot_profile(profile)
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
### Profile multiple prompts
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
from entropy_profiler import plot_aggregated
|
|
108
|
+
|
|
109
|
+
agg = profiler.profile_batch([
|
|
110
|
+
"The stock market experienced significant",
|
|
111
|
+
"In quantum mechanics, the wave function",
|
|
112
|
+
"Modern neural networks learn by",
|
|
113
|
+
], max_new_tokens=24)
|
|
114
|
+
|
|
115
|
+
print(agg.to_matrix().shape) # (3, n_layers)
|
|
116
|
+
plot_aggregated(agg)
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
### Compare prompts with distances
|
|
120
|
+
|
|
121
|
+
```python
|
|
122
|
+
from entropy_profiler import profile_distance
|
|
123
|
+
|
|
124
|
+
p1 = profiler.profile_text("Water boils at", max_new_tokens=24)
|
|
125
|
+
p2 = profiler.profile_text("Once upon a time", max_new_tokens=24)
|
|
126
|
+
|
|
127
|
+
result = profile_distance(p1, p2, metric="jsd")
|
|
128
|
+
print(f"JSD distance: {result.aggregate:.4f}")
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
### Analyse layer dynamics
|
|
132
|
+
|
|
133
|
+
```python
|
|
134
|
+
from entropy_profiler import LayerAnalyzer
|
|
135
|
+
|
|
136
|
+
profile, hidden_states = profiler.profile_text_with_states(
|
|
137
|
+
"Hello world", max_new_tokens=32
|
|
138
|
+
)
|
|
139
|
+
analyzer = LayerAnalyzer(profiler, profile, hidden_states=hidden_states)
|
|
140
|
+
|
|
141
|
+
print(analyzer.layer_entropy()) # (n_layers,)
|
|
142
|
+
print(analyzer.information_velocity()) # (n_layers,)
|
|
143
|
+
print(analyzer.layer_mi(method="cka")) # (n_layers,)
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
---
|
|
147
|
+
|
|
148
|
+
## Core Concepts
|
|
149
|
+
|
|
150
|
+
### Logit-Lens Decoding
|
|
151
|
+
|
|
152
|
+
At each transformer layer, the hidden state is projected through the model's
|
|
153
|
+
final layer norm and language model head to produce a vocabulary distribution.
|
|
154
|
+
The entropy of this distribution measures how "decided" the model is at that
|
|
155
|
+
layer — low entropy means a peaked distribution (confident prediction), high
|
|
156
|
+
entropy means a flat distribution (uncertain).
|
|
157
|
+
|
|
158
|
+
### Entropy Profiles
|
|
159
|
+
|
|
160
|
+
An **entropy profile** is a matrix of shape `(n_tokens, n_layers)` where each
|
|
161
|
+
entry is the entropy of the vocabulary distribution at that token position and
|
|
162
|
+
layer depth. The **mean profile** `(n_layers,)` averages across tokens to give
|
|
163
|
+
a single curve showing how entropy evolves through the network.
|
|
164
|
+
|
|
165
|
+
### Why Rényi Entropy?
|
|
166
|
+
|
|
167
|
+
Shannon entropy (`alpha=1`) is the default, but Rényi entropy at other orders
|
|
168
|
+
provides complementary views:
|
|
169
|
+
- `alpha < 1` — emphasises rare events (tail sensitivity)
|
|
170
|
+
- `alpha = 1` — Shannon entropy (standard)
|
|
171
|
+
- `alpha = 2` — collision entropy (sensitive to mode)
|
|
172
|
+
- `alpha > 2` — increasingly dominated by the most probable token
|
|
173
|
+
|
|
174
|
+
---
|
|
175
|
+
|
|
176
|
+
## API Reference
|
|
177
|
+
|
|
178
|
+
### Core Module (`entropy_profiler.profiler`)
|
|
179
|
+
|
|
180
|
+
| Symbol | Description |
|
|
181
|
+
|--------|-------------|
|
|
182
|
+
| `EntropyProfiler(model, dtype, alpha, layer_stride)` | Main class. Loads model, runs generation, computes entropy. |
|
|
183
|
+
| `EntropyProfile` | Dataclass: `entropy`, `token_ids`, `layer_indices`, `alpha`, `model_name`, `metadata`. |
|
|
184
|
+
| `AggregatedProfile` | Collection of profiles with `mean_profile()` and `to_matrix()`. |
|
|
185
|
+
| `shannon_entropy(probs)` | `H(p) = -sum(p log p)` on the last dimension. |
|
|
186
|
+
| `renyi_entropy(probs, alpha)` | Rényi entropy of order α. Falls back to Shannon when α ≈ 1. |
|
|
187
|
+
|
|
188
|
+
**`EntropyProfiler` methods:**
|
|
189
|
+
|
|
190
|
+
| Method | Returns | Description |
|
|
191
|
+
|--------|---------|-------------|
|
|
192
|
+
| `profile_text(prompt, max_new_tokens, ...)` | `EntropyProfile` | Profile generated text. |
|
|
193
|
+
| `profile_text_with_states(prompt, ...)` | `(EntropyProfile, Tensor)` | Profile + raw hidden states. |
|
|
194
|
+
| `profile_batch(prompts, ...)` | `AggregatedProfile` | Profile multiple prompts. |
|
|
195
|
+
| `unload()` | `None` | Free model memory. |
|
|
196
|
+
|
|
197
|
+
**`EntropyProfile` attributes and methods:**
|
|
198
|
+
|
|
199
|
+
| Member | Type | Description |
|
|
200
|
+
|--------|------|-------------|
|
|
201
|
+
| `entropy` | `Tensor (n_tokens, n_layers)` | Per-token, per-layer entropy. |
|
|
202
|
+
| `token_ids` | `Tensor (n_tokens,)` | Generated token IDs. |
|
|
203
|
+
| `n_layers` | `int` | Number of profiled layers. |
|
|
204
|
+
| `n_tokens` | `int` | Number of profiled tokens. |
|
|
205
|
+
| `mean_profile()` | `Tensor (n_layers,)` | Mean entropy at each layer. |
|
|
206
|
+
| `to_numpy()` | `ndarray` | Convert to NumPy (float32). |
|
|
207
|
+
|
|
208
|
+
### Distances (`entropy_profiler.distances`)
|
|
209
|
+
|
|
210
|
+
| Function | Type | Description |
|
|
211
|
+
|----------|------|-------------|
|
|
212
|
+
| `profile_distance(p1, p2, metric, aggregation)` | `DistanceResult` | Unified entry point. |
|
|
213
|
+
| `pairwise_distances(profiles, metric)` | `ndarray (N, N)` | Symmetric distance matrix. |
|
|
214
|
+
| `jsd_layer(p1, p2, n_bins)` | `ndarray (n_layers,)` | Per-layer Jensen-Shannon divergence. |
|
|
215
|
+
| `wasserstein_layer(p1, p2)` | `ndarray (n_layers,)` | Per-layer Wasserstein-1 distance. |
|
|
216
|
+
| `fisher_rao_distance(p1, p2)` | `float` | Geodesic on probability simplex. |
|
|
217
|
+
| `srvf_distance(p1, p2)` | `float` | Elastic SRVF curve distance. |
|
|
218
|
+
|
|
219
|
+
**Available metrics for `profile_distance`:** `"jsd"`, `"wasserstein"`, `"fisher_rao"`, `"srvf"`.
|
|
220
|
+
|
|
221
|
+
**Aggregation methods:** `"mean"`, `"max"`, `"sum"` (for layer-wise metrics).
|
|
222
|
+
|
|
223
|
+
### Layer Analysis (`entropy_profiler.analysis`)
|
|
224
|
+
|
|
225
|
+
| Symbol | Description |
|
|
226
|
+
|--------|-------------|
|
|
227
|
+
| `LayerAnalyzer(profiler, profile, hidden_states)` | Per-layer metric computation. |
|
|
228
|
+
|
|
229
|
+
Additional functions available via `from entropy_profiler.analysis import ...`:
|
|
230
|
+
`compare_models`, `plot_layer_importance`, `plot_information_plane`, `plot_velocity_entropy`.
|
|
231
|
+
|
|
232
|
+
**`LayerAnalyzer` methods:**
|
|
233
|
+
|
|
234
|
+
| Method | Returns | Description |
|
|
235
|
+
|--------|---------|-------------|
|
|
236
|
+
| `layer_entropy()` | `ndarray (n_layers,)` | Mean Shannon entropy per layer. |
|
|
237
|
+
| `information_velocity()` | `ndarray (n_layers,)` | Wasserstein between consecutive layers. |
|
|
238
|
+
| `distance_to_output()` | `ndarray (n_layers,)` | Fisher-Rao distance to final layer. |
|
|
239
|
+
| `jsd_to_output(n_bins)` | `ndarray (n_layers,)` | JSD from each layer to final. |
|
|
240
|
+
| `layer_mi(method)` | `ndarray (n_layers,)` | MI with final layer (Rényi or CKA). |
|
|
241
|
+
| `layer_importance()` | `dict` | All four non-MI metrics. |
|
|
242
|
+
|
|
243
|
+
### Visualization (`entropy_profiler.viz`)
|
|
244
|
+
|
|
245
|
+
| Function | Description |
|
|
246
|
+
|----------|-------------|
|
|
247
|
+
| `plot_profile(profile, ax, ...)` | Line plot with ±1 std fill. |
|
|
248
|
+
| `plot_profiles(profiles, labels, ...)` | Overlay multiple profiles. |
|
|
249
|
+
| `plot_heatmap(profile, ax, ...)` | Token × layer entropy heatmap. |
|
|
250
|
+
| `plot_aggregated(agg, ax, ...)` | Aggregated mean ± std curve. |
|
|
251
|
+
| `plot_cluster(profiles, labels, method, feature, metric, ...)` | 2D scatter via t-SNE/UMAP/PCA. |
|
|
252
|
+
|
|
253
|
+
### Estimators (`entropy_profiler.estimators`)
|
|
254
|
+
|
|
255
|
+
| Symbol | Description |
|
|
256
|
+
|--------|-------------|
|
|
257
|
+
| `MatrixRenyiMI(alpha, device)` | Matrix-based Rényi MI via Gram matrices. Used by `LayerAnalyzer.layer_mi()`. |
|
|
258
|
+
|
|
259
|
+
---
|
|
260
|
+
|
|
261
|
+
## Supported Models
|
|
262
|
+
|
|
263
|
+
Any HuggingFace `AutoModelForCausalLM` is supported. The profiler automatically
|
|
264
|
+
detects the unembedding architecture:
|
|
265
|
+
|
|
266
|
+
| Model Family | Layer Norm Path | Status |
|
|
267
|
+
|-------------|-----------------|--------|
|
|
268
|
+
| GPT-2 | `transformer.ln_f` | Tested |
|
|
269
|
+
| LLaMA / LLaMA 2 / LLaMA 3 | `model.norm` | Tested |
|
|
270
|
+
| Mistral | `model.norm` | Tested |
|
|
271
|
+
| Gemma / Gemma 2 | `model.norm` | Tested |
|
|
272
|
+
| Qwen / Qwen 2 | `model.norm` | Tested |
|
|
273
|
+
| OPT | `model.norm` (fallback) | Tested |
|
|
274
|
+
|
|
275
|
+
To add a new architecture, add a resolution pattern to
|
|
276
|
+
`_get_unembedding()` in `entropy_profiler/profiler.py`.
|
|
277
|
+
|
|
278
|
+
### Tips for large models
|
|
279
|
+
|
|
280
|
+
```python
|
|
281
|
+
# Use float16 for 7B+ models to fit in GPU memory
|
|
282
|
+
profiler = EntropyProfiler("meta-llama/Llama-2-7b-hf", dtype=torch.float16)
|
|
283
|
+
|
|
284
|
+
# Profile every other layer to reduce computation
|
|
285
|
+
profiler = EntropyProfiler("gpt2", layer_stride=2)
|
|
286
|
+
|
|
287
|
+
# Use context manager to auto-unload
|
|
288
|
+
with EntropyProfiler("gpt2") as profiler:
|
|
289
|
+
profile = profiler.profile_text("Hello world")
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
---
|
|
293
|
+
|
|
294
|
+
## Design Decisions
|
|
295
|
+
|
|
296
|
+
**No hooks.** HuggingFace's `output_hidden_states=True` returns all hidden
|
|
297
|
+
states without hook infrastructure. This works across all CausalLM
|
|
298
|
+
architectures with zero architecture-specific code.
|
|
299
|
+
|
|
300
|
+
**Logit-lens, not probing.** The unembedding head is the model's own decoder.
|
|
301
|
+
No training of linear probes is needed — the entropy values are directly
|
|
302
|
+
interpretable as "how peaked is the vocabulary distribution at this layer."
|
|
303
|
+
|
|
304
|
+
**Float32 entropy.** Entropy is always computed in float32 regardless of model
|
|
305
|
+
dtype, avoiding numerical issues with half-precision softmax.
|
|
306
|
+
|
|
307
|
+
**Dataclass outputs.** `EntropyProfile` and `DistanceResult` are plain
|
|
308
|
+
dataclasses — easy to inspect, serialize, and compose.
|
|
309
|
+
|
|
310
|
+
---
|
|
311
|
+
|
|
312
|
+
## Notebooks
|
|
313
|
+
|
|
314
|
+
| Notebook | Description |
|
|
315
|
+
|----------|-------------|
|
|
316
|
+
| `exploration.ipynb` | Multi-model exploration: entropy curves, heatmaps, velocities, distances, clustering. Requires GPU and gated-model access. |
|
|
317
|
+
| `api_tour.ipynb` | Complete API tour exercising every public function with GPT-2. |
|
|
318
|
+
|
|
319
|
+
```bash
|
|
320
|
+
uv sync --extra notebook
|
|
321
|
+
uv run jupyter notebook notebooks/
|
|
322
|
+
```
|
|
323
|
+
|
|
324
|
+
---
|
|
325
|
+
|
|
326
|
+
## Development
|
|
327
|
+
|
|
328
|
+
```bash
|
|
329
|
+
# Install dev dependencies
|
|
330
|
+
uv sync --extra dev
|
|
331
|
+
|
|
332
|
+
# Lint
|
|
333
|
+
uv run ruff check .
|
|
334
|
+
uv run ruff check --fix .
|
|
335
|
+
|
|
336
|
+
# Test
|
|
337
|
+
uv run pytest
|
|
338
|
+
|
|
339
|
+
# Run a script without activating venv
|
|
340
|
+
uv run python your_script.py
|
|
341
|
+
```
|
|
342
|
+
|
|
343
|
+
---
|
|
344
|
+
|
|
345
|
+
## License
|
|
346
|
+
|
|
347
|
+
MIT
|