ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (21) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  4. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  5. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  6. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  7. ai_edge_torch/generative/layers/attention.py +154 -26
  8. ai_edge_torch/generative/layers/model_config.py +3 -0
  9. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  10. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  11. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  12. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  13. ai_edge_torch/generative/test/test_quantize.py +1 -0
  14. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  15. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  16. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +20 -20
  18. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
  21. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
@@ -318,7 +318,7 @@ class ModelLoader:
318
318
  q_name = names.attn_query_proj.format(idx)
319
319
  k_name = names.attn_key_proj.format(idx)
320
320
  v_name = names.attn_value_proj.format(idx)
321
- # model.encoder.transformer_blocks[0].atten_func.q.weight
321
+ # model.encoder.transformer_blocks[0].atten_func.q_projection.weight
322
322
  if fuse_attention:
323
323
  converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
324
324
  config,
@@ -334,18 +334,34 @@ class ModelLoader:
334
334
  state.pop(f"{v_name}.bias"),
335
335
  )
336
336
  else:
337
- converted_state[f"{prefix}.atten_func.q.weight"] = state.pop(f"{q_name}.weight")
338
- converted_state[f"{prefix}.atten_func.k.weight"] = state.pop(f"{k_name}.weight")
339
- converted_state[f"{prefix}.atten_func.v.weight"] = state.pop(f"{v_name}.weight")
337
+ converted_state[f"{prefix}.atten_func.q_projection.weight"] = state.pop(
338
+ f"{q_name}.weight"
339
+ )
340
+ converted_state[f"{prefix}.atten_func.k_projection.weight"] = state.pop(
341
+ f"{k_name}.weight"
342
+ )
343
+ converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
344
+ f"{v_name}.weight"
345
+ )
340
346
  if config.attn_config.qkv_use_bias:
341
- converted_state[f"{prefix}.atten_func.q.bias"] = state.pop(f"{q_name}.bias")
342
- converted_state[f"{prefix}.atten_func.k.bias"] = state.pop(f"{k_name}.bias")
343
- converted_state[f"{prefix}.atten_func.v.bias"] = state.pop(f"{v_name}.bias")
347
+ converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
348
+ f"{q_name}.bias"
349
+ )
350
+ converted_state[f"{prefix}.atten_func.k_projection.bias"] = state.pop(
351
+ f"{k_name}.bias"
352
+ )
353
+ converted_state[f"{prefix}.atten_func.v_projection.bias"] = state.pop(
354
+ f"{v_name}.bias"
355
+ )
344
356
 
345
357
  o_name = names.attn_output_proj.format(idx)
346
- converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
358
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
359
+ f"{o_name}.weight"
360
+ )
347
361
  if config.attn_config.output_proj_use_bias:
348
- converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
362
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
363
+ f"{o_name}.bias"
364
+ )
349
365
 
350
366
  def _map_cross_attention(
351
367
  self,
@@ -383,32 +399,32 @@ class ModelLoader:
383
399
  state.pop(f"{v_name}.bias"),
384
400
  )
385
401
  else:
386
- converted_state[f"{prefix}.cross_atten_func.q.weight"] = state.pop(
402
+ converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
387
403
  f"{q_name}.weight"
388
404
  )
389
- converted_state[f"{prefix}.cross_atten_func.k.weight"] = state.pop(
405
+ converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
390
406
  f"{k_name}.weight"
391
407
  )
392
- converted_state[f"{prefix}.cross_atten_func.v.weight"] = state.pop(
408
+ converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
393
409
  f"{v_name}.weight"
394
410
  )
395
411
  if config.attn_config.qkv_use_bias:
396
- converted_state[f"{prefix}.cross_atten_func.q.bias"] = state.pop(
412
+ converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
397
413
  f"{q_name}.bias"
398
414
  )
399
- converted_state[f"{prefix}.cross_atten_func.k.bias"] = state.pop(
415
+ converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
400
416
  f"{k_name}.bias"
401
417
  )
402
- converted_state[f"{prefix}.cross_atten_func.v.bias"] = state.pop(
418
+ converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
403
419
  f"{v_name}.bias"
404
420
  )
405
421
 
406
422
  o_name = names.cross_attn_output_proj.format(idx)
407
- converted_state[f"{prefix}.cross_atten_func.proj.weight"] = state.pop(
423
+ converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
408
424
  f"{o_name}.weight"
409
425
  )
410
426
  if config.attn_config.output_proj_use_bias:
411
- converted_state[f"{prefix}.cross_atten_func.proj.bias"] = state.pop(
427
+ converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
412
428
  f"{o_name}.bias"
413
429
  )
414
430
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240611
3
+ Version: 0.2.0.dev20240617
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -15,11 +15,11 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,
15
15
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
16
16
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=RAgU31B98PQmXEIM3GOjgS0q9aRe2whJhGXpW2EjoqY,12438
17
17
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=tCx7J-WIFnxFCeRBtqJ159jWLgK9_9DCJrR4mkeBuYE,982
18
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=wr59GFss8fP8Vy--BaBj34Bto0N16gXxQj6OuTXH8cE,10030
18
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=cfY6RTWQTGXNoQxKHaDcBYR9QdkVQXOWjKhuxvglocw,10383
19
19
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
20
20
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
21
21
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
22
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
22
+ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
23
23
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
24
24
  ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
25
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
@@ -41,9 +41,9 @@ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTc
41
41
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
43
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
44
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=NmgDo5uAefrhMUbYku0TKHlqzO0NVWI_M1ue8tddQR4,4024
45
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=meW8t-3BDdjFs5vCAf76cn6lGx49a_GcEvnVa9R5if4,11106
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=_gEeUxa9Xyd3iLb_fyeUefHKuELVDorDlQs8e7wdXKg,7878
44
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
45
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
47
47
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
48
48
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
49
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -56,7 +56,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
56
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
58
  ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
59
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
59
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
60
60
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
62
62
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
@@ -65,19 +65,19 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
65
65
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
66
66
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
67
67
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
68
- ai_edge_torch/generative/layers/attention.py,sha256=Z8gXHYs6h8gaRiYAdvYUbHzg_2EmqfxiChsf_SYraAc,7902
68
+ ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
69
69
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
70
70
  ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
71
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
72
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
73
- ai_edge_torch/generative/layers/model_config.py,sha256=toWECENDWgay9hsZcy4C89qph0KI3CpaeFqFc8Fr-Xk,4584
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
74
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
75
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
76
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
77
77
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=7mHyJYq9lq5zVYp4mEz-R8Az3FFngi711YC20KP6ED8,10066
79
- ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonPTfrXQHcF8U0slrIREM,1210
80
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
78
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
79
+ ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
80
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
81
81
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
82
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
83
83
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -89,12 +89,12 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DE
89
89
  ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=qUB4f2DoB14dLkNPWf6TZodpT81mfAJeWM-lCAmkuHY,5735
90
90
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
91
91
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
92
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
93
- ai_edge_torch/generative/test/test_quantize.py,sha256=NVlMixAxVpDUabEvp6zTHHgIDgHFsMRwlf5MuyDwrPg,5355
92
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
93
+ ai_edge_torch/generative/test/test_quantize.py,sha256=IjCbCPWzIgXk3s7y7SJsg2usIxhOqs3PuhFvEYR4Sdw,5388
94
94
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
95
- ai_edge_torch/generative/utilities/autoencoder_loader.py,sha256=G2Nosy33JzkjGALPR4JjvffdFX1JWOj2zjbbuaDJEgg,10065
96
95
  ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
97
- ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
96
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
97
+ ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
98
98
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
99
99
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
100
100
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
@@ -109,8 +109,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
109
109
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
110
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
111
111
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
112
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
113
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/METADATA,sha256=WPGu2pq6N57fBtpunyFhunPe73UK_SVbqlZQsZwjWGo,1748
114
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
115
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
116
- ai_edge_torch_nightly-0.2.0.dev20240611.dist-info/RECORD,,
112
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
113
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/METADATA,sha256=Z9rUO2CabVbBpydpRk8OxNlwK4yznGCb2QHGlJhqRsM,1748
114
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
115
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
116
+ ai_edge_torch_nightly-0.2.0.dev20240617.dist-info/RECORD,,
@@ -1,298 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Common utility functions for data loading etc.
16
- from dataclasses import dataclass
17
- from typing import Dict, List, Tuple
18
-
19
- import torch
20
-
21
- import ai_edge_torch.generative.layers.model_config as layers_config
22
- import ai_edge_torch.generative.layers.unet.model_config as unet_config
23
- import ai_edge_torch.generative.utilities.loader as loader
24
-
25
-
26
- @dataclass
27
- class ResidualBlockTensorNames:
28
- norm_1: str = None
29
- conv_1: str = None
30
- norm_2: str = None
31
- conv_2: str = None
32
- residual_layer: str = None
33
-
34
-
35
- @dataclass
36
- class AttnetionBlockTensorNames:
37
- norm: str = None
38
- fused_qkv_proj: str = None
39
- output_proj: str = None
40
-
41
-
42
- @dataclass
43
- class MidBlockTensorNames:
44
- residual_block_tensor_names: List[ResidualBlockTensorNames]
45
- attention_block_tensor_names: List[AttnetionBlockTensorNames]
46
-
47
-
48
- @dataclass
49
- class UpDecoderBlockTensorNames:
50
- residual_block_tensor_names: List[ResidualBlockTensorNames]
51
- upsample_conv: str = None
52
-
53
-
54
- def _map_to_converted_state(
55
- state: Dict[str, torch.Tensor],
56
- state_param: str,
57
- converted_state: Dict[str, torch.Tensor],
58
- converted_state_param: str,
59
- ):
60
- converted_state[f"{converted_state_param}.weight"] = state.pop(
61
- f"{state_param}.weight"
62
- )
63
- if f"{state_param}.bias" in state:
64
- converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
65
-
66
-
67
- class AutoEncoderModelLoader(loader.ModelLoader):
68
-
69
- @dataclass
70
- class TensorNames:
71
- quant_conv: str = None
72
- post_quant_conv: str = None
73
- conv_in: str = None
74
- conv_out: str = None
75
- final_norm: str = None
76
- mid_block_tensor_names: MidBlockTensorNames = None
77
- up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
78
-
79
- def __init__(self, file_name: str, names: TensorNames):
80
- """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
81
-
82
- Args:
83
- file_name (str): Path to the checkpoint. Can be a directory or an
84
- exact file.
85
- names (TensorNames): An instance of `TensorNames` to determine mappings.
86
- """
87
- self._file_name = file_name
88
- self._names = names
89
- self._loader = self._get_loader()
90
-
91
- def load(
92
- self, model: torch.nn.Module, strict: bool = True
93
- ) -> Tuple[List[str], List[str]]:
94
- """Load the model from the checkpoint.
95
-
96
- Args:
97
- model (torch.nn.Module): The pytorch model that needs to be loaded.
98
- strict (bool, optional): Whether the converted keys are strictly
99
- matched. Defaults to True.
100
-
101
- Returns:
102
- missing_keys (List[str]): a list of str containing the missing keys.
103
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
104
-
105
- Raises:
106
- ValueError: If conversion results in unmapped tensors and strict mode is
107
- enabled.
108
- """
109
- state = self._loader(self._file_name)
110
- converted_state = dict()
111
- if self._names.quant_conv is not None:
112
- _map_to_converted_state(
113
- state, self._names.quant_conv, converted_state, "quant_conv"
114
- )
115
- if self._names.post_quant_conv is not None:
116
- _map_to_converted_state(
117
- state, self._names.post_quant_conv, converted_state, "post_quant_conv"
118
- )
119
- if self._names.conv_in is not None:
120
- _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
121
- if self._names.conv_out is not None:
122
- _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
123
- if self._names.final_norm is not None:
124
- _map_to_converted_state(
125
- state, self._names.final_norm, converted_state, "final_norm"
126
- )
127
- self._map_mid_block(
128
- state,
129
- converted_state,
130
- model.config.mid_block_config,
131
- self._names.mid_block_tensor_names,
132
- )
133
-
134
- reversed_block_out_channels = list(reversed(model.config.block_out_channels))
135
- block_out_channels = reversed_block_out_channels[0]
136
- for i, out_channels in enumerate(reversed_block_out_channels):
137
- prev_output_channel = block_out_channels
138
- block_out_channels = out_channels
139
- not_final_block = i < len(reversed_block_out_channels) - 1
140
- self._map_up_decoder_block(
141
- state,
142
- converted_state,
143
- f"up_decoder_blocks.{i}",
144
- unet_config.UpDecoderBlock2DConfig(
145
- in_channels=prev_output_channel,
146
- out_channels=block_out_channels,
147
- normalization_config=model.config.normalization_config,
148
- activation_type=model.config.activation_type,
149
- num_layers=model.config.layers_per_block,
150
- add_upsample=not_final_block,
151
- upsample_conv=True,
152
- ),
153
- self._names.up_decoder_blocks_tensor_names[i],
154
- )
155
- if strict and state:
156
- raise ValueError(
157
- f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
158
- )
159
- return model.load_state_dict(converted_state, strict=strict)
160
-
161
- def _map_residual_block(
162
- self,
163
- state: Dict[str, torch.Tensor],
164
- converted_state: Dict[str, torch.Tensor],
165
- tensor_names: ResidualBlockTensorNames,
166
- converted_state_param_prefix: str,
167
- config: unet_config.ResidualBlock2DConfig,
168
- ):
169
- _map_to_converted_state(
170
- state,
171
- tensor_names.norm_1,
172
- converted_state,
173
- f"{converted_state_param_prefix}.norm_1",
174
- )
175
- _map_to_converted_state(
176
- state,
177
- tensor_names.conv_1,
178
- converted_state,
179
- f"{converted_state_param_prefix}.conv_1",
180
- )
181
- _map_to_converted_state(
182
- state,
183
- tensor_names.norm_2,
184
- converted_state,
185
- f"{converted_state_param_prefix}.norm_2",
186
- )
187
- _map_to_converted_state(
188
- state,
189
- tensor_names.conv_2,
190
- converted_state,
191
- f"{converted_state_param_prefix}.conv_2",
192
- )
193
- if config.in_channels != config.out_channels:
194
- _map_to_converted_state(
195
- state,
196
- tensor_names.residual_layer,
197
- converted_state,
198
- f"{converted_state_param_prefix}.residual_layer",
199
- )
200
-
201
- def _map_attention_block(
202
- self,
203
- state: Dict[str, torch.Tensor],
204
- converted_state: Dict[str, torch.Tensor],
205
- tensor_names: AttnetionBlockTensorNames,
206
- converted_state_param_prefix: str,
207
- config: unet_config.AttentionBlock2DConfig,
208
- ):
209
- if config.normalization_config.type != layers_config.NormalizationType.NONE:
210
- _map_to_converted_state(
211
- state,
212
- tensor_names.norm,
213
- converted_state,
214
- f"{converted_state_param_prefix}.norm",
215
- )
216
- attention_layer_prefix = f"{converted_state_param_prefix}.attention"
217
- _map_to_converted_state(
218
- state,
219
- tensor_names.fused_qkv_proj,
220
- converted_state,
221
- f"{attention_layer_prefix}.qkv_projection",
222
- )
223
- _map_to_converted_state(
224
- state,
225
- tensor_names.output_proj,
226
- converted_state,
227
- f"{attention_layer_prefix}.output_projection",
228
- )
229
-
230
- def _map_mid_block(
231
- self,
232
- state: Dict[str, torch.Tensor],
233
- converted_state: Dict[str, torch.Tensor],
234
- config: unet_config.MidBlock2DConfig,
235
- tensor_names: MidBlockTensorNames,
236
- ):
237
- converted_state_param_prefix = "mid_block"
238
- residual_block_config = unet_config.ResidualBlock2DConfig(
239
- in_channels=config.in_channels,
240
- out_channels=config.in_channels,
241
- time_embedding_channels=config.time_embedding_channels,
242
- normalization_config=config.normalization_config,
243
- activation_type=config.activation_type,
244
- )
245
- self._map_residual_block(
246
- state,
247
- converted_state,
248
- tensor_names.residual_block_tensor_names[0],
249
- f"{converted_state_param_prefix}.resnets.0",
250
- residual_block_config,
251
- )
252
- for i in range(config.num_layers):
253
- if config.attention_block_config:
254
- self._map_attention_block(
255
- state,
256
- converted_state,
257
- tensor_names.attention_block_tensor_names[i],
258
- f"{converted_state_param_prefix}.attentions.{i}",
259
- config.attention_block_config,
260
- )
261
- self._map_residual_block(
262
- state,
263
- converted_state,
264
- tensor_names.residual_block_tensor_names[i + 1],
265
- f"{converted_state_param_prefix}.resnets.{i+1}",
266
- residual_block_config,
267
- )
268
-
269
- def _map_up_decoder_block(
270
- self,
271
- state: Dict[str, torch.Tensor],
272
- converted_state: Dict[str, torch.Tensor],
273
- converted_state_param_prefix: str,
274
- config: unet_config.UpDecoderBlock2DConfig,
275
- tensor_names: UpDecoderBlockTensorNames,
276
- ):
277
- for i in range(config.num_layers):
278
- input_channels = config.in_channels if i == 0 else config.out_channels
279
- self._map_residual_block(
280
- state,
281
- converted_state,
282
- tensor_names.residual_block_tensor_names[i],
283
- f"{converted_state_param_prefix}.resnets.{i}",
284
- unet_config.ResidualBlock2DConfig(
285
- in_channels=input_channels,
286
- out_channels=config.out_channels,
287
- time_embedding_channels=config.time_embedding_channels,
288
- normalization_config=config.normalization_config,
289
- activation_type=config.activation_type,
290
- ),
291
- )
292
- if config.add_upsample and config.upsample_conv:
293
- _map_to_converted_state(
294
- state,
295
- tensor_names.upsample_conv,
296
- converted_state,
297
- f"{converted_state_param_prefix}.upsample_conv",
298
- )