ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240618__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (24) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/debug/__init__.py +1 -0
  4. ai_edge_torch/debug/culprit.py +70 -29
  5. ai_edge_torch/debug/test/test_search_model.py +50 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  9. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  10. ai_edge_torch/generative/layers/attention.py +154 -26
  11. ai_edge_torch/generative/layers/model_config.py +3 -0
  12. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  13. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  14. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  15. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  16. ai_edge_torch/generative/test/test_quantize.py +1 -0
  17. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  18. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/RECORD +23 -22
  21. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  22. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240618.dist-info}/top_level.txt +0 -0
@@ -318,7 +318,7 @@ class ModelLoader:
318
318
  q_name = names.attn_query_proj.format(idx)
319
319
  k_name = names.attn_key_proj.format(idx)
320
320
  v_name = names.attn_value_proj.format(idx)
321
- # model.encoder.transformer_blocks[0].atten_func.q.weight
321
+ # model.encoder.transformer_blocks[0].atten_func.q_projection.weight
322
322
  if fuse_attention:
323
323
  converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
324
324
  config,
@@ -334,18 +334,34 @@ class ModelLoader:
334
334
  state.pop(f"{v_name}.bias"),
335
335
  )
336
336
  else:
337
- converted_state[f"{prefix}.atten_func.q.weight"] = state.pop(f"{q_name}.weight")
338
- converted_state[f"{prefix}.atten_func.k.weight"] = state.pop(f"{k_name}.weight")
339
- converted_state[f"{prefix}.atten_func.v.weight"] = state.pop(f"{v_name}.weight")
337
+ converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
338
+ f"{q_name}.weight"
339
+ )
340
+ converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
341
+ f"{k_name}.weight"
342
+ )
343
+ converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
344
+ f"{v_name}.weight"
345
+ )
340
346
  if config.attn_config.qkv_use_bias:
341
- converted_state[f"{prefix}.atten_func.q.bias"] = state.pop(f"{q_name}.bias")
342
- converted_state[f"{prefix}.atten_func.k.bias"] = state.pop(f"{k_name}.bias")
343
- converted_state[f"{prefix}.atten_func.v.bias"] = state.pop(f"{v_name}.bias")
347
+ converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
348
+ f"{q_name}.bias"
349
+ )
350
+ converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
351
+ f"{k_name}.bias"
352
+ )
353
+ converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
354
+ f"{v_name}.bias"
355
+ )
344
356
 
345
357
  o_name = names.attn_output_proj.format(idx)
346
- converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
358
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
359
+ f"{o_name}.weight"
360
+ )
347
361
  if config.attn_config.output_proj_use_bias:
348
- converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
362
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
363
+ f"{o_name}.bias"
364
+ )
349
365
 
350
366
  def _map_cross_attention(
351
367
  self,
@@ -383,32 +399,32 @@ class ModelLoader:
383
399
  state.pop(f"{v_name}.bias"),
384
400
  )
385
401
  else:
386
- converted_state[f"{prefix}.cross_atten_func.q.weight"] = state.pop(
402
+ converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
387
403
  f"{q_name}.weight"
388
404
  )
389
- converted_state[f"{prefix}.cross_atten_func.k.weight"] = state.pop(
405
+ converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
390
406
  f"{k_name}.weight"
391
407
  )
392
- converted_state[f"{prefix}.cross_atten_func.v.weight"] = state.pop(
408
+ converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
393
409
  f"{v_name}.weight"
394
410
  )
395
411
  if config.attn_config.qkv_use_bias:
396
- converted_state[f"{prefix}.cross_atten_func.q.bias"] = state.pop(
412
+ converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
397
413
  f"{q_name}.bias"
398
414
  )
399
- converted_state[f"{prefix}.cross_atten_func.k.bias"] = state.pop(
415
+ converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
400
416
  f"{k_name}.bias"
401
417
  )
402
- converted_state[f"{prefix}.cross_atten_func.v.bias"] = state.pop(
418
+ converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
403
419
  f"{v_name}.bias"
404
420
  )
405
421
 
406
422
  o_name = names.cross_attn_output_proj.format(idx)
407
- converted_state[f"{prefix}.cross_atten_func.proj.weight"] = state.pop(
423
+ converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
408
424
  f"{o_name}.weight"
409
425
  )
410
426
  if config.attn_config.output_proj_use_bias:
411
- converted_state[f"{prefix}.cross_atten_func.proj.bias"] = state.pop(
427
+ converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
412
428
  f"{o_name}.bias"
413
429
  )
414
430
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240611
3
+ Version: 0.2.0.dev20240618
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -15,20 +15,21 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,
15
15
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
16
16
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
17
17
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
18
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=wr59GFss8fP8Vy--BaBj34Bto0N16gXxQj6OuTXH8cE,10030
18
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
19
19
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
20
20
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
21
21
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
22
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
22
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
23
23
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
24
24
  ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
25
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
26
26
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
27
- ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
28
- ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
27
+ ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
28
+ ai_edge_torch/debug/culprit.py,sha256=urtCKPXORPvn6oyDxDSCSjgvngUnjjcsUMwAOeIl15E,14236
29
29
  ai_edge_torch/debug/utils.py,sha256=hjVmQVVl1dKxEF0D6KB4a3ouQ3wBkTsebOX2YsUObZM,1430
30
30
  ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
31
31
  ai_edge_torch/debug/test/test_culprit.py,sha256=9An_n9p_RWTAYdHYTCO-__EJlbnjclCDo8tDhOzMlwk,3731
32
+ ai_edge_torch/debug/test/test_search_model.py,sha256=0guAEon5cvwBpPXk6J0wVOKj7TXMDaiuomEEQmHgO5o,1590
32
33
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
33
34
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
34
35
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -41,9 +42,9 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTc
41
42
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
43
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
44
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
44
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=NmgDo5uAefrhMUbYku0TKHlqzO0NVWI_M1ue8tddQR4,4024
45
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=meW8t-3BDdjFs5vCAf76cn6lGx49a_GcEvnVa9R5if4,11106
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=_gEeUxa9Xyd3iLb_fyeUefHKuELVDorDlQs8e7wdXKg,7878
45
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
46
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
47
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
47
48
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
48
49
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
50
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -56,7 +57,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
56
57
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
58
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
59
  ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
59
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
60
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
60
61
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
62
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
62
63
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
@@ -65,19 +66,19 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
65
66
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
66
67
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
67
68
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
68
- ai_edge_torch/generative/layers/attention.py,sha256=Z8gXHYs6h8gaRiYAdvYUbHzg_2EmqfxiChsf_SYraAc,7902
69
+ ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
69
70
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
70
71
  ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
71
72
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
72
73
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
73
- ai_edge_torch/generative/layers/model_config.py,sha256=toWECENDWgay9hsZcy4C89qph0KI3CpaeFqFc8Fr-Xk,4584
74
+ ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
74
75
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
75
76
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
77
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
77
78
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=7mHyJYq9lq5zVYp4mEz-R8Az3FFngi711YC20KP6ED8,10066
79
- ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonPTfrXQHcF8U0slrIREM,1210
80
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
79
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
80
+ ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
81
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
81
82
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
83
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
83
84
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -89,12 +90,12 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
89
90
  ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
90
91
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
91
92
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
92
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
93
- ai_edge_torch/generative/test/test_quantize.py,sha256=NVlMixAxVpDUabEvp6zTHHgIDgHFsMRwlf5MuyDwrPg,5355
93
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
94
+ ai_edge_torch/generative/test/test_quantize.py,sha256=IjCbCPWzIgXk3s7y7SJsg2usIxhOqs3PuhFvEYR4Sdw,5388
94
95
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
95
- ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
96
96
  ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
97
- ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
97
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
98
+ ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
98
99
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
99
100
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
100
101
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
@@ -109,8 +110,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
109
110
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
111
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
111
112
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
112
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
113
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/METADATA,sha256=WPGu2pq6N57fBtpunyFhunPe73UK_SVbqlZQsZwjWGo,1748
114
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
115
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
116
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/RECORD,,
113
+ ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
114
+ ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/METADATA,sha256=aMhby_ftyg_8pWf8klYbHTCP7rMDcmuSTeryoRKt4U0,1748
115
+ ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
116
+ ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
117
+ ai_edge_torch_nightly-0.2.0.dev20240618.dist-info/RECORD,,
@@ -1,298 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Common utility functions for data loading etc.
16
- from dataclasses import dataclass
17
- from typing import Dict, List, Tuple
18
-
19
- import torch
20
-
21
- import ai_edge_torch.generative.layers.model_config as layers_config
22
- import ai_edge_torch.generative.layers.unet.model_config as unet_config
23
- import ai_edge_torch.generative.utilities.loader as loader
24
-
25
-
26
- @dataclass
27
- class ResidualBlockTensorNames:
28
- norm_1: str = None
29
- conv_1: str = None
30
- norm_2: str = None
31
- conv_2: str = None
32
- residual_layer: str = None
33
-
34
-
35
- @dataclass
36
- class AttnetionBlockTensorNames:
37
- norm: str = None
38
- fused_qkv_proj: str = None
39
- output_proj: str = None
40
-
41
-
42
- @dataclass
43
- class MidBlockTensorNames:
44
- residual_block_tensor_names: List[ResidualBlockTensorNames]
45
- attention_block_tensor_names: List[AttnetionBlockTensorNames]
46
-
47
-
48
- @dataclass
49
- class UpDecoderBlockTensorNames:
50
- residual_block_tensor_names: List[ResidualBlockTensorNames]
51
- upsample_conv: str = None
52
-
53
-
54
- def _map_to_converted_state(
55
- state: Dict[str, torch.Tensor],
56
- state_param: str,
57
- converted_state: Dict[str, torch.Tensor],
58
- converted_state_param: str,
59
- ):
60
- converted_state[f"{converted_state_param}.weight"] = state.pop(
61
- f"{state_param}.weight"
62
- )
63
- if f"{state_param}.bias" in state:
64
- converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
65
-
66
-
67
- class AutoEncoderModelLoader(loader.ModelLoader):
68
-
69
- @dataclass
70
- class TensorNames:
71
- quant_conv: str = None
72
- post_quant_conv: str = None
73
- conv_in: str = None
74
- conv_out: str = None
75
- final_norm: str = None
76
- mid_block_tensor_names: MidBlockTensorNames = None
77
- up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
78
-
79
- def __init__(self, file_name: str, names: TensorNames):
80
- """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
81
-
82
- Args:
83
- file_name (str): Path to the checkpoint. Can be a directory or an
84
- exact file.
85
- names (TensorNames): An instance of `TensorNames` to determine mappings.
86
- """
87
- self._file_name = file_name
88
- self._names = names
89
- self._loader = self._get_loader()
90
-
91
- def load(
92
- self, model: torch.nn.Module, strict: bool = True
93
- ) -> Tuple[List[str], List[str]]:
94
- """Load the model from the checkpoint.
95
-
96
- Args:
97
- model (torch.nn.Module): The pytorch model that needs to be loaded.
98
- strict (bool, optional): Whether the converted keys are strictly
99
- matched. Defaults to True.
100
-
101
- Returns:
102
- missing_keys (List[str]): a list of str containing the missing keys.
103
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
104
-
105
- Raises:
106
- ValueError: If conversion results in unmapped tensors and strict mode is
107
- enabled.
108
- """
109
- state = self._loader(self._file_name)
110
- converted_state = dict()
111
- if self._names.quant_conv is not None:
112
- _map_to_converted_state(
113
- state, self._names.quant_conv, converted_state, "quant_conv"
114
- )
115
- if self._names.post_quant_conv is not None:
116
- _map_to_converted_state(
117
- state, self._names.post_quant_conv, converted_state, "post_quant_conv"
118
- )
119
- if self._names.conv_in is not None:
120
- _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
121
- if self._names.conv_out is not None:
122
- _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
123
- if self._names.final_norm is not None:
124
- _map_to_converted_state(
125
- state, self._names.final_norm, converted_state, "final_norm"
126
- )
127
- self._map_mid_block(
128
- state,
129
- converted_state,
130
- model.config.mid_block_config,
131
- self._names.mid_block_tensor_names,
132
- )
133
-
134
- reversed_block_out_channels = list(reversed(model.config.block_out_channels))
135
- block_out_channels = reversed_block_out_channels[0]
136
- for i, out_channels in enumerate(reversed_block_out_channels):
137
- prev_output_channel = block_out_channels
138
- block_out_channels = out_channels
139
- not_final_block = i < len(reversed_block_out_channels) - 1
140
- self._map_up_decoder_block(
141
- state,
142
- converted_state,
143
- f"up_decoder_blocks.{i}",
144
- unet_config.UpDecoderBlock2DConfig(
145
- in_channels=prev_output_channel,
146
- out_channels=block_out_channels,
147
- normalization_config=model.config.normalization_config,
148
- activation_type=model.config.activation_type,
149
- num_layers=model.config.layers_per_block,
150
- add_upsample=not_final_block,
151
- upsample_conv=True,
152
- ),
153
- self._names.up_decoder_blocks_tensor_names[i],
154
- )
155
- if strict and state:
156
- raise ValueError(
157
- f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
158
- )
159
- return model.load_state_dict(converted_state, strict=strict)
160
-
161
- def _map_residual_block(
162
- self,
163
- state: Dict[str, torch.Tensor],
164
- converted_state: Dict[str, torch.Tensor],
165
- tensor_names: ResidualBlockTensorNames,
166
- converted_state_param_prefix: str,
167
- config: unet_config.ResidualBlock2DConfig,
168
- ):
169
- _map_to_converted_state(
170
- state,
171
- tensor_names.norm_1,
172
- converted_state,
173
- f"{converted_state_param_prefix}.norm_1",
174
- )
175
- _map_to_converted_state(
176
- state,
177
- tensor_names.conv_1,
178
- converted_state,
179
- f"{converted_state_param_prefix}.conv_1",
180
- )
181
- _map_to_converted_state(
182
- state,
183
- tensor_names.norm_2,
184
- converted_state,
185
- f"{converted_state_param_prefix}.norm_2",
186
- )
187
- _map_to_converted_state(
188
- state,
189
- tensor_names.conv_2,
190
- converted_state,
191
- f"{converted_state_param_prefix}.conv_2",
192
- )
193
- if config.in_channels != config.out_channels:
194
- _map_to_converted_state(
195
- state,
196
- tensor_names.residual_layer,
197
- converted_state,
198
- f"{converted_state_param_prefix}.residual_layer",
199
- )
200
-
201
- def _map_attention_block(
202
- self,
203
- state: Dict[str, torch.Tensor],
204
- converted_state: Dict[str, torch.Tensor],
205
- tensor_names: AttnetionBlockTensorNames,
206
- converted_state_param_prefix: str,
207
- config: unet_config.AttentionBlock2DConfig,
208
- ):
209
- if config.normalization_config.type != layers_config.NormalizationType.NONE:
210
- _map_to_converted_state(
211
- state,
212
- tensor_names.norm,
213
- converted_state,
214
- f"{converted_state_param_prefix}.norm",
215
- )
216
- attention_layer_prefix = f"{converted_state_param_prefix}.attention"
217
- _map_to_converted_state(
218
- state,
219
- tensor_names.fused_qkv_proj,
220
- converted_state,
221
- f"{attention_layer_prefix}.qkv_projection",
222
- )
223
- _map_to_converted_state(
224
- state,
225
- tensor_names.output_proj,
226
- converted_state,
227
- f"{attention_layer_prefix}.output_projection",
228
- )
229
-
230
- def _map_mid_block(
231
- self,
232
- state: Dict[str, torch.Tensor],
233
- converted_state: Dict[str, torch.Tensor],
234
- config: unet_config.MidBlock2DConfig,
235
- tensor_names: MidBlockTensorNames,
236
- ):
237
- converted_state_param_prefix = "mid_block"
238
- residual_block_config = unet_config.ResidualBlock2DConfig(
239
- in_channels=config.in_channels,
240
- out_channels=config.in_channels,
241
- time_embedding_channels=config.time_embedding_channels,
242
- normalization_config=config.normalization_config,
243
- activation_type=config.activation_type,
244
- )
245
- self._map_residual_block(
246
- state,
247
- converted_state,
248
- tensor_names.residual_block_tensor_names[0],
249
- f"{converted_state_param_prefix}.resnets.0",
250
- residual_block_config,
251
- )
252
- for i in range(config.num_layers):
253
- if config.attention_block_config:
254
- self._map_attention_block(
255
- state,
256
- converted_state,
257
- tensor_names.attention_block_tensor_names[i],
258
- f"{converted_state_param_prefix}.attentions.{i}",
259
- config.attention_block_config,
260
- )
261
- self._map_residual_block(
262
- state,
263
- converted_state,
264
- tensor_names.residual_block_tensor_names[i + 1],
265
- f"{converted_state_param_prefix}.resnets.{i+1}",
266
- residual_block_config,
267
- )
268
-
269
- def _map_up_decoder_block(
270
- self,
271
- state: Dict[str, torch.Tensor],
272
- converted_state: Dict[str, torch.Tensor],
273
- converted_state_param_prefix: str,
274
- config: unet_config.UpDecoderBlock2DConfig,
275
- tensor_names: UpDecoderBlockTensorNames,
276
- ):
277
- for i in range(config.num_layers):
278
- input_channels = config.in_channels if i == 0 else config.out_channels
279
- self._map_residual_block(
280
- state,
281
- converted_state,
282
- tensor_names.residual_block_tensor_names[i],
283
- f"{converted_state_param_prefix}.resnets.{i}",
284
- unet_config.ResidualBlock2DConfig(
285
- in_channels=input_channels,
286
- out_channels=config.out_channels,
287
- time_embedding_channels=config.time_embedding_channels,
288
- normalization_config=config.normalization_config,
289
- activation_type=config.activation_type,
290
- ),
291
- )
292
- if config.add_upsample and config.upsample_conv:
293
- _map_to_converted_state(
294
- state,
295
- tensor_names.upsample_conv,
296
- converted_state,
297
- f"{converted_state_param_prefix}.upsample_conv",
298
- )