mi-crow 0.1.1.post13__py3-none-any.whl → 0.1.1.post15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +3 -3
- mi_crow/datasets/classification_dataset.py +3 -3
- mi_crow/datasets/text_dataset.py +3 -3
- mi_crow/language_model/language_model.py +3 -3
- mi_crow/language_model/layers.py +3 -3
- mi_crow/mechanistic/sae/modules/topk_sae.py +2 -0
- mi_crow/mechanistic/sae/training/__init__.py +6 -0
- {mi_crow-0.1.1.post13.dist-info → mi_crow-0.1.1.post15.dist-info}/METADATA +1 -1
- {mi_crow-0.1.1.post13.dist-info → mi_crow-0.1.1.post15.dist-info}/RECORD +11 -10
- {mi_crow-0.1.1.post13.dist-info → mi_crow-0.1.1.post15.dist-info}/WHEEL +0 -0
- {mi_crow-0.1.1.post13.dist-info → mi_crow-0.1.1.post15.dist-info}/top_level.txt +0 -0
mi_crow/datasets/base_dataset.py
CHANGED
|
@@ -454,7 +454,7 @@ class BaseDataset(ABC):
|
|
|
454
454
|
limit: Optional[int] = None,
|
|
455
455
|
stratify_by: Optional[str] = None,
|
|
456
456
|
stratify_seed: Optional[int] = None,
|
|
457
|
-
**kwargs,
|
|
457
|
+
**kwargs: Any,
|
|
458
458
|
) -> "BaseDataset":
|
|
459
459
|
"""
|
|
460
460
|
Load dataset from HuggingFace Hub.
|
|
@@ -534,7 +534,7 @@ class BaseDataset(ABC):
|
|
|
534
534
|
stratify_by: Optional[str] = None,
|
|
535
535
|
stratify_seed: Optional[int] = None,
|
|
536
536
|
drop_na_columns: Optional[List[str]] = None,
|
|
537
|
-
**kwargs,
|
|
537
|
+
**kwargs: Any,
|
|
538
538
|
) -> "BaseDataset":
|
|
539
539
|
"""
|
|
540
540
|
Load dataset from CSV file.
|
|
@@ -593,7 +593,7 @@ class BaseDataset(ABC):
|
|
|
593
593
|
stratify_by: Optional[str] = None,
|
|
594
594
|
stratify_seed: Optional[int] = None,
|
|
595
595
|
drop_na_columns: Optional[List[str]] = None,
|
|
596
|
-
**kwargs,
|
|
596
|
+
**kwargs: Any,
|
|
597
597
|
) -> "BaseDataset":
|
|
598
598
|
"""
|
|
599
599
|
Load dataset from JSON or JSONL file.
|
|
@@ -368,7 +368,7 @@ class ClassificationDataset(BaseDataset):
|
|
|
368
368
|
stratify_seed: Optional[int] = None,
|
|
369
369
|
streaming: Optional[bool] = None,
|
|
370
370
|
drop_na: bool = False,
|
|
371
|
-
**kwargs,
|
|
371
|
+
**kwargs: Any,
|
|
372
372
|
) -> "ClassificationDataset":
|
|
373
373
|
"""
|
|
374
374
|
Load classification dataset from HuggingFace Hub.
|
|
@@ -459,7 +459,7 @@ class ClassificationDataset(BaseDataset):
|
|
|
459
459
|
stratify_by: Optional[str] = None,
|
|
460
460
|
stratify_seed: Optional[int] = None,
|
|
461
461
|
drop_na: bool = False,
|
|
462
|
-
**kwargs,
|
|
462
|
+
**kwargs: Any,
|
|
463
463
|
) -> "ClassificationDataset":
|
|
464
464
|
"""
|
|
465
465
|
Load classification dataset from CSV file.
|
|
@@ -519,7 +519,7 @@ class ClassificationDataset(BaseDataset):
|
|
|
519
519
|
stratify_by: Optional[str] = None,
|
|
520
520
|
stratify_seed: Optional[int] = None,
|
|
521
521
|
drop_na: bool = False,
|
|
522
|
-
**kwargs,
|
|
522
|
+
**kwargs: Any,
|
|
523
523
|
) -> "ClassificationDataset":
|
|
524
524
|
"""
|
|
525
525
|
Load classification dataset from JSON/JSONL file.
|
mi_crow/datasets/text_dataset.py
CHANGED
|
@@ -233,7 +233,7 @@ class TextDataset(BaseDataset):
|
|
|
233
233
|
stratify_seed: Optional[int] = None,
|
|
234
234
|
streaming: Optional[bool] = None,
|
|
235
235
|
drop_na: bool = False,
|
|
236
|
-
**kwargs,
|
|
236
|
+
**kwargs: Any,
|
|
237
237
|
) -> "TextDataset":
|
|
238
238
|
"""
|
|
239
239
|
Load text dataset from HuggingFace Hub.
|
|
@@ -312,7 +312,7 @@ class TextDataset(BaseDataset):
|
|
|
312
312
|
stratify_by: Optional[str] = None,
|
|
313
313
|
stratify_seed: Optional[int] = None,
|
|
314
314
|
drop_na: bool = False,
|
|
315
|
-
**kwargs,
|
|
315
|
+
**kwargs: Any,
|
|
316
316
|
) -> "TextDataset":
|
|
317
317
|
"""
|
|
318
318
|
Load text dataset from CSV file.
|
|
@@ -365,7 +365,7 @@ class TextDataset(BaseDataset):
|
|
|
365
365
|
stratify_by: Optional[str] = None,
|
|
366
366
|
stratify_seed: Optional[int] = None,
|
|
367
367
|
drop_na: bool = False,
|
|
368
|
-
**kwargs,
|
|
368
|
+
**kwargs: Any,
|
|
369
369
|
) -> "TextDataset":
|
|
370
370
|
"""
|
|
371
371
|
Load text dataset from JSON/JSONL file.
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set
|
|
5
|
+
from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch import nn, Tensor
|
|
@@ -134,7 +134,7 @@ class LanguageModel:
|
|
|
134
134
|
"""Set the store instance."""
|
|
135
135
|
self.context.store = value
|
|
136
136
|
|
|
137
|
-
def tokenize(self, texts: Sequence[str], **kwargs: Any):
|
|
137
|
+
def tokenize(self, texts: Sequence[str], **kwargs: Any) -> Any:
|
|
138
138
|
"""
|
|
139
139
|
Tokenize texts using the language model tokenizer.
|
|
140
140
|
|
|
@@ -154,7 +154,7 @@ class LanguageModel:
|
|
|
154
154
|
autocast: bool = True,
|
|
155
155
|
autocast_dtype: torch.dtype | None = None,
|
|
156
156
|
with_controllers: bool = True,
|
|
157
|
-
):
|
|
157
|
+
) -> Tuple[Any, Any]:
|
|
158
158
|
"""
|
|
159
159
|
Run forward pass on texts.
|
|
160
160
|
|
mi_crow/language_model/layers.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Dict, List, Callable, TYPE_CHECKING
|
|
1
|
+
from typing import Dict, List, Callable, TYPE_CHECKING, Any
|
|
2
2
|
|
|
3
3
|
from torch import nn
|
|
4
4
|
|
|
@@ -121,7 +121,7 @@ class LanguageModelLayers:
|
|
|
121
121
|
layer_signature: str | int,
|
|
122
122
|
hook: Callable,
|
|
123
123
|
hook_args: dict = None
|
|
124
|
-
):
|
|
124
|
+
) -> Any:
|
|
125
125
|
"""
|
|
126
126
|
Register a forward hook directly on a layer.
|
|
127
127
|
|
|
@@ -141,7 +141,7 @@ class LanguageModelLayers:
|
|
|
141
141
|
layer_signature: str | int,
|
|
142
142
|
hook: Callable,
|
|
143
143
|
hook_args: dict = None
|
|
144
|
-
):
|
|
144
|
+
) -> Any:
|
|
145
145
|
"""
|
|
146
146
|
Register a pre-forward hook directly on a layer.
|
|
147
147
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mi-crow
|
|
3
|
-
Version: 0.1.1.
|
|
3
|
+
Version: 0.1.1.post15
|
|
4
4
|
Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
|
|
5
5
|
Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
mi_crow/__init__.py,sha256=J7aXVlAicbjvk5630rhDxx0ATsvZnihud5u_aQpAwY8,487
|
|
2
2
|
mi_crow/utils.py,sha256=LTfh2Ep87lAgPBaZkrQPP9caXFJoS9zUxu4qFuV4kzM,1549
|
|
3
3
|
mi_crow/datasets/__init__.py,sha256=lCAc3nFlvoERrBPAan6C9YFmDx86W2gbIAy267Rb2Sk,349
|
|
4
|
-
mi_crow/datasets/base_dataset.py,sha256=
|
|
5
|
-
mi_crow/datasets/classification_dataset.py,sha256=
|
|
4
|
+
mi_crow/datasets/base_dataset.py,sha256=fSs63o0LDNXt1-i0HmKTxaP33CMcS6rWJ3C4vpB8qR0,22527
|
|
5
|
+
mi_crow/datasets/classification_dataset.py,sha256=c2peu0kWVY3N8XlEi9GjRuE6MmVV8f8vz8gtPmxI70s,21721
|
|
6
6
|
mi_crow/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
|
|
7
|
-
mi_crow/datasets/text_dataset.py,sha256=
|
|
7
|
+
mi_crow/datasets/text_dataset.py,sha256=dk5RzWy-T7ssQNwv84FrgEZ83zuDgpq91XgUfbwvk8c,16796
|
|
8
8
|
mi_crow/hooks/__init__.py,sha256=KYy5qcbEpnJceNH86ofy43Suu_36QXjj0HYl79rVyls,693
|
|
9
9
|
mi_crow/hooks/controller.py,sha256=eo8LMERORXYUjH4-_R6DHk5JKN6O8SW6PlnuBFrlNqg,6063
|
|
10
10
|
mi_crow/hooks/detector.py,sha256=Bj3xz56cSgRvbcoQBsHIdlJdf0dtgVLw3l1pOSRvRAg,3114
|
|
@@ -22,8 +22,8 @@ mi_crow/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5Ilih
|
|
|
22
22
|
mi_crow/language_model/hook_metadata.py,sha256=GACZjZUneo2l5j7DCFycLAunTm0etdMQ2YB_xgueUuk,1394
|
|
23
23
|
mi_crow/language_model/inference.py,sha256=-Kpm85jM8y6-GyDgrvIczitBIwGh8grJP8aYuXsLV-g,19082
|
|
24
24
|
mi_crow/language_model/initialization.py,sha256=e_Vkk-p9KWRt6-Hmkm6I29dTf20jzEAyNF9CG4nc48M,3704
|
|
25
|
-
mi_crow/language_model/language_model.py,sha256=
|
|
26
|
-
mi_crow/language_model/layers.py,sha256=
|
|
25
|
+
mi_crow/language_model/language_model.py,sha256=b4KniweHXauPHPgHba9cAlLK86z2atfeoetnG-UHoSo,13998
|
|
26
|
+
mi_crow/language_model/layers.py,sha256=7RdVU1mXt5oy2OhzNoKDeD9Omm899ZC-BJqpcTfxP2w,15933
|
|
27
27
|
mi_crow/language_model/persistence.py,sha256=9wQE6tRvLg7BgdLlkKRTOfRwXb5Q0LsEgg8B9J7Yos0,5881
|
|
28
28
|
mi_crow/language_model/tokenizer.py,sha256=uZbMDVNnzu8WZINUaR1tLFXiuk9V5pAoahwnJOUvEuE,7379
|
|
29
29
|
mi_crow/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
|
|
@@ -39,13 +39,14 @@ mi_crow/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_
|
|
|
39
39
|
mi_crow/mechanistic/sae/concepts/input_tracker.py,sha256=kIiqt7guv_-9-UPYtefAFJbHkWtAS_mnqYVvRU4eb2o,1890
|
|
40
40
|
mi_crow/mechanistic/sae/modules/__init__.py,sha256=e0lkCALQZcJN7KpYyTtXx3OD2NhBxV_kOZLLJ6EWaTE,243
|
|
41
41
|
mi_crow/mechanistic/sae/modules/l1_sae.py,sha256=qqw0iTWLSmWAlz5kgfw_mex8LeecFWM1FobyUteMqmM,15388
|
|
42
|
-
mi_crow/mechanistic/sae/modules/topk_sae.py,sha256=
|
|
42
|
+
mi_crow/mechanistic/sae/modules/topk_sae.py,sha256=TIueJ4ftJm0XfU2UyHm0n3bIwL-W1RUgnpYym4xBxpg,17328
|
|
43
|
+
mi_crow/mechanistic/sae/training/__init__.py,sha256=5flCJVkOyKizY0FZy1OP5v0EI6bPEayunpnUPp82a6s,140
|
|
43
44
|
mi_crow/mechanistic/sae/training/wandb_logger.py,sha256=YlSJd5CaNa35RmIgf1FD_gSEDyhGRa2UdHo_Ofrplos,8558
|
|
44
45
|
mi_crow/store/__init__.py,sha256=DrYTpdgzrRzjHm9bigy-GiP0BGxzjmD3-lJCthtgxbE,123
|
|
45
46
|
mi_crow/store/local_store.py,sha256=XmguFvdrUi6NHzvV_bLaDJzpk5KWU_-ObkzhICcLu6g,17216
|
|
46
47
|
mi_crow/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
|
|
47
48
|
mi_crow/store/store_dataloader.py,sha256=UkZhHCOTg56ozomPtU9vHBhxIMOPcOiyfMqiAxgqtQs,4341
|
|
48
|
-
mi_crow-0.1.1.
|
|
49
|
-
mi_crow-0.1.1.
|
|
50
|
-
mi_crow-0.1.1.
|
|
51
|
-
mi_crow-0.1.1.
|
|
49
|
+
mi_crow-0.1.1.post15.dist-info/METADATA,sha256=lsLbGfYLFGEQuQALB9AOqOGyZovTk7Of7gvbKuI5fOY,6584
|
|
50
|
+
mi_crow-0.1.1.post15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
51
|
+
mi_crow-0.1.1.post15.dist-info/top_level.txt,sha256=DTuNo2VWgrH6jQKY19NciReSpLwGKKIRzJ3WbpspLlE,8
|
|
52
|
+
mi_crow-0.1.1.post15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|