lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/model_import/__init__.py +8 -0
  3. lalamo/model_import/common.py +111 -0
  4. lalamo/model_import/configs/__init__.py +23 -0
  5. lalamo/model_import/configs/common.py +62 -0
  6. lalamo/model_import/configs/executorch.py +166 -0
  7. lalamo/model_import/configs/huggingface/__init__.py +18 -0
  8. lalamo/model_import/configs/huggingface/common.py +72 -0
  9. lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  10. lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  11. lalamo/model_import/configs/huggingface/llama.py +155 -0
  12. lalamo/model_import/configs/huggingface/mistral.py +132 -0
  13. lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  14. lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  15. lalamo/model_import/loaders/__init__.py +7 -0
  16. lalamo/model_import/loaders/common.py +45 -0
  17. lalamo/model_import/loaders/executorch.py +223 -0
  18. lalamo/model_import/loaders/huggingface.py +304 -0
  19. lalamo/model_import/model_specs/__init__.py +38 -0
  20. lalamo/model_import/model_specs/common.py +118 -0
  21. lalamo/model_import/model_specs/deepseek.py +28 -0
  22. lalamo/model_import/model_specs/gemma.py +76 -0
  23. lalamo/model_import/model_specs/huggingface.py +28 -0
  24. lalamo/model_import/model_specs/llama.py +101 -0
  25. lalamo/model_import/model_specs/mistral.py +59 -0
  26. lalamo/model_import/model_specs/pleias.py +28 -0
  27. lalamo/model_import/model_specs/polaris.py +22 -0
  28. lalamo/model_import/model_specs/qwen.py +336 -0
  29. lalamo/model_import/model_specs/reka.py +28 -0
  30. lalamo/modules/__init__.py +85 -0
  31. lalamo/modules/activations.py +30 -0
  32. lalamo/modules/attention.py +326 -0
  33. lalamo/modules/common.py +133 -0
  34. lalamo/modules/decoder.py +244 -0
  35. lalamo/modules/decoder_layer.py +240 -0
  36. lalamo/modules/embedding.py +299 -0
  37. lalamo/modules/kv_cache.py +196 -0
  38. lalamo/modules/linear.py +603 -0
  39. lalamo/modules/mlp.py +79 -0
  40. lalamo/modules/normalization.py +77 -0
  41. lalamo/modules/rope.py +255 -0
  42. lalamo/modules/utils.py +13 -0
  43. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
  44. lalamo-0.2.2.dist-info/RECORD +53 -0
  45. lalamo-0.2.1.dist-info/RECORD +0 -12
  46. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
  47. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
  48. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
  49. {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,187 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import jax.numpy as jnp
5
+ from jaxtyping import DTypeLike
6
+
7
+ from lalamo.modules import (
8
+ DecoderConfig,
9
+ TiedEmbeddingConfig,
10
+ )
11
+ from lalamo.modules.activations import Activation
12
+ from lalamo.modules.attention import AttentionConfig
13
+ from lalamo.modules.decoder_layer import DecoderLayerConfig
14
+ from lalamo.modules.linear import FullPrecisionLinearConfig
15
+ from lalamo.modules.mlp import MLPConfig
16
+ from lalamo.modules.normalization import RMSNormConfig, UpcastMode
17
+ from lalamo.modules.rope import LinearScalingRoPEConfig, UnscaledRoPEConfig
18
+
19
+ from .common import HuggingFaceConfig
20
+
21
+ __all__ = ["HFGemma3Config", "HFGemma3TextConfig"]
22
+
23
+
24
+ NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER = 6
25
+
26
+
27
+ def _round_to_bfloat16(x: float) -> float:
28
+ return jnp.asarray(x).astype(jnp.bfloat16).item()
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class GemmaRoPEScalingConfig:
33
+ factor: float
34
+ rope_type: Literal["linear"]
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class HFGemma3TextConfigRaw:
39
+ hidden_size: int
40
+ intermediate_size: int
41
+ model_type: Literal["gemma3_text"]
42
+ num_hidden_layers: int
43
+ sliding_window: int
44
+ rms_norm_eps: float = 1e-06
45
+ query_pre_attn_scalar: float = 256.0
46
+ attention_bias: bool = False
47
+ num_attention_heads: int = 8
48
+ num_key_value_heads: int = 4
49
+ attn_logit_softcapping: float | None = None
50
+ head_dim: int = 256
51
+ max_position_embeddings: int = 131072
52
+ rope_theta: float = 1000000.0
53
+ rope_local_base_freq: float = 10000.0
54
+ rope_scaling: GemmaRoPEScalingConfig | None = None
55
+ final_logit_softcapping: float | None = None
56
+ vocab_size: int = 262208
57
+
58
+ @property
59
+ def sliding_window_sizes(self) -> list[int | None]:
60
+ result = []
61
+ for i in range(self.num_hidden_layers):
62
+ if (i + 1) % NUM_SLIDING_WINDOW_LAYERS_PER_FULL_ATTENTION_LAYER == 0:
63
+ result.append(None)
64
+ else:
65
+ result.append(self.sliding_window)
66
+ return result
67
+
68
+ def to_decoder_config(
69
+ self,
70
+ context_length: int | None,
71
+ activation_precision: DTypeLike,
72
+ accumulation_precision: DTypeLike,
73
+ ) -> DecoderConfig:
74
+ input_scale = _round_to_bfloat16(self.hidden_size**0.5)
75
+ attention_scale = self.query_pre_attn_scalar**-0.5
76
+ embedding_config = TiedEmbeddingConfig(
77
+ input_scale=input_scale,
78
+ logits_soft_cap=None,
79
+ precision=activation_precision,
80
+ )
81
+ rms_norm_config = RMSNormConfig(
82
+ scale_precision=activation_precision,
83
+ accumulation_precision=accumulation_precision,
84
+ epsilon=self.rms_norm_eps,
85
+ scale_offset=1.0,
86
+ upcast_mode=UpcastMode.FULL_LAYER,
87
+ )
88
+
89
+ if self.rope_scaling is not None:
90
+ global_rope_config = LinearScalingRoPEConfig(
91
+ precision=activation_precision,
92
+ base=self.rope_theta,
93
+ max_sequence_length=self.max_position_embeddings,
94
+ scaling_factor=self.rope_scaling.factor,
95
+ )
96
+ else:
97
+ global_rope_config = UnscaledRoPEConfig(
98
+ precision=activation_precision,
99
+ base=self.rope_theta,
100
+ max_sequence_length=self.max_position_embeddings,
101
+ )
102
+ local_rope_config = UnscaledRoPEConfig(
103
+ precision=activation_precision,
104
+ base=self.rope_local_base_freq,
105
+ max_sequence_length=self.max_position_embeddings,
106
+ )
107
+
108
+ linear_config = FullPrecisionLinearConfig(precision=activation_precision)
109
+ mlp_config = MLPConfig(linear_config=linear_config, activation=Activation.GELU)
110
+ attention_config = AttentionConfig(
111
+ qkv_projection_config=linear_config,
112
+ out_projection_config=linear_config,
113
+ query_norm_config=rms_norm_config,
114
+ key_norm_config=rms_norm_config,
115
+ logit_soft_cap=self.attn_logit_softcapping,
116
+ has_qkv_biases=self.attention_bias,
117
+ has_out_biases=self.attention_bias,
118
+ )
119
+ decoder_layer_config = DecoderLayerConfig(
120
+ pre_attention_norm_config=rms_norm_config,
121
+ attention_config=attention_config,
122
+ post_attention_norm_config=rms_norm_config,
123
+ pre_mlp_norm_config=rms_norm_config,
124
+ mlp_config=mlp_config,
125
+ post_mlp_norm_config=rms_norm_config,
126
+ )
127
+ return DecoderConfig(
128
+ embedding_config=embedding_config,
129
+ global_rope_config=global_rope_config,
130
+ local_rope_config=local_rope_config,
131
+ layer_config=decoder_layer_config,
132
+ output_norm_config=rms_norm_config,
133
+ vocab_size=self.vocab_size,
134
+ model_dim=self.hidden_size,
135
+ hidden_dim=self.intermediate_size,
136
+ num_heads=self.num_attention_heads,
137
+ num_groups=self.num_key_value_heads,
138
+ head_dim=self.head_dim,
139
+ attention_scale=attention_scale,
140
+ num_layers=self.num_hidden_layers,
141
+ sliding_window_sizes=tuple(self.sliding_window_sizes),
142
+ context_length=context_length or self.max_position_embeddings,
143
+ )
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class HFGemma3TextConfig(HFGemma3TextConfigRaw, HuggingFaceConfig):
148
+ pass
149
+
150
+
151
+ @dataclass(frozen=True)
152
+ class HFGemma3VisionConfig:
153
+ hidden_size: int
154
+ image_size: int
155
+ intermediate_size: int
156
+ model_type: Literal["siglip_vision_model"]
157
+ num_attention_heads: int
158
+ num_hidden_layers: int
159
+ patch_size: int
160
+ vision_use_head: bool
161
+
162
+
163
+ @dataclass(frozen=True)
164
+ class HFGemma3Config(HuggingFaceConfig):
165
+ architectures: list[Literal["Gemma3ForConditionalGeneration"]]
166
+ boi_token_index: int
167
+ eoi_token_index: int
168
+ eos_token_id: int | list[int]
169
+ image_token_index: int
170
+ initializer_range: float
171
+ mm_tokens_per_image: int
172
+ model_type: Literal["gemma3"]
173
+ text_config: HFGemma3TextConfigRaw
174
+ transformers_version: str
175
+ vision_config: HFGemma3VisionConfig
176
+
177
+ def to_decoder_config(
178
+ self,
179
+ context_length: int | None,
180
+ activation_precision: DTypeLike,
181
+ accumulation_precision: DTypeLike,
182
+ ) -> DecoderConfig:
183
+ return self.text_config.to_decoder_config(
184
+ context_length=context_length,
185
+ activation_precision=activation_precision,
186
+ accumulation_precision=accumulation_precision,
187
+ )
@@ -0,0 +1,155 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ from jaxtyping import DTypeLike
5
+
6
+ from lalamo.modules import (
7
+ Activation,
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DecoderLayerConfig,
11
+ FullPrecisionLinearConfig,
12
+ GroupQuantizedLinearConfig,
13
+ LlamaRoPEConfig,
14
+ MLPConfig,
15
+ RMSNormConfig,
16
+ TiedEmbeddingConfig,
17
+ UnscaledRoPEConfig,
18
+ UpcastMode,
19
+ )
20
+ from lalamo.modules.embedding import UntiedEmbeddingConfig
21
+ from lalamo.quantization import QuantizationMode
22
+
23
+ from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
24
+
25
+ __all__ = ["HFLlamaConfig"]
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class LlamaRopeScalingConfig:
30
+ factor: float
31
+ high_freq_factor: float
32
+ low_freq_factor: float
33
+ original_max_position_embeddings: int
34
+ rope_type: Literal["llama3"]
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class HFLlamaConfig(HuggingFaceConfig):
39
+ architectures: list[Literal["LlamaForCausalLM"]]
40
+ attention_bias: bool
41
+ attention_dropout: float
42
+ bos_token_id: int | list[int]
43
+ eos_token_id: int | list[int]
44
+ hidden_act: Literal["silu"]
45
+ hidden_size: int
46
+ initializer_range: float
47
+ intermediate_size: int
48
+ max_position_embeddings: int
49
+ mlp_bias: bool
50
+ model_type: Literal["llama"]
51
+ num_attention_heads: int
52
+ num_hidden_layers: int
53
+ num_key_value_heads: int
54
+ pretraining_tp: int
55
+ rms_norm_eps: float
56
+ rope_scaling: LlamaRopeScalingConfig | None
57
+ rope_theta: float
58
+ tie_word_embeddings: bool
59
+ transformers_version: str
60
+ use_cache: bool
61
+ vocab_size: int
62
+ head_dim: int | None = None
63
+
64
+ quantization_config: AWQQuantizationConfig | GPTQQuantizationConfig | None = None
65
+
66
+ def to_decoder_config(
67
+ self,
68
+ context_length: int | None,
69
+ activation_precision: DTypeLike,
70
+ accumulation_precision: DTypeLike,
71
+ ) -> DecoderConfig:
72
+ if self.tie_word_embeddings:
73
+ embedding_config = TiedEmbeddingConfig(
74
+ input_scale=None,
75
+ logits_soft_cap=None,
76
+ precision=activation_precision,
77
+ )
78
+ else:
79
+ embedding_config = UntiedEmbeddingConfig(
80
+ input_scale=None,
81
+ logits_soft_cap=None,
82
+ precision=activation_precision,
83
+ )
84
+ if self.rope_scaling is None:
85
+ rope_config = UnscaledRoPEConfig(
86
+ precision=activation_precision,
87
+ base=self.rope_theta,
88
+ max_sequence_length=self.max_position_embeddings,
89
+ )
90
+ else:
91
+ rope_config = LlamaRoPEConfig(
92
+ precision=activation_precision,
93
+ base=self.rope_theta,
94
+ max_sequence_length=self.max_position_embeddings,
95
+ scaling_factor=self.rope_scaling.factor,
96
+ original_context_length=self.rope_scaling.original_max_position_embeddings,
97
+ low_frequency_factor=self.rope_scaling.low_freq_factor,
98
+ high_frequency_factor=self.rope_scaling.high_freq_factor,
99
+ )
100
+ rmsnorm_config = RMSNormConfig(
101
+ scale_precision=activation_precision,
102
+ accumulation_precision=accumulation_precision,
103
+ epsilon=self.rms_norm_eps,
104
+ scale_offset=None,
105
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
106
+ )
107
+ if self.quantization_config is None:
108
+ linear_config = FullPrecisionLinearConfig(
109
+ precision=activation_precision,
110
+ )
111
+ else:
112
+ linear_config = GroupQuantizedLinearConfig(
113
+ group_size=self.quantization_config.group_size,
114
+ weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
115
+ activation_quantization_mode=None,
116
+ activation_precision=activation_precision,
117
+ )
118
+ attention_config = AttentionConfig(
119
+ qkv_projection_config=linear_config,
120
+ out_projection_config=linear_config,
121
+ query_norm_config=None,
122
+ key_norm_config=None,
123
+ logit_soft_cap=None,
124
+ has_qkv_biases=self.attention_bias,
125
+ has_out_biases=False,
126
+ )
127
+ mlp_config = MLPConfig(
128
+ linear_config=linear_config,
129
+ activation=Activation.SILU,
130
+ )
131
+ decoder_layer_config = DecoderLayerConfig(
132
+ pre_attention_norm_config=rmsnorm_config,
133
+ attention_config=attention_config,
134
+ post_attention_norm_config=None,
135
+ pre_mlp_norm_config=rmsnorm_config,
136
+ mlp_config=mlp_config,
137
+ post_mlp_norm_config=None,
138
+ )
139
+ return DecoderConfig(
140
+ embedding_config=embedding_config,
141
+ global_rope_config=rope_config,
142
+ local_rope_config=None,
143
+ layer_config=decoder_layer_config,
144
+ output_norm_config=rmsnorm_config,
145
+ vocab_size=self.vocab_size,
146
+ model_dim=self.hidden_size,
147
+ hidden_dim=self.intermediate_size,
148
+ num_heads=self.num_attention_heads,
149
+ num_groups=self.num_key_value_heads,
150
+ head_dim=self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads,
151
+ attention_scale=None,
152
+ num_layers=self.num_hidden_layers,
153
+ sliding_window_sizes=None,
154
+ context_length=context_length or self.max_position_embeddings,
155
+ )
@@ -0,0 +1,132 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ from jaxtyping import DTypeLike
5
+
6
+ from lalamo.modules import (
7
+ Activation,
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DecoderLayerConfig,
11
+ FullPrecisionLinearConfig,
12
+ MLPConfig,
13
+ RMSNormConfig,
14
+ TiedEmbeddingConfig,
15
+ UnscaledRoPEConfig,
16
+ UntiedEmbeddingConfig,
17
+ )
18
+ from lalamo.modules.normalization import UpcastMode
19
+
20
+ from .common import HuggingFaceConfig
21
+
22
+ __all__ = ["HFMistralConfig"]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class HFMistralConfig(HuggingFaceConfig):
27
+ architectures: list[Literal["MistralForCausalLM"]]
28
+ attention_dropout: float
29
+ bos_token_id: int
30
+ eos_token_id: int
31
+ hidden_act: Literal["silu"]
32
+ hidden_size: int
33
+ initializer_range: float
34
+ intermediate_size: int
35
+ max_position_embeddings: int
36
+ model_type: Literal["mistral"]
37
+ num_attention_heads: int
38
+ num_hidden_layers: int
39
+ num_key_value_heads: int
40
+ rms_norm_eps: float
41
+ rope_theta: float
42
+ sliding_window: int | None
43
+ tie_word_embeddings: bool
44
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
45
+ transformers_version: str
46
+ use_cache: bool
47
+ vocab_size: int
48
+ head_dim: int | None = None
49
+
50
+ def to_decoder_config(
51
+ self,
52
+ context_length: int | None,
53
+ activation_precision: DTypeLike,
54
+ accumulation_precision: DTypeLike,
55
+ ) -> DecoderConfig:
56
+ # Choose embedding config based on tie_word_embeddings flag
57
+ if self.tie_word_embeddings:
58
+ embedding_config = TiedEmbeddingConfig(
59
+ input_scale=None,
60
+ logits_soft_cap=None,
61
+ precision=activation_precision,
62
+ )
63
+ else:
64
+ embedding_config = UntiedEmbeddingConfig(
65
+ input_scale=None,
66
+ logits_soft_cap=None,
67
+ precision=activation_precision,
68
+ )
69
+
70
+ rope_config = UnscaledRoPEConfig(
71
+ precision=activation_precision,
72
+ base=self.rope_theta,
73
+ max_sequence_length=self.max_position_embeddings,
74
+ )
75
+
76
+ rmsnorm_config = RMSNormConfig(
77
+ scale_precision=activation_precision,
78
+ accumulation_precision=accumulation_precision,
79
+ epsilon=self.rms_norm_eps,
80
+ scale_offset=None,
81
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
82
+ )
83
+
84
+ linear_config = FullPrecisionLinearConfig(
85
+ precision=activation_precision,
86
+ )
87
+
88
+ attention_config = AttentionConfig(
89
+ qkv_projection_config=linear_config,
90
+ out_projection_config=linear_config,
91
+ query_norm_config=None,
92
+ key_norm_config=None,
93
+ logit_soft_cap=None,
94
+ has_qkv_biases=False,
95
+ has_out_biases=False,
96
+ )
97
+
98
+ mlp_config = MLPConfig(
99
+ linear_config=linear_config,
100
+ activation=Activation.SILU,
101
+ )
102
+
103
+ decoder_layer_config = DecoderLayerConfig(
104
+ pre_attention_norm_config=rmsnorm_config,
105
+ attention_config=attention_config,
106
+ post_attention_norm_config=None,
107
+ pre_mlp_norm_config=rmsnorm_config,
108
+ mlp_config=mlp_config,
109
+ post_mlp_norm_config=None,
110
+ )
111
+
112
+ head_dim = self.head_dim or self.hidden_size // self.num_attention_heads
113
+
114
+ return DecoderConfig(
115
+ embedding_config=embedding_config,
116
+ global_rope_config=rope_config,
117
+ local_rope_config=None,
118
+ layer_config=decoder_layer_config,
119
+ output_norm_config=rmsnorm_config,
120
+ vocab_size=self.vocab_size,
121
+ model_dim=self.hidden_size,
122
+ hidden_dim=self.intermediate_size,
123
+ num_heads=self.num_attention_heads,
124
+ num_groups=self.num_key_value_heads,
125
+ head_dim=head_dim,
126
+ attention_scale=None,
127
+ num_layers=self.num_hidden_layers,
128
+ sliding_window_sizes=tuple([self.sliding_window] * self.num_hidden_layers)
129
+ if self.sliding_window is not None
130
+ else None,
131
+ context_length=context_length or self.max_position_embeddings,
132
+ )
@@ -0,0 +1,144 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ from jaxtyping import DTypeLike
5
+
6
+ from lalamo.modules import (
7
+ Activation,
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DecoderLayerConfig,
11
+ FullPrecisionLinearConfig,
12
+ GroupQuantizedLinearConfig,
13
+ MLPConfig,
14
+ RMSNormConfig,
15
+ TiedEmbeddingConfig,
16
+ UnscaledRoPEConfig,
17
+ UntiedEmbeddingConfig,
18
+ UpcastMode,
19
+ )
20
+ from lalamo.quantization import QuantizationMode
21
+
22
+ from .common import AWQQuantizationConfig, GPTQQuantizationConfig, HuggingFaceConfig
23
+
24
+ __all__ = ["HFQwen2Config"]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class HFQwen2Config(HuggingFaceConfig):
29
+ architectures: list[Literal["Qwen2ForCausalLM"]]
30
+ attention_dropout: float
31
+ bos_token_id: int | list[int]
32
+ eos_token_id: int | list[int]
33
+ hidden_act: Literal["silu"]
34
+ hidden_size: int
35
+ initializer_range: float
36
+ intermediate_size: int
37
+ max_position_embeddings: int
38
+ max_window_layers: int
39
+ model_type: Literal["qwen2"]
40
+ num_attention_heads: int
41
+ num_hidden_layers: int
42
+ num_key_value_heads: int
43
+ rms_norm_eps: float
44
+ rope_theta: float
45
+ sliding_window: int
46
+ tie_word_embeddings: bool
47
+ transformers_version: str
48
+ use_cache: bool
49
+ use_sliding_window: bool
50
+ vocab_size: int
51
+
52
+ quantization_config: AWQQuantizationConfig | GPTQQuantizationConfig | None = None
53
+
54
+ def _get_sliding_window_sizes(self) -> list[int | None]:
55
+ if not self.use_sliding_window:
56
+ return [None] * self.num_hidden_layers
57
+
58
+ sliding_window_sizes = []
59
+ for i in range(self.num_hidden_layers):
60
+ if i < self.max_window_layers:
61
+ sliding_window_sizes.append(self.sliding_window)
62
+ else:
63
+ sliding_window_sizes.append(None)
64
+ return sliding_window_sizes
65
+
66
+ def to_decoder_config(
67
+ self,
68
+ context_length: int | None,
69
+ activation_precision: DTypeLike,
70
+ accumulation_precision: DTypeLike,
71
+ ) -> DecoderConfig:
72
+ if self.tie_word_embeddings:
73
+ embedding_config = TiedEmbeddingConfig(
74
+ input_scale=None,
75
+ logits_soft_cap=None,
76
+ precision=activation_precision,
77
+ )
78
+ else:
79
+ embedding_config = UntiedEmbeddingConfig(
80
+ input_scale=None,
81
+ logits_soft_cap=None,
82
+ precision=activation_precision,
83
+ )
84
+ rope_config = UnscaledRoPEConfig(
85
+ precision=activation_precision,
86
+ base=self.rope_theta,
87
+ max_sequence_length=self.max_position_embeddings,
88
+ )
89
+ rmsnorm_config = RMSNormConfig(
90
+ scale_precision=activation_precision,
91
+ accumulation_precision=accumulation_precision,
92
+ epsilon=self.rms_norm_eps,
93
+ scale_offset=None,
94
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
95
+ )
96
+ if self.quantization_config is None:
97
+ linear_config = FullPrecisionLinearConfig(
98
+ precision=activation_precision,
99
+ )
100
+ else:
101
+ linear_config = GroupQuantizedLinearConfig(
102
+ group_size=self.quantization_config.group_size,
103
+ weight_quantization_mode=QuantizationMode.from_num_bits(self.quantization_config.bits),
104
+ activation_quantization_mode=None,
105
+ activation_precision=activation_precision,
106
+ )
107
+ attention_config = AttentionConfig(
108
+ qkv_projection_config=linear_config,
109
+ out_projection_config=linear_config,
110
+ query_norm_config=None,
111
+ key_norm_config=None,
112
+ logit_soft_cap=None,
113
+ has_qkv_biases=True,
114
+ has_out_biases=False,
115
+ )
116
+ mlp_config = MLPConfig(
117
+ linear_config=linear_config,
118
+ activation=Activation.SILU,
119
+ )
120
+ decoder_layer_config = DecoderLayerConfig(
121
+ pre_attention_norm_config=rmsnorm_config,
122
+ attention_config=attention_config,
123
+ post_attention_norm_config=None,
124
+ pre_mlp_norm_config=rmsnorm_config,
125
+ mlp_config=mlp_config,
126
+ post_mlp_norm_config=None,
127
+ )
128
+ return DecoderConfig(
129
+ embedding_config=embedding_config,
130
+ global_rope_config=rope_config,
131
+ local_rope_config=None,
132
+ layer_config=decoder_layer_config,
133
+ output_norm_config=rmsnorm_config,
134
+ vocab_size=self.vocab_size,
135
+ model_dim=self.hidden_size,
136
+ hidden_dim=self.intermediate_size,
137
+ num_heads=self.num_attention_heads,
138
+ num_groups=self.num_key_value_heads,
139
+ head_dim=self.hidden_size // self.num_attention_heads,
140
+ attention_scale=None,
141
+ num_layers=self.num_hidden_layers,
142
+ sliding_window_sizes=tuple(self._get_sliding_window_sizes()),
143
+ context_length=context_length or self.max_position_embeddings,
144
+ )