mi-crow 0.1.1.post13__py3-none-any.whl → 0.1.1.post15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -454,7 +454,7 @@ class BaseDataset(ABC):
454
454
  limit: Optional[int] = None,
455
455
  stratify_by: Optional[str] = None,
456
456
  stratify_seed: Optional[int] = None,
457
- **kwargs,
457
+ **kwargs: Any,
458
458
  ) -> "BaseDataset":
459
459
  """
460
460
  Load dataset from HuggingFace Hub.
@@ -534,7 +534,7 @@ class BaseDataset(ABC):
534
534
  stratify_by: Optional[str] = None,
535
535
  stratify_seed: Optional[int] = None,
536
536
  drop_na_columns: Optional[List[str]] = None,
537
- **kwargs,
537
+ **kwargs: Any,
538
538
  ) -> "BaseDataset":
539
539
  """
540
540
  Load dataset from CSV file.
@@ -593,7 +593,7 @@ class BaseDataset(ABC):
593
593
  stratify_by: Optional[str] = None,
594
594
  stratify_seed: Optional[int] = None,
595
595
  drop_na_columns: Optional[List[str]] = None,
596
- **kwargs,
596
+ **kwargs: Any,
597
597
  ) -> "BaseDataset":
598
598
  """
599
599
  Load dataset from JSON or JSONL file.
@@ -368,7 +368,7 @@ class ClassificationDataset(BaseDataset):
368
368
  stratify_seed: Optional[int] = None,
369
369
  streaming: Optional[bool] = None,
370
370
  drop_na: bool = False,
371
- **kwargs,
371
+ **kwargs: Any,
372
372
  ) -> "ClassificationDataset":
373
373
  """
374
374
  Load classification dataset from HuggingFace Hub.
@@ -459,7 +459,7 @@ class ClassificationDataset(BaseDataset):
459
459
  stratify_by: Optional[str] = None,
460
460
  stratify_seed: Optional[int] = None,
461
461
  drop_na: bool = False,
462
- **kwargs,
462
+ **kwargs: Any,
463
463
  ) -> "ClassificationDataset":
464
464
  """
465
465
  Load classification dataset from CSV file.
@@ -519,7 +519,7 @@ class ClassificationDataset(BaseDataset):
519
519
  stratify_by: Optional[str] = None,
520
520
  stratify_seed: Optional[int] = None,
521
521
  drop_na: bool = False,
522
- **kwargs,
522
+ **kwargs: Any,
523
523
  ) -> "ClassificationDataset":
524
524
  """
525
525
  Load classification dataset from JSON/JSONL file.
@@ -233,7 +233,7 @@ class TextDataset(BaseDataset):
233
233
  stratify_seed: Optional[int] = None,
234
234
  streaming: Optional[bool] = None,
235
235
  drop_na: bool = False,
236
- **kwargs,
236
+ **kwargs: Any,
237
237
  ) -> "TextDataset":
238
238
  """
239
239
  Load text dataset from HuggingFace Hub.
@@ -312,7 +312,7 @@ class TextDataset(BaseDataset):
312
312
  stratify_by: Optional[str] = None,
313
313
  stratify_seed: Optional[int] = None,
314
314
  drop_na: bool = False,
315
- **kwargs,
315
+ **kwargs: Any,
316
316
  ) -> "TextDataset":
317
317
  """
318
318
  Load text dataset from CSV file.
@@ -365,7 +365,7 @@ class TextDataset(BaseDataset):
365
365
  stratify_by: Optional[str] = None,
366
366
  stratify_seed: Optional[int] = None,
367
367
  drop_na: bool = False,
368
- **kwargs,
368
+ **kwargs: Any,
369
369
  ) -> "TextDataset":
370
370
  """
371
371
  Load text dataset from JSON/JSONL file.
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
4
  from pathlib import Path
5
- from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set
5
+ from typing import Sequence, Any, Dict, List, TYPE_CHECKING, Set, Tuple
6
6
 
7
7
  import torch
8
8
  from torch import nn, Tensor
@@ -134,7 +134,7 @@ class LanguageModel:
134
134
  """Set the store instance."""
135
135
  self.context.store = value
136
136
 
137
- def tokenize(self, texts: Sequence[str], **kwargs: Any):
137
+ def tokenize(self, texts: Sequence[str], **kwargs: Any) -> Any:
138
138
  """
139
139
  Tokenize texts using the language model tokenizer.
140
140
 
@@ -154,7 +154,7 @@ class LanguageModel:
154
154
  autocast: bool = True,
155
155
  autocast_dtype: torch.dtype | None = None,
156
156
  with_controllers: bool = True,
157
- ):
157
+ ) -> Tuple[Any, Any]:
158
158
  """
159
159
  Run forward pass on texts.
160
160
 
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Callable, TYPE_CHECKING
1
+ from typing import Dict, List, Callable, TYPE_CHECKING, Any
2
2
 
3
3
  from torch import nn
4
4
 
@@ -121,7 +121,7 @@ class LanguageModelLayers:
121
121
  layer_signature: str | int,
122
122
  hook: Callable,
123
123
  hook_args: dict = None
124
- ):
124
+ ) -> Any:
125
125
  """
126
126
  Register a forward hook directly on a layer.
127
127
 
@@ -141,7 +141,7 @@ class LanguageModelLayers:
141
141
  layer_signature: str | int,
142
142
  hook: Callable,
143
143
  hook_args: dict = None
144
- ):
144
+ ) -> Any:
145
145
  """
146
146
  Register a pre-forward hook directly on a layer.
147
147
 
@@ -26,6 +26,8 @@ class TopKSaeTrainingConfig(SaeTrainingConfig):
26
26
 
27
27
  Args:
28
28
  k: Number of top activations to keep (required for TopK SAE training)
29
+
30
+ Note:
29
31
  All other parameters are inherited from SaeTrainingConfig.
30
32
 
31
33
  Attributes:
@@ -0,0 +1,6 @@
1
+ """Training utilities for SAE models."""
2
+
3
+ from mi_crow.mechanistic.sae.training.wandb_logger import WandbLogger
4
+
5
+ __all__ = ["WandbLogger"]
6
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mi-crow
3
- Version: 0.1.1.post13
3
+ Version: 0.1.1.post15
4
4
  Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
5
5
  Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -1,10 +1,10 @@
1
1
  mi_crow/__init__.py,sha256=J7aXVlAicbjvk5630rhDxx0ATsvZnihud5u_aQpAwY8,487
2
2
  mi_crow/utils.py,sha256=LTfh2Ep87lAgPBaZkrQPP9caXFJoS9zUxu4qFuV4kzM,1549
3
3
  mi_crow/datasets/__init__.py,sha256=lCAc3nFlvoERrBPAan6C9YFmDx86W2gbIAy267Rb2Sk,349
4
- mi_crow/datasets/base_dataset.py,sha256=vYx-oj3jVhLZD1-xGSO4K4ZIsQtYpHP5zHmg7jd4FE0,22512
5
- mi_crow/datasets/classification_dataset.py,sha256=nL_xndJHyf8hlLxKBe_ZO2YLYsXQjGyeY6csqGTTzEY,21706
4
+ mi_crow/datasets/base_dataset.py,sha256=fSs63o0LDNXt1-i0HmKTxaP33CMcS6rWJ3C4vpB8qR0,22527
5
+ mi_crow/datasets/classification_dataset.py,sha256=c2peu0kWVY3N8XlEi9GjRuE6MmVV8f8vz8gtPmxI70s,21721
6
6
  mi_crow/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
7
- mi_crow/datasets/text_dataset.py,sha256=5FzHWkMWWK0yP69O48S3fUj5KgHb8qo3mkvvZihHFuU,16781
7
+ mi_crow/datasets/text_dataset.py,sha256=dk5RzWy-T7ssQNwv84FrgEZ83zuDgpq91XgUfbwvk8c,16796
8
8
  mi_crow/hooks/__init__.py,sha256=KYy5qcbEpnJceNH86ofy43Suu_36QXjj0HYl79rVyls,693
9
9
  mi_crow/hooks/controller.py,sha256=eo8LMERORXYUjH4-_R6DHk5JKN6O8SW6PlnuBFrlNqg,6063
10
10
  mi_crow/hooks/detector.py,sha256=Bj3xz56cSgRvbcoQBsHIdlJdf0dtgVLw3l1pOSRvRAg,3114
@@ -22,8 +22,8 @@ mi_crow/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5Ilih
22
22
  mi_crow/language_model/hook_metadata.py,sha256=GACZjZUneo2l5j7DCFycLAunTm0etdMQ2YB_xgueUuk,1394
23
23
  mi_crow/language_model/inference.py,sha256=-Kpm85jM8y6-GyDgrvIczitBIwGh8grJP8aYuXsLV-g,19082
24
24
  mi_crow/language_model/initialization.py,sha256=e_Vkk-p9KWRt6-Hmkm6I29dTf20jzEAyNF9CG4nc48M,3704
25
- mi_crow/language_model/language_model.py,sha256=a6CcklVA65oYtFxGXiwQrOKMPZj6eb7LOiT1zJ5-guo,13965
26
- mi_crow/language_model/layers.py,sha256=1yExHodMyqr_Yk4W-2HiSGnRs2sYOA7swsxI8u0Uvfk,15914
25
+ mi_crow/language_model/language_model.py,sha256=b4KniweHXauPHPgHba9cAlLK86z2atfeoetnG-UHoSo,13998
26
+ mi_crow/language_model/layers.py,sha256=7RdVU1mXt5oy2OhzNoKDeD9Omm899ZC-BJqpcTfxP2w,15933
27
27
  mi_crow/language_model/persistence.py,sha256=9wQE6tRvLg7BgdLlkKRTOfRwXb5Q0LsEgg8B9J7Yos0,5881
28
28
  mi_crow/language_model/tokenizer.py,sha256=uZbMDVNnzu8WZINUaR1tLFXiuk9V5pAoahwnJOUvEuE,7379
29
29
  mi_crow/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
@@ -39,13 +39,14 @@ mi_crow/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_
39
39
  mi_crow/mechanistic/sae/concepts/input_tracker.py,sha256=kIiqt7guv_-9-UPYtefAFJbHkWtAS_mnqYVvRU4eb2o,1890
40
40
  mi_crow/mechanistic/sae/modules/__init__.py,sha256=e0lkCALQZcJN7KpYyTtXx3OD2NhBxV_kOZLLJ6EWaTE,243
41
41
  mi_crow/mechanistic/sae/modules/l1_sae.py,sha256=qqw0iTWLSmWAlz5kgfw_mex8LeecFWM1FobyUteMqmM,15388
42
- mi_crow/mechanistic/sae/modules/topk_sae.py,sha256=pK_ajKTQb0wGAftzb6AE5ZZthV3aFLr6G3avOVclSHE,17313
42
+ mi_crow/mechanistic/sae/modules/topk_sae.py,sha256=TIueJ4ftJm0XfU2UyHm0n3bIwL-W1RUgnpYym4xBxpg,17328
43
+ mi_crow/mechanistic/sae/training/__init__.py,sha256=5flCJVkOyKizY0FZy1OP5v0EI6bPEayunpnUPp82a6s,140
43
44
  mi_crow/mechanistic/sae/training/wandb_logger.py,sha256=YlSJd5CaNa35RmIgf1FD_gSEDyhGRa2UdHo_Ofrplos,8558
44
45
  mi_crow/store/__init__.py,sha256=DrYTpdgzrRzjHm9bigy-GiP0BGxzjmD3-lJCthtgxbE,123
45
46
  mi_crow/store/local_store.py,sha256=XmguFvdrUi6NHzvV_bLaDJzpk5KWU_-ObkzhICcLu6g,17216
46
47
  mi_crow/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
47
48
  mi_crow/store/store_dataloader.py,sha256=UkZhHCOTg56ozomPtU9vHBhxIMOPcOiyfMqiAxgqtQs,4341
48
- mi_crow-0.1.1.post13.dist-info/METADATA,sha256=5KEjRwTvthwSs5Jed9Nr0TCKDpfzvLicbVKKm6KkBnQ,6584
49
- mi_crow-0.1.1.post13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
50
- mi_crow-0.1.1.post13.dist-info/top_level.txt,sha256=DTuNo2VWgrH6jQKY19NciReSpLwGKKIRzJ3WbpspLlE,8
51
- mi_crow-0.1.1.post13.dist-info/RECORD,,
49
+ mi_crow-0.1.1.post15.dist-info/METADATA,sha256=lsLbGfYLFGEQuQALB9AOqOGyZovTk7Of7gvbKuI5fOY,6584
50
+ mi_crow-0.1.1.post15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
+ mi_crow-0.1.1.post15.dist-info/top_level.txt,sha256=DTuNo2VWgrH6jQKY19NciReSpLwGKKIRzJ3WbpspLlE,8
52
+ mi_crow-0.1.1.post15.dist-info/RECORD,,