ai-edge-torch-nightly 0.5.0.dev20250423__py3-none-any.whl → 0.5.0.dev20250425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +38 -4
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/gemma3/decoder.py +1 -1
  8. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +37 -2
  9. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  10. ai_edge_torch/generative/layers/attention.py +4 -18
  11. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  12. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +38 -44
  13. ai_edge_torch/generative/test/test_model_conversion.py +38 -33
  14. ai_edge_torch/generative/test/test_model_conversion_large.py +3 -75
  15. ai_edge_torch/generative/utilities/converter.py +5 -0
  16. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/METADATA +1 -1
  19. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/RECORD +22 -25
  20. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  21. ai_edge_torch/generative/layers/experimental/attention.py +0 -231
  22. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  23. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/LICENSE +0 -0
  24. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/WHEEL +0 -0
  25. {ai_edge_torch_nightly-0.5.0.dev20250423.dist-info → ai_edge_torch_nightly-0.5.0.dev20250425.dist-info}/top_level.txt +0 -0
@@ -48,10 +48,8 @@ class TestModelConversion(googletest.TestCase):
48
48
 
49
49
  def setUp(self):
50
50
  super().setUp()
51
- # Builder function for an Interpreter that supports custom ops.
52
51
  self._interpreter_builder = (
53
- lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
54
- custom_op_registerers=["GenAIOpsRegisterer"],
52
+ lambda tflite_model: lambda: interpreter.Interpreter(
55
53
  model_content=tflite_model,
56
54
  experimental_default_delegate_latest_features=True,
57
55
  )
@@ -94,110 +92,62 @@ class TestModelConversion(googletest.TestCase):
94
92
  )
95
93
  )
96
94
 
97
- @googletest.skipIf(
98
- ai_edge_torch.config.in_oss,
99
- reason="tests with custom ops are not supported in oss",
100
- )
101
95
  def test_gemma1(self):
102
96
  config = gemma1.get_fake_model_config()
103
97
  pytorch_model = gemma1.Gemma1(config).eval()
104
98
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
105
99
 
106
- @googletest.skipIf(
107
- ai_edge_torch.config.in_oss,
108
- reason="tests with custom ops are not supported in oss",
109
- )
110
100
  def test_gemma2(self):
111
101
  config = gemma2.get_fake_model_config()
112
102
  pytorch_model = gemma2.Gemma2(config).eval()
113
103
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
114
104
 
115
- @googletest.skipIf(
116
- ai_edge_torch.config.in_oss,
117
- reason="tests with custom ops are not supported in oss",
118
- )
119
105
  def test_llama(self):
120
106
  config = llama.get_fake_model_config()
121
107
  pytorch_model = llama.Llama(config).eval()
122
108
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
123
109
 
124
- @googletest.skipIf(
125
- ai_edge_torch.config.in_oss,
126
- reason="tests with custom ops are not supported in oss",
127
- )
128
110
  def test_phi2(self):
129
111
  config = phi2.get_fake_model_config()
130
112
  pytorch_model = phi2.Phi2(config).eval()
131
113
  # Phi-2 logits are very big, so we need a larger absolute tolerance.
132
114
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
133
115
 
134
- @googletest.skipIf(
135
- ai_edge_torch.config.in_oss,
136
- reason="tests with custom ops are not supported in oss",
137
- )
138
116
  def test_phi3(self):
139
117
  config = phi3.get_fake_model_config()
140
118
  pytorch_model = phi3.Phi3_5Mini(config).eval()
141
119
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
142
120
 
143
- @googletest.skipIf(
144
- ai_edge_torch.config.in_oss,
145
- reason="tests with custom ops are not supported in oss",
146
- )
147
121
  def test_phi4(self):
148
122
  config = phi4.get_fake_model_config()
149
123
  pytorch_model = phi4.Phi4Mini(config).eval()
150
124
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
151
125
 
152
- @googletest.skipIf(
153
- ai_edge_torch.config.in_oss,
154
- reason="tests with custom ops are not supported in oss",
155
- )
156
126
  def test_smollm(self):
157
127
  config = smollm.get_fake_model_config()
158
128
  pytorch_model = smollm.SmolLM(config).eval()
159
129
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
160
130
 
161
- @googletest.skipIf(
162
- ai_edge_torch.config.in_oss,
163
- reason="tests with custom ops are not supported in oss",
164
- )
165
131
  def test_smollm2(self):
166
132
  config = smollm.get_fake_model_config_v2()
167
133
  pytorch_model = smollm.SmolLM2(config).eval()
168
134
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
169
135
 
170
- @googletest.skipIf(
171
- ai_edge_torch.config.in_oss,
172
- reason="tests with custom ops are not supported in oss",
173
- )
174
136
  def test_openelm(self):
175
137
  config = openelm.get_fake_model_config()
176
138
  pytorch_model = openelm.OpenELM(config).eval()
177
139
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
178
140
 
179
- @googletest.skipIf(
180
- ai_edge_torch.config.in_oss,
181
- reason="tests with custom ops are not supported in oss",
182
- )
183
141
  def test_qwen(self):
184
142
  config = qwen.get_fake_model_config()
185
143
  pytorch_model = qwen.Qwen(config).eval()
186
144
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
187
145
 
188
- @googletest.skipIf(
189
- ai_edge_torch.config.in_oss,
190
- reason="tests with custom ops are not supported in oss",
191
- )
192
146
  def test_deepseek(self):
193
147
  config = deepseek.get_fake_model_config()
194
148
  pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
195
149
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
196
150
 
197
- @googletest.skipIf(
198
- ai_edge_torch.config.in_oss,
199
- reason="tests with custom ops are not supported in oss",
200
- )
201
151
  def test_amd_llama_135m(self):
202
152
  config = amd_llama_135m.get_fake_model_config()
203
153
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
@@ -246,19 +196,11 @@ class TestModelConversion(googletest.TestCase):
246
196
  )
247
197
  )
248
198
 
249
- @googletest.skipIf(
250
- ai_edge_torch.config.in_oss,
251
- reason="tests with custom ops are not supported in oss",
252
- )
253
199
  def test_paligemma1(self):
254
200
  self._test_paligemma_model(
255
201
  decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
256
202
  )
257
203
 
258
- @googletest.skipIf(
259
- ai_edge_torch.config.in_oss,
260
- reason="tests with custom ops are not supported in oss",
261
- )
262
204
  def test_paligemma2(self):
263
205
  self._test_paligemma_model(
264
206
  decoder2.Decoder2,
@@ -267,10 +209,6 @@ class TestModelConversion(googletest.TestCase):
267
209
  rtol=1e-5,
268
210
  )
269
211
 
270
- @googletest.skipIf(
271
- ai_edge_torch.config.in_oss,
272
- reason="tests with custom ops are not supported in oss",
273
- )
274
212
  def test_qwen_vl_model(self):
275
213
  config = qwen_vl.get_fake_model_config()
276
214
  pytorch_model = qwen_vl.QwenVL(config).eval()
@@ -316,10 +254,7 @@ class TestModelConversion(googletest.TestCase):
316
254
  )
317
255
  )
318
256
 
319
- @googletest.skipIf(
320
- ai_edge_torch.config.in_oss,
321
- reason="tests with custom ops are not supported in oss",
322
- )
257
+ @googletest.skipIf(ai_edge_torch.config.in_oss, reason="flaky")
323
258
  def test_stable_diffusion_clip(self):
324
259
  config = sd_clip.get_fake_model_config()
325
260
  prompt_tokens = torch.from_numpy(
@@ -348,10 +283,7 @@ class TestModelConversion(googletest.TestCase):
348
283
  )
349
284
  )
350
285
 
351
- @googletest.skipIf(
352
- ai_edge_torch.config.in_oss,
353
- reason="tests with custom ops are not supported in oss",
354
- )
286
+ @googletest.skipIf(ai_edge_torch.config.in_oss, reason="b/413106901")
355
287
  def test_stable_diffusion_diffusion(self):
356
288
  config = sd_diffusion.get_fake_model_config(2)
357
289
  # Reduce stddev(scale) of input values to avoid too big output logits which
@@ -390,10 +322,6 @@ class TestModelConversion(googletest.TestCase):
390
322
  )
391
323
  )
392
324
 
393
- @googletest.skipIf(
394
- ai_edge_torch.config.in_oss,
395
- reason="tests with custom ops are not supported in oss",
396
- )
397
325
  def test_stable_diffusion_decoder(self):
398
326
  config = sd_decoder.get_fake_model_config()
399
327
  # Reduce stddev(scale) of input values to avoid too big output logits which
@@ -81,6 +81,11 @@ def define_conversion_flags(model_name: str):
81
81
  'If set, the model will be converted with the provided list of LoRA'
82
82
  ' ranks.',
83
83
  )
84
+ flags.DEFINE_bool(
85
+ 'transpose_kv_cache',
86
+ False,
87
+ 'If set, the model will be converted with transposed KV cache.',
88
+ )
84
89
 
85
90
  return flags
86
91
 
@@ -34,6 +34,8 @@ fx_infra.decomp.update_pre_lower_decomp(
34
34
  torch.ops.aten.replication_pad1d,
35
35
  torch.ops.aten.replication_pad2d,
36
36
  torch.ops.aten.replication_pad3d,
37
+ torch.ops.aten.upsample_bilinear2d.vec,
38
+ torch.ops.aten.upsample_nearest2d.vec,
37
39
  torch.ops.aten.addmm,
38
40
  ])
39
41
  )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250423"
16
+ __version__ = "0.5.0.dev20250425"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250423
3
+ Version: 0.5.0.dev20250425
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,16 +2,15 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=DjzQwP8czvLmUu-dJhnWVQJHOuaOqJJKuH2_TOViMvg,706
5
+ ai_edge_torch/version.py,sha256=_aF64u6MXH8zPBTEg6odQq2WazbUIxQYlfJNXzfkMdM,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=dOr3TUfF0UCvkmlUrMqKvgaN4jh3lJ9XFuO-sHaAmIw,5521
7
+ ai_edge_torch/_convert/conversion.py,sha256=QVugYVfbyaeBgSKKbhFzHG5oXA7t3M-40JcpcdSu6W8,5436
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
11
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
12
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=6LtGzzqT2IXprfI_vPYKhE7IuN5XmPG0xy-v0UtZ9yk,1361
13
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
14
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=jbRCZmSduG_1qmngaEEtbofAyL1PKZ8P1uxzzsXQhsw,1253
13
+ ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
15
14
  ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
16
15
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
17
16
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
@@ -19,7 +18,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha
19
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
20
19
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
21
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
22
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=IhEh3tTP3-AmQlpt24stKKEl0AIRyuo2REZIbhkmgJo,13940
23
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
24
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
25
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -54,8 +53,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif8
54
53
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=z5MWiZLnsQzhNYMiQbcI9i0ki-dtkbimCptkiTFZxwo,1586
55
54
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
56
55
  ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
57
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=r6Pb5_LRKvw2QrOMn3PzunrVxPB-LSdyU2H1XORZo9A,1553
58
- ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
56
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=1wz4h3bjyX2qMRZ310UKGNYTORegzxinVFmYz2Fupm4,2666
57
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=yhS_i2kR0GJWpWciCt4p9Z9nHYh6A5uJ8Ycy2ebFN9w,2909
59
58
  ai_edge_torch/generative/examples/deepseek/verify.py,sha256=iYldze-pvZGvPkkqr6zA7EmitPnH9sXkzjNVx353IcE,2403
60
59
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
60
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=tSEtGeS-Ndcc_cTm7c4CT4FqRiwrHedEv1oJk4Y_zYU,1552
@@ -67,7 +66,7 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-
67
66
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
68
67
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
69
68
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
70
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=n6ZQfqNEHuOhY7Pu21bb8Eax8yn2Sx5osTKJKmhonXY,15659
69
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=eXWE5CSX0KeUMsPevgsYOfvyajl9F1RFF4DCWhHcYPA,15646
71
70
  ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=GACDBI_MsFowR8A3wAWrpzradPYe-AUgB9ZjXaVBG-s,6485
72
71
  ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
73
72
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
@@ -104,8 +103,8 @@ ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfS
104
103
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
105
104
  ai_edge_torch/generative/examples/phi/verify_phi4.py,sha256=BoCa5kUBRHtMQ-5ql6yD4pG4xHJMyUiQlpMOWVx-JgY,2356
106
105
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
107
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=om3lXL1RnA87PkfU_cRfP6RnPgXrCmaB-cK98H-nqbA,1802
108
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
106
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=-Xe5koexhNUkWjS2XgS9Ggg7XOQAlMO8QcBJRTNjJa4,2972
107
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=m8APYzo9N0SXsdvCxC8HtCcbN3W7gLKkRBL-Tg0BWXU,4223
109
108
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
110
109
  ai_edge_torch/generative/examples/qwen_vl/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
111
110
  ai_edge_torch/generative/examples/qwen_vl/convert_to_tflite.py,sha256=yVebRatt2SLCsGvrYTBXOM-0S2REhkpikHTyy5MCjUw,2222
@@ -150,7 +149,7 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
150
149
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
151
150
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
152
151
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
153
- ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
152
+ ai_edge_torch/generative/layers/attention.py,sha256=uK1ih2kxPZherwi-pGSm8B--NNWnQ8npEAfgcjMIkEY,12964
154
153
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
155
154
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
156
155
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
@@ -159,12 +158,10 @@ ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1t
159
158
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
160
159
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
161
160
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
162
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
163
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=oo9h7pi0GcuylRgp2yUuvUJCrhj03aoWt_fP7EDP4LM,3775
161
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=efqqGRZPJ55hKn1MQJ-cXfrJD85uS1v7W_juyGyts58,5648
162
+ ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=Hn8Zw-jiB9GH2uZ-yaRMcDdpmjECcW4uCy-YNH9zV8c,3693
164
163
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
165
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
166
164
  ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
167
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
168
165
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
169
166
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
170
167
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -181,12 +178,12 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvI
181
178
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
182
179
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
183
180
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
184
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
185
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3GyBn4k7BWVgyGzrbcL30Su3nxZYLtwkCs,14787
181
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=mhNJikLnGVGi9NKmXB8FhnqeDy9gtrvC3yEbrTABZ4Y,6163
182
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=6LkLnFOvlnt7JVVDYKMaZClPRBEvdjq6xnSjIFYNdI8,12554
186
183
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
187
184
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
188
185
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
189
- ai_edge_torch/generative/utilities/converter.py,sha256=LtBHjnslhL-uf4sDRoC8JIbbUD73g0QW3FiWsHUdV1g,10631
186
+ ai_edge_torch/generative/utilities/converter.py,sha256=z3CvNJxKzglu1BU_5ri91RUeGHh7urhoWFbk0oq7i2M,10768
190
187
  ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
191
188
  ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
192
189
  ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
@@ -228,7 +225,7 @@ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62l
228
225
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=fEWjIdEpDIqT1EYLZE13O9A41OuaNdbfBrv3vNxS9gI,11601
229
226
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
230
227
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
231
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
228
+ ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
232
229
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
233
230
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
234
231
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
@@ -245,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
245
242
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
246
243
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
247
244
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
248
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/METADATA,sha256=PGzcX4WVfFW0wE0TSKLAuRB94iemrNff4L8CL_VUMnQ,2051
250
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
- ai_edge_torch_nightly-0.5.0.dev20250423.dist-info/RECORD,,
245
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/METADATA,sha256=owGeoLcv0XFf4tXFatFjXLSisoaRBBwrtyLx3LFq8PM,2051
247
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
+ ai_edge_torch_nightly-0.5.0.dev20250425.dist-info/RECORD,,
@@ -1,129 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Build interpolate composite pass."""
16
-
17
- import functools
18
-
19
- from ai_edge_torch import fx_infra
20
- from ai_edge_torch.hlfb import mark_pattern
21
- from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
22
- import torch
23
-
24
- # For torch nightly released after mid June 2024,
25
- # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
- # but a single aten op:
27
- # torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
28
- # This would interefere with our pattern matching based composite builder.
29
- # Here we register the now missing decompositions first.
30
- _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
31
- torch.ops.aten.upsample_bilinear2d.vec,
32
- torch.ops.aten.upsample_nearest2d.vec,
33
- ])
34
-
35
-
36
- @functools.cache
37
- def _get_upsample_bilinear2d_pattern():
38
- pattern = pattern_module.Pattern(
39
- "odml.upsample_bilinear2d",
40
- lambda x: torch.nn.functional.interpolate(
41
- x, scale_factor=2, mode="bilinear", align_corners=False
42
- ),
43
- export_args=(torch.rand(1, 3, 100, 100),),
44
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
45
- )
46
-
47
- @pattern.register_attr_builder
48
- def attr_builder(pattern, graph_module, internal_match):
49
- output = internal_match.returning_nodes[0]
50
- output_h, output_w = output.meta["val"].shape[-2:]
51
- return {
52
- "size": (int(output_h), int(output_w)),
53
- "align_corners": False,
54
- "is_nchw_op": True,
55
- }
56
-
57
- return pattern
58
-
59
-
60
- @functools.cache
61
- def _get_upsample_bilinear2d_align_corners_pattern():
62
- pattern = pattern_module.Pattern(
63
- "odml.upsample_bilinear2d",
64
- lambda x: torch.nn.functional.interpolate(
65
- x, scale_factor=2, mode="bilinear", align_corners=True
66
- ),
67
- export_args=(torch.rand(1, 3, 100, 100),),
68
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
69
- )
70
-
71
- @pattern.register_attr_builder
72
- def attr_builder(graph_module, pattern, internal_match):
73
- output = internal_match.returning_nodes[0]
74
- output_h, output_w = output.meta["val"].shape[-2:]
75
- return {
76
- "size": (int(output_h), int(output_w)),
77
- "align_corners": True,
78
- "is_nchw_op": True,
79
- }
80
-
81
- return pattern
82
-
83
-
84
- @functools.cache
85
- def _get_interpolate_nearest2d_pattern():
86
- pattern = pattern_module.Pattern(
87
- "tfl.resize_nearest_neighbor",
88
- lambda x: torch.nn.functional.interpolate(
89
- x, scale_factor=2, mode="nearest"
90
- ),
91
- export_args=(torch.rand(1, 3, 100, 100),),
92
- extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
93
- )
94
-
95
- @pattern.register_attr_builder
96
- def attr_builder(pattern, graph_module, internal_match):
97
- output = internal_match.returning_nodes[0]
98
- output_h, output_w = output.meta["val"].shape[-2:]
99
- return {
100
- "size": (int(output_h), int(output_w)),
101
- "is_nchw_op": True,
102
- }
103
-
104
- return pattern
105
-
106
-
107
- class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
108
-
109
- def __init__(self):
110
- super().__init__()
111
- self._patterns = [
112
- _get_upsample_bilinear2d_pattern(),
113
- _get_upsample_bilinear2d_align_corners_pattern(),
114
- _get_interpolate_nearest2d_pattern(),
115
- ]
116
-
117
- def call(self, exported_program: torch.export.ExportedProgram):
118
- exported_program = fx_infra.safe_run_decompositions(
119
- exported_program,
120
- _INTERPOLATE_DECOMPOSITIONS,
121
- )
122
-
123
- graph_module = exported_program.graph_module
124
- for pattern in self._patterns:
125
- graph_module = mark_pattern.mark_pattern(graph_module, pattern)
126
-
127
- graph_module.graph.lint()
128
- graph_module.recompile()
129
- return fx_infra.ExportedProgramPassResult(exported_program, True)