nnterp 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,10 @@
1
+ build/
2
+ dist/
3
+ public/
4
+ *.egg-info/
5
+ .installed.cfg
6
+ .vscode
7
+ __pycache__/
8
+ .DS_STORE
9
+ config.yaml
10
+ .nfs*
nnterp-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.1
2
+ Name: nnterp
3
+ Version: 0.1.0
4
+ Summary: Utils and mechanistic interpretability intervensions using nnsight
5
+ Author-email: Clément Dumas <butanium.contact@gmail.com>
6
+ Project-URL: Homepage, https://github.com/butanium/nnterp
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.7
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: nnsight
13
+
14
+ # nnterp
15
+
16
+ This might become a real package at some point, for now it's just a package in which I dump my nnsight code.
17
+
18
+ nnsight_utils.py basically allows you to deal with TL and HF models in a similar way.
19
+
20
+ interventions.py is a module that contains tools like logit lens, patchscope lens and other interventions.
nnterp-0.1.0/README.md ADDED
@@ -0,0 +1,7 @@
1
+ # nnterp
2
+
3
+ This might become a real package at some point, for now it's just a package in which I dump my nnsight code.
4
+
5
+ nnsight_utils.py basically allows you to deal with TL and HF models in a similar way.
6
+
7
+ interventions.py is a module that contains tools like logit lens, patchscope lens and other interventions.
File without changes
@@ -0,0 +1,609 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ import torch as th
4
+ from torch.utils.data import DataLoader
5
+ from warnings import warn
6
+ from nnsight_utils import (
7
+ get_layer,
8
+ get_layer_output,
9
+ get_layer_input,
10
+ get_attention,
11
+ get_attention_output,
12
+ get_next_token_probs,
13
+ collect_activations,
14
+ collect_activations_batched,
15
+ get_num_layers,
16
+ NNLanguageModel,
17
+ GetModuleOutput,
18
+ project_on_vocab,
19
+ )
20
+ from typing import Optional
21
+
22
+ __all__ = [
23
+ "logit_lens",
24
+ "TargetPrompt",
25
+ "repeat_prompt",
26
+ "TargetPromptBatch",
27
+ "patchscope_lens",
28
+ "patchscope_generate",
29
+ "steer",
30
+ "skip_layers",
31
+ "patch_attention_lens",
32
+ "patch_object_attn_lens",
33
+ "object_lens",
34
+ ]
35
+
36
+
37
+ @th.no_grad
38
+ def logit_lens(
39
+ nn_model: NNLanguageModel, prompts: list[str] | str, scan=True, remote=False
40
+ ):
41
+ """
42
+ Same as logit_lens but for Llama models directly instead of Transformer_lens models.
43
+ Get the probabilities of the next token for the last token of each prompt at each layer using the logit lens.
44
+
45
+ Args:
46
+ nn_model: NNSight Language Model
47
+ prompts: List of prompts or a single prompt
48
+
49
+ Returns:
50
+ A tensor of shape (num_prompts, num_layers, vocab_size) containing the probabilities
51
+ of the next token for each prompt at each layer. Tensor is on the CPU.
52
+ """
53
+ with nn_model.trace(prompts, scan=scan, remote=remote) as tracer:
54
+ hiddens_l = collect_activations(nn_model, prompts, open_context=False)
55
+ probs_l = []
56
+ for hiddens in hiddens_l:
57
+ logits = project_on_vocab(nn_model, hiddens)
58
+ probs = logits.softmax(-1).cpu()
59
+ probs_l.append(probs)
60
+ probs = th.stack(probs_l).transpose(0, 1).save()
61
+ return probs.value
62
+
63
+
64
+ @dataclass
65
+ class TargetPrompt:
66
+ prompt: str
67
+ index_to_patch: int
68
+
69
+
70
+ def repeat_prompt(
71
+ nn_model=None, words=None, rel=" ", sep="\n", placeholder="?"
72
+ ) -> TargetPrompt:
73
+ """
74
+ Prompt used in the patchscopes paper to predict the next token.
75
+ https://github.com/PAIR-code/interpretability/blob/master/patchscopes/code/next_token_prediction.ipynb
76
+ """
77
+ if words is None:
78
+ words = [
79
+ "king",
80
+ "1135",
81
+ "hello",
82
+ ]
83
+ assert nn_model is None or (
84
+ len(nn_model.tokenizer.tokenize(placeholder)) == 1
85
+ ), "Using a placeholder that is not a single token sounds like a bad idea"
86
+ prompt = sep.join([w + rel + w for w in words]) + sep + placeholder
87
+ index_to_patch = -1
88
+ return TargetPrompt(prompt, index_to_patch)
89
+
90
+
91
+ @dataclass
92
+ class TargetPromptBatch:
93
+ """
94
+ A class to handle multiple target prompts with potentially different indices to patch
95
+ """
96
+
97
+ prompts: list[str]
98
+ index_to_patch: th.Tensor
99
+
100
+ @classmethod
101
+ def from_target_prompts(cls, prompts_: list[TargetPrompt], tokenizer=None):
102
+ prompts = [p.prompt for p in prompts_]
103
+ index_to_patch = th.tensor([p.index_to_patch for p in prompts_])
104
+ if index_to_patch.min() < 0:
105
+ if tokenizer is None:
106
+ raise ValueError(
107
+ "If using negative index_to_patch, a tokenizer must be provided"
108
+ )
109
+ return cls(prompts, index_to_patch)
110
+
111
+ @classmethod
112
+ def from_target_prompt(cls, prompt: TargetPrompt, batch_size: int):
113
+ prompts = [prompt.prompt] * batch_size
114
+ index_to_patch = th.tensor([prompt.index_to_patch] * batch_size)
115
+ return cls(prompts, index_to_patch)
116
+
117
+ @classmethod
118
+ def from_prompts(
119
+ cls, prompts: str | list[str], index_to_patch: int | list[int] | th.Tensor
120
+ ):
121
+ if isinstance(prompts, str):
122
+ prompts = [prompts]
123
+ if isinstance(index_to_patch, int):
124
+ index_to_patch = th.tensor([index_to_patch] * len(prompts))
125
+ elif isinstance(index_to_patch, list):
126
+ index_to_patch = th.tensor(index_to_patch)
127
+ elif not isinstance(index_to_patch, th.Tensor):
128
+ raise ValueError(
129
+ f"index_to_patch must be an int, a list of ints or a tensor, got {type(index_to_patch)}"
130
+ )
131
+ return cls(prompts, index_to_patch)
132
+
133
+ def __len__(self):
134
+ return len(self.prompts)
135
+
136
+ def __getitem__(self, idx):
137
+ return TargetPrompt(self.prompts[idx], self.index_to_patch[idx])
138
+
139
+ def __iter__(self):
140
+ for i in range(len(self)):
141
+ yield self[i]
142
+
143
+ @staticmethod
144
+ def auto(
145
+ target_prompt: str | TargetPrompt | list[TargetPrompt] | TargetPromptBatch,
146
+ batch_size: int,
147
+ ):
148
+ if isinstance(target_prompt, TargetPrompt):
149
+ target_prompt = TargetPromptBatch.from_target_prompt(
150
+ target_prompt, batch_size
151
+ )
152
+ elif isinstance(target_prompt, list):
153
+ target_prompt = TargetPromptBatch.from_target_prompts(target_prompt)
154
+ elif not isinstance(target_prompt, TargetPromptBatch):
155
+ raise ValueError(
156
+ f"patch_prompts must be a str, a TargetPrompt, a list of TargetPrompt or a TargetPromptBatch, got {type(target_prompt)}"
157
+ )
158
+ return target_prompt
159
+
160
+
161
+ @th.no_grad
162
+ def patchscope_lens(
163
+ nn_model: NNLanguageModel,
164
+ source_prompts: list[str] | str | None = None,
165
+ target_patch_prompts: (
166
+ TargetPromptBatch | list[TargetPrompt] | TargetPrompt | None
167
+ ) = None,
168
+ layers=None,
169
+ latents=None,
170
+ remote=False,
171
+ ):
172
+ """
173
+ Replace the hidden state of the patch_prompt.index_to_patch token in the patch_prompt.prompt with the hidden state of the last token of each prompt at each layer.
174
+ Returns the probabilities of the next token in patch_prompt for each prompt for each layer intervention.
175
+ Args:
176
+ nn_model: The NNSight TL model
177
+ source_prompts: List of prompts or a single prompt to get the hidden states of the last token
178
+ target_patch_prompts: TargetPrompt(s) / TargetPromptBatch containing the prompt to patch and the index of the token to patch
179
+ layers: List of layers to intervene on. If None, all layers are intervened on.
180
+ remote: If True, the function will run on the nndif server. See `nnsight.net/status` to check which models are available.
181
+
182
+ Returns:
183
+ A tensor of shape (num_prompts, num_layers, vocab_size) containing the probabilities
184
+ of the next token for each prompt at each layer. Tensor is on the CPU.
185
+ """
186
+ if target_patch_prompts is None:
187
+ target_patch_prompts = repeat_prompt()
188
+ if latents is not None:
189
+ if len(set([len(h) for h in latents])) > 1:
190
+ raise ValueError("Inconsistent number of hiddens")
191
+ num_sources = len(latents[0])
192
+ else:
193
+ if source_prompts is None:
194
+ raise ValueError("Either source_prompts or hiddens must be provided")
195
+ if isinstance(source_prompts, str):
196
+ source_prompts = [source_prompts]
197
+ num_sources = len(source_prompts)
198
+ target_patch_prompts = TargetPromptBatch.auto(target_patch_prompts, num_sources)
199
+ if len(target_patch_prompts) != num_sources:
200
+ raise ValueError(
201
+ f"Number of sources ({num_sources}) does not match number of patch prompts ({len(target_patch_prompts)})"
202
+ )
203
+ if latents is None:
204
+ latents = collect_activations(nn_model, source_prompts, remote=remote)
205
+ elif source_prompts is not None:
206
+ raise ValueError("You cannot provide both source_prompts and hiddens")
207
+
208
+ probs_l = []
209
+ if layers is None:
210
+ layers = list(range(get_num_layers(nn_model)))
211
+ for layer in layers:
212
+ with nn_model.trace(
213
+ target_patch_prompts.prompts,
214
+ scan=layer == 0,
215
+ remote=remote,
216
+ ):
217
+ get_layer_output(nn_model, layer)[
218
+ th.arange(num_sources), target_patch_prompts.index_to_patch
219
+ ] = latents[layer]
220
+ probs_l.append(get_next_token_probs(nn_model).cpu().save())
221
+ probs = th.cat([p.value for p in probs_l], dim=0)
222
+ return probs.reshape(len(layers), num_sources, -1).transpose(0, 1)
223
+
224
+
225
+ @th.no_grad
226
+ def patchscope_generate(
227
+ nn_model: NNLanguageModel,
228
+ prompts: list[str] | str,
229
+ target_patch_prompt: TargetPrompt,
230
+ max_length: int = 50,
231
+ layers=None,
232
+ remote=False,
233
+ max_batch_size=32,
234
+ ):
235
+ """
236
+ Replace the hidden state of the patch_prompt.index_to_patch token in the patch_prompt.prompt with the hidden state of the last token of each prompt at each layer.
237
+ Returns the probabilities of the next token in patch_prompt for each prompt for each layer intervention.
238
+ Args:
239
+ nn_model: The NNSight LanguageModel with llama architecture
240
+ prompts: List of prompts or a single prompt to get the hidden states of the last token
241
+ target_patch_prompt: A TargetPrompt object containing the prompt to patch and the index of the token to patch
242
+ layers: List of layers to intervene on. If None, all layers are intervened on.
243
+ max_length: The maximum length of the generated sequence
244
+ remote: If True, the function will run on the nndif server. See `nnsight.net/status` to check which models are available.
245
+ max_batch_size: The maximum number of prompts to intervene on at once.
246
+
247
+ Returns:
248
+ A tensor of shape (num_prompts, num_layers, vocab_size) containing the probabilities
249
+ of the next token for each prompt at each layer. Tensor is on the CPU.
250
+ """
251
+ if isinstance(prompts, str):
252
+ prompts = [prompts]
253
+ if len(prompts) > max_batch_size:
254
+ warn(
255
+ f"Number of prompts ({len(prompts)}) exceeds max_batch_size ({max_batch_size}). This may cause memory errors."
256
+ )
257
+ hiddens = collect_activations(nn_model, prompts, remote=remote, layers=layers)
258
+ generations = {}
259
+ gen_kwargs = dict(remote=remote, max_new_tokens=max_length)
260
+ layer_loader = DataLoader(layers, batch_size=max(max_batch_size // len(prompts), 1))
261
+ for layer_batch in layer_loader:
262
+ with nn_model.generate(**gen_kwargs) as tracer:
263
+ for layer in layer_batch:
264
+ layer = layer.item()
265
+ with tracer.invoke(
266
+ [target_patch_prompt.prompt] * len(prompts),
267
+ scan=layer == 0,
268
+ ):
269
+ get_layer_output(nn_model, layer)[
270
+ :, target_patch_prompt.index_to_patch
271
+ ] = hiddens[layer]
272
+ gen = nn_model.generator.output.save()
273
+ generations[layer] = gen
274
+ for k, v in generations.items():
275
+ generations[k] = v.cpu()
276
+ return generations
277
+
278
+
279
+ def steer(
280
+ nn_model: NNLanguageModel,
281
+ layers: int | list[int],
282
+ steering_vector: th.Tensor,
283
+ factor: float = 1,
284
+ position: int = -1,
285
+ get_module: GetModuleOutput = get_layer_output,
286
+ ):
287
+ """
288
+ Steer the hidden states of a layer using a steering vector
289
+ Args:
290
+ nn_model: The NNSight model
291
+ layers: The layer(s) to steer
292
+ steering_vector: The steering vector to apply
293
+ factor: The factor to multiply the steering vector by
294
+ """
295
+ if isinstance(layers, int):
296
+ layers = [layers]
297
+ for layer in layers:
298
+ get_module(nn_model, layer)[:, position] += factor * steering_vector
299
+
300
+
301
+ def skip_layers(
302
+ nn_model: NNLanguageModel,
303
+ layers_to_skip: int | list[int],
304
+ position: int = -1,
305
+ ):
306
+ """
307
+ Skip the computation of the specified layers
308
+ Args:
309
+ nn_model: The NNSight model
310
+ layers_to_skip: The layers to skip
311
+ """
312
+ if isinstance(layers_to_skip, int):
313
+ layers_to_skip = [layers_to_skip]
314
+ for layer in layers_to_skip:
315
+ get_layer_output(nn_model, layer)[:, position] = get_layer_input(
316
+ nn_model, layer
317
+ )[:, position]
318
+
319
+
320
+ def patch_object_attn_lens(
321
+ nn_model: NNLanguageModel,
322
+ source_prompts: list[str] | str,
323
+ target_prompts: list[str] | str,
324
+ attn_idx_patch: int,
325
+ num_patches: int = 5,
326
+ scan=True,
327
+ ):
328
+ """
329
+ A complex lens that makes the model attend to the hidden states of the last token of the source prompts instead of the attn_idx_patch token of the target prompts at last token prediction. For each layer, this intervention is performed for num_patches layers.
330
+ Args:
331
+ nn_model: The NNSight model
332
+ source_prompts: The prompts to get the hidden states of the last token from
333
+ target_prompts: The prompts to predict the next token for
334
+ attn_idx_patch: The index of the token to patch in the target prompts
335
+ num_patches: The number of layers to patch for each layer
336
+
337
+ Returns:
338
+ A tensor of shape (num_target_prompts, num_layers, vocab_size) containing the probabilities
339
+ of the next token for each target prompt at each layer. Tensor is on the CPU.
340
+ """
341
+ if isinstance(source_prompts, str):
342
+ source_prompts = [source_prompts]
343
+ if isinstance(target_prompts, str):
344
+ target_prompts = [target_prompts]
345
+ global probs_l
346
+ num_layers = get_num_layers(nn_model)
347
+ probs_l = []
348
+
349
+ def get_act(model, layer):
350
+ return get_attention(model, layer).input[1]["hidden_states"]
351
+
352
+ source_hiddens = collect_activations(
353
+ nn_model,
354
+ source_prompts,
355
+ get_activations=get_act,
356
+ )
357
+ for layer in range(num_layers):
358
+ with nn_model.trace(target_prompts, scan=layer == 0 and scan):
359
+ for next_layer in range(layer, min(num_layers, layer + num_patches)):
360
+ get_attention(nn_model, next_layer).input[1]["hidden_states"][
361
+ :, attn_idx_patch
362
+ ] = source_hiddens[next_layer]
363
+ probs = get_next_token_probs(nn_model).cpu().save()
364
+ probs_l.append(probs)
365
+ return (
366
+ th.cat([p.value for p in probs_l], dim=0)
367
+ .reshape(num_layers, len(target_prompts), -1)
368
+ .transpose(0, 1)
369
+ )
370
+
371
+
372
+ @dataclass
373
+ class LatentPrompt:
374
+ """
375
+ A class to handle prompts with latent spots that will be replaced with latent vectors
376
+ """
377
+
378
+ prompt: str
379
+ latent_spots: list[int]
380
+
381
+ @classmethod
382
+ def from_string(cls, prompt: str, tokenizer, placeholder_token: str | None = None):
383
+ """
384
+ Create a LatentPrompt object from a string prompt
385
+
386
+ Args:
387
+ prompt: The prompt string
388
+ tokenizer: The tokenizer to use
389
+ placeholder_token: The token to use as a placeholder. If None, the tokenizer's bos_token is used.
390
+ """
391
+ if placeholder_token is None:
392
+ placeholder_token = tokenizer.bos_token
393
+ tokens = tokenizer.tokenize(prompt)
394
+ latent_spots = [
395
+ i - len(tokens) for i, t in enumerate(tokens) if t == placeholder_token
396
+ ]
397
+ return cls(prompt, latent_spots)
398
+
399
+
400
+ @dataclass
401
+ class LatentPromptBatch:
402
+ """
403
+ A class to batch multiple LatentPrompt objects and modify them at the token level
404
+ """
405
+
406
+ inputs: dict
407
+ latent_prompts: list[LatentPrompt]
408
+
409
+ @classmethod
410
+ def from_latent_prompts(cls, latent_prompts: list[LatentPrompt], tokenizer):
411
+ prompts = [lp.prompt for lp in latent_prompts]
412
+ inputs = tokenizer(prompts, return_tensors="pt")
413
+ return cls(inputs, latent_prompts)
414
+
415
+ def replace_tokens(self, token: int, replacements: list[int] | int):
416
+ for tokens in self.inputs.input_ids:
417
+ for i, t in enumerate(tokens):
418
+ if t == token:
419
+ if isinstance(replacements, int):
420
+ tokens[i] = replacements
421
+ else:
422
+ tokens[i] = replacements.pop(0)
423
+ return self
424
+
425
+
426
+ def run_latent_prompt(
427
+ nn_model: NNLanguageModel,
428
+ latent_prompts: list[LatentPrompt] | LatentPrompt | LatentPromptBatch,
429
+ prompts: list[str] | str | None = None,
430
+ latents: list[th.Tensor] | th.Tensor | None = None,
431
+ collect_from_single_layer: int | bool = False, # todo doc and ifs
432
+ patch_from_layer: int = 0,
433
+ patch_until_layer: int | None = None,
434
+ remote=False,
435
+ scan=True,
436
+ batch_size=32,
437
+ ):
438
+ """
439
+ Perform a forward pass on latent prompts and return the probabilities of the next token for each latent prompt.
440
+ Args:
441
+ nn_model: The NNSight model
442
+ latent_prompts: A (list of) LatentPrompt object(s) / a latent prompt batch to run the forward pass on.
443
+ prompts: The prompts to use as placeholders for the latent spots. If None, latents must be provided.
444
+ latents: The latent vectors to use. Must be of shape (num_latent_prompts, num_patches, hidden_size) if collect_from_single_layer is False, else (1, num_patches, hidden_size).
445
+ If None, prompts must be provided.
446
+ collect_from_single_layer: If True, assume that the latents are collected from a single layer. If int, will use latent from this layer for every patch.
447
+ Must be of shape (1, num_patches, hidden_size).
448
+ patch_from_layer: The layer to start patching from
449
+ patch_until_layer: The layer to patch until. If None, all layers from patch_from_layer to the last layer are patched.
450
+ remote: Whether to run the model on the remote device.
451
+ scan: Whether to use nnsight's scan when tracing the model.
452
+
453
+ Returns:
454
+ The probabilities of the next token for each latent prompt of shape (num_latent_prompts, vocab_size)
455
+ """
456
+ if patch_until_layer is None:
457
+ patch_until_layer = get_num_layers(nn_model) - 1
458
+ if isinstance(prompts, str):
459
+ prompts = [prompts]
460
+ if isinstance(latent_prompts, LatentPrompt):
461
+ latent_prompts = [latent_prompts]
462
+ if isinstance(latent_prompts, LatentPromptBatch):
463
+ inputs = latent_prompts.inputs
464
+ latent_prompts = latent_prompts.latent_prompts
465
+ else:
466
+ inputs = [lp.prompt for lp in latent_prompts]
467
+ if latents is None == prompts is None:
468
+ raise ValueError("Either prompts or latents must be provided")
469
+ if collect_from_single_layer is True and latents is None:
470
+ raise ValueError(
471
+ "When collecting from a single layer, latents must be provided"
472
+ )
473
+ if latents is not None and latents.dim() != 3:
474
+ raise ValueError(
475
+ f"Latents must be of shape (num_layers, num_patches, hidden_size), got {latents.shape}"
476
+ )
477
+ if collect_from_single_layer and latents is not None:
478
+ if latents.shape[0] != 1:
479
+ raise ValueError(
480
+ f"Latents must be of shape (1, num_patches, hidden_size) when collect_from_single_layer is True, got {latents.shape}"
481
+ )
482
+ n_patches = len(prompts) if prompts is not None else latents.shape[1]
483
+ num_spots = sum([len(lp.latent_spots) for lp in latent_prompts])
484
+ if num_spots != n_patches:
485
+ raise ValueError(
486
+ f"Number of latent spots does not match number of prompts/latents: got {num_spots} spots and {n_patches} prompts/latents"
487
+ )
488
+ if latents is None:
489
+ prompt_loader = DataLoader(prompts, batch_size=batch_size)
490
+ latents = [[] for _ in range(patch_until_layer + 1)]
491
+ for prompt_batch in prompt_loader:
492
+ acts = collect_activations(
493
+ nn_model,
494
+ prompt_batch,
495
+ layers=(
496
+ list(range(patch_from_layer, patch_until_layer + 1))
497
+ if not collect_from_single_layer
498
+ else collect_from_single_layer
499
+ ),
500
+ remote=remote,
501
+ ) # [layer, batch, d]
502
+ for layer, act in enumerate(acts):
503
+ latents[layer].extend(act)
504
+
505
+ with nn_model.trace(inputs, scan=scan, remote=remote):
506
+ h_index = 0
507
+ for i, lp in enumerate(latent_prompts):
508
+ for spot in lp.latent_spots:
509
+ for layer in range(patch_from_layer, patch_until_layer + 1):
510
+ get_layer_output(nn_model, layer)[i, spot] = (
511
+ latents[layer][h_index]
512
+ if not collect_from_single_layer
513
+ else latents[0][h_index]
514
+ )
515
+ h_index += 1
516
+ probs = get_next_token_probs(nn_model).cpu().save()
517
+ return probs.value
518
+
519
+
520
+ def latent_prompt_lens(
521
+ nn_model: NNLanguageModel,
522
+ latent_prompts: list[LatentPrompt] | LatentPrompt,
523
+ prompts: list[str] | str | None = None,
524
+ latents: list[th.Tensor] | th.Tensor | None = None,
525
+ collect_from_single_layer: bool = True,
526
+ patch_from_layer: int | None = 0,
527
+ patch_until_layer: int | None = None,
528
+ remote=False,
529
+ scan=True,
530
+ batch_size=32,
531
+ ):
532
+ if not collect_from_single_layer and patch_until_layer is not None:
533
+ raise ValueError(
534
+ "When collecting from multiple layers, patch_until_layer must be None"
535
+ )
536
+ if prompts is None and latents is None:
537
+ raise ValueError("Either prompts or latents must be provided")
538
+ if prompts is not None and latents is not None:
539
+ raise ValueError("Only one of prompts or latents can be provided")
540
+ if isinstance(prompts, str):
541
+ prompts = [prompts]
542
+ if prompts is not None:
543
+ latents = collect_activations_batched(
544
+ nn_model,
545
+ prompts,
546
+ remote=remote,
547
+ batch_size=batch_size,
548
+ )
549
+
550
+ probs = []
551
+ for layer in range(get_num_layers(nn_model)):
552
+ if collect_from_single_layer:
553
+ latents_ = latents[layer].unsqueeze(0)
554
+ if patch_until_layer is None:
555
+ patch_until_layer_ = layer
556
+ else:
557
+ patch_until_layer_ = patch_until_layer
558
+ else:
559
+ patch_until_layer_ = layer
560
+ latents_ = latents
561
+ if patch_from_layer is None:
562
+ patch_from_layer_ = layer
563
+ else:
564
+ patch_from_layer_ = patch_from_layer
565
+ probs.append(
566
+ run_latent_prompt(
567
+ nn_model,
568
+ latent_prompts,
569
+ latents=latents_,
570
+ collect_from_single_layer=collect_from_single_layer,
571
+ patch_from_layer=patch_from_layer_,
572
+ patch_until_layer=patch_until_layer_,
573
+ remote=remote,
574
+ scan=scan and layer == 0,
575
+ )
576
+ )
577
+ return th.stack(probs).transpose(0, 1)
578
+
579
+
580
+ class Intervention:
581
+ """
582
+ A class that contains an intervention on a model
583
+ """
584
+
585
+ latent: th.Tensor
586
+ layer: int
587
+ position: int | None = None
588
+ get_output: GetModuleOutput | None = None
589
+
590
+ def __post_init__(self):
591
+ if self.get_output is None:
592
+ self.get_output = get_layer_output
593
+
594
+ def apply(self, nn_model: NNLanguageModel):
595
+ """
596
+ Perform the intervention on the model
597
+ """
598
+ self.get_output(nn_model, self.layer)[:, self.position] = self.latent
599
+
600
+ @classmethod
601
+ def from_prompts(
602
+ cls, nn_model, prompts, layers, position=None, remote=False, get_output=None
603
+ ):
604
+ if isinstance(layers, int):
605
+ layers = [layers]
606
+ hiddens = collect_activations(
607
+ nn_model, prompts, layers, remote=remote, get_activations=get_layer_output
608
+ )
609
+ return [cls(h, l, position, get_output) for h, l in zip(hiddens, layers)]
@@ -0,0 +1,242 @@
1
+ from nnsight.models.UnifiedTransformer import UnifiedTransformer
2
+ from nnsight.models.LanguageModel import LanguageModelProxy, LanguageModel
3
+ from nnsight.envoy import Envoy
4
+ import torch as th
5
+ from torch.utils.data import DataLoader
6
+ from typing import Union, Callable
7
+ from contextlib import nullcontext
8
+
9
+ NNLanguageModel = Union[UnifiedTransformer, LanguageModel]
10
+ GetModuleOutput = Callable[[NNLanguageModel, int], LanguageModelProxy]
11
+
12
+
13
+ def get_num_layers(nn_model: NNLanguageModel):
14
+ """
15
+ Get the number of layers in the model
16
+ Args:
17
+ nn_model: The NNSight model
18
+ Returns:
19
+ The number of layers in the model
20
+ """
21
+ if isinstance(nn_model, UnifiedTransformer):
22
+ return len(nn_model.blocks)
23
+ else:
24
+ return len(nn_model.model.layers)
25
+
26
+
27
+ def get_layer(nn_model: NNLanguageModel, layer: int) -> Envoy:
28
+ """
29
+ Get the layer of the model
30
+ Args:
31
+ nn_model: The NNSight model
32
+ layer: The layer to get
33
+ Returns:
34
+ The Envoy for the layer
35
+ """
36
+ if isinstance(nn_model, UnifiedTransformer):
37
+ return nn_model.blocks[layer]
38
+ else:
39
+ return nn_model.model.layers[layer]
40
+
41
+
42
+ def get_layer_input(nn_model: NNLanguageModel, layer: int) -> LanguageModelProxy:
43
+ """
44
+ Get the hidden state input of a layer
45
+ Args:
46
+ nn_model: The NNSight model
47
+ layer: The layer to get the input of
48
+ Returns:
49
+ The Proxy for the input of the layer
50
+ """
51
+ return get_layer(nn_model, layer).input[0][0]
52
+
53
+
54
+ def get_layer_output(nn_model: NNLanguageModel, layer: int) -> LanguageModelProxy:
55
+ """
56
+ Get the output of a layer
57
+ Args:
58
+ nn_model: The NNSight model
59
+ layer: The layer to get the output of
60
+ Returns:
61
+ The Proxy for the output of the layer
62
+ """
63
+ output = get_layer(nn_model, layer).output
64
+ if isinstance(nn_model, UnifiedTransformer):
65
+ return output
66
+ else:
67
+ return output[0]
68
+
69
+
70
+ def get_attention(nn_model: NNLanguageModel, layer: int) -> Envoy:
71
+ """
72
+ Get the attention module of a layer
73
+ Args:
74
+ nn_model: The NNSight model
75
+ layer: The layer to get the attention module of
76
+ Returns:
77
+ The Envoy for the attention module of the layer
78
+ """
79
+ if isinstance(nn_model, UnifiedTransformer):
80
+ return nn_model.blocks[layer].attn
81
+ else:
82
+ return nn_model.model.layers[layer].self_attn
83
+
84
+
85
+ def get_attention_output(nn_model: NNLanguageModel, layer: int) -> LanguageModelProxy:
86
+ """
87
+ Get the output of the attention block of a layer
88
+ Args:
89
+ nn_model: The NNSight model
90
+ layer: The layer to get the output of
91
+ Returns:
92
+ The Proxy for the output of the attention block of the layer
93
+ """
94
+ output = get_attention(nn_model, layer).output
95
+ if isinstance(nn_model, UnifiedTransformer):
96
+ return output
97
+ else:
98
+ return output[0]
99
+
100
+
101
+ def get_logits(nn_model: NNLanguageModel) -> LanguageModelProxy:
102
+ """
103
+ Get the logits of the model
104
+ Args:
105
+ nn_model: The NNSight model
106
+ Returns:
107
+ The Proxy for the logits of the model
108
+ """
109
+ if isinstance(nn_model, UnifiedTransformer):
110
+ return nn_model.unembed.output
111
+ else:
112
+ return nn_model.lm_head.output
113
+
114
+
115
+ def project_on_vocab(
116
+ nn_model: NNLanguageModel, h: LanguageModelProxy
117
+ ) -> LanguageModelProxy:
118
+ """
119
+ Project the hidden states on the vocabulary, after applying the model's last layer norm
120
+ Args:
121
+ nn_model: The NNSight model
122
+ h: The hidden states to project
123
+ Returns:
124
+ The Proxy for the hidden states projected on the vocabulary
125
+ """
126
+ if isinstance(nn_model, UnifiedTransformer):
127
+ ln_out = nn_model.ln_final(h)
128
+ return nn_model.unembed(ln_out)
129
+ else:
130
+ ln_out = nn_model.model.norm(h)
131
+ return nn_model.lm_head(ln_out)
132
+
133
+
134
+ def get_next_token_probs(nn_model: NNLanguageModel) -> LanguageModelProxy:
135
+ """
136
+ Get the probabilities of the model
137
+ Args:
138
+ nn_model: The NNSight model
139
+ Returns:
140
+ The Proxy for the probabilities of the model
141
+ """
142
+ return get_logits(nn_model)[:, -1, :].softmax(-1)
143
+
144
+
145
+ @th.no_grad
146
+ def collect_activations(
147
+ nn_model: NNLanguageModel,
148
+ prompts,
149
+ layers=None,
150
+ get_activations: GetModuleOutput | None = None,
151
+ remote=False,
152
+ idx=None,
153
+ open_context=True,
154
+ ):
155
+ """
156
+ Collect the hidden states of the last token of each prompt at each layer
157
+
158
+ Args:
159
+ nn_model: The NNSight model
160
+ prompts: The prompts to collect activations for
161
+ layers: The layers to collect activations for, default to all layers
162
+ get_activations: The function to get the activations, default to layer output
163
+ remote: Whether to run the model on the remote device
164
+ idx: The index of the token to collect activations for
165
+ open_context: Whether to open a trace context to collect activations. Set to false if you want to
166
+ use this function in a context that already has a trace context open
167
+
168
+ Returns:
169
+ The hidden states of the last token of each prompt at each layer, moved to cpu. If open_context is False, returns a list of
170
+ Proxies. Dimensions are (num_layers, num_prompts, hidden_size)
171
+ """
172
+ if get_activations is None:
173
+ get_activations = get_layer_output
174
+ tok_prompts = nn_model.tokenizer(prompts, return_tensors="pt", padding=True)
175
+ # Todo?: This is a hacky way to get the last token index but it works for both left and right padding
176
+ last_token_index = tok_prompts.attention_mask.flip(1).cumsum(1).bool().int().sum(1)
177
+ if idx is None:
178
+ idx = last_token_index.sub(1) # Default to the last token
179
+ elif idx < 0:
180
+ idx = last_token_index + idx
181
+ else:
182
+ raise ValueError(
183
+ "positive index is currently not supported due to left padding"
184
+ )
185
+ if layers is None:
186
+ layers = range(get_num_layers(nn_model))
187
+
188
+ def wrap(h):
189
+ if open_context:
190
+ return h.cpu().save()
191
+ return h
192
+
193
+ # Collect the hidden states of the last token of each prompt at each layer
194
+ context = nn_model.trace(prompts, remote=remote) if open_context else nullcontext()
195
+ with context:
196
+ acts = [
197
+ wrap(
198
+ get_activations(nn_model, layer)[
199
+ th.arange(len(tok_prompts.input_ids)),
200
+ idx,
201
+ ]
202
+ )
203
+ for layer in layers
204
+ ]
205
+ return th.stack(acts)
206
+
207
+
208
+ def collect_activations_batched(
209
+ nn_model: NNLanguageModel,
210
+ prompts,
211
+ batch_size,
212
+ layers=None,
213
+ get_activations: GetModuleOutput | None = None,
214
+ remote=False,
215
+ idx=None,
216
+ tqdm=None,
217
+ ):
218
+ """
219
+ Collect the hidden states of the last token of each prompt at each layer in batches
220
+
221
+ Args:
222
+ nn_model: The NNSight model
223
+ prompts: The prompts to collect activations for
224
+ batch_size: The batch size to use
225
+ layers: The layers to collect activations for, default to all layers
226
+ get_activations: The function to get the activations, default to layer output
227
+ remote: Whether to run the model on the remote device
228
+ idx: The index of the token to collect activations for
229
+
230
+ Returns:
231
+ The hidden states of the last token of each prompt at each layer, moved to cpu. Dimensions are (num_layers, num_prompts, hidden_size)
232
+ """
233
+ dataloader = DataLoader(prompts, batch_size=batch_size)
234
+ if tqdm is not None:
235
+ dataloader = tqdm(dataloader)
236
+ acts = []
237
+ for batch in dataloader:
238
+ acts_batch = collect_activations(
239
+ nn_model, batch, layers, get_activations, remote, idx
240
+ )
241
+ acts.append(acts_batch)
242
+ return th.cat(acts, dim=1)
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.1
2
+ Name: nnterp
3
+ Version: 0.1.0
4
+ Summary: Utils and mechanistic interpretability intervensions using nnsight
5
+ Author-email: Clément Dumas <butanium.contact@gmail.com>
6
+ Project-URL: Homepage, https://github.com/butanium/nnterp
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.7
11
+ Description-Content-Type: text/markdown
12
+ Requires-Dist: nnsight
13
+
14
+ # nnterp
15
+
16
+ This might become a real package at some point, for now it's just a package in which I dump my nnsight code.
17
+
18
+ nnsight_utils.py basically allows you to deal with TL and HF models in a similar way.
19
+
20
+ interventions.py is a module that contains tools like logit lens, patchscope lens and other interventions.
@@ -0,0 +1,11 @@
1
+ .gitignore
2
+ README.md
3
+ pyproject.toml
4
+ nnterp/__init__.py
5
+ nnterp/interventions.py
6
+ nnterp/nnsight_utils.py
7
+ nnterp.egg-info/PKG-INFO
8
+ nnterp.egg-info/SOURCES.txt
9
+ nnterp.egg-info/dependency_links.txt
10
+ nnterp.egg-info/requires.txt
11
+ nnterp.egg-info/top_level.txt
@@ -0,0 +1 @@
1
+ nnsight
@@ -0,0 +1 @@
1
+ nnterp
@@ -0,0 +1,25 @@
1
+ [build-system]
2
+ requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
3
+ build-backend = "setuptools.build_meta"
4
+ [tool.setuptools_scm]
5
+
6
+ [project]
7
+ dynamic = ["version"]
8
+ name = "nnterp"
9
+ authors = [
10
+ { name="Clément Dumas", email="butanium.contact@gmail.com" },
11
+ ]
12
+ description = "Utils and mechanistic interpretability intervensions using nnsight"
13
+ readme = "README.md"
14
+ requires-python = ">=3.7"
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Operating System :: OS Independent",
19
+ ]
20
+ dependencies = [
21
+ "nnsight"
22
+ ]
23
+ [project.urls]
24
+ "Homepage" = "https://github.com/butanium/nnterp"
25
+
nnterp-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+