ai-edge-torch-nightly 0.6.0.dev20250828__py3-none-any.whl → 0.6.0.dev20250830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -594,6 +594,27 @@ class TestConvert(googletest.TestCase):
594
594
  self.fail(f"Conversion failed with int64 inputs: {err}")
595
595
  # pylint: enable=broad-except
596
596
 
597
+ def test_convert_model_with_torch_div_operation(self):
598
+ """Test converting a simple model with torch.div operation."""
599
+
600
+ class SampleModel(nn.Module):
601
+
602
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
603
+ return x / y
604
+
605
+ model = SampleModel().eval()
606
+ args = (
607
+ torch.randint(0, 100, (10, 10), dtype=torch.int64),
608
+ torch.randint(0, 100, (10, 10), dtype=torch.int64),
609
+ )
610
+
611
+ try:
612
+ # Expect this to fix the error during conversion
613
+ ai_edge_torch.convert(model, args)
614
+ except Exception as err:
615
+ self.fail(f"Conversion failed with int64 inputs: {err}")
616
+ # pylint: enable=broad-except
617
+
597
618
  def test_compile_model(self):
598
619
  """Tests AOT compilation of a simple Add module."""
599
620
 
@@ -116,7 +116,6 @@ lower_by_torch_xla2(torch.ops.aten.cosh)
116
116
  lower_by_torch_xla2(torch.ops.aten.cumsum)
117
117
  lower_by_torch_xla2(torch.ops.aten.detach)
118
118
  lower_by_torch_xla2(torch.ops.aten.diagonal)
119
- lower_by_torch_xla2(torch.ops.aten.div)
120
119
  lower_by_torch_xla2(torch.ops.aten.dot)
121
120
  lower_by_torch_xla2(torch.ops.aten.embedding)
122
121
  lower_by_torch_xla2(torch.ops.aten.empty)
@@ -352,6 +351,8 @@ def _aten_mul_scalar(lctx: LoweringContext, self, other):
352
351
  promoted_type = jnp.promote_types(self.dtype, other_dtype)
353
352
  if promoted_type == jnp.float64:
354
353
  promoted_type = jnp.float32
354
+ elif promoted_type == jnp.int64:
355
+ promoted_type = jnp.int32
355
356
  return jnp.multiply(
356
357
  self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
357
358
  )
@@ -369,6 +370,8 @@ def _aten_mul_tensor(lctx: LoweringContext, self, other):
369
370
  promoted_type = jnp.promote_types(self.dtype, other_dtype)
370
371
  if promoted_type == jnp.float64:
371
372
  promoted_type = jnp.float32
373
+ elif promoted_type == jnp.int64:
374
+ promoted_type = jnp.int32
372
375
  return jnp.multiply(
373
376
  self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
374
377
  )
@@ -376,6 +379,96 @@ def _aten_mul_tensor(lctx: LoweringContext, self, other):
376
379
  return jax_lowering(lctx, self, other)
377
380
 
378
381
 
382
+ @registry.lower(torch.ops.aten.div.Scalar)
383
+ def _aten_div_scalar(lctx: LoweringContext, self, other):
384
+ _log_usage(torch.ops.aten.div.Scalar)
385
+
386
+ @jax_bridge.wrap
387
+ def jax_lowering(self, other):
388
+ other_dtype = jnp.result_type(other)
389
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
390
+ if promoted_type == jnp.float64:
391
+ promoted_type = jnp.float32
392
+ elif promoted_type == jnp.int64:
393
+ promoted_type = jnp.int32
394
+ return jnp.divide(
395
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
396
+ )
397
+
398
+ return jax_lowering(lctx, self, other)
399
+
400
+
401
+ @registry.lower(torch.ops.aten.div.Scalar_mode)
402
+ def _aten_div_scalar_mode(lctx: LoweringContext, self, other, rounding_mode=""):
403
+ _log_usage(torch.ops.aten.div.Scalar_mode)
404
+
405
+ @jax_bridge.wrap
406
+ def jax_lowering(self, other):
407
+ other_dtype = jnp.result_type(other)
408
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
409
+ if promoted_type == jnp.float64:
410
+ promoted_type = jnp.float32
411
+ elif promoted_type == jnp.int64:
412
+ promoted_type = jnp.int32
413
+ if rounding_mode == "floor":
414
+ return jnp.floor_divide(
415
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
416
+ )
417
+ result = jnp.divide(
418
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
419
+ )
420
+ if rounding_mode == "trunc":
421
+ result = jnp.trunc(result)
422
+ return result
423
+
424
+ return jax_lowering(lctx, self, other)
425
+
426
+
427
+ @registry.lower(torch.ops.aten.div.Tensor)
428
+ def _aten_div_tensor(lctx: LoweringContext, self, other):
429
+ _log_usage(torch.ops.aten.div.Tensor)
430
+
431
+ @jax_bridge.wrap
432
+ def jax_lowering(self, other):
433
+ other_dtype = jnp.result_type(other)
434
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
435
+ if promoted_type == jnp.float64:
436
+ promoted_type = jnp.float32
437
+ elif promoted_type == jnp.int64:
438
+ promoted_type = jnp.int32
439
+ return jnp.divide(
440
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
441
+ )
442
+
443
+ return jax_lowering(lctx, self, other)
444
+
445
+
446
+ @registry.lower(torch.ops.aten.div.Tensor_mode)
447
+ def _aten_div_tensor_mode(lctx: LoweringContext, self, other, rounding_mode=""):
448
+ _log_usage(torch.ops.aten.div.Tensor_mode)
449
+
450
+ @jax_bridge.wrap
451
+ def jax_lowering(self, other):
452
+ other_dtype = jnp.result_type(other)
453
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
454
+ if promoted_type == jnp.float64:
455
+ promoted_type = jnp.float32
456
+ elif promoted_type == jnp.int64:
457
+ promoted_type = jnp.int32
458
+ if rounding_mode == "floor":
459
+ return jnp.floor_divide(
460
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
461
+ )
462
+ result = jnp.divide(
463
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
464
+ )
465
+ if rounding_mode == "trunc":
466
+ result = jnp.trunc(result)
467
+ return result
468
+
469
+ return jax_lowering(lctx, self, other)
470
+
471
+
379
472
  @registry.lower(torch.ops.aten.where.self)
380
473
  def _aten_where_self(lctx: LoweringContext, condition, self, other):
381
474
  _log_usage(torch.ops.aten.where.self)
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250828"
18
+ __version__ = "0.6.0.dev20250830"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250828
3
+ Version: 0.6.0.dev20250830
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=yl0TcYqmBJlm2Y0OYq0y49hU2ZDXXmQW_OvAUFDKogg,806
5
+ ai_edge_torch/version.py,sha256=hqWeY2tyRT1BA2-SphW3wPQ0c-LtS2ll4xHRheEB2WY,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
28
28
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=oXbr9G5Jc21xd1dr2CDrp774I4crs0_kkN490K5fNn0,7312
29
29
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
30
- ai_edge_torch/_convert/test/test_convert.py,sha256=ioeEpeS07NAs14-nHiPI-6lLTtnALxl8uNtTKvoHdgE,19316
30
+ ai_edge_torch/_convert/test/test_convert.py,sha256=sz2-b3_Tf3tKQgvPqxtu0QR-wgKOuoAKcqz19dUby3g,19962
31
31
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
32
32
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
33
33
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -253,7 +253,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvs
253
253
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
254
254
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
255
255
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
256
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=xUuQjoR0NJhuwG36GuycpKHo9jg783bDSHj9wE4F1Sg,15439
256
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Hth8F1SwPVEympR1_T6sEsCmmRXtoO7zatPR1uRLFbY,18411
257
257
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
258
258
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
259
259
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
269
269
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
270
270
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
271
271
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
272
- ai_edge_torch_nightly-0.6.0.dev20250828.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
- ai_edge_torch_nightly-0.6.0.dev20250828.dist-info/METADATA,sha256=TiZuiiYOhpawH2NS7795QCs04Vxzoau7dBquFrs0Tkc,2074
274
- ai_edge_torch_nightly-0.6.0.dev20250828.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
- ai_edge_torch_nightly-0.6.0.dev20250828.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
- ai_edge_torch_nightly-0.6.0.dev20250828.dist-info/RECORD,,
272
+ ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
+ ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/METADATA,sha256=ni14Rod4z0CU-YvXBiqbVKbprzGRTC6JlPyH3zQMo78,2074
274
+ ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
+ ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
+ ai_edge_torch_nightly-0.6.0.dev20250830.dist-info/RECORD,,