lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +20 -5
- lalamo/data/__init__.py +8 -0
- lalamo/data/huggingface_message.py +38 -0
- lalamo/data/lalamo_completions.py +43 -0
- lalamo/data/utils.py +8 -0
- lalamo/language_model.py +152 -69
- lalamo/main.py +271 -43
- lalamo/message_processor.py +11 -1
- lalamo/model_import/common.py +17 -7
- lalamo/model_import/decoder_configs/__init__.py +3 -0
- lalamo/model_import/decoder_configs/executorch.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
- lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
- lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
- lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
- lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
- lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
- lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
- lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
- lalamo/model_import/huggingface_tokenizer_config.py +1 -4
- lalamo/model_import/loaders/executorch.py +10 -9
- lalamo/model_import/loaders/huggingface.py +104 -9
- lalamo/model_import/loaders/utils.py +92 -0
- lalamo/model_import/model_specs/__init__.py +4 -1
- lalamo/model_import/model_specs/common.py +15 -12
- lalamo/model_import/model_specs/gpt_oss.py +21 -0
- lalamo/modules/__init__.py +35 -7
- lalamo/modules/activations.py +24 -14
- lalamo/modules/attention.py +73 -20
- lalamo/modules/common.py +8 -57
- lalamo/modules/decoder.py +48 -34
- lalamo/modules/decoder_layer.py +57 -43
- lalamo/modules/embedding.py +13 -19
- lalamo/modules/kv_cache.py +53 -16
- lalamo/modules/linear.py +260 -79
- lalamo/modules/mlp.py +395 -23
- lalamo/modules/normalization.py +2 -3
- lalamo/modules/rope.py +32 -21
- lalamo/modules/utils.py +10 -0
- lalamo/speculator/__init__.py +11 -0
- lalamo/speculator/common.py +22 -0
- lalamo/speculator/inference.py +75 -0
- lalamo/speculator/ngram.py +154 -0
- lalamo/speculator/utils.py +52 -0
- lalamo/utils.py +27 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
- lalamo-0.4.0.dist-info/RECORD +71 -0
- lalamo-0.3.3.dist-info/RECORD +0 -59
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,10 @@ from jaxtyping import Array
|
|
|
6
6
|
|
|
7
7
|
from lalamo.common import ParameterPath
|
|
8
8
|
from lalamo.modules import (
|
|
9
|
-
MLP,
|
|
10
9
|
Attention,
|
|
11
10
|
Decoder,
|
|
12
11
|
DecoderLayer,
|
|
12
|
+
DenseMLP,
|
|
13
13
|
FullPrecisionLinear,
|
|
14
14
|
GroupQuantizedLinear,
|
|
15
15
|
LinearBase,
|
|
@@ -17,9 +17,11 @@ from lalamo.modules import (
|
|
|
17
17
|
TiedEmbedding,
|
|
18
18
|
UntiedEmbedding,
|
|
19
19
|
)
|
|
20
|
+
from lalamo.modules.mlp import MixtureOfExperts, MLPBase
|
|
20
21
|
from lalamo.quantization import QuantizationMode
|
|
21
22
|
|
|
22
23
|
from .common import load_parameters
|
|
24
|
+
from .utils import decode_mxfp4, deinterleave_pairwise_columns
|
|
23
25
|
|
|
24
26
|
__all__ = ["load_huggingface"]
|
|
25
27
|
|
|
@@ -78,7 +80,7 @@ def _process_quantized_tensors(
|
|
|
78
80
|
zero_points = unpacked_zero_points.astype(module.config.activation_precision)
|
|
79
81
|
processed_scales = scales.astype(module.config.activation_precision)
|
|
80
82
|
|
|
81
|
-
return weights
|
|
83
|
+
return weights, zero_points, processed_scales
|
|
82
84
|
|
|
83
85
|
|
|
84
86
|
def _fuse_full_precision_weights(
|
|
@@ -158,16 +160,103 @@ def load_linear(
|
|
|
158
160
|
return load_parameters(
|
|
159
161
|
lambda m: (m.weights, m.scales, m.zero_points, m.biases),
|
|
160
162
|
module,
|
|
161
|
-
(weights, scales, zero_points, bias),
|
|
163
|
+
(weights.T, scales.T, zero_points.T, bias),
|
|
162
164
|
)
|
|
163
165
|
|
|
164
166
|
raise TypeError(f"Unsupported module type for loading: {type(module)}")
|
|
165
167
|
|
|
166
168
|
|
|
167
|
-
def load_mlp(module:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
169
|
+
def load_mlp(module: MLPBase, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLPBase:
|
|
170
|
+
if isinstance(module, DenseMLP):
|
|
171
|
+
# Standard dense MLP with separate sublayers.
|
|
172
|
+
up_projection = load_linear(
|
|
173
|
+
module.up_projection,
|
|
174
|
+
weights_dict,
|
|
175
|
+
path,
|
|
176
|
+
sublayers_to_fuse=["up_proj", "gate_proj"],
|
|
177
|
+
)
|
|
178
|
+
down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
|
|
179
|
+
return load_parameters(
|
|
180
|
+
lambda m: (m.up_projection, m.down_projection),
|
|
181
|
+
module,
|
|
182
|
+
(up_projection, down_projection),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if isinstance(module, MixtureOfExperts):
|
|
186
|
+
return load_moe(module, weights_dict, path)
|
|
187
|
+
|
|
188
|
+
raise TypeError(f"Unsupported module type for loading: {type(module)}")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path: ParameterPath) -> MixtureOfExperts:
|
|
192
|
+
# Load router via the standard linear loader
|
|
193
|
+
router = load_linear(module.router, weights_dict, path / "router")
|
|
194
|
+
|
|
195
|
+
experts_path = path / "experts"
|
|
196
|
+
# Handle fused MXFP4 experts layout if present
|
|
197
|
+
if (experts_path / "gate_up_proj_blocks") in weights_dict:
|
|
198
|
+
# Decode fused gate/up (interleaved), split into (up, gate), and add +1.0 to up bias
|
|
199
|
+
fused = decode_mxfp4(
|
|
200
|
+
weights_dict[experts_path / "gate_up_proj_blocks"],
|
|
201
|
+
weights_dict[experts_path / "gate_up_proj_scales"],
|
|
202
|
+
dtype=module.activation_precision,
|
|
203
|
+
flatten=False,
|
|
204
|
+
)
|
|
205
|
+
# Stored as (experts, outputs=2*hidden_dim, input_blocks, input_block_elems)
|
|
206
|
+
# Merge blocks and move outputs last
|
|
207
|
+
fused_eio = rearrange(fused, "e o ib ie -> e (ib ie) o")
|
|
208
|
+
up_w, gate_w = deinterleave_pairwise_columns(fused_eio, first="odd")
|
|
209
|
+
combined_up_gate = jnp.concatenate([up_w, gate_w], axis=-1)
|
|
210
|
+
# Transpose to new layout: (experts, outputs, inputs)
|
|
211
|
+
combined_up_gate_w = jnp.swapaxes(combined_up_gate, -1, -2)
|
|
212
|
+
|
|
213
|
+
gub = weights_dict[experts_path / "gate_up_proj_bias"]
|
|
214
|
+
if gub.ndim == 1:
|
|
215
|
+
# Broadcast to (experts, 2*hidden_dim)
|
|
216
|
+
gub = jnp.broadcast_to(gub, (combined_up_gate_w.shape[0], gub.shape[0]))
|
|
217
|
+
up_b, gate_b = deinterleave_pairwise_columns(gub, first="odd")
|
|
218
|
+
combined_up_gate_b = jnp.concatenate([up_b + 1.0, gate_b], axis=-1)
|
|
219
|
+
|
|
220
|
+
up_projection = load_parameters(
|
|
221
|
+
lambda m: (m.weights, m.biases), # type: ignore
|
|
222
|
+
module.experts.up_projection,
|
|
223
|
+
(combined_up_gate_w, combined_up_gate_b),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Down projection: decode MXFP4 to dense
|
|
227
|
+
down_w = decode_mxfp4(
|
|
228
|
+
weights_dict[experts_path / "down_proj_blocks"],
|
|
229
|
+
weights_dict[experts_path / "down_proj_scales"],
|
|
230
|
+
dtype=module.activation_precision,
|
|
231
|
+
flatten=False,
|
|
232
|
+
)
|
|
233
|
+
# Stored as (experts, outputs=model_dim, input_blocks, input_block_elems)
|
|
234
|
+
# Merge blocks and move outputs last
|
|
235
|
+
down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
|
|
236
|
+
down_b = weights_dict[experts_path / "down_proj_bias"]
|
|
237
|
+
if down_b.ndim == 1:
|
|
238
|
+
down_b = jnp.broadcast_to(down_b, down_w.shape[:-1] + (down_b.shape[0],))
|
|
239
|
+
|
|
240
|
+
down_projection = load_parameters(
|
|
241
|
+
lambda m: (m.weights, m.biases), # type: ignore
|
|
242
|
+
module.experts.down_projection,
|
|
243
|
+
(down_w, down_b),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
experts = load_parameters(
|
|
247
|
+
lambda m: (m.up_projection, m.down_projection),
|
|
248
|
+
module.experts,
|
|
249
|
+
(up_projection, down_projection),
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
# Fallback: recursively load a standard DenseMLP experts module
|
|
253
|
+
experts = load_mlp(module.experts, weights_dict, experts_path)
|
|
254
|
+
|
|
255
|
+
return load_parameters(
|
|
256
|
+
lambda m: (m.router, m.experts),
|
|
257
|
+
module,
|
|
258
|
+
(router, experts),
|
|
259
|
+
)
|
|
171
260
|
|
|
172
261
|
|
|
173
262
|
def load_rmsnorm(
|
|
@@ -202,10 +291,16 @@ def load_attention(
|
|
|
202
291
|
else:
|
|
203
292
|
key_norm = None
|
|
204
293
|
|
|
294
|
+
# GPT-OSS adds per-head attention sinks; load them if present.
|
|
295
|
+
if (path / "sinks") in weights_dict:
|
|
296
|
+
sinks = weights_dict[path / "sinks"]
|
|
297
|
+
else:
|
|
298
|
+
sinks = module.sinks
|
|
299
|
+
|
|
205
300
|
return load_parameters(
|
|
206
|
-
lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm),
|
|
301
|
+
lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm, m.sinks),
|
|
207
302
|
module,
|
|
208
|
-
(qkv_projection, out_projection, query_norm, key_norm),
|
|
303
|
+
(qkv_projection, out_projection, query_norm, key_norm, sinks),
|
|
209
304
|
)
|
|
210
305
|
|
|
211
306
|
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# MXFP4 decoding utilities for model loaders.
|
|
2
|
+
# Based on OpenAI's reference implementation logic for GPT-OSS MXFP4 weights.
|
|
3
|
+
# Converts packed FP4 blocks plus per-row scales into dense weights in the target dtype.
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jaxtyping import Array, DTypeLike
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"decode_mxfp4",
|
|
10
|
+
"deinterleave_pairwise_columns",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# The 16 representable FP4 values used by MXFP4, in logical order (low nibble indices 0..15).
|
|
15
|
+
# See: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/weights.py
|
|
16
|
+
_MXFP4_LUT_VALUES = (
|
|
17
|
+
0.0,
|
|
18
|
+
0.5,
|
|
19
|
+
1.0,
|
|
20
|
+
1.5,
|
|
21
|
+
2.0,
|
|
22
|
+
3.0,
|
|
23
|
+
4.0,
|
|
24
|
+
6.0,
|
|
25
|
+
-0.0,
|
|
26
|
+
-0.5,
|
|
27
|
+
-1.0,
|
|
28
|
+
-1.5,
|
|
29
|
+
-2.0,
|
|
30
|
+
-3.0,
|
|
31
|
+
-4.0,
|
|
32
|
+
-6.0,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def decode_mxfp4(
|
|
37
|
+
blocks: Array,
|
|
38
|
+
scales: Array,
|
|
39
|
+
*,
|
|
40
|
+
dtype: DTypeLike,
|
|
41
|
+
flatten: bool = False,
|
|
42
|
+
) -> Array:
|
|
43
|
+
target_dtype = jnp.dtype(dtype)
|
|
44
|
+
|
|
45
|
+
# Prepare LUT in target dtype
|
|
46
|
+
lut = jnp.array(_MXFP4_LUT_VALUES, dtype=target_dtype)
|
|
47
|
+
|
|
48
|
+
*prefix, rows, packed_cols = blocks.shape
|
|
49
|
+
if scales.shape != (*prefix, rows):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"MXFP4 scales shape {scales.shape} does not match blocks prefix/rows {(*prefix, rows)}",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Extract low/high nibble indices
|
|
55
|
+
low_mask = jnp.array(0x0F, dtype=blocks.dtype)
|
|
56
|
+
idx_lo = (blocks & low_mask).astype(jnp.int32)
|
|
57
|
+
idx_hi = (blocks >> jnp.array(4, dtype=blocks.dtype)).astype(jnp.int32)
|
|
58
|
+
|
|
59
|
+
# Lookup FP4 base values
|
|
60
|
+
vals_lo = lut[idx_lo]
|
|
61
|
+
vals_hi = lut[idx_hi]
|
|
62
|
+
|
|
63
|
+
# Interleave into (..., rows, 2*packed_cols)
|
|
64
|
+
out_shape = (*prefix, rows, packed_cols * 2)
|
|
65
|
+
out = jnp.empty(out_shape, dtype=target_dtype)
|
|
66
|
+
out = out.at[..., 0::2].set(vals_lo)
|
|
67
|
+
out = out.at[..., 1::2].set(vals_hi)
|
|
68
|
+
|
|
69
|
+
# Apply exponent scaling: exponents are biased by 127 in checkpoints
|
|
70
|
+
exp = scales.astype(jnp.int32) - 127
|
|
71
|
+
out = jnp.ldexp(out, exp[..., None])
|
|
72
|
+
|
|
73
|
+
if flatten:
|
|
74
|
+
return out.reshape(*prefix, rows * (packed_cols * 2))
|
|
75
|
+
return out
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def deinterleave_pairwise_columns(
|
|
79
|
+
matrix: Array,
|
|
80
|
+
*,
|
|
81
|
+
first: str = "even",
|
|
82
|
+
) -> tuple[Array, Array]:
|
|
83
|
+
if matrix.shape[-1] % 2 != 0:
|
|
84
|
+
raise ValueError(f"Last dimension must be even, got {matrix.shape[-1]}")
|
|
85
|
+
|
|
86
|
+
match first:
|
|
87
|
+
case "even":
|
|
88
|
+
return matrix[..., 0::2], matrix[..., 1::2]
|
|
89
|
+
case "odd":
|
|
90
|
+
return matrix[..., 1::2], matrix[..., 0::2]
|
|
91
|
+
case _:
|
|
92
|
+
raise ValueError("Parameter 'first' must be either 'even' or 'odd'")
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from .common import FileSpec, ModelSpec, UseCase, build_quantized_models
|
|
2
2
|
from .deepseek import DEEPSEEK_MODELS
|
|
3
3
|
from .gemma import GEMMA_MODELS
|
|
4
|
+
from .gpt_oss import GPT_OSS_MODELS
|
|
4
5
|
from .huggingface import HUGGINGFACE_MODELS
|
|
5
6
|
from .llama import LLAMA_MODELS
|
|
6
7
|
from .mistral import MISTRAL_MODELS
|
|
7
|
-
|
|
8
|
+
|
|
9
|
+
# from .pleias import PLEIAS_MODELS
|
|
8
10
|
from .polaris import POLARIS_MODELS
|
|
9
11
|
from .qwen import QWEN_MODELS
|
|
10
12
|
from .reka import REKA_MODELS
|
|
@@ -23,6 +25,7 @@ ALL_MODEL_LISTS = [
|
|
|
23
25
|
DEEPSEEK_MODELS,
|
|
24
26
|
GEMMA_MODELS,
|
|
25
27
|
HUGGINGFACE_MODELS,
|
|
28
|
+
GPT_OSS_MODELS,
|
|
26
29
|
MISTRAL_MODELS,
|
|
27
30
|
# PLEIAS_MODELS, # TODO(norpadon): Add chat template
|
|
28
31
|
POLARIS_MODELS,
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from collections.abc import (
|
|
2
2
|
Callable,
|
|
3
|
+
Iterator,
|
|
3
4
|
Mapping,
|
|
4
5
|
)
|
|
6
|
+
from contextlib import contextmanager
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
8
|
from enum import Enum
|
|
7
9
|
from pathlib import Path
|
|
@@ -10,11 +12,10 @@ from typing import ClassVar, cast, get_args, get_origin
|
|
|
10
12
|
import cattrs
|
|
11
13
|
import jax.numpy as jnp
|
|
12
14
|
from jaxtyping import Array, DTypeLike
|
|
13
|
-
from safetensors.flax import load_file as load_safetensors
|
|
14
15
|
|
|
15
16
|
from lalamo.model_import.decoder_configs import ForeignConfig
|
|
16
17
|
from lalamo.quantization import QuantizationMode
|
|
17
|
-
from lalamo.utils import MapDictValues
|
|
18
|
+
from lalamo.utils import MapDictValues, open_safetensors
|
|
18
19
|
|
|
19
20
|
__all__ = [
|
|
20
21
|
"ConfigMap",
|
|
@@ -37,16 +38,18 @@ class WeightsType(Enum):
|
|
|
37
38
|
SAFETENSORS = "safetensors"
|
|
38
39
|
TORCH = "torch"
|
|
39
40
|
|
|
40
|
-
|
|
41
|
+
@contextmanager
|
|
42
|
+
def load(self, filename: Path | str, float_dtype: DTypeLike) -> Iterator[Mapping[str, jnp.ndarray]]:
|
|
41
43
|
if self == WeightsType.SAFETENSORS:
|
|
42
|
-
|
|
44
|
+
with open_safetensors(filename) as weights_dict:
|
|
45
|
+
yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict)
|
|
46
|
+
else:
|
|
47
|
+
import torch
|
|
43
48
|
|
|
44
|
-
|
|
49
|
+
from lalamo.modules.torch_interop import torch_to_jax
|
|
45
50
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
49
|
-
return MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
|
|
51
|
+
torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
|
|
52
|
+
yield MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
|
|
50
53
|
|
|
51
54
|
|
|
52
55
|
class UseCase(Enum):
|
|
@@ -75,8 +78,8 @@ def _is_foreign_config_type(t: object) -> bool:
|
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
def _structure_foreign_config_factory(
|
|
78
|
-
t: object,
|
|
79
|
-
c: cattrs.Converter,
|
|
81
|
+
t: object, # noqa: ARG001
|
|
82
|
+
c: cattrs.Converter, # noqa: ARG001
|
|
80
83
|
) -> Callable[[object, object], type[ForeignConfig]]:
|
|
81
84
|
name_to_type = {t.__name__: t for t in ForeignConfig.__descendants__()}
|
|
82
85
|
|
|
@@ -88,7 +91,7 @@ def _structure_foreign_config_factory(
|
|
|
88
91
|
return _hook
|
|
89
92
|
|
|
90
93
|
|
|
91
|
-
def _unstructure_foreign_config_factory(t: object, c: cattrs.Converter) -> Callable[[type[ForeignConfig]], str]:
|
|
94
|
+
def _unstructure_foreign_config_factory(t: object, c: cattrs.Converter) -> Callable[[type[ForeignConfig]], str]: # noqa: ARG001
|
|
92
95
|
def _hook(v: type[ForeignConfig]) -> str:
|
|
93
96
|
return v.__name__
|
|
94
97
|
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from lalamo.model_import.decoder_configs import HFGPTOssConfig
|
|
2
|
+
|
|
3
|
+
from .common import ConfigMap, FileSpec, ModelSpec, WeightsType
|
|
4
|
+
|
|
5
|
+
__all__ = ["GPT_OSS_MODELS"]
|
|
6
|
+
|
|
7
|
+
GPT_OSS_MODELS = [
|
|
8
|
+
ModelSpec(
|
|
9
|
+
vendor="OpenAI",
|
|
10
|
+
family="GPT-OSS",
|
|
11
|
+
name="GPT-OSS-20B",
|
|
12
|
+
size="20B",
|
|
13
|
+
quantization=None,
|
|
14
|
+
repo="openai/gpt-oss-20b",
|
|
15
|
+
config_type=HFGPTOssConfig,
|
|
16
|
+
weights_type=WeightsType.SAFETENSORS,
|
|
17
|
+
configs=ConfigMap(
|
|
18
|
+
chat_template=FileSpec("chat_template.jinja"),
|
|
19
|
+
),
|
|
20
|
+
),
|
|
21
|
+
]
|
lalamo/modules/__init__.py
CHANGED
|
@@ -1,8 +1,14 @@
|
|
|
1
|
-
from .activations import Activation
|
|
1
|
+
from .activations import GELU, Activation, SiLU
|
|
2
2
|
from .attention import Attention, AttentionConfig
|
|
3
|
-
from .common import
|
|
4
|
-
from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderResult
|
|
5
|
-
from .decoder_layer import
|
|
3
|
+
from .common import AttentionType, ForwardPassMode, LalamoModule, config_converter
|
|
4
|
+
from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderForwardPassConfig, DecoderResult
|
|
5
|
+
from .decoder_layer import (
|
|
6
|
+
DecoderLayer,
|
|
7
|
+
DecoderLayerActivationTrace,
|
|
8
|
+
DecoderLayerConfig,
|
|
9
|
+
DecoderLayerForwardPassConfig,
|
|
10
|
+
DecoderLayerResult,
|
|
11
|
+
)
|
|
6
12
|
from .embedding import (
|
|
7
13
|
EmbeddingBase,
|
|
8
14
|
EmbeddingConfig,
|
|
@@ -24,7 +30,17 @@ from .linear import (
|
|
|
24
30
|
QLoRALinear,
|
|
25
31
|
QLoRALinearConfig,
|
|
26
32
|
)
|
|
27
|
-
from .mlp import
|
|
33
|
+
from .mlp import (
|
|
34
|
+
DenseMLP,
|
|
35
|
+
DenseMLPConfig,
|
|
36
|
+
MixtureOfExperts,
|
|
37
|
+
MixtureOfExpertsConfig,
|
|
38
|
+
MLPBase,
|
|
39
|
+
MLPConfig,
|
|
40
|
+
MLPForwardPassConfig,
|
|
41
|
+
RoutingFunction,
|
|
42
|
+
SoftmaxRouting,
|
|
43
|
+
)
|
|
28
44
|
from .normalization import RMSNorm, RMSNormConfig, UpcastMode
|
|
29
45
|
from .rope import (
|
|
30
46
|
LinearScalingRoPEConfig,
|
|
@@ -37,21 +53,27 @@ from .rope import (
|
|
|
37
53
|
)
|
|
38
54
|
|
|
39
55
|
__all__ = [
|
|
40
|
-
"
|
|
56
|
+
"GELU",
|
|
41
57
|
"Activation",
|
|
42
58
|
"Attention",
|
|
43
59
|
"AttentionConfig",
|
|
60
|
+
"AttentionType",
|
|
44
61
|
"Decoder",
|
|
45
62
|
"DecoderActivationTrace",
|
|
46
63
|
"DecoderConfig",
|
|
64
|
+
"DecoderForwardPassConfig",
|
|
47
65
|
"DecoderLayer",
|
|
48
66
|
"DecoderLayerActivationTrace",
|
|
49
67
|
"DecoderLayerConfig",
|
|
68
|
+
"DecoderLayerForwardPassConfig",
|
|
50
69
|
"DecoderLayerResult",
|
|
51
70
|
"DecoderResult",
|
|
71
|
+
"DenseMLP",
|
|
72
|
+
"DenseMLPConfig",
|
|
52
73
|
"DynamicKVCacheLayer",
|
|
53
74
|
"EmbeddingBase",
|
|
54
75
|
"EmbeddingConfig",
|
|
76
|
+
"ForwardPassMode",
|
|
55
77
|
"FullPrecisionLinear",
|
|
56
78
|
"FullPrecisionLinearConfig",
|
|
57
79
|
"GroupQuantizedLinear",
|
|
@@ -63,7 +85,11 @@ __all__ = [
|
|
|
63
85
|
"LinearConfig",
|
|
64
86
|
"LinearScalingRoPEConfig",
|
|
65
87
|
"LlamaRoPEConfig",
|
|
88
|
+
"MLPBase",
|
|
66
89
|
"MLPConfig",
|
|
90
|
+
"MLPForwardPassConfig",
|
|
91
|
+
"MixtureOfExperts",
|
|
92
|
+
"MixtureOfExpertsConfig",
|
|
67
93
|
"PositionalEmbeddings",
|
|
68
94
|
"QLoRALinear",
|
|
69
95
|
"QLoRALinearConfig",
|
|
@@ -73,6 +99,9 @@ __all__ = [
|
|
|
73
99
|
"RMSNormConfig",
|
|
74
100
|
"RoPE",
|
|
75
101
|
"RoPEConfig",
|
|
102
|
+
"RoutingFunction",
|
|
103
|
+
"SiLU",
|
|
104
|
+
"SoftmaxRouting",
|
|
76
105
|
"StaticKVCacheLayer",
|
|
77
106
|
"TiedEmbedding",
|
|
78
107
|
"TiedEmbeddingConfig",
|
|
@@ -80,7 +109,6 @@ __all__ = [
|
|
|
80
109
|
"UntiedEmbedding",
|
|
81
110
|
"UntiedEmbeddingConfig",
|
|
82
111
|
"UpcastMode",
|
|
83
|
-
"WeightLayout",
|
|
84
112
|
"YARNRoPEConfig",
|
|
85
113
|
"config_converter",
|
|
86
114
|
]
|
lalamo/modules/activations.py
CHANGED
|
@@ -1,30 +1,40 @@
|
|
|
1
|
-
from
|
|
1
|
+
from abc import abstractmethod
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
|
-
from
|
|
5
|
+
from attr import dataclass
|
|
6
6
|
from jaxtyping import Array, Float
|
|
7
7
|
|
|
8
|
+
from lalamo.modules.common import register_config_union
|
|
9
|
+
|
|
8
10
|
__all__ = [
|
|
11
|
+
"GELU",
|
|
9
12
|
"Activation",
|
|
10
|
-
"
|
|
13
|
+
"SiLU",
|
|
11
14
|
]
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
@
|
|
15
|
-
|
|
16
|
-
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class ActivationBase:
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class SiLU(ActivationBase):
|
|
25
|
+
alpha: float = 1.0
|
|
17
26
|
|
|
27
|
+
def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
|
|
28
|
+
return x / (1 + jnp.exp(-x * self.alpha))
|
|
18
29
|
|
|
19
|
-
class Activation(Enum):
|
|
20
|
-
SILU = "silu"
|
|
21
|
-
GELU = "gelu"
|
|
22
30
|
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class GELU(ActivationBase):
|
|
23
33
|
def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
|
|
24
|
-
return
|
|
34
|
+
return jax.nn.gelu(x)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
Activation = SiLU | GELU
|
|
25
38
|
|
|
26
39
|
|
|
27
|
-
|
|
28
|
-
Activation.SILU: silu,
|
|
29
|
-
Activation.GELU: jax.nn.gelu,
|
|
30
|
-
}
|
|
40
|
+
register_config_union(Activation)
|