lalamo 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lalamo/__init__.py +20 -5
  2. lalamo/data/__init__.py +8 -0
  3. lalamo/data/huggingface_message.py +38 -0
  4. lalamo/data/lalamo_completions.py +43 -0
  5. lalamo/data/utils.py +8 -0
  6. lalamo/language_model.py +152 -69
  7. lalamo/main.py +271 -43
  8. lalamo/message_processor.py +11 -1
  9. lalamo/model_import/common.py +17 -7
  10. lalamo/model_import/decoder_configs/__init__.py +3 -0
  11. lalamo/model_import/decoder_configs/executorch.py +12 -6
  12. lalamo/model_import/decoder_configs/huggingface/__init__.py +2 -0
  13. lalamo/model_import/decoder_configs/huggingface/common.py +1 -3
  14. lalamo/model_import/decoder_configs/huggingface/gemma2.py +11 -5
  15. lalamo/model_import/decoder_configs/huggingface/gemma3.py +14 -5
  16. lalamo/model_import/decoder_configs/huggingface/gpt_oss.py +195 -0
  17. lalamo/model_import/decoder_configs/huggingface/llama.py +38 -8
  18. lalamo/model_import/decoder_configs/huggingface/mistral.py +12 -6
  19. lalamo/model_import/decoder_configs/huggingface/qwen2.py +12 -6
  20. lalamo/model_import/decoder_configs/huggingface/qwen3.py +12 -6
  21. lalamo/model_import/huggingface_tokenizer_config.py +1 -4
  22. lalamo/model_import/loaders/executorch.py +10 -9
  23. lalamo/model_import/loaders/huggingface.py +104 -9
  24. lalamo/model_import/loaders/utils.py +92 -0
  25. lalamo/model_import/model_specs/__init__.py +4 -1
  26. lalamo/model_import/model_specs/common.py +15 -12
  27. lalamo/model_import/model_specs/gpt_oss.py +21 -0
  28. lalamo/modules/__init__.py +35 -7
  29. lalamo/modules/activations.py +24 -14
  30. lalamo/modules/attention.py +73 -20
  31. lalamo/modules/common.py +8 -57
  32. lalamo/modules/decoder.py +48 -34
  33. lalamo/modules/decoder_layer.py +57 -43
  34. lalamo/modules/embedding.py +13 -19
  35. lalamo/modules/kv_cache.py +53 -16
  36. lalamo/modules/linear.py +260 -79
  37. lalamo/modules/mlp.py +395 -23
  38. lalamo/modules/normalization.py +2 -3
  39. lalamo/modules/rope.py +32 -21
  40. lalamo/modules/utils.py +10 -0
  41. lalamo/speculator/__init__.py +11 -0
  42. lalamo/speculator/common.py +22 -0
  43. lalamo/speculator/inference.py +75 -0
  44. lalamo/speculator/ngram.py +154 -0
  45. lalamo/speculator/utils.py +52 -0
  46. lalamo/utils.py +27 -0
  47. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/METADATA +11 -4
  48. lalamo-0.4.0.dist-info/RECORD +71 -0
  49. lalamo-0.3.3.dist-info/RECORD +0 -59
  50. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/WHEEL +0 -0
  51. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/entry_points.txt +0 -0
  52. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/licenses/LICENSE +0 -0
  53. {lalamo-0.3.3.dist-info → lalamo-0.4.0.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,10 @@ from jaxtyping import Array
6
6
 
7
7
  from lalamo.common import ParameterPath
8
8
  from lalamo.modules import (
9
- MLP,
10
9
  Attention,
11
10
  Decoder,
12
11
  DecoderLayer,
12
+ DenseMLP,
13
13
  FullPrecisionLinear,
14
14
  GroupQuantizedLinear,
15
15
  LinearBase,
@@ -17,9 +17,11 @@ from lalamo.modules import (
17
17
  TiedEmbedding,
18
18
  UntiedEmbedding,
19
19
  )
20
+ from lalamo.modules.mlp import MixtureOfExperts, MLPBase
20
21
  from lalamo.quantization import QuantizationMode
21
22
 
22
23
  from .common import load_parameters
24
+ from .utils import decode_mxfp4, deinterleave_pairwise_columns
23
25
 
24
26
  __all__ = ["load_huggingface"]
25
27
 
@@ -78,7 +80,7 @@ def _process_quantized_tensors(
78
80
  zero_points = unpacked_zero_points.astype(module.config.activation_precision)
79
81
  processed_scales = scales.astype(module.config.activation_precision)
80
82
 
81
- return weights.transpose(), zero_points.transpose(), processed_scales.transpose()
83
+ return weights, zero_points, processed_scales
82
84
 
83
85
 
84
86
  def _fuse_full_precision_weights(
@@ -158,16 +160,103 @@ def load_linear(
158
160
  return load_parameters(
159
161
  lambda m: (m.weights, m.scales, m.zero_points, m.biases),
160
162
  module,
161
- (weights, scales, zero_points, bias),
163
+ (weights.T, scales.T, zero_points.T, bias),
162
164
  )
163
165
 
164
166
  raise TypeError(f"Unsupported module type for loading: {type(module)}")
165
167
 
166
168
 
167
- def load_mlp(module: MLP, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLP:
168
- up_projection = load_linear(module.up_projection, weights_dict, path, sublayers_to_fuse=["up_proj", "gate_proj"])
169
- down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
170
- return load_parameters(lambda m: (m.up_projection, m.down_projection), module, (up_projection, down_projection))
169
+ def load_mlp(module: MLPBase, weights_dict: Mapping[str, Array], path: ParameterPath) -> MLPBase:
170
+ if isinstance(module, DenseMLP):
171
+ # Standard dense MLP with separate sublayers.
172
+ up_projection = load_linear(
173
+ module.up_projection,
174
+ weights_dict,
175
+ path,
176
+ sublayers_to_fuse=["up_proj", "gate_proj"],
177
+ )
178
+ down_projection = load_linear(module.down_projection, weights_dict, path / "down_proj")
179
+ return load_parameters(
180
+ lambda m: (m.up_projection, m.down_projection),
181
+ module,
182
+ (up_projection, down_projection),
183
+ )
184
+
185
+ if isinstance(module, MixtureOfExperts):
186
+ return load_moe(module, weights_dict, path)
187
+
188
+ raise TypeError(f"Unsupported module type for loading: {type(module)}")
189
+
190
+
191
+ def load_moe(module: MixtureOfExperts, weights_dict: Mapping[str, Array], path: ParameterPath) -> MixtureOfExperts:
192
+ # Load router via the standard linear loader
193
+ router = load_linear(module.router, weights_dict, path / "router")
194
+
195
+ experts_path = path / "experts"
196
+ # Handle fused MXFP4 experts layout if present
197
+ if (experts_path / "gate_up_proj_blocks") in weights_dict:
198
+ # Decode fused gate/up (interleaved), split into (up, gate), and add +1.0 to up bias
199
+ fused = decode_mxfp4(
200
+ weights_dict[experts_path / "gate_up_proj_blocks"],
201
+ weights_dict[experts_path / "gate_up_proj_scales"],
202
+ dtype=module.activation_precision,
203
+ flatten=False,
204
+ )
205
+ # Stored as (experts, outputs=2*hidden_dim, input_blocks, input_block_elems)
206
+ # Merge blocks and move outputs last
207
+ fused_eio = rearrange(fused, "e o ib ie -> e (ib ie) o")
208
+ up_w, gate_w = deinterleave_pairwise_columns(fused_eio, first="odd")
209
+ combined_up_gate = jnp.concatenate([up_w, gate_w], axis=-1)
210
+ # Transpose to new layout: (experts, outputs, inputs)
211
+ combined_up_gate_w = jnp.swapaxes(combined_up_gate, -1, -2)
212
+
213
+ gub = weights_dict[experts_path / "gate_up_proj_bias"]
214
+ if gub.ndim == 1:
215
+ # Broadcast to (experts, 2*hidden_dim)
216
+ gub = jnp.broadcast_to(gub, (combined_up_gate_w.shape[0], gub.shape[0]))
217
+ up_b, gate_b = deinterleave_pairwise_columns(gub, first="odd")
218
+ combined_up_gate_b = jnp.concatenate([up_b + 1.0, gate_b], axis=-1)
219
+
220
+ up_projection = load_parameters(
221
+ lambda m: (m.weights, m.biases), # type: ignore
222
+ module.experts.up_projection,
223
+ (combined_up_gate_w, combined_up_gate_b),
224
+ )
225
+
226
+ # Down projection: decode MXFP4 to dense
227
+ down_w = decode_mxfp4(
228
+ weights_dict[experts_path / "down_proj_blocks"],
229
+ weights_dict[experts_path / "down_proj_scales"],
230
+ dtype=module.activation_precision,
231
+ flatten=False,
232
+ )
233
+ # Stored as (experts, outputs=model_dim, input_blocks, input_block_elems)
234
+ # Merge blocks and move outputs last
235
+ down_w = rearrange(down_w, "e o ib ie -> e o (ib ie)")
236
+ down_b = weights_dict[experts_path / "down_proj_bias"]
237
+ if down_b.ndim == 1:
238
+ down_b = jnp.broadcast_to(down_b, down_w.shape[:-1] + (down_b.shape[0],))
239
+
240
+ down_projection = load_parameters(
241
+ lambda m: (m.weights, m.biases), # type: ignore
242
+ module.experts.down_projection,
243
+ (down_w, down_b),
244
+ )
245
+
246
+ experts = load_parameters(
247
+ lambda m: (m.up_projection, m.down_projection),
248
+ module.experts,
249
+ (up_projection, down_projection),
250
+ )
251
+ else:
252
+ # Fallback: recursively load a standard DenseMLP experts module
253
+ experts = load_mlp(module.experts, weights_dict, experts_path)
254
+
255
+ return load_parameters(
256
+ lambda m: (m.router, m.experts),
257
+ module,
258
+ (router, experts),
259
+ )
171
260
 
172
261
 
173
262
  def load_rmsnorm(
@@ -202,10 +291,16 @@ def load_attention(
202
291
  else:
203
292
  key_norm = None
204
293
 
294
+ # GPT-OSS adds per-head attention sinks; load them if present.
295
+ if (path / "sinks") in weights_dict:
296
+ sinks = weights_dict[path / "sinks"]
297
+ else:
298
+ sinks = module.sinks
299
+
205
300
  return load_parameters(
206
- lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm),
301
+ lambda m: (m.qkv_projection, m.out_projection, m.query_norm, m.key_norm, m.sinks),
207
302
  module,
208
- (qkv_projection, out_projection, query_norm, key_norm),
303
+ (qkv_projection, out_projection, query_norm, key_norm, sinks),
209
304
  )
210
305
 
211
306
 
@@ -0,0 +1,92 @@
1
+ # MXFP4 decoding utilities for model loaders.
2
+ # Based on OpenAI's reference implementation logic for GPT-OSS MXFP4 weights.
3
+ # Converts packed FP4 blocks plus per-row scales into dense weights in the target dtype.
4
+
5
+ import jax.numpy as jnp
6
+ from jaxtyping import Array, DTypeLike
7
+
8
+ __all__ = [
9
+ "decode_mxfp4",
10
+ "deinterleave_pairwise_columns",
11
+ ]
12
+
13
+
14
+ # The 16 representable FP4 values used by MXFP4, in logical order (low nibble indices 0..15).
15
+ # See: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/weights.py
16
+ _MXFP4_LUT_VALUES = (
17
+ 0.0,
18
+ 0.5,
19
+ 1.0,
20
+ 1.5,
21
+ 2.0,
22
+ 3.0,
23
+ 4.0,
24
+ 6.0,
25
+ -0.0,
26
+ -0.5,
27
+ -1.0,
28
+ -1.5,
29
+ -2.0,
30
+ -3.0,
31
+ -4.0,
32
+ -6.0,
33
+ )
34
+
35
+
36
+ def decode_mxfp4(
37
+ blocks: Array,
38
+ scales: Array,
39
+ *,
40
+ dtype: DTypeLike,
41
+ flatten: bool = False,
42
+ ) -> Array:
43
+ target_dtype = jnp.dtype(dtype)
44
+
45
+ # Prepare LUT in target dtype
46
+ lut = jnp.array(_MXFP4_LUT_VALUES, dtype=target_dtype)
47
+
48
+ *prefix, rows, packed_cols = blocks.shape
49
+ if scales.shape != (*prefix, rows):
50
+ raise ValueError(
51
+ f"MXFP4 scales shape {scales.shape} does not match blocks prefix/rows {(*prefix, rows)}",
52
+ )
53
+
54
+ # Extract low/high nibble indices
55
+ low_mask = jnp.array(0x0F, dtype=blocks.dtype)
56
+ idx_lo = (blocks & low_mask).astype(jnp.int32)
57
+ idx_hi = (blocks >> jnp.array(4, dtype=blocks.dtype)).astype(jnp.int32)
58
+
59
+ # Lookup FP4 base values
60
+ vals_lo = lut[idx_lo]
61
+ vals_hi = lut[idx_hi]
62
+
63
+ # Interleave into (..., rows, 2*packed_cols)
64
+ out_shape = (*prefix, rows, packed_cols * 2)
65
+ out = jnp.empty(out_shape, dtype=target_dtype)
66
+ out = out.at[..., 0::2].set(vals_lo)
67
+ out = out.at[..., 1::2].set(vals_hi)
68
+
69
+ # Apply exponent scaling: exponents are biased by 127 in checkpoints
70
+ exp = scales.astype(jnp.int32) - 127
71
+ out = jnp.ldexp(out, exp[..., None])
72
+
73
+ if flatten:
74
+ return out.reshape(*prefix, rows * (packed_cols * 2))
75
+ return out
76
+
77
+
78
+ def deinterleave_pairwise_columns(
79
+ matrix: Array,
80
+ *,
81
+ first: str = "even",
82
+ ) -> tuple[Array, Array]:
83
+ if matrix.shape[-1] % 2 != 0:
84
+ raise ValueError(f"Last dimension must be even, got {matrix.shape[-1]}")
85
+
86
+ match first:
87
+ case "even":
88
+ return matrix[..., 0::2], matrix[..., 1::2]
89
+ case "odd":
90
+ return matrix[..., 1::2], matrix[..., 0::2]
91
+ case _:
92
+ raise ValueError("Parameter 'first' must be either 'even' or 'odd'")
@@ -1,10 +1,12 @@
1
1
  from .common import FileSpec, ModelSpec, UseCase, build_quantized_models
2
2
  from .deepseek import DEEPSEEK_MODELS
3
3
  from .gemma import GEMMA_MODELS
4
+ from .gpt_oss import GPT_OSS_MODELS
4
5
  from .huggingface import HUGGINGFACE_MODELS
5
6
  from .llama import LLAMA_MODELS
6
7
  from .mistral import MISTRAL_MODELS
7
- from .pleias import PLEIAS_MODELS
8
+
9
+ # from .pleias import PLEIAS_MODELS
8
10
  from .polaris import POLARIS_MODELS
9
11
  from .qwen import QWEN_MODELS
10
12
  from .reka import REKA_MODELS
@@ -23,6 +25,7 @@ ALL_MODEL_LISTS = [
23
25
  DEEPSEEK_MODELS,
24
26
  GEMMA_MODELS,
25
27
  HUGGINGFACE_MODELS,
28
+ GPT_OSS_MODELS,
26
29
  MISTRAL_MODELS,
27
30
  # PLEIAS_MODELS, # TODO(norpadon): Add chat template
28
31
  POLARIS_MODELS,
@@ -1,7 +1,9 @@
1
1
  from collections.abc import (
2
2
  Callable,
3
+ Iterator,
3
4
  Mapping,
4
5
  )
6
+ from contextlib import contextmanager
5
7
  from dataclasses import dataclass, field
6
8
  from enum import Enum
7
9
  from pathlib import Path
@@ -10,11 +12,10 @@ from typing import ClassVar, cast, get_args, get_origin
10
12
  import cattrs
11
13
  import jax.numpy as jnp
12
14
  from jaxtyping import Array, DTypeLike
13
- from safetensors.flax import load_file as load_safetensors
14
15
 
15
16
  from lalamo.model_import.decoder_configs import ForeignConfig
16
17
  from lalamo.quantization import QuantizationMode
17
- from lalamo.utils import MapDictValues
18
+ from lalamo.utils import MapDictValues, open_safetensors
18
19
 
19
20
  __all__ = [
20
21
  "ConfigMap",
@@ -37,16 +38,18 @@ class WeightsType(Enum):
37
38
  SAFETENSORS = "safetensors"
38
39
  TORCH = "torch"
39
40
 
40
- def load(self, filename: Path | str, float_dtype: DTypeLike) -> Mapping[str, jnp.ndarray]:
41
+ @contextmanager
42
+ def load(self, filename: Path | str, float_dtype: DTypeLike) -> Iterator[Mapping[str, jnp.ndarray]]:
41
43
  if self == WeightsType.SAFETENSORS:
42
- return MapDictValues(lambda v: cast_if_float(v, float_dtype), load_safetensors(filename))
44
+ with open_safetensors(filename) as weights_dict:
45
+ yield MapDictValues(lambda v: cast_if_float(v, float_dtype), weights_dict)
46
+ else:
47
+ import torch
43
48
 
44
- import torch
49
+ from lalamo.modules.torch_interop import torch_to_jax
45
50
 
46
- from lalamo.modules.torch_interop import torch_to_jax
47
-
48
- torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
49
- return MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
51
+ torch_weights = torch.load(filename, map_location="cpu", weights_only=True)
52
+ yield MapDictValues(lambda v: cast_if_float(torch_to_jax(v), float_dtype), torch_weights)
50
53
 
51
54
 
52
55
  class UseCase(Enum):
@@ -75,8 +78,8 @@ def _is_foreign_config_type(t: object) -> bool:
75
78
 
76
79
 
77
80
  def _structure_foreign_config_factory(
78
- t: object,
79
- c: cattrs.Converter,
81
+ t: object, # noqa: ARG001
82
+ c: cattrs.Converter, # noqa: ARG001
80
83
  ) -> Callable[[object, object], type[ForeignConfig]]:
81
84
  name_to_type = {t.__name__: t for t in ForeignConfig.__descendants__()}
82
85
 
@@ -88,7 +91,7 @@ def _structure_foreign_config_factory(
88
91
  return _hook
89
92
 
90
93
 
91
- def _unstructure_foreign_config_factory(t: object, c: cattrs.Converter) -> Callable[[type[ForeignConfig]], str]:
94
+ def _unstructure_foreign_config_factory(t: object, c: cattrs.Converter) -> Callable[[type[ForeignConfig]], str]: # noqa: ARG001
92
95
  def _hook(v: type[ForeignConfig]) -> str:
93
96
  return v.__name__
94
97
 
@@ -0,0 +1,21 @@
1
+ from lalamo.model_import.decoder_configs import HFGPTOssConfig
2
+
3
+ from .common import ConfigMap, FileSpec, ModelSpec, WeightsType
4
+
5
+ __all__ = ["GPT_OSS_MODELS"]
6
+
7
+ GPT_OSS_MODELS = [
8
+ ModelSpec(
9
+ vendor="OpenAI",
10
+ family="GPT-OSS",
11
+ name="GPT-OSS-20B",
12
+ size="20B",
13
+ quantization=None,
14
+ repo="openai/gpt-oss-20b",
15
+ config_type=HFGPTOssConfig,
16
+ weights_type=WeightsType.SAFETENSORS,
17
+ configs=ConfigMap(
18
+ chat_template=FileSpec("chat_template.jinja"),
19
+ ),
20
+ ),
21
+ ]
@@ -1,8 +1,14 @@
1
- from .activations import Activation
1
+ from .activations import GELU, Activation, SiLU
2
2
  from .attention import Attention, AttentionConfig
3
- from .common import LalamoModule, WeightLayout, config_converter
4
- from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderResult
5
- from .decoder_layer import DecoderLayer, DecoderLayerActivationTrace, DecoderLayerConfig, DecoderLayerResult
3
+ from .common import AttentionType, ForwardPassMode, LalamoModule, config_converter
4
+ from .decoder import Decoder, DecoderActivationTrace, DecoderConfig, DecoderForwardPassConfig, DecoderResult
5
+ from .decoder_layer import (
6
+ DecoderLayer,
7
+ DecoderLayerActivationTrace,
8
+ DecoderLayerConfig,
9
+ DecoderLayerForwardPassConfig,
10
+ DecoderLayerResult,
11
+ )
6
12
  from .embedding import (
7
13
  EmbeddingBase,
8
14
  EmbeddingConfig,
@@ -24,7 +30,17 @@ from .linear import (
24
30
  QLoRALinear,
25
31
  QLoRALinearConfig,
26
32
  )
27
- from .mlp import MLP, MLPConfig
33
+ from .mlp import (
34
+ DenseMLP,
35
+ DenseMLPConfig,
36
+ MixtureOfExperts,
37
+ MixtureOfExpertsConfig,
38
+ MLPBase,
39
+ MLPConfig,
40
+ MLPForwardPassConfig,
41
+ RoutingFunction,
42
+ SoftmaxRouting,
43
+ )
28
44
  from .normalization import RMSNorm, RMSNormConfig, UpcastMode
29
45
  from .rope import (
30
46
  LinearScalingRoPEConfig,
@@ -37,21 +53,27 @@ from .rope import (
37
53
  )
38
54
 
39
55
  __all__ = [
40
- "MLP",
56
+ "GELU",
41
57
  "Activation",
42
58
  "Attention",
43
59
  "AttentionConfig",
60
+ "AttentionType",
44
61
  "Decoder",
45
62
  "DecoderActivationTrace",
46
63
  "DecoderConfig",
64
+ "DecoderForwardPassConfig",
47
65
  "DecoderLayer",
48
66
  "DecoderLayerActivationTrace",
49
67
  "DecoderLayerConfig",
68
+ "DecoderLayerForwardPassConfig",
50
69
  "DecoderLayerResult",
51
70
  "DecoderResult",
71
+ "DenseMLP",
72
+ "DenseMLPConfig",
52
73
  "DynamicKVCacheLayer",
53
74
  "EmbeddingBase",
54
75
  "EmbeddingConfig",
76
+ "ForwardPassMode",
55
77
  "FullPrecisionLinear",
56
78
  "FullPrecisionLinearConfig",
57
79
  "GroupQuantizedLinear",
@@ -63,7 +85,11 @@ __all__ = [
63
85
  "LinearConfig",
64
86
  "LinearScalingRoPEConfig",
65
87
  "LlamaRoPEConfig",
88
+ "MLPBase",
66
89
  "MLPConfig",
90
+ "MLPForwardPassConfig",
91
+ "MixtureOfExperts",
92
+ "MixtureOfExpertsConfig",
67
93
  "PositionalEmbeddings",
68
94
  "QLoRALinear",
69
95
  "QLoRALinearConfig",
@@ -73,6 +99,9 @@ __all__ = [
73
99
  "RMSNormConfig",
74
100
  "RoPE",
75
101
  "RoPEConfig",
102
+ "RoutingFunction",
103
+ "SiLU",
104
+ "SoftmaxRouting",
76
105
  "StaticKVCacheLayer",
77
106
  "TiedEmbedding",
78
107
  "TiedEmbeddingConfig",
@@ -80,7 +109,6 @@ __all__ = [
80
109
  "UntiedEmbedding",
81
110
  "UntiedEmbeddingConfig",
82
111
  "UpcastMode",
83
- "WeightLayout",
84
112
  "YARNRoPEConfig",
85
113
  "config_converter",
86
114
  ]
@@ -1,30 +1,40 @@
1
- from enum import Enum
1
+ from abc import abstractmethod
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- from jax import jit
5
+ from attr import dataclass
6
6
  from jaxtyping import Array, Float
7
7
 
8
+ from lalamo.modules.common import register_config_union
9
+
8
10
  __all__ = [
11
+ "GELU",
9
12
  "Activation",
10
- "silu",
13
+ "SiLU",
11
14
  ]
12
15
 
13
16
 
14
- @jit
15
- def silu(x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
16
- return x / (1 + jnp.exp(-x))
17
+ @dataclass(frozen=True)
18
+ class ActivationBase:
19
+ @abstractmethod
20
+ def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]: ...
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class SiLU(ActivationBase):
25
+ alpha: float = 1.0
17
26
 
27
+ def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
28
+ return x / (1 + jnp.exp(-x * self.alpha))
18
29
 
19
- class Activation(Enum):
20
- SILU = "silu"
21
- GELU = "gelu"
22
30
 
31
+ @dataclass(frozen=True)
32
+ class GELU(ActivationBase):
23
33
  def __call__(self, x: Float[Array, "*dims"]) -> Float[Array, "*dims"]:
24
- return ACTIVATION_FUNCTIONS[self](x)
34
+ return jax.nn.gelu(x)
35
+
36
+
37
+ Activation = SiLU | GELU
25
38
 
26
39
 
27
- ACTIVATION_FUNCTIONS = {
28
- Activation.SILU: silu,
29
- Activation.GELU: jax.nn.gelu,
30
- }
40
+ register_config_union(Activation)