sae-lens 6.26.1__tar.gz → 6.28.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {sae_lens-6.26.1 → sae_lens-6.28.0}/PKG-INFO +3 -1
  2. {sae_lens-6.26.1 → sae_lens-6.28.0}/README.md +2 -0
  3. {sae_lens-6.26.1 → sae_lens-6.28.0}/pyproject.toml +2 -1
  4. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/__init__.py +3 -1
  5. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/cache_activations_runner.py +12 -5
  6. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/config.py +2 -0
  7. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/loading/pretrained_sae_loaders.py +2 -1
  8. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/gated_sae.py +1 -0
  9. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/jumprelu_sae.py +3 -0
  10. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/standard_sae.py +2 -0
  11. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/temporal_sae.py +1 -0
  12. sae_lens-6.28.0/sae_lens/synthetic/__init__.py +89 -0
  13. sae_lens-6.28.0/sae_lens/synthetic/activation_generator.py +215 -0
  14. sae_lens-6.28.0/sae_lens/synthetic/correlation.py +170 -0
  15. sae_lens-6.28.0/sae_lens/synthetic/evals.py +141 -0
  16. sae_lens-6.28.0/sae_lens/synthetic/feature_dictionary.py +138 -0
  17. sae_lens-6.28.0/sae_lens/synthetic/firing_probabilities.py +104 -0
  18. sae_lens-6.28.0/sae_lens/synthetic/hierarchy.py +335 -0
  19. sae_lens-6.28.0/sae_lens/synthetic/initialization.py +40 -0
  20. sae_lens-6.28.0/sae_lens/synthetic/plotting.py +230 -0
  21. sae_lens-6.28.0/sae_lens/synthetic/training.py +145 -0
  22. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/activations_store.py +51 -91
  23. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/mixing_buffer.py +14 -5
  24. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/sae_trainer.py +1 -1
  25. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/util.py +26 -1
  26. {sae_lens-6.26.1 → sae_lens-6.28.0}/LICENSE +0 -0
  27. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/analysis/__init__.py +0 -0
  28. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  29. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  30. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/constants.py +0 -0
  31. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/evals.py +0 -0
  32. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/llm_sae_training_runner.py +0 -0
  33. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/load_model.py +0 -0
  34. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/loading/__init__.py +0 -0
  35. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  36. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/pretokenize_runner.py +0 -0
  37. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/pretrained_saes.yaml +0 -0
  38. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/registry.py +0 -0
  39. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/__init__.py +0 -0
  40. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  41. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
  42. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  43. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/sae.py +0 -0
  44. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/topk_sae.py +0 -0
  45. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/transcoder.py +0 -0
  46. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/tokenization_and_batching.py +1 -1
  47. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/__init__.py +0 -0
  48. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/activation_scaler.py +0 -0
  49. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/optim.py +0 -0
  50. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/types.py +0 -0
  51. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  52. {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.26.1
3
+ Version: 6.28.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -77,6 +77,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
77
77
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
79
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
+ - [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
81
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
80
82
 
81
83
  ## Join the Slack!
82
84
 
@@ -41,6 +41,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
41
41
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
42
42
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
43
43
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
44
+ - [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
45
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
44
46
 
45
47
  ## Join the Slack!
46
48
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.26.1"
3
+ version = "6.28.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -58,6 +58,7 @@ eai-sparsify = "^1.1.1"
58
58
  mike = "^2.0.0"
59
59
  trio = "^0.30.0"
60
60
  dictionary-learning = "^0.1.0"
61
+ kaleido = "^1.2.0"
61
62
 
62
63
  [tool.poetry.extras]
63
64
  mamba = ["mamba-lens"]
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.26.1"
2
+ __version__ = "6.28.0"
3
3
 
4
4
  import logging
5
5
 
@@ -63,6 +63,7 @@ from .loading.pretrained_sae_loaders import (
63
63
  from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
64
64
  from .registry import register_sae_class, register_sae_training_class
65
65
  from .training.activations_store import ActivationsStore
66
+ from .training.sae_trainer import SAETrainer
66
67
  from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
67
68
 
68
69
  __all__ = [
@@ -102,6 +103,7 @@ __all__ = [
102
103
  "JumpReLUTrainingSAE",
103
104
  "JumpReLUTrainingSAEConfig",
104
105
  "SAETrainingRunner",
106
+ "SAETrainer",
105
107
  "LoggingConfig",
106
108
  "BatchTopKTrainingSAE",
107
109
  "BatchTopKTrainingSAEConfig",
@@ -263,14 +263,21 @@ class CacheActivationsRunner:
263
263
 
264
264
  for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
265
265
  try:
266
- buffer = self.activations_store.get_raw_buffer(
267
- self.cfg.n_batches_in_buffer, shuffle=False
268
- )
269
- shard = self._create_shard(buffer)
266
+ # Accumulate n_batches_in_buffer batches into one shard
267
+ buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
268
+ for _ in range(self.cfg.n_batches_in_buffer):
269
+ buffers.append(self.activations_store.get_raw_llm_batch())
270
+ # Concatenate all batches
271
+ acts = torch.cat([b[0] for b in buffers], dim=0)
272
+ token_ids: torch.Tensor | None = None
273
+ if buffers[0][1] is not None:
274
+ # All batches have token_ids if the first one does
275
+ token_ids = torch.cat([b[1] for b in buffers], dim=0) # type: ignore[arg-type]
276
+ shard = self._create_shard((acts, token_ids))
270
277
  shard.save_to_disk(
271
278
  f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
272
279
  )
273
- del buffer, shard
280
+ del buffers, acts, token_ids, shard
274
281
  except StopIteration:
275
282
  logger.warning(
276
283
  f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
@@ -148,6 +148,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
148
148
  seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
149
149
  disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
150
150
  sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
151
+ activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
151
152
  device (str): The device to use. Usually "cuda".
152
153
  act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
153
154
  seed (int): The seed to use.
@@ -217,6 +218,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
217
218
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
218
219
  special_token_field(default="bos")
219
220
  )
221
+ activations_mixing_fraction: float = 0.5
220
222
 
221
223
  # Misc
222
224
  device: str = "cpu"
@@ -959,7 +959,7 @@ def get_dictionary_learning_config_1_from_hf(
959
959
  architecture = "standard"
960
960
  if trainer["dict_class"] == "GatedAutoEncoder":
961
961
  architecture = "gated"
962
- elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
962
+ elif trainer["dict_class"] in ["MatryoshkaBatchTopKSAE", "BatchTopKSAE"]:
963
963
  architecture = "jumprelu"
964
964
 
965
965
  return {
@@ -1831,6 +1831,7 @@ def temporal_sae_huggingface_loader(
1831
1831
  Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1832
1832
 
1833
1833
  Expects folder_name to contain:
1834
+
1834
1835
  - conf.yaml (configuration)
1835
1836
  - latest_ckpt.safetensors (model weights)
1836
1837
  """
@@ -118,6 +118,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
118
118
  """
119
119
  GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
120
120
  It implements:
121
+
121
122
  - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
122
123
  - encode: calls encode_with_hidden_pre (standard training approach).
123
124
  - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
@@ -105,6 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
105
105
  activation function (e.g., ReLU etc.).
106
106
 
107
107
  It implements:
108
+
108
109
  - initialize_weights: sets up parameters, including a threshold.
109
110
  - encode: computes the feature activations using JumpReLU.
110
111
  - decode: reconstructs the input from the feature activations.
@@ -216,10 +217,12 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
216
217
  JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
217
218
 
218
219
  Similar to the inference-only JumpReLUSAE, but with:
220
+
219
221
  - A learnable log-threshold parameter (instead of a raw threshold).
220
222
  - A specialized auxiliary loss term for sparsity (L0 or similar).
221
223
 
222
224
  Methods of interest include:
225
+
223
226
  - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
224
227
  - encode_with_hidden_pre_jumprelu: runs a forward pass for training.
225
228
  - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
@@ -34,6 +34,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
34
34
  using a simple linear encoder and decoder.
35
35
 
36
36
  It implements the required abstract methods from BaseSAE:
37
+
37
38
  - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
38
39
  - encode: computes the feature activations from an input.
39
40
  - decode: reconstructs the input from the feature activations.
@@ -99,6 +100,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
99
100
  """
100
101
  StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
101
102
  It implements:
103
+
102
104
  - initialize_weights: basic weight initialization for encoder/decoder.
103
105
  - encode: inference encoding (invokes encode_with_hidden_pre).
104
106
  - decode: a simple linear decoder.
@@ -167,6 +167,7 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
167
167
  """TemporalSAE: Sparse Autoencoder with temporal attention.
168
168
 
169
169
  This SAE decomposes each activation x_t into:
170
+
170
171
  - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
171
172
  - x_novel: Novel information at position t (encoded sparsely)
172
173
 
@@ -0,0 +1,89 @@
1
+ """
2
+ Synthetic data utilities for SAE experiments.
3
+
4
+ This module provides tools for creating feature dictionaries and generating
5
+ synthetic activations for testing and experimenting with SAEs.
6
+
7
+ Main components:
8
+
9
+ - FeatureDictionary: Maps sparse feature activations to dense hidden activations
10
+ - ActivationGenerator: Generates batches of synthetic feature activations
11
+ - HierarchyNode: Enforces hierarchical structure on feature activations
12
+ - Training utilities: Helpers for training and evaluating SAEs on synthetic data
13
+ - Plotting utilities: Visualization helpers for understanding SAE behavior
14
+ """
15
+
16
+ from sae_lens.synthetic.activation_generator import (
17
+ ActivationGenerator,
18
+ ActivationsModifier,
19
+ ActivationsModifierInput,
20
+ )
21
+ from sae_lens.synthetic.correlation import (
22
+ create_correlation_matrix_from_correlations,
23
+ generate_random_correlation_matrix,
24
+ generate_random_correlations,
25
+ )
26
+ from sae_lens.synthetic.evals import (
27
+ SyntheticDataEvalResult,
28
+ eval_sae_on_synthetic_data,
29
+ mean_correlation_coefficient,
30
+ )
31
+ from sae_lens.synthetic.feature_dictionary import (
32
+ FeatureDictionary,
33
+ FeatureDictionaryInitializer,
34
+ orthogonal_initializer,
35
+ orthogonalize_embeddings,
36
+ )
37
+ from sae_lens.synthetic.firing_probabilities import (
38
+ linear_firing_probabilities,
39
+ random_firing_probabilities,
40
+ zipfian_firing_probabilities,
41
+ )
42
+ from sae_lens.synthetic.hierarchy import HierarchyNode, hierarchy_modifier
43
+ from sae_lens.synthetic.initialization import init_sae_to_match_feature_dict
44
+ from sae_lens.synthetic.plotting import (
45
+ find_best_feature_ordering,
46
+ find_best_feature_ordering_across_saes,
47
+ find_best_feature_ordering_from_sae,
48
+ plot_sae_feature_similarity,
49
+ )
50
+ from sae_lens.synthetic.training import (
51
+ SyntheticActivationIterator,
52
+ train_toy_sae,
53
+ )
54
+ from sae_lens.util import cosine_similarities
55
+
56
+ __all__ = [
57
+ # Main classes
58
+ "FeatureDictionary",
59
+ "HierarchyNode",
60
+ "hierarchy_modifier",
61
+ "ActivationGenerator",
62
+ # Activation generation
63
+ "zipfian_firing_probabilities",
64
+ "linear_firing_probabilities",
65
+ "random_firing_probabilities",
66
+ "create_correlation_matrix_from_correlations",
67
+ "generate_random_correlations",
68
+ "generate_random_correlation_matrix",
69
+ # Feature modifiers
70
+ "ActivationsModifier",
71
+ "ActivationsModifierInput",
72
+ # Utilities
73
+ "orthogonalize_embeddings",
74
+ "orthogonal_initializer",
75
+ "FeatureDictionaryInitializer",
76
+ "cosine_similarities",
77
+ # Training utilities
78
+ "SyntheticActivationIterator",
79
+ "SyntheticDataEvalResult",
80
+ "train_toy_sae",
81
+ "eval_sae_on_synthetic_data",
82
+ "mean_correlation_coefficient",
83
+ "init_sae_to_match_feature_dict",
84
+ # Plotting utilities
85
+ "find_best_feature_ordering",
86
+ "find_best_feature_ordering_from_sae",
87
+ "find_best_feature_ordering_across_saes",
88
+ "plot_sae_feature_similarity",
89
+ ]
@@ -0,0 +1,215 @@
1
+ """
2
+ Functions for generating synthetic feature activations.
3
+ """
4
+
5
+ from collections.abc import Callable, Sequence
6
+
7
+ import torch
8
+ from scipy.stats import norm
9
+ from torch import nn
10
+ from torch.distributions import MultivariateNormal
11
+
12
+ from sae_lens.util import str_to_dtype
13
+
14
+ ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
15
+ ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
16
+
17
+
18
+ class ActivationGenerator(nn.Module):
19
+ """
20
+ Generator for synthetic feature activations.
21
+
22
+ This module provides a generator for synthetic feature activations with controlled properties.
23
+ """
24
+
25
+ num_features: int
26
+ firing_probabilities: torch.Tensor
27
+ std_firing_magnitudes: torch.Tensor
28
+ mean_firing_magnitudes: torch.Tensor
29
+ modify_activations: ActivationsModifier | None
30
+ correlation_matrix: torch.Tensor | None
31
+ correlation_thresholds: torch.Tensor | None
32
+
33
+ def __init__(
34
+ self,
35
+ num_features: int,
36
+ firing_probabilities: torch.Tensor | float,
37
+ std_firing_magnitudes: torch.Tensor | float = 0.0,
38
+ mean_firing_magnitudes: torch.Tensor | float = 1.0,
39
+ modify_activations: ActivationsModifierInput = None,
40
+ correlation_matrix: torch.Tensor | None = None,
41
+ device: torch.device | str = "cpu",
42
+ dtype: torch.dtype | str = "float32",
43
+ ):
44
+ super().__init__()
45
+ self.num_features = num_features
46
+ self.firing_probabilities = _to_tensor(
47
+ firing_probabilities, num_features, device, dtype
48
+ )
49
+ self.std_firing_magnitudes = _to_tensor(
50
+ std_firing_magnitudes, num_features, device, dtype
51
+ )
52
+ self.mean_firing_magnitudes = _to_tensor(
53
+ mean_firing_magnitudes, num_features, device, dtype
54
+ )
55
+ self.modify_activations = _normalize_modifiers(modify_activations)
56
+ self.correlation_thresholds = None
57
+ if correlation_matrix is not None:
58
+ _validate_correlation_matrix(correlation_matrix, num_features)
59
+ self.correlation_thresholds = torch.tensor(
60
+ [norm.ppf(1 - p.item()) for p in self.firing_probabilities],
61
+ device=device,
62
+ dtype=self.firing_probabilities.dtype,
63
+ )
64
+ self.correlation_matrix = correlation_matrix
65
+
66
+ def sample(self, batch_size: int) -> torch.Tensor:
67
+ """
68
+ Generate a batch of feature activations with controlled properties.
69
+
70
+ This is the main function for generating synthetic training data for SAEs.
71
+ Features fire independently according to their firing probabilities unless
72
+ a correlation matrix is provided.
73
+
74
+ Args:
75
+ batch_size: Number of samples to generate
76
+
77
+ Returns:
78
+ Tensor of shape [batch_size, num_features] with non-negative activations
79
+ """
80
+ # All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
81
+ # are on the same device from __init__ via _to_tensor()
82
+ device = self.firing_probabilities.device
83
+
84
+ if self.correlation_matrix is not None:
85
+ assert self.correlation_thresholds is not None
86
+ firing_features = _generate_correlated_features(
87
+ batch_size,
88
+ self.correlation_matrix,
89
+ self.correlation_thresholds,
90
+ device,
91
+ )
92
+ else:
93
+ firing_features = torch.bernoulli(
94
+ self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
95
+ )
96
+
97
+ firing_magnitude_delta = torch.normal(
98
+ torch.zeros_like(self.firing_probabilities)
99
+ .unsqueeze(0)
100
+ .expand(batch_size, -1),
101
+ self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
102
+ )
103
+ firing_magnitude_delta[firing_features == 0] = 0
104
+ feature_activations = (
105
+ firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
106
+ ).relu()
107
+
108
+ if self.modify_activations is not None:
109
+ feature_activations = self.modify_activations(feature_activations).relu()
110
+ return feature_activations
111
+
112
+ def forward(self, batch_size: int) -> torch.Tensor:
113
+ return self.sample(batch_size)
114
+
115
+
116
+ def _generate_correlated_features(
117
+ batch_size: int,
118
+ correlation_matrix: torch.Tensor,
119
+ thresholds: torch.Tensor,
120
+ device: torch.device,
121
+ ) -> torch.Tensor:
122
+ """
123
+ Generate correlated binary features using multivariate Gaussian sampling.
124
+
125
+ Uses the Gaussian copula approach: sample from a multivariate normal
126
+ distribution, then threshold to get binary features.
127
+
128
+ Args:
129
+ batch_size: Number of samples to generate
130
+ correlation_matrix: Correlation matrix between features
131
+ thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
132
+ device: Device to generate samples on
133
+
134
+ Returns:
135
+ Binary feature matrix of shape [batch_size, num_features]
136
+ """
137
+ num_features = correlation_matrix.shape[0]
138
+
139
+ mvn = MultivariateNormal(
140
+ loc=torch.zeros(num_features, device=device, dtype=thresholds.dtype),
141
+ covariance_matrix=correlation_matrix.to(device=device, dtype=thresholds.dtype),
142
+ )
143
+
144
+ gaussian_samples = mvn.sample((batch_size,))
145
+ return (gaussian_samples > thresholds.unsqueeze(0)).float()
146
+
147
+
148
+ def _to_tensor(
149
+ value: torch.Tensor | float,
150
+ num_features: int,
151
+ device: torch.device | str,
152
+ dtype: torch.dtype | str,
153
+ ) -> torch.Tensor:
154
+ dtype = str_to_dtype(dtype)
155
+ device = torch.device(device)
156
+ if not isinstance(value, torch.Tensor):
157
+ value = value * torch.ones(num_features, device=device, dtype=dtype)
158
+ if value.shape != (num_features,):
159
+ raise ValueError(
160
+ f"Value must be a tensor of shape ({num_features},) or a float. Got {value.shape}"
161
+ )
162
+ return value.to(device, dtype)
163
+
164
+
165
+ def _normalize_modifiers(
166
+ modify_activations: ActivationsModifierInput,
167
+ ) -> ActivationsModifier | None:
168
+ """Convert modifier input to a single modifier or None."""
169
+ if modify_activations is None:
170
+ return None
171
+ if callable(modify_activations):
172
+ return modify_activations
173
+ # It's a sequence of modifiers - chain them
174
+ modifiers = list(modify_activations)
175
+ if len(modifiers) == 0:
176
+ return None
177
+ if len(modifiers) == 1:
178
+ return modifiers[0]
179
+
180
+ def chained(activations: torch.Tensor) -> torch.Tensor:
181
+ result = activations
182
+ for modifier in modifiers:
183
+ result = modifier(result)
184
+ return result
185
+
186
+ return chained
187
+
188
+
189
+ def _validate_correlation_matrix(
190
+ correlation_matrix: torch.Tensor, num_features: int
191
+ ) -> None:
192
+ """Validate that a correlation matrix has correct properties.
193
+
194
+ Args:
195
+ correlation_matrix: The matrix to validate
196
+ num_features: Expected number of features (matrix should be [num_features, num_features])
197
+
198
+ Raises:
199
+ ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
200
+ """
201
+ expected_shape = (num_features, num_features)
202
+ if correlation_matrix.shape != expected_shape:
203
+ raise ValueError(
204
+ f"Correlation matrix must have shape {expected_shape}, "
205
+ f"got {tuple(correlation_matrix.shape)}"
206
+ )
207
+
208
+ diagonal = torch.diag(correlation_matrix)
209
+ if not torch.allclose(diagonal, torch.ones_like(diagonal)):
210
+ raise ValueError("Correlation matrix diagonal must be all 1s")
211
+
212
+ try:
213
+ torch.linalg.cholesky(correlation_matrix)
214
+ except RuntimeError as e:
215
+ raise ValueError("Correlation matrix must be positive definite") from e
@@ -0,0 +1,170 @@
1
+ import random
2
+
3
+ import torch
4
+
5
+
6
+ def create_correlation_matrix_from_correlations(
7
+ num_features: int,
8
+ correlations: dict[tuple[int, int], float] | None = None,
9
+ default_correlation: float = 0.0,
10
+ ) -> torch.Tensor:
11
+ """
12
+ Create a correlation matrix with specified pairwise correlations.
13
+
14
+ Args:
15
+ num_features: Number of features
16
+ correlations: Dict mapping (i, j) pairs to correlation values.
17
+ Pairs should have i < j.
18
+ default_correlation: Default correlation for unspecified pairs
19
+
20
+ Returns:
21
+ Correlation matrix of shape [num_features, num_features]
22
+ """
23
+ matrix = torch.eye(num_features) + default_correlation * (
24
+ 1 - torch.eye(num_features)
25
+ )
26
+
27
+ if correlations is not None:
28
+ for (i, j), corr in correlations.items():
29
+ matrix[i, j] = corr
30
+ matrix[j, i] = corr
31
+
32
+ # Ensure matrix is symmetric (numerical precision)
33
+ matrix = (matrix + matrix.T) / 2
34
+
35
+ # Check positive definiteness and fix if necessary
36
+ # Use eigvalsh for symmetric matrices (returns real eigenvalues)
37
+ eigenvals = torch.linalg.eigvalsh(matrix)
38
+ if torch.any(eigenvals < -1e-6):
39
+ matrix = _fix_correlation_matrix(matrix)
40
+
41
+ return matrix
42
+
43
+
44
+ def _fix_correlation_matrix(
45
+ matrix: torch.Tensor, min_eigenval: float = 1e-6
46
+ ) -> torch.Tensor:
47
+ """Fix a correlation matrix to be positive semi-definite."""
48
+ eigenvals, eigenvecs = torch.linalg.eigh(matrix)
49
+ eigenvals = torch.clamp(eigenvals, min=min_eigenval)
50
+ fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
51
+
52
+ diag_vals = torch.diag(fixed_matrix)
53
+ fixed_matrix = fixed_matrix / torch.sqrt(
54
+ diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
55
+ )
56
+ fixed_matrix.fill_diagonal_(1.0)
57
+
58
+ return fixed_matrix
59
+
60
+
61
+ def generate_random_correlations(
62
+ num_features: int,
63
+ positive_ratio: float = 0.5,
64
+ uncorrelated_ratio: float = 0.3,
65
+ min_correlation_strength: float = 0.1,
66
+ max_correlation_strength: float = 0.8,
67
+ seed: int | None = None,
68
+ ) -> dict[tuple[int, int], float]:
69
+ """
70
+ Generate random correlations between features with specified constraints.
71
+
72
+ Args:
73
+ num_features: Number of features
74
+ positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
75
+ uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
76
+ min_correlation_strength: Minimum absolute correlation strength
77
+ max_correlation_strength: Maximum absolute correlation strength
78
+ seed: Random seed for reproducibility
79
+
80
+ Returns:
81
+ Dictionary mapping (i, j) pairs to correlation values
82
+ """
83
+ # Use local random number generator to avoid side effects on global state
84
+ rng = random.Random(seed)
85
+
86
+ # Validate inputs
87
+ if not 0.0 <= positive_ratio <= 1.0:
88
+ raise ValueError("positive_ratio must be between 0.0 and 1.0")
89
+ if not 0.0 <= uncorrelated_ratio <= 1.0:
90
+ raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
91
+ if min_correlation_strength < 0:
92
+ raise ValueError("min_correlation_strength must be non-negative")
93
+ if max_correlation_strength > 1.0:
94
+ raise ValueError("max_correlation_strength must be <= 1.0")
95
+ if min_correlation_strength > max_correlation_strength:
96
+ raise ValueError("min_correlation_strength must be <= max_correlation_strength")
97
+
98
+ # Generate all possible feature pairs (i, j) where i < j
99
+ all_pairs = [
100
+ (i, j) for i in range(num_features) for j in range(i + 1, num_features)
101
+ ]
102
+ total_pairs = len(all_pairs)
103
+
104
+ if total_pairs == 0:
105
+ return {}
106
+
107
+ # Determine how many pairs to correlate vs leave uncorrelated
108
+ num_uncorrelated = int(total_pairs * uncorrelated_ratio)
109
+ num_correlated = total_pairs - num_uncorrelated
110
+
111
+ # Randomly select which pairs to correlate
112
+ correlated_pairs = rng.sample(all_pairs, num_correlated)
113
+
114
+ # For correlated pairs, determine positive vs negative
115
+ num_positive = int(num_correlated * positive_ratio)
116
+ num_negative = num_correlated - num_positive
117
+
118
+ # Assign signs
119
+ signs = [1] * num_positive + [-1] * num_negative
120
+ rng.shuffle(signs)
121
+
122
+ # Generate correlation strengths
123
+ correlations = {}
124
+ for pair, sign in zip(correlated_pairs, signs):
125
+ # Sample correlation strength uniformly from range
126
+ strength = rng.uniform(min_correlation_strength, max_correlation_strength)
127
+ correlations[pair] = sign * strength
128
+
129
+ return correlations
130
+
131
+
132
+ def generate_random_correlation_matrix(
133
+ num_features: int,
134
+ positive_ratio: float = 0.5,
135
+ uncorrelated_ratio: float = 0.3,
136
+ min_correlation_strength: float = 0.1,
137
+ max_correlation_strength: float = 0.8,
138
+ seed: int | None = None,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Generate a random correlation matrix with specified constraints.
142
+
143
+ This is a convenience function that combines generate_random_correlations()
144
+ and create_correlation_matrix_from_correlations() into a single call.
145
+
146
+ Args:
147
+ num_features: Number of features
148
+ positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
149
+ uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
150
+ min_correlation_strength: Minimum absolute correlation strength
151
+ max_correlation_strength: Maximum absolute correlation strength
152
+ seed: Random seed for reproducibility
153
+
154
+ Returns:
155
+ Random correlation matrix of shape [num_features, num_features]
156
+ """
157
+ # Generate random correlations
158
+ correlations = generate_random_correlations(
159
+ num_features=num_features,
160
+ positive_ratio=positive_ratio,
161
+ uncorrelated_ratio=uncorrelated_ratio,
162
+ min_correlation_strength=min_correlation_strength,
163
+ max_correlation_strength=max_correlation_strength,
164
+ seed=seed,
165
+ )
166
+
167
+ # Create and return correlation matrix
168
+ return create_correlation_matrix_from_correlations(
169
+ num_features=num_features, correlations=correlations
170
+ )