lalamo 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +273 -45
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +10 -6
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -3
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/METADATA +11 -4
- lalamo-0.4.1.dist-info/RECORD +71 -0
- lalamo-0.3.4.dist-info/RECORD +0 -59
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/WHEEL +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.4.dist-info → lalamo-0.4.1.dist-info}/top_level.txt +0 -0
lalamo/modules/mlp.py
CHANGED
|
@@ -1,60 +1,170 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
1
3
|
from collections.abc import Mapping
|
|
2
4
|
from dataclasses import dataclass, replace
|
|
5
|
+
from functools import partial
|
|
3
6
|
from typing import Self
|
|
4
7
|
|
|
5
8
|
import equinox as eqx
|
|
6
9
|
import jax
|
|
7
|
-
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from einops import rearrange
|
|
12
|
+
from jax import vmap
|
|
13
|
+
from jaxtyping import Array, Bool, DTypeLike, Float, Int, PRNGKeyArray
|
|
8
14
|
|
|
9
15
|
from lalamo.common import ParameterTree
|
|
16
|
+
from lalamo.modules.utils import vmap_twice
|
|
10
17
|
|
|
11
18
|
from .activations import Activation
|
|
12
|
-
from .common import LalamoModule,
|
|
19
|
+
from .common import DummyUnionMember, ForwardPassMode, LalamoModule, register_config_union
|
|
13
20
|
from .linear import LinearBase, LinearConfig
|
|
14
21
|
|
|
15
|
-
__all__ = [
|
|
22
|
+
__all__ = [
|
|
23
|
+
"DenseMLP",
|
|
24
|
+
"DenseMLPConfig",
|
|
25
|
+
"MLPBase",
|
|
26
|
+
"MLPConfig",
|
|
27
|
+
"MLPForwardPassConfig",
|
|
28
|
+
"MixtureOfExperts",
|
|
29
|
+
"MixtureOfExpertsConfig",
|
|
30
|
+
"RoutingFunction",
|
|
31
|
+
"SoftmaxRouting",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_SENTINEL = 2**31 - 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class MLPForwardPassConfig:
|
|
40
|
+
moe_chunk_size_ratio: float = 0.2
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MLPBase[ConfigT: MLPConfig](LalamoModule[ConfigT]):
|
|
44
|
+
@property
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def activation_precision(self) -> DTypeLike: ...
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def model_dim(self) -> int: ...
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def hidden_dim(self) -> int: ...
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def __call__(
|
|
58
|
+
self,
|
|
59
|
+
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
60
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
61
|
+
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
62
|
+
forward_pass_config: MLPForwardPassConfig | None = None,
|
|
63
|
+
) -> Float[Array, "batch suffix_tokens channels"]: ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class MLPConfigBase(ABC):
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> MLPBase: ...
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def empty(self, model_dim: int, hidden_dim: int) -> MLPBase: ...
|
|
16
73
|
|
|
17
74
|
|
|
18
75
|
@dataclass(frozen=True)
|
|
19
|
-
class
|
|
76
|
+
class DenseMLPConfig(MLPConfigBase):
|
|
20
77
|
linear_config: LinearConfig
|
|
21
78
|
activation: Activation
|
|
79
|
+
has_up_biases: bool
|
|
80
|
+
has_down_biases: bool
|
|
81
|
+
gate_clipping: tuple[float | None, float | None] | None
|
|
82
|
+
up_clipping: tuple[float | None, float | None] | None
|
|
22
83
|
|
|
23
|
-
def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "
|
|
84
|
+
def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "DenseMLP":
|
|
24
85
|
up_key, down_key = jax.random.split(key)
|
|
25
|
-
return
|
|
86
|
+
return DenseMLP(
|
|
26
87
|
self,
|
|
27
88
|
up_projection=self.linear_config.random_init(
|
|
28
89
|
model_dim,
|
|
29
90
|
(hidden_dim, hidden_dim),
|
|
30
|
-
has_biases=
|
|
91
|
+
has_biases=self.has_up_biases,
|
|
31
92
|
key=up_key,
|
|
32
93
|
),
|
|
33
94
|
down_projection=self.linear_config.random_init(
|
|
34
95
|
hidden_dim,
|
|
35
96
|
(model_dim,),
|
|
36
|
-
has_biases=
|
|
97
|
+
has_biases=self.has_down_biases,
|
|
37
98
|
key=down_key,
|
|
38
99
|
),
|
|
39
100
|
)
|
|
40
101
|
|
|
41
|
-
def empty(self, model_dim: int, hidden_dim: int) -> "
|
|
42
|
-
return
|
|
102
|
+
def empty(self, model_dim: int, hidden_dim: int) -> "DenseMLP":
|
|
103
|
+
return DenseMLP(
|
|
43
104
|
self,
|
|
44
105
|
up_projection=self.linear_config.empty(
|
|
45
106
|
model_dim,
|
|
46
107
|
(hidden_dim, hidden_dim),
|
|
47
|
-
has_biases=
|
|
108
|
+
has_biases=self.has_up_biases,
|
|
48
109
|
),
|
|
49
110
|
down_projection=self.linear_config.empty(
|
|
50
111
|
hidden_dim,
|
|
51
112
|
(model_dim,),
|
|
52
|
-
has_biases=
|
|
113
|
+
has_biases=self.has_down_biases,
|
|
53
114
|
),
|
|
54
115
|
)
|
|
55
116
|
|
|
117
|
+
def random_init_mixture(
|
|
118
|
+
self,
|
|
119
|
+
mixture_size: int,
|
|
120
|
+
model_dim: int,
|
|
121
|
+
hidden_dim: int,
|
|
122
|
+
*,
|
|
123
|
+
key: PRNGKeyArray,
|
|
124
|
+
) -> "DenseMLP":
|
|
125
|
+
up_key, down_key = jax.random.split(key)
|
|
126
|
+
return DenseMLP(
|
|
127
|
+
self,
|
|
128
|
+
up_projection=self.linear_config.random_init_mixture(
|
|
129
|
+
mixture_size,
|
|
130
|
+
model_dim,
|
|
131
|
+
(hidden_dim, hidden_dim),
|
|
132
|
+
has_biases=self.has_up_biases,
|
|
133
|
+
key=up_key,
|
|
134
|
+
),
|
|
135
|
+
down_projection=self.linear_config.random_init_mixture(
|
|
136
|
+
mixture_size,
|
|
137
|
+
hidden_dim,
|
|
138
|
+
(model_dim,),
|
|
139
|
+
has_biases=self.has_down_biases,
|
|
140
|
+
key=down_key,
|
|
141
|
+
),
|
|
142
|
+
)
|
|
56
143
|
|
|
57
|
-
|
|
144
|
+
def empty_mixture(
|
|
145
|
+
self,
|
|
146
|
+
mixture_size: int,
|
|
147
|
+
model_dim: int,
|
|
148
|
+
hidden_dim: int,
|
|
149
|
+
) -> "DenseMLP":
|
|
150
|
+
return DenseMLP(
|
|
151
|
+
self,
|
|
152
|
+
up_projection=self.linear_config.empty_mixture(
|
|
153
|
+
mixture_size,
|
|
154
|
+
model_dim,
|
|
155
|
+
(hidden_dim, hidden_dim),
|
|
156
|
+
has_biases=self.has_up_biases,
|
|
157
|
+
),
|
|
158
|
+
down_projection=self.linear_config.empty_mixture(
|
|
159
|
+
mixture_size,
|
|
160
|
+
hidden_dim,
|
|
161
|
+
(model_dim,),
|
|
162
|
+
has_biases=self.has_down_biases,
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class DenseMLP(MLPBase[DenseMLPConfig]):
|
|
58
168
|
up_projection: LinearBase
|
|
59
169
|
down_projection: LinearBase
|
|
60
170
|
|
|
@@ -70,6 +180,10 @@ class MLP(LalamoModule[MLPConfig]):
|
|
|
70
180
|
def hidden_dim(self) -> int:
|
|
71
181
|
return self.down_projection.input_dim
|
|
72
182
|
|
|
183
|
+
@property
|
|
184
|
+
def mixture_size(self) -> int | None:
|
|
185
|
+
return self.up_projection.mixture_size
|
|
186
|
+
|
|
73
187
|
def __post_init__(self) -> None:
|
|
74
188
|
up_output_dim, gate_output_dim = self.up_projection.output_dims
|
|
75
189
|
if up_output_dim != gate_output_dim:
|
|
@@ -78,35 +192,293 @@ class MLP(LalamoModule[MLPConfig]):
|
|
|
78
192
|
f" the gate output dimension {gate_output_dim}",
|
|
79
193
|
)
|
|
80
194
|
(down_output_dim,) = self.down_projection.output_dims
|
|
81
|
-
if self.up_projection.input_dim != down_output_dim:
|
|
195
|
+
if (self.up_projection.input_dim, up_output_dim) != (down_output_dim, self.down_projection.input_dim):
|
|
82
196
|
raise ValueError(
|
|
83
|
-
f"Down projection
|
|
84
|
-
f" the up projection output
|
|
197
|
+
f"Down projection dimensions {self.down_projection.input_dim, down_output_dim} do not match"
|
|
198
|
+
f" the up projection output dimensions {self.up_projection.input_dim, up_output_dim}",
|
|
85
199
|
)
|
|
86
200
|
|
|
87
201
|
@eqx.filter_jit
|
|
88
|
-
def __call__(
|
|
202
|
+
def __call__(
|
|
203
|
+
self,
|
|
204
|
+
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
205
|
+
lengths_without_padding: Int[Array, " batch"] | None = None, # noqa: ARG002
|
|
206
|
+
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN, # noqa: ARG002
|
|
207
|
+
forward_pass_config: MLPForwardPassConfig | None = None, # noqa: ARG002
|
|
208
|
+
) -> Float[Array, "batch suffix_tokens channels"]:
|
|
209
|
+
return vmap_twice(self.call_unbatched)(inputs)
|
|
210
|
+
|
|
211
|
+
@eqx.filter_jit
|
|
212
|
+
def call_unbatched(self, inputs: Float[Array, " channels"]) -> Float[Array, " channels"]:
|
|
213
|
+
if self.mixture_size is not None:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"Mixtures of linear layers cannot be called directly."
|
|
216
|
+
"They are intended to be used with methods eqx.filter_vmap or lax.scan instead.",
|
|
217
|
+
)
|
|
89
218
|
up_proj, gate = self.up_projection(inputs)
|
|
219
|
+
if self.config.gate_clipping:
|
|
220
|
+
gate = jnp.clip(gate, *self.config.gate_clipping)
|
|
221
|
+
if self.config.up_clipping:
|
|
222
|
+
up_proj = jnp.clip(up_proj, *self.config.up_clipping)
|
|
90
223
|
gate = self.config.activation(gate)
|
|
91
224
|
(result,) = self.down_projection(up_proj * gate)
|
|
92
225
|
return result
|
|
93
226
|
|
|
94
|
-
def export_weights(self
|
|
227
|
+
def export_weights(self) -> ParameterTree:
|
|
95
228
|
return {
|
|
96
|
-
"up_projection": self.up_projection.export_weights(
|
|
97
|
-
"down_projection": self.down_projection.export_weights(
|
|
229
|
+
"up_projection": self.up_projection.export_weights(),
|
|
230
|
+
"down_projection": self.down_projection.export_weights(),
|
|
98
231
|
}
|
|
99
232
|
|
|
100
233
|
def import_weights(
|
|
101
234
|
self,
|
|
102
235
|
weights: ParameterTree[Array],
|
|
103
|
-
weight_layout: WeightLayout = WeightLayout.AUTO,
|
|
104
236
|
) -> Self:
|
|
105
237
|
assert isinstance(weights, Mapping)
|
|
106
238
|
assert isinstance(weights["up_projection"], Mapping)
|
|
107
239
|
assert isinstance(weights["down_projection"], Mapping)
|
|
108
240
|
return replace(
|
|
109
241
|
self,
|
|
110
|
-
up_projection=self.up_projection.import_weights(weights["up_projection"]
|
|
111
|
-
down_projection=self.down_projection.import_weights(weights["down_projection"]
|
|
242
|
+
up_projection=self.up_projection.import_weights(weights["up_projection"]),
|
|
243
|
+
down_projection=self.down_projection.import_weights(weights["down_projection"]),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class RoutingMap(eqx.Module):
|
|
248
|
+
expert_mask: Bool[Array, "*batch_tokens experts"]
|
|
249
|
+
expert_weights: Float[Array, "*batch_tokens experts"]
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@dataclass(frozen=True)
|
|
253
|
+
class RoutingFunctionBase(ABC):
|
|
254
|
+
def __call__(self, logits: Float[Array, "batch_tokens experts"], num_active: int) -> RoutingMap:
|
|
255
|
+
return vmap(partial(self.call_unbatched, num_active=num_active))(logits)
|
|
256
|
+
|
|
257
|
+
@abstractmethod
|
|
258
|
+
def call_unbatched(self, logits: Float[Array, " experts"], num_active: int) -> RoutingMap: ...
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dataclass(frozen=True)
|
|
262
|
+
class SoftmaxRouting(RoutingFunctionBase):
|
|
263
|
+
def call_unbatched(self, logits: Float[Array, " experts"], num_active: int) -> RoutingMap:
|
|
264
|
+
active_logits, active_indices = jax.lax.top_k(logits, num_active)
|
|
265
|
+
active_weights = jax.nn.softmax(active_logits)
|
|
266
|
+
mask = jnp.zeros_like(logits, dtype=bool)
|
|
267
|
+
mask = mask.at[active_indices].set(True)
|
|
268
|
+
expert_weights = jnp.zeros_like(logits)
|
|
269
|
+
expert_weights = expert_weights.at[active_indices].set(active_weights)
|
|
270
|
+
return RoutingMap(expert_mask=mask, expert_weights=expert_weights)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
RoutingFunction = SoftmaxRouting | DummyUnionMember
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
register_config_union(RoutingFunction)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@dataclass(frozen=True)
|
|
280
|
+
class MixtureOfExpertsConfig(ABC):
|
|
281
|
+
mixture_size: int
|
|
282
|
+
num_experts_per_token: int
|
|
283
|
+
routing_function: RoutingFunction
|
|
284
|
+
|
|
285
|
+
router_config: LinearConfig
|
|
286
|
+
router_has_biases: bool
|
|
287
|
+
|
|
288
|
+
expert_config: DenseMLPConfig
|
|
289
|
+
|
|
290
|
+
def random_init(self, model_dim: int, hidden_dim: int, *, key: PRNGKeyArray) -> "MixtureOfExperts":
|
|
291
|
+
experts_key, router_key = jax.random.split(key)
|
|
292
|
+
router = self.router_config.random_init(
|
|
293
|
+
model_dim,
|
|
294
|
+
(self.mixture_size,),
|
|
295
|
+
has_biases=self.router_has_biases,
|
|
296
|
+
key=router_key,
|
|
297
|
+
)
|
|
298
|
+
experts = self.expert_config.random_init_mixture(self.mixture_size, model_dim, hidden_dim, key=experts_key)
|
|
299
|
+
return MixtureOfExperts(self, router, experts)
|
|
300
|
+
|
|
301
|
+
def empty(self, model_dim: int, hidden_dim: int) -> "MixtureOfExperts":
|
|
302
|
+
router = self.router_config.empty(model_dim, (self.mixture_size,), has_biases=self.router_has_biases)
|
|
303
|
+
experts = self.expert_config.empty_mixture(self.mixture_size, model_dim, hidden_dim)
|
|
304
|
+
return MixtureOfExperts(self, router, experts)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class MixtureOfExperts(MLPBase[MixtureOfExpertsConfig]):
|
|
308
|
+
router: LinearBase
|
|
309
|
+
experts: DenseMLP
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def mixture_size(self) -> int:
|
|
313
|
+
return self.config.mixture_size
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def num_experts_per_token(self) -> int:
|
|
317
|
+
return self.config.num_experts_per_token
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def activation_precision(self) -> DTypeLike:
|
|
321
|
+
return self.experts.activation_precision
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def model_dim(self) -> int:
|
|
325
|
+
return self.experts.model_dim
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def hidden_dim(self) -> int:
|
|
329
|
+
return self.experts.hidden_dim
|
|
330
|
+
|
|
331
|
+
def __post_init__(self) -> None:
|
|
332
|
+
if self.router.input_dim != self.experts.model_dim:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"Router input dimension ({self.router.input_dim}) must match experts model_dim"
|
|
335
|
+
f" ({self.experts.model_dim}).",
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
(router_output_dim,) = self.router.output_dims
|
|
339
|
+
if router_output_dim != self.mixture_size:
|
|
340
|
+
raise ValueError(
|
|
341
|
+
f"Router output dimension ({router_output_dim}) must equal mixture_size ({self.mixture_size}).",
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if self.experts.mixture_size != self.mixture_size:
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"Experts mixture_size ({self.experts.mixture_size}) does not match specified mixture_size"
|
|
347
|
+
f" ({self.mixture_size}).",
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def __call__(
|
|
351
|
+
self,
|
|
352
|
+
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
353
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
354
|
+
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
355
|
+
forward_pass_config: MLPForwardPassConfig | None = None,
|
|
356
|
+
) -> Float[Array, "batch suffix_tokens channels"]:
|
|
357
|
+
match forward_pass_mode:
|
|
358
|
+
case ForwardPassMode.MULTI_TOKEN:
|
|
359
|
+
return self.call_prefill_mode(inputs, lengths_without_padding, forward_pass_config)
|
|
360
|
+
case ForwardPassMode.SINGLE_TOKEN:
|
|
361
|
+
return self.call_decode_mode(inputs)
|
|
362
|
+
|
|
363
|
+
@eqx.filter_jit
|
|
364
|
+
def call_decode_mode(
|
|
365
|
+
self,
|
|
366
|
+
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
367
|
+
) -> Float[Array, "batch suffix_tokens channels"]:
|
|
368
|
+
def per_token(x: Float[Array, " channels"]) -> Float[Array, " channels"]:
|
|
369
|
+
(router_logits,) = self.router(x)
|
|
370
|
+
routing = self.config.routing_function.call_unbatched(
|
|
371
|
+
router_logits,
|
|
372
|
+
num_active=self.num_experts_per_token,
|
|
373
|
+
)
|
|
374
|
+
active_indices = jnp.flatnonzero(routing.expert_mask, size=self.num_experts_per_token)
|
|
375
|
+
active_weights = routing.expert_weights[active_indices]
|
|
376
|
+
|
|
377
|
+
def apply_one(idx: Int[Array, ""], w: Float[Array, ""]) -> Float[Array, " channels"]:
|
|
378
|
+
selected_expert = jax.tree_util.tree_map(
|
|
379
|
+
lambda leaf: jax.lax.dynamic_index_in_dim(leaf, idx, axis=0, keepdims=False),
|
|
380
|
+
self.experts,
|
|
381
|
+
)
|
|
382
|
+
return selected_expert.call_unbatched(x) * w
|
|
383
|
+
|
|
384
|
+
contributions = vmap(apply_one)(active_indices, active_weights)
|
|
385
|
+
return jnp.sum(contributions, axis=0)
|
|
386
|
+
|
|
387
|
+
return vmap_twice(per_token)(inputs)
|
|
388
|
+
|
|
389
|
+
@eqx.filter_jit
|
|
390
|
+
def call_prefill_mode(
|
|
391
|
+
self,
|
|
392
|
+
inputs: Float[Array, "batch suffix_tokens channels"],
|
|
393
|
+
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
394
|
+
forward_pass_config: MLPForwardPassConfig | None = None,
|
|
395
|
+
) -> Float[Array, "batch suffix_tokens channels"]:
|
|
396
|
+
forward_pass_config = forward_pass_config or MLPForwardPassConfig()
|
|
397
|
+
batch_size, sequence_length, _ = inputs.shape
|
|
398
|
+
num_tokens = batch_size * sequence_length
|
|
399
|
+
if lengths_without_padding is None:
|
|
400
|
+
lengths_without_padding = jnp.ones(batch_size, dtype=jnp.int32) * sequence_length
|
|
401
|
+
padding_mask = jnp.arange(sequence_length)[None, :] < lengths_without_padding[:, None]
|
|
402
|
+
|
|
403
|
+
flattened_inputs = rearrange(inputs, "batch suffix_tokens channels -> (batch suffix_tokens) channels")
|
|
404
|
+
flattened_padding_mask = rearrange(padding_mask, "batch suffix_tokens -> (batch suffix_tokens)")
|
|
405
|
+
|
|
406
|
+
(router_logits,) = vmap(self.router)(flattened_inputs)
|
|
407
|
+
routing_map = self.config.routing_function(router_logits, self.num_experts_per_token)
|
|
408
|
+
token_mask = rearrange(
|
|
409
|
+
routing_map.expert_mask & flattened_padding_mask[:, None],
|
|
410
|
+
"tokens experts -> experts tokens",
|
|
411
|
+
)
|
|
412
|
+
expert_weights = rearrange(
|
|
413
|
+
routing_map.expert_weights,
|
|
414
|
+
"tokens experts -> experts tokens",
|
|
415
|
+
)
|
|
416
|
+
expert_weights = jnp.where(token_mask, expert_weights, 0.0)
|
|
417
|
+
|
|
418
|
+
chunk_size = math.ceil(num_tokens * forward_pass_config.moe_chunk_size_ratio)
|
|
419
|
+
num_padded_tokens = math.ceil(num_tokens / chunk_size) * chunk_size
|
|
420
|
+
token_indices = vmap(lambda m: jnp.flatnonzero(m, size=num_padded_tokens, fill_value=_SENTINEL))(token_mask)
|
|
421
|
+
chunked_token_indices = rearrange(
|
|
422
|
+
token_indices,
|
|
423
|
+
"experts (chunks chunk_tokens) -> chunks experts chunk_tokens",
|
|
424
|
+
chunk_tokens=chunk_size,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
def loop_iteration(
|
|
428
|
+
accumulator: Float[Array, "tokens channels"],
|
|
429
|
+
token_indices_for_chunk: Int[Array, "experts chunk_tokens"],
|
|
430
|
+
) -> tuple[Float[Array, "tokens channels"], None]:
|
|
431
|
+
def inner() -> Float[Array, "tokens channels"]:
|
|
432
|
+
weights_for_chunk = jnp.take_along_axis(
|
|
433
|
+
expert_weights,
|
|
434
|
+
token_indices_for_chunk,
|
|
435
|
+
axis=1,
|
|
436
|
+
mode="fill",
|
|
437
|
+
fill_value=0.0,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def run_expert(
|
|
441
|
+
expert: DenseMLP,
|
|
442
|
+
indices: Int[Array, " tokens_per_chunk"],
|
|
443
|
+
weights: Float[Array, " tokens_per_chunk"],
|
|
444
|
+
) -> Float[Array, "tokens_per_chunk channels"]:
|
|
445
|
+
inputs = flattened_inputs.at[indices].get(mode="fill", fill_value=0.0)
|
|
446
|
+
return vmap(expert.call_unbatched)(inputs) * weights[:, None]
|
|
447
|
+
|
|
448
|
+
expert_outputs = vmap(run_expert)(self.experts, token_indices_for_chunk, weights_for_chunk)
|
|
449
|
+
return accumulator.at[token_indices_for_chunk].add(
|
|
450
|
+
expert_outputs,
|
|
451
|
+
mode="drop",
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
return jax.lax.cond(jnp.any(token_indices_for_chunk != _SENTINEL), inner, lambda: accumulator), None
|
|
455
|
+
|
|
456
|
+
result, _ = jax.lax.scan(loop_iteration, jnp.zeros_like(flattened_inputs), chunked_token_indices)
|
|
457
|
+
return rearrange(result, "(batch suffix_tokens) channels -> batch suffix_tokens channels", batch=batch_size)
|
|
458
|
+
|
|
459
|
+
def export_weights(
|
|
460
|
+
self,
|
|
461
|
+
) -> ParameterTree[Array]:
|
|
462
|
+
return {
|
|
463
|
+
"router": self.router.export_weights(),
|
|
464
|
+
"experts": self.experts.export_weights(),
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
def import_weights(
|
|
468
|
+
self,
|
|
469
|
+
weights: ParameterTree[Array],
|
|
470
|
+
) -> Self:
|
|
471
|
+
assert isinstance(weights, Mapping)
|
|
472
|
+
assert isinstance(weights["router"], Mapping)
|
|
473
|
+
assert isinstance(weights["experts"], Mapping)
|
|
474
|
+
return replace(
|
|
475
|
+
self,
|
|
476
|
+
router=self.router.import_weights(weights["router"]),
|
|
477
|
+
experts=self.experts.import_weights(weights["experts"]),
|
|
112
478
|
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
MLPConfig = DenseMLPConfig | MixtureOfExpertsConfig
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
register_config_union(MLPConfig)
|
lalamo/modules/normalization.py
CHANGED
|
@@ -10,7 +10,7 @@ from jaxtyping import Array, DTypeLike, Float
|
|
|
10
10
|
|
|
11
11
|
from lalamo.common import ParameterTree, dummy_array
|
|
12
12
|
|
|
13
|
-
from .common import LalamoModule
|
|
13
|
+
from .common import LalamoModule
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
16
|
"RMSNorm",
|
|
@@ -83,13 +83,12 @@ class RMSNorm(LalamoModule[RMSNormConfig]):
|
|
|
83
83
|
result = normalized_x * adjusted_scales
|
|
84
84
|
return result.astype(inputs.dtype)
|
|
85
85
|
|
|
86
|
-
def export_weights(self
|
|
86
|
+
def export_weights(self) -> ParameterTree:
|
|
87
87
|
return {"scales": self.scales}
|
|
88
88
|
|
|
89
89
|
def import_weights(
|
|
90
90
|
self,
|
|
91
91
|
weights: ParameterTree[Array],
|
|
92
|
-
weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
|
|
93
92
|
) -> Self:
|
|
94
93
|
assert isinstance(weights, Mapping)
|
|
95
94
|
return replace(self, scales=weights["scales"])
|
lalamo/modules/rope.py
CHANGED
|
@@ -25,7 +25,7 @@ from jaxtyping import Array, DTypeLike, Float, Int
|
|
|
25
25
|
|
|
26
26
|
from lalamo.common import ParameterTree
|
|
27
27
|
|
|
28
|
-
from .common import LalamoModule,
|
|
28
|
+
from .common import LalamoModule, register_config_union
|
|
29
29
|
|
|
30
30
|
__all__ = [
|
|
31
31
|
"LinearScalingRoPEConfig",
|
|
@@ -39,22 +39,25 @@ __all__ = [
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
class PositionalEmbeddings(eqx.Module):
|
|
42
|
-
cosines: Float[Array, "tokens head_channels"]
|
|
43
|
-
sines: Float[Array, "tokens head_channels"]
|
|
42
|
+
cosines: Float[Array, "*batch tokens head_channels"]
|
|
43
|
+
sines: Float[Array, "*batch tokens head_channels"]
|
|
44
44
|
|
|
45
45
|
@property
|
|
46
46
|
def head_dim(self) -> int:
|
|
47
47
|
return self.cosines.shape[-1]
|
|
48
48
|
|
|
49
|
-
def rotate_half(
|
|
49
|
+
def rotate_half(
|
|
50
|
+
self,
|
|
51
|
+
heads: Float[Array, "*batch tokens head_channels"],
|
|
52
|
+
) -> Float[Array, "*batch tokens head_channels"]:
|
|
50
53
|
x1 = heads[..., : self.head_dim // 2]
|
|
51
54
|
x2 = heads[..., self.head_dim // 2 :]
|
|
52
55
|
return jnp.concatenate((-x2, x1), axis=-1)
|
|
53
56
|
|
|
54
|
-
def apply(self, heads: Float[Array, "tokens head_channels"]) -> Float[Array, "tokens head_channels"]:
|
|
57
|
+
def apply(self, heads: Float[Array, "*batch tokens head_channels"]) -> Float[Array, "*batch tokens head_channels"]:
|
|
55
58
|
return heads * self.cosines + self.rotate_half(heads) * self.sines
|
|
56
59
|
|
|
57
|
-
def export(self
|
|
60
|
+
def export(self) -> ParameterTree:
|
|
58
61
|
return dict(
|
|
59
62
|
cosines=self.cosines,
|
|
60
63
|
sines=self.sines,
|
|
@@ -105,9 +108,9 @@ class RoPE(LalamoModule[RoPEConfigBase]):
|
|
|
105
108
|
|
|
106
109
|
def __post_init__(self) -> None:
|
|
107
110
|
num_tokens, _ = self.sines.shape
|
|
108
|
-
if num_tokens
|
|
111
|
+
if num_tokens > self.config.max_sequence_length:
|
|
109
112
|
raise ValueError(
|
|
110
|
-
f"{num_tokens}
|
|
113
|
+
f"{num_tokens} exceeds the specified max sequence length {self.config.max_sequence_length}",
|
|
111
114
|
)
|
|
112
115
|
if self.cosines.dtype != self.config.precision:
|
|
113
116
|
raise ValueError(
|
|
@@ -140,7 +143,7 @@ class RoPE(LalamoModule[RoPEConfigBase]):
|
|
|
140
143
|
sines=self.sines[timesteps],
|
|
141
144
|
)
|
|
142
145
|
|
|
143
|
-
def export_weights(self
|
|
146
|
+
def export_weights(self) -> ParameterTree[Array]:
|
|
144
147
|
return {
|
|
145
148
|
"cosines": self.cosines,
|
|
146
149
|
"sines": self.sines,
|
|
@@ -149,7 +152,6 @@ class RoPE(LalamoModule[RoPEConfigBase]):
|
|
|
149
152
|
def import_weights(
|
|
150
153
|
self,
|
|
151
154
|
weights: ParameterTree[Array],
|
|
152
|
-
weight_layout: WeightLayout = WeightLayout.AUTO, # noqa: ARG002
|
|
153
155
|
) -> "RoPE":
|
|
154
156
|
assert isinstance(weights, Mapping)
|
|
155
157
|
return replace(self, cosines=weights["cosines"], sines=weights["sines"])
|
|
@@ -199,13 +201,15 @@ class LlamaRoPEConfig(RoPEConfigBase):
|
|
|
199
201
|
@dataclass(frozen=True)
|
|
200
202
|
class YARNRoPEConfig(RoPEConfigBase):
|
|
201
203
|
scaling_factor: float
|
|
204
|
+
original_context_length: int
|
|
202
205
|
beta_fast: float
|
|
203
206
|
beta_slow: float
|
|
207
|
+
truncate: bool
|
|
204
208
|
|
|
205
209
|
@classmethod
|
|
206
|
-
def _find_correction_dim(cls, num_rotations: float, dim: int, base: float,
|
|
210
|
+
def _find_correction_dim(cls, num_rotations: float, dim: int, base: float, original_context_length: int) -> float:
|
|
207
211
|
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
|
208
|
-
return (dim * math.log(
|
|
212
|
+
return (dim * math.log(original_context_length / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
|
209
213
|
|
|
210
214
|
@classmethod
|
|
211
215
|
def _find_correction_range(
|
|
@@ -214,19 +218,25 @@ class YARNRoPEConfig(RoPEConfigBase):
|
|
|
214
218
|
high_rot: float,
|
|
215
219
|
dim: int,
|
|
216
220
|
base: float,
|
|
217
|
-
|
|
218
|
-
|
|
221
|
+
original_context_length: int,
|
|
222
|
+
truncate: bool,
|
|
223
|
+
) -> tuple[float, float]:
|
|
219
224
|
"""Find dimension range bounds based on rotations"""
|
|
220
|
-
low =
|
|
221
|
-
high =
|
|
222
|
-
|
|
225
|
+
low = cls._find_correction_dim(low_rot, dim, base, original_context_length)
|
|
226
|
+
high = cls._find_correction_dim(high_rot, dim, base, original_context_length)
|
|
227
|
+
if truncate:
|
|
228
|
+
low = math.floor(low)
|
|
229
|
+
high = math.ceil(high)
|
|
230
|
+
return max(low, 0.0), min(high, float(dim - 1))
|
|
223
231
|
|
|
224
232
|
@classmethod
|
|
225
233
|
def _linear_ramp_factor(cls, min_value: float, max_value: float, dim: int) -> Float[Array, " head_dim"]:
|
|
226
234
|
if min_value == max_value:
|
|
227
235
|
max_value += 0.001 # Prevent singularity
|
|
228
236
|
|
|
229
|
-
|
|
237
|
+
min_v = jnp.float32(min_value)
|
|
238
|
+
max_v = jnp.float32(max_value)
|
|
239
|
+
linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_v) / (max_v - min_v)
|
|
230
240
|
ramp_func = jnp.clip(linear_func, 0, 1)
|
|
231
241
|
return ramp_func
|
|
232
242
|
|
|
@@ -234,7 +244,7 @@ class YARNRoPEConfig(RoPEConfigBase):
|
|
|
234
244
|
self,
|
|
235
245
|
inverse_frequencies: Float[Array, " tokens"],
|
|
236
246
|
head_dim: int,
|
|
237
|
-
max_sequence_length: int,
|
|
247
|
+
max_sequence_length: int, # noqa: ARG002
|
|
238
248
|
) -> Float[Array, " tokens"]:
|
|
239
249
|
scaled_frequencies = inverse_frequencies / self.scaling_factor
|
|
240
250
|
|
|
@@ -243,7 +253,8 @@ class YARNRoPEConfig(RoPEConfigBase):
|
|
|
243
253
|
self.beta_slow,
|
|
244
254
|
head_dim,
|
|
245
255
|
self.base,
|
|
246
|
-
|
|
256
|
+
self.original_context_length,
|
|
257
|
+
self.truncate,
|
|
247
258
|
)
|
|
248
259
|
|
|
249
260
|
# Get n-dimensional rotational scaling corrected for extrapolation
|
|
@@ -251,7 +262,7 @@ class YARNRoPEConfig(RoPEConfigBase):
|
|
|
251
262
|
return scaled_frequencies * (1 - smoothing_factor) + inverse_frequencies * smoothing_factor
|
|
252
263
|
|
|
253
264
|
@property
|
|
254
|
-
def
|
|
265
|
+
def _attention_scaling_factor(self) -> float:
|
|
255
266
|
return 0.1 * math.log(self.scaling_factor) + 1.0
|
|
256
267
|
|
|
257
268
|
|
lalamo/modules/utils.py
CHANGED
|
@@ -1,11 +1,21 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
1
3
|
import jax
|
|
4
|
+
from jax import vmap
|
|
2
5
|
from jaxtyping import Array, Float
|
|
3
6
|
|
|
4
7
|
__all__ = [
|
|
5
8
|
"apply_soft_capping",
|
|
9
|
+
"vmap_twice",
|
|
6
10
|
]
|
|
7
11
|
|
|
8
12
|
|
|
13
|
+
def vmap_twice[F: Callable](
|
|
14
|
+
func: F,
|
|
15
|
+
) -> F:
|
|
16
|
+
return vmap(vmap(func, in_axes=0), in_axes=0)
|
|
17
|
+
|
|
18
|
+
|
|
9
19
|
def apply_soft_capping(
|
|
10
20
|
values: Float[Array, "*"],
|
|
11
21
|
soft_cap: float,
|