ai-edge-torch-nightly 0.2.0.dev20240610__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion_utils.py +17 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +4 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +164 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +2 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +49 -4
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -2
- ai_edge_torch/generative/quantize/quant_recipes.py +3 -3
- ai_edge_torch/generative/quantize/supported_schemes.py +2 -1
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +75 -20
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- ai_edge_torch/quantize/quant_config.py +11 -15
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +29 -27
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240610.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
|
@@ -318,7 +318,7 @@ class ModelLoader:
|
|
|
318
318
|
q_name = names.attn_query_proj.format(idx)
|
|
319
319
|
k_name = names.attn_key_proj.format(idx)
|
|
320
320
|
v_name = names.attn_value_proj.format(idx)
|
|
321
|
-
# model.encoder.transformer_blocks[0].atten_func.
|
|
321
|
+
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
|
|
322
322
|
if fuse_attention:
|
|
323
323
|
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
|
|
324
324
|
config,
|
|
@@ -334,18 +334,34 @@ class ModelLoader:
|
|
|
334
334
|
state.pop(f"{v_name}.bias"),
|
|
335
335
|
)
|
|
336
336
|
else:
|
|
337
|
-
converted_state[f"{prefix}.atten_func.
|
|
338
|
-
|
|
339
|
-
|
|
337
|
+
converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
|
|
338
|
+
f"{q_name}.weight"
|
|
339
|
+
)
|
|
340
|
+
converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
|
|
341
|
+
f"{k_name}.weight"
|
|
342
|
+
)
|
|
343
|
+
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
|
|
344
|
+
f"{v_name}.weight"
|
|
345
|
+
)
|
|
340
346
|
if config.attn_config.qkv_use_bias:
|
|
341
|
-
converted_state[f"{prefix}.atten_func.
|
|
342
|
-
|
|
343
|
-
|
|
347
|
+
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
|
|
348
|
+
f"{q_name}.bias"
|
|
349
|
+
)
|
|
350
|
+
converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
|
|
351
|
+
f"{k_name}.bias"
|
|
352
|
+
)
|
|
353
|
+
converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
|
|
354
|
+
f"{v_name}.bias"
|
|
355
|
+
)
|
|
344
356
|
|
|
345
357
|
o_name = names.attn_output_proj.format(idx)
|
|
346
|
-
converted_state[f"{prefix}.atten_func.
|
|
358
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
359
|
+
f"{o_name}.weight"
|
|
360
|
+
)
|
|
347
361
|
if config.attn_config.output_proj_use_bias:
|
|
348
|
-
converted_state[f"{prefix}.atten_func.
|
|
362
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
|
|
363
|
+
f"{o_name}.bias"
|
|
364
|
+
)
|
|
349
365
|
|
|
350
366
|
def _map_cross_attention(
|
|
351
367
|
self,
|
|
@@ -383,32 +399,32 @@ class ModelLoader:
|
|
|
383
399
|
state.pop(f"{v_name}.bias"),
|
|
384
400
|
)
|
|
385
401
|
else:
|
|
386
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
402
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
|
|
387
403
|
f"{q_name}.weight"
|
|
388
404
|
)
|
|
389
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
405
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
|
|
390
406
|
f"{k_name}.weight"
|
|
391
407
|
)
|
|
392
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
408
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
|
|
393
409
|
f"{v_name}.weight"
|
|
394
410
|
)
|
|
395
411
|
if config.attn_config.qkv_use_bias:
|
|
396
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
412
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
|
|
397
413
|
f"{q_name}.bias"
|
|
398
414
|
)
|
|
399
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
415
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
|
|
400
416
|
f"{k_name}.bias"
|
|
401
417
|
)
|
|
402
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
418
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
|
|
403
419
|
f"{v_name}.bias"
|
|
404
420
|
)
|
|
405
421
|
|
|
406
422
|
o_name = names.cross_attn_output_proj.format(idx)
|
|
407
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
423
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
|
|
408
424
|
f"{o_name}.weight"
|
|
409
425
|
)
|
|
410
426
|
if config.attn_config.output_proj_use_bias:
|
|
411
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
427
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
|
|
412
428
|
f"{o_name}.bias"
|
|
413
429
|
)
|
|
414
430
|
|
|
@@ -32,27 +32,26 @@ class QuantConfig:
|
|
|
32
32
|
pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
|
|
33
33
|
with PT2E quantization. This method of quantization is not applicable to
|
|
34
34
|
models created with the Edge Generative API.
|
|
35
|
-
|
|
35
|
+
generative_recipe: Quantization recipe to be applied on a model created
|
|
36
36
|
with the Edge Generative API.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
39
|
pt2e_quantizer: pt2eq.PT2EQuantizer = None
|
|
40
|
-
|
|
40
|
+
generative_recipe: quant_recipe.GenerativeQuantRecipe = None
|
|
41
41
|
|
|
42
42
|
@enum.unique
|
|
43
43
|
class _QuantizerMode(enum.Enum):
|
|
44
44
|
NONE = enum.auto()
|
|
45
45
|
PT2E_DYNAMIC = enum.auto()
|
|
46
46
|
PT2E_STATIC = enum.auto()
|
|
47
|
-
|
|
48
|
-
TFLITE_FP16 = enum.auto()
|
|
47
|
+
AI_EDGE_QUANTIZER = enum.auto()
|
|
49
48
|
|
|
50
49
|
_quantizer_mode: _QuantizerMode = _QuantizerMode.NONE
|
|
51
50
|
|
|
52
51
|
def __init__(
|
|
53
52
|
self,
|
|
54
53
|
pt2e_quantizer: Optional[pt2eq.PT2EQuantizer] = None,
|
|
55
|
-
|
|
54
|
+
generative_recipe: Optional[quant_recipe.GenerativeQuantRecipe] = None,
|
|
56
55
|
):
|
|
57
56
|
"""Initializes some internal states based on selected quantization method.
|
|
58
57
|
|
|
@@ -61,8 +60,8 @@ class QuantConfig:
|
|
|
61
60
|
is properly setup. Additionally sets up an utility enum _quantizer_mode to
|
|
62
61
|
guide certain conversion processes.
|
|
63
62
|
"""
|
|
64
|
-
if pt2e_quantizer is not None and
|
|
65
|
-
raise ValueError('Cannot set both pt2e_quantizer and
|
|
63
|
+
if pt2e_quantizer is not None and generative_recipe is not None:
|
|
64
|
+
raise ValueError('Cannot set both pt2e_quantizer and generative_recipe.')
|
|
66
65
|
elif pt2e_quantizer is not None:
|
|
67
66
|
object.__setattr__(self, 'pt2e_quantizer', pt2e_quantizer)
|
|
68
67
|
object.__setattr__(
|
|
@@ -74,12 +73,9 @@ class QuantConfig:
|
|
|
74
73
|
else self._QuantizerMode.PT2E_STATIC
|
|
75
74
|
),
|
|
76
75
|
)
|
|
77
|
-
elif
|
|
78
|
-
|
|
79
|
-
object.__setattr__(self, '
|
|
80
|
-
|
|
81
|
-
object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_DYNAMIC)
|
|
82
|
-
elif self.transformer_recipe.default.weight_dtype == quant_attrs.Dtype.FP16:
|
|
83
|
-
object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.TFLITE_FP16)
|
|
76
|
+
elif generative_recipe is not None:
|
|
77
|
+
generative_recipe.verify()
|
|
78
|
+
object.__setattr__(self, 'generative_recipe', generative_recipe)
|
|
79
|
+
object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
|
|
84
80
|
else:
|
|
85
|
-
raise ValueError('Either pt2e_quantizer or
|
|
81
|
+
raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240617
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=FPMmuFU3pyMREtjB_san1fy_0PFtAsgA0VZfOYvDrb4,100
|
|
|
2
2
|
ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
|
|
3
3
|
ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
4
4
|
ai_edge_torch/convert/conversion.py,sha256=GN2Js232u_5Y118wg3qIfEoYewxbxLl3TpSnO6osi8c,4029
|
|
5
|
-
ai_edge_torch/convert/conversion_utils.py,sha256=
|
|
5
|
+
ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
|
|
7
7
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
|
|
8
8
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
@@ -15,11 +15,11 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,
|
|
|
15
15
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
|
|
16
16
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
|
|
17
17
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
|
|
18
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
|
18
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
|
|
19
19
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
|
|
20
20
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
|
|
21
21
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
22
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=
|
|
22
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
|
|
23
23
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
24
24
|
ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
|
|
25
25
|
ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
|
|
@@ -41,9 +41,9 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTc
|
|
|
41
41
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
42
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
43
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
|
|
44
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
45
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
|
46
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
|
44
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
|
|
45
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
|
|
47
47
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
|
|
48
48
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
|
|
49
49
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
@@ -56,7 +56,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
|
|
|
56
56
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
57
57
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
58
58
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
|
|
59
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
|
59
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
|
|
60
60
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
61
61
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
|
|
62
62
|
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
|
|
@@ -65,34 +65,36 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
|
|
|
65
65
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
|
|
66
66
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
|
|
67
67
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
68
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
|
68
|
+
ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
|
|
69
69
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
70
70
|
ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
|
|
71
71
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
|
|
72
72
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
|
|
73
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
73
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
|
|
74
74
|
ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
|
|
75
75
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
76
76
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
77
77
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
78
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
|
79
|
-
ai_edge_torch/generative/layers/unet/builder.py,sha256=
|
|
80
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
|
78
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
|
|
79
|
+
ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
|
|
80
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
|
|
81
81
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
82
82
|
ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
|
|
83
|
-
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=
|
|
84
|
-
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=
|
|
85
|
-
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256
|
|
86
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
|
87
|
-
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=
|
|
83
|
+
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
84
|
+
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=Y8zahKw7b_h7ajPaJZVef4jG-MoqImRCpVSbFtV_i24,5139
|
|
85
|
+
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=-vd6Qp0BdXJVKg4f0_hhwbKOi3QPIAPVqyXnJ-ZnISQ,1915
|
|
86
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=9ItD70jQRXMEhWod-nUfEeoWGJUUu6V9YOffF07VU9g,1795
|
|
87
|
+
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
|
88
|
+
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
89
|
+
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
|
|
88
90
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
89
91
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
90
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
91
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
|
92
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
|
|
93
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=IjCbCPWzIgXk3s7y7SJsg2usIxhOqs3PuhFvEYR4Sdw,5388
|
|
92
94
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
93
|
-
ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
|
|
94
95
|
ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
|
|
95
|
-
ai_edge_torch/generative/utilities/
|
|
96
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
|
|
97
|
+
ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
|
|
96
98
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
97
99
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
98
100
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
|
|
@@ -103,12 +105,12 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=aUAPKnH4_Jxpp
|
|
|
103
105
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
|
104
106
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=ye1f5vAZ0Vr4RWAtfrgU1o3JLs03Sa4inHRq3YxJDGo,15602
|
|
105
107
|
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCtUMXVzq3vHKxblsd5Y,36046
|
|
106
|
-
ai_edge_torch/quantize/quant_config.py,sha256=
|
|
108
|
+
ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
|
|
107
109
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
108
110
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
109
111
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
110
|
-
ai_edge_torch_nightly-0.2.0.
|
|
111
|
-
ai_edge_torch_nightly-0.2.0.
|
|
112
|
-
ai_edge_torch_nightly-0.2.0.
|
|
113
|
-
ai_edge_torch_nightly-0.2.0.
|
|
114
|
-
ai_edge_torch_nightly-0.2.0.
|
|
112
|
+
ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
113
|
+
ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/METADATA,sha256=Z9rUO2CabVbBpydpRk8OxNlwK4yznGCb2QHGlJhqRsM,1748
|
|
114
|
+
ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
115
|
+
ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
116
|
+
ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/RECORD,,
|
|
@@ -1,298 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
# Common utility functions for data loading etc.
|
|
16
|
-
from dataclasses import dataclass
|
|
17
|
-
from typing import Dict, List, Tuple
|
|
18
|
-
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
|
-
import ai_edge_torch.generative.layers.model_config as layers_config
|
|
22
|
-
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
23
|
-
import ai_edge_torch.generative.utilities.loader as loader
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@dataclass
|
|
27
|
-
class ResidualBlockTensorNames:
|
|
28
|
-
norm_1: str = None
|
|
29
|
-
conv_1: str = None
|
|
30
|
-
norm_2: str = None
|
|
31
|
-
conv_2: str = None
|
|
32
|
-
residual_layer: str = None
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@dataclass
|
|
36
|
-
class AttnetionBlockTensorNames:
|
|
37
|
-
norm: str = None
|
|
38
|
-
fused_qkv_proj: str = None
|
|
39
|
-
output_proj: str = None
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@dataclass
|
|
43
|
-
class MidBlockTensorNames:
|
|
44
|
-
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
45
|
-
attention_block_tensor_names: List[AttnetionBlockTensorNames]
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@dataclass
|
|
49
|
-
class UpDecoderBlockTensorNames:
|
|
50
|
-
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
51
|
-
upsample_conv: str = None
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def _map_to_converted_state(
|
|
55
|
-
state: Dict[str, torch.Tensor],
|
|
56
|
-
state_param: str,
|
|
57
|
-
converted_state: Dict[str, torch.Tensor],
|
|
58
|
-
converted_state_param: str,
|
|
59
|
-
):
|
|
60
|
-
converted_state[f"{converted_state_param}.weight"] = state.pop(
|
|
61
|
-
f"{state_param}.weight"
|
|
62
|
-
)
|
|
63
|
-
if f"{state_param}.bias" in state:
|
|
64
|
-
converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class AutoEncoderModelLoader(loader.ModelLoader):
|
|
68
|
-
|
|
69
|
-
@dataclass
|
|
70
|
-
class TensorNames:
|
|
71
|
-
quant_conv: str = None
|
|
72
|
-
post_quant_conv: str = None
|
|
73
|
-
conv_in: str = None
|
|
74
|
-
conv_out: str = None
|
|
75
|
-
final_norm: str = None
|
|
76
|
-
mid_block_tensor_names: MidBlockTensorNames = None
|
|
77
|
-
up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
|
|
78
|
-
|
|
79
|
-
def __init__(self, file_name: str, names: TensorNames):
|
|
80
|
-
"""AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
84
|
-
exact file.
|
|
85
|
-
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
86
|
-
"""
|
|
87
|
-
self._file_name = file_name
|
|
88
|
-
self._names = names
|
|
89
|
-
self._loader = self._get_loader()
|
|
90
|
-
|
|
91
|
-
def load(
|
|
92
|
-
self, model: torch.nn.Module, strict: bool = True
|
|
93
|
-
) -> Tuple[List[str], List[str]]:
|
|
94
|
-
"""Load the model from the checkpoint.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
|
98
|
-
strict (bool, optional): Whether the converted keys are strictly
|
|
99
|
-
matched. Defaults to True.
|
|
100
|
-
|
|
101
|
-
Returns:
|
|
102
|
-
missing_keys (List[str]): a list of str containing the missing keys.
|
|
103
|
-
unexpected_keys (List[str]): a list of str containing the unexpected keys.
|
|
104
|
-
|
|
105
|
-
Raises:
|
|
106
|
-
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
107
|
-
enabled.
|
|
108
|
-
"""
|
|
109
|
-
state = self._loader(self._file_name)
|
|
110
|
-
converted_state = dict()
|
|
111
|
-
if self._names.quant_conv is not None:
|
|
112
|
-
_map_to_converted_state(
|
|
113
|
-
state, self._names.quant_conv, converted_state, "quant_conv"
|
|
114
|
-
)
|
|
115
|
-
if self._names.post_quant_conv is not None:
|
|
116
|
-
_map_to_converted_state(
|
|
117
|
-
state, self._names.post_quant_conv, converted_state, "post_quant_conv"
|
|
118
|
-
)
|
|
119
|
-
if self._names.conv_in is not None:
|
|
120
|
-
_map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
|
|
121
|
-
if self._names.conv_out is not None:
|
|
122
|
-
_map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
|
|
123
|
-
if self._names.final_norm is not None:
|
|
124
|
-
_map_to_converted_state(
|
|
125
|
-
state, self._names.final_norm, converted_state, "final_norm"
|
|
126
|
-
)
|
|
127
|
-
self._map_mid_block(
|
|
128
|
-
state,
|
|
129
|
-
converted_state,
|
|
130
|
-
model.config.mid_block_config,
|
|
131
|
-
self._names.mid_block_tensor_names,
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
reversed_block_out_channels = list(reversed(model.config.block_out_channels))
|
|
135
|
-
block_out_channels = reversed_block_out_channels[0]
|
|
136
|
-
for i, out_channels in enumerate(reversed_block_out_channels):
|
|
137
|
-
prev_output_channel = block_out_channels
|
|
138
|
-
block_out_channels = out_channels
|
|
139
|
-
not_final_block = i < len(reversed_block_out_channels) - 1
|
|
140
|
-
self._map_up_decoder_block(
|
|
141
|
-
state,
|
|
142
|
-
converted_state,
|
|
143
|
-
f"up_decoder_blocks.{i}",
|
|
144
|
-
unet_config.UpDecoderBlock2DConfig(
|
|
145
|
-
in_channels=prev_output_channel,
|
|
146
|
-
out_channels=block_out_channels,
|
|
147
|
-
normalization_config=model.config.normalization_config,
|
|
148
|
-
activation_type=model.config.activation_type,
|
|
149
|
-
num_layers=model.config.layers_per_block,
|
|
150
|
-
add_upsample=not_final_block,
|
|
151
|
-
upsample_conv=True,
|
|
152
|
-
),
|
|
153
|
-
self._names.up_decoder_blocks_tensor_names[i],
|
|
154
|
-
)
|
|
155
|
-
if strict and state:
|
|
156
|
-
raise ValueError(
|
|
157
|
-
f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
|
|
158
|
-
)
|
|
159
|
-
return model.load_state_dict(converted_state, strict=strict)
|
|
160
|
-
|
|
161
|
-
def _map_residual_block(
|
|
162
|
-
self,
|
|
163
|
-
state: Dict[str, torch.Tensor],
|
|
164
|
-
converted_state: Dict[str, torch.Tensor],
|
|
165
|
-
tensor_names: ResidualBlockTensorNames,
|
|
166
|
-
converted_state_param_prefix: str,
|
|
167
|
-
config: unet_config.ResidualBlock2DConfig,
|
|
168
|
-
):
|
|
169
|
-
_map_to_converted_state(
|
|
170
|
-
state,
|
|
171
|
-
tensor_names.norm_1,
|
|
172
|
-
converted_state,
|
|
173
|
-
f"{converted_state_param_prefix}.norm_1",
|
|
174
|
-
)
|
|
175
|
-
_map_to_converted_state(
|
|
176
|
-
state,
|
|
177
|
-
tensor_names.conv_1,
|
|
178
|
-
converted_state,
|
|
179
|
-
f"{converted_state_param_prefix}.conv_1",
|
|
180
|
-
)
|
|
181
|
-
_map_to_converted_state(
|
|
182
|
-
state,
|
|
183
|
-
tensor_names.norm_2,
|
|
184
|
-
converted_state,
|
|
185
|
-
f"{converted_state_param_prefix}.norm_2",
|
|
186
|
-
)
|
|
187
|
-
_map_to_converted_state(
|
|
188
|
-
state,
|
|
189
|
-
tensor_names.conv_2,
|
|
190
|
-
converted_state,
|
|
191
|
-
f"{converted_state_param_prefix}.conv_2",
|
|
192
|
-
)
|
|
193
|
-
if config.in_channels != config.out_channels:
|
|
194
|
-
_map_to_converted_state(
|
|
195
|
-
state,
|
|
196
|
-
tensor_names.residual_layer,
|
|
197
|
-
converted_state,
|
|
198
|
-
f"{converted_state_param_prefix}.residual_layer",
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
def _map_attention_block(
|
|
202
|
-
self,
|
|
203
|
-
state: Dict[str, torch.Tensor],
|
|
204
|
-
converted_state: Dict[str, torch.Tensor],
|
|
205
|
-
tensor_names: AttnetionBlockTensorNames,
|
|
206
|
-
converted_state_param_prefix: str,
|
|
207
|
-
config: unet_config.AttentionBlock2DConfig,
|
|
208
|
-
):
|
|
209
|
-
if config.normalization_config.type != layers_config.NormalizationType.NONE:
|
|
210
|
-
_map_to_converted_state(
|
|
211
|
-
state,
|
|
212
|
-
tensor_names.norm,
|
|
213
|
-
converted_state,
|
|
214
|
-
f"{converted_state_param_prefix}.norm",
|
|
215
|
-
)
|
|
216
|
-
attention_layer_prefix = f"{converted_state_param_prefix}.attention"
|
|
217
|
-
_map_to_converted_state(
|
|
218
|
-
state,
|
|
219
|
-
tensor_names.fused_qkv_proj,
|
|
220
|
-
converted_state,
|
|
221
|
-
f"{attention_layer_prefix}.qkv_projection",
|
|
222
|
-
)
|
|
223
|
-
_map_to_converted_state(
|
|
224
|
-
state,
|
|
225
|
-
tensor_names.output_proj,
|
|
226
|
-
converted_state,
|
|
227
|
-
f"{attention_layer_prefix}.output_projection",
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
def _map_mid_block(
|
|
231
|
-
self,
|
|
232
|
-
state: Dict[str, torch.Tensor],
|
|
233
|
-
converted_state: Dict[str, torch.Tensor],
|
|
234
|
-
config: unet_config.MidBlock2DConfig,
|
|
235
|
-
tensor_names: MidBlockTensorNames,
|
|
236
|
-
):
|
|
237
|
-
converted_state_param_prefix = "mid_block"
|
|
238
|
-
residual_block_config = unet_config.ResidualBlock2DConfig(
|
|
239
|
-
in_channels=config.in_channels,
|
|
240
|
-
out_channels=config.in_channels,
|
|
241
|
-
time_embedding_channels=config.time_embedding_channels,
|
|
242
|
-
normalization_config=config.normalization_config,
|
|
243
|
-
activation_type=config.activation_type,
|
|
244
|
-
)
|
|
245
|
-
self._map_residual_block(
|
|
246
|
-
state,
|
|
247
|
-
converted_state,
|
|
248
|
-
tensor_names.residual_block_tensor_names[0],
|
|
249
|
-
f"{converted_state_param_prefix}.resnets.0",
|
|
250
|
-
residual_block_config,
|
|
251
|
-
)
|
|
252
|
-
for i in range(config.num_layers):
|
|
253
|
-
if config.attention_block_config:
|
|
254
|
-
self._map_attention_block(
|
|
255
|
-
state,
|
|
256
|
-
converted_state,
|
|
257
|
-
tensor_names.attention_block_tensor_names[i],
|
|
258
|
-
f"{converted_state_param_prefix}.attentions.{i}",
|
|
259
|
-
config.attention_block_config,
|
|
260
|
-
)
|
|
261
|
-
self._map_residual_block(
|
|
262
|
-
state,
|
|
263
|
-
converted_state,
|
|
264
|
-
tensor_names.residual_block_tensor_names[i + 1],
|
|
265
|
-
f"{converted_state_param_prefix}.resnets.{i+1}",
|
|
266
|
-
residual_block_config,
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
def _map_up_decoder_block(
|
|
270
|
-
self,
|
|
271
|
-
state: Dict[str, torch.Tensor],
|
|
272
|
-
converted_state: Dict[str, torch.Tensor],
|
|
273
|
-
converted_state_param_prefix: str,
|
|
274
|
-
config: unet_config.UpDecoderBlock2DConfig,
|
|
275
|
-
tensor_names: UpDecoderBlockTensorNames,
|
|
276
|
-
):
|
|
277
|
-
for i in range(config.num_layers):
|
|
278
|
-
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
279
|
-
self._map_residual_block(
|
|
280
|
-
state,
|
|
281
|
-
converted_state,
|
|
282
|
-
tensor_names.residual_block_tensor_names[i],
|
|
283
|
-
f"{converted_state_param_prefix}.resnets.{i}",
|
|
284
|
-
unet_config.ResidualBlock2DConfig(
|
|
285
|
-
in_channels=input_channels,
|
|
286
|
-
out_channels=config.out_channels,
|
|
287
|
-
time_embedding_channels=config.time_embedding_channels,
|
|
288
|
-
normalization_config=config.normalization_config,
|
|
289
|
-
activation_type=config.activation_type,
|
|
290
|
-
),
|
|
291
|
-
)
|
|
292
|
-
if config.add_upsample and config.upsample_conv:
|
|
293
|
-
_map_to_converted_state(
|
|
294
|
-
state,
|
|
295
|
-
tensor_names.upsample_conv,
|
|
296
|
-
converted_state,
|
|
297
|
-
f"{converted_state_param_prefix}.upsample_conv",
|
|
298
|
-
)
|
|
File without changes
|
|
File without changes
|