ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -48
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
|
@@ -15,9 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
import unittest
|
|
17
17
|
|
|
18
|
-
from parameterized import parameterized
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
18
|
import ai_edge_torch
|
|
22
19
|
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
|
|
23
20
|
from ai_edge_torch.generative.quantize import quant_recipe
|
|
@@ -29,20 +26,20 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
|
29
26
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
|
30
27
|
from ai_edge_torch.quantize import quant_config
|
|
31
28
|
from ai_edge_torch.testing import model_coverage
|
|
29
|
+
from parameterized import parameterized
|
|
30
|
+
import torch
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
class TestVerifyRecipes(unittest.TestCase):
|
|
35
34
|
"""Unit tests that check for model quantization recipes."""
|
|
36
35
|
|
|
37
|
-
@parameterized.expand(
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
]
|
|
45
|
-
)
|
|
36
|
+
@parameterized.expand([
|
|
37
|
+
(Dtype.FP32, Dtype.FP32),
|
|
38
|
+
(Dtype.INT8, Dtype.INT8),
|
|
39
|
+
(Dtype.INT8, Dtype.FP16),
|
|
40
|
+
(Dtype.FP16, Dtype.INT8),
|
|
41
|
+
(Dtype.FP16, Dtype.FP16),
|
|
42
|
+
])
|
|
46
43
|
def test_verify_invalid_recipes(
|
|
47
44
|
self,
|
|
48
45
|
activation,
|
|
@@ -54,31 +51,29 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
54
51
|
with self.assertRaises(ValueError):
|
|
55
52
|
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
|
56
53
|
|
|
57
|
-
@parameterized.expand(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
]
|
|
81
|
-
)
|
|
54
|
+
@parameterized.expand([
|
|
55
|
+
(
|
|
56
|
+
Dtype.FP32,
|
|
57
|
+
Dtype.INT8,
|
|
58
|
+
Mode.DYNAMIC_RANGE,
|
|
59
|
+
Algorithm.MIN_MAX,
|
|
60
|
+
Granularity.CHANNELWISE,
|
|
61
|
+
),
|
|
62
|
+
(
|
|
63
|
+
Dtype.FP32,
|
|
64
|
+
Dtype.INT8,
|
|
65
|
+
Mode.WEIGHT_ONLY,
|
|
66
|
+
Algorithm.MIN_MAX,
|
|
67
|
+
Granularity.CHANNELWISE,
|
|
68
|
+
),
|
|
69
|
+
(
|
|
70
|
+
Dtype.FP32,
|
|
71
|
+
Dtype.FP16,
|
|
72
|
+
Mode.WEIGHT_ONLY,
|
|
73
|
+
Algorithm.FLOAT_CAST,
|
|
74
|
+
Granularity.NONE,
|
|
75
|
+
),
|
|
76
|
+
])
|
|
82
77
|
def test_verify_valid_recipes(
|
|
83
78
|
self,
|
|
84
79
|
activation,
|
|
@@ -87,7 +82,9 @@ class TestVerifyRecipes(unittest.TestCase):
|
|
|
87
82
|
algo,
|
|
88
83
|
granularity,
|
|
89
84
|
):
|
|
90
|
-
quant_recipe.LayerQuantRecipe(
|
|
85
|
+
quant_recipe.LayerQuantRecipe(
|
|
86
|
+
activation, weight, mode, algo, granularity
|
|
87
|
+
).verify()
|
|
91
88
|
|
|
92
89
|
|
|
93
90
|
class TestQuantizeConvert(unittest.TestCase):
|
|
@@ -107,15 +104,13 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
107
104
|
)
|
|
108
105
|
)
|
|
109
106
|
|
|
110
|
-
@parameterized.expand(
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
]
|
|
118
|
-
)
|
|
107
|
+
@parameterized.expand([
|
|
108
|
+
(quant_recipes.full_fp16_recipe()),
|
|
109
|
+
(quant_recipes.full_int8_dynamic_recipe()),
|
|
110
|
+
(quant_recipes.full_int8_weight_only_recipe()),
|
|
111
|
+
(_attention_int8_dynamic_recipe()),
|
|
112
|
+
(_feedforward_int8_dynamic_recipe()),
|
|
113
|
+
])
|
|
119
114
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
|
120
115
|
config = toy_model.get_model_config()
|
|
121
116
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
|
@@ -146,7 +141,9 @@ class TestQuantizeConvert(unittest.TestCase):
|
|
|
146
141
|
)
|
|
147
142
|
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
148
143
|
|
|
149
|
-
self.assertLess(
|
|
144
|
+
self.assertLess(
|
|
145
|
+
len(quantized_model._tflite_model), len(float_model._tflite_model)
|
|
146
|
+
)
|
|
150
147
|
self.assertTrue(
|
|
151
148
|
model_coverage.compare_tflite_torch(
|
|
152
149
|
quantized_model,
|
|
@@ -18,11 +18,10 @@ import glob
|
|
|
18
18
|
import os
|
|
19
19
|
from typing import Callable, Dict, List, Tuple
|
|
20
20
|
|
|
21
|
+
from ai_edge_torch.generative.layers import model_config
|
|
21
22
|
from safetensors import safe_open
|
|
22
23
|
import torch
|
|
23
24
|
|
|
24
|
-
from ai_edge_torch.generative.layers import model_config
|
|
25
|
-
|
|
26
25
|
|
|
27
26
|
def load_safetensors(full_path: str):
|
|
28
27
|
"""Loads safetensors into a single state dictionary.
|
|
@@ -158,14 +157,22 @@ class ModelLoader:
|
|
|
158
157
|
f"{self._names.embedding_position}"
|
|
159
158
|
)
|
|
160
159
|
if self._names.lm_head is not None:
|
|
161
|
-
converted_state["lm_head.weight"] = state.pop(
|
|
160
|
+
converted_state["lm_head.weight"] = state.pop(
|
|
161
|
+
f"{self._names.lm_head}.weight"
|
|
162
|
+
)
|
|
162
163
|
if model.config.lm_head_use_bias:
|
|
163
|
-
converted_state["lm_head.bias"] = state.pop(
|
|
164
|
+
converted_state["lm_head.bias"] = state.pop(
|
|
165
|
+
f"{self._names.lm_head}.bias"
|
|
166
|
+
)
|
|
164
167
|
if self._names.final_norm is not None:
|
|
165
168
|
final_norm_name = self._names.final_norm
|
|
166
|
-
converted_state["final_norm.weight"] = state.pop(
|
|
169
|
+
converted_state["final_norm.weight"] = state.pop(
|
|
170
|
+
f"{final_norm_name}.weight"
|
|
171
|
+
)
|
|
167
172
|
if f"{final_norm_name}.bias" in state:
|
|
168
|
-
converted_state["final_norm.bias"] = state.pop(
|
|
173
|
+
converted_state["final_norm.bias"] = state.pop(
|
|
174
|
+
f"{final_norm_name}.bias"
|
|
175
|
+
)
|
|
169
176
|
|
|
170
177
|
for i in range(model.config.num_layers):
|
|
171
178
|
self._map_norm(i, model.config, state, converted_state)
|
|
@@ -214,18 +221,26 @@ class ModelLoader:
|
|
|
214
221
|
if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
|
215
222
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
|
216
223
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
|
217
|
-
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
|
224
|
+
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
|
225
|
+
f"{ff_up_proj_name}.weight"
|
|
226
|
+
)
|
|
218
227
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
|
219
228
|
f"{ff_down_proj_name}.weight"
|
|
220
229
|
)
|
|
221
230
|
if config.ff_config.use_bias:
|
|
222
|
-
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
|
223
|
-
|
|
231
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
|
232
|
+
f"{ff_up_proj_name}.bias"
|
|
233
|
+
)
|
|
234
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
|
235
|
+
f"{ff_down_proj_name}.bias"
|
|
236
|
+
)
|
|
224
237
|
else:
|
|
225
238
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
|
226
239
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
|
227
240
|
ff_gate_proj_name = self._names.ff_gate_proj.format(idx)
|
|
228
|
-
converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
|
|
241
|
+
converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
|
|
242
|
+
f"{ff_up_proj_name}.weight"
|
|
243
|
+
)
|
|
229
244
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
|
230
245
|
f"{ff_down_proj_name}.weight"
|
|
231
246
|
)
|
|
@@ -233,9 +248,15 @@ class ModelLoader:
|
|
|
233
248
|
f"{ff_gate_proj_name}.weight"
|
|
234
249
|
)
|
|
235
250
|
if config.ff_config.use_bias:
|
|
236
|
-
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
|
237
|
-
|
|
238
|
-
|
|
251
|
+
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
|
252
|
+
f"{ff_up_proj_name}.bias"
|
|
253
|
+
)
|
|
254
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
|
255
|
+
f"{ff_down_proj_name}.bias"
|
|
256
|
+
)
|
|
257
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
|
258
|
+
f"{ff_gate_proj_name}.bias"
|
|
259
|
+
)
|
|
239
260
|
|
|
240
261
|
def _map_attention(
|
|
241
262
|
self,
|
|
@@ -254,11 +275,13 @@ class ModelLoader:
|
|
|
254
275
|
q_name = self._names.attn_query_proj.format(idx)
|
|
255
276
|
k_name = self._names.attn_key_proj.format(idx)
|
|
256
277
|
v_name = self._names.attn_value_proj.format(idx)
|
|
257
|
-
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] =
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
278
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
|
|
279
|
+
self._fuse_qkv(
|
|
280
|
+
config,
|
|
281
|
+
state.pop(f"{q_name}.weight"),
|
|
282
|
+
state.pop(f"{k_name}.weight"),
|
|
283
|
+
state.pop(f"{v_name}.weight"),
|
|
284
|
+
)
|
|
262
285
|
)
|
|
263
286
|
if config.attn_config.qkv_use_bias:
|
|
264
287
|
if self._names.attn_fused_qkv_proj:
|
|
@@ -266,20 +289,22 @@ class ModelLoader:
|
|
|
266
289
|
f"{fused_qkv_name}.bias"
|
|
267
290
|
)
|
|
268
291
|
else:
|
|
269
|
-
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] =
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
292
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
|
|
293
|
+
self._fuse_qkv(
|
|
294
|
+
config,
|
|
295
|
+
state.pop(f"{q_name}.bias"),
|
|
296
|
+
state.pop(f"{k_name}.bias"),
|
|
297
|
+
state.pop(f"{v_name}.bias"),
|
|
298
|
+
)
|
|
274
299
|
)
|
|
275
300
|
|
|
276
301
|
o_name = self._names.attn_output_proj.format(idx)
|
|
277
|
-
converted_state[f"{prefix}.atten_func.output_projection.weight"] =
|
|
278
|
-
f"{o_name}.weight"
|
|
302
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
|
303
|
+
state.pop(f"{o_name}.weight")
|
|
279
304
|
)
|
|
280
305
|
if config.attn_config.output_proj_use_bias:
|
|
281
|
-
converted_state[f"{prefix}.atten_func.output_projection.bias"] =
|
|
282
|
-
f"{o_name}.bias"
|
|
306
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
|
307
|
+
state.pop(f"{o_name}.bias")
|
|
283
308
|
)
|
|
284
309
|
|
|
285
310
|
def _map_norm(
|
|
@@ -318,7 +343,9 @@ class ModelLoader:
|
|
|
318
343
|
v: torch.Tensor,
|
|
319
344
|
) -> torch.Tensor:
|
|
320
345
|
if config.attn_config.qkv_fused_interleaved:
|
|
321
|
-
q_per_kv =
|
|
346
|
+
q_per_kv = (
|
|
347
|
+
config.attn_config.num_heads // config.attn_config.num_query_groups
|
|
348
|
+
)
|
|
322
349
|
qs = torch.split(q, config.head_dim * q_per_kv)
|
|
323
350
|
ks = torch.split(k, config.head_dim)
|
|
324
351
|
vs = torch.split(v, config.head_dim)
|
|
@@ -16,11 +16,10 @@
|
|
|
16
16
|
from dataclasses import dataclass
|
|
17
17
|
from typing import Dict, List, Optional, Tuple
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
19
|
import ai_edge_torch.generative.layers.model_config as layers_config
|
|
22
20
|
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
23
21
|
import ai_edge_torch.generative.utilities.loader as loader
|
|
22
|
+
import torch
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
@dataclass
|
|
@@ -80,27 +79,35 @@ class TransformerBlockTensorNames:
|
|
|
80
79
|
class MidBlockTensorNames:
|
|
81
80
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
82
81
|
attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None
|
|
83
|
-
transformer_block_tensor_names: Optional[
|
|
82
|
+
transformer_block_tensor_names: Optional[
|
|
83
|
+
List[TransformerBlockTensorNames]
|
|
84
|
+
] = None
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
@dataclass
|
|
87
88
|
class DownEncoderBlockTensorNames:
|
|
88
89
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
89
|
-
transformer_block_tensor_names: Optional[
|
|
90
|
+
transformer_block_tensor_names: Optional[
|
|
91
|
+
List[TransformerBlockTensorNames]
|
|
92
|
+
] = None
|
|
90
93
|
downsample_conv: str = None
|
|
91
94
|
|
|
92
95
|
|
|
93
96
|
@dataclass
|
|
94
97
|
class UpDecoderBlockTensorNames:
|
|
95
98
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
96
|
-
transformer_block_tensor_names: Optional[
|
|
99
|
+
transformer_block_tensor_names: Optional[
|
|
100
|
+
List[TransformerBlockTensorNames]
|
|
101
|
+
] = None
|
|
97
102
|
upsample_conv: str = None
|
|
98
103
|
|
|
99
104
|
|
|
100
105
|
@dataclass
|
|
101
106
|
class SkipUpDecoderBlockTensorNames:
|
|
102
107
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
103
|
-
transformer_block_tensor_names: Optional[
|
|
108
|
+
transformer_block_tensor_names: Optional[
|
|
109
|
+
List[TransformerBlockTensorNames]
|
|
110
|
+
] = None
|
|
104
111
|
upsample_conv: str = None
|
|
105
112
|
|
|
106
113
|
|
|
@@ -119,7 +126,9 @@ def _map_to_converted_state(
|
|
|
119
126
|
converted_state[f"{converted_state_param}.weight"]
|
|
120
127
|
)
|
|
121
128
|
if f"{state_param}.bias" in state:
|
|
122
|
-
converted_state[f"{converted_state_param}.bias"] = state.pop(
|
|
129
|
+
converted_state[f"{converted_state_param}.bias"] = state.pop(
|
|
130
|
+
f"{state_param}.bias"
|
|
131
|
+
)
|
|
123
132
|
if squeeze_dims:
|
|
124
133
|
converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
|
|
125
134
|
converted_state[f"{converted_state_param}.bias"]
|
|
@@ -220,25 +229,41 @@ class BaseLoader(loader.ModelLoader):
|
|
|
220
229
|
f"{attention_layer_prefix}.v_projection",
|
|
221
230
|
squeeze_dims=True,
|
|
222
231
|
)
|
|
223
|
-
converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] =
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
232
|
+
converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = (
|
|
233
|
+
torch.concat(
|
|
234
|
+
[
|
|
235
|
+
converted_state[
|
|
236
|
+
f"{attention_layer_prefix}.q_projection.weight"
|
|
237
|
+
],
|
|
238
|
+
converted_state[
|
|
239
|
+
f"{attention_layer_prefix}.k_projection.weight"
|
|
240
|
+
],
|
|
241
|
+
converted_state[
|
|
242
|
+
f"{attention_layer_prefix}.v_projection.weight"
|
|
243
|
+
],
|
|
244
|
+
],
|
|
245
|
+
axis=0,
|
|
246
|
+
)
|
|
230
247
|
)
|
|
231
248
|
del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
|
|
232
249
|
del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
|
|
233
250
|
del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
|
|
234
251
|
if config.attention_config.qkv_use_bias:
|
|
235
|
-
converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] =
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
252
|
+
converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = (
|
|
253
|
+
torch.concat(
|
|
254
|
+
[
|
|
255
|
+
converted_state[
|
|
256
|
+
f"{attention_layer_prefix}.q_projection.bias"
|
|
257
|
+
],
|
|
258
|
+
converted_state[
|
|
259
|
+
f"{attention_layer_prefix}.k_projection.bias"
|
|
260
|
+
],
|
|
261
|
+
converted_state[
|
|
262
|
+
f"{attention_layer_prefix}.v_projection.bias"
|
|
263
|
+
],
|
|
264
|
+
],
|
|
265
|
+
axis=0,
|
|
266
|
+
)
|
|
242
267
|
)
|
|
243
268
|
del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
|
|
244
269
|
del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
|
|
@@ -316,11 +341,17 @@ class BaseLoader(loader.ModelLoader):
|
|
|
316
341
|
)
|
|
317
342
|
else:
|
|
318
343
|
_map_to_converted_state(
|
|
319
|
-
state,
|
|
344
|
+
state,
|
|
345
|
+
tensor_names.w1,
|
|
346
|
+
converted_state,
|
|
347
|
+
f"{converted_state_param_prefix}.w1",
|
|
320
348
|
)
|
|
321
349
|
|
|
322
350
|
_map_to_converted_state(
|
|
323
|
-
state,
|
|
351
|
+
state,
|
|
352
|
+
tensor_names.w2,
|
|
353
|
+
converted_state,
|
|
354
|
+
f"{converted_state_param_prefix}.w2",
|
|
324
355
|
)
|
|
325
356
|
|
|
326
357
|
def _map_transformer_block(
|
|
@@ -509,9 +540,13 @@ class BaseLoader(loader.ModelLoader):
|
|
|
509
540
|
):
|
|
510
541
|
for i in range(config.num_layers):
|
|
511
542
|
res_skip_channels = (
|
|
512
|
-
config.in_channels
|
|
543
|
+
config.in_channels
|
|
544
|
+
if (i == config.num_layers - 1)
|
|
545
|
+
else config.out_channels
|
|
546
|
+
)
|
|
547
|
+
resnet_in_channels = (
|
|
548
|
+
config.prev_out_channels if i == 0 else config.out_channels
|
|
513
549
|
)
|
|
514
|
-
resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
|
|
515
550
|
self._map_residual_block(
|
|
516
551
|
state,
|
|
517
552
|
converted_state,
|
|
@@ -599,9 +634,13 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
|
599
634
|
state, self._names.post_quant_conv, converted_state, "post_quant_conv"
|
|
600
635
|
)
|
|
601
636
|
if self._names.conv_in is not None:
|
|
602
|
-
_map_to_converted_state(
|
|
637
|
+
_map_to_converted_state(
|
|
638
|
+
state, self._names.conv_in, converted_state, "conv_in"
|
|
639
|
+
)
|
|
603
640
|
if self._names.conv_out is not None:
|
|
604
|
-
_map_to_converted_state(
|
|
641
|
+
_map_to_converted_state(
|
|
642
|
+
state, self._names.conv_out, converted_state, "conv_out"
|
|
643
|
+
)
|
|
605
644
|
if self._names.final_norm is not None:
|
|
606
645
|
_map_to_converted_state(
|
|
607
646
|
state, self._names.final_norm, converted_state, "final_norm"
|
|
@@ -614,7 +653,9 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
|
614
653
|
model.config.mid_block_config,
|
|
615
654
|
)
|
|
616
655
|
|
|
617
|
-
reversed_block_out_channels = list(
|
|
656
|
+
reversed_block_out_channels = list(
|
|
657
|
+
reversed(model.config.block_out_channels)
|
|
658
|
+
)
|
|
618
659
|
block_out_channels = reversed_block_out_channels[0]
|
|
619
660
|
for i, out_channels in enumerate(reversed_block_out_channels):
|
|
620
661
|
prev_output_channel = block_out_channels
|
|
@@ -690,8 +731,12 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
690
731
|
self._map_time_embedding(
|
|
691
732
|
state, converted_state, "time_embedding", self._names.time_embedding
|
|
692
733
|
)
|
|
693
|
-
_map_to_converted_state(
|
|
694
|
-
|
|
734
|
+
_map_to_converted_state(
|
|
735
|
+
state, self._names.conv_in, converted_state, "conv_in"
|
|
736
|
+
)
|
|
737
|
+
_map_to_converted_state(
|
|
738
|
+
state, self._names.conv_out, converted_state, "conv_out"
|
|
739
|
+
)
|
|
695
740
|
_map_to_converted_state(
|
|
696
741
|
state, self._names.final_norm, converted_state, "final_norm"
|
|
697
742
|
)
|
|
@@ -825,7 +870,9 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
825
870
|
)
|
|
826
871
|
|
|
827
872
|
# Map up_decoders.
|
|
828
|
-
reversed_block_out_channels = list(
|
|
873
|
+
reversed_block_out_channels = list(
|
|
874
|
+
reversed(model.config.block_out_channels)
|
|
875
|
+
)
|
|
829
876
|
up_decoder_layers_per_block = config.layers_per_block + 1
|
|
830
877
|
output_channel = reversed_block_out_channels[0]
|
|
831
878
|
for i, block_out_channel in enumerate(reversed_block_out_channels):
|
|
@@ -917,8 +964,14 @@ class DiffusionModelLoader(BaseLoader):
|
|
|
917
964
|
tensor_names: TimeEmbeddingTensorNames,
|
|
918
965
|
):
|
|
919
966
|
_map_to_converted_state(
|
|
920
|
-
state,
|
|
967
|
+
state,
|
|
968
|
+
tensor_names.w1,
|
|
969
|
+
converted_state,
|
|
970
|
+
f"{converted_state_param_prefix}.w1",
|
|
921
971
|
)
|
|
922
972
|
_map_to_converted_state(
|
|
923
|
-
state,
|
|
973
|
+
state,
|
|
974
|
+
tensor_names.w2,
|
|
975
|
+
converted_state,
|
|
976
|
+
f"{converted_state_param_prefix}.w2",
|
|
924
977
|
)
|