sae-lens 6.13.0__tar.gz → 6.13.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.13.0 → sae_lens-6.13.1}/PKG-INFO +1 -1
- {sae_lens-6.13.0 → sae_lens-6.13.1}/pyproject.toml +1 -1
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/tokenization_and_batching.py +20 -5
- {sae_lens-6.13.0 → sae_lens-6.13.1}/LICENSE +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/README.md +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/config.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.13.0 → sae_lens-6.13.1}/sae_lens/util.py +0 -0
|
@@ -68,7 +68,7 @@ def concat_and_batch_sequences(
|
|
|
68
68
|
) -> Generator[torch.Tensor, None, None]:
|
|
69
69
|
"""
|
|
70
70
|
Generator to concat token sequences together from the tokens_interator, yielding
|
|
71
|
-
|
|
71
|
+
sequences of size `context_size`. Batching across the batch dimension is handled by the caller.
|
|
72
72
|
|
|
73
73
|
Args:
|
|
74
74
|
tokens_iterator: An iterator which returns a 1D tensors of tokens
|
|
@@ -76,13 +76,28 @@ def concat_and_batch_sequences(
|
|
|
76
76
|
begin_batch_token_id: If provided, this token will be at position 0 of each batch
|
|
77
77
|
begin_sequence_token_id: If provided, this token will be the first token of each sequence
|
|
78
78
|
sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
|
|
79
|
-
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
|
|
79
|
+
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size (including BOS token if present)
|
|
80
80
|
max_batches: If not provided, the iterator will be run to completion.
|
|
81
81
|
"""
|
|
82
82
|
if disable_concat_sequences:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
if begin_batch_token_id and not begin_sequence_token_id:
|
|
84
|
+
begin_sequence_token_id = begin_batch_token_id
|
|
85
|
+
for sequence in tokens_iterator:
|
|
86
|
+
if (
|
|
87
|
+
begin_sequence_token_id is not None
|
|
88
|
+
and sequence[0] != begin_sequence_token_id
|
|
89
|
+
and len(sequence) >= context_size - 1
|
|
90
|
+
):
|
|
91
|
+
begin_sequence_token_id_tensor = torch.tensor(
|
|
92
|
+
[begin_sequence_token_id],
|
|
93
|
+
dtype=torch.long,
|
|
94
|
+
device=sequence.device,
|
|
95
|
+
)
|
|
96
|
+
sequence = torch.cat(
|
|
97
|
+
[begin_sequence_token_id_tensor, sequence[: context_size - 1]]
|
|
98
|
+
)
|
|
99
|
+
if len(sequence) >= context_size:
|
|
100
|
+
yield sequence[:context_size]
|
|
86
101
|
return
|
|
87
102
|
|
|
88
103
|
batch: torch.Tensor | None = None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|