ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240618__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
- ai_edge_torch/debug/__init__.py +1 -0
- ai_edge_torch/debug/culprit.py +70 -29
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
- ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
- ai_edge_torch/generative/layers/attention.py +154 -26
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
- ai_edge_torch/generative/layers/unet/builder.py +20 -2
- ai_edge_torch/generative/layers/unet/model_config.py +157 -5
- ai_edge_torch/generative/test/test_model_conversion.py +24 -0
- ai_edge_torch/generative/test/test_quantize.py +1 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
- ai_edge_torch/generative/utilities/t5_loader.py +33 -17
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/RECORD +23 -22
- ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/top_level.txt +0 -0
|
@@ -318,7 +318,7 @@ class ModelLoader:
|
|
|
318
318
|
q_name = names.attn_query_proj.format(idx)
|
|
319
319
|
k_name = names.attn_key_proj.format(idx)
|
|
320
320
|
v_name = names.attn_value_proj.format(idx)
|
|
321
|
-
# model.encoder.transformer_blocks[0].atten_func.
|
|
321
|
+
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
|
|
322
322
|
if fuse_attention:
|
|
323
323
|
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
|
|
324
324
|
config,
|
|
@@ -334,18 +334,34 @@ class ModelLoader:
|
|
|
334
334
|
state.pop(f"{v_name}.bias"),
|
|
335
335
|
)
|
|
336
336
|
else:
|
|
337
|
-
converted_state[f"{prefix}.atten_func.
|
|
338
|
-
|
|
339
|
-
|
|
337
|
+
converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
|
|
338
|
+
f"{q_name}.weight"
|
|
339
|
+
)
|
|
340
|
+
converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
|
|
341
|
+
f"{k_name}.weight"
|
|
342
|
+
)
|
|
343
|
+
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
|
|
344
|
+
f"{v_name}.weight"
|
|
345
|
+
)
|
|
340
346
|
if config.attn_config.qkv_use_bias:
|
|
341
|
-
converted_state[f"{prefix}.atten_func.
|
|
342
|
-
|
|
343
|
-
|
|
347
|
+
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
|
|
348
|
+
f"{q_name}.bias"
|
|
349
|
+
)
|
|
350
|
+
converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
|
|
351
|
+
f"{k_name}.bias"
|
|
352
|
+
)
|
|
353
|
+
converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
|
|
354
|
+
f"{v_name}.bias"
|
|
355
|
+
)
|
|
344
356
|
|
|
345
357
|
o_name = names.attn_output_proj.format(idx)
|
|
346
|
-
converted_state[f"{prefix}.atten_func.
|
|
358
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
359
|
+
f"{o_name}.weight"
|
|
360
|
+
)
|
|
347
361
|
if config.attn_config.output_proj_use_bias:
|
|
348
|
-
converted_state[f"{prefix}.atten_func.
|
|
362
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
|
|
363
|
+
f"{o_name}.bias"
|
|
364
|
+
)
|
|
349
365
|
|
|
350
366
|
def _map_cross_attention(
|
|
351
367
|
self,
|
|
@@ -383,32 +399,32 @@ class ModelLoader:
|
|
|
383
399
|
state.pop(f"{v_name}.bias"),
|
|
384
400
|
)
|
|
385
401
|
else:
|
|
386
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
402
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
|
|
387
403
|
f"{q_name}.weight"
|
|
388
404
|
)
|
|
389
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
405
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
|
|
390
406
|
f"{k_name}.weight"
|
|
391
407
|
)
|
|
392
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
408
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
|
|
393
409
|
f"{v_name}.weight"
|
|
394
410
|
)
|
|
395
411
|
if config.attn_config.qkv_use_bias:
|
|
396
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
412
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
|
|
397
413
|
f"{q_name}.bias"
|
|
398
414
|
)
|
|
399
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
415
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
|
|
400
416
|
f"{k_name}.bias"
|
|
401
417
|
)
|
|
402
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
418
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
|
|
403
419
|
f"{v_name}.bias"
|
|
404
420
|
)
|
|
405
421
|
|
|
406
422
|
o_name = names.cross_attn_output_proj.format(idx)
|
|
407
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
423
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
|
|
408
424
|
f"{o_name}.weight"
|
|
409
425
|
)
|
|
410
426
|
if config.attn_config.output_proj_use_bias:
|
|
411
|
-
converted_state[f"{prefix}.cross_atten_func.
|
|
427
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
|
|
412
428
|
f"{o_name}.bias"
|
|
413
429
|
)
|
|
414
430
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240618
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -15,20 +15,21 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,
|
|
|
15
15
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
|
|
16
16
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
|
|
17
17
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
|
|
18
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
|
18
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
|
|
19
19
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
|
|
20
20
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
|
|
21
21
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
22
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=
|
|
22
|
+
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
|
|
23
23
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
24
24
|
ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
|
|
25
25
|
ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
|
|
26
26
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
27
|
-
ai_edge_torch/debug/__init__.py,sha256=
|
|
28
|
-
ai_edge_torch/debug/culprit.py,sha256=
|
|
27
|
+
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
|
28
|
+
ai_edge_torch/debug/culprit.py,sha256=urtCKPXORPvn6oyDxDSCSjgvngUnjjcsUMwAOeIl15E,14236
|
|
29
29
|
ai_edge_torch/debug/utils.py,sha256=hjVmQVVl1dKxEF0D6KB4a3ouQ3wBkTsebOX2YsUObZM,1430
|
|
30
30
|
ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
31
31
|
ai_edge_torch/debug/test/test_culprit.py,sha256=9An_n9p_RWTAYdHYTCO-__EJlbnjclCDo8tDhOzMlwk,3731
|
|
32
|
+
ai_edge_torch/debug/test/test_search_model.py,sha256=0guAEon5cvwBpPXk6J0wVOKj7TXMDaiuomEEQmHgO5o,1590
|
|
32
33
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
33
34
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
34
35
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
@@ -41,9 +42,9 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTc
|
|
|
41
42
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
43
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
44
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
|
|
44
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
45
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
|
46
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
|
45
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
|
|
47
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
|
|
47
48
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
|
|
48
49
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
|
|
49
50
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
@@ -56,7 +57,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
|
|
|
56
57
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
57
58
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
58
59
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
|
|
59
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
|
60
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
|
|
60
61
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
61
62
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
|
|
62
63
|
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
|
|
@@ -65,19 +66,19 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
|
|
|
65
66
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
|
|
66
67
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
|
|
67
68
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
68
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
|
69
|
+
ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
|
|
69
70
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
70
71
|
ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
|
|
71
72
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
|
|
72
73
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
|
|
73
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
74
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
|
|
74
75
|
ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
|
|
75
76
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
76
77
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
77
78
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
78
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
|
79
|
-
ai_edge_torch/generative/layers/unet/builder.py,sha256=
|
|
80
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
|
79
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
|
|
80
|
+
ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
|
|
81
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
|
|
81
82
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
82
83
|
ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
|
|
83
84
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
@@ -89,12 +90,12 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
|
|
|
89
90
|
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
|
|
90
91
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
91
92
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
92
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
93
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
|
93
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
|
|
94
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=IjCbCPWzIgXk3s7y7SJsg2usIxhOqs3PuhFvEYR4Sdw,5388
|
|
94
95
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
95
|
-
ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
|
|
96
96
|
ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
|
|
97
|
-
ai_edge_torch/generative/utilities/
|
|
97
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
|
|
98
|
+
ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
|
|
98
99
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
99
100
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
100
101
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
|
|
@@ -109,8 +110,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
109
110
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
110
111
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
111
112
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
112
|
-
ai_edge_torch_nightly-0.2.0.
|
|
113
|
-
ai_edge_torch_nightly-0.2.0.
|
|
114
|
-
ai_edge_torch_nightly-0.2.0.
|
|
115
|
-
ai_edge_torch_nightly-0.2.0.
|
|
116
|
-
ai_edge_torch_nightly-0.2.0.
|
|
113
|
+
ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
114
|
+
ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/METADATA,sha256=aMhby_ftyg_8pWf8klYbHTCP7rMDcmuSTeryoRKt4U0,1748
|
|
115
|
+
ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
116
|
+
ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
117
|
+
ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/RECORD,,
|
|
@@ -1,298 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
# Common utility functions for data loading etc.
|
|
16
|
-
from dataclasses import dataclass
|
|
17
|
-
from typing import Dict, List, Tuple
|
|
18
|
-
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
|
-
import ai_edge_torch.generative.layers.model_config as layers_config
|
|
22
|
-
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
23
|
-
import ai_edge_torch.generative.utilities.loader as loader
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
@dataclass
|
|
27
|
-
class ResidualBlockTensorNames:
|
|
28
|
-
norm_1: str = None
|
|
29
|
-
conv_1: str = None
|
|
30
|
-
norm_2: str = None
|
|
31
|
-
conv_2: str = None
|
|
32
|
-
residual_layer: str = None
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@dataclass
|
|
36
|
-
class AttnetionBlockTensorNames:
|
|
37
|
-
norm: str = None
|
|
38
|
-
fused_qkv_proj: str = None
|
|
39
|
-
output_proj: str = None
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@dataclass
|
|
43
|
-
class MidBlockTensorNames:
|
|
44
|
-
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
45
|
-
attention_block_tensor_names: List[AttnetionBlockTensorNames]
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@dataclass
|
|
49
|
-
class UpDecoderBlockTensorNames:
|
|
50
|
-
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
|
51
|
-
upsample_conv: str = None
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def _map_to_converted_state(
|
|
55
|
-
state: Dict[str, torch.Tensor],
|
|
56
|
-
state_param: str,
|
|
57
|
-
converted_state: Dict[str, torch.Tensor],
|
|
58
|
-
converted_state_param: str,
|
|
59
|
-
):
|
|
60
|
-
converted_state[f"{converted_state_param}.weight"] = state.pop(
|
|
61
|
-
f"{state_param}.weight"
|
|
62
|
-
)
|
|
63
|
-
if f"{state_param}.bias" in state:
|
|
64
|
-
converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class AutoEncoderModelLoader(loader.ModelLoader):
|
|
68
|
-
|
|
69
|
-
@dataclass
|
|
70
|
-
class TensorNames:
|
|
71
|
-
quant_conv: str = None
|
|
72
|
-
post_quant_conv: str = None
|
|
73
|
-
conv_in: str = None
|
|
74
|
-
conv_out: str = None
|
|
75
|
-
final_norm: str = None
|
|
76
|
-
mid_block_tensor_names: MidBlockTensorNames = None
|
|
77
|
-
up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
|
|
78
|
-
|
|
79
|
-
def __init__(self, file_name: str, names: TensorNames):
|
|
80
|
-
"""AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
|
84
|
-
exact file.
|
|
85
|
-
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
|
86
|
-
"""
|
|
87
|
-
self._file_name = file_name
|
|
88
|
-
self._names = names
|
|
89
|
-
self._loader = self._get_loader()
|
|
90
|
-
|
|
91
|
-
def load(
|
|
92
|
-
self, model: torch.nn.Module, strict: bool = True
|
|
93
|
-
) -> Tuple[List[str], List[str]]:
|
|
94
|
-
"""Load the model from the checkpoint.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
model (torch.nn.Module): The pytorch model that needs to be loaded.
|
|
98
|
-
strict (bool, optional): Whether the converted keys are strictly
|
|
99
|
-
matched. Defaults to True.
|
|
100
|
-
|
|
101
|
-
Returns:
|
|
102
|
-
missing_keys (List[str]): a list of str containing the missing keys.
|
|
103
|
-
unexpected_keys (List[str]): a list of str containing the unexpected keys.
|
|
104
|
-
|
|
105
|
-
Raises:
|
|
106
|
-
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
107
|
-
enabled.
|
|
108
|
-
"""
|
|
109
|
-
state = self._loader(self._file_name)
|
|
110
|
-
converted_state = dict()
|
|
111
|
-
if self._names.quant_conv is not None:
|
|
112
|
-
_map_to_converted_state(
|
|
113
|
-
state, self._names.quant_conv, converted_state, "quant_conv"
|
|
114
|
-
)
|
|
115
|
-
if self._names.post_quant_conv is not None:
|
|
116
|
-
_map_to_converted_state(
|
|
117
|
-
state, self._names.post_quant_conv, converted_state, "post_quant_conv"
|
|
118
|
-
)
|
|
119
|
-
if self._names.conv_in is not None:
|
|
120
|
-
_map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
|
|
121
|
-
if self._names.conv_out is not None:
|
|
122
|
-
_map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
|
|
123
|
-
if self._names.final_norm is not None:
|
|
124
|
-
_map_to_converted_state(
|
|
125
|
-
state, self._names.final_norm, converted_state, "final_norm"
|
|
126
|
-
)
|
|
127
|
-
self._map_mid_block(
|
|
128
|
-
state,
|
|
129
|
-
converted_state,
|
|
130
|
-
model.config.mid_block_config,
|
|
131
|
-
self._names.mid_block_tensor_names,
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
reversed_block_out_channels = list(reversed(model.config.block_out_channels))
|
|
135
|
-
block_out_channels = reversed_block_out_channels[0]
|
|
136
|
-
for i, out_channels in enumerate(reversed_block_out_channels):
|
|
137
|
-
prev_output_channel = block_out_channels
|
|
138
|
-
block_out_channels = out_channels
|
|
139
|
-
not_final_block = i < len(reversed_block_out_channels) - 1
|
|
140
|
-
self._map_up_decoder_block(
|
|
141
|
-
state,
|
|
142
|
-
converted_state,
|
|
143
|
-
f"up_decoder_blocks.{i}",
|
|
144
|
-
unet_config.UpDecoderBlock2DConfig(
|
|
145
|
-
in_channels=prev_output_channel,
|
|
146
|
-
out_channels=block_out_channels,
|
|
147
|
-
normalization_config=model.config.normalization_config,
|
|
148
|
-
activation_type=model.config.activation_type,
|
|
149
|
-
num_layers=model.config.layers_per_block,
|
|
150
|
-
add_upsample=not_final_block,
|
|
151
|
-
upsample_conv=True,
|
|
152
|
-
),
|
|
153
|
-
self._names.up_decoder_blocks_tensor_names[i],
|
|
154
|
-
)
|
|
155
|
-
if strict and state:
|
|
156
|
-
raise ValueError(
|
|
157
|
-
f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
|
|
158
|
-
)
|
|
159
|
-
return model.load_state_dict(converted_state, strict=strict)
|
|
160
|
-
|
|
161
|
-
def _map_residual_block(
|
|
162
|
-
self,
|
|
163
|
-
state: Dict[str, torch.Tensor],
|
|
164
|
-
converted_state: Dict[str, torch.Tensor],
|
|
165
|
-
tensor_names: ResidualBlockTensorNames,
|
|
166
|
-
converted_state_param_prefix: str,
|
|
167
|
-
config: unet_config.ResidualBlock2DConfig,
|
|
168
|
-
):
|
|
169
|
-
_map_to_converted_state(
|
|
170
|
-
state,
|
|
171
|
-
tensor_names.norm_1,
|
|
172
|
-
converted_state,
|
|
173
|
-
f"{converted_state_param_prefix}.norm_1",
|
|
174
|
-
)
|
|
175
|
-
_map_to_converted_state(
|
|
176
|
-
state,
|
|
177
|
-
tensor_names.conv_1,
|
|
178
|
-
converted_state,
|
|
179
|
-
f"{converted_state_param_prefix}.conv_1",
|
|
180
|
-
)
|
|
181
|
-
_map_to_converted_state(
|
|
182
|
-
state,
|
|
183
|
-
tensor_names.norm_2,
|
|
184
|
-
converted_state,
|
|
185
|
-
f"{converted_state_param_prefix}.norm_2",
|
|
186
|
-
)
|
|
187
|
-
_map_to_converted_state(
|
|
188
|
-
state,
|
|
189
|
-
tensor_names.conv_2,
|
|
190
|
-
converted_state,
|
|
191
|
-
f"{converted_state_param_prefix}.conv_2",
|
|
192
|
-
)
|
|
193
|
-
if config.in_channels != config.out_channels:
|
|
194
|
-
_map_to_converted_state(
|
|
195
|
-
state,
|
|
196
|
-
tensor_names.residual_layer,
|
|
197
|
-
converted_state,
|
|
198
|
-
f"{converted_state_param_prefix}.residual_layer",
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
def _map_attention_block(
|
|
202
|
-
self,
|
|
203
|
-
state: Dict[str, torch.Tensor],
|
|
204
|
-
converted_state: Dict[str, torch.Tensor],
|
|
205
|
-
tensor_names: AttnetionBlockTensorNames,
|
|
206
|
-
converted_state_param_prefix: str,
|
|
207
|
-
config: unet_config.AttentionBlock2DConfig,
|
|
208
|
-
):
|
|
209
|
-
if config.normalization_config.type != layers_config.NormalizationType.NONE:
|
|
210
|
-
_map_to_converted_state(
|
|
211
|
-
state,
|
|
212
|
-
tensor_names.norm,
|
|
213
|
-
converted_state,
|
|
214
|
-
f"{converted_state_param_prefix}.norm",
|
|
215
|
-
)
|
|
216
|
-
attention_layer_prefix = f"{converted_state_param_prefix}.attention"
|
|
217
|
-
_map_to_converted_state(
|
|
218
|
-
state,
|
|
219
|
-
tensor_names.fused_qkv_proj,
|
|
220
|
-
converted_state,
|
|
221
|
-
f"{attention_layer_prefix}.qkv_projection",
|
|
222
|
-
)
|
|
223
|
-
_map_to_converted_state(
|
|
224
|
-
state,
|
|
225
|
-
tensor_names.output_proj,
|
|
226
|
-
converted_state,
|
|
227
|
-
f"{attention_layer_prefix}.output_projection",
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
def _map_mid_block(
|
|
231
|
-
self,
|
|
232
|
-
state: Dict[str, torch.Tensor],
|
|
233
|
-
converted_state: Dict[str, torch.Tensor],
|
|
234
|
-
config: unet_config.MidBlock2DConfig,
|
|
235
|
-
tensor_names: MidBlockTensorNames,
|
|
236
|
-
):
|
|
237
|
-
converted_state_param_prefix = "mid_block"
|
|
238
|
-
residual_block_config = unet_config.ResidualBlock2DConfig(
|
|
239
|
-
in_channels=config.in_channels,
|
|
240
|
-
out_channels=config.in_channels,
|
|
241
|
-
time_embedding_channels=config.time_embedding_channels,
|
|
242
|
-
normalization_config=config.normalization_config,
|
|
243
|
-
activation_type=config.activation_type,
|
|
244
|
-
)
|
|
245
|
-
self._map_residual_block(
|
|
246
|
-
state,
|
|
247
|
-
converted_state,
|
|
248
|
-
tensor_names.residual_block_tensor_names[0],
|
|
249
|
-
f"{converted_state_param_prefix}.resnets.0",
|
|
250
|
-
residual_block_config,
|
|
251
|
-
)
|
|
252
|
-
for i in range(config.num_layers):
|
|
253
|
-
if config.attention_block_config:
|
|
254
|
-
self._map_attention_block(
|
|
255
|
-
state,
|
|
256
|
-
converted_state,
|
|
257
|
-
tensor_names.attention_block_tensor_names[i],
|
|
258
|
-
f"{converted_state_param_prefix}.attentions.{i}",
|
|
259
|
-
config.attention_block_config,
|
|
260
|
-
)
|
|
261
|
-
self._map_residual_block(
|
|
262
|
-
state,
|
|
263
|
-
converted_state,
|
|
264
|
-
tensor_names.residual_block_tensor_names[i + 1],
|
|
265
|
-
f"{converted_state_param_prefix}.resnets.{i+1}",
|
|
266
|
-
residual_block_config,
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
def _map_up_decoder_block(
|
|
270
|
-
self,
|
|
271
|
-
state: Dict[str, torch.Tensor],
|
|
272
|
-
converted_state: Dict[str, torch.Tensor],
|
|
273
|
-
converted_state_param_prefix: str,
|
|
274
|
-
config: unet_config.UpDecoderBlock2DConfig,
|
|
275
|
-
tensor_names: UpDecoderBlockTensorNames,
|
|
276
|
-
):
|
|
277
|
-
for i in range(config.num_layers):
|
|
278
|
-
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
279
|
-
self._map_residual_block(
|
|
280
|
-
state,
|
|
281
|
-
converted_state,
|
|
282
|
-
tensor_names.residual_block_tensor_names[i],
|
|
283
|
-
f"{converted_state_param_prefix}.resnets.{i}",
|
|
284
|
-
unet_config.ResidualBlock2DConfig(
|
|
285
|
-
in_channels=input_channels,
|
|
286
|
-
out_channels=config.out_channels,
|
|
287
|
-
time_embedding_channels=config.time_embedding_channels,
|
|
288
|
-
normalization_config=config.normalization_config,
|
|
289
|
-
activation_type=config.activation_type,
|
|
290
|
-
),
|
|
291
|
-
)
|
|
292
|
-
if config.add_upsample and config.upsample_conv:
|
|
293
|
-
_map_to_converted_state(
|
|
294
|
-
state,
|
|
295
|
-
tensor_names.upsample_conv,
|
|
296
|
-
converted_state,
|
|
297
|
-
f"{converted_state_param_prefix}.upsample_conv",
|
|
298
|
-
)
|
|
File without changes
|
|
File without changes
|