ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_config.py +52 -0
  3. ai_edge_torch/_convert/test/test_convert.py +1 -2
  4. ai_edge_torch/debug/test/test_culprit.py +8 -3
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  7. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  8. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  10. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  11. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  12. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  15. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  17. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  18. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  19. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  21. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  22. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  23. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  24. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  25. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  26. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  27. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  29. ai_edge_torch/generative/layers/attention.py +2 -6
  30. ai_edge_torch/generative/layers/kv_cache.py +24 -18
  31. ai_edge_torch/generative/layers/normalization.py +1 -3
  32. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  33. ai_edge_torch/generative/test/test_model_conversion.py +12 -14
  34. ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
  35. ai_edge_torch/generative/test/utils.py +31 -6
  36. ai_edge_torch/generative/utilities/converter.py +25 -4
  37. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  38. ai_edge_torch/generative/utilities/verifier.py +16 -2
  39. ai_edge_torch/lowertools/_shim.py +4 -2
  40. ai_edge_torch/lowertools/test_utils.py +4 -2
  41. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  42. ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
  43. ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
  44. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  45. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
  46. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
  47. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  48. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  49. ai_edge_torch/version.py +1 -1
  50. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
  51. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
  52. ai_edge_torch/config.py +0 -27
  53. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
  54. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,28 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from typing import Union
18
+ from functools import partial
19
+ from typing import Any, Union
19
20
 
20
21
  from ai_edge_torch._convert import converter as converter_utils
21
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
24
26
  import torch
27
+ import torch.nn as nn
28
+
29
+
30
+ class ExportableModule(torch.nn.Module):
31
+
32
+ def __init__(self, module, **extra_kwargs):
33
+ super().__init__()
34
+ self.module = module
35
+ self.extra_kwargs = extra_kwargs
36
+
37
+ def forward(self, *export_args, **export_kwargs):
38
+ full_kwargs = {**export_kwargs, **self.extra_kwargs}
39
+ return self.module(*export_args, **full_kwargs)
25
40
 
26
41
 
27
42
  def convert_to_tflite(
@@ -31,6 +46,7 @@ def convert_to_tflite(
31
46
  pixel_values_size: torch.Size = None,
32
47
  quantize: bool = True,
33
48
  config: cfg.ModelConfig = None,
49
+ export_config: ExportConfig = None,
34
50
  ):
35
51
  """Converts a nn.Module model to multi-signature tflite model.
36
52
 
@@ -97,6 +113,11 @@ def convert_to_tflite(
97
113
  )
98
114
 
99
115
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
+
117
+ # For export, we create a module that captures any non-exportable,
118
+ # arugments, e.g. the generation config object.
119
+ mod = ExportableModule(pytorch_model, export_config=export_config)
120
+
100
121
  converter = converter_utils.Converter()
101
122
  for i in range(len(prefill_seq_lens)):
102
123
  prefill_seq_len = prefill_seq_lens[i]
@@ -108,7 +129,7 @@ def convert_to_tflite(
108
129
  prefill_signature_name = f'prefill_{prefill_seq_len}'
109
130
  converter.add_signature(
110
131
  prefill_signature_name,
111
- pytorch_model,
132
+ mod,
112
133
  sample_kwargs={
113
134
  'tokens': prefill_tokens,
114
135
  'input_pos': prefill_input_pos,
@@ -118,7 +139,7 @@ def convert_to_tflite(
118
139
  if prefill_pixel_values is not None:
119
140
  converter.add_signature(
120
141
  prefill_signature_name + '_pixel',
121
- pytorch_model,
142
+ mod,
122
143
  sample_kwargs={
123
144
  'tokens': prefill_tokens,
124
145
  'input_pos': prefill_input_pos,
@@ -129,7 +150,7 @@ def convert_to_tflite(
129
150
 
130
151
  converter.add_signature(
131
152
  'decode',
132
- pytorch_model,
153
+ mod,
133
154
  sample_kwargs={
134
155
  'tokens': decode_token,
135
156
  'input_pos': decode_input_pos,
@@ -16,7 +16,8 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
- from typing import Tuple
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
45
46
  TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
46
47
 
47
48
 
49
+ @dataclass
50
+ class ExportConfig:
51
+ """Model generating configuration settings."""
52
+
53
+ # On prefill signatures, should the model produce logit output?
54
+ # When False, only decode signatures will produce output.
55
+ output_logits_on_prefill: bool = False
56
+
57
+
48
58
  class DecoderOnlyModel(nn.Module):
49
59
  """A simple decoder-only transformer model built from the Edge Generative API.
50
60
 
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
93
103
  tokens: torch.Tensor,
94
104
  input_pos: torch.Tensor,
95
105
  kv_cache: kv_utils.KVCache,
106
+ export_config: Optional[ExportConfig] = None,
96
107
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
108
  _, seq_len = tokens.size()
98
109
  assert self.config.max_seq_len >= seq_len, (
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
108
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
109
120
 
110
121
  return self.forward_with_embeds(
111
- input_embeds, rope, mask, input_pos, kv_cache
122
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
112
123
  )
113
124
 
114
125
  def forward_with_embeds(
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
118
129
  mask: torch.Tensor,
119
130
  input_pos: torch.Tensor,
120
131
  kv_cache: kv_utils.KVCache,
132
+ export_config: Optional[ExportConfig] = None,
121
133
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
122
134
  """Forwards the model with input embeddings."""
123
135
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
137
149
  updated_kv_entires.append(kv_entry)
138
150
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
139
151
 
152
+ if export_config is not None:
153
+ if (
154
+ torch.numel(input_pos) > 1
155
+ and not export_config.output_logits_on_prefill
156
+ ):
157
+ return {"kv_cache": updated_kv_cache}
158
+
140
159
  x = self.final_norm(x)
141
160
  logits = self.lm_head(x) # (b, t, vocab_size)
142
161
  return {"logits": logits, "kv_cache": updated_kv_cache}
@@ -146,8 +165,9 @@ def build_decoder_only_model(
146
165
  checkpoint_path: str,
147
166
  config: cfg.ModelConfig,
148
167
  tensor_names: loading_utils.ModelLoader.TensorNames,
149
- ) -> DecoderOnlyModel:
150
- transformer = DecoderOnlyModel(config)
168
+ model_class: type[nn.Module] = DecoderOnlyModel,
169
+ ) -> nn.Module:
170
+ transformer = model_class(config)
151
171
  loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
152
172
  loader.load(
153
173
  transformer, strict=not config.lm_head_share_weight_with_embedding
@@ -19,6 +19,7 @@ import logging
19
19
  from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22
23
  import torch
23
24
 
24
25
 
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
40
41
  """
41
42
  super().__init__()
42
43
  self.model = model
44
+ self.export_config = ExportConfig(output_logits_on_prefill=True)
43
45
 
44
46
  def forward(
45
47
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
103
105
  Returns:
104
106
  The output logits and the updated KV cache.
105
107
  """
108
+ # Verification requires logit outputs on prefill for comparison.
109
+ if (
110
+ self.export_config is not None
111
+ and not self.export_config.output_logits_on_prefill
112
+ ):
113
+ raise ValueError("Verifier requires logit output on prefill.")
106
114
  # Since the reauthored model doesn't include keyword arguments, pass
107
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
108
116
  if pixel_values is None:
109
- output = self.model.forward(tokens, input_pos, kv_cache)
117
+ output = self.model.forward(
118
+ tokens, input_pos, kv_cache, export_config=self.export_config
119
+ )
110
120
  else:
111
121
  output = self.model.forward(
112
- tokens, input_pos, kv_cache, pixel_values=pixel_values
122
+ tokens,
123
+ input_pos,
124
+ kv_cache,
125
+ pixel_values=pixel_values,
126
+ export_config=self.export_config,
113
127
  )
114
128
  return output["logits"], output["kv_cache"]
115
129
 
@@ -15,13 +15,15 @@
15
15
 
16
16
  from typing import Any, Optional
17
17
 
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from ai_edge_torch._convert import signature
20
20
  from ai_edge_torch.quantize import quant_config as qcfg
21
21
  import torch
22
22
 
23
+ config = _config.config
24
+
23
25
  # isort: off
24
- if config.Config.use_torch_xla:
26
+ if config.use_torch_xla:
25
27
  from ai_edge_torch.lowertools import torch_xla_utils as utils
26
28
  from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
27
29
  from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
@@ -15,9 +15,11 @@
15
15
 
16
16
  import re
17
17
  from typing import Optional
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from absl.testing import absltest as googletest
20
20
 
21
+ config = _config.config
22
+
21
23
 
22
24
  def _extract_backend_configs(mlir):
23
25
  mlir = mlir.replace("\\22", '"')
@@ -38,7 +40,7 @@ def assert_string_count(
38
40
  if odml_torch_attr_counter is None:
39
41
  odml_torch_attr_counter = {}
40
42
 
41
- if config.Config.use_torch_xla:
43
+ if config.use_torch_xla:
42
44
  for key in torch_xla_pattern_counter:
43
45
  test_case.assertEqual(
44
46
  mlir.count(key),
@@ -21,6 +21,6 @@ from . import _quantized_decomposed
21
21
  from . import context
22
22
  from . import registry
23
23
  from . import utils
24
- from .registry import decompositions
24
+ from .decomp import decompositions
25
25
  from .registry import lookup
26
26
  from .registry import lower
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
276
276
  interior_padding if i == dim else 0 for i in range(rank)
277
277
  ],
278
278
  )
279
- pred = np.ones(self.type.shape, dtype=np.bool_)
280
- pred[*[
279
+
280
+ slices = [
281
281
  slice(start, end, step) if i == dim else slice(None, None, None)
282
282
  for i in range(rank)
283
- ]] = False
283
+ ]
284
+ pred = np.ones(self.type.shape, dtype=np.bool_)
285
+ pred[np.index_exp[tuple(slices)]] = False
284
286
  pred = stablehlo.constant(
285
287
  ir.DenseElementsAttr.get(
286
288
  np.packbits(pred, bitorder="little"),
@@ -232,7 +232,9 @@ def _aten_convolution(
232
232
 
233
233
  if bias is not None:
234
234
  # broadcast [C] to [NCHW]
235
- broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(
236
+ output_type, bias, ir.DenseI64ArrayAttr.get([1])
237
+ )
236
238
  res = stablehlo.add(
237
239
  lhs=res,
238
240
  rhs=broadcasted_bias,
@@ -16,12 +16,15 @@ import functools
16
16
  import logging
17
17
 
18
18
  from ai_edge_torch.odml_torch import jax_bridge
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
21
+ import jax.numpy as jnp
22
+ from jax._src.lib.mlir import ir
19
23
  import torch
20
24
  import torch_xla2.ops.jaten # Import to load torch_xla2 ops
21
25
  import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
22
26
 
23
- from . import registry
24
-
27
+ LoweringContext = context.LoweringContext
25
28
 
26
29
  @functools.cache
27
30
  def _log_usage(op):
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
258
261
  @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
259
262
  def _aten_copy(self, src, **kwargs):
260
263
  return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
264
+
265
+
266
+ # Schema:
267
+ # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
268
+ # -> Tensor
269
+ # Torch Reference:
270
+ # - https://pytorch.org/docs/stable/generated/torch.einsum.html
271
+ # - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
272
+ @registry.lower(torch.ops.aten.einsum.default)
273
+ def _aten_einsum_default(
274
+ lctx: LoweringContext,
275
+ equation: str,
276
+ tensors: list[ir.Value],
277
+ path=None,
278
+ ):
279
+ _log_usage(torch.ops.aten.einsum.default)
280
+
281
+ @jax_bridge.wrap
282
+ def jax_lowering(operands):
283
+ # Ignore the input path and let JAX determine the path.
284
+ return jnp.einsum(equation, *operands, optimize="optimal")
285
+
286
+ return jax_lowering(lctx, tuple(tensors))
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
20
20
  from ai_edge_torch.odml_torch.lowerings import utils
21
21
  from jax._src.lib.mlir import ir
22
22
  from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
23
24
  import torch
24
25
 
25
26
 
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
66
67
  normalized_rank = len(normalized_shape)
67
68
  if weight is not None:
68
69
  weight = stablehlo.broadcast_in_dim(
69
- data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ data_type,
71
+ weight,
72
+ ir.DenseI64ArrayAttr.get(
73
+ list(range(data_rank - normalized_rank, data_rank))
74
+ ),
70
75
  )
71
76
  output = stablehlo.multiply(weight, output)
72
77
  if bias is not None:
73
78
  bias = stablehlo.broadcast_in_dim(
74
- data_type, bias, list(range(data_rank - normalized_rank, data_rank))
79
+ data_type,
80
+ bias,
81
+ ir.DenseI64ArrayAttr.get(
82
+ list(range(data_rank - normalized_rank, data_rank))
83
+ ),
75
84
  )
76
85
  output = stablehlo.add(bias, output)
77
86
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Lowerings for PT2E torch.ops.quantized_decomposed ops."""
16
- from typing import Union, cast
16
+ from typing import Optional, Union, cast
17
17
 
18
18
  from ai_edge_torch.odml_torch.lowerings import context
19
19
  from ai_edge_torch.odml_torch.lowerings import utils
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
30
30
 
31
31
 
32
32
  def _uniform_quantized_type(
33
- stored_type: str | ir.Type,
34
- expressed_type: str | ir.Type,
33
+ stored_type: Union[str, ir.Type],
34
+ expressed_type: Union[str, ir.Type],
35
35
  *,
36
- scale=float | list[float] | tuple[float],
37
- zero_point=float | list[float] | tuple[float],
38
- storage_type_min: int | None = None,
39
- storage_type_max: int | None = None,
40
- channel_axis: int | None = None,
41
- channel_axis_size: int | None = None,
36
+ scale=Union[float, list[float], tuple[float]],
37
+ zero_point=Union[float, list[float], tuple[float]],
38
+ storage_type_min: Optional[int] = None,
39
+ storage_type_max: Optional[int] = None,
40
+ channel_axis: Optional[int] = None,
41
+ channel_axis_size: Optional[int] = None,
42
42
  ):
43
43
  """Polyfill for quant.UniformQuantizedType."""
44
44
  if storage_type_min and storage_type_max:
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export decompositions to run before lowering."""
16
+
17
+ import functools
18
+
19
+ import torch
20
+
21
+
22
+ @functools.cache
23
+ def decompositions():
24
+ # Base: Core ATen decompositions
25
+ decompositions = torch._decomp.core_aten_decompositions()
26
+
27
+ decompositions.update(
28
+ torch._decomp.get_decompositions([
29
+ torch.ops.aten.upsample_nearest2d,
30
+ torch.ops.aten._native_batch_norm_legit.no_stats,
31
+ torch.ops.aten._native_batch_norm_legit_functional,
32
+ torch.ops.aten._adaptive_avg_pool2d,
33
+ torch.ops.aten._adaptive_avg_pool3d,
34
+ torch.ops.aten.grid_sampler_2d,
35
+ torch.ops.aten.native_group_norm,
36
+ torch.ops.aten.native_dropout,
37
+ torch.ops.aten.reflection_pad1d,
38
+ torch.ops.aten.reflection_pad2d,
39
+ torch.ops.aten.reflection_pad3d,
40
+ torch.ops.aten.replication_pad1d,
41
+ torch.ops.aten.replication_pad2d,
42
+ torch.ops.aten.replication_pad3d,
43
+ torch.ops.aten.addmm,
44
+ ])
45
+ )
46
+
47
+ torch._decomp.remove_decompositions(
48
+ decompositions,
49
+ [
50
+ torch.ops.aten.roll,
51
+ # Torch's default einsum impl/decompositions is less efficient and
52
+ # optimized through converter than JAX's impl. Disable einsum
53
+ # decomposition to use JAX bridge for a more efficient lowering.
54
+ torch.ops.aten.einsum.default,
55
+ ],
56
+ )
57
+
58
+ # Override _safe_softmax decompositions with regular softmax.
59
+ # _safe_softmax introduces additional check-select ops to guard extreme
60
+ # input values to softmax, which could make the converted model inefficient
61
+ # on-device.
62
+ if hasattr(torch.ops.aten, "_safe_softmax"):
63
+ decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
64
+
65
+ return decompositions
@@ -26,7 +26,6 @@ class LoweringRegistry:
26
26
 
27
27
  def __init__(self):
28
28
  self.registered_ops = {}
29
- self.decompositions = {}
30
29
 
31
30
  def lookup(self, op_or_name):
32
31
  candidate = self._get_lowering(op_or_name)
@@ -52,33 +51,6 @@ class LoweringRegistry:
52
51
 
53
52
 
54
53
  global_registry = LoweringRegistry()
55
- global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
56
- global_registry.decompositions.update(
57
- torch._decomp.get_decompositions([
58
- torch.ops.aten.upsample_nearest2d,
59
- torch.ops.aten._native_batch_norm_legit.no_stats,
60
- torch.ops.aten._native_batch_norm_legit_functional,
61
- torch.ops.aten._adaptive_avg_pool2d,
62
- torch.ops.aten._adaptive_avg_pool3d,
63
- torch.ops.aten.grid_sampler_2d,
64
- torch.ops.aten.native_group_norm,
65
- torch.ops.aten.native_dropout,
66
- torch.ops.aten.reflection_pad1d,
67
- torch.ops.aten.reflection_pad2d,
68
- torch.ops.aten.reflection_pad3d,
69
- torch.ops.aten.replication_pad1d,
70
- torch.ops.aten.replication_pad2d,
71
- torch.ops.aten.replication_pad3d,
72
- torch.ops.aten.addmm,
73
- ])
74
- )
75
-
76
- torch._decomp.remove_decompositions(
77
- global_registry.decompositions,
78
- [
79
- torch.ops.aten.roll,
80
- ],
81
- )
82
54
 
83
55
 
84
56
  def lookup(op):
@@ -91,7 +63,3 @@ def lower(op):
91
63
  return lowering
92
64
 
93
65
  return inner
94
-
95
-
96
- def decompositions():
97
- return global_registry.decompositions
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241206"
16
+ __version__ = "0.3.0.dev20241214"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241206
3
+ Version: 0.3.0.dev20241214
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
11
11
  Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3 :: Only
14
- Classifier: Programming Language :: Python :: 3.9
15
14
  Classifier: Programming Language :: Python :: 3.10
16
15
  Classifier: Programming Language :: Python :: 3.11
17
16
  Classifier: Topic :: Scientific/Engineering
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
19
  Classifier: Topic :: Software Development
21
20
  Classifier: Topic :: Software Development :: Libraries
22
21
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
- Requires-Python: >=3.9
22
+ Requires-Python: >=3.10
24
23
  Description-Content-Type: text/markdown
25
24
  License-File: LICENSE
26
25
  Requires-Dist: numpy
@@ -28,10 +27,13 @@ Requires-Dist: scipy
28
27
  Requires-Dist: safetensors
29
28
  Requires-Dist: tabulate
30
29
  Requires-Dist: torch>=2.4.0
31
- Requires-Dist: torch-xla>=2.4.0
32
- Requires-Dist: tf-nightly>=2.19.0.dev20241121
30
+ Requires-Dist: tf-nightly>=2.19.0.dev20241201
33
31
  Requires-Dist: ai-edge-litert-nightly
34
32
  Requires-Dist: ai-edge-quantizer-nightly
33
+ Requires-Dist: jax
34
+ Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
35
+ Provides-Extra: torch-xla
36
+ Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
35
37
 
36
38
  Library that supports converting PyTorch models into a .tflite format, which can
37
39
  then be run with TensorFlow Lite and MediaPipe. This enables applications for