mi-crow 0.1.1.post12__py3-none-any.whl → 0.1.1.post14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {amber → mi_crow}/__init__.py +2 -2
  2. mi_crow/datasets/__init__.py +11 -0
  3. {amber → mi_crow}/datasets/base_dataset.py +2 -2
  4. {amber → mi_crow}/datasets/classification_dataset.py +3 -3
  5. {amber → mi_crow}/datasets/text_dataset.py +3 -3
  6. mi_crow/hooks/__init__.py +20 -0
  7. {amber → mi_crow}/hooks/controller.py +3 -3
  8. {amber → mi_crow}/hooks/detector.py +2 -2
  9. {amber → mi_crow}/hooks/hook.py +1 -1
  10. {amber → mi_crow}/hooks/implementations/function_controller.py +2 -2
  11. {amber → mi_crow}/hooks/implementations/layer_activation_detector.py +3 -3
  12. {amber → mi_crow}/hooks/implementations/model_input_detector.py +2 -2
  13. {amber → mi_crow}/hooks/implementations/model_output_detector.py +2 -2
  14. {amber → mi_crow}/hooks/utils.py +1 -1
  15. {amber → mi_crow}/language_model/activations.py +9 -9
  16. {amber → mi_crow}/language_model/context.py +3 -3
  17. {amber → mi_crow}/language_model/hook_metadata.py +1 -1
  18. {amber → mi_crow}/language_model/inference.py +7 -7
  19. {amber → mi_crow}/language_model/initialization.py +3 -3
  20. {amber → mi_crow}/language_model/language_model.py +11 -11
  21. {amber → mi_crow}/language_model/layers.py +4 -4
  22. {amber → mi_crow}/language_model/persistence.py +7 -7
  23. {amber → mi_crow}/language_model/tokenizer.py +1 -1
  24. {amber → mi_crow}/mechanistic/sae/autoencoder_context.py +1 -1
  25. {amber → mi_crow}/mechanistic/sae/concepts/autoencoder_concepts.py +8 -8
  26. {amber → mi_crow}/mechanistic/sae/concepts/concept_dictionary.py +1 -1
  27. {amber → mi_crow}/mechanistic/sae/concepts/input_tracker.py +2 -2
  28. mi_crow/mechanistic/sae/modules/__init__.py +5 -0
  29. {amber → mi_crow}/mechanistic/sae/modules/l1_sae.py +17 -17
  30. {amber → mi_crow}/mechanistic/sae/modules/topk_sae.py +18 -18
  31. {amber → mi_crow}/mechanistic/sae/sae.py +9 -9
  32. {amber → mi_crow}/mechanistic/sae/sae_trainer.py +7 -7
  33. mi_crow/mechanistic/sae/training/__init__.py +6 -0
  34. {amber → mi_crow}/mechanistic/sae/training/wandb_logger.py +2 -2
  35. mi_crow/store/__init__.py +5 -0
  36. {amber → mi_crow}/store/local_store.py +1 -1
  37. {amber → mi_crow}/store/store_dataloader.py +2 -2
  38. {amber → mi_crow}/utils.py +1 -1
  39. {mi_crow-0.1.1.post12.dist-info → mi_crow-0.1.1.post14.dist-info}/METADATA +2 -2
  40. mi_crow-0.1.1.post14.dist-info/RECORD +52 -0
  41. mi_crow-0.1.1.post14.dist-info/top_level.txt +1 -0
  42. amber/datasets/__init__.py +0 -11
  43. amber/hooks/__init__.py +0 -20
  44. amber/mechanistic/sae/modules/__init__.py +0 -5
  45. amber/store/__init__.py +0 -5
  46. mi_crow-0.1.1.post12.dist-info/RECORD +0 -51
  47. mi_crow-0.1.1.post12.dist-info/top_level.txt +0 -1
  48. {amber → mi_crow}/datasets/loading_strategy.py +0 -0
  49. {amber → mi_crow}/hooks/implementations/__init__.py +0 -0
  50. {amber → mi_crow}/language_model/__init__.py +0 -0
  51. {amber → mi_crow}/language_model/contracts.py +0 -0
  52. {amber → mi_crow}/language_model/utils.py +0 -0
  53. {amber → mi_crow}/mechanistic/__init__.py +0 -0
  54. {amber → mi_crow}/mechanistic/sae/__init__.py +0 -0
  55. {amber → mi_crow}/mechanistic/sae/concepts/__init__.py +0 -0
  56. {amber → mi_crow}/mechanistic/sae/concepts/concept_models.py +0 -0
  57. {amber → mi_crow}/store/store.py +0 -0
  58. {mi_crow-0.1.1.post12.dist-info → mi_crow-0.1.1.post14.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- """Amber: helper package for the Engineer Thesis project.
1
+ """mi_crow: helper package for the Engineer Thesis project.
2
2
 
3
3
  This module is intentionally minimal. It exists to define the top-level package
4
4
  and to enable code coverage to include the package. Importing it should succeed
@@ -6,7 +6,7 @@ without side effects.
6
6
  """
7
7
 
8
8
  # A tiny bit of executable code to make the package measurable by coverage.
9
- PACKAGE_NAME = "amber"
9
+ PACKAGE_NAME = "mi_crow"
10
10
  __version__ = "0.0.0"
11
11
 
12
12
 
@@ -0,0 +1,11 @@
1
+ from mi_crow.datasets.base_dataset import BaseDataset
2
+ from mi_crow.datasets.text_dataset import TextDataset
3
+ from mi_crow.datasets.classification_dataset import ClassificationDataset
4
+ from mi_crow.datasets.loading_strategy import LoadingStrategy
5
+
6
+ __all__ = [
7
+ "BaseDataset",
8
+ "TextDataset",
9
+ "ClassificationDataset",
10
+ "LoadingStrategy",
11
+ ]
@@ -9,8 +9,8 @@ from typing import Any, Dict, Iterator, List, Optional, Union
9
9
 
10
10
  from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
11
11
 
12
- from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
13
- from amber.store.store import Store
12
+ from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
13
+ from mi_crow.store.store import Store
14
14
 
15
15
 
16
16
  class BaseDataset(ABC):
@@ -5,9 +5,9 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
5
5
 
6
6
  from datasets import Dataset, IterableDataset, load_dataset
7
7
 
8
- from amber.datasets.base_dataset import BaseDataset
9
- from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
10
- from amber.store.store import Store
8
+ from mi_crow.datasets.base_dataset import BaseDataset
9
+ from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
10
+ from mi_crow.store.store import Store
11
11
 
12
12
 
13
13
  class ClassificationDataset(BaseDataset):
@@ -5,9 +5,9 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
5
5
 
6
6
  from datasets import Dataset, IterableDataset, load_dataset
7
7
 
8
- from amber.datasets.base_dataset import BaseDataset
9
- from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
10
- from amber.store.store import Store
8
+ from mi_crow.datasets.base_dataset import BaseDataset
9
+ from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
10
+ from mi_crow.store.store import Store
11
11
 
12
12
 
13
13
  class TextDataset(BaseDataset):
@@ -0,0 +1,20 @@
1
+ from mi_crow.hooks.hook import Hook, HookType, HookError
2
+ from mi_crow.hooks.detector import Detector
3
+ from mi_crow.hooks.controller import Controller
4
+ from mi_crow.hooks.implementations.layer_activation_detector import LayerActivationDetector
5
+ from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
6
+ from mi_crow.hooks.implementations.model_output_detector import ModelOutputDetector
7
+ from mi_crow.hooks.implementations.function_controller import FunctionController
8
+
9
+ __all__ = [
10
+ "Hook",
11
+ "HookType",
12
+ "HookError",
13
+ "Detector",
14
+ "Controller",
15
+ "LayerActivationDetector",
16
+ "ModelInputDetector",
17
+ "ModelOutputDetector",
18
+ "FunctionController",
19
+ ]
20
+
@@ -6,9 +6,9 @@ from typing import TYPE_CHECKING
6
6
  import torch
7
7
  import torch.nn as nn
8
8
 
9
- from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
10
- from amber.hooks.utils import extract_tensor_from_input, extract_tensor_from_output
11
- from amber.utils import get_logger
9
+ from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
10
+ from mi_crow.hooks.utils import extract_tensor_from_input, extract_tensor_from_output
11
+ from mi_crow.utils import get_logger
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  pass
@@ -5,8 +5,8 @@ from typing import Any, TYPE_CHECKING, Dict
5
5
 
6
6
  import torch
7
7
 
8
- from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
9
- from amber.store.store import Store
8
+ from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
9
+ from mi_crow.store.store import Store
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  pass
@@ -10,7 +10,7 @@ from torch import nn, Tensor
10
10
  from torch.types import _TensorOrTensors
11
11
 
12
12
  if TYPE_CHECKING:
13
- from amber.language_model.context import LanguageModelContext
13
+ from mi_crow.language_model.context import LanguageModelContext
14
14
 
15
15
 
16
16
  class HookType(str, Enum):
@@ -3,8 +3,8 @@ from __future__ import annotations
3
3
  from typing import Callable, TYPE_CHECKING
4
4
  import torch
5
5
 
6
- from amber.hooks.controller import Controller
7
- from amber.hooks.hook import HookType
6
+ from mi_crow.hooks.controller import Controller
7
+ from mi_crow.hooks.hook import HookType
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from torch import nn
@@ -3,9 +3,9 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
  import torch
5
5
 
6
- from amber.hooks.detector import Detector
7
- from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
- from amber.hooks.utils import extract_tensor_from_output
6
+ from mi_crow.hooks.detector import Detector
7
+ from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
+ from mi_crow.hooks.utils import extract_tensor_from_output
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from torch import nn
@@ -3,8 +3,8 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Dict, Set, List, Optional
4
4
  import torch
5
5
 
6
- from amber.hooks.detector import Detector
7
- from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
6
+ from mi_crow.hooks.detector import Detector
7
+ from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from torch import nn
@@ -3,8 +3,8 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
  import torch
5
5
 
6
- from amber.hooks.detector import Detector
7
- from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
6
+ from mi_crow.hooks.detector import Detector
7
+ from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from torch import nn
@@ -6,7 +6,7 @@ from typing import Any
6
6
 
7
7
  import torch
8
8
 
9
- from amber.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
9
+ from mi_crow.hooks.hook import HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
10
10
 
11
11
 
12
12
  def extract_tensor_from_input(input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
@@ -4,16 +4,16 @@ from typing import TYPE_CHECKING, Any, Dict, Sequence
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from amber.datasets import BaseDataset
8
- from amber.hooks import HookType
9
- from amber.hooks.implementations.layer_activation_detector import LayerActivationDetector
10
- from amber.hooks.implementations.model_input_detector import ModelInputDetector
11
- from amber.store.store import Store
12
- from amber.utils import get_logger
13
- from amber.language_model.utils import get_device_from_model
7
+ from mi_crow.datasets import BaseDataset
8
+ from mi_crow.hooks import HookType
9
+ from mi_crow.hooks.implementations.layer_activation_detector import LayerActivationDetector
10
+ from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
11
+ from mi_crow.store.store import Store
12
+ from mi_crow.utils import get_logger
13
+ from mi_crow.language_model.utils import get_device_from_model
14
14
 
15
15
  if TYPE_CHECKING:
16
- from amber.language_model.context import LanguageModelContext
16
+ from mi_crow.language_model.context import LanguageModelContext
17
17
 
18
18
  logger = get_logger(__name__)
19
19
 
@@ -170,7 +170,7 @@ class LanguageModelActivations:
170
170
  meta: Metadata dictionary
171
171
  verbose: Whether to log
172
172
  """
173
- from amber.language_model.inference import InferenceEngine
173
+ from mi_crow.language_model.inference import InferenceEngine
174
174
  InferenceEngine._save_run_metadata(store, run_name, meta, verbose)
175
175
 
176
176
  def _process_batch(
@@ -2,11 +2,11 @@ from dataclasses import dataclass, field
2
2
  from typing import Optional, Dict, Any, TYPE_CHECKING, List, Set
3
3
 
4
4
  if TYPE_CHECKING:
5
- from amber.language_model.language_model import LanguageModel
5
+ from mi_crow.language_model.language_model import LanguageModel
6
6
  from torch import nn
7
7
  from transformers import PreTrainedTokenizerBase
8
- from amber.hooks.hook import Hook
9
- from amber.store.store import Store
8
+ from mi_crow.hooks.hook import Hook
9
+ from mi_crow.store.store import Store
10
10
 
11
11
 
12
12
  @dataclass
@@ -6,7 +6,7 @@ from collections import defaultdict
6
6
  from typing import Dict, List, Any, TYPE_CHECKING
7
7
 
8
8
  if TYPE_CHECKING:
9
- from amber.language_model.context import LanguageModelContext
9
+ from mi_crow.language_model.context import LanguageModelContext
10
10
 
11
11
 
12
12
  def collect_hooks_metadata(context: "LanguageModelContext") -> Dict[str, List[Dict[str, Any]]]:
@@ -8,14 +8,14 @@ from typing import Sequence, Any, Dict, List, TYPE_CHECKING
8
8
  import torch
9
9
  from torch import nn
10
10
 
11
- from amber.language_model.utils import get_device_from_model, move_tensors_to_device, extract_logits_from_output
12
- from amber.utils import get_logger
11
+ from mi_crow.language_model.utils import get_device_from_model, move_tensors_to_device, extract_logits_from_output
12
+ from mi_crow.utils import get_logger
13
13
 
14
14
  if TYPE_CHECKING:
15
- from amber.language_model.language_model import LanguageModel
16
- from amber.hooks.controller import Controller
17
- from amber.datasets import BaseDataset
18
- from amber.store.store import Store
15
+ from mi_crow.language_model.language_model import LanguageModel
16
+ from mi_crow.hooks.controller import Controller
17
+ from mi_crow.datasets import BaseDataset
18
+ from mi_crow.store.store import Store
19
19
 
20
20
  logger = get_logger(__name__)
21
21
 
@@ -80,7 +80,7 @@ class InferenceEngine:
80
80
  Args:
81
81
  enc: Encoded inputs dictionary
82
82
  """
83
- from amber.hooks.implementations.model_input_detector import ModelInputDetector
83
+ from mi_crow.hooks.implementations.model_input_detector import ModelInputDetector
84
84
 
85
85
  detectors = self.lm.layers.get_detectors()
86
86
  for detector in detectors:
@@ -9,11 +9,11 @@ import torch
9
9
  from torch import nn
10
10
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
11
11
 
12
- from amber.store.store import Store
13
- from amber.language_model.utils import extract_model_id
12
+ from mi_crow.store.store import Store
13
+ from mi_crow.language_model.utils import extract_model_id
14
14
 
15
15
  if TYPE_CHECKING:
16
- from amber.language_model.language_model import LanguageModel
16
+ from mi_crow.language_model.language_model import LanguageModel
17
17
 
18
18
 
19
19
  def initialize_model_id(
@@ -8,18 +8,18 @@ import torch
8
8
  from torch import nn, Tensor
9
9
  from transformers import PreTrainedTokenizerBase
10
10
 
11
- from amber.language_model.layers import LanguageModelLayers
12
- from amber.language_model.tokenizer import LanguageModelTokenizer
13
- from amber.language_model.activations import LanguageModelActivations
14
- from amber.language_model.context import LanguageModelContext
15
- from amber.language_model.inference import InferenceEngine
16
- from amber.language_model.persistence import save_model, load_model_from_saved_file
17
- from amber.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
18
- from amber.store.store import Store
19
- from amber.utils import get_logger
11
+ from mi_crow.language_model.layers import LanguageModelLayers
12
+ from mi_crow.language_model.tokenizer import LanguageModelTokenizer
13
+ from mi_crow.language_model.activations import LanguageModelActivations
14
+ from mi_crow.language_model.context import LanguageModelContext
15
+ from mi_crow.language_model.inference import InferenceEngine
16
+ from mi_crow.language_model.persistence import save_model, load_model_from_saved_file
17
+ from mi_crow.language_model.initialization import initialize_model_id, create_from_huggingface, create_from_local_torch
18
+ from mi_crow.store.store import Store
19
+ from mi_crow.utils import get_logger
20
20
 
21
21
  if TYPE_CHECKING:
22
- from amber.mechanistic.sae.concepts.input_tracker import InputTracker
22
+ from mi_crow.mechanistic.sae.concepts.input_tracker import InputTracker
23
23
 
24
24
  logger = get_logger(__name__)
25
25
 
@@ -308,7 +308,7 @@ class LanguageModel:
308
308
  if self._input_tracker is not None:
309
309
  return self._input_tracker
310
310
 
311
- from amber.mechanistic.sae.concepts.input_tracker import InputTracker
311
+ from mi_crow.mechanistic.sae.concepts.input_tracker import InputTracker
312
312
 
313
313
  self._input_tracker = InputTracker(language_model=self)
314
314
 
@@ -2,12 +2,12 @@ from typing import Dict, List, Callable, TYPE_CHECKING
2
2
 
3
3
  from torch import nn
4
4
 
5
- from amber.hooks.hook import Hook, HookType
6
- from amber.hooks.detector import Detector
7
- from amber.hooks.controller import Controller
5
+ from mi_crow.hooks.hook import Hook, HookType
6
+ from mi_crow.hooks.detector import Detector
7
+ from mi_crow.hooks.controller import Controller
8
8
 
9
9
  if TYPE_CHECKING:
10
- from amber.language_model.context import LanguageModelContext
10
+ from mi_crow.language_model.context import LanguageModelContext
11
11
 
12
12
 
13
13
  class LanguageModelLayers:
@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING
9
9
  import torch
10
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
11
 
12
- from amber.language_model.contracts import ModelMetadata
13
- from amber.language_model.hook_metadata import collect_hooks_metadata
14
- from amber.language_model.utils import extract_model_id
12
+ from mi_crow.language_model.contracts import ModelMetadata
13
+ from mi_crow.language_model.hook_metadata import collect_hooks_metadata
14
+ from mi_crow.language_model.utils import extract_model_id
15
15
 
16
16
  if TYPE_CHECKING:
17
- from amber.language_model.language_model import LanguageModel
18
- from amber.store.store import Store
17
+ from mi_crow.language_model.language_model import LanguageModel
18
+ from mi_crow.store.store import Store
19
19
 
20
20
 
21
21
  def save_model(
@@ -78,7 +78,7 @@ def save_model(
78
78
  f"Failed to save model to {save_path}. Error: {e}"
79
79
  ) from e
80
80
 
81
- from amber.utils import get_logger
81
+ from mi_crow.utils import get_logger
82
82
  logger = get_logger(__name__)
83
83
  logger.info(f"Saved model to {save_path}")
84
84
 
@@ -169,7 +169,7 @@ def load_model_from_saved_file(
169
169
  # Note: Hooks are not automatically restored as they require hook instances
170
170
  # The hook metadata is available in metadata_dict["hooks"] if needed
171
171
 
172
- from amber.utils import get_logger
172
+ from mi_crow.utils import get_logger
173
173
  logger = get_logger(__name__)
174
174
  logger.info(f"Loaded model from {saved_path} (model_id: {model_id})")
175
175
 
@@ -4,7 +4,7 @@ from torch import nn
4
4
  from transformers import AutoTokenizer
5
5
 
6
6
  if TYPE_CHECKING:
7
- from amber.language_model.context import LanguageModelContext
7
+ from mi_crow.language_model.context import LanguageModelContext
8
8
 
9
9
 
10
10
  class LanguageModelTokenizer:
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Optional, TYPE_CHECKING
3
3
 
4
- from amber.store.store import Store
4
+ from mi_crow.store.store import Store
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  pass
@@ -9,12 +9,12 @@ import heapq
9
9
  import torch
10
10
  from torch import nn
11
11
 
12
- from amber.mechanistic.sae.concepts.concept_models import NeuronText
13
- from amber.mechanistic.sae.autoencoder_context import AutoencoderContext
14
- from amber.utils import get_logger
12
+ from mi_crow.mechanistic.sae.concepts.concept_models import NeuronText
13
+ from mi_crow.mechanistic.sae.autoencoder_context import AutoencoderContext
14
+ from mi_crow.utils import get_logger
15
15
 
16
16
  if TYPE_CHECKING:
17
- from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
17
+ from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
18
18
 
19
19
  logger = get_logger(__name__)
20
20
 
@@ -59,12 +59,12 @@ class AutoencoderConcepts:
59
59
 
60
60
  def _ensure_dictionary(self):
61
61
  if self.dictionary is None:
62
- from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
62
+ from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
63
63
  self.dictionary = ConceptDictionary(self._n_size)
64
64
  return self.dictionary
65
65
 
66
66
  def load_concepts_from_csv(self, csv_filepath: str | Path):
67
- from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
67
+ from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
68
68
  self.dictionary = ConceptDictionary.from_csv(
69
69
  csv_filepath=csv_filepath,
70
70
  n_size=self._n_size,
@@ -72,7 +72,7 @@ class AutoencoderConcepts:
72
72
  )
73
73
 
74
74
  def load_concepts_from_json(self, json_filepath: str | Path):
75
- from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
75
+ from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
76
76
  self.dictionary = ConceptDictionary.from_json(
77
77
  json_filepath=json_filepath,
78
78
  n_size=self._n_size,
@@ -84,7 +84,7 @@ class AutoencoderConcepts:
84
84
  if self._top_texts_heaps is None:
85
85
  raise ValueError("No top texts available. Enable text tracking and run inference first.")
86
86
 
87
- from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
87
+ from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
88
88
  neuron_texts = self.get_all_top_texts()
89
89
 
90
90
  self.dictionary = ConceptDictionary.from_llm(
@@ -6,7 +6,7 @@ from typing import Dict, Sequence, TYPE_CHECKING, Optional
6
6
  import json
7
7
  import csv
8
8
 
9
- from amber.store.store import Store
9
+ from mi_crow.store.store import Store
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  pass
@@ -1,9 +1,9 @@
1
1
  from typing import TYPE_CHECKING, Sequence
2
2
 
3
- from amber.utils import get_logger
3
+ from mi_crow.utils import get_logger
4
4
 
5
5
  if TYPE_CHECKING:
6
- from amber.language_model.language_model import LanguageModel
6
+ from mi_crow.language_model.language_model import LanguageModel
7
7
 
8
8
  logger = get_logger(__name__)
9
9
 
@@ -0,0 +1,5 @@
1
+ from mi_crow.mechanistic.sae.modules.topk_sae import TopKSae, TopKSaeTrainingConfig
2
+ from mi_crow.mechanistic.sae.modules.l1_sae import L1Sae, L1SaeTrainingConfig
3
+
4
+ __all__ = ["TopKSae", "TopKSaeTrainingConfig", "L1Sae", "L1SaeTrainingConfig"]
5
+
@@ -3,11 +3,11 @@ from typing import Any
3
3
 
4
4
  import torch
5
5
  from overcomplete import SAE as OvercompleteSAE
6
- from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
7
- from amber.mechanistic.sae.sae import Sae
8
- from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
9
- from amber.store.store import Store
10
- from amber.utils import get_logger
6
+ from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
7
+ from mi_crow.mechanistic.sae.sae import Sae
8
+ from mi_crow.mechanistic.sae.sae_trainer import SaeTrainingConfig
9
+ from mi_crow.store.store import Store
10
+ from mi_crow.utils import get_logger
11
11
 
12
12
  logger = get_logger(__name__)
13
13
 
@@ -293,7 +293,7 @@ class L1Sae(Sae):
293
293
  # Save overcomplete model state dict
294
294
  sae_state_dict = self.sae_engine.state_dict()
295
295
 
296
- amber_metadata = {
296
+ mi_crow_metadata = {
297
297
  "concepts_state": {
298
298
  'multiplication': self.concepts.multiplication.data,
299
299
  'bias': self.concepts.bias.data,
@@ -307,7 +307,7 @@ class L1Sae(Sae):
307
307
 
308
308
  payload = {
309
309
  "sae_state_dict": sae_state_dict,
310
- "amber_metadata": amber_metadata,
310
+ "mi_crow_metadata": mi_crow_metadata,
311
311
  }
312
312
 
313
313
  torch.save(payload, save_path)
@@ -336,16 +336,16 @@ class L1Sae(Sae):
336
336
  payload = torch.load(p, map_location=map_location)
337
337
 
338
338
  # Extract our metadata
339
- if "amber_metadata" not in payload:
340
- raise ValueError(f"Invalid L1SAE save format: missing 'amber_metadata' key in {p}")
341
-
342
- amber_meta = payload["amber_metadata"]
343
- n_latents = int(amber_meta["n_latents"])
344
- n_inputs = int(amber_meta["n_inputs"])
345
- device = amber_meta.get("device", "cpu")
346
- layer_signature = amber_meta.get("layer_signature")
347
- model_id = amber_meta.get("model_id")
348
- concepts_state = amber_meta.get("concepts_state", {})
339
+ if "mi_crow_metadata" not in payload:
340
+ raise ValueError(f"Invalid L1SAE save format: missing 'mi_crow_metadata' key in {p}")
341
+
342
+ mi_crow_meta = payload["mi_crow_metadata"]
343
+ n_latents = int(mi_crow_meta["n_latents"])
344
+ n_inputs = int(mi_crow_meta["n_inputs"])
345
+ device = mi_crow_meta.get("device", "cpu")
346
+ layer_signature = mi_crow_meta.get("layer_signature")
347
+ model_id = mi_crow_meta.get("model_id")
348
+ concepts_state = mi_crow_meta.get("concepts_state", {})
349
349
 
350
350
  # Create L1Sae instance
351
351
  l1_sae = L1Sae(
@@ -7,11 +7,11 @@ from overcomplete import (
7
7
  TopKSAE as OvercompleteTopkSAE,
8
8
  SAE as OvercompleteSAE
9
9
  )
10
- from amber.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
11
- from amber.mechanistic.sae.sae import Sae
12
- from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
13
- from amber.store.store import Store
14
- from amber.utils import get_logger
10
+ from mi_crow.hooks.hook import HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
11
+ from mi_crow.mechanistic.sae.sae import Sae
12
+ from mi_crow.mechanistic.sae.sae_trainer import SaeTrainingConfig
13
+ from mi_crow.store.store import Store
14
+ from mi_crow.utils import get_logger
15
15
 
16
16
  logger = get_logger(__name__)
17
17
 
@@ -338,7 +338,7 @@ class TopKSae(Sae):
338
338
  # Save overcomplete model state dict
339
339
  sae_state_dict = self.sae_engine.state_dict()
340
340
 
341
- amber_metadata = {
341
+ mi_crow_metadata = {
342
342
  "concepts_state": {
343
343
  'multiplication': self.concepts.multiplication.data,
344
344
  'bias': self.concepts.bias.data,
@@ -353,7 +353,7 @@ class TopKSae(Sae):
353
353
 
354
354
  payload = {
355
355
  "sae_state_dict": sae_state_dict,
356
- "amber_metadata": amber_metadata,
356
+ "mi_crow_metadata": mi_crow_metadata,
357
357
  }
358
358
 
359
359
  torch.save(payload, save_path)
@@ -382,17 +382,17 @@ class TopKSae(Sae):
382
382
  payload = torch.load(p, map_location=map_location)
383
383
 
384
384
  # Extract our metadata
385
- if "amber_metadata" not in payload:
386
- raise ValueError(f"Invalid TopKSAE save format: missing 'amber_metadata' key in {p}")
387
-
388
- amber_meta = payload["amber_metadata"]
389
- n_latents = int(amber_meta["n_latents"])
390
- n_inputs = int(amber_meta["n_inputs"])
391
- k = int(amber_meta["k"])
392
- device = amber_meta.get("device", "cpu")
393
- layer_signature = amber_meta.get("layer_signature")
394
- model_id = amber_meta.get("model_id")
395
- concepts_state = amber_meta.get("concepts_state", {})
385
+ if "mi_crow_metadata" not in payload:
386
+ raise ValueError(f"Invalid TopKSAE save format: missing 'mi_crow_metadata' key in {p}")
387
+
388
+ mi_crow_meta = payload["mi_crow_metadata"]
389
+ n_latents = int(mi_crow_meta["n_latents"])
390
+ n_inputs = int(mi_crow_meta["n_inputs"])
391
+ k = int(mi_crow_meta["k"])
392
+ device = mi_crow_meta.get("device", "cpu")
393
+ layer_signature = mi_crow_meta.get("layer_signature")
394
+ model_id = mi_crow_meta.get("model_id")
395
+ concepts_state = mi_crow_meta.get("concepts_state", {})
396
396
 
397
397
  # Create TopKSAE instance
398
398
  topk_sae = TopKSae(
@@ -5,15 +5,15 @@ from typing import Any, TYPE_CHECKING, Literal
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
- from amber.hooks.controller import Controller
9
- from amber.hooks.detector import Detector
10
- from amber.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
11
- from amber.mechanistic.sae.autoencoder_context import AutoencoderContext
12
- from amber.mechanistic.sae.concepts.autoencoder_concepts import AutoencoderConcepts
13
- from amber.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
14
- from amber.mechanistic.sae.sae_trainer import SaeTrainer
15
- from amber.store.store import Store
16
- from amber.utils import get_logger
8
+ from mi_crow.hooks.controller import Controller
9
+ from mi_crow.hooks.detector import Detector
10
+ from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
11
+ from mi_crow.mechanistic.sae.autoencoder_context import AutoencoderContext
12
+ from mi_crow.mechanistic.sae.concepts.autoencoder_concepts import AutoencoderConcepts
13
+ from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
14
+ from mi_crow.mechanistic.sae.sae_trainer import SaeTrainer
15
+ from mi_crow.store.store import Store
16
+ from mi_crow.utils import get_logger
17
17
 
18
18
  from overcomplete.sae import SAE as OvercompleteSAE
19
19
 
@@ -9,12 +9,12 @@ import gc
9
9
 
10
10
  import torch
11
11
 
12
- from amber.store.store_dataloader import StoreDataloader
13
- from amber.utils import get_logger
12
+ from mi_crow.store.store_dataloader import StoreDataloader
13
+ from mi_crow.utils import get_logger
14
14
 
15
15
  if TYPE_CHECKING:
16
- from amber.mechanistic.sae.sae import Sae
17
- from amber.store.store import Store
16
+ from mi_crow.mechanistic.sae.sae import Sae
17
+ from mi_crow.store.store import Store
18
18
 
19
19
 
20
20
  @dataclass
@@ -528,7 +528,7 @@ class SaeTrainer:
528
528
 
529
529
  sae_state_dict = self.sae.sae_engine.state_dict()
530
530
 
531
- amber_metadata = {
531
+ mi_crow_metadata = {
532
532
  "concepts_state": {
533
533
  'multiplication': self.sae.concepts.multiplication.data.cpu().clone(),
534
534
  'bias': self.sae.concepts.bias.data.cpu().clone(),
@@ -541,11 +541,11 @@ class SaeTrainer:
541
541
  }
542
542
 
543
543
  if hasattr(self.sae, 'k'):
544
- amber_metadata["k"] = self.sae.k
544
+ mi_crow_metadata["k"] = self.sae.k
545
545
 
546
546
  payload = {
547
547
  "sae_state_dict": sae_state_dict,
548
- "amber_metadata": amber_metadata,
548
+ "mi_crow_metadata": mi_crow_metadata,
549
549
  }
550
550
 
551
551
  torch.save(payload, model_path)
@@ -0,0 +1,6 @@
1
+ """Training utilities for SAE models."""
2
+
3
+ from mi_crow.mechanistic.sae.training.wandb_logger import WandbLogger
4
+
5
+ __all__ = ["WandbLogger"]
6
+
@@ -2,8 +2,8 @@
2
2
 
3
3
  from typing import Any, Optional
4
4
 
5
- from amber.mechanistic.sae.sae_trainer import SaeTrainingConfig
6
- from amber.utils import get_logger
5
+ from mi_crow.mechanistic.sae.sae_trainer import SaeTrainingConfig
6
+ from mi_crow.utils import get_logger
7
7
 
8
8
  logger = get_logger(__name__)
9
9
 
@@ -0,0 +1,5 @@
1
+ from mi_crow.store.store import Store
2
+ from mi_crow.store.local_store import LocalStore
3
+
4
+ __all__ = ["Store", "LocalStore"]
5
+
@@ -5,7 +5,7 @@ import shutil
5
5
 
6
6
  import torch
7
7
 
8
- from amber.store.store import Store, TensorMetadata
8
+ from mi_crow.store.store import Store, TensorMetadata
9
9
  import safetensors.torch as storch
10
10
 
11
11
 
@@ -2,12 +2,12 @@ from typing import Optional, Iterator, TYPE_CHECKING
2
2
 
3
3
  import torch
4
4
 
5
- from amber.utils import get_logger
5
+ from mi_crow.utils import get_logger
6
6
 
7
7
  logger = get_logger(__name__)
8
8
 
9
9
  if TYPE_CHECKING:
10
- from amber.store.store import Store
10
+ from mi_crow.store.store import Store
11
11
 
12
12
 
13
13
  class StoreDataloader:
@@ -27,7 +27,7 @@ def set_seed(seed: int, deterministic: bool = True) -> None:
27
27
  torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
28
28
 
29
29
 
30
- def get_logger(name: str = "amber", level: int | str = logging.INFO) -> logging.Logger:
30
+ def get_logger(name: str = "mi_crow", level: int | str = logging.INFO) -> logging.Logger:
31
31
  """Get a configured logger with a simple format. Idempotent."""
32
32
  logger = logging.getLogger(name)
33
33
  if isinstance(level, str):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mi-crow
3
- Version: 0.1.1.post12
3
+ Version: 0.1.1.post14
4
4
  Summary: Engineer Thesis: Explaining and modifying LLM responses using SAE and concepts.
5
5
  Author-email: Hubert Kowalski <your.email@example.com>, Adam Kaniasty <adam.kaniasty@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -96,7 +96,7 @@ uv run --group server pytest tests/server/test_api.py --cov=server --cov-fail-un
96
96
 
97
97
  ### SAE API usage
98
98
 
99
- - Configure artifact location (optional): `export SERVER_ARTIFACT_BASE_PATH=/path/to/amber_artifacts` (defaults to `~/.cache/amber_server`)
99
+ - Configure artifact location (optional): `export SERVER_ARTIFACT_BASE_PATH=/path/to/mi_crow_artifacts` (defaults to `~/.cache/mi_crow_server`)
100
100
  - Load a model: `curl -X POST http://localhost:8000/models/load -H "Content-Type: application/json" -d '{"model_id":"bielik"}'`
101
101
  - Save activations from dataset (stored in `LocalStore` under `activations/<model>/<run_id>`):
102
102
  - HF dataset: `{"dataset":{"type":"hf","name":"ag_news","split":"train","text_field":"text"}}`
@@ -0,0 +1,52 @@
1
+ mi_crow/__init__.py,sha256=J7aXVlAicbjvk5630rhDxx0ATsvZnihud5u_aQpAwY8,487
2
+ mi_crow/utils.py,sha256=LTfh2Ep87lAgPBaZkrQPP9caXFJoS9zUxu4qFuV4kzM,1549
3
+ mi_crow/datasets/__init__.py,sha256=lCAc3nFlvoERrBPAan6C9YFmDx86W2gbIAy267Rb2Sk,349
4
+ mi_crow/datasets/base_dataset.py,sha256=vYx-oj3jVhLZD1-xGSO4K4ZIsQtYpHP5zHmg7jd4FE0,22512
5
+ mi_crow/datasets/classification_dataset.py,sha256=nL_xndJHyf8hlLxKBe_ZO2YLYsXQjGyeY6csqGTTzEY,21706
6
+ mi_crow/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
7
+ mi_crow/datasets/text_dataset.py,sha256=5FzHWkMWWK0yP69O48S3fUj5KgHb8qo3mkvvZihHFuU,16781
8
+ mi_crow/hooks/__init__.py,sha256=KYy5qcbEpnJceNH86ofy43Suu_36QXjj0HYl79rVyls,693
9
+ mi_crow/hooks/controller.py,sha256=eo8LMERORXYUjH4-_R6DHk5JKN6O8SW6PlnuBFrlNqg,6063
10
+ mi_crow/hooks/detector.py,sha256=Bj3xz56cSgRvbcoQBsHIdlJdf0dtgVLw3l1pOSRvRAg,3114
11
+ mi_crow/hooks/hook.py,sha256=JrCyPptXzHICAToxug7FD8zKWKZcLxbXfIe7UHCkh34,7542
12
+ mi_crow/hooks/utils.py,sha256=GdUAqL9InCsthjBVWbZtVQp2VtQLLOMaJy8S8Tc7WX0,2024
13
+ mi_crow/hooks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ mi_crow/hooks/implementations/function_controller.py,sha256=Rwu1Ghffqm-Jc4mqniRdIsXQeV24JKlx9CMO8jP8iv0,3092
15
+ mi_crow/hooks/implementations/layer_activation_detector.py,sha256=mFqLkWrBlNYzYYvwI8O-ymnpwP1UJSX7eWIQ7Fp-lxE,3167
16
+ mi_crow/hooks/implementations/model_input_detector.py,sha256=4AFA88eQ-zM9Kz8bpv4Bl5cvlZF5ZyIRXNi7Znd6V98,10078
17
+ mi_crow/hooks/implementations/model_output_detector.py,sha256=BmD8oJ1k_nv4hIzG-9SGxSCRARQkLRI4d_XQmKs6jjE,4883
18
+ mi_crow/language_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ mi_crow/language_model/activations.py,sha256=mJhuXkQ_JKzQsrFossnXVTdq4WgBI10Gi0ZYeQ9USFY,16989
20
+ mi_crow/language_model/context.py,sha256=TJSe6IruGktePpQ0dtHIoeatSY72qAhEKeTvLizCdlw,1124
21
+ mi_crow/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5IlihEm_2VM,268
22
+ mi_crow/language_model/hook_metadata.py,sha256=GACZjZUneo2l5j7DCFycLAunTm0etdMQ2YB_xgueUuk,1394
23
+ mi_crow/language_model/inference.py,sha256=-Kpm85jM8y6-GyDgrvIczitBIwGh8grJP8aYuXsLV-g,19082
24
+ mi_crow/language_model/initialization.py,sha256=e_Vkk-p9KWRt6-Hmkm6I29dTf20jzEAyNF9CG4nc48M,3704
25
+ mi_crow/language_model/language_model.py,sha256=a6CcklVA65oYtFxGXiwQrOKMPZj6eb7LOiT1zJ5-guo,13965
26
+ mi_crow/language_model/layers.py,sha256=1yExHodMyqr_Yk4W-2HiSGnRs2sYOA7swsxI8u0Uvfk,15914
27
+ mi_crow/language_model/persistence.py,sha256=9wQE6tRvLg7BgdLlkKRTOfRwXb5Q0LsEgg8B9J7Yos0,5881
28
+ mi_crow/language_model/tokenizer.py,sha256=uZbMDVNnzu8WZINUaR1tLFXiuk9V5pAoahwnJOUvEuE,7379
29
+ mi_crow/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
30
+ mi_crow/mechanistic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ mi_crow/mechanistic/sae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
+ mi_crow/mechanistic/sae/autoencoder_context.py,sha256=u5WzSlLb8_HaJF9LwGqe-J_rE-iRCoXbvwEhTZIArkw,947
33
+ mi_crow/mechanistic/sae/sae.py,sha256=R2IckZ-UDVXFURUSRoIdX0MGLe0TWU5JCVeDXPz1Wv4,6053
34
+ mi_crow/mechanistic/sae/sae_trainer.py,sha256=VqN9UgEPtgYItzAS0RqMfOhTnTxIEuModO0noHYieFA,26327
35
+ mi_crow/mechanistic/sae/concepts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
+ mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py,sha256=SE7A6y6cZiMGhylljNdZz0D8lhrulUh1FYiKhnPPRvY,14387
37
+ mi_crow/mechanistic/sae/concepts/concept_dictionary.py,sha256=aUXTe4Fy7Oe7iYJOcQImEXYMZAXhnkfAszZNwivk4eg,7673
38
+ mi_crow/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_FfWAY2XN4LFPIkNQiJWbY,133
39
+ mi_crow/mechanistic/sae/concepts/input_tracker.py,sha256=kIiqt7guv_-9-UPYtefAFJbHkWtAS_mnqYVvRU4eb2o,1890
40
+ mi_crow/mechanistic/sae/modules/__init__.py,sha256=e0lkCALQZcJN7KpYyTtXx3OD2NhBxV_kOZLLJ6EWaTE,243
41
+ mi_crow/mechanistic/sae/modules/l1_sae.py,sha256=qqw0iTWLSmWAlz5kgfw_mex8LeecFWM1FobyUteMqmM,15388
42
+ mi_crow/mechanistic/sae/modules/topk_sae.py,sha256=pK_ajKTQb0wGAftzb6AE5ZZthV3aFLr6G3avOVclSHE,17313
43
+ mi_crow/mechanistic/sae/training/__init__.py,sha256=5flCJVkOyKizY0FZy1OP5v0EI6bPEayunpnUPp82a6s,140
44
+ mi_crow/mechanistic/sae/training/wandb_logger.py,sha256=YlSJd5CaNa35RmIgf1FD_gSEDyhGRa2UdHo_Ofrplos,8558
45
+ mi_crow/store/__init__.py,sha256=DrYTpdgzrRzjHm9bigy-GiP0BGxzjmD3-lJCthtgxbE,123
46
+ mi_crow/store/local_store.py,sha256=XmguFvdrUi6NHzvV_bLaDJzpk5KWU_-ObkzhICcLu6g,17216
47
+ mi_crow/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
48
+ mi_crow/store/store_dataloader.py,sha256=UkZhHCOTg56ozomPtU9vHBhxIMOPcOiyfMqiAxgqtQs,4341
49
+ mi_crow-0.1.1.post14.dist-info/METADATA,sha256=Dd5lhIR9XmTNgdTDgZBhSsFVAvih80juP77KxHHfjFQ,6584
50
+ mi_crow-0.1.1.post14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
51
+ mi_crow-0.1.1.post14.dist-info/top_level.txt,sha256=DTuNo2VWgrH6jQKY19NciReSpLwGKKIRzJ3WbpspLlE,8
52
+ mi_crow-0.1.1.post14.dist-info/RECORD,,
@@ -0,0 +1 @@
1
+ mi_crow
@@ -1,11 +0,0 @@
1
- from amber.datasets.base_dataset import BaseDataset
2
- from amber.datasets.text_dataset import TextDataset
3
- from amber.datasets.classification_dataset import ClassificationDataset
4
- from amber.datasets.loading_strategy import LoadingStrategy
5
-
6
- __all__ = [
7
- "BaseDataset",
8
- "TextDataset",
9
- "ClassificationDataset",
10
- "LoadingStrategy",
11
- ]
amber/hooks/__init__.py DELETED
@@ -1,20 +0,0 @@
1
- from amber.hooks.hook import Hook, HookType, HookError
2
- from amber.hooks.detector import Detector
3
- from amber.hooks.controller import Controller
4
- from amber.hooks.implementations.layer_activation_detector import LayerActivationDetector
5
- from amber.hooks.implementations.model_input_detector import ModelInputDetector
6
- from amber.hooks.implementations.model_output_detector import ModelOutputDetector
7
- from amber.hooks.implementations.function_controller import FunctionController
8
-
9
- __all__ = [
10
- "Hook",
11
- "HookType",
12
- "HookError",
13
- "Detector",
14
- "Controller",
15
- "LayerActivationDetector",
16
- "ModelInputDetector",
17
- "ModelOutputDetector",
18
- "FunctionController",
19
- ]
20
-
@@ -1,5 +0,0 @@
1
- from amber.mechanistic.sae.modules.topk_sae import TopKSae, TopKSaeTrainingConfig
2
- from amber.mechanistic.sae.modules.l1_sae import L1Sae, L1SaeTrainingConfig
3
-
4
- __all__ = ["TopKSae", "TopKSaeTrainingConfig", "L1Sae", "L1SaeTrainingConfig"]
5
-
amber/store/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from amber.store.store import Store
2
- from amber.store.local_store import LocalStore
3
-
4
- __all__ = ["Store", "LocalStore"]
5
-
@@ -1,51 +0,0 @@
1
- amber/__init__.py,sha256=5nh0D8qvFgOhBEQj00Rm06T1iY5VcSiifAg9SoY1LLA,483
2
- amber/utils.py,sha256=oER2LA_alUjaIk_xCAyP2V54ywjqsg00I4KvitYnJPc,1547
3
- amber/datasets/__init__.py,sha256=zhqgbm5zMBsRbmPNfjlYNJwGWOLuCNf5jEj0P8aopRU,341
4
- amber/datasets/base_dataset.py,sha256=X2wt3GdjgAOY24_vOqrD5gVFxGplSRMCb69CoQtj0xw,22508
5
- amber/datasets/classification_dataset.py,sha256=x_ZQ4dMzoY3Nn8V1I01xvzJK_IcHcDcm8dIxYeXzV5g,21700
6
- amber/datasets/loading_strategy.py,sha256=17VM3Td8lqDllGIx9DHI6WiXmSKKQHDHbfe4ZeM8ATA,1206
7
- amber/datasets/text_dataset.py,sha256=ly0GHCS28Rg5ZluaafjavhcbvSD9-6ryovd_Y1ZIMms,16775
8
- amber/hooks/__init__.py,sha256=9H08ZVoTK6TzYJXjEP2aqdHfoyLfdXvg6eOv3K1zNps,679
9
- amber/hooks/controller.py,sha256=hc8FrrDosFYLrEGsEZmx1KsJ77F4p_gMKcF2WzHiURY,6057
10
- amber/hooks/detector.py,sha256=5drJFrdrjseVjRNT-cq-U8XCt8AXV04YY2YrkQz4eFk,3110
11
- amber/hooks/hook.py,sha256=-Qi-GJqRuIskXMuHUzp9_ESbbZi5tLSAFMWoBmrP3io,7540
12
- amber/hooks/utils.py,sha256=wtsrjsMt-bXR3NshkwyZmfLre3IE3S4E5EoKppQrYOo,2022
13
- amber/hooks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- amber/hooks/implementations/function_controller.py,sha256=66FFx_7sU7b0_FFQFFkAOqQm3aGsiyIRuZ64hIv-0w8,3088
15
- amber/hooks/implementations/layer_activation_detector.py,sha256=bzoW6V8NNDNgRASs1YN_1TjEXaK3ahoNWiZ-ODfjB6I,3161
16
- amber/hooks/implementations/model_input_detector.py,sha256=cYRVfyBEHi-1qg6F-4Q0vKEae6gYtq_3g1j3rOOCQdA,10074
17
- amber/hooks/implementations/model_output_detector.py,sha256=iN-twt7Chc9ODmj-iei7_Ah7GqvE-knTVWi4C9kNye4,4879
18
- amber/language_model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- amber/language_model/activations.py,sha256=7VRDlCM-IctF1ee0H68_Dd7xeXHyUSQe1SOPb28Ej18,16971
20
- amber/language_model/context.py,sha256=koslpikCcu9Svsopboa1wd7Fv0R2-4sI2whOCvVWvT8,1118
21
- amber/language_model/contracts.py,sha256=6ij7rzJcpSAKgYx-fiefg0Fi8TsFugaM5IlihEm_2VM,268
22
- amber/language_model/hook_metadata.py,sha256=9Xyfiu4ekCZj79zG4gZfLk-850AO2iKDE24FDXe7q7s,1392
23
- amber/language_model/inference.py,sha256=l8BASS8E9B4VWJHucEqF_G_zqOzlKeG4KEvGameBbMw,19068
24
- amber/language_model/initialization.py,sha256=hfrKdI_fsmaxk0p9q4wN7EFxq_lSXs8BXGlxwKJ21Qw,3698
25
- amber/language_model/language_model.py,sha256=MXoaXYbNBUxHw4sxtWkVFLbXiw9tFEx9GI78sCiESuQ,13943
26
- amber/language_model/layers.py,sha256=Ob7QZl8i236ALLklY9o_xtjDZSt6FD8sqdmFy_YLgN0,15906
27
- amber/language_model/persistence.py,sha256=i2ibDH1OABM5-ZNNLh7h4rOYWPsg3aaeYhmB_xWYDZw,5867
28
- amber/language_model/tokenizer.py,sha256=9eKNOHvUjIJhJbj7M-tN7jWU5lWhOeCY_cssa4exQ1g,7377
29
- amber/language_model/utils.py,sha256=5Y7scRvvudUjKDV8QPhC3HAc2S-dCuqbm6xEjRr0fRM,2630
30
- amber/mechanistic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
- amber/mechanistic/sae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- amber/mechanistic/sae/autoencoder_context.py,sha256=cn0mv9COqT3jNcvXBfce70ankVEmw9kNE3Mu-knugoc,945
33
- amber/mechanistic/sae/sae.py,sha256=ha6rXGsOXE59E_ohTH0vJh6M4rQh3Xw0GfmCkSgeYS4,6035
34
- amber/mechanistic/sae/sae_trainer.py,sha256=GMrPz9SpSuANA0tJt3IkyIOOqVr2k7apE0w4CqL92gM,26311
35
- amber/mechanistic/sae/concepts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
- amber/mechanistic/sae/concepts/autoencoder_concepts.py,sha256=dTB9v5zK5lcKc4_zbYtWoCy8fA6ycj1-FJid4-7hQ04,14371
37
- amber/mechanistic/sae/concepts/concept_dictionary.py,sha256=9px845gODuufW2koym-5426q0ijdpxvS7avH1vu8-Ls,7671
38
- amber/mechanistic/sae/concepts/concept_models.py,sha256=HGyPoMSmj8CAg9joIa6fV_FfWAY2XN4LFPIkNQiJWbY,133
39
- amber/mechanistic/sae/concepts/input_tracker.py,sha256=81FrOv9AAC7ejhryOWDTZ7Hlt3B2WoANx-wiO0KLr24,1886
40
- amber/mechanistic/sae/modules/__init__.py,sha256=xpoz0HtPWoJD4dPj1qHaxtXDr7J0ERn30CX3m1dz21s,239
41
- amber/mechanistic/sae/modules/l1_sae.py,sha256=_BebvpB9iUCTDjSiYNzDBM4l9sU_wadAdypwXaSb4ww,15352
42
- amber/mechanistic/sae/modules/topk_sae.py,sha256=GQA8hYb6Fw7U1e5ExZjzZBjAZRhGk-VwqqYREVCQ_u8,17275
43
- amber/mechanistic/sae/training/wandb_logger.py,sha256=d3vVBIQrnsJurX5HNVu7OYW4DqNgv18UZDpV8ddfN9k,8554
44
- amber/store/__init__.py,sha256=UW4Hqyu-_qgnZ-gN_mk97OaWSrlPERcNi5YjnXMKeOU,119
45
- amber/store/local_store.py,sha256=1pJbizZKrzNt_IQFnCFYjApPXs9ot-G1H8adeR7Qi50,17214
46
- amber/store/store.py,sha256=VuDe9Git0glND3TTHh0zhDJNxdQY3dCp0cURhApYQbU,9334
47
- amber/store/store_dataloader.py,sha256=QyYHSgOos8e-yzaEE_rySSVlGKaRNybURSDCgNrTIVM,4337
48
- mi_crow-0.1.1.post12.dist-info/METADATA,sha256=1_SPBcnK8j_scOIWwnPqQGjlUpNnXxGGtvkYRZHeDQ8,6580
49
- mi_crow-0.1.1.post12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
50
- mi_crow-0.1.1.post12.dist-info/top_level.txt,sha256=FNP1x_ePvcW9Jsr7J9gCBARdDC-gqxIYtWF6HGNxtnI,6
51
- mi_crow-0.1.1.post12.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- amber
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes