lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +15 -2
- lalamo/data/__init__.py +0 -1
- lalamo/data/huggingface_message.py +1 -0
- lalamo/main.py +167 -18
- lalamo/message_processor.py +2 -3
- lalamo/model_import/common.py +120 -27
- lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo/model_import/decoder_configs/common.py +62 -21
- lalamo/model_import/decoder_configs/executorch.py +14 -9
- lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo/model_import/loaders/__init__.py +3 -2
- lalamo/model_import/loaders/executorch.py +24 -12
- lalamo/model_import/loaders/huggingface.py +258 -30
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +8 -2
- lalamo/model_import/model_specs/gemma.py +5 -1
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo/models/__init__.py +10 -0
- lalamo/models/common.py +81 -0
- lalamo/{language_model.py → models/language_model.py} +32 -49
- lalamo/models/router.py +59 -0
- lalamo/modules/__init__.py +33 -16
- lalamo/modules/classifier.py +339 -0
- lalamo/modules/common.py +6 -3
- lalamo/modules/decoder.py +52 -180
- lalamo/modules/mlp.py +28 -5
- lalamo/modules/normalization.py +13 -8
- lalamo/modules/token_mixers/attention.py +10 -6
- lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo/modules/transformer.py +273 -0
- lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
- lalamo/speculator/__init__.py +6 -2
- lalamo/speculator/estimator.py +91 -0
- lalamo/speculator/inference.py +28 -9
- lalamo/speculator/ngram.py +7 -3
- lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
- lalamo-0.5.4.dist-info/RECORD +88 -0
- lalamo-0.5.2.dist-info/RECORD +0 -80
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
import json
|
|
2
1
|
from collections.abc import Iterable
|
|
3
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass
|
|
4
3
|
from pathlib import Path
|
|
5
|
-
from typing import NamedTuple
|
|
4
|
+
from typing import NamedTuple
|
|
6
5
|
|
|
7
6
|
import equinox as eqx
|
|
8
7
|
import jax
|
|
@@ -10,14 +9,19 @@ import jax.numpy as jnp
|
|
|
10
9
|
from einops import rearrange
|
|
11
10
|
from jax import vmap
|
|
12
11
|
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
|
|
13
|
-
from tokenizers import Tokenizer
|
|
14
12
|
|
|
15
|
-
from lalamo.
|
|
16
|
-
from lalamo.
|
|
17
|
-
|
|
18
|
-
|
|
13
|
+
from lalamo.message_processor import AssistantMessage, Message, MessageProcessor
|
|
14
|
+
from lalamo.modules import (
|
|
15
|
+
Decoder,
|
|
16
|
+
DecoderConfig,
|
|
17
|
+
DecoderForwardPassConfig,
|
|
18
|
+
ForwardPassMode,
|
|
19
|
+
LalamoModule,
|
|
20
|
+
State,
|
|
21
|
+
)
|
|
19
22
|
from lalamo.sampling import SamplingPolicy, make_policy
|
|
20
|
-
|
|
23
|
+
|
|
24
|
+
from .common import TextModel, TextModelConfig
|
|
21
25
|
|
|
22
26
|
__all__ = [
|
|
23
27
|
"ForwardPassConfig",
|
|
@@ -71,46 +75,25 @@ class GenerationConfig:
|
|
|
71
75
|
|
|
72
76
|
|
|
73
77
|
@dataclass(frozen=True)
|
|
74
|
-
class LanguageModelConfig:
|
|
75
|
-
decoder_config: DecoderConfig
|
|
76
|
-
message_processor_config: MessageProcessorConfig
|
|
78
|
+
class LanguageModelConfig(TextModelConfig[DecoderConfig]):
|
|
77
79
|
generation_config: GenerationConfig
|
|
78
80
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
81
|
+
def init(
|
|
82
|
+
self,
|
|
83
|
+
model: LalamoModule,
|
|
84
|
+
message_processor: MessageProcessor,
|
|
85
|
+
) -> "LanguageModel":
|
|
86
|
+
assert isinstance(model, Decoder)
|
|
87
|
+
return LanguageModel(self, model, message_processor)
|
|
83
88
|
|
|
84
89
|
@classmethod
|
|
85
|
-
def
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
config_json = json.load(config_file)
|
|
90
|
-
config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
|
|
91
|
-
with open_safetensors(path / "model.safetensors") as (weights_dict, _):
|
|
92
|
-
weights = unflatten_parameters(weights_dict)
|
|
93
|
-
decoder = config.decoder_config.empty().import_weights(weights)
|
|
94
|
-
tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
|
|
95
|
-
message_processor = MessageProcessor(config.message_processor_config, tokenizer)
|
|
96
|
-
return cls(config, decoder, message_processor)
|
|
97
|
-
|
|
98
|
-
@property
|
|
99
|
-
def activation_precision(self) -> DTypeLike:
|
|
100
|
-
return self.decoder.activation_precision
|
|
101
|
-
|
|
102
|
-
def export_weights(self) -> ParameterTree:
|
|
103
|
-
return self.decoder.export_weights()
|
|
90
|
+
def load_model(cls, path: Path | str) -> "LanguageModel":
|
|
91
|
+
result = super().load_model(path)
|
|
92
|
+
assert isinstance(result, LanguageModel)
|
|
93
|
+
return result
|
|
104
94
|
|
|
105
|
-
def import_weights(
|
|
106
|
-
self,
|
|
107
|
-
weights: ParameterTree[Array],
|
|
108
|
-
) -> Self:
|
|
109
|
-
return replace(
|
|
110
|
-
self,
|
|
111
|
-
decoder=self.decoder.import_weights(weights),
|
|
112
|
-
)
|
|
113
95
|
|
|
96
|
+
class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
|
|
114
97
|
@property
|
|
115
98
|
def stop_token_ids(self) -> tuple[int, ...]:
|
|
116
99
|
return self.config.generation_config.stop_token_ids
|
|
@@ -129,11 +112,11 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
129
112
|
batch_size, sequence_length = token_ids.shape
|
|
130
113
|
token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
|
|
131
114
|
if state_capacity is not None:
|
|
132
|
-
state = self.
|
|
115
|
+
state = self.model.init_static_state(batch_size, state_capacity)
|
|
133
116
|
else:
|
|
134
117
|
state = None
|
|
135
118
|
|
|
136
|
-
decoder_outputs = self.
|
|
119
|
+
decoder_outputs = self.model(
|
|
137
120
|
token_ids,
|
|
138
121
|
token_positions,
|
|
139
122
|
state,
|
|
@@ -220,7 +203,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
220
203
|
else:
|
|
221
204
|
forward_pass_mode = ForwardPassMode.MULTI_TOKEN
|
|
222
205
|
|
|
223
|
-
decoder_outputs = self.
|
|
206
|
+
decoder_outputs = self.model(
|
|
224
207
|
next_token_ids[:, None],
|
|
225
208
|
next_token_indices[:, None],
|
|
226
209
|
state.state,
|
|
@@ -272,7 +255,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
272
255
|
key: PRNGKeyArray | None = None,
|
|
273
256
|
) -> AssistantMessage:
|
|
274
257
|
formatted_messages = self.message_processor.render_request(messages)
|
|
275
|
-
token_ids = jnp.array(self.message_processor.
|
|
258
|
+
token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)[None, :]
|
|
276
259
|
response_ids = self.generate_tokens(
|
|
277
260
|
token_ids,
|
|
278
261
|
sampling_policy,
|
|
@@ -292,7 +275,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
292
275
|
key: PRNGKeyArray | None = None,
|
|
293
276
|
) -> Iterable[str]:
|
|
294
277
|
formatted_messages = self.message_processor.render_request(messages)
|
|
295
|
-
token_ids = jnp.array(self.message_processor.
|
|
278
|
+
token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)
|
|
296
279
|
for token_id in self.stream_tokens(
|
|
297
280
|
token_ids,
|
|
298
281
|
sampling_policy,
|
|
@@ -352,7 +335,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
|
|
|
352
335
|
return
|
|
353
336
|
|
|
354
337
|
next_token_indices = state.last_token_indices + 1
|
|
355
|
-
decoder_outputs = self.
|
|
338
|
+
decoder_outputs = self.model(
|
|
356
339
|
next_token_id.reshape(1, 1),
|
|
357
340
|
next_token_indices.reshape(1, 1),
|
|
358
341
|
state.state,
|
lalamo/models/router.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax import Array
|
|
7
|
+
from jax import numpy as jnp
|
|
8
|
+
from jaxtyping import Float
|
|
9
|
+
|
|
10
|
+
from lalamo.message_processor import Message, MessageProcessor
|
|
11
|
+
from lalamo.modules import Classifier, ClassifierConfig, LalamoModule
|
|
12
|
+
|
|
13
|
+
from .common import TextModel, TextModelConfig
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"Router",
|
|
17
|
+
"RouterConfig",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class RouterConfig(TextModelConfig[ClassifierConfig]):
|
|
23
|
+
def init(
|
|
24
|
+
self,
|
|
25
|
+
model: LalamoModule,
|
|
26
|
+
message_processor: MessageProcessor,
|
|
27
|
+
) -> "Router":
|
|
28
|
+
assert isinstance(model, Classifier)
|
|
29
|
+
return Router(self, model, message_processor)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def load_model(cls, path: Path | str) -> "Router":
|
|
33
|
+
result = super().load_model(path)
|
|
34
|
+
assert isinstance(result, Router)
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Router(TextModel[RouterConfig, Classifier]):
|
|
39
|
+
def label_output_logits(self, logits: Float[Array, "batch logits"]) -> dict[str, Float[Array, " batch"]]:
|
|
40
|
+
output_labels = self.model.config.output_labels
|
|
41
|
+
probabilities = jax.nn.sigmoid(logits)
|
|
42
|
+
|
|
43
|
+
if output_labels is None:
|
|
44
|
+
output_labels = [f"class_{idx}" for idx in range(self.model.config.num_labels)]
|
|
45
|
+
|
|
46
|
+
assert probabilities.ndim == 2, f"Expected 2D array, got array of shape {logits.shape}"
|
|
47
|
+
|
|
48
|
+
return dict(zip(output_labels, jnp.unstack(probabilities, axis=1), strict=True))
|
|
49
|
+
|
|
50
|
+
def classify_chat(
|
|
51
|
+
self,
|
|
52
|
+
messages: Iterable[Message],
|
|
53
|
+
) -> dict[str, float]:
|
|
54
|
+
token_ids = jnp.array(self.message_processor.tokenize_request(messages), dtype=jnp.int32)[None, :]
|
|
55
|
+
_, sequence_length = token_ids.shape
|
|
56
|
+
token_positions = jnp.arange(sequence_length, dtype=jnp.int32)[None, :]
|
|
57
|
+
classifier_output = self.model(token_ids=token_ids, token_positions=token_positions)
|
|
58
|
+
|
|
59
|
+
return {k: float(v.item()) for k, v in self.label_output_logits(classifier_output.logits).items()}
|
lalamo/modules/__init__.py
CHANGED
|
@@ -1,12 +1,17 @@
|
|
|
1
1
|
from .activations import GELU, Activation, Identity, SiLU
|
|
2
|
-
from .
|
|
3
|
-
from .
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
2
|
+
from .classifier import Classifier, ClassifierConfig
|
|
3
|
+
from .common import (
|
|
4
|
+
ForwardPassMode,
|
|
5
|
+
LalamoModule,
|
|
6
|
+
PositionalEmbeddingSelector,
|
|
7
|
+
config_converter,
|
|
8
|
+
)
|
|
9
|
+
from .decoder import (
|
|
10
|
+
Decoder,
|
|
11
|
+
DecoderActivationTrace,
|
|
12
|
+
DecoderConfig,
|
|
13
|
+
DecoderForwardPassConfig,
|
|
14
|
+
DecoderResult,
|
|
10
15
|
)
|
|
11
16
|
from .embedding import (
|
|
12
17
|
EmbeddingBase,
|
|
@@ -45,7 +50,7 @@ from .mlp import (
|
|
|
45
50
|
RoutingFunction,
|
|
46
51
|
SoftmaxRouting,
|
|
47
52
|
)
|
|
48
|
-
from .normalization import
|
|
53
|
+
from .normalization import Normalization, NormalizationConfig, UpcastMode
|
|
49
54
|
from .rope import (
|
|
50
55
|
LinearScalingRoPEConfig,
|
|
51
56
|
LlamaRoPEConfig,
|
|
@@ -67,21 +72,26 @@ from .token_mixers import (
|
|
|
67
72
|
State,
|
|
68
73
|
StaticKVCacheLayer,
|
|
69
74
|
)
|
|
75
|
+
from .transformer import Transformer, TransformerConfig
|
|
76
|
+
from .transformer_layer import (
|
|
77
|
+
TransformerLayer,
|
|
78
|
+
TransformerLayerActivationTrace,
|
|
79
|
+
TransformerLayerConfig,
|
|
80
|
+
TransformerLayerForwardPassConfig,
|
|
81
|
+
TransformerLayerResult,
|
|
82
|
+
)
|
|
70
83
|
|
|
71
84
|
__all__ = [
|
|
72
85
|
"GELU",
|
|
73
86
|
"Activation",
|
|
74
87
|
"Attention",
|
|
75
88
|
"AttentionConfig",
|
|
89
|
+
"Classifier",
|
|
90
|
+
"ClassifierConfig",
|
|
76
91
|
"Decoder",
|
|
77
92
|
"DecoderActivationTrace",
|
|
78
93
|
"DecoderConfig",
|
|
79
94
|
"DecoderForwardPassConfig",
|
|
80
|
-
"DecoderLayer",
|
|
81
|
-
"DecoderLayerActivationTrace",
|
|
82
|
-
"DecoderLayerConfig",
|
|
83
|
-
"DecoderLayerForwardPassConfig",
|
|
84
|
-
"DecoderLayerResult",
|
|
85
95
|
"DecoderResult",
|
|
86
96
|
"DenseMLP",
|
|
87
97
|
"DenseMLPConfig",
|
|
@@ -113,14 +123,14 @@ __all__ = [
|
|
|
113
123
|
"Mamba2Config",
|
|
114
124
|
"MixtureOfExperts",
|
|
115
125
|
"MixtureOfExpertsConfig",
|
|
126
|
+
"Normalization",
|
|
127
|
+
"NormalizationConfig",
|
|
116
128
|
"PositionalEmbeddingSelector",
|
|
117
129
|
"PositionalEmbeddings",
|
|
118
130
|
"QLoRALinear",
|
|
119
131
|
"QLoRALinearConfig",
|
|
120
132
|
"QuantizedTiedEmbedding",
|
|
121
133
|
"QuantizedTiedEmbeddingConfig",
|
|
122
|
-
"RMSNorm",
|
|
123
|
-
"RMSNormConfig",
|
|
124
134
|
"RoPE",
|
|
125
135
|
"RoPEConfig",
|
|
126
136
|
"RoutingFunction",
|
|
@@ -132,6 +142,13 @@ __all__ = [
|
|
|
132
142
|
"StaticKVCacheLayer",
|
|
133
143
|
"TiedEmbedding",
|
|
134
144
|
"TiedEmbeddingConfig",
|
|
145
|
+
"Transformer",
|
|
146
|
+
"TransformerConfig",
|
|
147
|
+
"TransformerLayer",
|
|
148
|
+
"TransformerLayerActivationTrace",
|
|
149
|
+
"TransformerLayerConfig",
|
|
150
|
+
"TransformerLayerForwardPassConfig",
|
|
151
|
+
"TransformerLayerResult",
|
|
135
152
|
"UnscaledRoPEConfig",
|
|
136
153
|
"UntiedEmbedding",
|
|
137
154
|
"UntiedEmbeddingConfig",
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from typing import Self
|
|
5
|
+
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
import jax
|
|
8
|
+
from jax import numpy as jnp
|
|
9
|
+
from jax import vmap
|
|
10
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
11
|
+
|
|
12
|
+
from lalamo.common import ParameterTree
|
|
13
|
+
from lalamo.modules import Activation
|
|
14
|
+
from lalamo.modules.normalization import NormalizationConfig
|
|
15
|
+
from lalamo.modules.transformer import (
|
|
16
|
+
Normalization,
|
|
17
|
+
Transformer,
|
|
18
|
+
TransformerConfig,
|
|
19
|
+
TransformerForwardPassConfig,
|
|
20
|
+
)
|
|
21
|
+
from lalamo.modules.utils import vmap_twice
|
|
22
|
+
|
|
23
|
+
from .common import ForwardPassMode, LalamoModule
|
|
24
|
+
from .embedding import EmbeddingBase, EmbeddingConfig
|
|
25
|
+
from .linear import LinearBase, LinearConfig
|
|
26
|
+
from .rope import PositionalEmbeddings
|
|
27
|
+
from .transformer_layer import TransformerLayerResult
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"Classifier",
|
|
31
|
+
"ClassifierActivationTrace",
|
|
32
|
+
"ClassifierConfig",
|
|
33
|
+
"ClassifierResult",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PoolingType(StrEnum):
|
|
38
|
+
CLS = "cls"
|
|
39
|
+
MEAN = "mean"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class PredictionHeadConfig:
|
|
44
|
+
dense_config: LinearConfig
|
|
45
|
+
activation: Activation
|
|
46
|
+
normalization_config: NormalizationConfig
|
|
47
|
+
readout_config: LinearConfig
|
|
48
|
+
use_dense_bias: bool
|
|
49
|
+
|
|
50
|
+
def empty(self, input_size: int, num_labels: int) -> "PredictionHead":
|
|
51
|
+
dense_layer = self.dense_config.empty(
|
|
52
|
+
input_dim=input_size,
|
|
53
|
+
output_dims=(input_size,),
|
|
54
|
+
has_biases=self.use_dense_bias,
|
|
55
|
+
)
|
|
56
|
+
norm = self.normalization_config.empty(input_size)
|
|
57
|
+
readout = self.readout_config.empty(input_dim=input_size, output_dims=(num_labels,), has_biases=True)
|
|
58
|
+
|
|
59
|
+
return PredictionHead(
|
|
60
|
+
config=self,
|
|
61
|
+
dense=dense_layer,
|
|
62
|
+
activation=self.activation,
|
|
63
|
+
norm=norm,
|
|
64
|
+
readout=readout,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def random_init(self, input_size: int, num_labels: int, key: PRNGKeyArray) -> "PredictionHead":
|
|
68
|
+
dense_key, readout_key = jax.random.split(key)
|
|
69
|
+
dense_layer = self.dense_config.random_init(
|
|
70
|
+
input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key
|
|
71
|
+
)
|
|
72
|
+
norm = self.normalization_config.empty(input_size)
|
|
73
|
+
readout = self.readout_config.random_init(
|
|
74
|
+
input_dim=input_size,
|
|
75
|
+
output_dims=(num_labels,),
|
|
76
|
+
has_biases=True,
|
|
77
|
+
key=readout_key,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return PredictionHead(
|
|
81
|
+
config=self,
|
|
82
|
+
dense=dense_layer,
|
|
83
|
+
activation=self.activation,
|
|
84
|
+
norm=norm,
|
|
85
|
+
readout=readout,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class PredictionHead(LalamoModule[PredictionHeadConfig]):
|
|
90
|
+
dense: LinearBase
|
|
91
|
+
activation: Activation
|
|
92
|
+
norm: Normalization
|
|
93
|
+
readout: LinearBase
|
|
94
|
+
|
|
95
|
+
def __call__(self, inner_features: Float[Array, "batch channels"]) -> Float[Array, "batch logits"]:
|
|
96
|
+
return vmap(self.call_unbatched)(inner_features)
|
|
97
|
+
|
|
98
|
+
def call_unbatched(
|
|
99
|
+
self,
|
|
100
|
+
inner_features: Float[Array, " in_channels"],
|
|
101
|
+
) -> Float[Array, " logits"]:
|
|
102
|
+
(dense_outs,) = self.dense(inner_features)
|
|
103
|
+
dense_outs = self.activation(dense_outs)
|
|
104
|
+
norm_outs = self.norm(dense_outs)
|
|
105
|
+
(result,) = self.readout(norm_outs)
|
|
106
|
+
return result
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def activation_precision(self) -> DTypeLike:
|
|
110
|
+
return self.dense.activation_precision
|
|
111
|
+
|
|
112
|
+
def export_weights(self) -> ParameterTree:
|
|
113
|
+
result = dict(
|
|
114
|
+
dense=self.dense.export_weights(),
|
|
115
|
+
norm=self.norm.export_weights(),
|
|
116
|
+
readout=self.readout.export_weights(),
|
|
117
|
+
)
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
def import_weights(
|
|
121
|
+
self,
|
|
122
|
+
weights: ParameterTree[Array],
|
|
123
|
+
) -> Self:
|
|
124
|
+
assert isinstance(weights, Mapping)
|
|
125
|
+
assert isinstance(weights["dense"], Mapping)
|
|
126
|
+
assert isinstance(weights["norm"], Mapping)
|
|
127
|
+
assert isinstance(weights["readout"], Mapping)
|
|
128
|
+
return replace(
|
|
129
|
+
self,
|
|
130
|
+
dense=self.dense.import_weights(weights["dense"]),
|
|
131
|
+
norm=self.norm.import_weights(weights["norm"]),
|
|
132
|
+
readout=self.readout.import_weights(weights["readout"]),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ClassifierActivationTrace(eqx.Module):
|
|
137
|
+
token_ids: Int[Array, "batch tokens"]
|
|
138
|
+
token_positions: Int[Array, "batch tokens"]
|
|
139
|
+
|
|
140
|
+
local_positional_embeddings: PositionalEmbeddings
|
|
141
|
+
global_positional_embeddings: PositionalEmbeddings
|
|
142
|
+
|
|
143
|
+
embedding_norm_output: Float[Array, "batch tokens channels"]
|
|
144
|
+
layer_results: tuple[TransformerLayerResult, ...]
|
|
145
|
+
output_norm: Float[Array, "batch tokens channels"]
|
|
146
|
+
output_pooling: Float[Array, "batch channels"]
|
|
147
|
+
logits: Float[Array, "batch logits"]
|
|
148
|
+
|
|
149
|
+
def export(self) -> ParameterTree:
|
|
150
|
+
result = dict(
|
|
151
|
+
token_ids=self.token_ids,
|
|
152
|
+
token_positions=self.token_positions,
|
|
153
|
+
local_positional_embeddings=self.local_positional_embeddings.export(),
|
|
154
|
+
global_positional_embeddings=self.global_positional_embeddings.export(),
|
|
155
|
+
layer_results=[layer_result.export() for layer_result in self.layer_results],
|
|
156
|
+
output_norm=self.output_norm,
|
|
157
|
+
output_pooling=self.output_pooling,
|
|
158
|
+
logits=self.logits,
|
|
159
|
+
)
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ClassifierResult(eqx.Module):
|
|
164
|
+
logits: Float[Array, "batch logits"]
|
|
165
|
+
activation_trace: ClassifierActivationTrace | None = None
|
|
166
|
+
|
|
167
|
+
def export(self) -> ParameterTree:
|
|
168
|
+
result: dict[str, ParameterTree | Array] = dict(
|
|
169
|
+
logits=self.logits,
|
|
170
|
+
)
|
|
171
|
+
if self.activation_trace is not None:
|
|
172
|
+
result["activation_trace"] = self.activation_trace.export()
|
|
173
|
+
return result
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@dataclass(frozen=True)
|
|
177
|
+
class ClassifierConfig:
|
|
178
|
+
embedding_config: EmbeddingConfig
|
|
179
|
+
embedding_norm_config: NormalizationConfig
|
|
180
|
+
transformer_config: TransformerConfig
|
|
181
|
+
prediction_head_config: PredictionHeadConfig
|
|
182
|
+
readout_config: LinearConfig
|
|
183
|
+
|
|
184
|
+
vocab_size: int
|
|
185
|
+
model_dim: int
|
|
186
|
+
hidden_dim: int
|
|
187
|
+
attention_scale: float | None
|
|
188
|
+
num_layers: int
|
|
189
|
+
context_length: int
|
|
190
|
+
num_labels: int
|
|
191
|
+
classifier_pooling: PoolingType
|
|
192
|
+
|
|
193
|
+
output_labels: tuple[str, ...] | None
|
|
194
|
+
|
|
195
|
+
def random_init(
|
|
196
|
+
self,
|
|
197
|
+
*,
|
|
198
|
+
key: PRNGKeyArray,
|
|
199
|
+
) -> "Classifier":
|
|
200
|
+
embedding_key, transformer_key, prediction_head_key = jax.random.split(key, num=3)
|
|
201
|
+
embedding = self.embedding_config.random_init(
|
|
202
|
+
vocab_size=self.vocab_size,
|
|
203
|
+
model_dim=self.model_dim,
|
|
204
|
+
key=embedding_key,
|
|
205
|
+
)
|
|
206
|
+
embedding_norm = self.embedding_norm_config.empty(self.model_dim)
|
|
207
|
+
transformer = self.transformer_config.random_init(
|
|
208
|
+
key=transformer_key,
|
|
209
|
+
)
|
|
210
|
+
prediction_head = self.prediction_head_config.random_init(
|
|
211
|
+
input_size=self.hidden_dim,
|
|
212
|
+
num_labels=self.num_labels,
|
|
213
|
+
key=prediction_head_key,
|
|
214
|
+
)
|
|
215
|
+
return Classifier(
|
|
216
|
+
self,
|
|
217
|
+
embedding=embedding,
|
|
218
|
+
embedding_norm=embedding_norm,
|
|
219
|
+
transformer=transformer,
|
|
220
|
+
prediction_head=prediction_head,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def empty(self) -> "Classifier":
|
|
224
|
+
embedding = self.embedding_config.empty(
|
|
225
|
+
vocab_size=self.vocab_size,
|
|
226
|
+
model_dim=self.model_dim,
|
|
227
|
+
)
|
|
228
|
+
embedding_norm = self.embedding_norm_config.empty(self.model_dim)
|
|
229
|
+
transformer = self.transformer_config.empty()
|
|
230
|
+
prediction_head = self.prediction_head_config.empty(
|
|
231
|
+
input_size=self.hidden_dim,
|
|
232
|
+
num_labels=self.num_labels,
|
|
233
|
+
)
|
|
234
|
+
return Classifier(
|
|
235
|
+
self,
|
|
236
|
+
embedding=embedding,
|
|
237
|
+
embedding_norm=embedding_norm,
|
|
238
|
+
transformer=transformer,
|
|
239
|
+
prediction_head=prediction_head,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class Classifier(LalamoModule[ClassifierConfig]):
|
|
244
|
+
embedding: EmbeddingBase
|
|
245
|
+
embedding_norm: Normalization
|
|
246
|
+
transformer: Transformer
|
|
247
|
+
prediction_head: PredictionHead
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def activation_precision(self) -> DTypeLike:
|
|
251
|
+
return self.embedding.activation_precision
|
|
252
|
+
|
|
253
|
+
def __post_init__(self) -> None:
|
|
254
|
+
if self.config.output_labels is not None and len(self.config.output_labels) != self.config.num_labels:
|
|
255
|
+
raise ValueError("Number of output logits is different from provided list of labels")
|
|
256
|
+
|
|
257
|
+
@eqx.filter_jit
|
|
258
|
+
def __call__(
|
|
259
|
+
self,
|
|
260
|
+
token_ids: Int[Array, "batch tokens"],
|
|
261
|
+
token_positions: Int[Array, "batch tokens"],
|
|
262
|
+
return_activation_trace: bool = False,
|
|
263
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
264
|
+
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
265
|
+
forward_pass_config: TransformerForwardPassConfig | None = None,
|
|
266
|
+
) -> ClassifierResult:
|
|
267
|
+
inner_features = self.embedding.embed(token_ids)
|
|
268
|
+
normalized_embeddings = vmap_twice(self.embedding_norm)(inner_features)
|
|
269
|
+
|
|
270
|
+
transformer_result = self.transformer(
|
|
271
|
+
inner_features=normalized_embeddings,
|
|
272
|
+
token_positions=token_positions,
|
|
273
|
+
state=None,
|
|
274
|
+
return_updated_state=False,
|
|
275
|
+
return_layer_results=return_activation_trace,
|
|
276
|
+
return_positional_embeddings=return_activation_trace,
|
|
277
|
+
lengths_without_padding=lengths_without_padding,
|
|
278
|
+
forward_pass_mode=forward_pass_mode,
|
|
279
|
+
forward_pass_config=forward_pass_config,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if self.config.classifier_pooling == PoolingType.CLS:
|
|
283
|
+
pooled_output = transformer_result.outputs[:, 0, :]
|
|
284
|
+
elif self.config.classifier_pooling == PoolingType.MEAN:
|
|
285
|
+
attention_mask = jnp.ones((*token_ids.shape, 1), dtype=transformer_result.outputs.dtype)
|
|
286
|
+
pooled_output = (transformer_result.outputs * attention_mask).sum(axis=1) / attention_mask.sum(axis=1)
|
|
287
|
+
else:
|
|
288
|
+
raise TypeError(f"classifier_pooling of unknown type: {self.config.classifier_pooling}")
|
|
289
|
+
|
|
290
|
+
logits = self.prediction_head(pooled_output)
|
|
291
|
+
|
|
292
|
+
if return_activation_trace:
|
|
293
|
+
assert transformer_result.layer_results is not None
|
|
294
|
+
assert transformer_result.global_positional_embeddings is not None
|
|
295
|
+
assert transformer_result.local_positional_embeddings is not None
|
|
296
|
+
activation_trace = ClassifierActivationTrace(
|
|
297
|
+
token_ids=token_ids,
|
|
298
|
+
token_positions=token_positions,
|
|
299
|
+
global_positional_embeddings=transformer_result.global_positional_embeddings,
|
|
300
|
+
local_positional_embeddings=transformer_result.local_positional_embeddings,
|
|
301
|
+
embedding_norm_output=normalized_embeddings,
|
|
302
|
+
layer_results=tuple(transformer_result.layer_results),
|
|
303
|
+
output_norm=transformer_result.outputs,
|
|
304
|
+
output_pooling=pooled_output,
|
|
305
|
+
logits=logits,
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
activation_trace = None
|
|
309
|
+
|
|
310
|
+
return ClassifierResult(
|
|
311
|
+
logits=logits,
|
|
312
|
+
activation_trace=activation_trace,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def export_weights(self) -> ParameterTree:
|
|
316
|
+
result = dict(
|
|
317
|
+
embedding=self.embedding.export_weights(),
|
|
318
|
+
embedding_norm=self.embedding_norm.export_weights(),
|
|
319
|
+
transformer=self.transformer.export_weights(),
|
|
320
|
+
prediction_head=self.prediction_head.export_weights(),
|
|
321
|
+
)
|
|
322
|
+
return result
|
|
323
|
+
|
|
324
|
+
def import_weights(
|
|
325
|
+
self,
|
|
326
|
+
weights: ParameterTree[Array],
|
|
327
|
+
) -> Self:
|
|
328
|
+
assert isinstance(weights, Mapping)
|
|
329
|
+
assert isinstance(weights["embedding"], Mapping)
|
|
330
|
+
assert isinstance(weights["embedding_norm"], Mapping)
|
|
331
|
+
assert isinstance(weights["transformer"], Mapping)
|
|
332
|
+
assert isinstance(weights["prediction_head"], Mapping)
|
|
333
|
+
return replace(
|
|
334
|
+
self,
|
|
335
|
+
embedding=self.embedding.import_weights(weights["embedding"]),
|
|
336
|
+
embedding_norm=self.embedding_norm.import_weights(weights["embedding_norm"]),
|
|
337
|
+
transformer=self.transformer.import_weights(weights["transformer"]),
|
|
338
|
+
prediction_head=self.prediction_head.import_weights(weights["prediction_head"]),
|
|
339
|
+
)
|
lalamo/modules/common.py
CHANGED
|
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from types import UnionType
|
|
5
|
-
from typing import Any, Self
|
|
5
|
+
from typing import Any, Generic, Self, TypeVar
|
|
6
6
|
|
|
7
7
|
import equinox as eqx
|
|
8
8
|
from cattrs import Converter
|
|
@@ -32,8 +32,11 @@ class ForwardPassMode(Enum):
|
|
|
32
32
|
SINGLE_TOKEN = "single_token"
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
35
|
+
ConfigT_co = TypeVar("ConfigT_co", covariant=True)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LalamoModule(eqx.Module, Generic[ConfigT_co]): # noqa: UP046
|
|
39
|
+
config: ConfigT_co = eqx.field(static=True)
|
|
37
40
|
|
|
38
41
|
@property
|
|
39
42
|
@abstractmethod
|