ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -3
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
  8. ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
  9. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
  10. ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
  11. ai_edge_torch/generative/examples/hammer/verify.py +86 -0
  12. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
  13. ai_edge_torch/generative/examples/llama/llama.py +3 -1
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
  15. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
  16. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
  17. ai_edge_torch/generative/examples/phi/phi2.py +1 -1
  18. ai_edge_torch/generative/examples/phi/phi3.py +3 -1
  19. ai_edge_torch/generative/examples/phi/phi4.py +3 -1
  20. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -3
  21. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
  23. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
  24. ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
  27. ai_edge_torch/generative/layers/kv_cache.py +2 -4
  28. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  29. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
  30. ai_edge_torch/generative/test/test_model_conversion.py +3 -33
  31. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -75
  32. ai_edge_torch/generative/utilities/converter.py +11 -1
  33. ai_edge_torch/generative/utilities/export_config.py +30 -0
  34. ai_edge_torch/model.py +2 -0
  35. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +41 -39
  39. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  40. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  41. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
  42. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
  43. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,16 @@
1
1
  ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=Nixp49eAXZPPMWEWkqpm_M4Mi_WGPx-I8q2noKuh0hw,706
4
+ ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
+ ai_edge_torch/version.py,sha256=6qv9zJ0Z2J_RJ-E0S1o1-u2sbxvuuPUWnJcxWhmQEWg,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
7
+ ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
11
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
12
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=6LtGzzqT2IXprfI_vPYKhE7IuN5XmPG0xy-v0UtZ9yk,1361
13
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
14
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
13
+ ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
15
14
  ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
16
15
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
17
16
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
@@ -19,7 +18,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha
19
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
20
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
21
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
22
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=IhEh3tTP3-AmQlpt24stKKEl0AIRyuo2REZIbhkmgJo,13940
23
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
24
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
25
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -54,8 +53,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
54
53
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
55
54
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
56
55
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
57
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=r6Pb5_LRKvw2QrOMn3PzunrVxPB-LSdyU2H1XORZo9A,1553
58
- ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
56
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
57
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
59
58
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
60
59
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
60
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
@@ -66,15 +65,19 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
66
65
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
67
66
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
68
67
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
69
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
68
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
70
69
  ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
71
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
72
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
73
72
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
74
73
  ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
74
+ ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
+ ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
76
+ ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
77
+ ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
75
78
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=A4uLUdqvU1NKo3seqZlWSS3fqYahnEKqNBQBJO6yXvE,1762
77
- ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
79
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=nz5h4m8bVnw8P7OEtqhA_fKfvaRzxhT2_75vkFCqHmU,1735
80
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=H7I5iNhIJ55gb0-9k7g-FPcG2IlthnA9XMR8qd__5bQ,6621
78
81
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
79
82
  ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
83
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
@@ -94,18 +97,18 @@ ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4I
94
97
  ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
95
98
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
96
99
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=Y2qaObMJeh9UABkUI7FBm4sCGi2YMQhsj0CSOS2fYek,1540
98
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=TuGW_FPMs0pV7ZBe46FfaDrlfte4Dz75vGHmBOCFfww,1538
99
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=VZe7OQ54dgOGWe74XT2W7zZBm5uJaeIF8ZuNakkL0iA,1539
100
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
101
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=ddo52Inl5ub81q460cEyKhnsC3txellRErut-_qtBbM,6949
102
- ai_edge_torch/generative/examples/phi/phi4.py,sha256=OkMwLGe8l2JEAgOFi19AdbNBl1xp1djZBZo8MJP58ho,5732
100
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=k-0ZC-_zZZmkdcc6dr1QGXfX9lDZZXRQSuc6wT0n3Is,1514
101
+ ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=5KSJRySjSc89FriCOnfBabD8zRLUcGAw3L0VInuJFUY,1512
102
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=wVIdGenHTi9xUffYddN_uXWMBO2tgo1e_hU4OG_NmHA,1513
103
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=X9MfjK8rmyRSrfNzIaKQNSgqLM5_CBH-BrLFX_7BWL8,3494
104
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=65Dbv8cA4WFdluflHQHzgDmDFjdmc6rxMO4hQukaxKU,6978
105
+ ai_edge_torch/generative/examples/phi/phi4.py,sha256=y3CCZCW4MnvX74d4MNERRuQBE0p5dquC2M9vDXXqnZI,5760
103
106
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
104
107
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
105
108
  ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
106
109
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
107
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=om3lXL1RnA87PkfU_cRfP6RnPgXrCmaB-cK98H-nqbA,1802
108
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
110
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVsJ9Obl7PBhMgd3a0T1t8dqoPp_VzZaQ,1776
111
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
109
112
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
110
113
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
111
114
  ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
@@ -116,9 +119,9 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
116
119
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
117
120
  ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
118
121
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
119
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=IjV0jriRKlF9aV5yLjtONjACb4_VxNIAGk9w1sr_hmc,1748
120
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=Wa_3OWXcM389iySwS5E47uCYZaTj6h-4RTP_Xi2-1aE,1721
121
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
122
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=jTM_tndbDqzq19uLz2n71S7M81L1Y6R7oVBPsMcYGzk,1785
123
+ ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=wU72MzpUIi2mQ8ZODW1x4L5KZPWvuXyB-_Eqo-RKqFw,1757
124
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=SFE8fIJx7Y_oan0vXSmhEmI0Ib2HD3k9cyKLU_4MxfI,3807
122
125
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
123
126
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
124
127
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
@@ -144,8 +147,8 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
144
147
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
145
148
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=-z5tkQzGHbo37eAl9sDAJuT1Egxm8xI9CZmYLcmqIfU,4761
146
149
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
147
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=LPxg7mAJ_aAUIx6eE5bxixPA8Ep9Vul0CWJoNcrD5oE,1565
148
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
150
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=XM-dCBW2HG6FlwwPjlJi0I_TEaVqdv7aWpFEv-XUdLc,1539
151
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=6Qhml-XB8_RjQdYN948OaSsPJNrfi-Mr7PFB73C79Ug,2828
149
152
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
150
153
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
151
154
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
@@ -154,16 +157,15 @@ ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWn
154
157
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
155
158
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
156
159
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
157
- ai_edge_torch/generative/layers/kv_cache.py,sha256=WNH_Ab29eXKXs8HAm3Wmdv_LBzO6PQW5d34Eo6Yzgd0,8492
160
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=dDeirtuo9AnlN1tYoLbFi_pKhIDmn35FQY1m6X28hSY,8468
158
161
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
159
162
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
160
163
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
161
164
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
162
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
163
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=D4rATT2Ppa9Su7yuRHYnQPJ1dFvUDAyH1GrFnCed7p8,3810
165
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
166
+ ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
164
167
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
165
168
  ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
166
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
167
169
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
168
170
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
169
171
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -180,13 +182,13 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
180
182
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
181
183
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
182
184
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
183
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=jSNJ0Eex6VYCkGn3FXbCOOJ2S3-F_QuwJctu3VycjR4,7200
184
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
185
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
186
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hPmWpg41ZMWwBsngTykRVzRPHtpbkwiLM,12811
185
187
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
186
188
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
187
189
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
188
- ai_edge_torch/generative/utilities/converter.py,sha256=LtBHjnslhL-uf4sDRoC8JIbbUD73g0QW3FiWsHUdV1g,10631
189
- ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
190
+ ai_edge_torch/generative/utilities/converter.py,sha256=4RNNl7vk3WN_JG5EZajofiRSqtPnUNCYosxTacdEOto,10948
191
+ ai_edge_torch/generative/utilities/export_config.py,sha256=maUVt0T5FsLpHO5H-BZ-O0FRBZO_ejKwGhPR9Qq8ViM,2490
190
192
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
191
193
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
192
194
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
@@ -227,7 +229,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
227
229
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
228
230
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
229
231
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
230
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
232
+ ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
231
233
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
232
234
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
233
235
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
@@ -244,8 +246,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
244
246
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
245
247
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
246
248
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
247
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
248
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/METADATA,sha256=Gz8c2qvL6qiK7lrd001P55TXltKdycDvDaAq4d4Y-eQ,2051
249
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
250
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
251
- ai_edge_torch_nightly-0.5.0.dev20250424.dist-info/RECORD,,
249
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
250
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/METADATA,sha256=y_g3V3S_WlYlEmSNZWmP4kV5f_A1Nynk77VwS8qL_X0,2051
251
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
252
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
253
+ ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/RECORD,,
@@ -1,129 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Build interpolate composite pass."""
16
-
17
- import functools
18
-
19
- from ai_edge_torch import fx_infra
20
- from ai_edge_torch.hlfb import mark_pattern
21
- from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
- import torch
23
-
24
- # For torch nightly released after mid June 2024,
25
- # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
- # but a single aten op:
27
- # torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
28
- # This would interefere with our pattern matching based composite builder.
29
- # Here we register the now missing decompositions first.
30
- _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
31
- torch.ops.aten.upsample_bilinear2d.vec,
32
- torch.ops.aten.upsample_nearest2d.vec,
33
- ])
34
-
35
-
36
- @functools.cache
37
- def _get_upsample_bilinear2d_pattern():
38
- pattern = pattern_module.Pattern(
39
- "odml.upsample_bilinear2d",
40
- lambda x: torch.nn.functional.interpolate(
41
- x, scale_factor=2, mode="bilinear", align_corners=False
42
- ),
43
- export_args=(torch.rand(1, 3, 100, 100),),
44
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
45
- )
46
-
47
- @pattern.register_attr_builder
48
- def attr_builder(pattern, graph_module, internal_match):
49
- output = internal_match.returning_nodes[0]
50
- output_h, output_w = output.meta["val"].shape[-2:]
51
- return {
52
- "size": (int(output_h), int(output_w)),
53
- "align_corners": False,
54
- "is_nchw_op": True,
55
- }
56
-
57
- return pattern
58
-
59
-
60
- @functools.cache
61
- def _get_upsample_bilinear2d_align_corners_pattern():
62
- pattern = pattern_module.Pattern(
63
- "odml.upsample_bilinear2d",
64
- lambda x: torch.nn.functional.interpolate(
65
- x, scale_factor=2, mode="bilinear", align_corners=True
66
- ),
67
- export_args=(torch.rand(1, 3, 100, 100),),
68
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
69
- )
70
-
71
- @pattern.register_attr_builder
72
- def attr_builder(graph_module, pattern, internal_match):
73
- output = internal_match.returning_nodes[0]
74
- output_h, output_w = output.meta["val"].shape[-2:]
75
- return {
76
- "size": (int(output_h), int(output_w)),
77
- "align_corners": True,
78
- "is_nchw_op": True,
79
- }
80
-
81
- return pattern
82
-
83
-
84
- @functools.cache
85
- def _get_interpolate_nearest2d_pattern():
86
- pattern = pattern_module.Pattern(
87
- "tfl.resize_nearest_neighbor",
88
- lambda x: torch.nn.functional.interpolate(
89
- x, scale_factor=2, mode="nearest"
90
- ),
91
- export_args=(torch.rand(1, 3, 100, 100),),
92
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
93
- )
94
-
95
- @pattern.register_attr_builder
96
- def attr_builder(pattern, graph_module, internal_match):
97
- output = internal_match.returning_nodes[0]
98
- output_h, output_w = output.meta["val"].shape[-2:]
99
- return {
100
- "size": (int(output_h), int(output_w)),
101
- "is_nchw_op": True,
102
- }
103
-
104
- return pattern
105
-
106
-
107
- class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
108
-
109
- def __init__(self):
110
- super().__init__()
111
- self._patterns = [
112
- _get_upsample_bilinear2d_pattern(),
113
- _get_upsample_bilinear2d_align_corners_pattern(),
114
- _get_interpolate_nearest2d_pattern(),
115
- ]
116
-
117
- def call(self, exported_program: torch.export.ExportedProgram):
118
- exported_program = fx_infra.safe_run_decompositions(
119
- exported_program,
120
- _INTERPOLATE_DECOMPOSITIONS,
121
- )
122
-
123
- graph_module = exported_program.graph_module
124
- for pattern in self._patterns:
125
- graph_module = mark_pattern.mark_pattern(graph_module, pattern)
126
-
127
- graph_module.graph.lint()
128
- graph_module.recompile()
129
- return fx_infra.ExportedProgramPassResult(exported_program, True)
@@ -1,93 +0,0 @@
1
- # Copyright 2025 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Implements scaled dot product attention. This is experimental and
16
- # GPU-specific code.
17
-
18
- import math
19
- from typing import Optional
20
-
21
- from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
22
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
- from ai_edge_torch.generative.utilities import types
24
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
- from multipledispatch import dispatch
26
- import torch
27
- import torch.nn.functional as F
28
-
29
-
30
- def scaled_dot_product_attention(
31
- kv: kv_utils.KVCacheEntry,
32
- query: torch.Tensor,
33
- key: torch.Tensor,
34
- value: torch.Tensor,
35
- head_size: int,
36
- mask: Optional[torch.Tensor] = None,
37
- scale: Optional[float] = None,
38
- softcap: Optional[float] = None,
39
- ):
40
- if hasattr(kv, "kv_layout"):
41
- return _sdpa(
42
- kv.kv_layout[0](), # key layout
43
- kv.kv_layout[1](), # value layout
44
- query=query,
45
- key=key,
46
- value=value,
47
- head_size=head_size,
48
- mask=mask,
49
- scale=scale,
50
- softcap=softcap,
51
- )
52
- raise ValueError("No kv_layout attribute found in kv.")
53
-
54
-
55
- @dispatch(types.BNTH, types.BNHT)
56
- def _sdpa(k_type, v_type, *args, **kwargs):
57
- query = kwargs["query"]
58
- key = kwargs["key"]
59
- value = kwargs["value"]
60
- head_size = kwargs["head_size"]
61
- mask = kwargs.get("mask", None)
62
- scale = kwargs.get("scale", None)
63
- softcap = kwargs.get("softcap", None)
64
-
65
- if scale is None:
66
- scale = 1.0 / math.sqrt(head_size)
67
-
68
- query = query * scale
69
-
70
- assert mask is not None, "Mask should not be None!"
71
- t = mask.shape[2]
72
-
73
- logits = bmm_lib.bmm_4d(query, key)
74
-
75
- _, bk, gt, s = logits.shape
76
- g = gt // t
77
- logits = logits.reshape((bk, g, t, s))
78
- if softcap is not None:
79
- logits = torch.tanh(logits / softcap)
80
- logits = logits * softcap
81
-
82
- padded_logits = logits + mask
83
- padded_logits = padded_logits.reshape(1, bk, gt, s)
84
- probs = F.softmax(padded_logits, dim=-1).type_as(key)
85
- encoded = bmm_lib.bmm_4d(probs, value)
86
-
87
- return encoded # 1, bk, gt, h
88
-
89
-
90
- @dispatch(object, object)
91
- def _sdpa(k_type, v_type, *args, **kwargs):
92
-
93
- raise ValueError(f"No implementations for k={k_type} and v={v_type}")