lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. lalamo/__init__.py +15 -2
  2. lalamo/data/__init__.py +0 -1
  3. lalamo/data/huggingface_message.py +1 -0
  4. lalamo/main.py +167 -18
  5. lalamo/message_processor.py +2 -3
  6. lalamo/model_import/common.py +120 -27
  7. lalamo/model_import/decoder_configs/__init__.py +4 -2
  8. lalamo/model_import/decoder_configs/common.py +62 -21
  9. lalamo/model_import/decoder_configs/executorch.py +14 -9
  10. lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
  11. lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
  12. lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
  13. lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
  14. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
  15. lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
  16. lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
  17. lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
  18. lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
  21. lalamo/model_import/loaders/__init__.py +3 -2
  22. lalamo/model_import/loaders/executorch.py +24 -12
  23. lalamo/model_import/loaders/huggingface.py +258 -30
  24. lalamo/model_import/model_specs/__init__.py +4 -2
  25. lalamo/model_import/model_specs/common.py +8 -2
  26. lalamo/model_import/model_specs/gemma.py +5 -1
  27. lalamo/model_import/model_specs/huggingface.py +1 -1
  28. lalamo/model_import/model_specs/mirai.py +20 -0
  29. lalamo/models/__init__.py +10 -0
  30. lalamo/models/common.py +81 -0
  31. lalamo/{language_model.py → models/language_model.py} +32 -49
  32. lalamo/models/router.py +59 -0
  33. lalamo/modules/__init__.py +33 -16
  34. lalamo/modules/classifier.py +339 -0
  35. lalamo/modules/common.py +6 -3
  36. lalamo/modules/decoder.py +52 -180
  37. lalamo/modules/mlp.py +28 -5
  38. lalamo/modules/normalization.py +13 -8
  39. lalamo/modules/token_mixers/attention.py +10 -6
  40. lalamo/modules/token_mixers/state/kv_cache.py +14 -4
  41. lalamo/modules/transformer.py +273 -0
  42. lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
  43. lalamo/speculator/__init__.py +6 -2
  44. lalamo/speculator/estimator.py +91 -0
  45. lalamo/speculator/inference.py +28 -9
  46. lalamo/speculator/ngram.py +7 -3
  47. lalamo/speculator/utils.py +4 -2
  48. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
  49. lalamo-0.5.4.dist-info/RECORD +88 -0
  50. lalamo-0.5.2.dist-info/RECORD +0 -80
  51. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
  52. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
  53. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
  54. {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ from collections.abc import Mapping, Sequence
2
+ from dataclasses import dataclass, replace
3
+ from typing import Self
4
+
5
+ import equinox as eqx
6
+ import jax
7
+ from jax import vmap
8
+ from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
9
+
10
+ from lalamo.common import ParameterTree
11
+ from lalamo.modules.token_mixers import AttentionConfig
12
+ from lalamo.modules.utils import vmap_twice
13
+
14
+ from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
15
+ from .normalization import Normalization, NormalizationConfig
16
+ from .rope import PositionalEmbeddings, RoPE, RoPEConfig
17
+ from .token_mixers import State
18
+ from .transformer_layer import (
19
+ TransformerLayer,
20
+ TransformerLayerConfig,
21
+ TransformerLayerForwardPassConfig,
22
+ TransformerLayerResult,
23
+ )
24
+
25
+ __all__ = [
26
+ "Transformer",
27
+ "TransformerConfig",
28
+ "TransformerResult",
29
+ ]
30
+
31
+
32
+ type TransformerForwardPassConfig = TransformerLayerForwardPassConfig
33
+
34
+
35
+ class TransformerResult(eqx.Module):
36
+ outputs: Float[Array, "batch suffix_tokens channels"]
37
+ updated_state: State | None = None
38
+ layer_results: tuple[TransformerLayerResult, ...] | None = None
39
+ global_positional_embeddings: PositionalEmbeddings | None = None
40
+ local_positional_embeddings: PositionalEmbeddings | None = None
41
+
42
+ def export(self) -> ParameterTree:
43
+ result: dict[str, ParameterTree | Array] = dict(
44
+ outputs=self.outputs,
45
+ )
46
+ if self.updated_state is not None:
47
+ result["updated_state"] = [state_layer.export() for state_layer in self.updated_state]
48
+ if self.layer_results is not None:
49
+ result["layer_results"] = [layer_result.export() for layer_result in self.layer_results]
50
+ if self.global_positional_embeddings is not None:
51
+ result["global_positional_embeddings"] = self.global_positional_embeddings.export()
52
+ if self.local_positional_embeddings is not None:
53
+ result["local_positional_embeddings"] = self.local_positional_embeddings.export()
54
+ return result
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class TransformerConfig:
59
+ global_rope_config: RoPEConfig | None
60
+ local_rope_config: RoPEConfig | None
61
+ layer_configs: tuple[TransformerLayerConfig, ...]
62
+ output_norm_config: NormalizationConfig
63
+ model_dim: int
64
+ hidden_dim: int
65
+ context_length: int
66
+
67
+ def random_init(self, *, key: PRNGKeyArray) -> "Transformer":
68
+ first_layer_config, *_ = self.layer_configs
69
+
70
+ if self.global_rope_config:
71
+ global_rope = self.global_rope_config.init(
72
+ head_dim=first_layer_config.rope_dim,
73
+ num_timesteps=self.context_length,
74
+ )
75
+ else:
76
+ global_rope = None
77
+
78
+ if self.local_rope_config:
79
+ max_sliding_window_size = max(
80
+ layer_config.mixer_config.sliding_window_size or 0
81
+ for layer_config in self.layer_configs
82
+ if isinstance(layer_config.mixer_config, AttentionConfig)
83
+ )
84
+
85
+ local_rope = self.local_rope_config.init(
86
+ head_dim=first_layer_config.rope_dim,
87
+ num_timesteps=max(max_sliding_window_size, self.context_length),
88
+ )
89
+ else:
90
+ local_rope = None
91
+
92
+ layers_keys = jax.random.split(key, num=len(self.layer_configs))
93
+ layers = tuple(
94
+ layer_config.random_init(
95
+ model_dim=self.model_dim,
96
+ hidden_dim=self.hidden_dim,
97
+ key=layer_key,
98
+ )
99
+ for layer_key, layer_config in zip(layers_keys, self.layer_configs, strict=True)
100
+ )
101
+ output_norm = self.output_norm_config.init(self.model_dim)
102
+
103
+ return Transformer(
104
+ config=self,
105
+ global_rope=global_rope,
106
+ local_rope=local_rope,
107
+ layers=layers,
108
+ output_norm=output_norm,
109
+ )
110
+
111
+ def empty(self) -> "Transformer":
112
+ first_layer_config, *_ = self.layer_configs
113
+
114
+ if self.global_rope_config:
115
+ global_rope = self.global_rope_config.init(
116
+ head_dim=first_layer_config.rope_dim,
117
+ num_timesteps=self.context_length,
118
+ )
119
+ else:
120
+ global_rope = None
121
+
122
+ if self.local_rope_config:
123
+ local_rope = self.local_rope_config.init(
124
+ head_dim=first_layer_config.rope_dim,
125
+ num_timesteps=self.context_length,
126
+ )
127
+ else:
128
+ local_rope = None
129
+
130
+ layers = tuple(
131
+ layer_config.empty(
132
+ model_dim=self.model_dim,
133
+ hidden_dim=self.hidden_dim,
134
+ )
135
+ for layer_config in self.layer_configs
136
+ )
137
+ output_norm = self.output_norm_config.empty(self.model_dim)
138
+
139
+ return Transformer(
140
+ config=self,
141
+ global_rope=global_rope,
142
+ local_rope=local_rope,
143
+ layers=layers,
144
+ output_norm=output_norm,
145
+ )
146
+
147
+
148
+ class Transformer(LalamoModule[TransformerConfig]):
149
+ global_rope: RoPE | None
150
+ local_rope: RoPE | None
151
+ layers: tuple[TransformerLayer, ...]
152
+ output_norm: Normalization
153
+
154
+ @property
155
+ def activation_precision(self) -> DTypeLike:
156
+ return self.layers[0].activation_precision
157
+
158
+ @eqx.filter_jit
159
+ def __call__(
160
+ self,
161
+ inner_features: Float[Array, "batch suffix_tokens channels"],
162
+ token_positions: Int[Array, "batch suffix_tokens"],
163
+ state: State | None,
164
+ return_updated_state: bool,
165
+ return_layer_results: bool,
166
+ return_positional_embeddings: bool,
167
+ lengths_without_padding: Int[Array, " batch"] | None,
168
+ forward_pass_mode: ForwardPassMode,
169
+ forward_pass_config: TransformerForwardPassConfig | None,
170
+ ) -> TransformerResult:
171
+ if inner_features.ndim != 3:
172
+ raise ValueError(
173
+ f"inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim), got {inner_features.shape}",
174
+ )
175
+ if token_positions.ndim != 2:
176
+ raise ValueError(
177
+ "token_positions must be a 2D array of size (batch_size, sequence_length),"
178
+ f" got {token_positions.shape}",
179
+ )
180
+
181
+ maybe_state = state or ([None] * len(self.layers))
182
+
183
+ if self.global_rope is not None:
184
+ global_positional_embeddings = vmap(self.global_rope)(token_positions)
185
+ else:
186
+ global_positional_embeddings = None
187
+ if self.local_rope is not None:
188
+ local_positional_embeddings = vmap(self.local_rope)(token_positions)
189
+ else:
190
+ local_positional_embeddings = global_positional_embeddings
191
+
192
+ updated_state_layers = []
193
+ layer_results = []
194
+
195
+ for layer, state_layer in zip(self.layers, maybe_state, strict=True):
196
+ match layer.positional_embedding_selector:
197
+ case PositionalEmbeddingSelector.LOCAL:
198
+ positional_embeddings_to_use = local_positional_embeddings
199
+ case PositionalEmbeddingSelector.GLOBAL:
200
+ positional_embeddings_to_use = global_positional_embeddings
201
+ case PositionalEmbeddingSelector.NONE:
202
+ positional_embeddings_to_use = None
203
+
204
+ layer_result = layer(
205
+ inner_features,
206
+ positional_embeddings_to_use,
207
+ state=state_layer,
208
+ return_updated_state=return_updated_state,
209
+ return_activation_trace=return_layer_results,
210
+ lengths_without_padding=lengths_without_padding,
211
+ forward_pass_mode=forward_pass_mode,
212
+ forward_pass_config=forward_pass_config,
213
+ )
214
+ inner_features = layer_result.outputs
215
+ layer_results.append(layer_result)
216
+ updated_state_layers.append(layer_result.updated_state)
217
+
218
+ normalized_outputs = vmap_twice(self.output_norm)(inner_features)
219
+
220
+ return TransformerResult(
221
+ outputs=normalized_outputs,
222
+ updated_state=(State(updated_state_layers) if return_updated_state else None),
223
+ layer_results=tuple(layer_results) if return_layer_results else None,
224
+ global_positional_embeddings=(global_positional_embeddings if return_positional_embeddings else None),
225
+ local_positional_embeddings=(local_positional_embeddings if return_positional_embeddings else None),
226
+ )
227
+
228
+ def init_static_state(self, batch_size: int, capacity: int) -> State:
229
+ return State(layer.init_static_state(batch_size, capacity) for layer in self.layers)
230
+
231
+ def export_weights(self) -> ParameterTree:
232
+ result = dict(
233
+ layers=[layer.export_weights() for layer in self.layers],
234
+ output_norm=self.output_norm.export_weights(),
235
+ )
236
+ if self.global_rope:
237
+ result["global_rope"] = self.global_rope.export_weights()
238
+ if self.local_rope:
239
+ result["local_rope"] = self.local_rope.export_weights()
240
+ return result
241
+
242
+ def import_weights(
243
+ self,
244
+ weights: ParameterTree[Array],
245
+ ) -> Self:
246
+ assert isinstance(weights, Mapping)
247
+ assert isinstance(weights["layers"], Sequence)
248
+ assert isinstance(weights["output_norm"], Mapping)
249
+
250
+ if self.global_rope:
251
+ assert isinstance(weights["global_rope"], Mapping)
252
+ global_rope = self.global_rope.import_weights(weights["global_rope"])
253
+ else:
254
+ global_rope = None
255
+
256
+ if self.local_rope:
257
+ assert isinstance(weights["local_rope"], Mapping)
258
+ local_rope = self.local_rope.import_weights(weights["local_rope"])
259
+ else:
260
+ local_rope = None
261
+
262
+ layers = []
263
+ for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
264
+ assert isinstance(layer_weights, Mapping)
265
+ layers.append(layer.import_weights(layer_weights))
266
+
267
+ return replace(
268
+ self,
269
+ global_rope=global_rope,
270
+ layers=tuple(layers),
271
+ output_norm=self.output_norm.import_weights(weights["output_norm"]),
272
+ local_rope=local_rope,
273
+ )
@@ -13,24 +13,24 @@ from lalamo.common import ParameterTree
13
13
 
14
14
  from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
15
15
  from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
16
- from .normalization import RMSNorm, RMSNormConfig
16
+ from .normalization import Normalization, NormalizationConfig
17
17
  from .rope import PositionalEmbeddings
18
18
  from .token_mixers import KVCacheLayer, StateLayerBase, StaticKVCacheLayer, TokenMixerBase, TokenMixerConfig
19
19
  from .utils import vmap_twice
20
20
 
21
21
  __all__ = [
22
- "DecoderLayer",
23
- "DecoderLayerActivationTrace",
24
- "DecoderLayerConfig",
25
- "DecoderLayerForwardPassConfig",
26
- "DecoderLayerResult",
22
+ "TransformerLayer",
23
+ "TransformerLayerActivationTrace",
24
+ "TransformerLayerConfig",
25
+ "TransformerLayerForwardPassConfig",
26
+ "TransformerLayerResult",
27
27
  ]
28
28
 
29
29
 
30
- type DecoderLayerForwardPassConfig = MLPForwardPassConfig
30
+ type TransformerLayerForwardPassConfig = MLPForwardPassConfig
31
31
 
32
32
 
33
- class DecoderLayerActivationTrace(eqx.Module):
33
+ class TransformerLayerActivationTrace(eqx.Module):
34
34
  inputs: Float[Array, "batch suffix_tokens channels"]
35
35
  positional_embeddings: PositionalEmbeddings | None
36
36
  state: StateLayerBase | None
@@ -63,10 +63,10 @@ class DecoderLayerActivationTrace(eqx.Module):
63
63
  return result
64
64
 
65
65
 
66
- class DecoderLayerResult(eqx.Module):
67
- outputs: Float[Array, "suffix_tokens channels"]
66
+ class TransformerLayerResult(eqx.Module):
67
+ outputs: Float[Array, "batch tokens channels"]
68
68
  updated_state: KVCacheLayer | None
69
- activation_trace: DecoderLayerActivationTrace | None
69
+ activation_trace: TransformerLayerActivationTrace | None
70
70
 
71
71
  def export(self) -> ParameterTree:
72
72
  result: dict[str, ParameterTree | Array] = dict(
@@ -80,13 +80,13 @@ class DecoderLayerResult(eqx.Module):
80
80
 
81
81
 
82
82
  @dataclass(frozen=True)
83
- class DecoderLayerConfig:
84
- pre_mixer_norm_config: RMSNormConfig
83
+ class TransformerLayerConfig:
84
+ pre_mixer_norm_config: NormalizationConfig | None
85
85
  mixer_config: TokenMixerConfig
86
- post_mixer_norm_config: RMSNormConfig | None
87
- pre_mlp_norm_config: RMSNormConfig
86
+ post_mixer_norm_config: NormalizationConfig | None
87
+ pre_mlp_norm_config: NormalizationConfig
88
88
  mlp_config: MLPConfig
89
- post_mlp_norm_config: RMSNormConfig | None
89
+ post_mlp_norm_config: NormalizationConfig | None
90
90
 
91
91
  @property
92
92
  def rope_dim(self) -> int:
@@ -98,28 +98,31 @@ class DecoderLayerConfig:
98
98
  hidden_dim: int,
99
99
  *,
100
100
  key: PRNGKeyArray,
101
- ) -> "DecoderLayer":
101
+ ) -> "TransformerLayer":
102
102
  attention_key, mlp_key = jax.random.split(key)
103
- pre_attention_norm = self.pre_mixer_norm_config.init(model_dim)
103
+ if self.pre_mixer_norm_config is not None:
104
+ pre_mixer_norm = self.pre_mixer_norm_config.init(model_dim)
105
+ else:
106
+ pre_mixer_norm = None
104
107
  mixer = self.mixer_config.random_init(
105
108
  model_dim=model_dim,
106
109
  key=attention_key,
107
110
  )
108
111
  if self.post_mixer_norm_config is not None:
109
- post_attention_norm = self.post_mixer_norm_config.init(model_dim)
112
+ post_mixer_norm = self.post_mixer_norm_config.init(model_dim)
110
113
  else:
111
- post_attention_norm = None
114
+ post_mixer_norm = None
112
115
  pre_mlp_norm = self.pre_mlp_norm_config.init(model_dim)
113
116
  mlp = self.mlp_config.random_init(model_dim, hidden_dim, key=mlp_key)
114
117
  if self.post_mlp_norm_config is not None:
115
118
  post_mlp_norm = self.post_mlp_norm_config.init(model_dim)
116
119
  else:
117
120
  post_mlp_norm = None
118
- return DecoderLayer(
121
+ return TransformerLayer(
119
122
  config=self,
120
- pre_mixer_norm=pre_attention_norm,
123
+ pre_mixer_norm=pre_mixer_norm,
121
124
  mixer=mixer,
122
- post_mixer_norm=post_attention_norm,
125
+ post_mixer_norm=post_mixer_norm,
123
126
  pre_mlp_norm=pre_mlp_norm,
124
127
  mlp=mlp,
125
128
  post_mlp_norm=post_mlp_norm,
@@ -129,39 +132,42 @@ class DecoderLayerConfig:
129
132
  self,
130
133
  model_dim: int,
131
134
  hidden_dim: int,
132
- ) -> "DecoderLayer":
133
- pre_attention_norm = self.pre_mixer_norm_config.empty(model_dim)
135
+ ) -> "TransformerLayer":
136
+ if self.pre_mixer_norm_config is not None:
137
+ pre_mixer_norm = self.pre_mixer_norm_config.empty(model_dim)
138
+ else:
139
+ pre_mixer_norm = None
134
140
  attention = self.mixer_config.empty(
135
141
  model_dim=model_dim,
136
142
  )
137
143
  if self.post_mixer_norm_config is not None:
138
- post_attention_norm = self.post_mixer_norm_config.empty(model_dim)
144
+ post_mixer_norm = self.post_mixer_norm_config.empty(model_dim)
139
145
  else:
140
- post_attention_norm = None
146
+ post_mixer_norm = None
141
147
  pre_mlp_norm = self.pre_mlp_norm_config.empty(model_dim)
142
148
  mlp = self.mlp_config.empty(model_dim, hidden_dim)
143
149
  if self.post_mlp_norm_config is not None:
144
150
  post_mlp_norm = self.post_mlp_norm_config.empty(model_dim)
145
151
  else:
146
152
  post_mlp_norm = None
147
- return DecoderLayer(
153
+ return TransformerLayer(
148
154
  config=self,
149
- pre_mixer_norm=pre_attention_norm,
155
+ pre_mixer_norm=pre_mixer_norm,
150
156
  mixer=attention,
151
- post_mixer_norm=post_attention_norm,
157
+ post_mixer_norm=post_mixer_norm,
152
158
  pre_mlp_norm=pre_mlp_norm,
153
159
  mlp=mlp,
154
160
  post_mlp_norm=post_mlp_norm,
155
161
  )
156
162
 
157
163
 
158
- class DecoderLayer(LalamoModule[DecoderLayerConfig]):
159
- pre_mixer_norm: RMSNorm
164
+ class TransformerLayer(LalamoModule[TransformerLayerConfig]):
165
+ pre_mixer_norm: Normalization | None
160
166
  mixer: TokenMixerBase
161
- post_mixer_norm: RMSNorm | None
162
- pre_mlp_norm: RMSNorm
167
+ post_mixer_norm: Normalization | None
168
+ pre_mlp_norm: Normalization
163
169
  mlp: MLPBase
164
- post_mlp_norm: RMSNorm | None
170
+ post_mlp_norm: Normalization | None
165
171
 
166
172
  @property
167
173
  def activation_precision(self) -> DTypeLike:
@@ -172,7 +178,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
172
178
  return self.mixer.positional_embedding_selector
173
179
 
174
180
  def __post_init__(self) -> None:
175
- model_dim = self.pre_mixer_norm.input_dim
181
+ model_dim = self.pre_mixer_norm.input_dim if self.pre_mixer_norm is not None else self.mixer.model_dim
176
182
  if self.mixer.model_dim != model_dim:
177
183
  raise ValueError(
178
184
  f"Attention model dim {self.mixer.model_dim} does not match"
@@ -204,15 +210,21 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
204
210
  return_activation_trace: bool = False,
205
211
  lengths_without_padding: Int[Array, " batch"] | None = None,
206
212
  forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
207
- forward_pass_config: DecoderLayerForwardPassConfig | None = None,
208
- ) -> DecoderLayerResult:
213
+ forward_pass_config: TransformerLayerForwardPassConfig | None = None,
214
+ ) -> TransformerLayerResult:
209
215
  if inputs.ndim != 3:
210
216
  raise ValueError(
211
217
  f"Inputs to decoder layers must be a 3D arrays of size (batch_size, sequence_length, hidden_dim),"
212
218
  f" got {inputs.shape}",
213
219
  )
214
- normalized_mixer_inputs = vmap_twice(self.pre_mixer_norm)(inputs)
215
- batched_mixer_fn = vmap(partial(self.mixer, return_updated_state=return_updated_state))
220
+ if self.pre_mixer_norm is not None:
221
+ normalized_mixer_inputs = vmap_twice(self.pre_mixer_norm)(inputs)
222
+ else:
223
+ normalized_mixer_inputs = inputs
224
+
225
+ batched_mixer_fn = vmap(
226
+ partial(self.mixer, return_updated_state=return_updated_state or return_activation_trace),
227
+ )
216
228
  mixer_outputs, updated_state = batched_mixer_fn(
217
229
  normalized_mixer_inputs,
218
230
  positional_embeddings,
@@ -240,7 +252,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
240
252
  outputs = mlp_inputs + mlp_outputs
241
253
 
242
254
  if return_activation_trace:
243
- activation_trace = DecoderLayerActivationTrace(
255
+ activation_trace = TransformerLayerActivationTrace(
244
256
  inputs=inputs,
245
257
  positional_embeddings=positional_embeddings,
246
258
  state=state,
@@ -255,7 +267,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
255
267
  else:
256
268
  activation_trace = None
257
269
 
258
- return DecoderLayerResult(
270
+ return TransformerLayerResult(
259
271
  outputs=outputs,
260
272
  updated_state=updated_state,
261
273
  activation_trace=activation_trace,
@@ -269,11 +281,12 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
269
281
 
270
282
  def export_weights(self) -> ParameterTree:
271
283
  result = dict(
272
- pre_mixer_norm=self.pre_mixer_norm.export_weights(),
273
284
  mixer=self.mixer.export_weights(),
274
285
  pre_mlp_norm=self.pre_mlp_norm.export_weights(),
275
286
  mlp=self.mlp.export_weights(),
276
287
  )
288
+ if self.pre_mixer_norm is not None:
289
+ result["pre_mixer_norm"] = self.pre_mixer_norm.export_weights()
277
290
  if self.post_mixer_norm is not None:
278
291
  result["post_mixer_norm"] = self.post_mixer_norm.export_weights()
279
292
  if self.post_mlp_norm is not None:
@@ -285,7 +298,6 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
285
298
  weights: ParameterTree[Array],
286
299
  ) -> Self:
287
300
  assert isinstance(weights, Mapping)
288
- assert isinstance(weights["pre_mixer_norm"], Mapping)
289
301
  assert isinstance(weights["mixer"], Mapping)
290
302
  assert isinstance(weights["mlp"], Mapping)
291
303
  assert isinstance(weights["pre_mlp_norm"], Mapping)
@@ -302,9 +314,14 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
302
314
  post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
303
315
  else:
304
316
  post_mlp_norm = None
317
+ if self.pre_mixer_norm is not None:
318
+ assert isinstance(weights["pre_mixer_norm"], Mapping)
319
+ pre_mixer_norm = self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"])
320
+ else:
321
+ pre_mixer_norm = None
305
322
  return replace(
306
323
  self,
307
- pre_mixer_norm=self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"]),
324
+ pre_mixer_norm=pre_mixer_norm,
308
325
  mixer=self.mixer.import_weights(weights["mixer"]),
309
326
  post_mixer_norm=post_mixer_norm,
310
327
  pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
@@ -1,11 +1,15 @@
1
1
  from .common import Speculator
2
- from .inference import inference_collect_traces
2
+ from .estimator import estimate_batchsize_from_memory
3
+ from .inference import CollectTracesEvent, inference_collect_traces
3
4
  from .ngram import NGramSpeculator
4
- from .utils import train_speculator
5
+ from .utils import SpeculatorTrainingEvent, train_speculator
5
6
 
6
7
  __all__ = [
8
+ "CollectTracesEvent",
7
9
  "NGramSpeculator",
8
10
  "Speculator",
11
+ "SpeculatorTrainingEvent",
12
+ "estimate_batchsize_from_memory",
9
13
  "inference_collect_traces",
10
14
  "train_speculator",
11
15
  ]
@@ -0,0 +1,91 @@
1
+ import functools
2
+ import itertools
3
+ from collections.abc import Callable
4
+ from typing import NamedTuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+ from lalamo.models import LanguageModel
10
+
11
+
12
+ def estimate_memory_from_batchsize(
13
+ model: LanguageModel,
14
+ max_input_length: int,
15
+ max_output_length: int,
16
+ num_logits_per_token: int,
17
+ batch_size: int,
18
+ ) -> int:
19
+ memory_analysis = (
20
+ jax.jit(
21
+ functools.partial(
22
+ model.generate_tokens,
23
+ max_output_length=max_output_length,
24
+ num_top_logits_to_return=num_logits_per_token,
25
+ ),
26
+ backend="cpu", # cuda backend tries to allocate in .compile() and ooms
27
+ )
28
+ .lower(
29
+ prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
30
+ prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
31
+ )
32
+ .compile()
33
+ .memory_analysis()
34
+ )
35
+
36
+ assert hasattr(memory_analysis, "argument_size_in_bytes")
37
+ assert hasattr(memory_analysis, "output_size_in_bytes")
38
+ assert hasattr(memory_analysis, "temp_size_in_bytes")
39
+
40
+ return (
41
+ memory_analysis.argument_size_in_bytes # type: ignore (pyright bug)
42
+ + memory_analysis.output_size_in_bytes # type: ignore (pyright bug)
43
+ + memory_analysis.temp_size_in_bytes # type: ignore (pyright bug)
44
+ )
45
+
46
+
47
+ class EstimateBatchsizeFromMemoryEvent(NamedTuple):
48
+ lo: int
49
+ hi: int | None
50
+
51
+
52
+ def estimate_batchsize_from_memory(
53
+ model: LanguageModel,
54
+ max_input_length: int,
55
+ max_output_length: int,
56
+ num_logits_per_token: int,
57
+ target_mem: int,
58
+ progress: Callable[[EstimateBatchsizeFromMemoryEvent], None] | None = None,
59
+ ) -> int:
60
+ mem_for_bs = functools.cache(
61
+ functools.partial(
62
+ estimate_memory_from_batchsize,
63
+ model,
64
+ max_input_length,
65
+ max_output_length,
66
+ num_logits_per_token,
67
+ ),
68
+ )
69
+
70
+ lo = 0
71
+ hi = 0
72
+ for candidate_exp in itertools.count():
73
+ lo = hi
74
+ hi = 2**candidate_exp
75
+
76
+ if progress is not None:
77
+ progress(EstimateBatchsizeFromMemoryEvent(lo, None))
78
+ if target_mem < mem_for_bs(hi):
79
+ break
80
+
81
+ while hi - lo > 1:
82
+ mid = (lo + hi) // 2
83
+
84
+ if progress is not None:
85
+ progress(EstimateBatchsizeFromMemoryEvent(lo, hi))
86
+ if target_mem < mem_for_bs(mid):
87
+ hi = mid
88
+ else:
89
+ lo = mid
90
+
91
+ return lo