ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_config.py +52 -0
- ai_edge_torch/_convert/test/test_convert.py +1 -2
- ai_edge_torch/debug/test/test_culprit.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/llama/llama.py +11 -17
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
- ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/phi/phi2.py +8 -3
- ai_edge_torch/generative/examples/phi/phi3.py +7 -9
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
- ai_edge_torch/generative/layers/attention.py +2 -6
- ai_edge_torch/generative/layers/kv_cache.py +24 -18
- ai_edge_torch/generative/layers/normalization.py +1 -3
- ai_edge_torch/generative/test/test_kv_cache.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +12 -14
- ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
- ai_edge_torch/generative/test/utils.py +31 -6
- ai_edge_torch/generative/utilities/converter.py +25 -4
- ai_edge_torch/generative/utilities/model_builder.py +24 -4
- ai_edge_torch/generative/utilities/verifier.py +16 -2
- ai_edge_torch/lowertools/_shim.py +4 -2
- ai_edge_torch/lowertools/test_utils.py +4 -2
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
- ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
- ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
- ai_edge_torch/config.py +0 -27
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,28 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
from
|
18
|
+
from functools import partial
|
19
|
+
from typing import Any, Union
|
19
20
|
|
20
21
|
from ai_edge_torch._convert import converter as converter_utils
|
21
22
|
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
23
24
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
24
26
|
import torch
|
27
|
+
import torch.nn as nn
|
28
|
+
|
29
|
+
|
30
|
+
class ExportableModule(torch.nn.Module):
|
31
|
+
|
32
|
+
def __init__(self, module, **extra_kwargs):
|
33
|
+
super().__init__()
|
34
|
+
self.module = module
|
35
|
+
self.extra_kwargs = extra_kwargs
|
36
|
+
|
37
|
+
def forward(self, *export_args, **export_kwargs):
|
38
|
+
full_kwargs = {**export_kwargs, **self.extra_kwargs}
|
39
|
+
return self.module(*export_args, **full_kwargs)
|
25
40
|
|
26
41
|
|
27
42
|
def convert_to_tflite(
|
@@ -31,6 +46,7 @@ def convert_to_tflite(
|
|
31
46
|
pixel_values_size: torch.Size = None,
|
32
47
|
quantize: bool = True,
|
33
48
|
config: cfg.ModelConfig = None,
|
49
|
+
export_config: ExportConfig = None,
|
34
50
|
):
|
35
51
|
"""Converts a nn.Module model to multi-signature tflite model.
|
36
52
|
|
@@ -97,6 +113,11 @@ def convert_to_tflite(
|
|
97
113
|
)
|
98
114
|
|
99
115
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
116
|
+
|
117
|
+
# For export, we create a module that captures any non-exportable,
|
118
|
+
# arugments, e.g. the generation config object.
|
119
|
+
mod = ExportableModule(pytorch_model, export_config=export_config)
|
120
|
+
|
100
121
|
converter = converter_utils.Converter()
|
101
122
|
for i in range(len(prefill_seq_lens)):
|
102
123
|
prefill_seq_len = prefill_seq_lens[i]
|
@@ -108,7 +129,7 @@ def convert_to_tflite(
|
|
108
129
|
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
109
130
|
converter.add_signature(
|
110
131
|
prefill_signature_name,
|
111
|
-
|
132
|
+
mod,
|
112
133
|
sample_kwargs={
|
113
134
|
'tokens': prefill_tokens,
|
114
135
|
'input_pos': prefill_input_pos,
|
@@ -118,7 +139,7 @@ def convert_to_tflite(
|
|
118
139
|
if prefill_pixel_values is not None:
|
119
140
|
converter.add_signature(
|
120
141
|
prefill_signature_name + '_pixel',
|
121
|
-
|
142
|
+
mod,
|
122
143
|
sample_kwargs={
|
123
144
|
'tokens': prefill_tokens,
|
124
145
|
'input_pos': prefill_input_pos,
|
@@ -129,7 +150,7 @@ def convert_to_tflite(
|
|
129
150
|
|
130
151
|
converter.add_signature(
|
131
152
|
'decode',
|
132
|
-
|
153
|
+
mod,
|
133
154
|
sample_kwargs={
|
134
155
|
'tokens': decode_token,
|
135
156
|
'input_pos': decode_input_pos,
|
@@ -16,7 +16,8 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
-
from
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from typing import Optional, Tuple
|
20
21
|
|
21
22
|
from ai_edge_torch.generative.layers import attention
|
22
23
|
from ai_edge_torch.generative.layers import builder
|
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
|
|
45
46
|
TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
|
46
47
|
|
47
48
|
|
49
|
+
@dataclass
|
50
|
+
class ExportConfig:
|
51
|
+
"""Model generating configuration settings."""
|
52
|
+
|
53
|
+
# On prefill signatures, should the model produce logit output?
|
54
|
+
# When False, only decode signatures will produce output.
|
55
|
+
output_logits_on_prefill: bool = False
|
56
|
+
|
57
|
+
|
48
58
|
class DecoderOnlyModel(nn.Module):
|
49
59
|
"""A simple decoder-only transformer model built from the Edge Generative API.
|
50
60
|
|
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
|
|
93
103
|
tokens: torch.Tensor,
|
94
104
|
input_pos: torch.Tensor,
|
95
105
|
kv_cache: kv_utils.KVCache,
|
106
|
+
export_config: Optional[ExportConfig] = None,
|
96
107
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
97
108
|
_, seq_len = tokens.size()
|
98
109
|
assert self.config.max_seq_len >= seq_len, (
|
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
|
|
108
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
109
120
|
|
110
121
|
return self.forward_with_embeds(
|
111
|
-
input_embeds, rope, mask, input_pos, kv_cache
|
122
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
112
123
|
)
|
113
124
|
|
114
125
|
def forward_with_embeds(
|
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
|
|
118
129
|
mask: torch.Tensor,
|
119
130
|
input_pos: torch.Tensor,
|
120
131
|
kv_cache: kv_utils.KVCache,
|
132
|
+
export_config: Optional[ExportConfig] = None,
|
121
133
|
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
122
134
|
"""Forwards the model with input embeddings."""
|
123
135
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
|
|
137
149
|
updated_kv_entires.append(kv_entry)
|
138
150
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
139
151
|
|
152
|
+
if export_config is not None:
|
153
|
+
if (
|
154
|
+
torch.numel(input_pos) > 1
|
155
|
+
and not export_config.output_logits_on_prefill
|
156
|
+
):
|
157
|
+
return {"kv_cache": updated_kv_cache}
|
158
|
+
|
140
159
|
x = self.final_norm(x)
|
141
160
|
logits = self.lm_head(x) # (b, t, vocab_size)
|
142
161
|
return {"logits": logits, "kv_cache": updated_kv_cache}
|
@@ -146,8 +165,9 @@ def build_decoder_only_model(
|
|
146
165
|
checkpoint_path: str,
|
147
166
|
config: cfg.ModelConfig,
|
148
167
|
tensor_names: loading_utils.ModelLoader.TensorNames,
|
149
|
-
|
150
|
-
|
168
|
+
model_class: type[nn.Module] = DecoderOnlyModel,
|
169
|
+
) -> nn.Module:
|
170
|
+
transformer = model_class(config)
|
151
171
|
loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
|
152
172
|
loader.load(
|
153
173
|
transformer, strict=not config.lm_head_share_weight_with_embedding
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
22
23
|
import torch
|
23
24
|
|
24
25
|
|
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
|
|
40
41
|
"""
|
41
42
|
super().__init__()
|
42
43
|
self.model = model
|
44
|
+
self.export_config = ExportConfig(output_logits_on_prefill=True)
|
43
45
|
|
44
46
|
def forward(
|
45
47
|
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
103
105
|
Returns:
|
104
106
|
The output logits and the updated KV cache.
|
105
107
|
"""
|
108
|
+
# Verification requires logit outputs on prefill for comparison.
|
109
|
+
if (
|
110
|
+
self.export_config is not None
|
111
|
+
and not self.export_config.output_logits_on_prefill
|
112
|
+
):
|
113
|
+
raise ValueError("Verifier requires logit output on prefill.")
|
106
114
|
# Since the reauthored model doesn't include keyword arguments, pass
|
107
115
|
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
108
116
|
if pixel_values is None:
|
109
|
-
output = self.model.forward(
|
117
|
+
output = self.model.forward(
|
118
|
+
tokens, input_pos, kv_cache, export_config=self.export_config
|
119
|
+
)
|
110
120
|
else:
|
111
121
|
output = self.model.forward(
|
112
|
-
tokens,
|
122
|
+
tokens,
|
123
|
+
input_pos,
|
124
|
+
kv_cache,
|
125
|
+
pixel_values=pixel_values,
|
126
|
+
export_config=self.export_config,
|
113
127
|
)
|
114
128
|
return output["logits"], output["kv_cache"]
|
115
129
|
|
@@ -15,13 +15,15 @@
|
|
15
15
|
|
16
16
|
from typing import Any, Optional
|
17
17
|
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from ai_edge_torch._convert import signature
|
20
20
|
from ai_edge_torch.quantize import quant_config as qcfg
|
21
21
|
import torch
|
22
22
|
|
23
|
+
config = _config.config
|
24
|
+
|
23
25
|
# isort: off
|
24
|
-
if config.
|
26
|
+
if config.use_torch_xla:
|
25
27
|
from ai_edge_torch.lowertools import torch_xla_utils as utils
|
26
28
|
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
|
27
29
|
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
import re
|
17
17
|
from typing import Optional
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import _config
|
19
19
|
from absl.testing import absltest as googletest
|
20
20
|
|
21
|
+
config = _config.config
|
22
|
+
|
21
23
|
|
22
24
|
def _extract_backend_configs(mlir):
|
23
25
|
mlir = mlir.replace("\\22", '"')
|
@@ -38,7 +40,7 @@ def assert_string_count(
|
|
38
40
|
if odml_torch_attr_counter is None:
|
39
41
|
odml_torch_attr_counter = {}
|
40
42
|
|
41
|
-
if config.
|
43
|
+
if config.use_torch_xla:
|
42
44
|
for key in torch_xla_pattern_counter:
|
43
45
|
test_case.assertEqual(
|
44
46
|
mlir.count(key),
|
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
276
276
|
interior_padding if i == dim else 0 for i in range(rank)
|
277
277
|
],
|
278
278
|
)
|
279
|
-
|
280
|
-
|
279
|
+
|
280
|
+
slices = [
|
281
281
|
slice(start, end, step) if i == dim else slice(None, None, None)
|
282
282
|
for i in range(rank)
|
283
|
-
]
|
283
|
+
]
|
284
|
+
pred = np.ones(self.type.shape, dtype=np.bool_)
|
285
|
+
pred[np.index_exp[tuple(slices)]] = False
|
284
286
|
pred = stablehlo.constant(
|
285
287
|
ir.DenseElementsAttr.get(
|
286
288
|
np.packbits(pred, bitorder="little"),
|
@@ -232,7 +232,9 @@ def _aten_convolution(
|
|
232
232
|
|
233
233
|
if bias is not None:
|
234
234
|
# broadcast [C] to [NCHW]
|
235
|
-
broadcasted_bias = stablehlo.broadcast_in_dim(
|
235
|
+
broadcasted_bias = stablehlo.broadcast_in_dim(
|
236
|
+
output_type, bias, ir.DenseI64ArrayAttr.get([1])
|
237
|
+
)
|
236
238
|
res = stablehlo.add(
|
237
239
|
lhs=res,
|
238
240
|
rhs=broadcasted_bias,
|
@@ -16,12 +16,15 @@ import functools
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch import jax_bridge
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
21
|
+
import jax.numpy as jnp
|
22
|
+
from jax._src.lib.mlir import ir
|
19
23
|
import torch
|
20
24
|
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
|
21
25
|
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
|
22
26
|
|
23
|
-
|
24
|
-
|
27
|
+
LoweringContext = context.LoweringContext
|
25
28
|
|
26
29
|
@functools.cache
|
27
30
|
def _log_usage(op):
|
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
|
|
258
261
|
@lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
|
259
262
|
def _aten_copy(self, src, **kwargs):
|
260
263
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
|
264
|
+
|
265
|
+
|
266
|
+
# Schema:
|
267
|
+
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
|
268
|
+
# -> Tensor
|
269
|
+
# Torch Reference:
|
270
|
+
# - https://pytorch.org/docs/stable/generated/torch.einsum.html
|
271
|
+
# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
|
272
|
+
@registry.lower(torch.ops.aten.einsum.default)
|
273
|
+
def _aten_einsum_default(
|
274
|
+
lctx: LoweringContext,
|
275
|
+
equation: str,
|
276
|
+
tensors: list[ir.Value],
|
277
|
+
path=None,
|
278
|
+
):
|
279
|
+
_log_usage(torch.ops.aten.einsum.default)
|
280
|
+
|
281
|
+
@jax_bridge.wrap
|
282
|
+
def jax_lowering(operands):
|
283
|
+
# Ignore the input path and let JAX determine the path.
|
284
|
+
return jnp.einsum(equation, *operands, optimize="optimal")
|
285
|
+
|
286
|
+
return jax_lowering(lctx, tuple(tensors))
|
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
|
|
20
20
|
from ai_edge_torch.odml_torch.lowerings import utils
|
21
21
|
from jax._src.lib.mlir import ir
|
22
22
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import numpy as np
|
23
24
|
import torch
|
24
25
|
|
25
26
|
|
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
|
|
66
67
|
normalized_rank = len(normalized_shape)
|
67
68
|
if weight is not None:
|
68
69
|
weight = stablehlo.broadcast_in_dim(
|
69
|
-
data_type,
|
70
|
+
data_type,
|
71
|
+
weight,
|
72
|
+
ir.DenseI64ArrayAttr.get(
|
73
|
+
list(range(data_rank - normalized_rank, data_rank))
|
74
|
+
),
|
70
75
|
)
|
71
76
|
output = stablehlo.multiply(weight, output)
|
72
77
|
if bias is not None:
|
73
78
|
bias = stablehlo.broadcast_in_dim(
|
74
|
-
data_type,
|
79
|
+
data_type,
|
80
|
+
bias,
|
81
|
+
ir.DenseI64ArrayAttr.get(
|
82
|
+
list(range(data_rank - normalized_rank, data_rank))
|
83
|
+
),
|
75
84
|
)
|
76
85
|
output = stablehlo.add(bias, output)
|
77
86
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
|
16
|
-
from typing import Union, cast
|
16
|
+
from typing import Optional, Union, cast
|
17
17
|
|
18
18
|
from ai_edge_torch.odml_torch.lowerings import context
|
19
19
|
from ai_edge_torch.odml_torch.lowerings import utils
|
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
|
|
30
30
|
|
31
31
|
|
32
32
|
def _uniform_quantized_type(
|
33
|
-
stored_type: str
|
34
|
-
expressed_type: str
|
33
|
+
stored_type: Union[str, ir.Type],
|
34
|
+
expressed_type: Union[str, ir.Type],
|
35
35
|
*,
|
36
|
-
scale=float
|
37
|
-
zero_point=float
|
38
|
-
storage_type_min: int
|
39
|
-
storage_type_max: int
|
40
|
-
channel_axis: int
|
41
|
-
channel_axis_size: int
|
36
|
+
scale=Union[float, list[float], tuple[float]],
|
37
|
+
zero_point=Union[float, list[float], tuple[float]],
|
38
|
+
storage_type_min: Optional[int] = None,
|
39
|
+
storage_type_max: Optional[int] = None,
|
40
|
+
channel_axis: Optional[int] = None,
|
41
|
+
channel_axis_size: Optional[int] = None,
|
42
42
|
):
|
43
43
|
"""Polyfill for quant.UniformQuantizedType."""
|
44
44
|
if storage_type_min and storage_type_max:
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Torch export decompositions to run before lowering."""
|
16
|
+
|
17
|
+
import functools
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
@functools.cache
|
23
|
+
def decompositions():
|
24
|
+
# Base: Core ATen decompositions
|
25
|
+
decompositions = torch._decomp.core_aten_decompositions()
|
26
|
+
|
27
|
+
decompositions.update(
|
28
|
+
torch._decomp.get_decompositions([
|
29
|
+
torch.ops.aten.upsample_nearest2d,
|
30
|
+
torch.ops.aten._native_batch_norm_legit.no_stats,
|
31
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
32
|
+
torch.ops.aten._adaptive_avg_pool2d,
|
33
|
+
torch.ops.aten._adaptive_avg_pool3d,
|
34
|
+
torch.ops.aten.grid_sampler_2d,
|
35
|
+
torch.ops.aten.native_group_norm,
|
36
|
+
torch.ops.aten.native_dropout,
|
37
|
+
torch.ops.aten.reflection_pad1d,
|
38
|
+
torch.ops.aten.reflection_pad2d,
|
39
|
+
torch.ops.aten.reflection_pad3d,
|
40
|
+
torch.ops.aten.replication_pad1d,
|
41
|
+
torch.ops.aten.replication_pad2d,
|
42
|
+
torch.ops.aten.replication_pad3d,
|
43
|
+
torch.ops.aten.addmm,
|
44
|
+
])
|
45
|
+
)
|
46
|
+
|
47
|
+
torch._decomp.remove_decompositions(
|
48
|
+
decompositions,
|
49
|
+
[
|
50
|
+
torch.ops.aten.roll,
|
51
|
+
# Torch's default einsum impl/decompositions is less efficient and
|
52
|
+
# optimized through converter than JAX's impl. Disable einsum
|
53
|
+
# decomposition to use JAX bridge for a more efficient lowering.
|
54
|
+
torch.ops.aten.einsum.default,
|
55
|
+
],
|
56
|
+
)
|
57
|
+
|
58
|
+
# Override _safe_softmax decompositions with regular softmax.
|
59
|
+
# _safe_softmax introduces additional check-select ops to guard extreme
|
60
|
+
# input values to softmax, which could make the converted model inefficient
|
61
|
+
# on-device.
|
62
|
+
if hasattr(torch.ops.aten, "_safe_softmax"):
|
63
|
+
decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
|
64
|
+
|
65
|
+
return decompositions
|
@@ -26,7 +26,6 @@ class LoweringRegistry:
|
|
26
26
|
|
27
27
|
def __init__(self):
|
28
28
|
self.registered_ops = {}
|
29
|
-
self.decompositions = {}
|
30
29
|
|
31
30
|
def lookup(self, op_or_name):
|
32
31
|
candidate = self._get_lowering(op_or_name)
|
@@ -52,33 +51,6 @@ class LoweringRegistry:
|
|
52
51
|
|
53
52
|
|
54
53
|
global_registry = LoweringRegistry()
|
55
|
-
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
|
56
|
-
global_registry.decompositions.update(
|
57
|
-
torch._decomp.get_decompositions([
|
58
|
-
torch.ops.aten.upsample_nearest2d,
|
59
|
-
torch.ops.aten._native_batch_norm_legit.no_stats,
|
60
|
-
torch.ops.aten._native_batch_norm_legit_functional,
|
61
|
-
torch.ops.aten._adaptive_avg_pool2d,
|
62
|
-
torch.ops.aten._adaptive_avg_pool3d,
|
63
|
-
torch.ops.aten.grid_sampler_2d,
|
64
|
-
torch.ops.aten.native_group_norm,
|
65
|
-
torch.ops.aten.native_dropout,
|
66
|
-
torch.ops.aten.reflection_pad1d,
|
67
|
-
torch.ops.aten.reflection_pad2d,
|
68
|
-
torch.ops.aten.reflection_pad3d,
|
69
|
-
torch.ops.aten.replication_pad1d,
|
70
|
-
torch.ops.aten.replication_pad2d,
|
71
|
-
torch.ops.aten.replication_pad3d,
|
72
|
-
torch.ops.aten.addmm,
|
73
|
-
])
|
74
|
-
)
|
75
|
-
|
76
|
-
torch._decomp.remove_decompositions(
|
77
|
-
global_registry.decompositions,
|
78
|
-
[
|
79
|
-
torch.ops.aten.roll,
|
80
|
-
],
|
81
|
-
)
|
82
54
|
|
83
55
|
|
84
56
|
def lookup(op):
|
@@ -91,7 +63,3 @@ def lower(op):
|
|
91
63
|
return lowering
|
92
64
|
|
93
65
|
return inner
|
94
|
-
|
95
|
-
|
96
|
-
def decompositions():
|
97
|
-
return global_registry.decompositions
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241214
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
|
|
11
11
|
Classifier: License :: OSI Approved :: Apache Software License
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
13
13
|
Classifier: Programming Language :: Python :: 3 :: Only
|
14
|
-
Classifier: Programming Language :: Python :: 3.9
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
17
16
|
Classifier: Topic :: Scientific/Engineering
|
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
19
|
Classifier: Topic :: Software Development
|
21
20
|
Classifier: Topic :: Software Development :: Libraries
|
22
21
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23
|
-
Requires-Python: >=3.
|
22
|
+
Requires-Python: >=3.10
|
24
23
|
Description-Content-Type: text/markdown
|
25
24
|
License-File: LICENSE
|
26
25
|
Requires-Dist: numpy
|
@@ -28,10 +27,13 @@ Requires-Dist: scipy
|
|
28
27
|
Requires-Dist: safetensors
|
29
28
|
Requires-Dist: tabulate
|
30
29
|
Requires-Dist: torch>=2.4.0
|
31
|
-
Requires-Dist:
|
32
|
-
Requires-Dist: tf-nightly>=2.19.0.dev20241121
|
30
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20241201
|
33
31
|
Requires-Dist: ai-edge-litert-nightly
|
34
32
|
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: jax
|
34
|
+
Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
|
35
|
+
Provides-Extra: torch-xla
|
36
|
+
Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
|
35
37
|
|
36
38
|
Library that supports converting PyTorch models into a .tflite format, which can
|
37
39
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|