lalamo 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +15 -2
- lalamo/data/__init__.py +0 -1
- lalamo/data/huggingface_message.py +1 -0
- lalamo/main.py +167 -18
- lalamo/message_processor.py +2 -3
- lalamo/model_import/common.py +120 -27
- lalamo/model_import/decoder_configs/__init__.py +4 -2
- lalamo/model_import/decoder_configs/common.py +62 -21
- lalamo/model_import/decoder_configs/executorch.py +14 -9
- lalamo/model_import/decoder_configs/huggingface/__init__.py +4 -2
- lalamo/model_import/decoder_configs/huggingface/common.py +38 -12
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +15 -10
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +19 -16
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +16 -10
- lalamo/model_import/decoder_configs/huggingface/llama.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/llamba.py +23 -14
- lalamo/model_import/decoder_configs/huggingface/mistral.py +16 -11
- lalamo/model_import/decoder_configs/huggingface/modern_bert.py +241 -0
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +17 -10
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +15 -10
- lalamo/model_import/loaders/__init__.py +3 -2
- lalamo/model_import/loaders/executorch.py +24 -12
- lalamo/model_import/loaders/huggingface.py +258 -30
- lalamo/model_import/model_specs/__init__.py +4 -2
- lalamo/model_import/model_specs/common.py +8 -2
- lalamo/model_import/model_specs/gemma.py +5 -1
- lalamo/model_import/model_specs/huggingface.py +1 -1
- lalamo/model_import/model_specs/mirai.py +20 -0
- lalamo/models/__init__.py +10 -0
- lalamo/models/common.py +81 -0
- lalamo/{language_model.py → models/language_model.py} +32 -49
- lalamo/models/router.py +59 -0
- lalamo/modules/__init__.py +33 -16
- lalamo/modules/classifier.py +339 -0
- lalamo/modules/common.py +6 -3
- lalamo/modules/decoder.py +52 -180
- lalamo/modules/mlp.py +28 -5
- lalamo/modules/normalization.py +13 -8
- lalamo/modules/token_mixers/attention.py +10 -6
- lalamo/modules/token_mixers/state/kv_cache.py +14 -4
- lalamo/modules/transformer.py +273 -0
- lalamo/modules/{decoder_layer.py → transformer_layer.py} +62 -45
- lalamo/speculator/__init__.py +6 -2
- lalamo/speculator/estimator.py +91 -0
- lalamo/speculator/inference.py +28 -9
- lalamo/speculator/ngram.py +7 -3
- lalamo/speculator/utils.py +4 -2
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/METADATA +1 -1
- lalamo-0.5.4.dist-info/RECORD +88 -0
- lalamo-0.5.2.dist-info/RECORD +0 -80
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/WHEEL +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/entry_points.txt +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.5.2.dist-info → lalamo-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
from collections.abc import Mapping, Sequence
|
|
2
|
+
from dataclasses import dataclass, replace
|
|
3
|
+
from typing import Self
|
|
4
|
+
|
|
5
|
+
import equinox as eqx
|
|
6
|
+
import jax
|
|
7
|
+
from jax import vmap
|
|
8
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
9
|
+
|
|
10
|
+
from lalamo.common import ParameterTree
|
|
11
|
+
from lalamo.modules.token_mixers import AttentionConfig
|
|
12
|
+
from lalamo.modules.utils import vmap_twice
|
|
13
|
+
|
|
14
|
+
from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
|
|
15
|
+
from .normalization import Normalization, NormalizationConfig
|
|
16
|
+
from .rope import PositionalEmbeddings, RoPE, RoPEConfig
|
|
17
|
+
from .token_mixers import State
|
|
18
|
+
from .transformer_layer import (
|
|
19
|
+
TransformerLayer,
|
|
20
|
+
TransformerLayerConfig,
|
|
21
|
+
TransformerLayerForwardPassConfig,
|
|
22
|
+
TransformerLayerResult,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"Transformer",
|
|
27
|
+
"TransformerConfig",
|
|
28
|
+
"TransformerResult",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
type TransformerForwardPassConfig = TransformerLayerForwardPassConfig
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TransformerResult(eqx.Module):
|
|
36
|
+
outputs: Float[Array, "batch suffix_tokens channels"]
|
|
37
|
+
updated_state: State | None = None
|
|
38
|
+
layer_results: tuple[TransformerLayerResult, ...] | None = None
|
|
39
|
+
global_positional_embeddings: PositionalEmbeddings | None = None
|
|
40
|
+
local_positional_embeddings: PositionalEmbeddings | None = None
|
|
41
|
+
|
|
42
|
+
def export(self) -> ParameterTree:
|
|
43
|
+
result: dict[str, ParameterTree | Array] = dict(
|
|
44
|
+
outputs=self.outputs,
|
|
45
|
+
)
|
|
46
|
+
if self.updated_state is not None:
|
|
47
|
+
result["updated_state"] = [state_layer.export() for state_layer in self.updated_state]
|
|
48
|
+
if self.layer_results is not None:
|
|
49
|
+
result["layer_results"] = [layer_result.export() for layer_result in self.layer_results]
|
|
50
|
+
if self.global_positional_embeddings is not None:
|
|
51
|
+
result["global_positional_embeddings"] = self.global_positional_embeddings.export()
|
|
52
|
+
if self.local_positional_embeddings is not None:
|
|
53
|
+
result["local_positional_embeddings"] = self.local_positional_embeddings.export()
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class TransformerConfig:
|
|
59
|
+
global_rope_config: RoPEConfig | None
|
|
60
|
+
local_rope_config: RoPEConfig | None
|
|
61
|
+
layer_configs: tuple[TransformerLayerConfig, ...]
|
|
62
|
+
output_norm_config: NormalizationConfig
|
|
63
|
+
model_dim: int
|
|
64
|
+
hidden_dim: int
|
|
65
|
+
context_length: int
|
|
66
|
+
|
|
67
|
+
def random_init(self, *, key: PRNGKeyArray) -> "Transformer":
|
|
68
|
+
first_layer_config, *_ = self.layer_configs
|
|
69
|
+
|
|
70
|
+
if self.global_rope_config:
|
|
71
|
+
global_rope = self.global_rope_config.init(
|
|
72
|
+
head_dim=first_layer_config.rope_dim,
|
|
73
|
+
num_timesteps=self.context_length,
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
global_rope = None
|
|
77
|
+
|
|
78
|
+
if self.local_rope_config:
|
|
79
|
+
max_sliding_window_size = max(
|
|
80
|
+
layer_config.mixer_config.sliding_window_size or 0
|
|
81
|
+
for layer_config in self.layer_configs
|
|
82
|
+
if isinstance(layer_config.mixer_config, AttentionConfig)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
local_rope = self.local_rope_config.init(
|
|
86
|
+
head_dim=first_layer_config.rope_dim,
|
|
87
|
+
num_timesteps=max(max_sliding_window_size, self.context_length),
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
local_rope = None
|
|
91
|
+
|
|
92
|
+
layers_keys = jax.random.split(key, num=len(self.layer_configs))
|
|
93
|
+
layers = tuple(
|
|
94
|
+
layer_config.random_init(
|
|
95
|
+
model_dim=self.model_dim,
|
|
96
|
+
hidden_dim=self.hidden_dim,
|
|
97
|
+
key=layer_key,
|
|
98
|
+
)
|
|
99
|
+
for layer_key, layer_config in zip(layers_keys, self.layer_configs, strict=True)
|
|
100
|
+
)
|
|
101
|
+
output_norm = self.output_norm_config.init(self.model_dim)
|
|
102
|
+
|
|
103
|
+
return Transformer(
|
|
104
|
+
config=self,
|
|
105
|
+
global_rope=global_rope,
|
|
106
|
+
local_rope=local_rope,
|
|
107
|
+
layers=layers,
|
|
108
|
+
output_norm=output_norm,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def empty(self) -> "Transformer":
|
|
112
|
+
first_layer_config, *_ = self.layer_configs
|
|
113
|
+
|
|
114
|
+
if self.global_rope_config:
|
|
115
|
+
global_rope = self.global_rope_config.init(
|
|
116
|
+
head_dim=first_layer_config.rope_dim,
|
|
117
|
+
num_timesteps=self.context_length,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
global_rope = None
|
|
121
|
+
|
|
122
|
+
if self.local_rope_config:
|
|
123
|
+
local_rope = self.local_rope_config.init(
|
|
124
|
+
head_dim=first_layer_config.rope_dim,
|
|
125
|
+
num_timesteps=self.context_length,
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
local_rope = None
|
|
129
|
+
|
|
130
|
+
layers = tuple(
|
|
131
|
+
layer_config.empty(
|
|
132
|
+
model_dim=self.model_dim,
|
|
133
|
+
hidden_dim=self.hidden_dim,
|
|
134
|
+
)
|
|
135
|
+
for layer_config in self.layer_configs
|
|
136
|
+
)
|
|
137
|
+
output_norm = self.output_norm_config.empty(self.model_dim)
|
|
138
|
+
|
|
139
|
+
return Transformer(
|
|
140
|
+
config=self,
|
|
141
|
+
global_rope=global_rope,
|
|
142
|
+
local_rope=local_rope,
|
|
143
|
+
layers=layers,
|
|
144
|
+
output_norm=output_norm,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Transformer(LalamoModule[TransformerConfig]):
|
|
149
|
+
global_rope: RoPE | None
|
|
150
|
+
local_rope: RoPE | None
|
|
151
|
+
layers: tuple[TransformerLayer, ...]
|
|
152
|
+
output_norm: Normalization
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def activation_precision(self) -> DTypeLike:
|
|
156
|
+
return self.layers[0].activation_precision
|
|
157
|
+
|
|
158
|
+
@eqx.filter_jit
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
inner_features: Float[Array, "batch suffix_tokens channels"],
|
|
162
|
+
token_positions: Int[Array, "batch suffix_tokens"],
|
|
163
|
+
state: State | None,
|
|
164
|
+
return_updated_state: bool,
|
|
165
|
+
return_layer_results: bool,
|
|
166
|
+
return_positional_embeddings: bool,
|
|
167
|
+
lengths_without_padding: Int[Array, " batch"] | None,
|
|
168
|
+
forward_pass_mode: ForwardPassMode,
|
|
169
|
+
forward_pass_config: TransformerForwardPassConfig | None,
|
|
170
|
+
) -> TransformerResult:
|
|
171
|
+
if inner_features.ndim != 3:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"inner_features must be a 3D array of size (batch_size, sequence_length, hidden_dim), got {inner_features.shape}",
|
|
174
|
+
)
|
|
175
|
+
if token_positions.ndim != 2:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
"token_positions must be a 2D array of size (batch_size, sequence_length),"
|
|
178
|
+
f" got {token_positions.shape}",
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
maybe_state = state or ([None] * len(self.layers))
|
|
182
|
+
|
|
183
|
+
if self.global_rope is not None:
|
|
184
|
+
global_positional_embeddings = vmap(self.global_rope)(token_positions)
|
|
185
|
+
else:
|
|
186
|
+
global_positional_embeddings = None
|
|
187
|
+
if self.local_rope is not None:
|
|
188
|
+
local_positional_embeddings = vmap(self.local_rope)(token_positions)
|
|
189
|
+
else:
|
|
190
|
+
local_positional_embeddings = global_positional_embeddings
|
|
191
|
+
|
|
192
|
+
updated_state_layers = []
|
|
193
|
+
layer_results = []
|
|
194
|
+
|
|
195
|
+
for layer, state_layer in zip(self.layers, maybe_state, strict=True):
|
|
196
|
+
match layer.positional_embedding_selector:
|
|
197
|
+
case PositionalEmbeddingSelector.LOCAL:
|
|
198
|
+
positional_embeddings_to_use = local_positional_embeddings
|
|
199
|
+
case PositionalEmbeddingSelector.GLOBAL:
|
|
200
|
+
positional_embeddings_to_use = global_positional_embeddings
|
|
201
|
+
case PositionalEmbeddingSelector.NONE:
|
|
202
|
+
positional_embeddings_to_use = None
|
|
203
|
+
|
|
204
|
+
layer_result = layer(
|
|
205
|
+
inner_features,
|
|
206
|
+
positional_embeddings_to_use,
|
|
207
|
+
state=state_layer,
|
|
208
|
+
return_updated_state=return_updated_state,
|
|
209
|
+
return_activation_trace=return_layer_results,
|
|
210
|
+
lengths_without_padding=lengths_without_padding,
|
|
211
|
+
forward_pass_mode=forward_pass_mode,
|
|
212
|
+
forward_pass_config=forward_pass_config,
|
|
213
|
+
)
|
|
214
|
+
inner_features = layer_result.outputs
|
|
215
|
+
layer_results.append(layer_result)
|
|
216
|
+
updated_state_layers.append(layer_result.updated_state)
|
|
217
|
+
|
|
218
|
+
normalized_outputs = vmap_twice(self.output_norm)(inner_features)
|
|
219
|
+
|
|
220
|
+
return TransformerResult(
|
|
221
|
+
outputs=normalized_outputs,
|
|
222
|
+
updated_state=(State(updated_state_layers) if return_updated_state else None),
|
|
223
|
+
layer_results=tuple(layer_results) if return_layer_results else None,
|
|
224
|
+
global_positional_embeddings=(global_positional_embeddings if return_positional_embeddings else None),
|
|
225
|
+
local_positional_embeddings=(local_positional_embeddings if return_positional_embeddings else None),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def init_static_state(self, batch_size: int, capacity: int) -> State:
|
|
229
|
+
return State(layer.init_static_state(batch_size, capacity) for layer in self.layers)
|
|
230
|
+
|
|
231
|
+
def export_weights(self) -> ParameterTree:
|
|
232
|
+
result = dict(
|
|
233
|
+
layers=[layer.export_weights() for layer in self.layers],
|
|
234
|
+
output_norm=self.output_norm.export_weights(),
|
|
235
|
+
)
|
|
236
|
+
if self.global_rope:
|
|
237
|
+
result["global_rope"] = self.global_rope.export_weights()
|
|
238
|
+
if self.local_rope:
|
|
239
|
+
result["local_rope"] = self.local_rope.export_weights()
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
def import_weights(
|
|
243
|
+
self,
|
|
244
|
+
weights: ParameterTree[Array],
|
|
245
|
+
) -> Self:
|
|
246
|
+
assert isinstance(weights, Mapping)
|
|
247
|
+
assert isinstance(weights["layers"], Sequence)
|
|
248
|
+
assert isinstance(weights["output_norm"], Mapping)
|
|
249
|
+
|
|
250
|
+
if self.global_rope:
|
|
251
|
+
assert isinstance(weights["global_rope"], Mapping)
|
|
252
|
+
global_rope = self.global_rope.import_weights(weights["global_rope"])
|
|
253
|
+
else:
|
|
254
|
+
global_rope = None
|
|
255
|
+
|
|
256
|
+
if self.local_rope:
|
|
257
|
+
assert isinstance(weights["local_rope"], Mapping)
|
|
258
|
+
local_rope = self.local_rope.import_weights(weights["local_rope"])
|
|
259
|
+
else:
|
|
260
|
+
local_rope = None
|
|
261
|
+
|
|
262
|
+
layers = []
|
|
263
|
+
for layer, layer_weights in zip(self.layers, weights["layers"], strict=True):
|
|
264
|
+
assert isinstance(layer_weights, Mapping)
|
|
265
|
+
layers.append(layer.import_weights(layer_weights))
|
|
266
|
+
|
|
267
|
+
return replace(
|
|
268
|
+
self,
|
|
269
|
+
global_rope=global_rope,
|
|
270
|
+
layers=tuple(layers),
|
|
271
|
+
output_norm=self.output_norm.import_weights(weights["output_norm"]),
|
|
272
|
+
local_rope=local_rope,
|
|
273
|
+
)
|
|
@@ -13,24 +13,24 @@ from lalamo.common import ParameterTree
|
|
|
13
13
|
|
|
14
14
|
from .common import ForwardPassMode, LalamoModule, PositionalEmbeddingSelector
|
|
15
15
|
from .mlp import MLPBase, MLPConfig, MLPForwardPassConfig
|
|
16
|
-
from .normalization import
|
|
16
|
+
from .normalization import Normalization, NormalizationConfig
|
|
17
17
|
from .rope import PositionalEmbeddings
|
|
18
18
|
from .token_mixers import KVCacheLayer, StateLayerBase, StaticKVCacheLayer, TokenMixerBase, TokenMixerConfig
|
|
19
19
|
from .utils import vmap_twice
|
|
20
20
|
|
|
21
21
|
__all__ = [
|
|
22
|
-
"
|
|
23
|
-
"
|
|
24
|
-
"
|
|
25
|
-
"
|
|
26
|
-
"
|
|
22
|
+
"TransformerLayer",
|
|
23
|
+
"TransformerLayerActivationTrace",
|
|
24
|
+
"TransformerLayerConfig",
|
|
25
|
+
"TransformerLayerForwardPassConfig",
|
|
26
|
+
"TransformerLayerResult",
|
|
27
27
|
]
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
type
|
|
30
|
+
type TransformerLayerForwardPassConfig = MLPForwardPassConfig
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class
|
|
33
|
+
class TransformerLayerActivationTrace(eqx.Module):
|
|
34
34
|
inputs: Float[Array, "batch suffix_tokens channels"]
|
|
35
35
|
positional_embeddings: PositionalEmbeddings | None
|
|
36
36
|
state: StateLayerBase | None
|
|
@@ -63,10 +63,10 @@ class DecoderLayerActivationTrace(eqx.Module):
|
|
|
63
63
|
return result
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
class
|
|
67
|
-
outputs: Float[Array, "
|
|
66
|
+
class TransformerLayerResult(eqx.Module):
|
|
67
|
+
outputs: Float[Array, "batch tokens channels"]
|
|
68
68
|
updated_state: KVCacheLayer | None
|
|
69
|
-
activation_trace:
|
|
69
|
+
activation_trace: TransformerLayerActivationTrace | None
|
|
70
70
|
|
|
71
71
|
def export(self) -> ParameterTree:
|
|
72
72
|
result: dict[str, ParameterTree | Array] = dict(
|
|
@@ -80,13 +80,13 @@ class DecoderLayerResult(eqx.Module):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
@dataclass(frozen=True)
|
|
83
|
-
class
|
|
84
|
-
pre_mixer_norm_config:
|
|
83
|
+
class TransformerLayerConfig:
|
|
84
|
+
pre_mixer_norm_config: NormalizationConfig | None
|
|
85
85
|
mixer_config: TokenMixerConfig
|
|
86
|
-
post_mixer_norm_config:
|
|
87
|
-
pre_mlp_norm_config:
|
|
86
|
+
post_mixer_norm_config: NormalizationConfig | None
|
|
87
|
+
pre_mlp_norm_config: NormalizationConfig
|
|
88
88
|
mlp_config: MLPConfig
|
|
89
|
-
post_mlp_norm_config:
|
|
89
|
+
post_mlp_norm_config: NormalizationConfig | None
|
|
90
90
|
|
|
91
91
|
@property
|
|
92
92
|
def rope_dim(self) -> int:
|
|
@@ -98,28 +98,31 @@ class DecoderLayerConfig:
|
|
|
98
98
|
hidden_dim: int,
|
|
99
99
|
*,
|
|
100
100
|
key: PRNGKeyArray,
|
|
101
|
-
) -> "
|
|
101
|
+
) -> "TransformerLayer":
|
|
102
102
|
attention_key, mlp_key = jax.random.split(key)
|
|
103
|
-
|
|
103
|
+
if self.pre_mixer_norm_config is not None:
|
|
104
|
+
pre_mixer_norm = self.pre_mixer_norm_config.init(model_dim)
|
|
105
|
+
else:
|
|
106
|
+
pre_mixer_norm = None
|
|
104
107
|
mixer = self.mixer_config.random_init(
|
|
105
108
|
model_dim=model_dim,
|
|
106
109
|
key=attention_key,
|
|
107
110
|
)
|
|
108
111
|
if self.post_mixer_norm_config is not None:
|
|
109
|
-
|
|
112
|
+
post_mixer_norm = self.post_mixer_norm_config.init(model_dim)
|
|
110
113
|
else:
|
|
111
|
-
|
|
114
|
+
post_mixer_norm = None
|
|
112
115
|
pre_mlp_norm = self.pre_mlp_norm_config.init(model_dim)
|
|
113
116
|
mlp = self.mlp_config.random_init(model_dim, hidden_dim, key=mlp_key)
|
|
114
117
|
if self.post_mlp_norm_config is not None:
|
|
115
118
|
post_mlp_norm = self.post_mlp_norm_config.init(model_dim)
|
|
116
119
|
else:
|
|
117
120
|
post_mlp_norm = None
|
|
118
|
-
return
|
|
121
|
+
return TransformerLayer(
|
|
119
122
|
config=self,
|
|
120
|
-
pre_mixer_norm=
|
|
123
|
+
pre_mixer_norm=pre_mixer_norm,
|
|
121
124
|
mixer=mixer,
|
|
122
|
-
post_mixer_norm=
|
|
125
|
+
post_mixer_norm=post_mixer_norm,
|
|
123
126
|
pre_mlp_norm=pre_mlp_norm,
|
|
124
127
|
mlp=mlp,
|
|
125
128
|
post_mlp_norm=post_mlp_norm,
|
|
@@ -129,39 +132,42 @@ class DecoderLayerConfig:
|
|
|
129
132
|
self,
|
|
130
133
|
model_dim: int,
|
|
131
134
|
hidden_dim: int,
|
|
132
|
-
) -> "
|
|
133
|
-
|
|
135
|
+
) -> "TransformerLayer":
|
|
136
|
+
if self.pre_mixer_norm_config is not None:
|
|
137
|
+
pre_mixer_norm = self.pre_mixer_norm_config.empty(model_dim)
|
|
138
|
+
else:
|
|
139
|
+
pre_mixer_norm = None
|
|
134
140
|
attention = self.mixer_config.empty(
|
|
135
141
|
model_dim=model_dim,
|
|
136
142
|
)
|
|
137
143
|
if self.post_mixer_norm_config is not None:
|
|
138
|
-
|
|
144
|
+
post_mixer_norm = self.post_mixer_norm_config.empty(model_dim)
|
|
139
145
|
else:
|
|
140
|
-
|
|
146
|
+
post_mixer_norm = None
|
|
141
147
|
pre_mlp_norm = self.pre_mlp_norm_config.empty(model_dim)
|
|
142
148
|
mlp = self.mlp_config.empty(model_dim, hidden_dim)
|
|
143
149
|
if self.post_mlp_norm_config is not None:
|
|
144
150
|
post_mlp_norm = self.post_mlp_norm_config.empty(model_dim)
|
|
145
151
|
else:
|
|
146
152
|
post_mlp_norm = None
|
|
147
|
-
return
|
|
153
|
+
return TransformerLayer(
|
|
148
154
|
config=self,
|
|
149
|
-
pre_mixer_norm=
|
|
155
|
+
pre_mixer_norm=pre_mixer_norm,
|
|
150
156
|
mixer=attention,
|
|
151
|
-
post_mixer_norm=
|
|
157
|
+
post_mixer_norm=post_mixer_norm,
|
|
152
158
|
pre_mlp_norm=pre_mlp_norm,
|
|
153
159
|
mlp=mlp,
|
|
154
160
|
post_mlp_norm=post_mlp_norm,
|
|
155
161
|
)
|
|
156
162
|
|
|
157
163
|
|
|
158
|
-
class
|
|
159
|
-
pre_mixer_norm:
|
|
164
|
+
class TransformerLayer(LalamoModule[TransformerLayerConfig]):
|
|
165
|
+
pre_mixer_norm: Normalization | None
|
|
160
166
|
mixer: TokenMixerBase
|
|
161
|
-
post_mixer_norm:
|
|
162
|
-
pre_mlp_norm:
|
|
167
|
+
post_mixer_norm: Normalization | None
|
|
168
|
+
pre_mlp_norm: Normalization
|
|
163
169
|
mlp: MLPBase
|
|
164
|
-
post_mlp_norm:
|
|
170
|
+
post_mlp_norm: Normalization | None
|
|
165
171
|
|
|
166
172
|
@property
|
|
167
173
|
def activation_precision(self) -> DTypeLike:
|
|
@@ -172,7 +178,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
172
178
|
return self.mixer.positional_embedding_selector
|
|
173
179
|
|
|
174
180
|
def __post_init__(self) -> None:
|
|
175
|
-
model_dim = self.pre_mixer_norm.input_dim
|
|
181
|
+
model_dim = self.pre_mixer_norm.input_dim if self.pre_mixer_norm is not None else self.mixer.model_dim
|
|
176
182
|
if self.mixer.model_dim != model_dim:
|
|
177
183
|
raise ValueError(
|
|
178
184
|
f"Attention model dim {self.mixer.model_dim} does not match"
|
|
@@ -204,15 +210,21 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
204
210
|
return_activation_trace: bool = False,
|
|
205
211
|
lengths_without_padding: Int[Array, " batch"] | None = None,
|
|
206
212
|
forward_pass_mode: ForwardPassMode = ForwardPassMode.MULTI_TOKEN,
|
|
207
|
-
forward_pass_config:
|
|
208
|
-
) ->
|
|
213
|
+
forward_pass_config: TransformerLayerForwardPassConfig | None = None,
|
|
214
|
+
) -> TransformerLayerResult:
|
|
209
215
|
if inputs.ndim != 3:
|
|
210
216
|
raise ValueError(
|
|
211
217
|
f"Inputs to decoder layers must be a 3D arrays of size (batch_size, sequence_length, hidden_dim),"
|
|
212
218
|
f" got {inputs.shape}",
|
|
213
219
|
)
|
|
214
|
-
|
|
215
|
-
|
|
220
|
+
if self.pre_mixer_norm is not None:
|
|
221
|
+
normalized_mixer_inputs = vmap_twice(self.pre_mixer_norm)(inputs)
|
|
222
|
+
else:
|
|
223
|
+
normalized_mixer_inputs = inputs
|
|
224
|
+
|
|
225
|
+
batched_mixer_fn = vmap(
|
|
226
|
+
partial(self.mixer, return_updated_state=return_updated_state or return_activation_trace),
|
|
227
|
+
)
|
|
216
228
|
mixer_outputs, updated_state = batched_mixer_fn(
|
|
217
229
|
normalized_mixer_inputs,
|
|
218
230
|
positional_embeddings,
|
|
@@ -240,7 +252,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
240
252
|
outputs = mlp_inputs + mlp_outputs
|
|
241
253
|
|
|
242
254
|
if return_activation_trace:
|
|
243
|
-
activation_trace =
|
|
255
|
+
activation_trace = TransformerLayerActivationTrace(
|
|
244
256
|
inputs=inputs,
|
|
245
257
|
positional_embeddings=positional_embeddings,
|
|
246
258
|
state=state,
|
|
@@ -255,7 +267,7 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
255
267
|
else:
|
|
256
268
|
activation_trace = None
|
|
257
269
|
|
|
258
|
-
return
|
|
270
|
+
return TransformerLayerResult(
|
|
259
271
|
outputs=outputs,
|
|
260
272
|
updated_state=updated_state,
|
|
261
273
|
activation_trace=activation_trace,
|
|
@@ -269,11 +281,12 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
269
281
|
|
|
270
282
|
def export_weights(self) -> ParameterTree:
|
|
271
283
|
result = dict(
|
|
272
|
-
pre_mixer_norm=self.pre_mixer_norm.export_weights(),
|
|
273
284
|
mixer=self.mixer.export_weights(),
|
|
274
285
|
pre_mlp_norm=self.pre_mlp_norm.export_weights(),
|
|
275
286
|
mlp=self.mlp.export_weights(),
|
|
276
287
|
)
|
|
288
|
+
if self.pre_mixer_norm is not None:
|
|
289
|
+
result["pre_mixer_norm"] = self.pre_mixer_norm.export_weights()
|
|
277
290
|
if self.post_mixer_norm is not None:
|
|
278
291
|
result["post_mixer_norm"] = self.post_mixer_norm.export_weights()
|
|
279
292
|
if self.post_mlp_norm is not None:
|
|
@@ -285,7 +298,6 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
285
298
|
weights: ParameterTree[Array],
|
|
286
299
|
) -> Self:
|
|
287
300
|
assert isinstance(weights, Mapping)
|
|
288
|
-
assert isinstance(weights["pre_mixer_norm"], Mapping)
|
|
289
301
|
assert isinstance(weights["mixer"], Mapping)
|
|
290
302
|
assert isinstance(weights["mlp"], Mapping)
|
|
291
303
|
assert isinstance(weights["pre_mlp_norm"], Mapping)
|
|
@@ -302,9 +314,14 @@ class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
|
302
314
|
post_mlp_norm = self.post_mlp_norm.import_weights(weights["post_mlp_norm"])
|
|
303
315
|
else:
|
|
304
316
|
post_mlp_norm = None
|
|
317
|
+
if self.pre_mixer_norm is not None:
|
|
318
|
+
assert isinstance(weights["pre_mixer_norm"], Mapping)
|
|
319
|
+
pre_mixer_norm = self.pre_mixer_norm.import_weights(weights["pre_mixer_norm"])
|
|
320
|
+
else:
|
|
321
|
+
pre_mixer_norm = None
|
|
305
322
|
return replace(
|
|
306
323
|
self,
|
|
307
|
-
pre_mixer_norm=
|
|
324
|
+
pre_mixer_norm=pre_mixer_norm,
|
|
308
325
|
mixer=self.mixer.import_weights(weights["mixer"]),
|
|
309
326
|
post_mixer_norm=post_mixer_norm,
|
|
310
327
|
pre_mlp_norm=self.pre_mlp_norm.import_weights(weights["pre_mlp_norm"]),
|
lalamo/speculator/__init__.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
from .common import Speculator
|
|
2
|
-
from .
|
|
2
|
+
from .estimator import estimate_batchsize_from_memory
|
|
3
|
+
from .inference import CollectTracesEvent, inference_collect_traces
|
|
3
4
|
from .ngram import NGramSpeculator
|
|
4
|
-
from .utils import train_speculator
|
|
5
|
+
from .utils import SpeculatorTrainingEvent, train_speculator
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
8
|
+
"CollectTracesEvent",
|
|
7
9
|
"NGramSpeculator",
|
|
8
10
|
"Speculator",
|
|
11
|
+
"SpeculatorTrainingEvent",
|
|
12
|
+
"estimate_batchsize_from_memory",
|
|
9
13
|
"inference_collect_traces",
|
|
10
14
|
"train_speculator",
|
|
11
15
|
]
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import itertools
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
|
|
9
|
+
from lalamo.models import LanguageModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def estimate_memory_from_batchsize(
|
|
13
|
+
model: LanguageModel,
|
|
14
|
+
max_input_length: int,
|
|
15
|
+
max_output_length: int,
|
|
16
|
+
num_logits_per_token: int,
|
|
17
|
+
batch_size: int,
|
|
18
|
+
) -> int:
|
|
19
|
+
memory_analysis = (
|
|
20
|
+
jax.jit(
|
|
21
|
+
functools.partial(
|
|
22
|
+
model.generate_tokens,
|
|
23
|
+
max_output_length=max_output_length,
|
|
24
|
+
num_top_logits_to_return=num_logits_per_token,
|
|
25
|
+
),
|
|
26
|
+
backend="cpu", # cuda backend tries to allocate in .compile() and ooms
|
|
27
|
+
)
|
|
28
|
+
.lower(
|
|
29
|
+
prompt_token_ids=jax.ShapeDtypeStruct((batch_size, max_input_length), jnp.int32),
|
|
30
|
+
prompt_lengths_without_padding=jax.ShapeDtypeStruct((batch_size,), jnp.int32),
|
|
31
|
+
)
|
|
32
|
+
.compile()
|
|
33
|
+
.memory_analysis()
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
assert hasattr(memory_analysis, "argument_size_in_bytes")
|
|
37
|
+
assert hasattr(memory_analysis, "output_size_in_bytes")
|
|
38
|
+
assert hasattr(memory_analysis, "temp_size_in_bytes")
|
|
39
|
+
|
|
40
|
+
return (
|
|
41
|
+
memory_analysis.argument_size_in_bytes # type: ignore (pyright bug)
|
|
42
|
+
+ memory_analysis.output_size_in_bytes # type: ignore (pyright bug)
|
|
43
|
+
+ memory_analysis.temp_size_in_bytes # type: ignore (pyright bug)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EstimateBatchsizeFromMemoryEvent(NamedTuple):
|
|
48
|
+
lo: int
|
|
49
|
+
hi: int | None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def estimate_batchsize_from_memory(
|
|
53
|
+
model: LanguageModel,
|
|
54
|
+
max_input_length: int,
|
|
55
|
+
max_output_length: int,
|
|
56
|
+
num_logits_per_token: int,
|
|
57
|
+
target_mem: int,
|
|
58
|
+
progress: Callable[[EstimateBatchsizeFromMemoryEvent], None] | None = None,
|
|
59
|
+
) -> int:
|
|
60
|
+
mem_for_bs = functools.cache(
|
|
61
|
+
functools.partial(
|
|
62
|
+
estimate_memory_from_batchsize,
|
|
63
|
+
model,
|
|
64
|
+
max_input_length,
|
|
65
|
+
max_output_length,
|
|
66
|
+
num_logits_per_token,
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
lo = 0
|
|
71
|
+
hi = 0
|
|
72
|
+
for candidate_exp in itertools.count():
|
|
73
|
+
lo = hi
|
|
74
|
+
hi = 2**candidate_exp
|
|
75
|
+
|
|
76
|
+
if progress is not None:
|
|
77
|
+
progress(EstimateBatchsizeFromMemoryEvent(lo, None))
|
|
78
|
+
if target_mem < mem_for_bs(hi):
|
|
79
|
+
break
|
|
80
|
+
|
|
81
|
+
while hi - lo > 1:
|
|
82
|
+
mid = (lo + hi) // 2
|
|
83
|
+
|
|
84
|
+
if progress is not None:
|
|
85
|
+
progress(EstimateBatchsizeFromMemoryEvent(lo, hi))
|
|
86
|
+
if target_mem < mem_for_bs(mid):
|
|
87
|
+
hi = mid
|
|
88
|
+
else:
|
|
89
|
+
lo = mid
|
|
90
|
+
|
|
91
|
+
return lo
|