ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +24 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +12 -14
- ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=
|
2
|
-
ai_edge_torch/
|
1
|
+
ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
|
2
|
+
ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=iCH8lnlOrtbGwvxnT3knpY_keeu2UnrJ_ZXNK2LSvf4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -34,56 +34,56 @@ ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyu
|
|
34
34
|
ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
|
35
35
|
ai_edge_torch/debug/utils.py,sha256=vOAL4t6Lj47uhKapfEsc_WHmvwew3eKO9hSJyzvPXnU,1625
|
36
36
|
ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
37
|
-
ai_edge_torch/debug/test/test_culprit.py,sha256=
|
37
|
+
ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
|
38
38
|
ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
|
39
39
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
42
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
|
-
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=
|
44
|
-
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256
|
43
|
+
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
|
44
|
+
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=
|
48
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
49
|
-
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
47
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
|
49
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
54
54
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
|
-
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/llama/llama.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=ck7tXN0U25wAbbRjDcf-aqiS2YhismkmoZIsMpjIsjc,2536
|
56
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=BMjpdw6oOXmtqXCAfW9o7Iewaj-Hxd57xVrvSLBuHTk,6656
|
57
57
|
ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
|
58
58
|
ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
59
|
ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
|
60
60
|
ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
|
61
61
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
62
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256
|
63
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
62
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6xImvGh95--8cNl3QPs-Xml2XBgNJB_c2hY,2323
|
63
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
|
64
64
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
65
65
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
|
-
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=
|
67
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
|
67
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
|
68
68
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
69
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
69
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
|
70
70
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
71
71
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
72
72
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
73
73
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
74
|
-
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=
|
75
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
76
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
77
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
|
75
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
|
76
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
|
77
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=7Y1E4XpRuZOiSbeZJ-C2uJjmlnDtWv6L0XvPRE8oEQs,7112
|
78
78
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
79
79
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
80
80
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
81
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=
|
82
|
-
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=
|
81
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=BHkDsivbbfVBPxknkgWwtLskvxyrd42TXuCB0aLVbMY,2633
|
82
|
+
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
|
83
83
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
84
84
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
85
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
86
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
85
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=RKmSBMrup5A2bsPPaTdrBQb8NeRiUHy_1SUOA8DAs9U,2305
|
86
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=kk3cB_qaCzbFOhHtJlLb7qvSEBQTsILnoAcSFE3AkpE,2711
|
87
87
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
88
88
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
89
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
@@ -107,21 +107,21 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
|
|
107
107
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
108
108
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
109
109
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
110
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
110
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
|
111
111
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
112
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
113
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
|
113
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
114
114
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
115
115
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
116
116
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
117
117
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
118
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoTLzXXIy8dNtU8xC4,13199
|
119
119
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
120
120
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
121
121
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
122
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
122
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
123
123
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
124
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
124
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
|
125
125
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
126
126
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
127
127
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -137,34 +137,33 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
|
|
137
137
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
138
138
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
139
139
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
|
140
|
-
ai_edge_torch/generative/test/test_kv_cache.py,sha256=
|
140
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
|
141
141
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
142
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
143
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
142
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
143
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
|
144
144
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
145
|
-
ai_edge_torch/generative/test/utils.py,sha256=
|
145
|
+
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
146
146
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
147
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
147
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
|
148
148
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
149
149
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
150
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
150
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
|
151
151
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
152
152
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
153
153
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
154
154
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
155
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
155
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
|
156
156
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
157
157
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
158
158
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
159
159
|
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
|
160
160
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
161
161
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
162
|
-
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
|
163
162
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
164
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
|
165
164
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
166
165
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
|
167
|
-
ai_edge_torch/lowertools/test_utils.py,sha256=
|
166
|
+
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
168
167
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
|
169
168
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
170
169
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
@@ -182,15 +181,16 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
182
181
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
183
182
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
184
183
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
185
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
186
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
184
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
|
185
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
|
187
186
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
188
|
-
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=
|
189
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
190
|
-
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=
|
191
|
-
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=
|
187
|
+
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
188
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
189
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
190
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
|
192
191
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
193
|
-
ai_edge_torch/odml_torch/lowerings/
|
192
|
+
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
193
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
194
194
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
195
195
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
196
196
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
200
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
201
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
202
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/METADATA,sha256=fUbq26zB0WUU1l6eUud8vq3Nm3KSIhox74pzFSFTmoM,1966
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/RECORD,,
|
ai_edge_torch/config.py
DELETED
@@ -1,27 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Provides a configuration for the AI Edge Torch library."""
|
17
|
-
|
18
|
-
import dataclasses
|
19
|
-
import os
|
20
|
-
|
21
|
-
|
22
|
-
@dataclasses.dataclass
|
23
|
-
class Config:
|
24
|
-
use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
|
25
|
-
"1",
|
26
|
-
"true",
|
27
|
-
)
|
@@ -1,283 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Tests for StableHLOCompositeBuilder."""
|
16
|
-
|
17
|
-
import math
|
18
|
-
|
19
|
-
from ai_edge_torch import config
|
20
|
-
from ai_edge_torch import lowertools
|
21
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
22
|
-
import torch
|
23
|
-
import torch.nn.functional as F
|
24
|
-
|
25
|
-
from absl.testing import absltest as googletest
|
26
|
-
|
27
|
-
|
28
|
-
def _export_stablehlo_mlir(model, args):
|
29
|
-
ep = torch.export.export(model, args)
|
30
|
-
return lowertools.exported_program_to_mlir_text(ep)
|
31
|
-
|
32
|
-
|
33
|
-
@googletest.skipIf(
|
34
|
-
not config.Config.use_torch_xla,
|
35
|
-
reason="The odml_torch counter part is in odml_torch.",
|
36
|
-
)
|
37
|
-
class TestStableHLOCompositeBuilder(googletest.TestCase):
|
38
|
-
|
39
|
-
def test_build_composite(self):
|
40
|
-
class SampleModel(torch.nn.Module):
|
41
|
-
|
42
|
-
def forward(self, x):
|
43
|
-
builder = StableHLOCompositeBuilder(name="test.plus_two")
|
44
|
-
y = x + 1
|
45
|
-
y = builder.mark_inputs(y)
|
46
|
-
z = y + 2
|
47
|
-
z = builder.mark_outputs(z)
|
48
|
-
return z
|
49
|
-
|
50
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
51
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
|
52
|
-
|
53
|
-
def test_build_multiple_composites(self):
|
54
|
-
class SampleModel(torch.nn.Module):
|
55
|
-
|
56
|
-
def plus_one(self, x: torch.Tensor):
|
57
|
-
builder = StableHLOCompositeBuilder("test.plus_one")
|
58
|
-
x = builder.mark_inputs(x)
|
59
|
-
y = x + 1
|
60
|
-
y = builder.mark_outputs(y)
|
61
|
-
return y
|
62
|
-
|
63
|
-
def plus_two(self, x: torch.Tensor):
|
64
|
-
builder = StableHLOCompositeBuilder("test.plus_two")
|
65
|
-
x = builder.mark_inputs(x)
|
66
|
-
y = x + 2
|
67
|
-
y = builder.mark_outputs(y)
|
68
|
-
return y
|
69
|
-
|
70
|
-
def forward(self, x):
|
71
|
-
x = self.plus_two(x)
|
72
|
-
x = x + 3
|
73
|
-
x = self.plus_one(x)
|
74
|
-
x = x + 4
|
75
|
-
x = self.plus_two(x)
|
76
|
-
return x
|
77
|
-
|
78
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
79
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
|
80
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
|
81
|
-
|
82
|
-
def test_build_composite_with_attr(self):
|
83
|
-
class SampleModel(torch.nn.Module):
|
84
|
-
|
85
|
-
def __init__(self):
|
86
|
-
super().__init__()
|
87
|
-
|
88
|
-
def log_softmax(self, x: torch.Tensor, dim: int):
|
89
|
-
builder = StableHLOCompositeBuilder(
|
90
|
-
name="test.log_softmax", attr={"dim": dim}
|
91
|
-
)
|
92
|
-
x = builder.mark_inputs(x)
|
93
|
-
y = torch.nn.functional.log_softmax(x, dim=dim)
|
94
|
-
y = builder.mark_outputs(y)
|
95
|
-
return y
|
96
|
-
|
97
|
-
def forward(self, x):
|
98
|
-
x = x + 1
|
99
|
-
x = self.log_softmax(x, 0)
|
100
|
-
x = self.log_softmax(x, 1)
|
101
|
-
return x
|
102
|
-
|
103
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
104
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
|
105
|
-
self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
|
106
|
-
self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
|
107
|
-
|
108
|
-
def test_build_composite_with_mix_type_attrs(self):
|
109
|
-
class SampleModel(torch.nn.Module):
|
110
|
-
|
111
|
-
def __init__(self):
|
112
|
-
super().__init__()
|
113
|
-
|
114
|
-
def log_softmax(self, x: torch.Tensor, dim: int):
|
115
|
-
builder = StableHLOCompositeBuilder(
|
116
|
-
name="test.log_softmax",
|
117
|
-
attr={
|
118
|
-
"dim": dim,
|
119
|
-
"source": "torch.nn",
|
120
|
-
"version": 1.0,
|
121
|
-
},
|
122
|
-
)
|
123
|
-
x = builder.mark_inputs(x)
|
124
|
-
y = torch.nn.functional.log_softmax(x, dim=dim)
|
125
|
-
y = builder.mark_outputs(y)
|
126
|
-
return y
|
127
|
-
|
128
|
-
def forward(self, x):
|
129
|
-
x = x + 1
|
130
|
-
x = self.log_softmax(x, 0)
|
131
|
-
return x
|
132
|
-
|
133
|
-
mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
|
134
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
|
135
|
-
self.assertEqual(
|
136
|
-
mlir.count(
|
137
|
-
'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
|
138
|
-
" version = 1.000000e+00 : f32}"
|
139
|
-
),
|
140
|
-
1,
|
141
|
-
)
|
142
|
-
|
143
|
-
def test_sdpa_composite(self):
|
144
|
-
class SDPAModel(torch.nn.Module):
|
145
|
-
|
146
|
-
def scaled_dot_product_attention(
|
147
|
-
self,
|
148
|
-
q: torch.Tensor,
|
149
|
-
k: torch.Tensor,
|
150
|
-
v: torch.Tensor,
|
151
|
-
head_size: int,
|
152
|
-
mask: torch.Tensor,
|
153
|
-
):
|
154
|
-
builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
|
155
|
-
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
156
|
-
|
157
|
-
scale = 1.0 / math.sqrt(head_size)
|
158
|
-
|
159
|
-
q = q.transpose(1, 2)
|
160
|
-
k = k.transpose(1, 2)
|
161
|
-
v = v.transpose(1, 2)
|
162
|
-
y = F.scaled_dot_product_attention(
|
163
|
-
q,
|
164
|
-
k,
|
165
|
-
v,
|
166
|
-
attn_mask=mask,
|
167
|
-
dropout_p=0.0,
|
168
|
-
is_causal=mask is None,
|
169
|
-
scale=scale,
|
170
|
-
)
|
171
|
-
result = y.transpose(1, 2)
|
172
|
-
result = builder.mark_outputs(result)
|
173
|
-
return result
|
174
|
-
|
175
|
-
def forward(self, q, k, v, mask):
|
176
|
-
x = self.scaled_dot_product_attention(
|
177
|
-
q,
|
178
|
-
k,
|
179
|
-
v,
|
180
|
-
8,
|
181
|
-
mask,
|
182
|
-
)
|
183
|
-
return x
|
184
|
-
|
185
|
-
query = torch.rand(1, 1, 32, 4)
|
186
|
-
key = torch.rand(1, 500, 1, 4)
|
187
|
-
value = torch.rand(1, 500, 1, 4)
|
188
|
-
mask = torch.rand(1, 1, 1, 500)
|
189
|
-
|
190
|
-
mlir = _export_stablehlo_mlir(
|
191
|
-
SDPAModel().eval(),
|
192
|
-
(query, key, value, mask),
|
193
|
-
)
|
194
|
-
self.assertEqual(
|
195
|
-
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
|
196
|
-
)
|
197
|
-
|
198
|
-
def test_sdpa_composite_with_attr(self):
|
199
|
-
class SDPAModel(torch.nn.Module):
|
200
|
-
|
201
|
-
def scaled_dot_product_attention(
|
202
|
-
self,
|
203
|
-
q: torch.Tensor,
|
204
|
-
k: torch.Tensor,
|
205
|
-
v: torch.Tensor,
|
206
|
-
head_size: int,
|
207
|
-
include_captanh: bool,
|
208
|
-
):
|
209
|
-
builder = StableHLOCompositeBuilder(
|
210
|
-
name="test.scaled_dot_product_attention",
|
211
|
-
attr={"include_captanh": include_captanh},
|
212
|
-
)
|
213
|
-
q, k, v = builder.mark_inputs(q, k, v)
|
214
|
-
|
215
|
-
scale = 1.0 / math.sqrt(head_size)
|
216
|
-
|
217
|
-
q = q.transpose(1, 2)
|
218
|
-
k = k.transpose(1, 2)
|
219
|
-
v = v.transpose(1, 2)
|
220
|
-
y = F.scaled_dot_product_attention(
|
221
|
-
q,
|
222
|
-
k,
|
223
|
-
v,
|
224
|
-
attn_mask=None,
|
225
|
-
dropout_p=0.0,
|
226
|
-
is_causal=True,
|
227
|
-
scale=scale,
|
228
|
-
)
|
229
|
-
result = y.transpose(1, 2)
|
230
|
-
result = builder.mark_outputs(result)
|
231
|
-
return result
|
232
|
-
|
233
|
-
def forward(self, q, k, v):
|
234
|
-
x = self.scaled_dot_product_attention(q, k, v, 8, True)
|
235
|
-
y = self.scaled_dot_product_attention(q, k, v, 8, False)
|
236
|
-
return x + y
|
237
|
-
|
238
|
-
query = torch.rand(1, 1, 32, 4)
|
239
|
-
key = torch.rand(1, 500, 1, 4)
|
240
|
-
value = torch.rand(1, 500, 1, 4)
|
241
|
-
mlir = _export_stablehlo_mlir(
|
242
|
-
SDPAModel().eval(),
|
243
|
-
(query, key, value),
|
244
|
-
)
|
245
|
-
self.assertEqual(
|
246
|
-
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
|
247
|
-
)
|
248
|
-
self.assertEqual(
|
249
|
-
mlir.count("composite_attributes = {include_captanh = true}"), 1
|
250
|
-
)
|
251
|
-
self.assertEqual(
|
252
|
-
mlir.count("composite_attributes = {include_captanh = false}"), 1
|
253
|
-
)
|
254
|
-
|
255
|
-
def test_build_composite_with_multiple_inputs_outputs(self):
|
256
|
-
class SampleModel(torch.nn.Module):
|
257
|
-
|
258
|
-
def mimo_sample(self, a, b, c):
|
259
|
-
builder = StableHLOCompositeBuilder(name="test.mimo_sample")
|
260
|
-
|
261
|
-
a, b, c = builder.mark_inputs(a, b, c)
|
262
|
-
x = a + b + c
|
263
|
-
y = (a - b) * x
|
264
|
-
z = (c + 1.0) * a
|
265
|
-
x, y, z = builder.mark_outputs(x, y, z)
|
266
|
-
|
267
|
-
result = x + y * z
|
268
|
-
return result
|
269
|
-
|
270
|
-
def forward(self, a, b, c):
|
271
|
-
x = self.mimo_sample(a, b, c)
|
272
|
-
x = self.mimo_sample(a, b, x)
|
273
|
-
x = self.mimo_sample(x, x, c)
|
274
|
-
return x
|
275
|
-
|
276
|
-
mlir = _export_stablehlo_mlir(
|
277
|
-
SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
|
278
|
-
)
|
279
|
-
self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
|
280
|
-
|
281
|
-
|
282
|
-
if __name__ == "__main__":
|
283
|
-
googletest.main()
|
File without changes
|
File without changes
|