ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +1 -1
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
- ai_edge_torch/generative/layers/attention.py +4 -18
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +38 -44
- ai_edge_torch/generative/test/test_model_conversion.py +38 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
- ai_edge_torch/generative/utilities/converter.py +5 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +22 -25
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
- ai_edge_torch/generative/layers/experimental/attention.py +0 -231
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -48,10 +48,8 @@ class TestModelConversion(googletest.TestCase):
|
|
48
48
|
|
49
49
|
def setUp(self):
|
50
50
|
super().setUp()
|
51
|
-
# Builder function for an Interpreter that supports custom ops.
|
52
51
|
self._interpreter_builder = (
|
53
|
-
lambda tflite_model: lambda: interpreter.
|
54
|
-
custom_op_registerers=["GenAIOpsRegisterer"],
|
52
|
+
lambda tflite_model: lambda: interpreter.Interpreter(
|
55
53
|
model_content=tflite_model,
|
56
54
|
experimental_default_delegate_latest_features=True,
|
57
55
|
)
|
@@ -94,110 +92,62 @@ class TestModelConversion(googletest.TestCase):
|
|
94
92
|
)
|
95
93
|
)
|
96
94
|
|
97
|
-
@googletest.skipIf(
|
98
|
-
ai_edge_torch.config.in_oss,
|
99
|
-
reason="tests with custom ops are not supported in oss",
|
100
|
-
)
|
101
95
|
def test_gemma1(self):
|
102
96
|
config = gemma1.get_fake_model_config()
|
103
97
|
pytorch_model = gemma1.Gemma1(config).eval()
|
104
98
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
105
99
|
|
106
|
-
@googletest.skipIf(
|
107
|
-
ai_edge_torch.config.in_oss,
|
108
|
-
reason="tests with custom ops are not supported in oss",
|
109
|
-
)
|
110
100
|
def test_gemma2(self):
|
111
101
|
config = gemma2.get_fake_model_config()
|
112
102
|
pytorch_model = gemma2.Gemma2(config).eval()
|
113
103
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
114
104
|
|
115
|
-
@googletest.skipIf(
|
116
|
-
ai_edge_torch.config.in_oss,
|
117
|
-
reason="tests with custom ops are not supported in oss",
|
118
|
-
)
|
119
105
|
def test_llama(self):
|
120
106
|
config = llama.get_fake_model_config()
|
121
107
|
pytorch_model = llama.Llama(config).eval()
|
122
108
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
123
109
|
|
124
|
-
@googletest.skipIf(
|
125
|
-
ai_edge_torch.config.in_oss,
|
126
|
-
reason="tests with custom ops are not supported in oss",
|
127
|
-
)
|
128
110
|
def test_phi2(self):
|
129
111
|
config = phi2.get_fake_model_config()
|
130
112
|
pytorch_model = phi2.Phi2(config).eval()
|
131
113
|
# Phi-2 logits are very big, so we need a larger absolute tolerance.
|
132
114
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
133
115
|
|
134
|
-
@googletest.skipIf(
|
135
|
-
ai_edge_torch.config.in_oss,
|
136
|
-
reason="tests with custom ops are not supported in oss",
|
137
|
-
)
|
138
116
|
def test_phi3(self):
|
139
117
|
config = phi3.get_fake_model_config()
|
140
118
|
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
141
119
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
142
120
|
|
143
|
-
@googletest.skipIf(
|
144
|
-
ai_edge_torch.config.in_oss,
|
145
|
-
reason="tests with custom ops are not supported in oss",
|
146
|
-
)
|
147
121
|
def test_phi4(self):
|
148
122
|
config = phi4.get_fake_model_config()
|
149
123
|
pytorch_model = phi4.Phi4Mini(config).eval()
|
150
124
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
151
125
|
|
152
|
-
@googletest.skipIf(
|
153
|
-
ai_edge_torch.config.in_oss,
|
154
|
-
reason="tests with custom ops are not supported in oss",
|
155
|
-
)
|
156
126
|
def test_smollm(self):
|
157
127
|
config = smollm.get_fake_model_config()
|
158
128
|
pytorch_model = smollm.SmolLM(config).eval()
|
159
129
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
160
130
|
|
161
|
-
@googletest.skipIf(
|
162
|
-
ai_edge_torch.config.in_oss,
|
163
|
-
reason="tests with custom ops are not supported in oss",
|
164
|
-
)
|
165
131
|
def test_smollm2(self):
|
166
132
|
config = smollm.get_fake_model_config_v2()
|
167
133
|
pytorch_model = smollm.SmolLM2(config).eval()
|
168
134
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
169
135
|
|
170
|
-
@googletest.skipIf(
|
171
|
-
ai_edge_torch.config.in_oss,
|
172
|
-
reason="tests with custom ops are not supported in oss",
|
173
|
-
)
|
174
136
|
def test_openelm(self):
|
175
137
|
config = openelm.get_fake_model_config()
|
176
138
|
pytorch_model = openelm.OpenELM(config).eval()
|
177
139
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
178
140
|
|
179
|
-
@googletest.skipIf(
|
180
|
-
ai_edge_torch.config.in_oss,
|
181
|
-
reason="tests with custom ops are not supported in oss",
|
182
|
-
)
|
183
141
|
def test_qwen(self):
|
184
142
|
config = qwen.get_fake_model_config()
|
185
143
|
pytorch_model = qwen.Qwen(config).eval()
|
186
144
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
187
145
|
|
188
|
-
@googletest.skipIf(
|
189
|
-
ai_edge_torch.config.in_oss,
|
190
|
-
reason="tests with custom ops are not supported in oss",
|
191
|
-
)
|
192
146
|
def test_deepseek(self):
|
193
147
|
config = deepseek.get_fake_model_config()
|
194
148
|
pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
|
195
149
|
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
196
150
|
|
197
|
-
@googletest.skipIf(
|
198
|
-
ai_edge_torch.config.in_oss,
|
199
|
-
reason="tests with custom ops are not supported in oss",
|
200
|
-
)
|
201
151
|
def test_amd_llama_135m(self):
|
202
152
|
config = amd_llama_135m.get_fake_model_config()
|
203
153
|
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
|
@@ -246,19 +196,11 @@ class TestModelConversion(googletest.TestCase):
|
|
246
196
|
)
|
247
197
|
)
|
248
198
|
|
249
|
-
@googletest.skipIf(
|
250
|
-
ai_edge_torch.config.in_oss,
|
251
|
-
reason="tests with custom ops are not supported in oss",
|
252
|
-
)
|
253
199
|
def test_paligemma1(self):
|
254
200
|
self._test_paligemma_model(
|
255
201
|
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
|
256
202
|
)
|
257
203
|
|
258
|
-
@googletest.skipIf(
|
259
|
-
ai_edge_torch.config.in_oss,
|
260
|
-
reason="tests with custom ops are not supported in oss",
|
261
|
-
)
|
262
204
|
def test_paligemma2(self):
|
263
205
|
self._test_paligemma_model(
|
264
206
|
decoder2.Decoder2,
|
@@ -267,10 +209,6 @@ class TestModelConversion(googletest.TestCase):
|
|
267
209
|
rtol=1e-5,
|
268
210
|
)
|
269
211
|
|
270
|
-
@googletest.skipIf(
|
271
|
-
ai_edge_torch.config.in_oss,
|
272
|
-
reason="tests with custom ops are not supported in oss",
|
273
|
-
)
|
274
212
|
def test_qwen_vl_model(self):
|
275
213
|
config = qwen_vl.get_fake_model_config()
|
276
214
|
pytorch_model = qwen_vl.QwenVL(config).eval()
|
@@ -316,10 +254,7 @@ class TestModelConversion(googletest.TestCase):
|
|
316
254
|
)
|
317
255
|
)
|
318
256
|
|
319
|
-
@googletest.skipIf(
|
320
|
-
ai_edge_torch.config.in_oss,
|
321
|
-
reason="tests with custom ops are not supported in oss",
|
322
|
-
)
|
257
|
+
@googletest.skipIf(ai_edge_torch.config.in_oss, reason="flaky")
|
323
258
|
def test_stable_diffusion_clip(self):
|
324
259
|
config = sd_clip.get_fake_model_config()
|
325
260
|
prompt_tokens = torch.from_numpy(
|
@@ -348,10 +283,7 @@ class TestModelConversion(googletest.TestCase):
|
|
348
283
|
)
|
349
284
|
)
|
350
285
|
|
351
|
-
@googletest.skipIf(
|
352
|
-
ai_edge_torch.config.in_oss,
|
353
|
-
reason="tests with custom ops are not supported in oss",
|
354
|
-
)
|
286
|
+
@googletest.skipIf(ai_edge_torch.config.in_oss, reason="b/413106901")
|
355
287
|
def test_stable_diffusion_diffusion(self):
|
356
288
|
config = sd_diffusion.get_fake_model_config(2)
|
357
289
|
# Reduce stddev(scale) of input values to avoid too big output logits which
|
@@ -390,10 +322,6 @@ class TestModelConversion(googletest.TestCase):
|
|
390
322
|
)
|
391
323
|
)
|
392
324
|
|
393
|
-
@googletest.skipIf(
|
394
|
-
ai_edge_torch.config.in_oss,
|
395
|
-
reason="tests with custom ops are not supported in oss",
|
396
|
-
)
|
397
325
|
def test_stable_diffusion_decoder(self):
|
398
326
|
config = sd_decoder.get_fake_model_config()
|
399
327
|
# Reduce stddev(scale) of input values to avoid too big output logits which
|
@@ -81,6 +81,11 @@ def define_conversion_flags(model_name: str):
|
|
81
81
|
'If set, the model will be converted with the provided list of LoRA'
|
82
82
|
' ranks.',
|
83
83
|
)
|
84
|
+
flags.DEFINE_bool(
|
85
|
+
'transpose_kv_cache',
|
86
|
+
False,
|
87
|
+
'If set, the model will be converted with transposed KV cache.',
|
88
|
+
)
|
84
89
|
|
85
90
|
return flags
|
86
91
|
|
@@ -34,6 +34,8 @@ fx_infra.decomp.update_pre_lower_decomp(
|
|
34
34
|
torch.ops.aten.replication_pad1d,
|
35
35
|
torch.ops.aten.replication_pad2d,
|
36
36
|
torch.ops.aten.replication_pad3d,
|
37
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
38
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
37
39
|
torch.ops.aten.addmm,
|
38
40
|
])
|
39
41
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250425
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,16 +2,15 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=_aF64u6MXH8zPBTEg6odQq2WazbUIxQYlfJNXzfkMdM,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
13
|
-
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=
|
14
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
|
13
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
|
15
14
|
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
16
15
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
17
16
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
@@ -19,7 +18,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha
|
|
19
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
20
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
|
21
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
22
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=IhEh3tTP3-AmQlpt24stKKEl0AIRyuo2REZIbhkmgJo,13940
|
23
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
24
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
|
25
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -54,8 +53,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
|
|
54
53
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
|
55
54
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
56
55
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
57
|
-
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=
|
58
|
-
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=1wz4h3bjyX2qMRZ310UKGNYTORegzxinVFmYz2Fupm4,2666
|
57
|
+
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
|
59
58
|
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
|
60
59
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
60
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
|
@@ -67,7 +66,7 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-
|
|
67
66
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
68
67
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
69
68
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
|
70
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
69
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
|
71
70
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
|
72
71
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
73
72
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
@@ -104,8 +103,8 @@ ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfS
|
|
104
103
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
105
104
|
ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
|
106
105
|
ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
107
|
-
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256
|
108
|
-
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=
|
106
|
+
ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=-Xe5koexhNUkWjS2XgS9Ggg7XOQAlMO8QcBJRTNjJa4,2972
|
107
|
+
ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
|
109
108
|
ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
|
110
109
|
ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
111
110
|
ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
|
@@ -150,7 +149,7 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
|
|
150
149
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
151
150
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
152
151
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
153
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
152
|
+
ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
|
154
153
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
155
154
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
156
155
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
@@ -159,12 +158,10 @@ ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1t
|
|
159
158
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
160
159
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
161
160
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
162
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
163
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
161
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
|
162
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
|
164
163
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
165
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
|
166
164
|
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
167
|
-
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
|
168
165
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
169
166
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
170
167
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -181,12 +178,12 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
|
|
181
178
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
|
182
179
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
183
180
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
184
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
185
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256
|
181
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
|
182
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=6LkLnFOvlnt7JVVDYKMaZClPRBEvdjq6xnSjIFYNdI8,12554
|
186
183
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
187
184
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
188
185
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
189
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
186
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=z3CvNJxKzglu1BU_5ri91RUeGHh7urhoWFbk0oq7i2M,10768
|
190
187
|
ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
|
191
188
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
192
189
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
@@ -228,7 +225,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
|
|
228
225
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
|
229
226
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
230
227
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
231
|
-
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=
|
228
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
232
229
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
|
233
230
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
234
231
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
@@ -245,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
245
242
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
246
243
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
247
244
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
251
|
-
ai_edge_torch_nightly-0.5.0.
|
252
|
-
ai_edge_torch_nightly-0.5.0.
|
245
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
246
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/METADATA,sha256=owGeoLcv0XFf4tXFatFjXLSisoaRBBwrtyLx3LFq8PM,2051
|
247
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/RECORD,,
|
@@ -1,129 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Build interpolate composite pass."""
|
16
|
-
|
17
|
-
import functools
|
18
|
-
|
19
|
-
from ai_edge_torch import fx_infra
|
20
|
-
from ai_edge_torch.hlfb import mark_pattern
|
21
|
-
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
|
-
import torch
|
23
|
-
|
24
|
-
# For torch nightly released after mid June 2024,
|
25
|
-
# torch.nn.functional.interpolate no longer gets exported into decomposed graph
|
26
|
-
# but a single aten op:
|
27
|
-
# torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
|
28
|
-
# This would interefere with our pattern matching based composite builder.
|
29
|
-
# Here we register the now missing decompositions first.
|
30
|
-
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
|
31
|
-
torch.ops.aten.upsample_bilinear2d.vec,
|
32
|
-
torch.ops.aten.upsample_nearest2d.vec,
|
33
|
-
])
|
34
|
-
|
35
|
-
|
36
|
-
@functools.cache
|
37
|
-
def _get_upsample_bilinear2d_pattern():
|
38
|
-
pattern = pattern_module.Pattern(
|
39
|
-
"odml.upsample_bilinear2d",
|
40
|
-
lambda x: torch.nn.functional.interpolate(
|
41
|
-
x, scale_factor=2, mode="bilinear", align_corners=False
|
42
|
-
),
|
43
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
44
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
45
|
-
)
|
46
|
-
|
47
|
-
@pattern.register_attr_builder
|
48
|
-
def attr_builder(pattern, graph_module, internal_match):
|
49
|
-
output = internal_match.returning_nodes[0]
|
50
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
51
|
-
return {
|
52
|
-
"size": (int(output_h), int(output_w)),
|
53
|
-
"align_corners": False,
|
54
|
-
"is_nchw_op": True,
|
55
|
-
}
|
56
|
-
|
57
|
-
return pattern
|
58
|
-
|
59
|
-
|
60
|
-
@functools.cache
|
61
|
-
def _get_upsample_bilinear2d_align_corners_pattern():
|
62
|
-
pattern = pattern_module.Pattern(
|
63
|
-
"odml.upsample_bilinear2d",
|
64
|
-
lambda x: torch.nn.functional.interpolate(
|
65
|
-
x, scale_factor=2, mode="bilinear", align_corners=True
|
66
|
-
),
|
67
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
68
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
69
|
-
)
|
70
|
-
|
71
|
-
@pattern.register_attr_builder
|
72
|
-
def attr_builder(graph_module, pattern, internal_match):
|
73
|
-
output = internal_match.returning_nodes[0]
|
74
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
75
|
-
return {
|
76
|
-
"size": (int(output_h), int(output_w)),
|
77
|
-
"align_corners": True,
|
78
|
-
"is_nchw_op": True,
|
79
|
-
}
|
80
|
-
|
81
|
-
return pattern
|
82
|
-
|
83
|
-
|
84
|
-
@functools.cache
|
85
|
-
def _get_interpolate_nearest2d_pattern():
|
86
|
-
pattern = pattern_module.Pattern(
|
87
|
-
"tfl.resize_nearest_neighbor",
|
88
|
-
lambda x: torch.nn.functional.interpolate(
|
89
|
-
x, scale_factor=2, mode="nearest"
|
90
|
-
),
|
91
|
-
export_args=(torch.rand(1, 3, 100, 100),),
|
92
|
-
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
93
|
-
)
|
94
|
-
|
95
|
-
@pattern.register_attr_builder
|
96
|
-
def attr_builder(pattern, graph_module, internal_match):
|
97
|
-
output = internal_match.returning_nodes[0]
|
98
|
-
output_h, output_w = output.meta["val"].shape[-2:]
|
99
|
-
return {
|
100
|
-
"size": (int(output_h), int(output_w)),
|
101
|
-
"is_nchw_op": True,
|
102
|
-
}
|
103
|
-
|
104
|
-
return pattern
|
105
|
-
|
106
|
-
|
107
|
-
class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
|
108
|
-
|
109
|
-
def __init__(self):
|
110
|
-
super().__init__()
|
111
|
-
self._patterns = [
|
112
|
-
_get_upsample_bilinear2d_pattern(),
|
113
|
-
_get_upsample_bilinear2d_align_corners_pattern(),
|
114
|
-
_get_interpolate_nearest2d_pattern(),
|
115
|
-
]
|
116
|
-
|
117
|
-
def call(self, exported_program: torch.export.ExportedProgram):
|
118
|
-
exported_program = fx_infra.safe_run_decompositions(
|
119
|
-
exported_program,
|
120
|
-
_INTERPOLATE_DECOMPOSITIONS,
|
121
|
-
)
|
122
|
-
|
123
|
-
graph_module = exported_program.graph_module
|
124
|
-
for pattern in self._patterns:
|
125
|
-
graph_module = mark_pattern.mark_pattern(graph_module, pattern)
|
126
|
-
|
127
|
-
graph_module.graph.lint()
|
128
|
-
graph_module.recompile()
|
129
|
-
return fx_infra.ExportedProgramPassResult(exported_program, True)
|