mi-crow 0.1.1.post12__py3-none-any.whl → 0.1.1.post14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amber → mi_crow}/__init__.py +2 -2
- mi_crow/datasets/__init__.py +11 -0
- {amber → mi_crow}/datasets/base_dataset.py +2 -2
- {amber → mi_crow}/datasets/classification_dataset.py +3 -3
- {amber → mi_crow}/datasets/text_dataset.py +3 -3
- mi_crow/hooks/__init__.py +20 -0
- {amber → mi_crow}/hooks/controller.py +3 -3
- {amber → mi_crow}/hooks/detector.py +2 -2
- {amber → mi_crow}/hooks/hook.py +1 -1
- {amber → mi_crow}/hooks/implementations/function_controller.py +2 -2
- {amber → mi_crow}/hooks/implementations/layer_activation_detector.py +3 -3
- {amber → mi_crow}/hooks/implementations/model_input_detector.py +2 -2
- {amber → mi_crow}/hooks/implementations/model_output_detector.py +2 -2
- {amber → mi_crow}/hooks/utils.py +1 -1
- {amber → mi_crow}/language_model/activations.py +9 -9
- {amber → mi_crow}/language_model/context.py +3 -3
- {amber → mi_crow}/language_model/hook_metadata.py +1 -1
- {amber → mi_crow}/language_model/inference.py +7 -7
- {amber → mi_crow}/language_model/initialization.py +3 -3
- {amber → mi_crow}/language_model/language_model.py +11 -11
- {amber → mi_crow}/language_model/layers.py +4 -4
- {amber → mi_crow}/language_model/persistence.py +7 -7
- {amber → mi_crow}/language_model/tokenizer.py +1 -1
- {amber → mi_crow}/mechanistic/sae/autoencoder_context.py +1 -1
- {amber → mi_crow}/mechanistic/sae/concepts/autoencoder_concepts.py +8 -8
- {amber → mi_crow}/mechanistic/sae/concepts/concept_dictionary.py +1 -1
- {amber → mi_crow}/mechanistic/sae/concepts/input_tracker.py +2 -2
- mi_crow/mechanistic/sae/modules/__init__.py +5 -0
- {amber → mi_crow}/mechanistic/sae/modules/l1_sae.py +17 -17
- {amber → mi_crow}/mechanistic/sae/modules/topk_sae.py +18 -18
- {amber → mi_crow}/mechanistic/sae/sae.py +9 -9
- {amber → mi_crow}/mechanistic/sae/sae_trainer.py +7 -7
- mi_crow/mechanistic/sae/training/__init__.py +6 -0
- {amber → mi_crow}/mechanistic/sae/training/wandb_logger.py +2 -2
- mi_crow/store/__init__.py +5 -0
- {amber → mi_crow}/store/local_store.py +1 -1
- {amber → mi_crow}/store/store_dataloader.py +2 -2
- {amber → mi_crow}/utils.py +1 -1
- {mi_crow-0.1.1.post12.dist-info → mi_crow-0.1.1.post14.dist-info}/METADATA +2 -2
- mi_crow-0.1.1.post14.dist-info/RECORD +52 -0
- mi_crow-0.1.1.post14.dist-info/top_level.txt +1 -0
- amber/datasets/__init__.py +0 -11
- amber/hooks/__init__.py +0 -20
- amber/mechanistic/sae/modules/__init__.py +0 -5
- amber/store/__init__.py +0 -5
- mi_crow-0.1.1.post12.dist-info/RECORD +0 -51
- mi_crow-0.1.1.post12.dist-info/top_level.txt +0 -1
- {amber → mi_crow}/datasets/loading_strategy.py +0 -0
- {amber → mi_crow}/hooks/implementations/__init__.py +0 -0
- {amber → mi_crow}/language_model/__init__.py +0 -0
- {amber → mi_crow}/language_model/contracts.py +0 -0
- {amber → mi_crow}/language_model/utils.py +0 -0
- {amber → mi_crow}/mechanistic/__init__.py +0 -0
- {amber → mi_crow}/mechanistic/sae/__init__.py +0 -0
- {amber → mi_crow}/mechanistic/sae/concepts/__init__.py +0 -0
- {amber → mi_crow}/mechanistic/sae/concepts/concept_models.py +0 -0
- {amber → mi_crow}/store/store.py +0 -0
- {mi_crow-0.1.1.post12.dist-info → mi_crow-0.1.1.post14.dist-info}/WHEEL +0 -0
{amber → mi_crow}/__init__.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""mi_crow: helper package for the Engineer Thesis project.
|
|
2
2
|
|
|
3
3
|
This module is intentionally minimal. It exists to define the top-level package
|
|
4
4
|
and to enable code coverage to include the package. Importing it should succeed
|
|
@@ -6,7 +6,7 @@ without side effects.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
# A tiny bit of executable code to make the package measurable by coverage.
|
|
9
|
-
PACKAGE_NAME = "
|
|
9
|
+
PACKAGE_NAME = "mi_crow"
|
|
10
10
|
__version__ = "0.0.0"
|
|
11
11
|
|
|
12
12
|
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from mi_crow.datasets.base_dataset import BaseDataset
|
|
2
|
+
from mi_crow.datasets.text_dataset import TextDataset
|
|
3
|
+
from mi_crow.datasets.classification_dataset import ClassificationDataset
|
|
4
|
+
from mi_crow.datasets.loading_strategy import LoadingStrategy
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"BaseDataset",
|
|
8
|
+
"TextDataset",
|
|
9
|
+
"ClassificationDataset",
|
|
10
|
+
"LoadingStrategy",
|
|
11
|
+
]
|
|
@@ -9,8 +9,8 @@ from typing import Any, Dict, Iterator, List, Optional, Union
|
|
|
9
9
|
|
|
10
10
|
from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
|
|
11
11
|
|
|
12
|
-
from
|
|
13
|
-
from
|
|
12
|
+
from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
13
|
+
from mi_crow.store.store import Store
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class BaseDataset(ABC):
|
|
@@ -5,9 +5,9 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
|
|
5
5
|
|
|
6
6
|
from datasets import Dataset, IterableDataset, load_dataset
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
8
|
+
from mi_crow.datasets.base_dataset import BaseDataset
|
|
9
|
+
from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
10
|
+
from mi_crow.store.store import Store
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ClassificationDataset(BaseDataset):
|
|
@@ -5,9 +5,9 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
|
|
5
5
|
|
|
6
6
|
from datasets import Dataset, IterableDataset, load_dataset
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
8
|
+
from mi_crow.datasets.base_dataset import BaseDataset
|
|
9
|
+
from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
10
|
+
from mi_crow.store.store import Store
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TextDataset(BaseDataset):
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from mi_crow.hooks.hook import Hook, HookType, HookError
|
|
2
|
+
from mi_crow.hooks.detector import Detector
|
|
3
|
+
from mi_crow.hooks.controller import Controller
|
|
4
|
+
from mi_crow.hooks.implementations.layer_activation_detector import LayerActivationDetector
|
|
5
|
+
from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
|
|
6
|
+
from mi_crow.hooks.implementations.model_output_detector import ModelOutputDetector
|
|
7
|
+
from mi_crow.hooks.implementations.function_controller import FunctionController
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Hook",
|
|
11
|
+
"HookType",
|
|
12
|
+
"HookError",
|
|
13
|
+
"Detector",
|
|
14
|
+
"Controller",
|
|
15
|
+
"LayerActivationDetector",
|
|
16
|
+
"ModelInputDetector",
|
|
17
|
+
"ModelOutputDetector",
|
|
18
|
+
"FunctionController",
|
|
19
|
+
]
|
|
20
|
+
|
|
@@ -6,9 +6,9 @@ from typing import TYPE_CHECKING
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn as nn
|
|
8
8
|
|
|
9
|
-
from
|
|
10
|
-
from
|
|
11
|
-
from
|
|
9
|
+
from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
10
|
+
from mi_crow.hooks.utils import extract_tensor_from_input, extract_tensor_from_output
|
|
11
|
+
from mi_crow.utils import get_logger
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
14
|
pass
|
|
@@ -5,8 +5,8 @@ from typing import Any, TYPE_CHECKING, Dict
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
8
|
+
from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
9
|
+
from mi_crow.store.store import Store
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
pass
|
{amber → mi_crow}/hooks/hook.py
RENAMED
|
@@ -10,7 +10,7 @@ from torch import nn, Tensor
|
|
|
10
10
|
from torch.types import _TensorOrTensors
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
-
from
|
|
13
|
+
from mi_crow.language_model.context import LanguageModelContext
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class HookType(str, Enum):
|
|
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import Callable, TYPE_CHECKING
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
6
|
+
from mi_crow.hooks.controller import Controller
|
|
7
|
+
from mi_crow.hooks.hook import HookType
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from torch import nn
|
|
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
from
|
|
6
|
+
from mi_crow.hooks.detector import Detector
|
|
7
|
+
from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
8
|
+
from mi_crow.hooks.utils import extract_tensor_from_output
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from torch import nn
|
|
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING, Dict, Set, List, Optional
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
6
|
+
from mi_crow.hooks.detector import Detector
|
|
7
|
+
from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from torch import nn
|
|
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
6
|
+
from mi_crow.hooks.detector import Detector
|
|
7
|
+
from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from torch import nn
|
{amber → mi_crow}/hooks/utils.py
RENAMED
|
@@ -6,7 +6,7 @@ from typing import Any
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from
|
|
9
|
+
from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def extract_tensor_from_input(input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
|
|
@@ -4,16 +4,16 @@ from typing import TYPE_CHECKING, Any, Dict, Sequence
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
7
|
+
from mi_crow.datasets import BaseDataset
|
|
8
|
+
from mi_crow.hooks import HookType
|
|
9
|
+
from mi_crow.hooks.implementations.layer_activation_detector import LayerActivationDetector
|
|
10
|
+
from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
|
|
11
|
+
from mi_crow.store.store import Store
|
|
12
|
+
from mi_crow.utils import get_logger
|
|
13
|
+
from mi_crow.language_model.utils import get_device_from_model
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from
|
|
16
|
+
from mi_crow.language_model.context import LanguageModelContext
|
|
17
17
|
|
|
18
18
|
logger = get_logger(__name__)
|
|
19
19
|
|
|
@@ -170,7 +170,7 @@ class LanguageModelActivations:
|
|
|
170
170
|
meta: Metadata dictionary
|
|
171
171
|
verbose: Whether to log
|
|
172
172
|
"""
|
|
173
|
-
from
|
|
173
|
+
from mi_crow.language_model.inference import InferenceEngine
|
|
174
174
|
InferenceEngine._save_run_metadata(store, run_name, meta, verbose)
|
|
175
175
|
|
|
176
176
|
def _process_batch(
|
|
@@ -2,11 +2,11 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from typing import Optional, Dict, Any, TYPE_CHECKING, List, Set
|
|
3
3
|
|
|
4
4
|
if TYPE_CHECKING:
|
|
5
|
-
from
|
|
5
|
+
from mi_crow.language_model.language_model import LanguageModel
|
|
6
6
|
from torch import nn
|
|
7
7
|
from transformers import PreTrainedTokenizerBase
|
|
8
|
-
from
|
|
9
|
-
from
|
|
8
|
+
from mi_crow.hooks.hook import Hook
|
|
9
|
+
from mi_crow.store.store import Store
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@dataclass
|
|
@@ -6,7 +6,7 @@ from collections import defaultdict
|
|
|
6
6
|
from typing import Dict, List, Any, TYPE_CHECKING
|
|
7
7
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
|
-
from
|
|
9
|
+
from mi_crow.language_model.context import LanguageModelContext
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def collect_hooks_metadata(context: "LanguageModelContext") -> Dict[str, List[Dict[str, Any]]]:
|
|
@@ -8,14 +8,14 @@ from typing import Sequence, Any, Dict, List, TYPE_CHECKING
|
|
|
8
8
|
import torch
|
|
9
9
|
from torch import nn
|
|
10
10
|
|
|
11
|
-
from
|
|
12
|
-
from
|
|
11
|
+
from mi_crow.language_model.utils import get_device_from_model, move_tensors_to_device, extract_logits_from_output
|
|
12
|
+
from mi_crow.utils import get_logger
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
-
from
|
|
16
|
-
from
|
|
17
|
-
from
|
|
18
|
-
from
|
|
15
|
+
from mi_crow.language_model.language_model import LanguageModel
|
|
16
|
+
from mi_crow.hooks.controller import Controller
|
|
17
|
+
from mi_crow.datasets import BaseDataset
|
|
18
|
+
from mi_crow.store.store import Store
|
|
19
19
|
|
|
20
20
|
logger = get_logger(__name__)
|
|
21
21
|
|
|
@@ -80,7 +80,7 @@ class InferenceEngine:
|
|
|
80
80
|
Args:
|
|
81
81
|
enc: Encoded inputs dictionary
|
|
82
82
|
"""
|
|
83
|
-
from
|
|
83
|
+
from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
|
|
84
84
|
|
|
85
85
|
detectors = self.lm.layers.get_detectors()
|
|
86
86
|
for detector in detectors:
|
|
@@ -9,11 +9,11 @@ import torch
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
|
11
11
|
|
|
12
|
-
from
|
|
13
|
-
from
|
|
12
|
+
from mi_crow.store.store import Store
|
|
13
|
+
from mi_crow.language_model.utils import extract_model_id
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from
|
|
16
|
+
from mi_crow.language_model.language_model import LanguageModel
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def initialize_model_id(
|
|
@@ -8,18 +8,18 @@ import torch
|
|
|
8
8
|
from torch import nn, Tensor
|
|
9
9
|
from transformers import PreTrainedTokenizerBase
|
|
10
10
|
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
15
|
-
from
|
|
16
|
-
from
|
|
17
|
-
from
|
|
18
|
-
from
|
|
19
|
-
from
|
|
11
|
+
from mi_crow.language_model.layers import LanguageModelLayers
|
|
12
|
+
from mi_crow.language_model.tokenizer import LanguageModelTokenizer
|
|
13
|
+
from mi_crow.language_model.activations import LanguageModelActivations
|
|
14
|
+
from mi_crow.language_model.context import LanguageModelContext
|
|
15
|
+
from mi_crow.language_model.inference import InferenceEngine
|
|
16
|
+
from mi_crow.language_model.persistence import save_model, load_model_from_saved_file
|
|
17
|
+
from mi_crow.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
|
|
18
|
+
from mi_crow.store.store import Store
|
|
19
|
+
from mi_crow.utils import get_logger
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
|
-
from
|
|
22
|
+
from mi_crow.mechanistic.sae.concepts.input_tracker import InputTracker
|
|
23
23
|
|
|
24
24
|
logger = get_logger(__name__)
|
|
25
25
|
|
|
@@ -308,7 +308,7 @@ class LanguageModel:
|
|
|
308
308
|
if self._input_tracker is not None:
|
|
309
309
|
return self._input_tracker
|
|
310
310
|
|
|
311
|
-
from
|
|
311
|
+
from mi_crow.mechanistic.sae.concepts.input_tracker import InputTracker
|
|
312
312
|
|
|
313
313
|
self._input_tracker = InputTracker(language_model=self)
|
|
314
314
|
|
|
@@ -2,12 +2,12 @@ from typing import Dict, List, Callable, TYPE_CHECKING
|
|
|
2
2
|
|
|
3
3
|
from torch import nn
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from
|
|
5
|
+
from mi_crow.hooks.hook import Hook, HookType
|
|
6
|
+
from mi_crow.hooks.detector import Detector
|
|
7
|
+
from mi_crow.hooks.controller import Controller
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
|
-
from
|
|
10
|
+
from mi_crow.language_model.context import LanguageModelContext
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class LanguageModelLayers:
|
|
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING
|
|
|
9
9
|
import torch
|
|
10
10
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
11
11
|
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
12
|
+
from mi_crow.language_model.contracts import ModelMetadata
|
|
13
|
+
from mi_crow.language_model.hook_metadata import collect_hooks_metadata
|
|
14
|
+
from mi_crow.language_model.utils import extract_model_id
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
|
-
from
|
|
18
|
-
from
|
|
17
|
+
from mi_crow.language_model.language_model import LanguageModel
|
|
18
|
+
from mi_crow.store.store import Store
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def save_model(
|
|
@@ -78,7 +78,7 @@ def save_model(
|
|
|
78
78
|
f"Failed to save model to {save_path}. Error: {e}"
|
|
79
79
|
) from e
|
|
80
80
|
|
|
81
|
-
from
|
|
81
|
+
from mi_crow.utils import get_logger
|
|
82
82
|
logger = get_logger(__name__)
|
|
83
83
|
logger.info(f"Saved model to {save_path}")
|
|
84
84
|
|
|
@@ -169,7 +169,7 @@ def load_model_from_saved_file(
|
|
|
169
169
|
# Note: Hooks are not automatically restored as they require hook instances
|
|
170
170
|
# The hook metadata is available in metadata_dict["hooks"] if needed
|
|
171
171
|
|
|
172
|
-
from
|
|
172
|
+
from mi_crow.utils import get_logger
|
|
173
173
|
logger = get_logger(__name__)
|
|
174
174
|
logger.info(f"Loaded model from {saved_path} (model_id: {model_id})")
|
|
175
175
|
|
|
@@ -9,12 +9,12 @@ import heapq
|
|
|
9
9
|
import torch
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
12
|
+
from mi_crow.mechanistic.sae.concepts.concept_models import NeuronText
|
|
13
|
+
from mi_crow.mechanistic.sae.autoencoder_context import AutoencoderContext
|
|
14
|
+
from mi_crow.utils import get_logger
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
|
-
from
|
|
17
|
+
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
18
18
|
|
|
19
19
|
logger = get_logger(__name__)
|
|
20
20
|
|
|
@@ -59,12 +59,12 @@ class AutoencoderConcepts:
|
|
|
59
59
|
|
|
60
60
|
def _ensure_dictionary(self):
|
|
61
61
|
if self.dictionary is None:
|
|
62
|
-
from
|
|
62
|
+
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
63
63
|
self.dictionary = ConceptDictionary(self._n_size)
|
|
64
64
|
return self.dictionary
|
|
65
65
|
|
|
66
66
|
def load_concepts_from_csv(self, csv_filepath: str | Path):
|
|
67
|
-
from
|
|
67
|
+
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
68
68
|
self.dictionary = ConceptDictionary.from_csv(
|
|
69
69
|
csv_filepath=csv_filepath,
|
|
70
70
|
n_size=self._n_size,
|
|
@@ -72,7 +72,7 @@ class AutoencoderConcepts:
|
|
|
72
72
|
)
|
|
73
73
|
|
|
74
74
|
def load_concepts_from_json(self, json_filepath: str | Path):
|
|
75
|
-
from
|
|
75
|
+
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
76
76
|
self.dictionary = ConceptDictionary.from_json(
|
|
77
77
|
json_filepath=json_filepath,
|
|
78
78
|
n_size=self._n_size,
|
|
@@ -84,7 +84,7 @@ class AutoencoderConcepts:
|
|
|
84
84
|
if self._top_texts_heaps is None:
|
|
85
85
|
raise ValueError("No top texts available. Enable text tracking and run inference first.")
|
|
86
86
|
|
|
87
|
-
from
|
|
87
|
+
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
88
88
|
neuron_texts = self.get_all_top_texts()
|
|
89
89
|
|
|
90
90
|
self.dictionary = ConceptDictionary.from_llm(
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, Sequence
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from mi_crow.utils import get_logger
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
|
-
from
|
|
6
|
+
from mi_crow.language_model.language_model import LanguageModel
|
|
7
7
|
|
|
8
8
|
logger = get_logger(__name__)
|
|
9
9
|
|
|
@@ -3,11 +3,11 @@ from typing import Any
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from overcomplete import SAE as OvercompleteSAE
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
6
|
+
from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
7
|
+
from mi_crow.mechanistic.sae.sae import Sae
|
|
8
|
+
from mi_crow.mechanistic.sae.sae_trainer import SaeTrainingConfig
|
|
9
|
+
from mi_crow.store.store import Store
|
|
10
|
+
from mi_crow.utils import get_logger
|
|
11
11
|
|
|
12
12
|
logger = get_logger(__name__)
|
|
13
13
|
|
|
@@ -293,7 +293,7 @@ class L1Sae(Sae):
|
|
|
293
293
|
# Save overcomplete model state dict
|
|
294
294
|
sae_state_dict = self.sae_engine.state_dict()
|
|
295
295
|
|
|
296
|
-
|
|
296
|
+
mi_crow_metadata = {
|
|
297
297
|
"concepts_state": {
|
|
298
298
|
'multiplication': self.concepts.multiplication.data,
|
|
299
299
|
'bias': self.concepts.bias.data,
|
|
@@ -307,7 +307,7 @@ class L1Sae(Sae):
|
|
|
307
307
|
|
|
308
308
|
payload = {
|
|
309
309
|
"sae_state_dict": sae_state_dict,
|
|
310
|
-
"
|
|
310
|
+
"mi_crow_metadata": mi_crow_metadata,
|
|
311
311
|
}
|
|
312
312
|
|
|
313
313
|
torch.save(payload, save_path)
|
|
@@ -336,16 +336,16 @@ class L1Sae(Sae):
|
|
|
336
336
|
payload = torch.load(p, map_location=map_location)
|
|
337
337
|
|
|
338
338
|
# Extract our metadata
|
|
339
|
-
if "
|
|
340
|
-
raise ValueError(f"Invalid L1SAE save format: missing '
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
n_latents = int(
|
|
344
|
-
n_inputs = int(
|
|
345
|
-
device =
|
|
346
|
-
layer_signature =
|
|
347
|
-
model_id =
|
|
348
|
-
concepts_state =
|
|
339
|
+
if "mi_crow_metadata" not in payload:
|
|
340
|
+
raise ValueError(f"Invalid L1SAE save format: missing 'mi_crow_metadata' key in {p}")
|
|
341
|
+
|
|
342
|
+
mi_crow_meta = payload["mi_crow_metadata"]
|
|
343
|
+
n_latents = int(mi_crow_meta["n_latents"])
|
|
344
|
+
n_inputs = int(mi_crow_meta["n_inputs"])
|
|
345
|
+
device = mi_crow_meta.get("device", "cpu")
|
|
346
|
+
layer_signature = mi_crow_meta.get("layer_signature")
|
|
347
|
+
model_id = mi_crow_meta.get("model_id")
|
|
348
|
+
concepts_state = mi_crow_meta.get("concepts_state", {})
|
|
349
349
|
|
|
350
350
|
# Create L1Sae instance
|
|
351
351
|
l1_sae = L1Sae(
|
|
@@ -7,11 +7,11 @@ from overcomplete import (
|
|
|
7
7
|
TopKSAE as OvercompleteTopkSAE,
|
|
8
8
|
SAE as OvercompleteSAE
|
|
9
9
|
)
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
10
|
+
from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
11
|
+
from mi_crow.mechanistic.sae.sae import Sae
|
|
12
|
+
from mi_crow.mechanistic.sae.sae_trainer import SaeTrainingConfig
|
|
13
|
+
from mi_crow.store.store import Store
|
|
14
|
+
from mi_crow.utils import get_logger
|
|
15
15
|
|
|
16
16
|
logger = get_logger(__name__)
|
|
17
17
|
|
|
@@ -338,7 +338,7 @@ class TopKSae(Sae):
|
|
|
338
338
|
# Save overcomplete model state dict
|
|
339
339
|
sae_state_dict = self.sae_engine.state_dict()
|
|
340
340
|
|
|
341
|
-
|
|
341
|
+
mi_crow_metadata = {
|
|
342
342
|
"concepts_state": {
|
|
343
343
|
'multiplication': self.concepts.multiplication.data,
|
|
344
344
|
'bias': self.concepts.bias.data,
|
|
@@ -353,7 +353,7 @@ class TopKSae(Sae):
|
|
|
353
353
|
|
|
354
354
|
payload = {
|
|
355
355
|
"sae_state_dict": sae_state_dict,
|
|
356
|
-
"
|
|
356
|
+
"mi_crow_metadata": mi_crow_metadata,
|
|
357
357
|
}
|
|
358
358
|
|
|
359
359
|
torch.save(payload, save_path)
|
|
@@ -382,17 +382,17 @@ class TopKSae(Sae):
|
|
|
382
382
|
payload = torch.load(p, map_location=map_location)
|
|
383
383
|
|
|
384
384
|
# Extract our metadata
|
|
385
|
-
if "
|
|
386
|
-
raise ValueError(f"Invalid TopKSAE save format: missing '
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
n_latents = int(
|
|
390
|
-
n_inputs = int(
|
|
391
|
-
k = int(
|
|
392
|
-
device =
|
|
393
|
-
layer_signature =
|
|
394
|
-
model_id =
|
|
395
|
-
concepts_state =
|
|
385
|
+
if "mi_crow_metadata" not in payload:
|
|
386
|
+
raise ValueError(f"Invalid TopKSAE save format: missing 'mi_crow_metadata' key in {p}")
|
|
387
|
+
|
|
388
|
+
mi_crow_meta = payload["mi_crow_metadata"]
|
|
389
|
+
n_latents = int(mi_crow_meta["n_latents"])
|
|
390
|
+
n_inputs = int(mi_crow_meta["n_inputs"])
|
|
391
|
+
k = int(mi_crow_meta["k"])
|
|
392
|
+
device = mi_crow_meta.get("device", "cpu")
|
|
393
|
+
layer_signature = mi_crow_meta.get("layer_signature")
|
|
394
|
+
model_id = mi_crow_meta.get("model_id")
|
|
395
|
+
concepts_state = mi_crow_meta.get("concepts_state", {})
|
|
396
396
|
|
|
397
397
|
# Create TopKSAE instance
|
|
398
398
|
topk_sae = TopKSae(
|
|
@@ -5,15 +5,15 @@ from typing import Any, TYPE_CHECKING, Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
15
|
-
from
|
|
16
|
-
from
|
|
8
|
+
from mi_crow.hooks.controller import Controller
|
|
9
|
+
from mi_crow.hooks.detector import Detector
|
|
10
|
+
from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
11
|
+
from mi_crow.mechanistic.sae.autoencoder_context import AutoencoderContext
|
|
12
|
+
from mi_crow.mechanistic.sae.concepts.autoencoder_concepts import AutoencoderConcepts
|
|
13
|
+
from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
|
|
14
|
+
from mi_crow.mechanistic.sae.sae_trainer import SaeTrainer
|
|
15
|
+
from mi_crow.store.store import Store
|
|
16
|
+
from mi_crow.utils import get_logger
|
|
17
17
|
|
|
18
18
|
from overcomplete.sae import SAE as OvercompleteSAE
|
|
19
19
|
|
|
@@ -9,12 +9,12 @@ import gc
|
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
|
|
12
|
-
from
|
|
13
|
-
from
|
|
12
|
+
from mi_crow.store.store_dataloader import StoreDataloader
|
|
13
|
+
from mi_crow.utils import get_logger
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from
|
|
17
|
-
from
|
|
16
|
+
from mi_crow.mechanistic.sae.sae import Sae
|
|
17
|
+
from mi_crow.store.store import Store
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@dataclass
|
|
@@ -528,7 +528,7 @@ class SaeTrainer:
|
|
|
528
528
|
|
|
529
529
|
sae_state_dict = self.sae.sae_engine.state_dict()
|
|
530
530
|
|
|
531
|
-
|
|
531
|
+
mi_crow_metadata = {
|
|
532
532
|
"concepts_state": {
|
|
533
533
|
'multiplication': self.sae.concepts.multiplication.data.cpu().clone(),
|
|
534
534
|
'bias': self.sae.concepts.bias.data.cpu().clone(),
|
|
@@ -541,11 +541,11 @@ class SaeTrainer:
|
|
|
541
541
|
}
|
|
542
542
|
|
|
543
543
|
if hasattr(self.sae, 'k'):
|
|
544
|
-
|
|
544
|
+
mi_crow_metadata["k"] = self.sae.k
|
|
545
545
|
|
|
546
546
|
payload = {
|
|
547
547
|
"sae_state_dict": sae_state_dict,
|
|
548
|
-
"
|
|
548
|
+
"mi_crow_metadata": mi_crow_metadata,
|
|
549
549
|
}
|
|
550
550
|
|
|
551
551
|
torch.save(payload, model_path)
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Optional
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from mi_crow.mechanistic.sae.sae_trainer import SaeTrainingConfig
|
|
6
|
+
from mi_crow.utils import get_logger
|
|
7
7
|
|
|
8
8
|
logger = get_logger(__name__)
|
|
9
9
|
|
|
@@ -2,12 +2,12 @@ from typing import Optional, Iterator, TYPE_CHECKING
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from mi_crow.utils import get_logger
|
|
6
6
|
|
|
7
7
|
logger = get_logger(__name__)
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
|
-
from
|
|
10
|
+
from mi_crow.store.store import Store
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class StoreDataloader:
|
{amber → mi_crow}/utils.py
RENAMED
|
@@ -27,7 +27,7 @@ def set_seed(seed: int, deterministic: bool = True) -> None:
|
|
|
27
27
|
torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def get_logger(name: str = "
|
|
30
|
+
def get_logger(name: str = "mi_crow", level: int | str = logging.INFO) -> logging.Logger:
|
|
31
31
|
"""Get a configured logger with a simple format. Idempotent."""
|
|
32
32
|
logger = logging.getLogger(name)
|
|
33
33
|
if isinstance(level, str):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mi-crow
|
|
3
|
-
Version: 0.1.1.
|
|
3
|
+
Version: 0.1.1.post14
|
|
4
4
|
Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
|
|
5
5
|
Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -96,7 +96,7 @@ uv run --group server pytest tests/server/test_api.py --cov=server --cov-fail-un
|
|
|
96
96
|
|
|
97
97
|
### SAE API usage
|
|
98
98
|
|
|
99
|
-
- Configure artifact location (optional): `export SERVER_ARTIFACT_BASE_PATH=/path/to/
|
|
99
|
+
- Configure artifact location (optional): `export SERVER_ARTIFACT_BASE_PATH=/path/to/mi_crow_artifacts` (defaults to `~/.cache/mi_crow_server`)
|
|
100
100
|
- Load a model: `curl -X POST http://localhost:8000/models/load -H "Content-Type: application/json" -d '{"model_id":"bielik"}'`
|
|
101
101
|
- Save activations from dataset (stored in `LocalStore` under `activations/<model>/<run_id>`):
|
|
102
102
|
- HF dataset: `{"dataset":{"type":"hf","name":"ag_news","split":"train","text_field":"text"}}`
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
mi_crow/__init__.py,sha256=J7aXVlAicbjvk5630rhDxx0ATsvZnihud5u_aQpAwY8,487
|
|
2
|
+
mi_crow/utils.py,sha256=LTfh2Ep87lAgPBaZkrQPP9caXFJoS9zUxu4qFuV4kzM,1549
|
|
3
|
+
mi_crow/datasets/__init__.py,sha256=lCAc3nFlvoERrBPAan6C9YFmDx86W2gbIAy267Rb2Sk,349
|
|
4
|
+
mi_crow/datasets/base_dataset.py,sha256=vYx-oj3jVhLZD1-xGSO4K4ZIsQtYpHP5zHmg7jd4FE0,22512
|
|
5
|
+
mi_crow/datasets/classification_dataset.py,sha256=nL_xndJHyf8hlLxKBe_ZO2YLYsXQjGyeY6csqGTTzEY,21706
|
|
6
|
+
mi_crow/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
|
|
7
|
+
mi_crow/datasets/text_dataset.py,sha256=5FzHWkMWWK0yP69O48S3fUj5KgHb8qo3mkvvZihHFuU,16781
|
|
8
|
+
mi_crow/hooks/__init__.py,sha256=KYy5qcbEpnJceNH86ofy43Suu_36QXjj0HYl79rVyls,693
|
|
9
|
+
mi_crow/hooks/controller.py,sha256=eo8LMERORXYUjH4-_R6DHk5JKN6O8SW6PlnuBFrlNqg,6063
|
|
10
|
+
mi_crow/hooks/detector.py,sha256=Bj3xz56cSgRvbcoQBsHIdlJdf0dtgVLw3l1pOSRvRAg,3114
|
|
11
|
+
mi_crow/hooks/hook.py,sha256=JrCyPptXzHICAToxug7FD8zKWKZcLxbXfIe7UHCkh34,7542
|
|
12
|
+
mi_crow/hooks/utils.py,sha256=GdUAqL9InCsthjBVWbZtVQp2VtQLLOMaJy8S8Tc7WX0,2024
|
|
13
|
+
mi_crow/hooks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
mi_crow/hooks/implementations/function_controller.py,sha256=Rwu1Ghffqm-Jc4mqniRdIsXQeV24JKlx9CMO8jP8iv0,3092
|
|
15
|
+
mi_crow/hooks/implementations/layer_activation_detector.py,sha256=mFqLkWrBlNYzYYvwI8O-ymnpwP1UJSX7eWIQ7Fp-lxE,3167
|
|
16
|
+
mi_crow/hooks/implementations/model_input_detector.py,sha256=4AFA88eQ-zM9Kz8bpv4Bl5cvlZF5ZyIRXNi7Znd6V98,10078
|
|
17
|
+
mi_crow/hooks/implementations/model_output_detector.py,sha256=BmD8oJ1k_nv4hIzG-9SGxSCRARQkLRI4d_XQmKs6jjE,4883
|
|
18
|
+
mi_crow/language_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
+
mi_crow/language_model/activations.py,sha256=mJhuXkQ_JKzQsrFossnXVTdq4WgBI10Gi0ZYeQ9USFY,16989
|
|
20
|
+
mi_crow/language_model/context.py,sha256=TJSe6IruGktePpQ0dtHIoeatSY72qAhEKeTvLizCdlw,1124
|
|
21
|
+
mi_crow/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5IlihEm_2VM,268
|
|
22
|
+
mi_crow/language_model/hook_metadata.py,sha256=GACZjZUneo2l5j7DCFycLAunTm0etdMQ2YB_xgueUuk,1394
|
|
23
|
+
mi_crow/language_model/inference.py,sha256=-Kpm85jM8y6-GyDgrvIczitBIwGh8grJP8aYuXsLV-g,19082
|
|
24
|
+
mi_crow/language_model/initialization.py,sha256=e_Vkk-p9KWRt6-Hmkm6I29dTf20jzEAyNF9CG4nc48M,3704
|
|
25
|
+
mi_crow/language_model/language_model.py,sha256=a6CcklVA65oYtFxGXiwQrOKMPZj6eb7LOiT1zJ5-guo,13965
|
|
26
|
+
mi_crow/language_model/layers.py,sha256=1yExHodMyqr_Yk4W-2HiSGnRs2sYOA7swsxI8u0Uvfk,15914
|
|
27
|
+
mi_crow/language_model/persistence.py,sha256=9wQE6tRvLg7BgdLlkKRTOfRwXb5Q0LsEgg8B9J7Yos0,5881
|
|
28
|
+
mi_crow/language_model/tokenizer.py,sha256=uZbMDVNnzu8WZINUaR1tLFXiuk9V5pAoahwnJOUvEuE,7379
|
|
29
|
+
mi_crow/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
|
|
30
|
+
mi_crow/mechanistic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
+
mi_crow/mechanistic/sae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
+
mi_crow/mechanistic/sae/autoencoder_context.py,sha256=u5WzSlLb8_HaJF9LwGqe-J_rE-iRCoXbvwEhTZIArkw,947
|
|
33
|
+
mi_crow/mechanistic/sae/sae.py,sha256=R2IckZ-UDVXFURUSRoIdX0MGLe0TWU5JCVeDXPz1Wv4,6053
|
|
34
|
+
mi_crow/mechanistic/sae/sae_trainer.py,sha256=VqN9UgEPtgYItzAS0RqMfOhTnTxIEuModO0noHYieFA,26327
|
|
35
|
+
mi_crow/mechanistic/sae/concepts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
+
mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py,sha256=SE7A6y6cZiMGhylljNdZz0D8lhrulUh1FYiKhnPPRvY,14387
|
|
37
|
+
mi_crow/mechanistic/sae/concepts/concept_dictionary.py,sha256=aUXTe4Fy7Oe7iYJOcQImEXYMZAXhnkfAszZNwivk4eg,7673
|
|
38
|
+
mi_crow/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_FfWAY2XN4LFPIkNQiJWbY,133
|
|
39
|
+
mi_crow/mechanistic/sae/concepts/input_tracker.py,sha256=kIiqt7guv_-9-UPYtefAFJbHkWtAS_mnqYVvRU4eb2o,1890
|
|
40
|
+
mi_crow/mechanistic/sae/modules/__init__.py,sha256=e0lkCALQZcJN7KpYyTtXx3OD2NhBxV_kOZLLJ6EWaTE,243
|
|
41
|
+
mi_crow/mechanistic/sae/modules/l1_sae.py,sha256=qqw0iTWLSmWAlz5kgfw_mex8LeecFWM1FobyUteMqmM,15388
|
|
42
|
+
mi_crow/mechanistic/sae/modules/topk_sae.py,sha256=pK_ajKTQb0wGAftzb6AE5ZZthV3aFLr6G3avOVclSHE,17313
|
|
43
|
+
mi_crow/mechanistic/sae/training/__init__.py,sha256=5flCJVkOyKizY0FZy1OP5v0EI6bPEayunpnUPp82a6s,140
|
|
44
|
+
mi_crow/mechanistic/sae/training/wandb_logger.py,sha256=YlSJd5CaNa35RmIgf1FD_gSEDyhGRa2UdHo_Ofrplos,8558
|
|
45
|
+
mi_crow/store/__init__.py,sha256=DrYTpdgzrRzjHm9bigy-GiP0BGxzjmD3-lJCthtgxbE,123
|
|
46
|
+
mi_crow/store/local_store.py,sha256=XmguFvdrUi6NHzvV_bLaDJzpk5KWU_-ObkzhICcLu6g,17216
|
|
47
|
+
mi_crow/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
|
|
48
|
+
mi_crow/store/store_dataloader.py,sha256=UkZhHCOTg56ozomPtU9vHBhxIMOPcOiyfMqiAxgqtQs,4341
|
|
49
|
+
mi_crow-0.1.1.post14.dist-info/METADATA,sha256=Dd5lhIR9XmTNgdTDgZBhSsFVAvih80juP77KxHHfjFQ,6584
|
|
50
|
+
mi_crow-0.1.1.post14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
51
|
+
mi_crow-0.1.1.post14.dist-info/top_level.txt,sha256=DTuNo2VWgrH6jQKY19NciReSpLwGKKIRzJ3WbpspLlE,8
|
|
52
|
+
mi_crow-0.1.1.post14.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mi_crow
|
amber/datasets/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
from amber.datasets.base_dataset import BaseDataset
|
|
2
|
-
from amber.datasets.text_dataset import TextDataset
|
|
3
|
-
from amber.datasets.classification_dataset import ClassificationDataset
|
|
4
|
-
from amber.datasets.loading_strategy import LoadingStrategy
|
|
5
|
-
|
|
6
|
-
__all__ = [
|
|
7
|
-
"BaseDataset",
|
|
8
|
-
"TextDataset",
|
|
9
|
-
"ClassificationDataset",
|
|
10
|
-
"LoadingStrategy",
|
|
11
|
-
]
|
amber/hooks/__init__.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
from amber.hooks.hook import Hook, HookType, HookError
|
|
2
|
-
from amber.hooks.detector import Detector
|
|
3
|
-
from amber.hooks.controller import Controller
|
|
4
|
-
from amber.hooks.implementations.layer_activation_detector import LayerActivationDetector
|
|
5
|
-
from amber.hooks.implementations.model_input_detector import ModelInputDetector
|
|
6
|
-
from amber.hooks.implementations.model_output_detector import ModelOutputDetector
|
|
7
|
-
from amber.hooks.implementations.function_controller import FunctionController
|
|
8
|
-
|
|
9
|
-
__all__ = [
|
|
10
|
-
"Hook",
|
|
11
|
-
"HookType",
|
|
12
|
-
"HookError",
|
|
13
|
-
"Detector",
|
|
14
|
-
"Controller",
|
|
15
|
-
"LayerActivationDetector",
|
|
16
|
-
"ModelInputDetector",
|
|
17
|
-
"ModelOutputDetector",
|
|
18
|
-
"FunctionController",
|
|
19
|
-
]
|
|
20
|
-
|
amber/store/__init__.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
amber/__init__.py,sha256=5nh0D8qvFgOhBEQj00Rm06T1iY5VcSiifAg9SoY1LLA,483
|
|
2
|
-
amber/utils.py,sha256=oER2LA_alUjaIk_xCAyP2V54ywjqsg00I4KvitYnJPc,1547
|
|
3
|
-
amber/datasets/__init__.py,sha256=zhqgbm5zMBsRbmPNfjlYNJwGWOLuCNf5jEj0P8aopRU,341
|
|
4
|
-
amber/datasets/base_dataset.py,sha256=X2wt3GdjgAOY24_vOqrD5gVFxGplSRMCb69CoQtj0xw,22508
|
|
5
|
-
amber/datasets/classification_dataset.py,sha256=x_ZQ4dMzoY3Nn8V1I01xvzJK_IcHcDcm8dIxYeXzV5g,21700
|
|
6
|
-
amber/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
|
|
7
|
-
amber/datasets/text_dataset.py,sha256=ly0GHCS28Rg5ZluaafjavhcbvSD9-6ryovd_Y1ZIMms,16775
|
|
8
|
-
amber/hooks/__init__.py,sha256=9H08ZVoTK6TzYJXjEP2aqdHfoyLfdXvg6eOv3K1zNps,679
|
|
9
|
-
amber/hooks/controller.py,sha256=hc8FrrDosFYLrEGsEZmx1KsJ77F4p_gMKcF2WzHiURY,6057
|
|
10
|
-
amber/hooks/detector.py,sha256=5drJFrdrjseVjRNT-cq-U8XCt8AXV04YY2YrkQz4eFk,3110
|
|
11
|
-
amber/hooks/hook.py,sha256=-Qi-GJqRuIskXMuHUzp9_ESbbZi5tLSAFMWoBmrP3io,7540
|
|
12
|
-
amber/hooks/utils.py,sha256=wtsrjsMt-bXR3NshkwyZmfLre3IE3S4E5EoKppQrYOo,2022
|
|
13
|
-
amber/hooks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
amber/hooks/implementations/function_controller.py,sha256=66FFx_7sU7b0_FFQFFkAOqQm3aGsiyIRuZ64hIv-0w8,3088
|
|
15
|
-
amber/hooks/implementations/layer_activation_detector.py,sha256=bzoW6V8NNDNgRASs1YN_1TjEXaK3ahoNWiZ-ODfjB6I,3161
|
|
16
|
-
amber/hooks/implementations/model_input_detector.py,sha256=cYRVfyBEHi-1qg6F-4Q0vKEae6gYtq_3g1j3rOOCQdA,10074
|
|
17
|
-
amber/hooks/implementations/model_output_detector.py,sha256=iN-twt7Chc9ODmj-iei7_Ah7GqvE-knTVWi4C9kNye4,4879
|
|
18
|
-
amber/language_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
-
amber/language_model/activations.py,sha256=7VRDlCM-IctF1ee0H68_Dd7xeXHyUSQe1SOPb28Ej18,16971
|
|
20
|
-
amber/language_model/context.py,sha256=koslpikCcu9Svsopboa1wd7Fv0R2-4sI2whOCvVWvT8,1118
|
|
21
|
-
amber/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5IlihEm_2VM,268
|
|
22
|
-
amber/language_model/hook_metadata.py,sha256=9Xyfiu4ekCZj79zG4gZfLk-850AO2iKDE24FDXe7q7s,1392
|
|
23
|
-
amber/language_model/inference.py,sha256=l8BASS8E9B4VWJHucEqF_G_zqOzlKeG4KEvGameBbMw,19068
|
|
24
|
-
amber/language_model/initialization.py,sha256=hfrKdI_fsmaxk0p9q4wN7EFxq_lSXs8BXGlxwKJ21Qw,3698
|
|
25
|
-
amber/language_model/language_model.py,sha256=MXoaXYbNBUxHw4sxtWkVFLbXiw9tFEx9GI78sCiESuQ,13943
|
|
26
|
-
amber/language_model/layers.py,sha256=Ob7QZl8i236ALLklY9o_xtjDZSt6FD8sqdmFy_YLgN0,15906
|
|
27
|
-
amber/language_model/persistence.py,sha256=i2ibDH1OABM5-ZNNLh7h4rOYWPsg3aaeYhmB_xWYDZw,5867
|
|
28
|
-
amber/language_model/tokenizer.py,sha256=9eKNOHvUjIJhJbj7M-tN7jWU5lWhOeCY_cssa4exQ1g,7377
|
|
29
|
-
amber/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
|
|
30
|
-
amber/mechanistic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
-
amber/mechanistic/sae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
-
amber/mechanistic/sae/autoencoder_context.py,sha256=cn0mv9COqT3jNcvXBfce70ankVEmw9kNE3Mu-knugoc,945
|
|
33
|
-
amber/mechanistic/sae/sae.py,sha256=ha6rXGsOXE59E_ohTH0vJh6M4rQh3Xw0GfmCkSgeYS4,6035
|
|
34
|
-
amber/mechanistic/sae/sae_trainer.py,sha256=GMrPz9SpSuANA0tJt3IkyIOOqVr2k7apE0w4CqL92gM,26311
|
|
35
|
-
amber/mechanistic/sae/concepts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
-
amber/mechanistic/sae/concepts/autoencoder_concepts.py,sha256=dTB9v5zK5lcKc4_zbYtWoCy8fA6ycj1-FJid4-7hQ04,14371
|
|
37
|
-
amber/mechanistic/sae/concepts/concept_dictionary.py,sha256=9px845gODuufW2koym-5426q0ijdpxvS7avH1vu8-Ls,7671
|
|
38
|
-
amber/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_FfWAY2XN4LFPIkNQiJWbY,133
|
|
39
|
-
amber/mechanistic/sae/concepts/input_tracker.py,sha256=81FrOv9AAC7ejhryOWDTZ7Hlt3B2WoANx-wiO0KLr24,1886
|
|
40
|
-
amber/mechanistic/sae/modules/__init__.py,sha256=xpoz0HtPWoJD4dPj1qHaxtXDr7J0ERn30CX3m1dz21s,239
|
|
41
|
-
amber/mechanistic/sae/modules/l1_sae.py,sha256=_BebvpB9iUCTDjSiYNzDBM4l9sU_wadAdypwXaSb4ww,15352
|
|
42
|
-
amber/mechanistic/sae/modules/topk_sae.py,sha256=GQA8hYb6Fw7U1e5ExZjzZBjAZRhGk-VwqqYREVCQ_u8,17275
|
|
43
|
-
amber/mechanistic/sae/training/wandb_logger.py,sha256=d3vVBIQrnsJurX5HNVu7OYW4DqNgv18UZDpV8ddfN9k,8554
|
|
44
|
-
amber/store/__init__.py,sha256=UW4Hqyu-_qgnZ-gN_mk97OaWSrlPERcNi5YjnXMKeOU,119
|
|
45
|
-
amber/store/local_store.py,sha256=1pJbizZKrzNt_IQFnCFYjApPXs9ot-G1H8adeR7Qi50,17214
|
|
46
|
-
amber/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
|
|
47
|
-
amber/store/store_dataloader.py,sha256=QyYHSgOos8e-yzaEE_rySSVlGKaRNybURSDCgNrTIVM,4337
|
|
48
|
-
mi_crow-0.1.1.post12.dist-info/METADATA,sha256=1_SPBcnK8j_scOIWwnPqQGjlUpNnXxGGtvkYRZHeDQ8,6580
|
|
49
|
-
mi_crow-0.1.1.post12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
50
|
-
mi_crow-0.1.1.post12.dist-info/top_level.txt,sha256=FNP1x_ePvcW9Jsr7J9gCBARdDC-gqxIYtWF6HGNxtnI,6
|
|
51
|
-
mi_crow-0.1.1.post12.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
amber
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{amber → mi_crow}/store/store.py
RENAMED
|
File without changes
|
|
File without changes
|