lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +15 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +6 -2
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
  49. lalamo-0.5.4.dist-info/RECORD +88 -0
  50. lalamo-0.5.2.dist-info/RECORD +0 -80
  51. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
- import json
2
1
  from collections.abc import Iterable
3
- from dataclasses import dataclass, replace
2
+ from dataclasses import dataclass
4
3
  from pathlib import Path
5
- from typing import NamedTuple, Self
4
+ from typing import NamedTuple
6
5
 
7
6
  import equinox as eqx
8
7
  import jax
@@ -10,14 +9,19 @@ import jax.numpy as jnp
10
9
  from einops import rearrange
11
10
  from jax import vmap
12
11
  from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
13
- from tokenizers import Tokenizer
14
12
 
15
- from lalamo.common import DTypeLike, ParameterTree, unflatten_parameters
16
- from lalamo.message_processor import AssistantMessage, Message, MessageProcessor, MessageProcessorConfig
17
- from lalamo.modules import Decoder, DecoderConfig, ForwardPassMode, LalamoModule, State, config_converter
18
- from lalamo.modules.decoder import DecoderForwardPassConfig
13
+ from lalamo.message_processor import AssistantMessage, Message, MessageProcessor
14
+ from lalamo.modules import (
15
+ Decoder,
16
+ DecoderConfig,
17
+ DecoderForwardPassConfig,
18
+ ForwardPassMode,
19
+ LalamoModule,
20
+ State,
21
+ )
19
22
  from lalamo.sampling import SamplingPolicy, make_policy
20
- from lalamo.utils import open_safetensors
23
+
24
+ from .common import TextModel, TextModelConfig
21
25
 
22
26
  __all__ = [
23
27
  "ForwardPassConfig",
@@ -71,46 +75,25 @@ class GenerationConfig:
71
75
 
72
76
 
73
77
  @dataclass(frozen=True)
74
- class LanguageModelConfig:
75
- decoder_config: DecoderConfig
76
- message_processor_config: MessageProcessorConfig
78
+ class LanguageModelConfig(TextModelConfig[DecoderConfig]):
77
79
  generation_config: GenerationConfig
78
80
 
79
-
80
- class LanguageModel(LalamoModule[LanguageModelConfig]):
81
- decoder: Decoder
82
- message_processor: MessageProcessor = eqx.field(static=True)
81
+ def init(
82
+ self,
83
+ model: LalamoModule,
84
+ message_processor: MessageProcessor,
85
+ ) -> "LanguageModel":
86
+ assert isinstance(model, Decoder)
87
+ return LanguageModel(self, model, message_processor)
83
88
 
84
89
  @classmethod
85
- def load(cls, path: Path | str) -> Self:
86
- if isinstance(path, str):
87
- path = Path(path)
88
- with open(path / "config.json") as config_file:
89
- config_json = json.load(config_file)
90
- config = config_converter.structure(config_json["model_config"], LanguageModelConfig)
91
- with open_safetensors(path / "model.safetensors") as (weights_dict, _):
92
- weights = unflatten_parameters(weights_dict)
93
- decoder = config.decoder_config.empty().import_weights(weights)
94
- tokenizer = Tokenizer.from_file(str(path / "tokenizer.json"))
95
- message_processor = MessageProcessor(config.message_processor_config, tokenizer)
96
- return cls(config, decoder, message_processor)
97
-
98
- @property
99
- def activation_precision(self) -> DTypeLike:
100
- return self.decoder.activation_precision
101
-
102
- def export_weights(self) -> ParameterTree:
103
- return self.decoder.export_weights()
90
+ def load_model(cls, path: Path | str) -> "LanguageModel":
91
+ result = super().load_model(path)
92
+ assert isinstance(result, LanguageModel)
93
+ return result
104
94
 
105
- def import_weights(
106
- self,
107
- weights: ParameterTree[Array],
108
- ) -> Self:
109
- return replace(
110
- self,
111
- decoder=self.decoder.import_weights(weights),
112
- )
113
95
 
96
+ class LanguageModel(TextModel[LanguageModelConfig, Decoder]):
114
97
  @property
115
98
  def stop_token_ids(self) -> tuple[int, ...]:
116
99
  return self.config.generation_config.stop_token_ids
@@ -129,11 +112,11 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
129
112
  batch_size, sequence_length = token_ids.shape
130
113
  token_positions = jnp.repeat(jnp.arange(sequence_length, dtype=jnp.int32)[None, ...], batch_size, axis=0)
131
114
  if state_capacity is not None:
132
- state = self.decoder.init_static_state(batch_size, state_capacity)
115
+ state = self.model.init_static_state(batch_size, state_capacity)
133
116
  else:
134
117
  state = None
135
118
 
136
- decoder_outputs = self.decoder(
119
+ decoder_outputs = self.model(
137
120
  token_ids,
138
121
  token_positions,
139
122
  state,
@@ -220,7 +203,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
220
203
  else:
221
204
  forward_pass_mode = ForwardPassMode.MULTI_TOKEN
222
205
 
223
- decoder_outputs = self.decoder(
206
+ decoder_outputs = self.model(
224
207
  next_token_ids[:, None],
225
208
  next_token_indices[:, None],
226
209
  state.state,
@@ -272,7 +255,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
272
255
  key: PRNGKeyArray | None = None,
273
256
  ) -> AssistantMessage:
274
257
  formatted_messages = self.message_processor.render_request(messages)
275
- token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)[None, :]
258
+ token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)[None, :]
276
259
  response_ids = self.generate_tokens(
277
260
  token_ids,
278
261
  sampling_policy,
@@ -292,7 +275,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
292
275
  key: PRNGKeyArray | None = None,
293
276
  ) -> Iterable[str]:
294
277
  formatted_messages = self.message_processor.render_request(messages)
295
- token_ids = jnp.array(self.message_processor.tokenize(formatted_messages), dtype=jnp.int32)
278
+ token_ids = jnp.array(self.message_processor.tokenize_text(formatted_messages), dtype=jnp.int32)
296
279
  for token_id in self.stream_tokens(
297
280
  token_ids,
298
281
  sampling_policy,
@@ -352,7 +335,7 @@ class LanguageModel(LalamoModule[LanguageModelConfig]):
352
335
  return
353
336
 
354
337
  next_token_indices = state.last_token_indices + 1
355
- decoder_outputs = self.decoder(
338
+ decoder_outputs = self.model(
356
339
  next_token_id.reshape(1, 1),
357
340
  next_token_indices.reshape(1, 1),
358
341
  state.state,
@@ -0,0 +1,59 @@
1
+ from collections.abc import Iterable
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+
5
+ import jax
6
+ from jax import Array
7
+ from jax import numpy as jnp
8
+ from jaxtyping import Float
9
+
10
+ from lalamo.message_processor import Message, MessageProcessor
11
+ from lalamo.modules import Classifier, ClassifierConfig, LalamoModule
12
+
13
+ from .common import TextModel, TextModelConfig
14
+
15
+ __all__ = [
16
+ "Router",
17
+ "RouterConfig",
18
+ ]
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class RouterConfig(TextModelConfig[ClassifierConfig]):
23
+ def init(
24
+ self,
25
+ model: LalamoModule,
26
+ message_processor: MessageProcessor,
27
+ ) -> "Router":
28
+ assert isinstance(model, Classifier)
29
+ return Router(self, model, message_processor)
30
+
31
+ @classmethod
32
+ def load_model(cls, path: Path | str) -> "Router":
33
+ result = super().load_model(path)
34
+ assert isinstance(result, Router)
35
+ return result
36
+
37
+
38
+ class Router(TextModel[RouterConfig, Classifier]):
39
+ def label_output_logits(self, logits: Float[Array, "batch logits"]) -> dict[str, Float[Array, " batch"]]:
40
+ output_labels = self.model.config.output_labels
41
+ probabilities = jax.nn.sigmoid(logits)
42
+
43
+ if output_labels is None:
44
+ output_labels = [f"class_{idx}" for idx in range(self.model.config.num_labels)]
45
+
46
+ assert probabilities.ndim == 2, f"Expected 2D array, got array of shape {logits.shape}"
47
+
48
+ return dict(zip(output_labels, jnp.unstack(probabilities, axis=1), strict=True))
49
+
50
+ def classify_chat(
51
+ self,
52
+ messages: Iterable[Message],
53
+ ) -> dict[str, float]:
54
+ token_ids = jnp.array(self.message_processor.tokenize_request(messages), dtype=jnp.int32)[None, :]
55
+ _, sequence_length = token_ids.shape
56
+ token_positions = jnp.arange(sequence_length, dtype=jnp.int32)[None, :]
57
+ classifier_output = self.model(token_ids=token_ids, token_positions=token_positions)
58
+
59
+ return {k: float(v.item()) for k, v in self.label_output_logits(classifier_output.logits).items()}
@@ -1,12 +1,17 @@
1
1
  from .activations import GELU, Activation, Identity, SiLU
2
- from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector, config_converter
3
- from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderForwardPassConfig, DecoderResult
4
- from .decoder_layer import (
5
- DecoderLayer,
6
- DecoderLayerActivationTrace,
7
- DecoderLayerConfig,
8
- DecoderLayerForwardPassConfig,
9
- DecoderLayerResult,
2
+ from .classifier import Classifier, ClassifierConfig
3
+ from .common import (
4
+ ForwardPassMode,
5
+ LalamoModule,
6
+ PositionalEmbeddingSelector,
7
+ config_converter,
8
+ )
9
+ from .decoder import (
10
+ Decoder,
11
+ DecoderActivationTrace,
12
+ DecoderConfig,
13
+ DecoderForwardPassConfig,
14
+ DecoderResult,
10
15
  )
11
16
  from .embedding import (
12
17
  EmbeddingBase,
@@ -45,7 +50,7 @@ from .mlp import (
45
50
  RoutingFunction,
46
51
  SoftmaxRouting,
47
52
  )
48
- from .normalization import RMSNorm, RMSNormConfig, UpcastMode
53
+ from .normalization import Normalization, NormalizationConfig, UpcastMode
49
54
  from .rope import (
50
55
  LinearScalingRoPEConfig,
51
56
  LlamaRoPEConfig,
@@ -67,21 +72,26 @@ from .token_mixers import (
67
72
  State,
68
73
  StaticKVCacheLayer,
69
74
  )
75
+ from .transformer import Transformer, TransformerConfig
76
+ from .transformer_layer import (
77
+ TransformerLayer,
78
+ TransformerLayerActivationTrace,
79
+ TransformerLayerConfig,
80
+ TransformerLayerForwardPassConfig,
81
+ TransformerLayerResult,
82
+ )
70
83
 
71
84
  __all__ = [
72
85
  "GELU",
73
86
  "Activation",
74
87
  "Attention",
75
88
  "AttentionConfig",
89
+ "Classifier",
90
+ "ClassifierConfig",
76
91
  "Decoder",
77
92
  "DecoderActivationTrace",
78
93
  "DecoderConfig",
79
94
  "DecoderForwardPassConfig",
80
- "DecoderLayer",
81
- "DecoderLayerActivationTrace",
82
- "DecoderLayerConfig",
83
- "DecoderLayerForwardPassConfig",
84
- "DecoderLayerResult",
85
95
  "DecoderResult",
86
96
  "DenseMLP",
87
97
  "DenseMLPConfig",
@@ -113,14 +123,14 @@ __all__ = [
113
123
  "Mamba2Config",
114
124
  "MixtureOfExperts",
115
125
  "MixtureOfExpertsConfig",
126
+ "Normalization",
127
+ "NormalizationConfig",
116
128
  "PositionalEmbeddingSelector",
117
129
  "PositionalEmbeddings",
118
130
  "QLoRALinear",
119
131
  "QLoRALinearConfig",
120
132
  "QuantizedTiedEmbedding",
121
133
  "QuantizedTiedEmbeddingConfig",
122
- "RMSNorm",
123
- "RMSNormConfig",
124
134
  "RoPE",
125
135
  "RoPEConfig",
126
136
  "RoutingFunction",
@@ -132,6 +142,13 @@ __all__ = [
132
142
  "StaticKVCacheLayer",
133
143
  "TiedEmbedding",
134
144
  "TiedEmbeddingConfig",
145
+ "Transformer",
146
+ "TransformerConfig",
147
+ "TransformerLayer",
148
+ "TransformerLayerActivationTrace",
149
+ "TransformerLayerConfig",
150
+ "TransformerLayerForwardPassConfig",
151
+ "TransformerLayerResult",
135
152
  "UnscaledRoPEConfig",
136
153
  "UntiedEmbedding",
137
154
  "UntiedEmbeddingConfig",
@@ -0,0 +1,339 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass, replace
3
+ from enum import StrEnum
4
+ from typing import Self
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ from jax import numpy as jnp
9
+ from jax import vmap
10
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
11
+
12
+ from lalamo.common import ParameterTree
13
+ from lalamo.modules import Activation
14
+ from lalamo.modules.normalization import NormalizationConfig
15
+ from lalamo.modules.transformer import (
16
+ Normalization,
17
+ Transformer,
18
+ TransformerConfig,
19
+ TransformerForwardPassConfig,
20
+ )
21
+ from lalamo.modules.utils import vmap_twice
22
+
23
+ from .common import ForwardPassMode, LalamoModule
24
+ from .embedding import EmbeddingBase, EmbeddingConfig
25
+ from .linear import LinearBase, LinearConfig
26
+ from .rope import PositionalEmbeddings
27
+ from .transformer_layer import TransformerLayerResult
28
+
29
+ __all__ = [
30
+ "Classifier",
31
+ "ClassifierActivationTrace",
32
+ "ClassifierConfig",
33
+ "ClassifierResult",
34
+ ]
35
+
36
+
37
+ class PoolingType(StrEnum):
38
+ CLS = "cls"
39
+ MEAN = "mean"
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class PredictionHeadConfig:
44
+ dense_config: LinearConfig
45
+ activation: Activation
46
+ normalization_config: NormalizationConfig
47
+ readout_config: LinearConfig
48
+ use_dense_bias: bool
49
+
50
+ def empty(self, input_size: int, num_labels: int) -> "PredictionHead":
51
+ dense_layer = self.dense_config.empty(
52
+ input_dim=input_size,
53
+ output_dims=(input_size,),
54
+ has_biases=self.use_dense_bias,
55
+ )
56
+ norm = self.normalization_config.empty(input_size)
57
+ readout = self.readout_config.empty(input_dim=input_size, output_dims=(num_labels,), has_biases=True)
58
+
59
+ return PredictionHead(
60
+ config=self,
61
+ dense=dense_layer,
62
+ activation=self.activation,
63
+ norm=norm,
64
+ readout=readout,
65
+ )
66
+
67
+ def random_init(self, input_size: int, num_labels: int, key: PRNGKeyArray) -> "PredictionHead":
68
+ dense_key, readout_key = jax.random.split(key)
69
+ dense_layer = self.dense_config.random_init(
70
+ input_size, (input_size,), has_biases=self.use_dense_bias, key=dense_key
71
+ )
72
+ norm = self.normalization_config.empty(input_size)
73
+ readout = self.readout_config.random_init(
74
+ input_dim=input_size,
75
+ output_dims=(num_labels,),
76
+ has_biases=True,
77
+ key=readout_key,
78
+ )
79
+
80
+ return PredictionHead(
81
+ config=self,
82
+ dense=dense_layer,
83
+ activation=self.activation,
84
+ norm=norm,
85
+ readout=readout,
86
+ )
87
+
88
+
89
+ class PredictionHead(LalamoModule[PredictionHeadConfig]):
90
+ dense: LinearBase
91
+ activation: Activation
92
+ norm: Normalization
93
+ readout: LinearBase
94
+
95
+ def __call__(self, inner_features: Float[Array, "batch channels"]) -> Float[Array, "batch logits"]:
96
+ return vmap(self.call_unbatched)(inner_features)
97
+
98
+ def call_unbatched(
99
+ self,
100
+ inner_features: Float[Array, " in_channels"],
101
+ ) -> Float[Array, " logits"]:
102
+ (dense_outs,) = self.dense(inner_features)
103
+ dense_outs = self.activation(dense_outs)
104
+ norm_outs = self.norm(dense_outs)
105
+ (result,) = self.readout(norm_outs)
106
+ return result
107
+
108
+ @property
109
+ def activation_precision(self) -> DTypeLike:
110
+ return self.dense.activation_precision
111
+
112
+ def export_weights(self) -> ParameterTree:
113
+ result = dict(
114
+ dense=self.dense.export_weights(),
115
+ norm=self.norm.export_weights(),
116
+ readout=self.readout.export_weights(),
117
+ )
118
+ return result
119
+
120
+ def import_weights(
121
+ self,
122
+ weights: ParameterTree[Array],
123
+ ) -> Self:
124
+ assert isinstance(weights, Mapping)
125
+ assert isinstance(weights["dense"], Mapping)
126
+ assert isinstance(weights["norm"], Mapping)
127
+ assert isinstance(weights["readout"], Mapping)
128
+ return replace(
129
+ self,
130
+ dense=self.dense.import_weights(weights["dense"]),
131
+ norm=self.norm.import_weights(weights["norm"]),
132
+ readout=self.readout.import_weights(weights["readout"]),
133
+ )
134
+
135
+
136
+ class ClassifierActivationTrace(eqx.Module):
137
+ token_ids: Int[Array, "batch tokens"]
138
+ token_positions: Int[Array, "batch tokens"]
139
+
140
+ local_positional_embeddings: PositionalEmbeddings
141
+ global_positional_embeddings: PositionalEmbeddings
142
+
143
+ embedding_norm_output: Float[Array, "batch tokens channels"]
144
+ layer_results: tuple[TransformerLayerResult, ...]
145
+ output_norm: Float[Array, "batch tokens channels"]
146
+ output_pooling: Float[Array, "batch channels"]
147
+ logits: Float[Array, "batch logits"]
148
+
149
+ def export(self) -> ParameterTree:
150
+ result = dict(
151
+ token_ids=self.token_ids,
152
+ token_positions=self.token_positions,
153
+ local_positional_embeddings=self.local_positional_embeddings.export(),
154
+ global_positional_embeddings=self.global_positional_embeddings.export(),
155
+ layer_results=[layer_result.export() for layer_result in self.layer_results],
156
+ output_norm=self.output_norm,
157
+ output_pooling=self.output_pooling,
158
+ logits=self.logits,
159
+ )
160
+ return result
161
+
162
+
163
+ class ClassifierResult(eqx.Module):
164
+ logits: Float[Array, "batch logits"]
165
+ activation_trace: ClassifierActivationTrace | None = None
166
+
167
+ def export(self) -> ParameterTree:
168
+ result: dict[str, ParameterTree | Array] = dict(
169
+ logits=self.logits,
170
+ )
171
+ if self.activation_trace is not None:
172
+ result["activation_trace"] = self.activation_trace.export()
173
+ return result
174
+
175
+
176
+ @dataclass(frozen=True)
177
+ class ClassifierConfig:
178
+ embedding_config: EmbeddingConfig
179
+ embedding_norm_config: NormalizationConfig
180
+ transformer_config: TransformerConfig
181
+ prediction_head_config: PredictionHeadConfig
182
+ readout_config: LinearConfig
183
+
184
+ vocab_size: int
185
+ model_dim: int
186
+ hidden_dim: int
187
+ attention_scale: float | None
188
+ num_layers: int
189
+ context_length: int
190
+ num_labels: int
191
+ classifier_pooling: PoolingType
192
+
193
+ output_labels: tuple[str, ...] | None
194
+
195
+ def random_init(
196
+ self,
197
+ *,
198
+ key: PRNGKeyArray,
199
+ ) -> "Classifier":
200
+ embedding_key, transformer_key, prediction_head_key = jax.random.split(key, num=3)
201
+ embedding = self.embedding_config.random_init(
202
+ vocab_size=self.vocab_size,
203
+ model_dim=self.model_dim,
204
+ key=embedding_key,
205
+ )
206
+ embedding_norm = self.embedding_norm_config.empty(self.model_dim)
207
+ transformer = self.transformer_config.random_init(
208
+ key=transformer_key,
209
+ )
210
+ prediction_head = self.prediction_head_config.random_init(
211
+ input_size=self.hidden_dim,
212
+ num_labels=self.num_labels,
213
+ key=prediction_head_key,
214
+ )
215
+ return Classifier(
216
+ self,
217
+ embedding=embedding,
218
+ embedding_norm=embedding_norm,
219
+ transformer=transformer,
220
+ prediction_head=prediction_head,
221
+ )
222
+
223
+ def empty(self) -> "Classifier":
224
+ embedding = self.embedding_config.empty(
225
+ vocab_size=self.vocab_size,
226
+ model_dim=self.model_dim,
227
+ )
228
+ embedding_norm = self.embedding_norm_config.empty(self.model_dim)
229
+ transformer = self.transformer_config.empty()
230
+ prediction_head = self.prediction_head_config.empty(
231
+ input_size=self.hidden_dim,
232
+ num_labels=self.num_labels,
233
+ )
234
+ return Classifier(
235
+ self,
236
+ embedding=embedding,
237
+ embedding_norm=embedding_norm,
238
+ transformer=transformer,
239
+ prediction_head=prediction_head,
240
+ )
241
+
242
+
243
+ class Classifier(LalamoModule[ClassifierConfig]):
244
+ embedding: EmbeddingBase
245
+ embedding_norm: Normalization
246
+ transformer: Transformer
247
+ prediction_head: PredictionHead
248
+
249
+ @property
250
+ def activation_precision(self) -> DTypeLike:
251
+ return self.embedding.activation_precision
252
+
253
+ def __post_init__(self) -> None:
254
+ if self.config.output_labels is not None and len(self.config.output_labels) != self.config.num_labels:
255
+ raise ValueError("Number of output logits is different from provided list of labels")
256
+
257
+ @eqx.filter_jit
258
+ def __call__(
259
+ self,
260
+ token_ids: Int[Array, "batch tokens"],
261
+ token_positions: Int[Array, "batch tokens"],
262
+ return_activation_trace: bool = False,
263
+ lengths_without_padding: Int[Array, " batch"] | None = None,
264
+ forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
265
+ forward_pass_config: TransformerForwardPassConfig | None = None,
266
+ ) -> ClassifierResult:
267
+ inner_features = self.embedding.embed(token_ids)
268
+ normalized_embeddings = vmap_twice(self.embedding_norm)(inner_features)
269
+
270
+ transformer_result = self.transformer(
271
+ inner_features=normalized_embeddings,
272
+ token_positions=token_positions,
273
+ state=None,
274
+ return_updated_state=False,
275
+ return_layer_results=return_activation_trace,
276
+ return_positional_embeddings=return_activation_trace,
277
+ lengths_without_padding=lengths_without_padding,
278
+ forward_pass_mode=forward_pass_mode,
279
+ forward_pass_config=forward_pass_config,
280
+ )
281
+
282
+ if self.config.classifier_pooling == PoolingType.CLS:
283
+ pooled_output = transformer_result.outputs[:, 0, :]
284
+ elif self.config.classifier_pooling == PoolingType.MEAN:
285
+ attention_mask = jnp.ones((*token_ids.shape, 1), dtype=transformer_result.outputs.dtype)
286
+ pooled_output = (transformer_result.outputs * attention_mask).sum(axis=1) / attention_mask.sum(axis=1)
287
+ else:
288
+ raise TypeError(f"classifier_pooling of unknown type: {self.config.classifier_pooling}")
289
+
290
+ logits = self.prediction_head(pooled_output)
291
+
292
+ if return_activation_trace:
293
+ assert transformer_result.layer_results is not None
294
+ assert transformer_result.global_positional_embeddings is not None
295
+ assert transformer_result.local_positional_embeddings is not None
296
+ activation_trace = ClassifierActivationTrace(
297
+ token_ids=token_ids,
298
+ token_positions=token_positions,
299
+ global_positional_embeddings=transformer_result.global_positional_embeddings,
300
+ local_positional_embeddings=transformer_result.local_positional_embeddings,
301
+ embedding_norm_output=normalized_embeddings,
302
+ layer_results=tuple(transformer_result.layer_results),
303
+ output_norm=transformer_result.outputs,
304
+ output_pooling=pooled_output,
305
+ logits=logits,
306
+ )
307
+ else:
308
+ activation_trace = None
309
+
310
+ return ClassifierResult(
311
+ logits=logits,
312
+ activation_trace=activation_trace,
313
+ )
314
+
315
+ def export_weights(self) -> ParameterTree:
316
+ result = dict(
317
+ embedding=self.embedding.export_weights(),
318
+ embedding_norm=self.embedding_norm.export_weights(),
319
+ transformer=self.transformer.export_weights(),
320
+ prediction_head=self.prediction_head.export_weights(),
321
+ )
322
+ return result
323
+
324
+ def import_weights(
325
+ self,
326
+ weights: ParameterTree[Array],
327
+ ) -> Self:
328
+ assert isinstance(weights, Mapping)
329
+ assert isinstance(weights["embedding"], Mapping)
330
+ assert isinstance(weights["embedding_norm"], Mapping)
331
+ assert isinstance(weights["transformer"], Mapping)
332
+ assert isinstance(weights["prediction_head"], Mapping)
333
+ return replace(
334
+ self,
335
+ embedding=self.embedding.import_weights(weights["embedding"]),
336
+ embedding_norm=self.embedding_norm.import_weights(weights["embedding_norm"]),
337
+ transformer=self.transformer.import_weights(weights["transformer"]),
338
+ prediction_head=self.prediction_head.import_weights(weights["prediction_head"]),
339
+ )
lalamo/modules/common.py CHANGED
@@ -2,7 +2,7 @@ from abc import abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from types import UnionType
5
- from typing import Any, Self
5
+ from typing import Any, Generic, Self, TypeVar
6
6
 
7
7
  import equinox as eqx
8
8
  from cattrs import Converter
@@ -32,8 +32,11 @@ class ForwardPassMode(Enum):
32
32
  SINGLE_TOKEN = "single_token"
33
33
 
34
34
 
35
- class LalamoModule[ConfigT](eqx.Module):
36
- config: ConfigT = eqx.field(static=True)
35
+ ConfigT_co = TypeVar("ConfigT_co", covariant=True)
36
+
37
+
38
+ class LalamoModule(eqx.Module, Generic[ConfigT_co]): # noqa: UP046
39
+ config: ConfigT_co = eqx.field(static=True)
37
40
 
38
41
  @property
39
42
  @abstractmethod