ai-edge-torch-nightly 0.4.0.dev20250401__py3-none-any.whl → 0.4.0.dev20250404__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/test/test_custom_dus.py +1 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +47 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250401.dist-info → ai_edge_torch_nightly-0.4.0.dev20250404.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250401.dist-info → ai_edge_torch_nightly-0.4.0.dev20250404.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.4.0.dev20250401.dist-info → ai_edge_torch_nightly-0.4.0.dev20250404.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250401.dist-info → ai_edge_torch_nightly-0.4.0.dev20250404.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250401.dist-info → ai_edge_torch_nightly-0.4.0.dev20250404.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Utilities for building MLIR lowerings."""
|
16
16
|
|
17
|
+
from collections.abc import Callable
|
17
18
|
import functools
|
18
19
|
import numbers
|
19
|
-
from typing import Any
|
20
|
-
from typing import Optional
|
20
|
+
from typing import Any, Optional, Union
|
21
21
|
from ai_edge_torch.odml_torch import export_utils
|
22
22
|
from jax._src.lib.mlir import ir
|
23
23
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
@@ -222,3 +222,48 @@ def convert_int_to_float(t: ir.Value) -> ir.Value:
|
|
222
222
|
return stablehlo.convert(
|
223
223
|
ir.RankedTensorType.get(t.type.shape, ir.F64Type.get()), t
|
224
224
|
)
|
225
|
+
|
226
|
+
|
227
|
+
# IR Helpers
|
228
|
+
IrValues = Union[ir.Value, tuple[ir.Value, ...]]
|
229
|
+
|
230
|
+
|
231
|
+
# Non-canonicalized dtype to IR type mapping.
|
232
|
+
_numpy_dtype_to_ir_type: dict[np.dtype, Callable[[], ir.Type]] = {
|
233
|
+
np.dtype(np.bool_): functools.partial(ir.IntegerType.get_signless, 1),
|
234
|
+
np.dtype(np.int8): functools.partial(ir.IntegerType.get_signless, 8),
|
235
|
+
np.dtype(np.int16): functools.partial(ir.IntegerType.get_signless, 16),
|
236
|
+
np.dtype(np.int32): functools.partial(ir.IntegerType.get_signless, 32),
|
237
|
+
np.dtype(np.int64): functools.partial(ir.IntegerType.get_signless, 64),
|
238
|
+
np.dtype(np.uint8): functools.partial(ir.IntegerType.get_unsigned, 8),
|
239
|
+
np.dtype(np.uint16): functools.partial(ir.IntegerType.get_unsigned, 16),
|
240
|
+
np.dtype(np.uint32): functools.partial(ir.IntegerType.get_unsigned, 32),
|
241
|
+
np.dtype(np.uint64): functools.partial(ir.IntegerType.get_unsigned, 64),
|
242
|
+
np.dtype(np.float16): ir.F16Type.get,
|
243
|
+
np.dtype(np.float32): ir.F32Type.get,
|
244
|
+
np.dtype(np.float64): ir.F64Type.get,
|
245
|
+
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
246
|
+
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
247
|
+
}
|
248
|
+
|
249
|
+
|
250
|
+
def numpy_dtype_to_ir_type(dtype: np.dtype | np.generic) -> ir.Type:
|
251
|
+
assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
|
252
|
+
dtype = np.dtype(dtype)
|
253
|
+
try:
|
254
|
+
ir_type_factory = _numpy_dtype_to_ir_type[dtype]
|
255
|
+
except KeyError as err:
|
256
|
+
raise TypeError(
|
257
|
+
f"No numpy_dtype_to_ir_type handler for dtype: {dtype}"
|
258
|
+
) from err
|
259
|
+
return ir_type_factory()
|
260
|
+
|
261
|
+
|
262
|
+
def numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
|
263
|
+
element_type = numpy_dtype_to_ir_type(x.dtype)
|
264
|
+
shape = x.shape
|
265
|
+
if x.dtype == np.bool_:
|
266
|
+
x = np.packbits(x, bitorder="little") # type: ignore
|
267
|
+
x = np.ascontiguousarray(x)
|
268
|
+
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
|
269
|
+
return stablehlo.constant(attr)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250404
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=5Y5PWe5kvjSSYTDe16N_sj2O1E-n6hrbOjUlSz4mABA,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -176,7 +176,7 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
|
|
176
176
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
177
177
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
178
178
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
179
|
-
ai_edge_torch/generative/test/test_custom_dus.py,sha256=
|
179
|
+
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
180
180
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=MBPS-0bDXB0tQSKHa1XwDQeVIfabRbc8JQA99h9fzlQ,5961
|
181
181
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
182
182
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
@@ -232,7 +232,7 @@ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQD
|
|
232
232
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
233
233
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
234
234
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
235
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256
|
235
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=-TzK1igPgR38oZkU1iPh-DZhlKVwuBtGWVC-y81PXzY,8935
|
236
236
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
237
237
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
238
238
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
@@ -242,8 +242,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
242
242
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
243
243
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
244
244
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
245
|
-
ai_edge_torch_nightly-0.4.0.
|
246
|
-
ai_edge_torch_nightly-0.4.0.
|
247
|
-
ai_edge_torch_nightly-0.4.0.
|
248
|
-
ai_edge_torch_nightly-0.4.0.
|
249
|
-
ai_edge_torch_nightly-0.4.0.
|
245
|
+
ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
246
|
+
ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/METADATA,sha256=jpYvoXGbLlhrXvFTQ_mgRIp3skG8vmKU4966Je4VDNU,1966
|
247
|
+
ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
248
|
+
ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
249
|
+
ai_edge_torch_nightly-0.4.0.dev20250404.dist-info/RECORD,,
|
File without changes
|
File without changes
|