ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +5 -4
- ai_edge_torch/_convert/conversion.py +112 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +94 -48
- ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/test_convert.py +495 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
- ai_edge_torch/config.py +27 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +72 -40
- ai_edge_torch/debug/test/test_culprit.py +7 -5
- ai_edge_torch/debug/test/test_search_model.py +8 -7
- ai_edge_torch/debug/utils.py +14 -3
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
- ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
- ai_edge_torch/generative/examples/openelm/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
- ai_edge_torch/generative/examples/phi/phi3.py +286 -0
- ai_edge_torch/generative/examples/phi/verify.py +65 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
- ai_edge_torch/generative/examples/smollm/verify.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
- ai_edge_torch/generative/examples/t5/t5.py +208 -159
- ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
- ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
- ai_edge_torch/generative/fx_passes/__init__.py +4 -5
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
- ai_edge_torch/generative/layers/attention.py +141 -102
- ai_edge_torch/generative/layers/attention_utils.py +53 -12
- ai_edge_torch/generative/layers/builder.py +37 -7
- ai_edge_torch/generative/layers/feed_forward.py +39 -14
- ai_edge_torch/generative/layers/kv_cache.py +162 -50
- ai_edge_torch/generative/layers/model_config.py +84 -30
- ai_edge_torch/generative/layers/normalization.py +185 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
- ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/layers/unet/model_config.py +17 -15
- ai_edge_torch/generative/quantize/example.py +7 -8
- ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
- ai_edge_torch/generative/test/test_model_conversion.py +124 -188
- ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
- ai_edge_torch/generative/test/test_quantize.py +76 -60
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/loader.py +120 -57
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
- ai_edge_torch/generative/utilities/t5_loader.py +110 -81
- ai_edge_torch/generative/utilities/verifier.py +247 -0
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
- ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
- ai_edge_torch/lowertools/test_utils.py +60 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
- ai_edge_torch/model.py +53 -18
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +357 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
- ai_edge_torch/quantize/quant_config.py +13 -9
- ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
- ai_edge_torch/convert/conversion.py +0 -117
- ai_edge_torch/convert/conversion_utils.py +0 -400
- ai_edge_torch/convert/fx_passes/__init__.py +0 -59
- ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
- ai_edge_torch/convert/test/test_convert.py +0 -311
- ai_edge_torch/convert/test/test_convert_composites.py +0 -192
- ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -18,11 +18,10 @@ import glob
|
|
18
18
|
import os
|
19
19
|
from typing import Callable, Dict, List, Tuple
|
20
20
|
|
21
|
+
from ai_edge_torch.generative.layers import model_config
|
21
22
|
from safetensors import safe_open
|
22
23
|
import torch
|
23
24
|
|
24
|
-
from ai_edge_torch.generative.layers import model_config
|
25
|
-
|
26
25
|
|
27
26
|
def load_safetensors(full_path: str):
|
28
27
|
"""Loads safetensors into a single state dictionary.
|
@@ -73,7 +72,7 @@ def load_pytorch_statedict(full_path: str):
|
|
73
72
|
patterns = []
|
74
73
|
if os.path.isdir(full_path):
|
75
74
|
patterns.append(os.path.join(full_path, "*.bin"))
|
76
|
-
patterns.append(os.path.join(full_path, "
|
75
|
+
patterns.append(os.path.join(full_path, "*pt"))
|
77
76
|
else:
|
78
77
|
patterns.append(full_path)
|
79
78
|
for pattern in patterns:
|
@@ -93,9 +92,7 @@ def load_pytorch_statedict(full_path: str):
|
|
93
92
|
|
94
93
|
|
95
94
|
class ModelLoader:
|
96
|
-
"""
|
97
|
-
Edge Generative API layer format.
|
98
|
-
"""
|
95
|
+
"""Utlity for loading model checkpoints to the Edge Generative API layer."""
|
99
96
|
|
100
97
|
@dataclass
|
101
98
|
class TensorNames:
|
@@ -104,25 +101,30 @@ class ModelLoader:
|
|
104
101
|
attn_value_proj: str = None
|
105
102
|
attn_fused_qkv_proj: str = None
|
106
103
|
attn_output_proj: str = None
|
104
|
+
attn_query_norm: str = None
|
105
|
+
attn_key_norm: str = None
|
107
106
|
|
108
107
|
ff_up_proj: str = None
|
109
108
|
ff_down_proj: str = None
|
110
109
|
ff_gate_proj: str = None
|
111
110
|
|
112
111
|
pre_attn_norm: str = None
|
112
|
+
post_attn_norm: str = None
|
113
113
|
pre_ff_norm: str = None
|
114
|
+
post_ff_norm: str = None
|
114
115
|
embedding: str = None
|
115
116
|
embedding_position: str = None
|
116
117
|
final_norm: str = None
|
117
118
|
lm_head: str = None
|
118
119
|
|
119
120
|
def __init__(self, file_name: str, names: TensorNames) -> None:
|
120
|
-
"""ModelLoader constructor.
|
121
|
-
|
121
|
+
"""ModelLoader constructor.
|
122
|
+
|
123
|
+
Can be used to load multiple models of the same type.
|
122
124
|
|
123
125
|
Args:
|
124
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
125
|
-
|
126
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
127
|
+
file.
|
126
128
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
127
129
|
"""
|
128
130
|
self._file_name = file_name
|
@@ -141,13 +143,15 @@ class ModelLoader:
|
|
141
143
|
|
142
144
|
Returns:
|
143
145
|
missing_keys (List[str]): a list of str containing the missing keys.
|
144
|
-
unexpected_keys (List[str]): a list of str containing the unexpected
|
146
|
+
unexpected_keys (List[str]): a list of str containing the unexpected
|
147
|
+
keys.
|
145
148
|
|
146
149
|
Raises:
|
147
150
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
148
151
|
enabled.
|
149
152
|
"""
|
150
153
|
state = self._loader(self._file_name)
|
154
|
+
state = state["model_state_dict"] if "model_state_dict" in state else state
|
151
155
|
converted_state = dict()
|
152
156
|
if self._names.embedding is not None:
|
153
157
|
converted_state["tok_embedding.weight"] = state.pop(
|
@@ -158,14 +162,22 @@ class ModelLoader:
|
|
158
162
|
f"{self._names.embedding_position}"
|
159
163
|
)
|
160
164
|
if self._names.lm_head is not None:
|
161
|
-
converted_state["lm_head.weight"] = state.pop(
|
165
|
+
converted_state["lm_head.weight"] = state.pop(
|
166
|
+
f"{self._names.lm_head}.weight"
|
167
|
+
)
|
162
168
|
if model.config.lm_head_use_bias:
|
163
|
-
converted_state["lm_head.bias"] = state.pop(
|
169
|
+
converted_state["lm_head.bias"] = state.pop(
|
170
|
+
f"{self._names.lm_head}.bias"
|
171
|
+
)
|
164
172
|
if self._names.final_norm is not None:
|
165
173
|
final_norm_name = self._names.final_norm
|
166
|
-
converted_state["final_norm.weight"] = state.pop(
|
174
|
+
converted_state["final_norm.weight"] = state.pop(
|
175
|
+
f"{final_norm_name}.weight"
|
176
|
+
)
|
167
177
|
if f"{final_norm_name}.bias" in state:
|
168
|
-
converted_state["final_norm.bias"] = state.pop(
|
178
|
+
converted_state["final_norm.bias"] = state.pop(
|
179
|
+
f"{final_norm_name}.bias"
|
180
|
+
)
|
169
181
|
|
170
182
|
for i in range(model.config.num_layers):
|
171
183
|
self._map_norm(i, model.config, state, converted_state)
|
@@ -191,17 +203,17 @@ class ModelLoader:
|
|
191
203
|
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
|
192
204
|
return load_safetensors
|
193
205
|
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
|
194
|
-
os.path.join(self._file_name, "
|
206
|
+
os.path.join(self._file_name, "*pt")
|
195
207
|
):
|
196
208
|
return load_pytorch_statedict
|
197
209
|
|
198
210
|
if self._file_name.endswith(".safetensors"):
|
199
211
|
return load_safetensors
|
200
212
|
|
201
|
-
if self._file_name.endswith(".bin") or self._file_name.endswith("
|
213
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
|
202
214
|
return load_pytorch_statedict
|
203
215
|
|
204
|
-
raise ValueError(
|
216
|
+
raise ValueError("File format not supported.")
|
205
217
|
|
206
218
|
def _map_feedforward(
|
207
219
|
self,
|
@@ -211,31 +223,66 @@ class ModelLoader:
|
|
211
223
|
converted_state: Dict[str, torch.Tensor],
|
212
224
|
):
|
213
225
|
prefix = f"transformer_blocks.{idx}"
|
214
|
-
|
226
|
+
ff_config = config.block_config(idx).ff_config
|
227
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
215
228
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
216
229
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
217
|
-
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
230
|
+
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
231
|
+
f"{ff_up_proj_name}.weight"
|
232
|
+
)
|
218
233
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
219
234
|
f"{ff_down_proj_name}.weight"
|
220
235
|
)
|
221
|
-
if
|
222
|
-
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
223
|
-
|
236
|
+
if ff_config.use_bias:
|
237
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
238
|
+
f"{ff_up_proj_name}.bias"
|
239
|
+
)
|
240
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
241
|
+
f"{ff_down_proj_name}.bias"
|
242
|
+
)
|
224
243
|
else:
|
225
244
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
226
245
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
227
246
|
ff_gate_proj_name = self._names.ff_gate_proj.format(idx)
|
228
|
-
converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
|
247
|
+
converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
|
248
|
+
f"{ff_up_proj_name}.weight"
|
249
|
+
)
|
229
250
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
230
251
|
f"{ff_down_proj_name}.weight"
|
231
252
|
)
|
232
253
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
233
254
|
f"{ff_gate_proj_name}.weight"
|
234
255
|
)
|
235
|
-
if
|
236
|
-
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
237
|
-
|
238
|
-
|
256
|
+
if ff_config.use_bias:
|
257
|
+
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
258
|
+
f"{ff_up_proj_name}.bias"
|
259
|
+
)
|
260
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
261
|
+
f"{ff_down_proj_name}.bias"
|
262
|
+
)
|
263
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
264
|
+
f"{ff_gate_proj_name}.bias"
|
265
|
+
)
|
266
|
+
|
267
|
+
if self._names.pre_ff_norm is not None:
|
268
|
+
pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
|
269
|
+
converted_state[f"{prefix}.ff.pre_ff_norm.weight"] = state.pop(
|
270
|
+
f"{pre_ff_norm_name}.weight"
|
271
|
+
)
|
272
|
+
if f"{pre_ff_norm_name}.bias" in state:
|
273
|
+
converted_state[f"{prefix}.ff.pre_ff_norm.bias"] = state.pop(
|
274
|
+
f"{pre_ff_norm_name}.bias"
|
275
|
+
)
|
276
|
+
|
277
|
+
if self._names.post_ff_norm is not None:
|
278
|
+
post_ff_norm_name = self._names.post_ff_norm.format(idx)
|
279
|
+
converted_state[f"{prefix}.ff.post_ff_norm.weight"] = state.pop(
|
280
|
+
f"{post_ff_norm_name}.weight"
|
281
|
+
)
|
282
|
+
if f"{post_ff_norm_name}.bias" in state:
|
283
|
+
converted_state[f"{prefix}.ff.post_ff_norm.bias"] = state.pop(
|
284
|
+
f"{post_ff_norm_name}.bias"
|
285
|
+
)
|
239
286
|
|
240
287
|
def _map_attention(
|
241
288
|
self,
|
@@ -245,6 +292,7 @@ class ModelLoader:
|
|
245
292
|
converted_state: Dict[str, torch.Tensor],
|
246
293
|
):
|
247
294
|
prefix = f"transformer_blocks.{idx}"
|
295
|
+
attn_config = config.block_config(idx).attn_config
|
248
296
|
if self._names.attn_fused_qkv_proj:
|
249
297
|
fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
|
250
298
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
|
@@ -254,32 +302,47 @@ class ModelLoader:
|
|
254
302
|
q_name = self._names.attn_query_proj.format(idx)
|
255
303
|
k_name = self._names.attn_key_proj.format(idx)
|
256
304
|
v_name = self._names.attn_value_proj.format(idx)
|
257
|
-
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] =
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
305
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
|
306
|
+
self._fuse_qkv(
|
307
|
+
attn_config,
|
308
|
+
state.pop(f"{q_name}.weight"),
|
309
|
+
state.pop(f"{k_name}.weight"),
|
310
|
+
state.pop(f"{v_name}.weight"),
|
311
|
+
)
|
262
312
|
)
|
263
|
-
if
|
313
|
+
if attn_config.qkv_use_bias:
|
264
314
|
if self._names.attn_fused_qkv_proj:
|
265
315
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
|
266
316
|
f"{fused_qkv_name}.bias"
|
267
317
|
)
|
268
318
|
else:
|
269
|
-
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] =
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
319
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
|
320
|
+
self._fuse_qkv(
|
321
|
+
attn_config,
|
322
|
+
state.pop(f"{q_name}.bias"),
|
323
|
+
state.pop(f"{k_name}.bias"),
|
324
|
+
state.pop(f"{v_name}.bias"),
|
325
|
+
)
|
274
326
|
)
|
275
327
|
|
328
|
+
if self._names.attn_query_norm is not None:
|
329
|
+
attn_query_norm_name = self._names.attn_query_norm.format(idx)
|
330
|
+
converted_state[f"{prefix}.atten_func.query_norm.weight"] = state.pop(
|
331
|
+
f"{attn_query_norm_name}.weight"
|
332
|
+
)
|
333
|
+
if self._names.attn_key_norm is not None:
|
334
|
+
attn_key_norm_name = self._names.attn_key_norm.format(idx)
|
335
|
+
converted_state[f"{prefix}.atten_func.key_norm.weight"] = state.pop(
|
336
|
+
f"{attn_key_norm_name}.weight"
|
337
|
+
)
|
338
|
+
|
276
339
|
o_name = self._names.attn_output_proj.format(idx)
|
277
|
-
converted_state[f"{prefix}.atten_func.output_projection.weight"] =
|
278
|
-
f"{o_name}.weight"
|
340
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
341
|
+
state.pop(f"{o_name}.weight")
|
279
342
|
)
|
280
|
-
if
|
281
|
-
converted_state[f"{prefix}.atten_func.output_projection.bias"] =
|
282
|
-
f"{o_name}.bias"
|
343
|
+
if attn_config.output_proj_use_bias:
|
344
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
345
|
+
state.pop(f"{o_name}.bias")
|
283
346
|
)
|
284
347
|
|
285
348
|
def _map_norm(
|
@@ -300,28 +363,28 @@ class ModelLoader:
|
|
300
363
|
f"{pre_attn_norm_name}.bias"
|
301
364
|
)
|
302
365
|
|
303
|
-
if self._names.
|
304
|
-
|
305
|
-
converted_state[f"{prefix}.
|
306
|
-
f"{
|
366
|
+
if self._names.post_attn_norm is not None:
|
367
|
+
post_attn_norm_name = self._names.post_attn_norm.format(idx)
|
368
|
+
converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
|
369
|
+
f"{post_attn_norm_name}.weight"
|
307
370
|
)
|
308
|
-
if f"{
|
309
|
-
converted_state[f"{prefix}.
|
310
|
-
f"{
|
371
|
+
if f"{post_attn_norm_name}.bias" in state:
|
372
|
+
converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
|
373
|
+
f"{post_attn_norm_name}.bias"
|
311
374
|
)
|
312
375
|
|
313
376
|
def _fuse_qkv(
|
314
377
|
self,
|
315
|
-
|
378
|
+
attn_config: model_config.AttentionConfig,
|
316
379
|
q: torch.Tensor,
|
317
380
|
k: torch.Tensor,
|
318
381
|
v: torch.Tensor,
|
319
382
|
) -> torch.Tensor:
|
320
|
-
if
|
321
|
-
q_per_kv =
|
322
|
-
qs = torch.split(q,
|
323
|
-
ks = torch.split(k,
|
324
|
-
vs = torch.split(v,
|
383
|
+
if attn_config.qkv_fused_interleaved:
|
384
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
385
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
386
|
+
ks = torch.split(k, attn_config.head_dim)
|
387
|
+
vs = torch.split(v, attn_config.head_dim)
|
325
388
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
326
389
|
return torch.cat(cycled)
|
327
390
|
else:
|