ai-edge-torch-nightly 0.6.0.dev20250828__py3-none-any.whl → 0.6.0.dev20250830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +21 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +94 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250828.dist-info → ai_edge_torch_nightly-0.6.0.dev20250830.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250828.dist-info → ai_edge_torch_nightly-0.6.0.dev20250830.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.6.0.dev20250828.dist-info → ai_edge_torch_nightly-0.6.0.dev20250830.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250828.dist-info → ai_edge_torch_nightly-0.6.0.dev20250830.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250828.dist-info → ai_edge_torch_nightly-0.6.0.dev20250830.dist-info}/top_level.txt +0 -0
@@ -594,6 +594,27 @@ class TestConvert(googletest.TestCase):
|
|
594
594
|
self.fail(f"Conversion failed with int64 inputs: {err}")
|
595
595
|
# pylint: enable=broad-except
|
596
596
|
|
597
|
+
def test_convert_model_with_torch_div_operation(self):
|
598
|
+
"""Test converting a simple model with torch.div operation."""
|
599
|
+
|
600
|
+
class SampleModel(nn.Module):
|
601
|
+
|
602
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
603
|
+
return x / y
|
604
|
+
|
605
|
+
model = SampleModel().eval()
|
606
|
+
args = (
|
607
|
+
torch.randint(0, 100, (10, 10), dtype=torch.int64),
|
608
|
+
torch.randint(0, 100, (10, 10), dtype=torch.int64),
|
609
|
+
)
|
610
|
+
|
611
|
+
try:
|
612
|
+
# Expect this to fix the error during conversion
|
613
|
+
ai_edge_torch.convert(model, args)
|
614
|
+
except Exception as err:
|
615
|
+
self.fail(f"Conversion failed with int64 inputs: {err}")
|
616
|
+
# pylint: enable=broad-except
|
617
|
+
|
597
618
|
def test_compile_model(self):
|
598
619
|
"""Tests AOT compilation of a simple Add module."""
|
599
620
|
|
@@ -116,7 +116,6 @@ lower_by_torch_xla2(torch.ops.aten.cosh)
|
|
116
116
|
lower_by_torch_xla2(torch.ops.aten.cumsum)
|
117
117
|
lower_by_torch_xla2(torch.ops.aten.detach)
|
118
118
|
lower_by_torch_xla2(torch.ops.aten.diagonal)
|
119
|
-
lower_by_torch_xla2(torch.ops.aten.div)
|
120
119
|
lower_by_torch_xla2(torch.ops.aten.dot)
|
121
120
|
lower_by_torch_xla2(torch.ops.aten.embedding)
|
122
121
|
lower_by_torch_xla2(torch.ops.aten.empty)
|
@@ -352,6 +351,8 @@ def _aten_mul_scalar(lctx: LoweringContext, self, other):
|
|
352
351
|
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
353
352
|
if promoted_type == jnp.float64:
|
354
353
|
promoted_type = jnp.float32
|
354
|
+
elif promoted_type == jnp.int64:
|
355
|
+
promoted_type = jnp.int32
|
355
356
|
return jnp.multiply(
|
356
357
|
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
357
358
|
)
|
@@ -369,6 +370,8 @@ def _aten_mul_tensor(lctx: LoweringContext, self, other):
|
|
369
370
|
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
370
371
|
if promoted_type == jnp.float64:
|
371
372
|
promoted_type = jnp.float32
|
373
|
+
elif promoted_type == jnp.int64:
|
374
|
+
promoted_type = jnp.int32
|
372
375
|
return jnp.multiply(
|
373
376
|
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
374
377
|
)
|
@@ -376,6 +379,96 @@ def _aten_mul_tensor(lctx: LoweringContext, self, other):
|
|
376
379
|
return jax_lowering(lctx, self, other)
|
377
380
|
|
378
381
|
|
382
|
+
@registry.lower(torch.ops.aten.div.Scalar)
|
383
|
+
def _aten_div_scalar(lctx: LoweringContext, self, other):
|
384
|
+
_log_usage(torch.ops.aten.div.Scalar)
|
385
|
+
|
386
|
+
@jax_bridge.wrap
|
387
|
+
def jax_lowering(self, other):
|
388
|
+
other_dtype = jnp.result_type(other)
|
389
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
390
|
+
if promoted_type == jnp.float64:
|
391
|
+
promoted_type = jnp.float32
|
392
|
+
elif promoted_type == jnp.int64:
|
393
|
+
promoted_type = jnp.int32
|
394
|
+
return jnp.divide(
|
395
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
396
|
+
)
|
397
|
+
|
398
|
+
return jax_lowering(lctx, self, other)
|
399
|
+
|
400
|
+
|
401
|
+
@registry.lower(torch.ops.aten.div.Scalar_mode)
|
402
|
+
def _aten_div_scalar_mode(lctx: LoweringContext, self, other, rounding_mode=""):
|
403
|
+
_log_usage(torch.ops.aten.div.Scalar_mode)
|
404
|
+
|
405
|
+
@jax_bridge.wrap
|
406
|
+
def jax_lowering(self, other):
|
407
|
+
other_dtype = jnp.result_type(other)
|
408
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
409
|
+
if promoted_type == jnp.float64:
|
410
|
+
promoted_type = jnp.float32
|
411
|
+
elif promoted_type == jnp.int64:
|
412
|
+
promoted_type = jnp.int32
|
413
|
+
if rounding_mode == "floor":
|
414
|
+
return jnp.floor_divide(
|
415
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
416
|
+
)
|
417
|
+
result = jnp.divide(
|
418
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
419
|
+
)
|
420
|
+
if rounding_mode == "trunc":
|
421
|
+
result = jnp.trunc(result)
|
422
|
+
return result
|
423
|
+
|
424
|
+
return jax_lowering(lctx, self, other)
|
425
|
+
|
426
|
+
|
427
|
+
@registry.lower(torch.ops.aten.div.Tensor)
|
428
|
+
def _aten_div_tensor(lctx: LoweringContext, self, other):
|
429
|
+
_log_usage(torch.ops.aten.div.Tensor)
|
430
|
+
|
431
|
+
@jax_bridge.wrap
|
432
|
+
def jax_lowering(self, other):
|
433
|
+
other_dtype = jnp.result_type(other)
|
434
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
435
|
+
if promoted_type == jnp.float64:
|
436
|
+
promoted_type = jnp.float32
|
437
|
+
elif promoted_type == jnp.int64:
|
438
|
+
promoted_type = jnp.int32
|
439
|
+
return jnp.divide(
|
440
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
441
|
+
)
|
442
|
+
|
443
|
+
return jax_lowering(lctx, self, other)
|
444
|
+
|
445
|
+
|
446
|
+
@registry.lower(torch.ops.aten.div.Tensor_mode)
|
447
|
+
def _aten_div_tensor_mode(lctx: LoweringContext, self, other, rounding_mode=""):
|
448
|
+
_log_usage(torch.ops.aten.div.Tensor_mode)
|
449
|
+
|
450
|
+
@jax_bridge.wrap
|
451
|
+
def jax_lowering(self, other):
|
452
|
+
other_dtype = jnp.result_type(other)
|
453
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
454
|
+
if promoted_type == jnp.float64:
|
455
|
+
promoted_type = jnp.float32
|
456
|
+
elif promoted_type == jnp.int64:
|
457
|
+
promoted_type = jnp.int32
|
458
|
+
if rounding_mode == "floor":
|
459
|
+
return jnp.floor_divide(
|
460
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
461
|
+
)
|
462
|
+
result = jnp.divide(
|
463
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
464
|
+
)
|
465
|
+
if rounding_mode == "trunc":
|
466
|
+
result = jnp.trunc(result)
|
467
|
+
return result
|
468
|
+
|
469
|
+
return jax_lowering(lctx, self, other)
|
470
|
+
|
471
|
+
|
379
472
|
@registry.lower(torch.ops.aten.where.self)
|
380
473
|
def _aten_where_self(lctx: LoweringContext, condition, self, other):
|
381
474
|
_log_usage(torch.ops.aten.where.self)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250830
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=hqWeY2tyRT1BA2-SphW3wPQ0c-LtS2ll4xHRheEB2WY,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
28
28
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=oXbr9G5Jc21xd1dr2CDrp774I4crs0_kkN490K5fNn0,7312
|
29
29
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
30
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=sz2-b3_Tf3tKQgvPqxtu0QR-wgKOuoAKcqz19dUby3g,19962
|
31
31
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
32
32
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
33
33
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -253,7 +253,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvs
|
|
253
253
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
254
254
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
255
255
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
256
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
256
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Hth8F1SwPVEympR1_T6sEsCmmRXtoO7zatPR1uRLFbY,18411
|
257
257
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
258
258
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
259
259
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
269
269
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
270
270
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
271
271
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
272
|
-
ai_edge_torch_nightly-0.6.0.
|
273
|
-
ai_edge_torch_nightly-0.6.0.
|
274
|
-
ai_edge_torch_nightly-0.6.0.
|
275
|
-
ai_edge_torch_nightly-0.6.0.
|
276
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/METADATA,sha256=ni14Rod4z0CU-YvXBiqbVKbprzGRTC6JlPyH3zQMo78,2074
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
275
|
+
ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
276
|
+
ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/RECORD,,
|
File without changes
|
File without changes
|