lalamo 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +273 -45
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +10 -6
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -3
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/METADATA +11 -4
  48. lalamo-0.4.1.dist-info/RECORD +71 -0
  49. lalamo-0.3.4.dist-info/RECORD +0 -59
  50. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/top_level.txt +0 -0
lalamo/modules/mlp.py CHANGED
@@ -1,60 +1,170 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
1
3
  from collections.abc import Mapping
2
4
  from dataclasses import dataclass, replace
5
+ from functools import partial
3
6
  from typing import Self
4
7
 
5
8
  import equinox as eqx
6
9
  import jax
7
- from jaxtyping import Array, DTypeLike, Float, PRNGKeyArray
10
+ import jax.numpy as jnp
11
+ from einops import rearrange
12
+ from jax import vmap
13
+ from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
8
14
 
9
15
  from lalamo.common import ParameterTree
16
+ from lalamo.modules.utils import vmap_twice
10
17
 
11
18
  from .activations import Activation
12
- from .common import LalamoModule, WeightLayout
19
+ from .common import DummyUnionMember, ForwardPassMode, LalamoModule, register_config_union
13
20
  from .linear import LinearBase, LinearConfig
14
21
 
15
- __all__ = ["MLP", "MLPConfig"]
22
+ __all__ = [
23
+ "DenseMLP",
24
+ "DenseMLPConfig",
25
+ "MLPBase",
26
+ "MLPConfig",
27
+ "MLPForwardPassConfig",
28
+ "MixtureOfExperts",
29
+ "MixtureOfExpertsConfig",
30
+ "RoutingFunction",
31
+ "SoftmaxRouting",
32
+ ]
33
+
34
+
35
+ _SENTINEL = 2**31 - 1
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class MLPForwardPassConfig:
40
+ moe_chunk_size_ratio: float = 0.2
41
+
42
+
43
+ class MLPBase[ConfigT: MLPConfig](LalamoModule[ConfigT]):
44
+ @property
45
+ @abstractmethod
46
+ def activation_precision(self) -> DTypeLike: ...
47
+
48
+ @property
49
+ @abstractmethod
50
+ def model_dim(self) -> int: ...
51
+
52
+ @property
53
+ @abstractmethod
54
+ def hidden_dim(self) -> int: ...
55
+
56
+ @abstractmethod
57
+ def __call__(
58
+ self,
59
+ inputs: Float[Array, "batch suffix_tokens channels"],
60
+ lengths_without_padding: Int[Array, " batch"] | None = None,
61
+ forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
62
+ forward_pass_config: MLPForwardPassConfig | None = None,
63
+ ) -> Float[Array, "batch suffix_tokens channels"]: ...
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class MLPConfigBase(ABC):
68
+ @abstractmethod
69
+ def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> MLPBase: ...
70
+
71
+ @abstractmethod
72
+ def empty(self, model_dim: int, hidden_dim: int) -> MLPBase: ...
16
73
 
17
74
 
18
75
  @dataclass(frozen=True)
19
- class MLPConfig:
76
+ class DenseMLPConfig(MLPConfigBase):
20
77
  linear_config: LinearConfig
21
78
  activation: Activation
79
+ has_up_biases: bool
80
+ has_down_biases: bool
81
+ gate_clipping: tuple[float | None, float | None] | None
82
+ up_clipping: tuple[float | None, float | None] | None
22
83
 
23
- def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MLP":
84
+ def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "DenseMLP":
24
85
  up_key, down_key = jax.random.split(key)
25
- return MLP(
86
+ return DenseMLP(
26
87
  self,
27
88
  up_projection=self.linear_config.random_init(
28
89
  model_dim,
29
90
  (hidden_dim, hidden_dim),
30
- has_biases=False,
91
+ has_biases=self.has_up_biases,
31
92
  key=up_key,
32
93
  ),
33
94
  down_projection=self.linear_config.random_init(
34
95
  hidden_dim,
35
96
  (model_dim,),
36
- has_biases=False,
97
+ has_biases=self.has_down_biases,
37
98
  key=down_key,
38
99
  ),
39
100
  )
40
101
 
41
- def empty(self, model_dim: int, hidden_dim: int) -> "MLP":
42
- return MLP(
102
+ def empty(self, model_dim: int, hidden_dim: int) -> "DenseMLP":
103
+ return DenseMLP(
43
104
  self,
44
105
  up_projection=self.linear_config.empty(
45
106
  model_dim,
46
107
  (hidden_dim, hidden_dim),
47
- has_biases=False,
108
+ has_biases=self.has_up_biases,
48
109
  ),
49
110
  down_projection=self.linear_config.empty(
50
111
  hidden_dim,
51
112
  (model_dim,),
52
- has_biases=False,
113
+ has_biases=self.has_down_biases,
53
114
  ),
54
115
  )
55
116
 
117
+ def random_init_mixture(
118
+ self,
119
+ mixture_size: int,
120
+ model_dim: int,
121
+ hidden_dim: int,
122
+ *,
123
+ key: PRNGKeyArray,
124
+ ) -> "DenseMLP":
125
+ up_key, down_key = jax.random.split(key)
126
+ return DenseMLP(
127
+ self,
128
+ up_projection=self.linear_config.random_init_mixture(
129
+ mixture_size,
130
+ model_dim,
131
+ (hidden_dim, hidden_dim),
132
+ has_biases=self.has_up_biases,
133
+ key=up_key,
134
+ ),
135
+ down_projection=self.linear_config.random_init_mixture(
136
+ mixture_size,
137
+ hidden_dim,
138
+ (model_dim,),
139
+ has_biases=self.has_down_biases,
140
+ key=down_key,
141
+ ),
142
+ )
56
143
 
57
- class MLP(LalamoModule[MLPConfig]):
144
+ def empty_mixture(
145
+ self,
146
+ mixture_size: int,
147
+ model_dim: int,
148
+ hidden_dim: int,
149
+ ) -> "DenseMLP":
150
+ return DenseMLP(
151
+ self,
152
+ up_projection=self.linear_config.empty_mixture(
153
+ mixture_size,
154
+ model_dim,
155
+ (hidden_dim, hidden_dim),
156
+ has_biases=self.has_up_biases,
157
+ ),
158
+ down_projection=self.linear_config.empty_mixture(
159
+ mixture_size,
160
+ hidden_dim,
161
+ (model_dim,),
162
+ has_biases=self.has_down_biases,
163
+ ),
164
+ )
165
+
166
+
167
+ class DenseMLP(MLPBase[DenseMLPConfig]):
58
168
  up_projection: LinearBase
59
169
  down_projection: LinearBase
60
170
 
@@ -70,6 +180,10 @@ class MLP(LalamoModule[MLPConfig]):
70
180
  def hidden_dim(self) -> int:
71
181
  return self.down_projection.input_dim
72
182
 
183
+ @property
184
+ def mixture_size(self) -> int | None:
185
+ return self.up_projection.mixture_size
186
+
73
187
  def __post_init__(self) -> None:
74
188
  up_output_dim, gate_output_dim = self.up_projection.output_dims
75
189
  if up_output_dim != gate_output_dim:
@@ -78,35 +192,293 @@ class MLP(LalamoModule[MLPConfig]):
78
192
  f" the gate output dimension {gate_output_dim}",
79
193
  )
80
194
  (down_output_dim,) = self.down_projection.output_dims
81
- if self.up_projection.input_dim != down_output_dim:
195
+ if (self.up_projection.input_dim, up_output_dim) != (down_output_dim, self.down_projection.input_dim):
82
196
  raise ValueError(
83
- f"Down projection input dimension {down_output_dim} does not match"
84
- f" the up projection output dimension {self.up_projection.input_dim}",
197
+ f"Down projection dimensions {self.down_projection.input_dim, down_output_dim} do not match"
198
+ f" the up projection output dimensions {self.up_projection.input_dim, up_output_dim}",
85
199
  )
86
200
 
87
201
  @eqx.filter_jit
88
- def __call__(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
202
+ def __call__(
203
+ self,
204
+ inputs: Float[Array, "batch suffix_tokens channels"],
205
+ lengths_without_padding: Int[Array, " batch"] | None = None, # noqa: ARG002
206
+ forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN, # noqa: ARG002
207
+ forward_pass_config: MLPForwardPassConfig | None = None, # noqa: ARG002
208
+ ) -> Float[Array, "batch suffix_tokens channels"]:
209
+ return vmap_twice(self.call_unbatched)(inputs)
210
+
211
+ @eqx.filter_jit
212
+ def call_unbatched(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
213
+ if self.mixture_size is not None:
214
+ raise ValueError(
215
+ "Mixtures of linear layers cannot be called directly."
216
+ "They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
217
+ )
89
218
  up_proj, gate = self.up_projection(inputs)
219
+ if self.config.gate_clipping:
220
+ gate = jnp.clip(gate, *self.config.gate_clipping)
221
+ if self.config.up_clipping:
222
+ up_proj = jnp.clip(up_proj, *self.config.up_clipping)
90
223
  gate = self.config.activation(gate)
91
224
  (result,) = self.down_projection(up_proj * gate)
92
225
  return result
93
226
 
94
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree:
227
+ def export_weights(self) -> ParameterTree:
95
228
  return {
96
- "up_projection": self.up_projection.export_weights(weight_layout),
97
- "down_projection": self.down_projection.export_weights(weight_layout),
229
+ "up_projection": self.up_projection.export_weights(),
230
+ "down_projection": self.down_projection.export_weights(),
98
231
  }
99
232
 
100
233
  def import_weights(
101
234
  self,
102
235
  weights: ParameterTree[Array],
103
- weight_layout: WeightLayout = WeightLayout.AUTO,
104
236
  ) -> Self:
105
237
  assert isinstance(weights, Mapping)
106
238
  assert isinstance(weights["up_projection"], Mapping)
107
239
  assert isinstance(weights["down_projection"], Mapping)
108
240
  return replace(
109
241
  self,
110
- up_projection=self.up_projection.import_weights(weights["up_projection"], weight_layout),
111
- down_projection=self.down_projection.import_weights(weights["down_projection"], weight_layout),
242
+ up_projection=self.up_projection.import_weights(weights["up_projection"]),
243
+ down_projection=self.down_projection.import_weights(weights["down_projection"]),
244
+ )
245
+
246
+
247
+ class RoutingMap(eqx.Module):
248
+ expert_mask: Bool[Array, "*batch_tokens experts"]
249
+ expert_weights: Float[Array, "*batch_tokens experts"]
250
+
251
+
252
+ @dataclass(frozen=True)
253
+ class RoutingFunctionBase(ABC):
254
+ def __call__(self, logits: Float[Array, "batch_tokens experts"], num_active: int) -> RoutingMap:
255
+ return vmap(partial(self.call_unbatched, num_active=num_active))(logits)
256
+
257
+ @abstractmethod
258
+ def call_unbatched(self, logits: Float[Array, " experts"], num_active: int) -> RoutingMap: ...
259
+
260
+
261
+ @dataclass(frozen=True)
262
+ class SoftmaxRouting(RoutingFunctionBase):
263
+ def call_unbatched(self, logits: Float[Array, " experts"], num_active: int) -> RoutingMap:
264
+ active_logits, active_indices = jax.lax.top_k(logits, num_active)
265
+ active_weights = jax.nn.softmax(active_logits)
266
+ mask = jnp.zeros_like(logits, dtype=bool)
267
+ mask = mask.at[active_indices].set(True)
268
+ expert_weights = jnp.zeros_like(logits)
269
+ expert_weights = expert_weights.at[active_indices].set(active_weights)
270
+ return RoutingMap(expert_mask=mask, expert_weights=expert_weights)
271
+
272
+
273
+ RoutingFunction = SoftmaxRouting | DummyUnionMember
274
+
275
+
276
+ register_config_union(RoutingFunction)
277
+
278
+
279
+ @dataclass(frozen=True)
280
+ class MixtureOfExpertsConfig(ABC):
281
+ mixture_size: int
282
+ num_experts_per_token: int
283
+ routing_function: RoutingFunction
284
+
285
+ router_config: LinearConfig
286
+ router_has_biases: bool
287
+
288
+ expert_config: DenseMLPConfig
289
+
290
+ def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MixtureOfExperts":
291
+ experts_key, router_key = jax.random.split(key)
292
+ router = self.router_config.random_init(
293
+ model_dim,
294
+ (self.mixture_size,),
295
+ has_biases=self.router_has_biases,
296
+ key=router_key,
297
+ )
298
+ experts = self.expert_config.random_init_mixture(self.mixture_size, model_dim, hidden_dim, key=experts_key)
299
+ return MixtureOfExperts(self, router, experts)
300
+
301
+ def empty(self, model_dim: int, hidden_dim: int) -> "MixtureOfExperts":
302
+ router = self.router_config.empty(model_dim, (self.mixture_size,), has_biases=self.router_has_biases)
303
+ experts = self.expert_config.empty_mixture(self.mixture_size, model_dim, hidden_dim)
304
+ return MixtureOfExperts(self, router, experts)
305
+
306
+
307
+ class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
308
+ router: LinearBase
309
+ experts: DenseMLP
310
+
311
+ @property
312
+ def mixture_size(self) -> int:
313
+ return self.config.mixture_size
314
+
315
+ @property
316
+ def num_experts_per_token(self) -> int:
317
+ return self.config.num_experts_per_token
318
+
319
+ @property
320
+ def activation_precision(self) -> DTypeLike:
321
+ return self.experts.activation_precision
322
+
323
+ @property
324
+ def model_dim(self) -> int:
325
+ return self.experts.model_dim
326
+
327
+ @property
328
+ def hidden_dim(self) -> int:
329
+ return self.experts.hidden_dim
330
+
331
+ def __post_init__(self) -> None:
332
+ if self.router.input_dim != self.experts.model_dim:
333
+ raise ValueError(
334
+ f"Router input dimension ({self.router.input_dim}) must match experts model_dim"
335
+ f" ({self.experts.model_dim}).",
336
+ )
337
+
338
+ (router_output_dim,) = self.router.output_dims
339
+ if router_output_dim != self.mixture_size:
340
+ raise ValueError(
341
+ f"Router output dimension ({router_output_dim}) must equal mixture_size ({self.mixture_size}).",
342
+ )
343
+
344
+ if self.experts.mixture_size != self.mixture_size:
345
+ raise ValueError(
346
+ f"Experts mixture_size ({self.experts.mixture_size}) does not match specified mixture_size"
347
+ f" ({self.mixture_size}).",
348
+ )
349
+
350
+ def __call__(
351
+ self,
352
+ inputs: Float[Array, "batch suffix_tokens channels"],
353
+ lengths_without_padding: Int[Array, " batch"] | None = None,
354
+ forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
355
+ forward_pass_config: MLPForwardPassConfig | None = None,
356
+ ) -> Float[Array, "batch suffix_tokens channels"]:
357
+ match forward_pass_mode:
358
+ case ForwardPassMode.MULTI_TOKEN:
359
+ return self.call_prefill_mode(inputs, lengths_without_padding, forward_pass_config)
360
+ case ForwardPassMode.SINGLE_TOKEN:
361
+ return self.call_decode_mode(inputs)
362
+
363
+ @eqx.filter_jit
364
+ def call_decode_mode(
365
+ self,
366
+ inputs: Float[Array, "batch suffix_tokens channels"],
367
+ ) -> Float[Array, "batch suffix_tokens channels"]:
368
+ def per_token(x: Float[Array, " channels"]) -> Float[Array, " channels"]:
369
+ (router_logits,) = self.router(x)
370
+ routing = self.config.routing_function.call_unbatched(
371
+ router_logits,
372
+ num_active=self.num_experts_per_token,
373
+ )
374
+ active_indices = jnp.flatnonzero(routing.expert_mask, size=self.num_experts_per_token)
375
+ active_weights = routing.expert_weights[active_indices]
376
+
377
+ def apply_one(idx: Int[Array, ""], w: Float[Array, ""]) -> Float[Array, " channels"]:
378
+ selected_expert = jax.tree_util.tree_map(
379
+ lambda leaf: jax.lax.dynamic_index_in_dim(leaf, idx, axis=0, keepdims=False),
380
+ self.experts,
381
+ )
382
+ return selected_expert.call_unbatched(x) * w
383
+
384
+ contributions = vmap(apply_one)(active_indices, active_weights)
385
+ return jnp.sum(contributions, axis=0)
386
+
387
+ return vmap_twice(per_token)(inputs)
388
+
389
+ @eqx.filter_jit
390
+ def call_prefill_mode(
391
+ self,
392
+ inputs: Float[Array, "batch suffix_tokens channels"],
393
+ lengths_without_padding: Int[Array, " batch"] | None = None,
394
+ forward_pass_config: MLPForwardPassConfig | None = None,
395
+ ) -> Float[Array, "batch suffix_tokens channels"]:
396
+ forward_pass_config = forward_pass_config or MLPForwardPassConfig()
397
+ batch_size, sequence_length, _ = inputs.shape
398
+ num_tokens = batch_size * sequence_length
399
+ if lengths_without_padding is None:
400
+ lengths_without_padding = jnp.ones(batch_size, dtype=jnp.int32) * sequence_length
401
+ padding_mask = jnp.arange(sequence_length)[None, :] < lengths_without_padding[:, None]
402
+
403
+ flattened_inputs = rearrange(inputs, "batch suffix_tokens channels -> (batch suffix_tokens) channels")
404
+ flattened_padding_mask = rearrange(padding_mask, "batch suffix_tokens -> (batch suffix_tokens)")
405
+
406
+ (router_logits,) = vmap(self.router)(flattened_inputs)
407
+ routing_map = self.config.routing_function(router_logits, self.num_experts_per_token)
408
+ token_mask = rearrange(
409
+ routing_map.expert_mask & flattened_padding_mask[:, None],
410
+ "tokens experts -> experts tokens",
411
+ )
412
+ expert_weights = rearrange(
413
+ routing_map.expert_weights,
414
+ "tokens experts -> experts tokens",
415
+ )
416
+ expert_weights = jnp.where(token_mask, expert_weights, 0.0)
417
+
418
+ chunk_size = math.ceil(num_tokens * forward_pass_config.moe_chunk_size_ratio)
419
+ num_padded_tokens = math.ceil(num_tokens / chunk_size) * chunk_size
420
+ token_indices = vmap(lambda m: jnp.flatnonzero(m, size=num_padded_tokens, fill_value=_SENTINEL))(token_mask)
421
+ chunked_token_indices = rearrange(
422
+ token_indices,
423
+ "experts (chunks chunk_tokens) -> chunks experts chunk_tokens",
424
+ chunk_tokens=chunk_size,
425
+ )
426
+
427
+ def loop_iteration(
428
+ accumulator: Float[Array, "tokens channels"],
429
+ token_indices_for_chunk: Int[Array, "experts chunk_tokens"],
430
+ ) -> tuple[Float[Array, "tokens channels"], None]:
431
+ def inner() -> Float[Array, "tokens channels"]:
432
+ weights_for_chunk = jnp.take_along_axis(
433
+ expert_weights,
434
+ token_indices_for_chunk,
435
+ axis=1,
436
+ mode="fill",
437
+ fill_value=0.0,
438
+ )
439
+
440
+ def run_expert(
441
+ expert: DenseMLP,
442
+ indices: Int[Array, " tokens_per_chunk"],
443
+ weights: Float[Array, " tokens_per_chunk"],
444
+ ) -> Float[Array, "tokens_per_chunk channels"]:
445
+ inputs = flattened_inputs.at[indices].get(mode="fill", fill_value=0.0)
446
+ return vmap(expert.call_unbatched)(inputs) * weights[:, None]
447
+
448
+ expert_outputs = vmap(run_expert)(self.experts, token_indices_for_chunk, weights_for_chunk)
449
+ return accumulator.at[token_indices_for_chunk].add(
450
+ expert_outputs,
451
+ mode="drop",
452
+ )
453
+
454
+ return jax.lax.cond(jnp.any(token_indices_for_chunk != _SENTINEL), inner, lambda: accumulator), None
455
+
456
+ result, _ = jax.lax.scan(loop_iteration, jnp.zeros_like(flattened_inputs), chunked_token_indices)
457
+ return rearrange(result, "(batch suffix_tokens) channels -> batch suffix_tokens channels", batch=batch_size)
458
+
459
+ def export_weights(
460
+ self,
461
+ ) -> ParameterTree[Array]:
462
+ return {
463
+ "router": self.router.export_weights(),
464
+ "experts": self.experts.export_weights(),
465
+ }
466
+
467
+ def import_weights(
468
+ self,
469
+ weights: ParameterTree[Array],
470
+ ) -> Self:
471
+ assert isinstance(weights, Mapping)
472
+ assert isinstance(weights["router"], Mapping)
473
+ assert isinstance(weights["experts"], Mapping)
474
+ return replace(
475
+ self,
476
+ router=self.router.import_weights(weights["router"]),
477
+ experts=self.experts.import_weights(weights["experts"]),
112
478
  )
479
+
480
+
481
+ MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
482
+
483
+
484
+ register_config_union(MLPConfig)
@@ -10,7 +10,7 @@ from jaxtyping import Array, DTypeLike, Float
10
10
 
11
11
  from lalamo.common import ParameterTree, dummy_array
12
12
 
13
- from .common import LalamoModule, WeightLayout
13
+ from .common import LalamoModule
14
14
 
15
15
  __all__ = [
16
16
  "RMSNorm",
@@ -83,13 +83,12 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
83
83
  result = normalized_x * adjusted_scales
84
84
  return result.astype(inputs.dtype)
85
85
 
86
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
86
+ def export_weights(self) -> ParameterTree:
87
87
  return {"scales": self.scales}
88
88
 
89
89
  def import_weights(
90
90
  self,
91
91
  weights: ParameterTree[Array],
92
- weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
93
92
  ) -> Self:
94
93
  assert isinstance(weights, Mapping)
95
94
  return replace(self, scales=weights["scales"])
lalamo/modules/rope.py CHANGED
@@ -25,7 +25,7 @@ from jaxtyping import Array, DTypeLike, Float, Int
25
25
 
26
26
  from lalamo.common import ParameterTree
27
27
 
28
- from .common import LalamoModule, WeightLayout, register_config_union
28
+ from .common import LalamoModule, register_config_union
29
29
 
30
30
  __all__ = [
31
31
  "LinearScalingRoPEConfig",
@@ -39,22 +39,25 @@ __all__ = [
39
39
 
40
40
 
41
41
  class PositionalEmbeddings(eqx.Module):
42
- cosines: Float[Array, "tokens head_channels"]
43
- sines: Float[Array, "tokens head_channels"]
42
+ cosines: Float[Array, "*batch tokens head_channels"]
43
+ sines: Float[Array, "*batch tokens head_channels"]
44
44
 
45
45
  @property
46
46
  def head_dim(self) -> int:
47
47
  return self.cosines.shape[-1]
48
48
 
49
- def rotate_half(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
49
+ def rotate_half(
50
+ self,
51
+ heads: Float[Array, "*batch tokens head_channels"],
52
+ ) -> Float[Array, "*batch tokens head_channels"]:
50
53
  x1 = heads[..., : self.head_dim // 2]
51
54
  x2 = heads[..., self.head_dim // 2 :]
52
55
  return jnp.concatenate((-x2, x1), axis=-1)
53
56
 
54
- def apply(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
57
+ def apply(self, heads: Float[Array, "*batch tokens head_channels"]) -> Float[Array, "*batch tokens head_channels"]:
55
58
  return heads * self.cosines + self.rotate_half(heads) * self.sines
56
59
 
57
- def export(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree: # noqa: ARG002
60
+ def export(self) -> ParameterTree:
58
61
  return dict(
59
62
  cosines=self.cosines,
60
63
  sines=self.sines,
@@ -105,9 +108,9 @@ class RoPE(LalamoModule[RoPEConfigBase]):
105
108
 
106
109
  def __post_init__(self) -> None:
107
110
  num_tokens, _ = self.sines.shape
108
- if num_tokens != self.config.max_sequence_length:
111
+ if num_tokens > self.config.max_sequence_length:
109
112
  raise ValueError(
110
- f"{num_tokens} does not match the specified max sequence length {self.config.max_sequence_length}",
113
+ f"{num_tokens} exceeds the specified max sequence length {self.config.max_sequence_length}",
111
114
  )
112
115
  if self.cosines.dtype != self.config.precision:
113
116
  raise ValueError(
@@ -140,7 +143,7 @@ class RoPE(LalamoModule[RoPEConfigBase]):
140
143
  sines=self.sines[timesteps],
141
144
  )
142
145
 
143
- def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterTree[Array]: # noqa: ARG002
146
+ def export_weights(self) -> ParameterTree[Array]:
144
147
  return {
145
148
  "cosines": self.cosines,
146
149
  "sines": self.sines,
@@ -149,7 +152,6 @@ class RoPE(LalamoModule[RoPEConfigBase]):
149
152
  def import_weights(
150
153
  self,
151
154
  weights: ParameterTree[Array],
152
- weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
153
155
  ) -> "RoPE":
154
156
  assert isinstance(weights, Mapping)
155
157
  return replace(self, cosines=weights["cosines"], sines=weights["sines"])
@@ -199,13 +201,15 @@ class LlamaRoPEConfig(RoPEConfigBase):
199
201
  @dataclass(frozen=True)
200
202
  class YARNRoPEConfig(RoPEConfigBase):
201
203
  scaling_factor: float
204
+ original_context_length: int
202
205
  beta_fast: float
203
206
  beta_slow: float
207
+ truncate: bool
204
208
 
205
209
  @classmethod
206
- def _find_correction_dim(cls, num_rotations: float, dim: int, base: float, max_position_embeddings: int) -> float:
210
+ def _find_correction_dim(cls, num_rotations: float, dim: int, base: float, original_context_length: int) -> float:
207
211
  """Inverse dimension formula to find the dimension based on the number of rotations"""
208
- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
212
+ return (dim * math.log(original_context_length / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
209
213
 
210
214
  @classmethod
211
215
  def _find_correction_range(
@@ -214,19 +218,25 @@ class YARNRoPEConfig(RoPEConfigBase):
214
218
  high_rot: float,
215
219
  dim: int,
216
220
  base: float,
217
- max_position_embeddings: int,
218
- ) -> tuple[int, int]:
221
+ original_context_length: int,
222
+ truncate: bool,
223
+ ) -> tuple[float, float]:
219
224
  """Find dimension range bounds based on rotations"""
220
- low = math.floor(cls._find_correction_dim(low_rot, dim, base, max_position_embeddings))
221
- high = math.ceil(cls._find_correction_dim(high_rot, dim, base, max_position_embeddings))
222
- return max(low, 0), min(high, dim - 1)
225
+ low = cls._find_correction_dim(low_rot, dim, base, original_context_length)
226
+ high = cls._find_correction_dim(high_rot, dim, base, original_context_length)
227
+ if truncate:
228
+ low = math.floor(low)
229
+ high = math.ceil(high)
230
+ return max(low, 0.0), min(high, float(dim - 1))
223
231
 
224
232
  @classmethod
225
233
  def _linear_ramp_factor(cls, min_value: float, max_value: float, dim: int) -> Float[Array, " head_dim"]:
226
234
  if min_value == max_value:
227
235
  max_value += 0.001 # Prevent singularity
228
236
 
229
- linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_value) / (max_value - min_value)
237
+ min_v = jnp.float32(min_value)
238
+ max_v = jnp.float32(max_value)
239
+ linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_v) / (max_v - min_v)
230
240
  ramp_func = jnp.clip(linear_func, 0, 1)
231
241
  return ramp_func
232
242
 
@@ -234,7 +244,7 @@ class YARNRoPEConfig(RoPEConfigBase):
234
244
  self,
235
245
  inverse_frequencies: Float[Array, " tokens"],
236
246
  head_dim: int,
237
- max_sequence_length: int,
247
+ max_sequence_length: int, # noqa: ARG002
238
248
  ) -> Float[Array, " tokens"]:
239
249
  scaled_frequencies = inverse_frequencies / self.scaling_factor
240
250
 
@@ -243,7 +253,8 @@ class YARNRoPEConfig(RoPEConfigBase):
243
253
  self.beta_slow,
244
254
  head_dim,
245
255
  self.base,
246
- max_sequence_length,
256
+ self.original_context_length,
257
+ self.truncate,
247
258
  )
248
259
 
249
260
  # Get n-dimensional rotational scaling corrected for extrapolation
@@ -251,7 +262,7 @@ class YARNRoPEConfig(RoPEConfigBase):
251
262
  return scaled_frequencies * (1 - smoothing_factor) + inverse_frequencies * smoothing_factor
252
263
 
253
264
  @property
254
- def attention_scaling_factor(self) -> float:
265
+ def _attention_scaling_factor(self) -> float:
255
266
  return 0.1 * math.log(self.scaling_factor) + 1.0
256
267
 
257
268
 
lalamo/modules/utils.py CHANGED
@@ -1,11 +1,21 @@
1
+ from collections.abc import Callable
2
+
1
3
  import jax
4
+ from jax import vmap
2
5
  from jaxtyping import Array, Float
3
6
 
4
7
  __all__ = [
5
8
  "apply_soft_capping",
9
+ "vmap_twice",
6
10
  ]
7
11
 
8
12
 
13
+ def vmap_twice[F: Callable](
14
+ func: F,
15
+ ) -> F:
16
+ return vmap(vmap(func, in_axes=0), in_axes=0)
17
+
18
+
9
19
  def apply_soft_capping(
10
20
  values: Float[Array, "*"],
11
21
  soft_cap: float,