ai-edge-torch-nightly 0.4.0.dev20250402__py3-none-any.whl → 0.4.0.dev20250404__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@
15
15
 
16
16
  """A suite of tests to validate the Dynamic Update Slice Custom Op."""
17
17
 
18
+ from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
18
19
  import torch
19
20
  from torch import nn
20
21
 
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
  """Utilities for building MLIR lowerings."""
16
16
 
17
+ from collections.abc import Callable
17
18
  import functools
18
19
  import numbers
19
- from typing import Any
20
- from typing import Optional
20
+ from typing import Any, Optional, Union
21
21
  from ai_edge_torch.odml_torch import export_utils
22
22
  from jax._src.lib.mlir import ir
23
23
  from jax._src.lib.mlir.dialects import hlo as stablehlo
@@ -222,3 +222,48 @@ def convert_int_to_float(t: ir.Value) -> ir.Value:
222
222
  return stablehlo.convert(
223
223
  ir.RankedTensorType.get(t.type.shape, ir.F64Type.get()), t
224
224
  )
225
+
226
+
227
+ # IR Helpers
228
+ IrValues = Union[ir.Value, tuple[ir.Value, ...]]
229
+
230
+
231
+ # Non-canonicalized dtype to IR type mapping.
232
+ _numpy_dtype_to_ir_type: dict[np.dtype, Callable[[], ir.Type]] = {
233
+ np.dtype(np.bool_): functools.partial(ir.IntegerType.get_signless, 1),
234
+ np.dtype(np.int8): functools.partial(ir.IntegerType.get_signless, 8),
235
+ np.dtype(np.int16): functools.partial(ir.IntegerType.get_signless, 16),
236
+ np.dtype(np.int32): functools.partial(ir.IntegerType.get_signless, 32),
237
+ np.dtype(np.int64): functools.partial(ir.IntegerType.get_signless, 64),
238
+ np.dtype(np.uint8): functools.partial(ir.IntegerType.get_unsigned, 8),
239
+ np.dtype(np.uint16): functools.partial(ir.IntegerType.get_unsigned, 16),
240
+ np.dtype(np.uint32): functools.partial(ir.IntegerType.get_unsigned, 32),
241
+ np.dtype(np.uint64): functools.partial(ir.IntegerType.get_unsigned, 64),
242
+ np.dtype(np.float16): ir.F16Type.get,
243
+ np.dtype(np.float32): ir.F32Type.get,
244
+ np.dtype(np.float64): ir.F64Type.get,
245
+ np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
246
+ np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
247
+ }
248
+
249
+
250
+ def numpy_dtype_to_ir_type(dtype: np.dtype | np.generic) -> ir.Type:
251
+ assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
252
+ dtype = np.dtype(dtype)
253
+ try:
254
+ ir_type_factory = _numpy_dtype_to_ir_type[dtype]
255
+ except KeyError as err:
256
+ raise TypeError(
257
+ f"No numpy_dtype_to_ir_type handler for dtype: {dtype}"
258
+ ) from err
259
+ return ir_type_factory()
260
+
261
+
262
+ def numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
263
+ element_type = numpy_dtype_to_ir_type(x.dtype)
264
+ shape = x.shape
265
+ if x.dtype == np.bool_:
266
+ x = np.packbits(x, bitorder="little") # type: ignore
267
+ x = np.ascontiguousarray(x)
268
+ attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
269
+ return stablehlo.constant(attr)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250402"
16
+ __version__ = "0.4.0.dev20250404"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250402
3
+ Version: 0.4.0.dev20250404
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=WqI1n7ZtdJC36jS18ODSCCHgcMQ12kAmRNk-TTrOeAo,706
5
+ ai_edge_torch/version.py,sha256=5Y5PWe5kvjSSYTDe16N_sj2O1E-n6hrbOjUlSz4mABA,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -176,7 +176,7 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
176
176
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
177
177
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
178
178
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
179
- ai_edge_torch/generative/test/test_custom_dus.py,sha256=ifgnUCWihT59eFdLrlc5_j9sWygEKclU6Iqw6zdlgeI,3177
179
+ ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
180
180
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=MBPS-0bDXB0tQSKHa1XwDQeVIfabRbc8JQA99h9fzlQ,5961
181
181
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
182
182
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
@@ -232,7 +232,7 @@ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQD
232
232
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
233
233
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
234
234
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
235
- ai_edge_torch/odml_torch/lowerings/utils.py,sha256=tIyZiSy2Rbtea9ZlRYUfTaYE5vW_lAU6itT6_rUp8Qg,7028
235
+ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=-TzK1igPgR38oZkU1iPh-DZhlKVwuBtGWVC-y81PXzY,8935
236
236
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
237
237
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
238
238
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -242,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
242
242
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
243
243
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
244
244
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
245
- ai_edge_torch_nightly-0.4.0.dev20250402.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
- ai_edge_torch_nightly-0.4.0.dev20250402.dist-info/METADATA,sha256=Dm1OfGpmUSSXIT6gnxQlAYu5qm8e6WbLdq-ocbXEoBU,1966
247
- ai_edge_torch_nightly-0.4.0.dev20250402.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
- ai_edge_torch_nightly-0.4.0.dev20250402.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
- ai_edge_torch_nightly-0.4.0.dev20250402.dist-info/RECORD,,
245
+ ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
+ ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/METADATA,sha256=jpYvoXGbLlhrXvFTQ_mgRIp3skG8vmKU4966Je4VDNU,1966
247
+ ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
+ ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
+ ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/RECORD,,