interpkit 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {interpkit-0.3.0 → interpkit-0.4.0}/PKG-INFO +41 -4
- {interpkit-0.3.0 → interpkit-0.4.0}/README.md +40 -3
- interpkit-0.4.0/interpkit/__main__.py +19 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/cli/main.py +110 -6
- interpkit-0.4.0/interpkit/core/inputs.py +403 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/loader.py +35 -5
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/model.py +149 -4
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/render.py +11 -11
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/theme.py +11 -8
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/attribute.py +73 -4
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/batch.py +4 -4
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/circuits.py +2 -2
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/dla.py +23 -2
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/find_circuit.py +9 -5
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/report.py +55 -10
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/sae.py +199 -19
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/scan.py +28 -6
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/steer.py +48 -2
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit.egg-info/PKG-INFO +41 -4
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit.egg-info/SOURCES.txt +4 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/pyproject.toml +1 -1
- interpkit-0.4.0/tests/test_chat.py +217 -0
- interpkit-0.4.0/tests/test_inputs.py +251 -0
- interpkit-0.4.0/tests/test_robustness_audit.py +763 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_sae.py +155 -0
- interpkit-0.4.0/tests/test_steer.py +91 -0
- interpkit-0.3.0/interpkit/core/inputs.py +0 -130
- interpkit-0.3.0/tests/test_steer.py +0 -30
- {interpkit-0.3.0 → interpkit-0.4.0}/LICENSE +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/__init__.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/cli/__init__.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/__init__.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/cache.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/discovery.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/html.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/plot.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/registry.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/core/tl_compat.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/__init__.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/ablate.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/activations.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/attention.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/diff.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/heads.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/inspect.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/lens.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/patch.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/probe.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit/ops/trace.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit.egg-info/dependency_links.txt +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit.egg-info/entry_points.txt +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit.egg-info/requires.txt +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/interpkit.egg-info/top_level.txt +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/setup.cfg +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_ablate.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_activations.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_architectures.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_attention.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_attribute.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_cache.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_cli.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_diff.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_discovery.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_discovery_units.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_error_handling.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_html.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_inspect.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_invariants.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_lens.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_load_params.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_multi_arch.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_ops.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_param_variants.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_patch.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_plot_internals.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_plots.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_probe.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_registry.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_regressions.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_render_internals.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_tl_compat.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_tl_ops.py +0 -0
- {interpkit-0.3.0 → interpkit-0.4.0}/tests/test_trace.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: interpkit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Mech interp for any HuggingFace model.
|
|
5
5
|
Author: Davide Zani
|
|
6
6
|
License-Expression: MIT
|
|
@@ -111,6 +111,25 @@ model = interpkit.load("google/vit-base-patch16-224")
|
|
|
111
111
|
model = interpkit.load("bert-base-uncased")
|
|
112
112
|
```
|
|
113
113
|
|
|
114
|
+
### Chat models
|
|
115
|
+
|
|
116
|
+
Instruction-tuned models work too — interpkit applies the tokenizer's chat template automatically.
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
chat = interpkit.load("HuggingFaceTB/SmolLM2-360M-Instruct")
|
|
120
|
+
|
|
121
|
+
result = chat.chat("Write a haiku about cats.", max_new_tokens=64)
|
|
122
|
+
print(result["response"])
|
|
123
|
+
|
|
124
|
+
# Run any other op on the templated prompt
|
|
125
|
+
chat.dla(result["prompt"])
|
|
126
|
+
|
|
127
|
+
# Or pass a message list directly to any op
|
|
128
|
+
chat.dla([{"role": "user", "content": "Capital of France?"}])
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
See [examples/10_chat_models.ipynb](examples/10_chat_models.ipynb) for a full walkthrough including chat-style steering.
|
|
132
|
+
|
|
114
133
|
---
|
|
115
134
|
|
|
116
135
|
## Operations
|
|
@@ -118,6 +137,7 @@ model = interpkit.load("bert-base-uncased")
|
|
|
118
137
|
| Operation | What it does | Works on |
|
|
119
138
|
|-----------|-------------|----------|
|
|
120
139
|
| **`scan`** | One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings | LMs |
|
|
140
|
+
| **`chat`** | Send a message through the tokenizer's chat template and generate a reply | Chat / instruct LMs |
|
|
121
141
|
| **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution; optionally decompose through an SAE into per-feature attributions | LMs |
|
|
122
142
|
| `inspect` | Module tree with types, param counts, shapes | Any model |
|
|
123
143
|
| `patch` | Activation patching at a module, head, or position | Any model |
|
|
@@ -328,10 +348,12 @@ results = model.dla_batch(["The capital of France is", "The CEO of Apple is"])
|
|
|
328
348
|
## Steering
|
|
329
349
|
|
|
330
350
|
```python
|
|
331
|
-
vector = model.steer_vector("
|
|
351
|
+
vector = model.steer_vector(" love", " hate", at="transformer.h.8")
|
|
332
352
|
model.steer("The weather today is", vector=vector, at="transformer.h.8", scale=2.0)
|
|
333
353
|
```
|
|
334
354
|
|
|
355
|
+
> Note the leading spaces. BPE tokenizers (GPT-2, Llama, ...) treat `" love"` and `"love"` as different tokens, and the leading-space variant is the one the model actually sees in normal text. interpkit prints a warning if you forget.
|
|
356
|
+
|
|
335
357
|
## Linear Probe
|
|
336
358
|
|
|
337
359
|
```python
|
|
@@ -422,7 +444,7 @@ interpkit lens gpt2 "The capital of France is"
|
|
|
422
444
|
interpkit lens gpt2 "The capital of France is" --position -1
|
|
423
445
|
interpkit attention gpt2 "The capital of France is" --layer 8 --save attention.png
|
|
424
446
|
interpkit attribute gpt2 "The capital of France is"
|
|
425
|
-
interpkit steer gpt2 "The weather is" --positive
|
|
447
|
+
interpkit steer gpt2 "The weather is" --positive " love" --negative " hate" --at transformer.h.8
|
|
426
448
|
interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
|
|
427
449
|
interpkit decompose gpt2 "The capital of France is"
|
|
428
450
|
interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
|
|
@@ -430,6 +452,10 @@ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jb
|
|
|
430
452
|
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
|
|
431
453
|
interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
|
|
432
454
|
|
|
455
|
+
# Chat / instruct models — applies the tokenizer's chat template automatically
|
|
456
|
+
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Write a haiku about cats." --max-new-tokens 64
|
|
457
|
+
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "What is 2+2?" --system "You are terse." --show-prompt
|
|
458
|
+
|
|
433
459
|
# Interactive HTML output
|
|
434
460
|
interpkit attention gpt2 "hello world" --html attention.html
|
|
435
461
|
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --html trace.html
|
|
@@ -439,7 +465,17 @@ interpkit attribute gpt2 "The capital of France is" --html attribution.html
|
|
|
439
465
|
interpkit attribute microsoft/resnet-50 cat.jpg --target 281
|
|
440
466
|
```
|
|
441
467
|
|
|
442
|
-
Run `interpkit` with no arguments for a full command reference
|
|
468
|
+
Run `interpkit` with no arguments for a full command reference, or
|
|
469
|
+
`interpkit --extensive` for a beginner-friendly walkthrough of every command.
|
|
470
|
+
|
|
471
|
+
If the `interpkit` console script isn't on your `PATH` (e.g. fresh
|
|
472
|
+
environments, sandboxed installs, or running from a checkout without
|
|
473
|
+
re-installing), every command also works as `python -m interpkit ...`:
|
|
474
|
+
|
|
475
|
+
```bash
|
|
476
|
+
python -m interpkit scan gpt2 "The capital of France is"
|
|
477
|
+
python -m interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Hello!"
|
|
478
|
+
```
|
|
443
479
|
|
|
444
480
|
---
|
|
445
481
|
|
|
@@ -501,6 +537,7 @@ See the [`examples/`](examples/) directory for Jupyter notebooks:
|
|
|
501
537
|
| `07_vision_models` | ResNet/ViT attribution, ablation, activations |
|
|
502
538
|
| `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
|
|
503
539
|
| `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
|
|
540
|
+
| `10_chat_models` | Chat-template handling, `model.chat()`, message-list inputs, chat-style steering |
|
|
504
541
|
|
|
505
542
|
---
|
|
506
543
|
|
|
@@ -63,6 +63,25 @@ model = interpkit.load("google/vit-base-patch16-224")
|
|
|
63
63
|
model = interpkit.load("bert-base-uncased")
|
|
64
64
|
```
|
|
65
65
|
|
|
66
|
+
### Chat models
|
|
67
|
+
|
|
68
|
+
Instruction-tuned models work too — interpkit applies the tokenizer's chat template automatically.
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
chat = interpkit.load("HuggingFaceTB/SmolLM2-360M-Instruct")
|
|
72
|
+
|
|
73
|
+
result = chat.chat("Write a haiku about cats.", max_new_tokens=64)
|
|
74
|
+
print(result["response"])
|
|
75
|
+
|
|
76
|
+
# Run any other op on the templated prompt
|
|
77
|
+
chat.dla(result["prompt"])
|
|
78
|
+
|
|
79
|
+
# Or pass a message list directly to any op
|
|
80
|
+
chat.dla([{"role": "user", "content": "Capital of France?"}])
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
See [examples/10_chat_models.ipynb](examples/10_chat_models.ipynb) for a full walkthrough including chat-style steering.
|
|
84
|
+
|
|
66
85
|
---
|
|
67
86
|
|
|
68
87
|
## Operations
|
|
@@ -70,6 +89,7 @@ model = interpkit.load("bert-base-uncased")
|
|
|
70
89
|
| Operation | What it does | Works on |
|
|
71
90
|
|-----------|-------------|----------|
|
|
72
91
|
| **`scan`** | One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings | LMs |
|
|
92
|
+
| **`chat`** | Send a message through the tokenizer's chat template and generate a reply | Chat / instruct LMs |
|
|
73
93
|
| **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution; optionally decompose through an SAE into per-feature attributions | LMs |
|
|
74
94
|
| `inspect` | Module tree with types, param counts, shapes | Any model |
|
|
75
95
|
| `patch` | Activation patching at a module, head, or position | Any model |
|
|
@@ -280,10 +300,12 @@ results = model.dla_batch(["The capital of France is", "The CEO of Apple is"])
|
|
|
280
300
|
## Steering
|
|
281
301
|
|
|
282
302
|
```python
|
|
283
|
-
vector = model.steer_vector("
|
|
303
|
+
vector = model.steer_vector(" love", " hate", at="transformer.h.8")
|
|
284
304
|
model.steer("The weather today is", vector=vector, at="transformer.h.8", scale=2.0)
|
|
285
305
|
```
|
|
286
306
|
|
|
307
|
+
> Note the leading spaces. BPE tokenizers (GPT-2, Llama, ...) treat `" love"` and `"love"` as different tokens, and the leading-space variant is the one the model actually sees in normal text. interpkit prints a warning if you forget.
|
|
308
|
+
|
|
287
309
|
## Linear Probe
|
|
288
310
|
|
|
289
311
|
```python
|
|
@@ -374,7 +396,7 @@ interpkit lens gpt2 "The capital of France is"
|
|
|
374
396
|
interpkit lens gpt2 "The capital of France is" --position -1
|
|
375
397
|
interpkit attention gpt2 "The capital of France is" --layer 8 --save attention.png
|
|
376
398
|
interpkit attribute gpt2 "The capital of France is"
|
|
377
|
-
interpkit steer gpt2 "The weather is" --positive
|
|
399
|
+
interpkit steer gpt2 "The weather is" --positive " love" --negative " hate" --at transformer.h.8
|
|
378
400
|
interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
|
|
379
401
|
interpkit decompose gpt2 "The capital of France is"
|
|
380
402
|
interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
|
|
@@ -382,6 +404,10 @@ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jb
|
|
|
382
404
|
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
|
|
383
405
|
interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
|
|
384
406
|
|
|
407
|
+
# Chat / instruct models — applies the tokenizer's chat template automatically
|
|
408
|
+
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Write a haiku about cats." --max-new-tokens 64
|
|
409
|
+
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "What is 2+2?" --system "You are terse." --show-prompt
|
|
410
|
+
|
|
385
411
|
# Interactive HTML output
|
|
386
412
|
interpkit attention gpt2 "hello world" --html attention.html
|
|
387
413
|
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --html trace.html
|
|
@@ -391,7 +417,17 @@ interpkit attribute gpt2 "The capital of France is" --html attribution.html
|
|
|
391
417
|
interpkit attribute microsoft/resnet-50 cat.jpg --target 281
|
|
392
418
|
```
|
|
393
419
|
|
|
394
|
-
Run `interpkit` with no arguments for a full command reference
|
|
420
|
+
Run `interpkit` with no arguments for a full command reference, or
|
|
421
|
+
`interpkit --extensive` for a beginner-friendly walkthrough of every command.
|
|
422
|
+
|
|
423
|
+
If the `interpkit` console script isn't on your `PATH` (e.g. fresh
|
|
424
|
+
environments, sandboxed installs, or running from a checkout without
|
|
425
|
+
re-installing), every command also works as `python -m interpkit ...`:
|
|
426
|
+
|
|
427
|
+
```bash
|
|
428
|
+
python -m interpkit scan gpt2 "The capital of France is"
|
|
429
|
+
python -m interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Hello!"
|
|
430
|
+
```
|
|
395
431
|
|
|
396
432
|
---
|
|
397
433
|
|
|
@@ -453,6 +489,7 @@ See the [`examples/`](examples/) directory for Jupyter notebooks:
|
|
|
453
489
|
| `07_vision_models` | ResNet/ViT attribution, ablation, activations |
|
|
454
490
|
| `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
|
|
455
491
|
| `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
|
|
492
|
+
| `10_chat_models` | Chat-template handling, `model.chat()`, message-list inputs, chat-style steering |
|
|
456
493
|
|
|
457
494
|
---
|
|
458
495
|
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Entry point so ``python -m interpkit`` invokes the Typer CLI.
|
|
2
|
+
|
|
3
|
+
Mirrors the ``[project.scripts] interpkit = "interpkit.cli.main:app"``
|
|
4
|
+
console script declared in :file:`pyproject.toml`, so users without the
|
|
5
|
+
console script on their ``$PATH`` (e.g. just-installed in a fresh
|
|
6
|
+
environment, vendored copies, ad-hoc subprocess invocations) can still
|
|
7
|
+
reach every CLI command via ``python -m interpkit ...``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from interpkit.cli.main import app
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def main() -> None:
|
|
14
|
+
"""Invoke the Typer app — separate function makes patching easier in tests."""
|
|
15
|
+
app()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if __name__ == "__main__":
|
|
19
|
+
main()
|
|
@@ -6,6 +6,7 @@ import json as _json
|
|
|
6
6
|
from importlib.metadata import version as _pkg_version
|
|
7
7
|
|
|
8
8
|
import typer
|
|
9
|
+
import typer.rich_utils as _ru
|
|
9
10
|
from rich.console import Console
|
|
10
11
|
from rich.panel import Panel
|
|
11
12
|
from rich.table import Table
|
|
@@ -14,6 +15,18 @@ from rich_gradient import Text as GradientText
|
|
|
14
15
|
|
|
15
16
|
from interpkit.core.theme import ACCENT, ACCENT_DIM, BRAND_COLORS
|
|
16
17
|
|
|
18
|
+
_ru.STYLE_OPTION = f"bold {ACCENT}"
|
|
19
|
+
_ru.STYLE_SWITCH = f"bold {ACCENT}"
|
|
20
|
+
_ru.STYLE_METAVAR = f"bold {ACCENT}"
|
|
21
|
+
_ru.STYLE_USAGE = ACCENT
|
|
22
|
+
_ru.STYLE_USAGE_COMMAND = "bold"
|
|
23
|
+
_ru.STYLE_COMMANDS_TABLE_FIRST_COLUMN = f"bold {ACCENT}"
|
|
24
|
+
_ru.STYLE_OPTIONS_PANEL_BORDER = ACCENT_DIM
|
|
25
|
+
_ru.STYLE_COMMANDS_PANEL_BORDER = ACCENT_DIM
|
|
26
|
+
_ru.STYLE_REQUIRED_SHORT = ACCENT
|
|
27
|
+
_ru.STYLE_REQUIRED_LONG = ACCENT_DIM
|
|
28
|
+
_ru.STYLE_NEGATIVE_OPTION = f"bold {ACCENT}"
|
|
29
|
+
|
|
17
30
|
app = typer.Typer(
|
|
18
31
|
name="interpkit",
|
|
19
32
|
help="Mech interp for any HuggingFace model.",
|
|
@@ -110,6 +123,27 @@ def _show_extensive_help() -> None:
|
|
|
110
123
|
padding=(0, 2),
|
|
111
124
|
))
|
|
112
125
|
|
|
126
|
+
console.print()
|
|
127
|
+
console.print(Panel(
|
|
128
|
+
f"[bold {ACCENT}]chat[/bold {ACCENT}] "
|
|
129
|
+
"[dim]interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct 'Write a haiku.'[/dim]\n\n"
|
|
130
|
+
"Send a message to an instruction-tuned chat model and print its reply. The message is"
|
|
131
|
+
" routed through the tokenizer's chat template (e.g. ChatML, Llama-2 Inst, Qwen, Gemma)"
|
|
132
|
+
" with [dim]add_generation_prompt=True[/dim] before generation, so any HF chat model that"
|
|
133
|
+
" ships a template just works.\n\n"
|
|
134
|
+
" Errors clearly when the model has no chat template (i.e. a base/non-instruct model) —"
|
|
135
|
+
" in that case load an instruct variant or call any other command with a plain string.\n\n"
|
|
136
|
+
" [bold]Key options:[/bold]\n"
|
|
137
|
+
" [bold green]--system 'be brief'[/bold green] Optional system prompt prepended to the conversation.\n"
|
|
138
|
+
" [bold green]--max-new-tokens N[/bold green] Generation budget (default 128).\n"
|
|
139
|
+
" [bold green]--sample / --no-sample[/bold green] Sampling vs greedy decoding (default greedy).\n"
|
|
140
|
+
" [bold green]--temperature / --top-p[/bold green] Standard sampling controls (used when --sample).\n"
|
|
141
|
+
" [bold green]--show-prompt[/bold green] Print the chat-templated prompt before generating.",
|
|
142
|
+
title="chat",
|
|
143
|
+
border_style=ACCENT_DIM,
|
|
144
|
+
padding=(0, 2),
|
|
145
|
+
))
|
|
146
|
+
|
|
113
147
|
# ── Core Operations ───────────────────────────────────────────
|
|
114
148
|
console.print()
|
|
115
149
|
console.print(Rule("[bold]Core Operations[/bold]", style=ACCENT))
|
|
@@ -261,7 +295,7 @@ def _show_extensive_help() -> None:
|
|
|
261
295
|
),
|
|
262
296
|
(
|
|
263
297
|
"steer",
|
|
264
|
-
"interpkit steer gpt2 'The sky is' --positive
|
|
298
|
+
"interpkit steer gpt2 'The sky is' --positive ' love' --negative ' hate' --at transformer.h.8",
|
|
265
299
|
"Activation steering. Computes a 'steering vector' as the mean-difference between"
|
|
266
300
|
" activations for contrasting concepts ([bold green]--positive[/bold green] vs"
|
|
267
301
|
" [bold green]--negative[/bold green]), then adds a scaled copy of it to the activations"
|
|
@@ -422,6 +456,7 @@ def main(
|
|
|
422
456
|
quick_start = _cmd_table([
|
|
423
457
|
("scan", "One-command overview \u2014 DLA, lens, attention, attribution"),
|
|
424
458
|
("report", "Generate an interactive HTML report"),
|
|
459
|
+
("chat", "Send a message to a chat / instruct model"),
|
|
425
460
|
])
|
|
426
461
|
|
|
427
462
|
core_ops = _cmd_table([
|
|
@@ -469,6 +504,10 @@ def main(
|
|
|
469
504
|
console.print()
|
|
470
505
|
console.print(" [dim]\u25b8[/dim] Most commands accept [bold green]--save[/bold green] and [bold green]--html[/bold green] for exports.")
|
|
471
506
|
console.print(f" [dim]\u25b8[/dim] Run [bold {ACCENT}]interpkit <command> --help[/bold {ACCENT}] for detailed usage.")
|
|
507
|
+
console.print(
|
|
508
|
+
f" [dim]\u25b8[/dim] No console script on PATH? [bold {ACCENT}]python -m interpkit[/bold {ACCENT}]"
|
|
509
|
+
" works the same everywhere."
|
|
510
|
+
)
|
|
472
511
|
console.print(f" [dim]\u25b8[/dim] New here? Try [bold {ACCENT}]interpkit --extensive[/bold {ACCENT}] for a plain-English walkthrough.")
|
|
473
512
|
console.print()
|
|
474
513
|
|
|
@@ -781,7 +820,8 @@ def features(
|
|
|
781
820
|
model_name: str = typer.Argument(..., help="HuggingFace model ID (e.g. gpt2)"),
|
|
782
821
|
input_data: str | None = typer.Argument(None, help="Input text (omit when using --positive-file / --negative-file)"),
|
|
783
822
|
at: str = typer.Option(..., "--at", help="Module name to decompose (e.g. transformer.h.8)"),
|
|
784
|
-
sae: str = typer.Option(..., "--sae", help="SAE source: HuggingFace repo ID
|
|
823
|
+
sae: str = typer.Option(..., "--sae", help="SAE source: HuggingFace repo ID, local file path (.safetensors / .pt), or 'org/repo/subfolder' shorthand"),
|
|
824
|
+
sae_subfolder: str | None = typer.Option(None, "--sae-subfolder", help="Subfolder inside the SAE repo (e.g. 'blocks.8.hook_resid_pre'). Equivalent to appending it to --sae."),
|
|
785
825
|
top_k: int = typer.Option(20, "--top-k", help="Number of top features to display"),
|
|
786
826
|
positive_file: str | None = typer.Option(None, "--positive-file", help="Text file with positive examples for contrastive analysis, one per line"),
|
|
787
827
|
negative_file: str | None = typer.Option(None, "--negative-file", help="Text file with negative examples for contrastive analysis, one per line"),
|
|
@@ -800,13 +840,19 @@ def features(
|
|
|
800
840
|
pos_inputs = read_examples_file(positive_file)
|
|
801
841
|
neg_inputs = read_examples_file(negative_file)
|
|
802
842
|
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
803
|
-
result = m.contrastive_features(
|
|
843
|
+
result = m.contrastive_features(
|
|
844
|
+
pos_inputs, neg_inputs, at=at, sae=sae, top_k=top_k,
|
|
845
|
+
sae_subfolder=sae_subfolder,
|
|
846
|
+
)
|
|
804
847
|
else:
|
|
805
848
|
if input_data is None:
|
|
806
849
|
raise typer.BadParameter("Provide input text or use --positive-file / --negative-file for contrastive mode")
|
|
807
850
|
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
808
851
|
with console.status(" Decomposing features..."):
|
|
809
|
-
result = m.features(
|
|
852
|
+
result = m.features(
|
|
853
|
+
input_data, at=at, sae=sae, top_k=top_k,
|
|
854
|
+
sae_subfolder=sae_subfolder,
|
|
855
|
+
)
|
|
810
856
|
|
|
811
857
|
if _output_format == "json":
|
|
812
858
|
_json_dump(result)
|
|
@@ -847,8 +893,9 @@ def dla(
|
|
|
847
893
|
top_k: int = typer.Option(10, "--top-k", help="Number of top/bottom contributors to show"),
|
|
848
894
|
save: str | None = typer.Option(None, "--save", help="Save bar chart to file (e.g. dla.png)"),
|
|
849
895
|
html_path: str | None = typer.Option(None, "--html", help="Save interactive HTML to file"),
|
|
850
|
-
sae: str | None = typer.Option(None, "--sae", help="SAE source: HuggingFace repo ID
|
|
896
|
+
sae: str | None = typer.Option(None, "--sae", help="SAE source: HuggingFace repo ID, local file path (.safetensors / .pt), or 'org/repo/subfolder' shorthand"),
|
|
851
897
|
sae_at: str | None = typer.Option(None, "--sae-at", help="Module to decompose through the SAE (e.g. transformer.h.11.attn)"),
|
|
898
|
+
sae_subfolder: str | None = typer.Option(None, "--sae-subfolder", help="Subfolder inside the SAE repo (e.g. 'blocks.8.hook_resid_pre'). Equivalent to appending it to --sae."),
|
|
852
899
|
device: str | None = typer.Option(None, help="Device"),
|
|
853
900
|
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
854
901
|
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
@@ -865,7 +912,7 @@ def dla(
|
|
|
865
912
|
result = m.dla(
|
|
866
913
|
input_data, token=parsed_token, position=position,
|
|
867
914
|
top_k=top_k, save=save, html=html_path,
|
|
868
|
-
sae=sae, sae_at=sae_at,
|
|
915
|
+
sae=sae, sae_at=sae_at, sae_subfolder=sae_subfolder,
|
|
869
916
|
)
|
|
870
917
|
if _output_format == "json":
|
|
871
918
|
_json_dump(result)
|
|
@@ -959,5 +1006,62 @@ def report(
|
|
|
959
1006
|
_json_dump(result)
|
|
960
1007
|
|
|
961
1008
|
|
|
1009
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1010
|
+
# chat
|
|
1011
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
@app.command()
|
|
1015
|
+
def chat(
|
|
1016
|
+
model_name: str = typer.Argument(..., help="HuggingFace chat/instruct model ID (e.g. HuggingFaceTB/SmolLM2-360M-Instruct)"),
|
|
1017
|
+
message: str = typer.Argument(..., help="User message to send"),
|
|
1018
|
+
system: str | None = typer.Option(None, "--system", help="Optional system prompt"),
|
|
1019
|
+
max_new_tokens: int = typer.Option(128, "--max-new-tokens", help="Max generation length"),
|
|
1020
|
+
sample: bool = typer.Option(False, "--sample/--no-sample", help="Sample (True) or use greedy decoding (False, default)"),
|
|
1021
|
+
temperature: float = typer.Option(1.0, "--temperature", help="Sampling temperature (used when --sample)"),
|
|
1022
|
+
top_p: float = typer.Option(1.0, "--top-p", help="Nucleus sampling cutoff (used when --sample)"),
|
|
1023
|
+
show_prompt: bool = typer.Option(False, "--show-prompt", help="Print the chat-templated prompt before generating"),
|
|
1024
|
+
device: str | None = typer.Option(None, help="Device"),
|
|
1025
|
+
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
1026
|
+
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
1027
|
+
) -> None:
|
|
1028
|
+
"""Send a chat message and print the model's response.
|
|
1029
|
+
|
|
1030
|
+
Routes the message through the tokenizer's chat template
|
|
1031
|
+
(``apply_chat_template`` with ``add_generation_prompt=True``) and
|
|
1032
|
+
calls ``model.generate``. Errors clearly when the loaded model has
|
|
1033
|
+
no chat template (i.e. is a base/non-instruct model).
|
|
1034
|
+
"""
|
|
1035
|
+
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
1036
|
+
with console.status(" Generating response..."):
|
|
1037
|
+
result = m.chat(
|
|
1038
|
+
message,
|
|
1039
|
+
system=system,
|
|
1040
|
+
max_new_tokens=max_new_tokens,
|
|
1041
|
+
do_sample=sample,
|
|
1042
|
+
temperature=temperature,
|
|
1043
|
+
top_p=top_p,
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
if show_prompt:
|
|
1047
|
+
console.print(Panel(
|
|
1048
|
+
result["prompt"],
|
|
1049
|
+
title="[bold]Prompt[/bold]",
|
|
1050
|
+
border_style=ACCENT_DIM,
|
|
1051
|
+
padding=(0, 1),
|
|
1052
|
+
))
|
|
1053
|
+
|
|
1054
|
+
console.print()
|
|
1055
|
+
console.print(Panel(
|
|
1056
|
+
result["response"],
|
|
1057
|
+
title=f"[bold]{model_name}[/bold]",
|
|
1058
|
+
border_style=ACCENT,
|
|
1059
|
+
padding=(0, 2),
|
|
1060
|
+
))
|
|
1061
|
+
|
|
1062
|
+
if _output_format == "json":
|
|
1063
|
+
_json_dump({k: v for k, v in result.items() if k not in {"input_ids", "output_ids"}})
|
|
1064
|
+
|
|
1065
|
+
|
|
962
1066
|
if __name__ == "__main__":
|
|
963
1067
|
app()
|