lalamo 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/__init__.py +8 -0
- lalamo/model_import/common.py +111 -0
- lalamo/model_import/configs/__init__.py +24 -0
- lalamo/model_import/configs/common.py +62 -0
- lalamo/model_import/configs/executorch.py +166 -0
- lalamo/model_import/configs/huggingface/__init__.py +18 -0
- lalamo/model_import/configs/huggingface/common.py +72 -0
- lalamo/model_import/configs/huggingface/gemma2.py +122 -0
- lalamo/model_import/configs/huggingface/gemma3.py +187 -0
- lalamo/model_import/configs/huggingface/llama.py +155 -0
- lalamo/model_import/configs/huggingface/mistral.py +132 -0
- lalamo/model_import/configs/huggingface/qwen2.py +144 -0
- lalamo/model_import/configs/huggingface/qwen3.py +142 -0
- lalamo/model_import/loaders/__init__.py +7 -0
- lalamo/model_import/loaders/common.py +45 -0
- lalamo/model_import/loaders/executorch.py +223 -0
- lalamo/model_import/loaders/huggingface.py +304 -0
- lalamo/model_import/model_specs/__init__.py +38 -0
- lalamo/model_import/model_specs/common.py +118 -0
- lalamo/model_import/model_specs/deepseek.py +28 -0
- lalamo/model_import/model_specs/gemma.py +76 -0
- lalamo/model_import/model_specs/huggingface.py +28 -0
- lalamo/model_import/model_specs/llama.py +100 -0
- lalamo/model_import/model_specs/mistral.py +59 -0
- lalamo/model_import/model_specs/pleias.py +28 -0
- lalamo/model_import/model_specs/polaris.py +22 -0
- lalamo/model_import/model_specs/qwen.py +336 -0
- lalamo/model_import/model_specs/reka.py +28 -0
- lalamo/modules/__init__.py +85 -0
- lalamo/modules/activations.py +30 -0
- lalamo/modules/attention.py +326 -0
- lalamo/modules/common.py +133 -0
- lalamo/modules/decoder.py +244 -0
- lalamo/modules/decoder_layer.py +240 -0
- lalamo/modules/embedding.py +299 -0
- lalamo/modules/kv_cache.py +196 -0
- lalamo/modules/linear.py +603 -0
- lalamo/modules/mlp.py +79 -0
- lalamo/modules/normalization.py +77 -0
- lalamo/modules/rope.py +255 -0
- lalamo/modules/utils.py +13 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/METADATA +1 -1
- lalamo-0.2.3.dist-info/RECORD +53 -0
- lalamo-0.2.1.dist-info/RECORD +0 -12
- {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/WHEEL +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
from jax import vmap
|
|
6
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
7
|
+
|
|
8
|
+
from lalamo.common import ParameterDict
|
|
9
|
+
|
|
10
|
+
from .attention import Attention, AttentionConfig
|
|
11
|
+
from .common import AttentionType, LalamoModule, WeightLayout
|
|
12
|
+
from .kv_cache import KVCacheLayer, StaticKVCacheLayer
|
|
13
|
+
from .mlp import MLP, MLPConfig
|
|
14
|
+
from .normalization import RMSNorm, RMSNormConfig
|
|
15
|
+
from .rope import PositionalEmbeddings
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"DecoderLayer",
|
|
19
|
+
"DecoderLayerActivationTrace",
|
|
20
|
+
"DecoderLayerConfig",
|
|
21
|
+
"DecoderLayerResult",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DecoderLayerActivationTrace(eqx.Module):
|
|
26
|
+
inputs: Float[Array, "suffix_tokens channels"]
|
|
27
|
+
positional_embeddings: PositionalEmbeddings
|
|
28
|
+
kv_cache: KVCacheLayer | None
|
|
29
|
+
|
|
30
|
+
mlp_inputs: Float[Array, "suffix_tokens channels"]
|
|
31
|
+
pre_attention_norm: Float[Array, "suffix_tokens channels"]
|
|
32
|
+
attention: Float[Array, "suffix_tokens channels"]
|
|
33
|
+
post_attention_norm: Float[Array, "suffix_tokens channels"] | None
|
|
34
|
+
pre_mlp_norm: Float[Array, "suffix_tokens channels"]
|
|
35
|
+
mlp: Float[Array, "suffix_tokens channels"]
|
|
36
|
+
post_mlp_norm: Float[Array, "suffix_tokens channels"] | None
|
|
37
|
+
|
|
38
|
+
def export(self) -> ParameterDict:
|
|
39
|
+
result = ParameterDict(
|
|
40
|
+
inputs=self.inputs,
|
|
41
|
+
positional_embeddings=self.positional_embeddings.export(),
|
|
42
|
+
mlp_inputs=self.mlp_inputs,
|
|
43
|
+
pre_attention_norm=self.pre_attention_norm,
|
|
44
|
+
attention=self.attention,
|
|
45
|
+
pre_mlp_norm=self.pre_mlp_norm,
|
|
46
|
+
mlp=self.mlp,
|
|
47
|
+
)
|
|
48
|
+
if self.kv_cache is not None:
|
|
49
|
+
result["kv_cache"] = self.kv_cache.export()
|
|
50
|
+
if self.post_attention_norm is not None:
|
|
51
|
+
result["post_attention_norm"] = self.post_attention_norm
|
|
52
|
+
if self.post_mlp_norm is not None:
|
|
53
|
+
result["post_mlp_norm"] = self.post_mlp_norm
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DecoderLayerResult(eqx.Module):
|
|
58
|
+
outputs: Float[Array, "suffix_tokens channels"]
|
|
59
|
+
updated_kv_cache: KVCacheLayer | None
|
|
60
|
+
activation_trace: DecoderLayerActivationTrace | None
|
|
61
|
+
|
|
62
|
+
def export(self) -> ParameterDict:
|
|
63
|
+
result = ParameterDict(
|
|
64
|
+
outputs=self.outputs,
|
|
65
|
+
)
|
|
66
|
+
if self.updated_kv_cache is not None:
|
|
67
|
+
result["updated_kv_cache"] = self.updated_kv_cache.export()
|
|
68
|
+
if self.activation_trace is not None:
|
|
69
|
+
result["activation_trace"] = self.activation_trace.export()
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass(frozen=True)
|
|
74
|
+
class DecoderLayerConfig:
|
|
75
|
+
pre_attention_norm_config: RMSNormConfig
|
|
76
|
+
attention_config: AttentionConfig
|
|
77
|
+
post_attention_norm_config: RMSNormConfig | None
|
|
78
|
+
pre_mlp_norm_config: RMSNormConfig
|
|
79
|
+
mlp_config: MLPConfig
|
|
80
|
+
post_mlp_norm_config: RMSNormConfig | None
|
|
81
|
+
|
|
82
|
+
def random_init(
|
|
83
|
+
self,
|
|
84
|
+
model_dim: int,
|
|
85
|
+
hidden_dim: int,
|
|
86
|
+
num_heads: int,
|
|
87
|
+
num_groups: int,
|
|
88
|
+
head_dim: int,
|
|
89
|
+
attention_scale: float | None,
|
|
90
|
+
sliding_window_size: int | None,
|
|
91
|
+
*,
|
|
92
|
+
key: PRNGKeyArray,
|
|
93
|
+
) -> "DecoderLayer":
|
|
94
|
+
attention_key, mlp_key = jax.random.split(key)
|
|
95
|
+
pre_attention_norm = self.pre_attention_norm_config.init(model_dim)
|
|
96
|
+
attention = self.attention_config.random_init(
|
|
97
|
+
model_dim=model_dim,
|
|
98
|
+
num_heads=num_heads,
|
|
99
|
+
num_groups=num_groups,
|
|
100
|
+
head_dim=head_dim,
|
|
101
|
+
is_causal=True,
|
|
102
|
+
scale=attention_scale,
|
|
103
|
+
sliding_window_size=sliding_window_size,
|
|
104
|
+
key=attention_key,
|
|
105
|
+
)
|
|
106
|
+
if self.post_attention_norm_config is not None:
|
|
107
|
+
post_attention_norm = self.post_attention_norm_config.init(model_dim)
|
|
108
|
+
else:
|
|
109
|
+
post_attention_norm = None
|
|
110
|
+
pre_mlp_norm = self.pre_mlp_norm_config.init(model_dim)
|
|
111
|
+
mlp = self.mlp_config.random_init(model_dim, hidden_dim, key=mlp_key)
|
|
112
|
+
if self.post_mlp_norm_config is not None:
|
|
113
|
+
post_mlp_norm = self.post_mlp_norm_config.init(model_dim)
|
|
114
|
+
else:
|
|
115
|
+
post_mlp_norm = None
|
|
116
|
+
return DecoderLayer(
|
|
117
|
+
config=self,
|
|
118
|
+
pre_attention_norm=pre_attention_norm,
|
|
119
|
+
attention=attention,
|
|
120
|
+
post_attention_norm=post_attention_norm,
|
|
121
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
122
|
+
mlp=mlp,
|
|
123
|
+
post_mlp_norm=post_mlp_norm,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class DecoderLayer(LalamoModule[DecoderLayerConfig]):
|
|
128
|
+
pre_attention_norm: RMSNorm
|
|
129
|
+
attention: Attention
|
|
130
|
+
post_attention_norm: RMSNorm | None
|
|
131
|
+
pre_mlp_norm: RMSNorm
|
|
132
|
+
mlp: MLP
|
|
133
|
+
post_mlp_norm: RMSNorm | None
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def activation_precision(self) -> DTypeLike:
|
|
137
|
+
return self.attention.activation_precision
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def attention_type(self) -> AttentionType:
|
|
141
|
+
return self.attention.attention_type
|
|
142
|
+
|
|
143
|
+
def __post_init__(self) -> None:
|
|
144
|
+
model_dim = self.pre_attention_norm.input_dim
|
|
145
|
+
if self.attention.model_dim != model_dim:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Attention model dim {self.attention.model_dim} does not match"
|
|
148
|
+
f" the first normalization layer dim {model_dim}",
|
|
149
|
+
)
|
|
150
|
+
if self.post_attention_norm is not None and self.post_attention_norm.input_dim != model_dim:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Post attention normalization dim {self.post_attention_norm.input_dim} does not match"
|
|
153
|
+
f" the first normalization layer dim {model_dim}",
|
|
154
|
+
)
|
|
155
|
+
if self.pre_mlp_norm.input_dim != model_dim:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Pre MLP normalization dim {self.pre_mlp_norm.input_dim} does not match"
|
|
158
|
+
f" the first normalization layer dim {model_dim}",
|
|
159
|
+
)
|
|
160
|
+
if self.mlp.model_dim != model_dim:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"MLP up projection dim {self.mlp.up_projection.input_dim} does not match"
|
|
163
|
+
f" the first normalization layer dim {model_dim}",
|
|
164
|
+
)
|
|
165
|
+
if self.mlp.hidden_dim != self.mlp.down_projection.input_dim:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"MLP down projection dim {self.mlp.down_projection.input_dim} does not match"
|
|
168
|
+
f" the up projection dim {self.mlp.hidden_dim}",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def __call__(
|
|
172
|
+
self,
|
|
173
|
+
inputs: Float[Array, "suffix_tokens channels"],
|
|
174
|
+
positional_embeddings: PositionalEmbeddings,
|
|
175
|
+
kv_cache: KVCacheLayer | None = None,
|
|
176
|
+
return_updated_kv_cache: bool = False,
|
|
177
|
+
return_activation_trace: bool = False,
|
|
178
|
+
length_without_padding: Int[Array, ""] | int | None = None,
|
|
179
|
+
) -> DecoderLayerResult:
|
|
180
|
+
normalized_attention_inputs = vmap(self.pre_attention_norm, in_axes=0)(inputs)
|
|
181
|
+
attention_outputs, updated_kv_cache = self.attention(
|
|
182
|
+
normalized_attention_inputs,
|
|
183
|
+
positional_embeddings,
|
|
184
|
+
kv_cache=kv_cache,
|
|
185
|
+
return_updated_kv_cache=return_updated_kv_cache,
|
|
186
|
+
length_without_padding=length_without_padding,
|
|
187
|
+
)
|
|
188
|
+
if self.post_attention_norm is not None:
|
|
189
|
+
normalized_attention_outputs = vmap(self.post_attention_norm, in_axes=0)(attention_outputs)
|
|
190
|
+
mlp_inputs = inputs + normalized_attention_outputs
|
|
191
|
+
else:
|
|
192
|
+
normalized_attention_outputs = None
|
|
193
|
+
mlp_inputs = inputs + attention_outputs
|
|
194
|
+
|
|
195
|
+
normalized_mlp_inputs = vmap(self.pre_mlp_norm, in_axes=0)(mlp_inputs)
|
|
196
|
+
mlp_outputs = vmap(self.mlp, in_axes=0)(normalized_mlp_inputs)
|
|
197
|
+
if self.post_mlp_norm is not None:
|
|
198
|
+
normalized_mlp_outputs = vmap(self.post_mlp_norm, in_axes=0)(mlp_outputs)
|
|
199
|
+
outputs = mlp_inputs + normalized_mlp_outputs
|
|
200
|
+
else:
|
|
201
|
+
normalized_mlp_outputs = None
|
|
202
|
+
outputs = mlp_inputs + mlp_outputs
|
|
203
|
+
|
|
204
|
+
if return_activation_trace:
|
|
205
|
+
activation_trace = DecoderLayerActivationTrace(
|
|
206
|
+
inputs=inputs,
|
|
207
|
+
positional_embeddings=positional_embeddings,
|
|
208
|
+
kv_cache=kv_cache,
|
|
209
|
+
pre_attention_norm=normalized_attention_inputs,
|
|
210
|
+
attention=attention_outputs,
|
|
211
|
+
post_attention_norm=normalized_attention_outputs,
|
|
212
|
+
mlp_inputs=mlp_inputs,
|
|
213
|
+
pre_mlp_norm=normalized_mlp_inputs,
|
|
214
|
+
mlp=mlp_outputs,
|
|
215
|
+
post_mlp_norm=normalized_mlp_outputs,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
activation_trace = None
|
|
219
|
+
|
|
220
|
+
return DecoderLayerResult(
|
|
221
|
+
outputs=outputs,
|
|
222
|
+
updated_kv_cache=updated_kv_cache,
|
|
223
|
+
activation_trace=activation_trace,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def init_static_kv_cache(self, capacity: int) -> StaticKVCacheLayer:
|
|
227
|
+
return self.attention.init_static_kv_cache(capacity)
|
|
228
|
+
|
|
229
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
230
|
+
result = ParameterDict(
|
|
231
|
+
pre_attention_norm=self.pre_attention_norm.export_weights(weight_layout),
|
|
232
|
+
attention=self.attention.export_weights(weight_layout),
|
|
233
|
+
pre_mlp_norm=self.pre_mlp_norm.export_weights(weight_layout),
|
|
234
|
+
mlp=self.mlp.export_weights(weight_layout),
|
|
235
|
+
)
|
|
236
|
+
if self.post_attention_norm is not None:
|
|
237
|
+
result["post_attention_norm"] = self.post_attention_norm.export_weights(weight_layout)
|
|
238
|
+
if self.post_mlp_norm is not None:
|
|
239
|
+
result["post_mlp_norm"] = self.post_mlp_norm.export_weights(weight_layout)
|
|
240
|
+
return result
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from jaxtyping import Array, DTypeLike, Float, Int, PRNGKeyArray
|
|
8
|
+
|
|
9
|
+
from lalamo.common import ParameterDict
|
|
10
|
+
from lalamo.quantization import QuantizationMode, dynamically_quantize_activations, quantize_weights
|
|
11
|
+
|
|
12
|
+
from .common import LalamoModule, WeightLayout, register_config_union
|
|
13
|
+
from .utils import apply_soft_capping
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"EmbeddingBase",
|
|
17
|
+
"EmbeddingConfig",
|
|
18
|
+
"QuantizedTiedEmbedding",
|
|
19
|
+
"QuantizedTiedEmbeddingConfig",
|
|
20
|
+
"TiedEmbedding",
|
|
21
|
+
"TiedEmbeddingConfig",
|
|
22
|
+
"UntiedEmbedding",
|
|
23
|
+
"UntiedEmbeddingConfig",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class EmbeddingConfigBase:
|
|
29
|
+
input_scale: float | None
|
|
30
|
+
logits_soft_cap: float | None
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def random_init(
|
|
34
|
+
self,
|
|
35
|
+
vocab_size: int,
|
|
36
|
+
model_dim: int,
|
|
37
|
+
*,
|
|
38
|
+
key: PRNGKeyArray,
|
|
39
|
+
) -> "EmbeddingBase": ...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class EmbeddingBase[ConfigT: EmbeddingConfigBase](LalamoModule[ConfigT]):
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]: ...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]: ...
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def vocab_size(self) -> int: ...
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def model_dim(self) -> int: ...
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def _default_weight_layout(cls) -> WeightLayout:
|
|
59
|
+
return WeightLayout.INPUT_OUTPUT
|
|
60
|
+
|
|
61
|
+
def embed(self, x: Int[Array, " tokens"]) -> Float[Array, "tokens channels"]:
|
|
62
|
+
result = self._prepare_input_weights()[x]
|
|
63
|
+
if self.config.input_scale is not None:
|
|
64
|
+
result = result * jnp.array(self.config.input_scale, dtype=result.dtype)
|
|
65
|
+
return result
|
|
66
|
+
|
|
67
|
+
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
68
|
+
logits = self._prepare_output_weights() @ x
|
|
69
|
+
if self.config.logits_soft_cap is not None:
|
|
70
|
+
logits = apply_soft_capping(logits, self.config.logits_soft_cap)
|
|
71
|
+
return logits
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen=True)
|
|
75
|
+
class TiedEmbeddingConfig(EmbeddingConfigBase):
|
|
76
|
+
precision: DTypeLike
|
|
77
|
+
|
|
78
|
+
def random_init(
|
|
79
|
+
self,
|
|
80
|
+
vocab_size: int,
|
|
81
|
+
model_dim: int,
|
|
82
|
+
*,
|
|
83
|
+
key: PRNGKeyArray,
|
|
84
|
+
) -> "TiedEmbedding":
|
|
85
|
+
weights = jax.random.normal(key, (vocab_size, model_dim), dtype=self.precision)
|
|
86
|
+
return TiedEmbedding(config=self, weights=weights)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TiedEmbedding(EmbeddingBase[TiedEmbeddingConfig]):
|
|
90
|
+
weights: Float[Array, "vocabulary channels"]
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def activation_precision(self) -> DTypeLike:
|
|
94
|
+
return self.config.precision
|
|
95
|
+
|
|
96
|
+
def __post_init__(self) -> None:
|
|
97
|
+
if self.config.precision != self.weights.dtype:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Embedding dtype {self.weights.dtype} does not match the specified precision {self.config.precision}",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def model_dim(self) -> int:
|
|
104
|
+
_, model_dim = self.weights.shape
|
|
105
|
+
return model_dim
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def vocab_size(self) -> int:
|
|
109
|
+
vocab_size, _ = self.weights.shape
|
|
110
|
+
return vocab_size
|
|
111
|
+
|
|
112
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
113
|
+
return self.weights
|
|
114
|
+
|
|
115
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
116
|
+
return self.weights
|
|
117
|
+
|
|
118
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
|
|
119
|
+
return ParameterDict(weights=self.weights)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass(frozen=True)
|
|
123
|
+
class UntiedEmbeddingConfig(EmbeddingConfigBase):
|
|
124
|
+
precision: DTypeLike
|
|
125
|
+
|
|
126
|
+
def random_init(
|
|
127
|
+
self,
|
|
128
|
+
vocab_size: int,
|
|
129
|
+
model_dim: int,
|
|
130
|
+
*,
|
|
131
|
+
key: PRNGKeyArray,
|
|
132
|
+
) -> "UntiedEmbedding":
|
|
133
|
+
input_key, output_key = jax.random.split(key)
|
|
134
|
+
input_weights = jax.random.normal(input_key, (vocab_size, model_dim), dtype=self.precision)
|
|
135
|
+
output_weights = jax.random.normal(output_key, (vocab_size, model_dim), dtype=self.precision)
|
|
136
|
+
return UntiedEmbedding(
|
|
137
|
+
config=self,
|
|
138
|
+
input_weights=input_weights,
|
|
139
|
+
output_weights=output_weights,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class UntiedEmbedding(EmbeddingBase[UntiedEmbeddingConfig]):
|
|
144
|
+
input_weights: Float[Array, "vocabulary channels"]
|
|
145
|
+
output_weights: Float[Array, "vocabulary channels"]
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def activation_precision(self) -> DTypeLike:
|
|
149
|
+
return self.config.precision
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def model_dim(self) -> int:
|
|
153
|
+
_, model_dim = self.input_weights.shape
|
|
154
|
+
return model_dim
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def vocab_size(self) -> int:
|
|
158
|
+
vocab_size, _ = self.input_weights.shape
|
|
159
|
+
return vocab_size
|
|
160
|
+
|
|
161
|
+
def __post_init__(self) -> None:
|
|
162
|
+
if self.config.precision != self.input_weights.dtype:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Embedding dtype {self.input_weights.dtype} does not match",
|
|
165
|
+
f" the specified precision {self.config.precision}",
|
|
166
|
+
)
|
|
167
|
+
if self.config.precision != self.output_weights.dtype:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"Embedding dtype {self.output_weights.dtype} does not match"
|
|
170
|
+
f" the specified precision {self.config.precision}",
|
|
171
|
+
)
|
|
172
|
+
input_vocab_size, input_model_dim = self.input_weights.shape
|
|
173
|
+
output_vocab_size, output_model_dim = self.output_weights.shape
|
|
174
|
+
if input_vocab_size != output_vocab_size:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Input vocab size {input_vocab_size} does not match the output vocab size {output_vocab_size}",
|
|
177
|
+
)
|
|
178
|
+
if input_model_dim != output_model_dim:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"Input model dim {input_model_dim} does not match the output model dim {output_model_dim}",
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
184
|
+
return self.input_weights
|
|
185
|
+
|
|
186
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
187
|
+
return self.output_weights
|
|
188
|
+
|
|
189
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict:
|
|
190
|
+
if weight_layout == WeightLayout.AUTO:
|
|
191
|
+
weight_layout = self._default_weight_layout()
|
|
192
|
+
|
|
193
|
+
match weight_layout:
|
|
194
|
+
case WeightLayout.OUTPUT_INPUT:
|
|
195
|
+
output_weights = self.output_weights
|
|
196
|
+
case WeightLayout.INPUT_OUTPUT:
|
|
197
|
+
output_weights = rearrange(self.output_weights, "token_ids channels -> channels token_ids")
|
|
198
|
+
case _:
|
|
199
|
+
raise ValueError(f"Unsupported weight layout: {weight_layout}")
|
|
200
|
+
|
|
201
|
+
return ParameterDict(
|
|
202
|
+
input_weights=self.input_weights,
|
|
203
|
+
output_weights=output_weights,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@dataclass(frozen=True)
|
|
208
|
+
class QuantizedTiedEmbeddingConfig(EmbeddingConfigBase):
|
|
209
|
+
embedding_quantization_mode: QuantizationMode
|
|
210
|
+
activation_quantization_mode: QuantizationMode | None
|
|
211
|
+
activation_precision: DTypeLike
|
|
212
|
+
|
|
213
|
+
def random_init(
|
|
214
|
+
self,
|
|
215
|
+
vocab_size: int,
|
|
216
|
+
model_dim: int,
|
|
217
|
+
*,
|
|
218
|
+
key: PRNGKeyArray,
|
|
219
|
+
) -> "QuantizedTiedEmbedding":
|
|
220
|
+
min_val, max_val = self.embedding_quantization_mode.range
|
|
221
|
+
min_abs_val = min(abs(min_val), abs(max_val))
|
|
222
|
+
scale = 1 / min_abs_val
|
|
223
|
+
scales = scale * jnp.ones(vocab_size, dtype=self.activation_precision)
|
|
224
|
+
weights = jax.random.normal(key, (vocab_size, model_dim), dtype=self.activation_precision)
|
|
225
|
+
weights = quantize_weights(weights * min_abs_val, self.embedding_quantization_mode)
|
|
226
|
+
return QuantizedTiedEmbedding(config=self, weights=weights, scales=scales)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class QuantizedTiedEmbedding(EmbeddingBase[QuantizedTiedEmbeddingConfig]):
|
|
230
|
+
weights: Float[Array, "vocabulary channels"]
|
|
231
|
+
scales: Float[Array, " vocabulary"]
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def activation_precision(self) -> DTypeLike:
|
|
235
|
+
return self.config.activation_precision
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def model_dim(self) -> int:
|
|
239
|
+
_, model_dim = self.weights.shape
|
|
240
|
+
return model_dim
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def vocab_size(self) -> int:
|
|
244
|
+
vocab_size, _ = self.weights.shape
|
|
245
|
+
return vocab_size
|
|
246
|
+
|
|
247
|
+
def __post_init__(self) -> None:
|
|
248
|
+
if self.weights.dtype != self.config.activation_precision:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"Embedding dtype ({self.scales.dtype}) is not equal to specified activation precision"
|
|
251
|
+
f" ({self.config.activation_precision})."
|
|
252
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
253
|
+
)
|
|
254
|
+
if self.scales.dtype != self.config.activation_precision:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Scales dtype {self.scales.dtype} does not match the specified activation precision"
|
|
257
|
+
f" {self.config.activation_precision}"
|
|
258
|
+
" Quantized layers require parameter dtypes to be equal to the activation precision.",
|
|
259
|
+
)
|
|
260
|
+
weights_vocab_size, weights_model_dim = self.weights.shape
|
|
261
|
+
(scales_vocab_size,) = self.scales.shape
|
|
262
|
+
if weights_vocab_size != scales_vocab_size:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"Embedding vocab size {weights_vocab_size} does not match"
|
|
265
|
+
f" the scales dimension size {scales_vocab_size}",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def int_weights(self) -> Int[Array, "vocabulary channels"]:
|
|
270
|
+
result = quantize_weights(self.weights, self.config.embedding_quantization_mode)
|
|
271
|
+
return result.astype(self.config.embedding_quantization_mode.dtype)
|
|
272
|
+
|
|
273
|
+
def _prepare_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
274
|
+
quantized_weights = quantize_weights(self.weights, self.config.embedding_quantization_mode)
|
|
275
|
+
quantized_weights = quantized_weights * self.scales.reshape(-1, 1)
|
|
276
|
+
return quantized_weights
|
|
277
|
+
|
|
278
|
+
def _prepare_input_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
279
|
+
return self._prepare_weights()
|
|
280
|
+
|
|
281
|
+
def _prepare_output_weights(self) -> Float[Array, "vocabulary channels"]:
|
|
282
|
+
return self._prepare_weights()
|
|
283
|
+
|
|
284
|
+
def readout(self, x: Float[Array, " channels"]) -> Float[Array, " vocabulary"]:
|
|
285
|
+
if self.config.activation_quantization_mode is not None:
|
|
286
|
+
x = dynamically_quantize_activations(x, self.config.activation_quantization_mode)
|
|
287
|
+
return super().readout(x)
|
|
288
|
+
|
|
289
|
+
def export_weights(self, weight_layout: WeightLayout = WeightLayout.AUTO) -> ParameterDict: # noqa: ARG002
|
|
290
|
+
return ParameterDict(
|
|
291
|
+
weights=self.int_weights,
|
|
292
|
+
scales=self.scales,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
EmbeddingConfig = TiedEmbeddingConfig | UntiedEmbeddingConfig | QuantizedTiedEmbeddingConfig
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
register_config_union(EmbeddingConfig)
|