interpkit 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {interpkit-0.2.0 → interpkit-0.3.0}/PKG-INFO +39 -18
  2. {interpkit-0.2.0 → interpkit-0.3.0}/README.md +27 -14
  3. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/cli/main.py +216 -194
  4. interpkit-0.3.0/interpkit/core/cache.py +36 -0
  5. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/discovery.py +116 -10
  6. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/html.py +1 -2
  7. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/inputs.py +17 -10
  8. interpkit-0.3.0/interpkit/core/loader.py +292 -0
  9. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/model.py +61 -326
  10. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/plot.py +5 -6
  11. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/registry.py +18 -4
  12. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/render.py +294 -178
  13. interpkit-0.3.0/interpkit/core/theme.py +33 -0
  14. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/tl_compat.py +3 -3
  15. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/ablate.py +1 -1
  16. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/activations.py +5 -2
  17. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/attention.py +15 -15
  18. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/attribute.py +15 -15
  19. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/batch.py +14 -13
  20. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/circuits.py +14 -15
  21. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/diff.py +8 -4
  22. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/dla.py +158 -11
  23. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/find_circuit.py +6 -6
  24. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/heads.py +4 -3
  25. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/inspect.py +1 -1
  26. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/lens.py +7 -3
  27. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/patch.py +11 -11
  28. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/probe.py +4 -3
  29. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/report.py +1 -1
  30. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/sae.py +87 -24
  31. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/scan.py +50 -34
  32. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/steer.py +11 -7
  33. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/trace.py +5 -4
  34. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit.egg-info/PKG-INFO +39 -18
  35. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit.egg-info/SOURCES.txt +8 -0
  36. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit.egg-info/requires.txt +9 -0
  37. {interpkit-0.2.0 → interpkit-0.3.0}/pyproject.toml +40 -4
  38. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_activations.py +3 -1
  39. interpkit-0.3.0/tests/test_architectures.py +286 -0
  40. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_attention.py +3 -1
  41. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_cli.py +2 -2
  42. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_discovery_units.py +4 -6
  43. interpkit-0.3.0/tests/test_invariants.py +108 -0
  44. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_load_params.py +13 -15
  45. interpkit-0.3.0/tests/test_multi_arch.py +140 -0
  46. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_ops.py +61 -1
  47. interpkit-0.3.0/tests/test_param_variants.py +140 -0
  48. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_plot_internals.py +22 -25
  49. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_registry.py +0 -1
  50. interpkit-0.3.0/tests/test_regressions.py +90 -0
  51. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_render_internals.py +0 -3
  52. interpkit-0.3.0/tests/test_sae.py +219 -0
  53. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_tl_compat.py +0 -2
  54. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_tl_ops.py +0 -1
  55. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_trace.py +0 -2
  56. interpkit-0.2.0/tests/test_sae.py +0 -115
  57. {interpkit-0.2.0 → interpkit-0.3.0}/LICENSE +0 -0
  58. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/__init__.py +0 -0
  59. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/cli/__init__.py +0 -0
  60. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/core/__init__.py +0 -0
  61. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit/ops/__init__.py +0 -0
  62. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit.egg-info/dependency_links.txt +0 -0
  63. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit.egg-info/entry_points.txt +0 -0
  64. {interpkit-0.2.0 → interpkit-0.3.0}/interpkit.egg-info/top_level.txt +0 -0
  65. {interpkit-0.2.0 → interpkit-0.3.0}/setup.cfg +0 -0
  66. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_ablate.py +0 -0
  67. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_attribute.py +0 -0
  68. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_cache.py +0 -0
  69. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_diff.py +0 -0
  70. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_discovery.py +0 -0
  71. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_error_handling.py +0 -0
  72. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_html.py +0 -0
  73. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_inspect.py +0 -0
  74. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_lens.py +0 -0
  75. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_patch.py +0 -0
  76. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_plots.py +0 -0
  77. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_probe.py +0 -0
  78. {interpkit-0.2.0 → interpkit-0.3.0}/tests/test_steer.py +0 -0
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: interpkit
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Mech interp for any HuggingFace model.
5
5
  Author: Davide Zani
6
6
  License-Expression: MIT
7
- Project-URL: Homepage, https://github.com/z4nix/MechKit
8
- Project-URL: Repository, https://github.com/z4nix/MechKit
9
- Project-URL: Issues, https://github.com/z4nix/MechKit/issues
7
+ Project-URL: Homepage, https://github.com/z4nix/interpkit
8
+ Project-URL: Repository, https://github.com/z4nix/interpkit
9
+ Project-URL: Issues, https://github.com/z4nix/interpkit/issues
10
10
  Keywords: mechanistic-interpretability,pytorch,transformers,mech-interp,interpretability
11
11
  Classifier: Development Status :: 3 - Alpha
12
12
  Classifier: Intended Audience :: Science/Research
@@ -23,6 +23,7 @@ Requires-Dist: torch>=2.1
23
23
  Requires-Dist: transformers>=4.36
24
24
  Requires-Dist: safetensors>=0.4
25
25
  Requires-Dist: rich>=13.0
26
+ Requires-Dist: rich-gradient>=0.3
26
27
  Requires-Dist: typer>=0.9
27
28
  Requires-Dist: Pillow>=10.0
28
29
  Requires-Dist: matplotlib>=3.8
@@ -34,20 +35,20 @@ Requires-Dist: scikit-learn>=1.3; extra == "probe"
34
35
  Provides-Extra: dev
35
36
  Requires-Dist: pytest>=7.0; extra == "dev"
36
37
  Requires-Dist: pytest-timeout>=2.2; extra == "dev"
38
+ Requires-Dist: pytest-cov>=5.0; extra == "dev"
37
39
  Requires-Dist: scikit-learn>=1.3; extra == "dev"
38
40
  Requires-Dist: torchvision>=0.16; extra == "dev"
41
+ Requires-Dist: ruff>=0.4; extra == "dev"
42
+ Requires-Dist: mypy>=1.8; extra == "dev"
43
+ Provides-Extra: docs
44
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
45
+ Requires-Dist: mkdocs-material>=9.5; extra == "docs"
46
+ Requires-Dist: mkdocstrings[python]>=0.24; extra == "docs"
39
47
  Dynamic: license-file
40
48
 
41
- ```
42
- IIIII tt KK KK iii tt
43
- III nn nnn tt eee rr rr pp pp KK KK tt
44
- III nnn nn tttt ee e rrr r ppp pp KKKK iii tttt
45
- III nn nn tt eeeee rr pppppp KK KK iii tt
46
- IIIII nn nn tttt eeeee rr pp KK KK iii tttt
47
- pp
48
- ```
49
-
50
- > Mech interp for any HuggingFace model.
49
+ <p align="center">
50
+ <img src="assets/logo.svg" alt="InterpKit" width="680">
51
+ </p>
51
52
 
52
53
  [![PyPI version](https://img.shields.io/pypi/v/interpkit.svg)](https://pypi.org/project/interpkit/)
53
54
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@@ -75,8 +76,8 @@ pip install interpkit[probe]
75
76
  Or install from source for development:
76
77
 
77
78
  ```bash
78
- git clone https://github.com/davidezani/InterpKit.git
79
- cd InterpKit
79
+ git clone https://github.com/z4nix/interpkit.git
80
+ cd interpkit
80
81
  pip install -e ".[dev]"
81
82
  ```
82
83
 
@@ -117,7 +118,7 @@ model = interpkit.load("bert-base-uncased")
117
118
  | Operation | What it does | Works on |
118
119
  |-----------|-------------|----------|
119
120
  | **`scan`** | One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings | LMs |
120
- | **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution | LMs |
121
+ | **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution; optionally decompose through an SAE into per-feature attributions | LMs |
121
122
  | `inspect` | Module tree with types, param counts, shapes | Any model |
122
123
  | `patch` | Activation patching at a module, head, or position | Any model |
123
124
  | `trace` | Causal tracing — module-level or position-aware (Meng et al.) heatmap | Any model |
@@ -172,6 +173,16 @@ model.dla("The capital of France is", token="Paris")
172
173
 
173
174
  # Save a bar chart
174
175
  model.dla("The capital of France is", save="dla.png")
176
+
177
+ # Feature-level DLA — decompose a component through an SAE
178
+ # to see which individual features drive the prediction
179
+ model.dla(
180
+ "The capital of France is",
181
+ sae="jbloom/GPT2-Small-SAEs-Reformatted",
182
+ sae_at="transformer.h.11.attn",
183
+ )
184
+ # result["feature_contributions"]["features"]
185
+ # — per-feature logit attributions at the specified component
175
186
  ```
176
187
 
177
188
  ## Causal Tracing
@@ -342,14 +353,22 @@ interpkit.diff(base, finetuned, "The capital of France is")
342
353
 
343
354
  ## SAE Features
344
355
 
345
- Decompose activations into interpretable features using pre-trained Sparse Autoencoders from HuggingFace:
356
+ Decompose activations into interpretable features using pre-trained Sparse Autoencoders:
346
357
 
347
358
  ```python
359
+ # From HuggingFace
348
360
  model.features(
349
361
  "The capital of France is",
350
362
  at="transformer.h.8",
351
363
  sae="jbloom/GPT2-Small-SAEs-Reformatted",
352
364
  )
365
+
366
+ # From a local file (.safetensors or .pt)
367
+ model.features(
368
+ "The capital of France is",
369
+ at="transformer.h.8",
370
+ sae="/path/to/sae_weights.safetensors",
371
+ )
353
372
  ```
354
373
 
355
374
  No SAELens dependency — weights are loaded directly via `safetensors`.
@@ -408,6 +427,8 @@ interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
408
427
  interpkit decompose gpt2 "The capital of France is"
409
428
  interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
410
429
  interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs-Reformatted
430
+ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
431
+ interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
411
432
 
412
433
  # Interactive HTML output
413
434
  interpkit attention gpt2 "hello world" --html attention.html
@@ -1,13 +1,6 @@
1
- ```
2
- IIIII tt KK KK iii tt
3
- III nn nnn tt eee rr rr pp pp KK KK tt
4
- III nnn nn tttt ee e rrr r ppp pp KKKK iii tttt
5
- III nn nn tt eeeee rr pppppp KK KK iii tt
6
- IIIII nn nn tttt eeeee rr pp KK KK iii tttt
7
- pp
8
- ```
9
-
10
- > Mech interp for any HuggingFace model.
1
+ <p align="center">
2
+ <img src="assets/logo.svg" alt="InterpKit" width="680">
3
+ </p>
11
4
 
12
5
  [![PyPI version](https://img.shields.io/pypi/v/interpkit.svg)](https://pypi.org/project/interpkit/)
13
6
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
@@ -35,8 +28,8 @@ pip install interpkit[probe]
35
28
  Or install from source for development:
36
29
 
37
30
  ```bash
38
- git clone https://github.com/davidezani/InterpKit.git
39
- cd InterpKit
31
+ git clone https://github.com/z4nix/interpkit.git
32
+ cd interpkit
40
33
  pip install -e ".[dev]"
41
34
  ```
42
35
 
@@ -77,7 +70,7 @@ model = interpkit.load("bert-base-uncased")
77
70
  | Operation | What it does | Works on |
78
71
  |-----------|-------------|----------|
79
72
  | **`scan`** | One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings | LMs |
80
- | **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution | LMs |
73
+ | **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution; optionally decompose through an SAE into per-feature attributions | LMs |
81
74
  | `inspect` | Module tree with types, param counts, shapes | Any model |
82
75
  | `patch` | Activation patching at a module, head, or position | Any model |
83
76
  | `trace` | Causal tracing — module-level or position-aware (Meng et al.) heatmap | Any model |
@@ -132,6 +125,16 @@ model.dla("The capital of France is", token="Paris")
132
125
 
133
126
  # Save a bar chart
134
127
  model.dla("The capital of France is", save="dla.png")
128
+
129
+ # Feature-level DLA — decompose a component through an SAE
130
+ # to see which individual features drive the prediction
131
+ model.dla(
132
+ "The capital of France is",
133
+ sae="jbloom/GPT2-Small-SAEs-Reformatted",
134
+ sae_at="transformer.h.11.attn",
135
+ )
136
+ # result["feature_contributions"]["features"]
137
+ # — per-feature logit attributions at the specified component
135
138
  ```
136
139
 
137
140
  ## Causal Tracing
@@ -302,14 +305,22 @@ interpkit.diff(base, finetuned, "The capital of France is")
302
305
 
303
306
  ## SAE Features
304
307
 
305
- Decompose activations into interpretable features using pre-trained Sparse Autoencoders from HuggingFace:
308
+ Decompose activations into interpretable features using pre-trained Sparse Autoencoders:
306
309
 
307
310
  ```python
311
+ # From HuggingFace
308
312
  model.features(
309
313
  "The capital of France is",
310
314
  at="transformer.h.8",
311
315
  sae="jbloom/GPT2-Small-SAEs-Reformatted",
312
316
  )
317
+
318
+ # From a local file (.safetensors or .pt)
319
+ model.features(
320
+ "The capital of France is",
321
+ at="transformer.h.8",
322
+ sae="/path/to/sae_weights.safetensors",
323
+ )
313
324
  ```
314
325
 
315
326
  No SAELens dependency — weights are loaded directly via `safetensors`.
@@ -368,6 +379,8 @@ interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
368
379
  interpkit decompose gpt2 "The capital of France is"
369
380
  interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
370
381
  interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs-Reformatted
382
+ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
383
+ interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
371
384
 
372
385
  # Interactive HTML output
373
386
  interpkit attention gpt2 "hello world" --html attention.html