ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
- ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
- ai_edge_torch/generative/examples/hammer/verify.py +86 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/llama/llama.py +3 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/phi2.py +1 -1
- ai_edge_torch/generative/examples/phi/phi3.py +3 -1
- ai_edge_torch/generative/examples/phi/phi4.py +3 -1
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
- ai_edge_torch/generative/layers/kv_cache.py +2 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
- ai_edge_torch/generative/test/test_model_conversion.py +3 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -75
- ai_edge_torch/generative/utilities/converter.py +11 -1
- ai_edge_torch/generative/utilities/export_config.py +30 -0
- ai_edge_torch/model.py +2 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +41 -39
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,16 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/model.py,sha256=
|
5
|
-
ai_edge_torch/version.py,sha256=
|
4
|
+
ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
|
5
|
+
ai_edge_torch/version.py,sha256=6qv9zJ0Z2J_RJ-E0S1o1-u2sbxvuuPUWnJcxWhmQEWg,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
13
|
-
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=
|
14
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
|
13
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
|
15
14
|
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
16
15
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
17
16
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
@@ -19,7 +18,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha
|
|
19
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
20
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
|
21
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
22
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=IhEh3tTP3-AmQlpt24stKKEl0AIRyuo2REZIbhkmgJo,13940
|
23
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
24
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
|
25
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -54,8 +53,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
|
|
54
53
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
|
55
54
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
56
55
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
57
|
-
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=
|
58
|
-
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=l0OrPGmX8WscuG9MIgtd0sqR4BeReNAu7fADzyPbnZw,1580
|
57
|
+
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
|
59
58
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
60
59
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
60
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
|
@@ -66,15 +65,19 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
|
|
66
65
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
67
66
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
68
67
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
69
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=MjkQDVynaw9C5z9ODzKfb85xW5JfxHUWBJ_Aco05FHo,1760
|
70
69
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
|
71
70
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
72
71
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
73
72
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
74
73
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
|
74
|
+
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
|
+
ai_edge_torch/generative/examples/hammer/convert_to_tflite.py,sha256=946mchDmvUhMsv1kzslp4LHtCIuHn4qjimHYQ-XnxMo,2962
|
76
|
+
ai_edge_torch/generative/examples/hammer/hammer.py,sha256=76INcjffvaNCQ02fzXcxJUW_6EKHs4sg3q1nDBbEpHE,3431
|
77
|
+
ai_edge_torch/generative/examples/hammer/verify.py,sha256=MkzAGkbPy4LKRhyCDm1cw-9jUt4VUxLPdwK_25fCGSE,2705
|
75
78
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
|
-
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=
|
77
|
-
ai_edge_torch/generative/examples/llama/llama.py,sha256=
|
79
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=nz5h4m8bVnw8P7OEtqhA_fKfvaRzxhT2_75vkFCqHmU,1735
|
80
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=H7I5iNhIJ55gb0-9k7g-FPcG2IlthnA9XMR8qd__5bQ,6621
|
78
81
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
79
82
|
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
83
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
@@ -94,18 +97,18 @@ ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4I
|
|
94
97
|
ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
|
95
98
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
|
96
99
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
|
-
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=
|
98
|
-
ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=
|
99
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
100
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
101
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
102
|
-
ai_edge_torch/generative/examples/phi/phi4.py,sha256=
|
100
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=k-0ZC-_zZZmkdcc6dr1QGXfX9lDZZXRQSuc6wT0n3Is,1514
|
101
|
+
ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py,sha256=5KSJRySjSc89FriCOnfBabD8zRLUcGAw3L0VInuJFUY,1512
|
102
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=wVIdGenHTi9xUffYddN_uXWMBO2tgo1e_hU4OG_NmHA,1513
|
103
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=X9MfjK8rmyRSrfNzIaKQNSgqLM5_CBH-BrLFX_7BWL8,3494
|
104
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=65Dbv8cA4WFdluflHQHzgDmDFjdmc6rxMO4hQukaxKU,6978
|
105
|
+
ai_edge_torch/generative/examples/phi/phi4.py,sha256=y3CCZCW4MnvX74d4MNERRuQBE0p5dquC2M9vDXXqnZI,5760
|
103
106
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
104
107
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
105
108
|
ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
|
106
109
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
107
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=
|
108
|
-
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=
|
110
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=eOpv3scJr4mVsJ9Obl7PBhMgd3a0T1t8dqoPp_VzZaQ,1776
|
111
|
+
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
|
109
112
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
110
113
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
111
114
|
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
|
@@ -116,9 +119,9 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
|
|
116
119
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
117
120
|
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
|
118
121
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
119
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
120
|
-
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=
|
121
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
122
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=jTM_tndbDqzq19uLz2n71S7M81L1Y6R7oVBPsMcYGzk,1785
|
123
|
+
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=wU72MzpUIi2mQ8ZODW1x4L5KZPWvuXyB-_Eqo-RKqFw,1757
|
124
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=SFE8fIJx7Y_oan0vXSmhEmI0Ib2HD3k9cyKLU_4MxfI,3807
|
122
125
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
|
123
126
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
124
127
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
@@ -144,8 +147,8 @@ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNH
|
|
144
147
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Crpj-vOwSViHpblXOrRJmsIn4DrHyuB3XZ8kHifb7LA,5203
|
145
148
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=-z5tkQzGHbo37eAl9sDAJuT1Egxm8xI9CZmYLcmqIfU,4761
|
146
149
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
147
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
148
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
150
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=XM-dCBW2HG6FlwwPjlJi0I_TEaVqdv7aWpFEv-XUdLc,1539
|
151
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=6Qhml-XB8_RjQdYN948OaSsPJNrfi-Mr7PFB73C79Ug,2828
|
149
152
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1tO2i0nUCqe-VkRgboA10VZ7KNg,2431
|
150
153
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
151
154
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
@@ -154,16 +157,15 @@ ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWn
|
|
154
157
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
155
158
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
156
159
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
157
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
160
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=dDeirtuo9AnlN1tYoLbFi_pKhIDmn35FQY1m6X28hSY,8468
|
158
161
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
159
162
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
160
163
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
161
164
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
162
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
163
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
165
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
|
166
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
|
164
167
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
165
168
|
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
166
|
-
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
|
167
169
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
168
170
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
169
171
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -180,13 +182,13 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
|
|
180
182
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
|
181
183
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
182
184
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
183
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
184
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256
|
185
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
|
186
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=vQWmpzMkJ2hPmWpg41ZMWwBsngTykRVzRPHtpbkwiLM,12811
|
185
187
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
186
188
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
187
189
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
188
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
189
|
-
ai_edge_torch/generative/utilities/export_config.py,sha256=
|
190
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=4RNNl7vk3WN_JG5EZajofiRSqtPnUNCYosxTacdEOto,10948
|
191
|
+
ai_edge_torch/generative/utilities/export_config.py,sha256=maUVt0T5FsLpHO5H-BZ-O0FRBZO_ejKwGhPR9Qq8ViM,2490
|
190
192
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
191
193
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
192
194
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
@@ -227,7 +229,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
|
|
227
229
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
|
228
230
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
229
231
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
230
|
-
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=
|
232
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
231
233
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
|
232
234
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
233
235
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
@@ -244,8 +246,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
244
246
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
245
247
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
246
248
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
247
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/METADATA,sha256=y_g3V3S_WlYlEmSNZWmP4kV5f_A1Nynk77VwS8qL_X0,2051
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
253
|
+
ai_edge_torch_nightly-0.5.0.dev20250426.dist-info/RECORD,,
|
@@ -1,129 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Build interpolate composite pass."""
|
16
|
-
|
17
|
-
import functools
|
18
|
-
|
19
|
-
from ai_edge_torch import fx_infra
|
20
|
-
from ai_edge_torch.hlfb import mark_pattern
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
|
-
import torch
|
23
|
-
|
24
|
-
# For torch nightly released after mid June 2024,
|
25
|
-
# torch.nn.functional.interpolate no longer gets exported into decomposed graph
|
26
|
-
# but a single aten op:
|
27
|
-
# torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
|
28
|
-
# This would interefere with our pattern matching based composite builder.
|
29
|
-
# Here we register the now missing decompositions first.
|
30
|
-
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
|
31
|
-
torch.ops.aten.upsample_bilinear2d.vec,
|
32
|
-
torch.ops.aten.upsample_nearest2d.vec,
|
33
|
-
])
|
34
|
-
|
35
|
-
|
36
|
-
@functools.cache
|
37
|
-
def _get_upsample_bilinear2d_pattern():
|
38
|
-
pattern = pattern_module.Pattern(
|
39
|
-
"odml.upsample_bilinear2d",
|
40
|
-
lambda x: torch.nn.functional.interpolate(
|
41
|
-
x, scale_factor=2, mode="bilinear", align_corners=False
|
42
|
-
),
|
43
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
44
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
45
|
-
)
|
46
|
-
|
47
|
-
@pattern.register_attr_builder
|
48
|
-
def attr_builder(pattern, graph_module, internal_match):
|
49
|
-
output = internal_match.returning_nodes[0]
|
50
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
51
|
-
return {
|
52
|
-
"size": (int(output_h), int(output_w)),
|
53
|
-
"align_corners": False,
|
54
|
-
"is_nchw_op": True,
|
55
|
-
}
|
56
|
-
|
57
|
-
return pattern
|
58
|
-
|
59
|
-
|
60
|
-
@functools.cache
|
61
|
-
def _get_upsample_bilinear2d_align_corners_pattern():
|
62
|
-
pattern = pattern_module.Pattern(
|
63
|
-
"odml.upsample_bilinear2d",
|
64
|
-
lambda x: torch.nn.functional.interpolate(
|
65
|
-
x, scale_factor=2, mode="bilinear", align_corners=True
|
66
|
-
),
|
67
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
68
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
69
|
-
)
|
70
|
-
|
71
|
-
@pattern.register_attr_builder
|
72
|
-
def attr_builder(graph_module, pattern, internal_match):
|
73
|
-
output = internal_match.returning_nodes[0]
|
74
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
75
|
-
return {
|
76
|
-
"size": (int(output_h), int(output_w)),
|
77
|
-
"align_corners": True,
|
78
|
-
"is_nchw_op": True,
|
79
|
-
}
|
80
|
-
|
81
|
-
return pattern
|
82
|
-
|
83
|
-
|
84
|
-
@functools.cache
|
85
|
-
def _get_interpolate_nearest2d_pattern():
|
86
|
-
pattern = pattern_module.Pattern(
|
87
|
-
"tfl.resize_nearest_neighbor",
|
88
|
-
lambda x: torch.nn.functional.interpolate(
|
89
|
-
x, scale_factor=2, mode="nearest"
|
90
|
-
),
|
91
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
92
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
93
|
-
)
|
94
|
-
|
95
|
-
@pattern.register_attr_builder
|
96
|
-
def attr_builder(pattern, graph_module, internal_match):
|
97
|
-
output = internal_match.returning_nodes[0]
|
98
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
99
|
-
return {
|
100
|
-
"size": (int(output_h), int(output_w)),
|
101
|
-
"is_nchw_op": True,
|
102
|
-
}
|
103
|
-
|
104
|
-
return pattern
|
105
|
-
|
106
|
-
|
107
|
-
class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
|
108
|
-
|
109
|
-
def __init__(self):
|
110
|
-
super().__init__()
|
111
|
-
self._patterns = [
|
112
|
-
_get_upsample_bilinear2d_pattern(),
|
113
|
-
_get_upsample_bilinear2d_align_corners_pattern(),
|
114
|
-
_get_interpolate_nearest2d_pattern(),
|
115
|
-
]
|
116
|
-
|
117
|
-
def call(self, exported_program: torch.export.ExportedProgram):
|
118
|
-
exported_program = fx_infra.safe_run_decompositions(
|
119
|
-
exported_program,
|
120
|
-
_INTERPOLATE_DECOMPOSITIONS,
|
121
|
-
)
|
122
|
-
|
123
|
-
graph_module = exported_program.graph_module
|
124
|
-
for pattern in self._patterns:
|
125
|
-
graph_module = mark_pattern.mark_pattern(graph_module, pattern)
|
126
|
-
|
127
|
-
graph_module.graph.lint()
|
128
|
-
graph_module.recompile()
|
129
|
-
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -1,93 +0,0 @@
|
|
1
|
-
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# Implements scaled dot product attention. This is experimental and
|
16
|
-
# GPU-specific code.
|
17
|
-
|
18
|
-
import math
|
19
|
-
from typing import Optional
|
20
|
-
|
21
|
-
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
22
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
|
-
from ai_edge_torch.generative.utilities import types
|
24
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
25
|
-
from multipledispatch import dispatch
|
26
|
-
import torch
|
27
|
-
import torch.nn.functional as F
|
28
|
-
|
29
|
-
|
30
|
-
def scaled_dot_product_attention(
|
31
|
-
kv: kv_utils.KVCacheEntry,
|
32
|
-
query: torch.Tensor,
|
33
|
-
key: torch.Tensor,
|
34
|
-
value: torch.Tensor,
|
35
|
-
head_size: int,
|
36
|
-
mask: Optional[torch.Tensor] = None,
|
37
|
-
scale: Optional[float] = None,
|
38
|
-
softcap: Optional[float] = None,
|
39
|
-
):
|
40
|
-
if hasattr(kv, "kv_layout"):
|
41
|
-
return _sdpa(
|
42
|
-
kv.kv_layout[0](), # key layout
|
43
|
-
kv.kv_layout[1](), # value layout
|
44
|
-
query=query,
|
45
|
-
key=key,
|
46
|
-
value=value,
|
47
|
-
head_size=head_size,
|
48
|
-
mask=mask,
|
49
|
-
scale=scale,
|
50
|
-
softcap=softcap,
|
51
|
-
)
|
52
|
-
raise ValueError("No kv_layout attribute found in kv.")
|
53
|
-
|
54
|
-
|
55
|
-
@dispatch(types.BNTH, types.BNHT)
|
56
|
-
def _sdpa(k_type, v_type, *args, **kwargs):
|
57
|
-
query = kwargs["query"]
|
58
|
-
key = kwargs["key"]
|
59
|
-
value = kwargs["value"]
|
60
|
-
head_size = kwargs["head_size"]
|
61
|
-
mask = kwargs.get("mask", None)
|
62
|
-
scale = kwargs.get("scale", None)
|
63
|
-
softcap = kwargs.get("softcap", None)
|
64
|
-
|
65
|
-
if scale is None:
|
66
|
-
scale = 1.0 / math.sqrt(head_size)
|
67
|
-
|
68
|
-
query = query * scale
|
69
|
-
|
70
|
-
assert mask is not None, "Mask should not be None!"
|
71
|
-
t = mask.shape[2]
|
72
|
-
|
73
|
-
logits = bmm_lib.bmm_4d(query, key)
|
74
|
-
|
75
|
-
_, bk, gt, s = logits.shape
|
76
|
-
g = gt // t
|
77
|
-
logits = logits.reshape((bk, g, t, s))
|
78
|
-
if softcap is not None:
|
79
|
-
logits = torch.tanh(logits / softcap)
|
80
|
-
logits = logits * softcap
|
81
|
-
|
82
|
-
padded_logits = logits + mask
|
83
|
-
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
84
|
-
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
85
|
-
encoded = bmm_lib.bmm_4d(probs, value)
|
86
|
-
|
87
|
-
return encoded # 1, bk, gt, h
|
88
|
-
|
89
|
-
|
90
|
-
@dispatch(object, object)
|
91
|
-
def _sdpa(k_type, v_type, *args, **kwargs):
|
92
|
-
|
93
|
-
raise ValueError(f"No implementations for k={k_type} and v={v_type}")
|
File without changes
|
File without changes
|