ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +24 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +12 -14
- ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,28 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from functools import partial
|
19
|
+
from typing import Any, Union
|
19
20
|
|
20
21
|
from ai_edge_torch._convert import converter as converter_utils
|
21
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
24
26
|
import torch
|
27
|
+
import torch.nn as nn
|
28
|
+
|
29
|
+
|
30
|
+
class ExportableModule(torch.nn.Module):
|
31
|
+
|
32
|
+
def __init__(self, module, **extra_kwargs):
|
33
|
+
super().__init__()
|
34
|
+
self.module = module
|
35
|
+
self.extra_kwargs = extra_kwargs
|
36
|
+
|
37
|
+
def forward(self, *export_args, **export_kwargs):
|
38
|
+
full_kwargs = {**export_kwargs, **self.extra_kwargs}
|
39
|
+
return self.module(*export_args, **full_kwargs)
|
25
40
|
|
26
41
|
|
27
42
|
def convert_to_tflite(
|
@@ -31,6 +46,7 @@ def convert_to_tflite(
|
|
31
46
|
pixel_values_size: torch.Size = None,
|
32
47
|
quantize: bool = True,
|
33
48
|
config: cfg.ModelConfig = None,
|
49
|
+
export_config: ExportConfig = None,
|
34
50
|
):
|
35
51
|
"""Converts a nn.Module model to multi-signature tflite model.
|
36
52
|
|
@@ -97,6 +113,11 @@ def convert_to_tflite(
|
|
97
113
|
)
|
98
114
|
|
99
115
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
|
+
|
117
|
+
# For export, we create a module that captures any non-exportable,
|
118
|
+
# arugments, e.g. the generation config object.
|
119
|
+
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
|
+
|
100
121
|
converter = converter_utils.Converter()
|
101
122
|
for i in range(len(prefill_seq_lens)):
|
102
123
|
prefill_seq_len = prefill_seq_lens[i]
|
@@ -108,7 +129,7 @@ def convert_to_tflite(
|
|
108
129
|
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
109
130
|
converter.add_signature(
|
110
131
|
prefill_signature_name,
|
111
|
-
|
132
|
+
mod,
|
112
133
|
sample_kwargs={
|
113
134
|
'tokens': prefill_tokens,
|
114
135
|
'input_pos': prefill_input_pos,
|
@@ -118,7 +139,7 @@ def convert_to_tflite(
|
|
118
139
|
if prefill_pixel_values is not None:
|
119
140
|
converter.add_signature(
|
120
141
|
prefill_signature_name + '_pixel',
|
121
|
-
|
142
|
+
mod,
|
122
143
|
sample_kwargs={
|
123
144
|
'tokens': prefill_tokens,
|
124
145
|
'input_pos': prefill_input_pos,
|
@@ -129,7 +150,7 @@ def convert_to_tflite(
|
|
129
150
|
|
130
151
|
converter.add_signature(
|
131
152
|
'decode',
|
132
|
-
|
153
|
+
mod,
|
133
154
|
sample_kwargs={
|
134
155
|
'tokens': decode_token,
|
135
156
|
'input_pos': decode_input_pos,
|
@@ -16,7 +16,8 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
from
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
|
|
45
46
|
TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
|
46
47
|
|
47
48
|
|
49
|
+
@dataclass
|
50
|
+
class ExportConfig:
|
51
|
+
"""Model generating configuration settings."""
|
52
|
+
|
53
|
+
# On prefill signatures, should the model produce logit output?
|
54
|
+
# When False, only decode signatures will produce output.
|
55
|
+
output_logits_on_prefill: bool = False
|
56
|
+
|
57
|
+
|
48
58
|
class DecoderOnlyModel(nn.Module):
|
49
59
|
"""A simple decoder-only transformer model built from the Edge Generative API.
|
50
60
|
|
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
|
|
93
103
|
tokens: torch.Tensor,
|
94
104
|
input_pos: torch.Tensor,
|
95
105
|
kv_cache: kv_utils.KVCache,
|
106
|
+
export_config: Optional[ExportConfig] = None,
|
96
107
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
108
|
_, seq_len = tokens.size()
|
98
109
|
assert self.config.max_seq_len >= seq_len, (
|
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
|
|
108
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
109
120
|
|
110
121
|
return self.forward_with_embeds(
|
111
|
-
input_embeds, rope, mask, input_pos, kv_cache
|
122
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
112
123
|
)
|
113
124
|
|
114
125
|
def forward_with_embeds(
|
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
|
|
118
129
|
mask: torch.Tensor,
|
119
130
|
input_pos: torch.Tensor,
|
120
131
|
kv_cache: kv_utils.KVCache,
|
132
|
+
export_config: Optional[ExportConfig] = None,
|
121
133
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
122
134
|
"""Forwards the model with input embeddings."""
|
123
135
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
|
|
137
149
|
updated_kv_entires.append(kv_entry)
|
138
150
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
139
151
|
|
152
|
+
if export_config is not None:
|
153
|
+
if (
|
154
|
+
torch.numel(input_pos) > 1
|
155
|
+
and not export_config.output_logits_on_prefill
|
156
|
+
):
|
157
|
+
return {"kv_cache": updated_kv_cache}
|
158
|
+
|
140
159
|
x = self.final_norm(x)
|
141
160
|
logits = self.lm_head(x) # (b, t, vocab_size)
|
142
161
|
return {"logits": logits, "kv_cache": updated_kv_cache}
|
@@ -146,8 +165,9 @@ def build_decoder_only_model(
|
|
146
165
|
checkpoint_path: str,
|
147
166
|
config: cfg.ModelConfig,
|
148
167
|
tensor_names: loading_utils.ModelLoader.TensorNames,
|
149
|
-
|
150
|
-
|
168
|
+
model_class: type[nn.Module] = DecoderOnlyModel,
|
169
|
+
) -> nn.Module:
|
170
|
+
transformer = model_class(config)
|
151
171
|
loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
|
152
172
|
loader.load(
|
153
173
|
transformer, strict=not config.lm_head_share_weight_with_embedding
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
22
23
|
import torch
|
23
24
|
|
24
25
|
|
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
|
|
40
41
|
"""
|
41
42
|
super().__init__()
|
42
43
|
self.model = model
|
44
|
+
self.export_config = ExportConfig(output_logits_on_prefill=True)
|
43
45
|
|
44
46
|
def forward(
|
45
47
|
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
103
105
|
Returns:
|
104
106
|
The output logits and the updated KV cache.
|
105
107
|
"""
|
108
|
+
# Verification requires logit outputs on prefill for comparison.
|
109
|
+
if (
|
110
|
+
self.export_config is not None
|
111
|
+
and not self.export_config.output_logits_on_prefill
|
112
|
+
):
|
113
|
+
raise ValueError("Verifier requires logit output on prefill.")
|
106
114
|
# Since the reauthored model doesn't include keyword arguments, pass
|
107
115
|
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
108
116
|
if pixel_values is None:
|
109
|
-
output = self.model.forward(
|
117
|
+
output = self.model.forward(
|
118
|
+
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
|
+
)
|
110
120
|
else:
|
111
121
|
output = self.model.forward(
|
112
|
-
tokens,
|
122
|
+
tokens,
|
123
|
+
input_pos,
|
124
|
+
kv_cache,
|
125
|
+
pixel_values=pixel_values,
|
126
|
+
export_config=self.export_config,
|
113
127
|
)
|
114
128
|
return output["logits"], output["kv_cache"]
|
115
129
|
|
@@ -15,13 +15,15 @@
|
|
15
15
|
|
16
16
|
from typing import Any, Optional
|
17
17
|
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from ai_edge_torch._convert import signature
|
20
20
|
from ai_edge_torch.quantize import quant_config as qcfg
|
21
21
|
import torch
|
22
22
|
|
23
|
+
config = _config.config
|
24
|
+
|
23
25
|
# isort: off
|
24
|
-
if config.
|
26
|
+
if config.use_torch_xla:
|
25
27
|
from ai_edge_torch.lowertools import torch_xla_utils as utils
|
26
28
|
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
|
27
29
|
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import re
|
17
17
|
from typing import Optional
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from absl.testing import absltest as googletest
|
20
20
|
|
21
|
+
config = _config.config
|
22
|
+
|
21
23
|
|
22
24
|
def _extract_backend_configs(mlir):
|
23
25
|
mlir = mlir.replace("\\22", '"')
|
@@ -38,7 +40,7 @@ def assert_string_count(
|
|
38
40
|
if odml_torch_attr_counter is None:
|
39
41
|
odml_torch_attr_counter = {}
|
40
42
|
|
41
|
-
if config.
|
43
|
+
if config.use_torch_xla:
|
42
44
|
for key in torch_xla_pattern_counter:
|
43
45
|
test_case.assertEqual(
|
44
46
|
mlir.count(key),
|
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
276
276
|
interior_padding if i == dim else 0 for i in range(rank)
|
277
277
|
],
|
278
278
|
)
|
279
|
-
|
280
|
-
|
279
|
+
|
280
|
+
slices = [
|
281
281
|
slice(start, end, step) if i == dim else slice(None, None, None)
|
282
282
|
for i in range(rank)
|
283
|
-
]
|
283
|
+
]
|
284
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
285
|
+
pred[np.index_exp[tuple(slices)]] = False
|
284
286
|
pred = stablehlo.constant(
|
285
287
|
ir.DenseElementsAttr.get(
|
286
288
|
np.packbits(pred, bitorder="little"),
|
@@ -232,7 +232,9 @@ def _aten_convolution(
|
|
232
232
|
|
233
233
|
if bias is not None:
|
234
234
|
# broadcast [C] to [NCHW]
|
235
|
-
broadcasted_bias = stablehlo.broadcast_in_dim(
|
235
|
+
broadcasted_bias = stablehlo.broadcast_in_dim(
|
236
|
+
output_type, bias, ir.DenseI64ArrayAttr.get([1])
|
237
|
+
)
|
236
238
|
res = stablehlo.add(
|
237
239
|
lhs=res,
|
238
240
|
rhs=broadcasted_bias,
|
@@ -16,12 +16,15 @@ import functools
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch import jax_bridge
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
21
|
+
import jax.numpy as jnp
|
22
|
+
from jax._src.lib.mlir import ir
|
19
23
|
import torch
|
20
24
|
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
|
21
25
|
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
|
22
26
|
|
23
|
-
|
24
|
-
|
27
|
+
LoweringContext = context.LoweringContext
|
25
28
|
|
26
29
|
@functools.cache
|
27
30
|
def _log_usage(op):
|
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
|
|
258
261
|
@lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
|
259
262
|
def _aten_copy(self, src, **kwargs):
|
260
263
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
|
264
|
+
|
265
|
+
|
266
|
+
# Schema:
|
267
|
+
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
|
268
|
+
# -> Tensor
|
269
|
+
# Torch Reference:
|
270
|
+
# - https://pytorch.org/docs/stable/generated/torch.einsum.html
|
271
|
+
# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
|
272
|
+
@registry.lower(torch.ops.aten.einsum.default)
|
273
|
+
def _aten_einsum_default(
|
274
|
+
lctx: LoweringContext,
|
275
|
+
equation: str,
|
276
|
+
tensors: list[ir.Value],
|
277
|
+
path=None,
|
278
|
+
):
|
279
|
+
_log_usage(torch.ops.aten.einsum.default)
|
280
|
+
|
281
|
+
@jax_bridge.wrap
|
282
|
+
def jax_lowering(operands):
|
283
|
+
# Ignore the input path and let JAX determine the path.
|
284
|
+
return jnp.einsum(equation, *operands, optimize="optimal")
|
285
|
+
|
286
|
+
return jax_lowering(lctx, tuple(tensors))
|
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
|
|
20
20
|
from ai_edge_torch.odml_torch.lowerings import utils
|
21
21
|
from jax._src.lib.mlir import ir
|
22
22
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import numpy as np
|
23
24
|
import torch
|
24
25
|
|
25
26
|
|
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
|
|
66
67
|
normalized_rank = len(normalized_shape)
|
67
68
|
if weight is not None:
|
68
69
|
weight = stablehlo.broadcast_in_dim(
|
69
|
-
data_type,
|
70
|
+
data_type,
|
71
|
+
weight,
|
72
|
+
ir.DenseI64ArrayAttr.get(
|
73
|
+
list(range(data_rank - normalized_rank, data_rank))
|
74
|
+
),
|
70
75
|
)
|
71
76
|
output = stablehlo.multiply(weight, output)
|
72
77
|
if bias is not None:
|
73
78
|
bias = stablehlo.broadcast_in_dim(
|
74
|
-
data_type,
|
79
|
+
data_type,
|
80
|
+
bias,
|
81
|
+
ir.DenseI64ArrayAttr.get(
|
82
|
+
list(range(data_rank - normalized_rank, data_rank))
|
83
|
+
),
|
75
84
|
)
|
76
85
|
output = stablehlo.add(bias, output)
|
77
86
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
|
16
|
-
from typing import Union, cast
|
16
|
+
from typing import Optional, Union, cast
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch.lowerings import context
|
19
19
|
from ai_edge_torch.odml_torch.lowerings import utils
|
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
|
|
30
30
|
|
31
31
|
|
32
32
|
def _uniform_quantized_type(
|
33
|
-
stored_type: str
|
34
|
-
expressed_type: str
|
33
|
+
stored_type: Union[str, ir.Type],
|
34
|
+
expressed_type: Union[str, ir.Type],
|
35
35
|
*,
|
36
|
-
scale=float
|
37
|
-
zero_point=float
|
38
|
-
storage_type_min: int
|
39
|
-
storage_type_max: int
|
40
|
-
channel_axis: int
|
41
|
-
channel_axis_size: int
|
36
|
+
scale=Union[float, list[float], tuple[float]],
|
37
|
+
zero_point=Union[float, list[float], tuple[float]],
|
38
|
+
storage_type_min: Optional[int] = None,
|
39
|
+
storage_type_max: Optional[int] = None,
|
40
|
+
channel_axis: Optional[int] = None,
|
41
|
+
channel_axis_size: Optional[int] = None,
|
42
42
|
):
|
43
43
|
"""Polyfill for quant.UniformQuantizedType."""
|
44
44
|
if storage_type_min and storage_type_max:
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Torch export decompositions to run before lowering."""
|
16
|
+
|
17
|
+
import functools
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
@functools.cache
|
23
|
+
def decompositions():
|
24
|
+
# Base: Core ATen decompositions
|
25
|
+
decompositions = torch._decomp.core_aten_decompositions()
|
26
|
+
|
27
|
+
decompositions.update(
|
28
|
+
torch._decomp.get_decompositions([
|
29
|
+
torch.ops.aten.upsample_nearest2d,
|
30
|
+
torch.ops.aten._native_batch_norm_legit.no_stats,
|
31
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
32
|
+
torch.ops.aten._adaptive_avg_pool2d,
|
33
|
+
torch.ops.aten._adaptive_avg_pool3d,
|
34
|
+
torch.ops.aten.grid_sampler_2d,
|
35
|
+
torch.ops.aten.native_group_norm,
|
36
|
+
torch.ops.aten.native_dropout,
|
37
|
+
torch.ops.aten.reflection_pad1d,
|
38
|
+
torch.ops.aten.reflection_pad2d,
|
39
|
+
torch.ops.aten.reflection_pad3d,
|
40
|
+
torch.ops.aten.replication_pad1d,
|
41
|
+
torch.ops.aten.replication_pad2d,
|
42
|
+
torch.ops.aten.replication_pad3d,
|
43
|
+
torch.ops.aten.addmm,
|
44
|
+
])
|
45
|
+
)
|
46
|
+
|
47
|
+
torch._decomp.remove_decompositions(
|
48
|
+
decompositions,
|
49
|
+
[
|
50
|
+
torch.ops.aten.roll,
|
51
|
+
# Torch's default einsum impl/decompositions is less efficient and
|
52
|
+
# optimized through converter than JAX's impl. Disable einsum
|
53
|
+
# decomposition to use JAX bridge for a more efficient lowering.
|
54
|
+
torch.ops.aten.einsum.default,
|
55
|
+
],
|
56
|
+
)
|
57
|
+
|
58
|
+
# Override _safe_softmax decompositions with regular softmax.
|
59
|
+
# _safe_softmax introduces additional check-select ops to guard extreme
|
60
|
+
# input values to softmax, which could make the converted model inefficient
|
61
|
+
# on-device.
|
62
|
+
if hasattr(torch.ops.aten, "_safe_softmax"):
|
63
|
+
decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
|
64
|
+
|
65
|
+
return decompositions
|
@@ -26,7 +26,6 @@ class LoweringRegistry:
|
|
26
26
|
|
27
27
|
def __init__(self):
|
28
28
|
self.registered_ops = {}
|
29
|
-
self.decompositions = {}
|
30
29
|
|
31
30
|
def lookup(self, op_or_name):
|
32
31
|
candidate = self._get_lowering(op_or_name)
|
@@ -52,33 +51,6 @@ class LoweringRegistry:
|
|
52
51
|
|
53
52
|
|
54
53
|
global_registry = LoweringRegistry()
|
55
|
-
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
|
56
|
-
global_registry.decompositions.update(
|
57
|
-
torch._decomp.get_decompositions([
|
58
|
-
torch.ops.aten.upsample_nearest2d,
|
59
|
-
torch.ops.aten._native_batch_norm_legit.no_stats,
|
60
|
-
torch.ops.aten._native_batch_norm_legit_functional,
|
61
|
-
torch.ops.aten._adaptive_avg_pool2d,
|
62
|
-
torch.ops.aten._adaptive_avg_pool3d,
|
63
|
-
torch.ops.aten.grid_sampler_2d,
|
64
|
-
torch.ops.aten.native_group_norm,
|
65
|
-
torch.ops.aten.native_dropout,
|
66
|
-
torch.ops.aten.reflection_pad1d,
|
67
|
-
torch.ops.aten.reflection_pad2d,
|
68
|
-
torch.ops.aten.reflection_pad3d,
|
69
|
-
torch.ops.aten.replication_pad1d,
|
70
|
-
torch.ops.aten.replication_pad2d,
|
71
|
-
torch.ops.aten.replication_pad3d,
|
72
|
-
torch.ops.aten.addmm,
|
73
|
-
])
|
74
|
-
)
|
75
|
-
|
76
|
-
torch._decomp.remove_decompositions(
|
77
|
-
global_registry.decompositions,
|
78
|
-
[
|
79
|
-
torch.ops.aten.roll,
|
80
|
-
],
|
81
|
-
)
|
82
54
|
|
83
55
|
|
84
56
|
def lookup(op):
|
@@ -91,7 +63,3 @@ def lower(op):
|
|
91
63
|
return lowering
|
92
64
|
|
93
65
|
return inner
|
94
|
-
|
95
|
-
|
96
|
-
def decompositions():
|
97
|
-
return global_registry.decompositions
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241214
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: Apache Software License
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
13
13
|
Classifier: Programming Language :: Python :: 3 :: Only
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
17
16
|
Classifier: Topic :: Scientific/Engineering
|
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
19
|
Classifier: Topic :: Software Development
|
21
20
|
Classifier: Topic :: Software Development :: Libraries
|
22
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
-
Requires-Python: >=3.
|
22
|
+
Requires-Python: >=3.10
|
24
23
|
Description-Content-Type: text/markdown
|
25
24
|
License-File: LICENSE
|
26
25
|
Requires-Dist: numpy
|
@@ -28,10 +27,13 @@ Requires-Dist: scipy
|
|
28
27
|
Requires-Dist: safetensors
|
29
28
|
Requires-Dist: tabulate
|
30
29
|
Requires-Dist: torch>=2.4.0
|
31
|
-
Requires-Dist:
|
32
|
-
Requires-Dist: tf-nightly>=2.19.0.dev20241121
|
30
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20241201
|
33
31
|
Requires-Dist: ai-edge-litert-nightly
|
34
32
|
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: jax
|
34
|
+
Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
|
35
|
+
Provides-Extra: torch-xla
|
36
|
+
Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
|
35
37
|
|
36
38
|
Library that supports converting PyTorch models into a .tflite format, which can
|
37
39
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|