ai-edge-torch-nightly 0.3.0.dev20241206__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. ai_edge_torch/__init__.py +1 -1
  2. ai_edge_torch/_config.py +52 -0
  3. ai_edge_torch/_convert/test/test_convert.py +1 -2
  4. ai_edge_torch/debug/test/test_culprit.py +8 -3
  5. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  6. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  7. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  8. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  9. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  10. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  11. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  12. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  13. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  14. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  15. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  16. ai_edge_torch/generative/examples/paligemma/decoder.py +10 -9
  17. ai_edge_torch/generative/examples/paligemma/paligemma.py +11 -1
  18. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  19. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  20. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  21. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  22. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  23. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  24. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  25. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  26. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  27. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  29. ai_edge_torch/generative/layers/attention.py +2 -6
  30. ai_edge_torch/generative/layers/kv_cache.py +24 -18
  31. ai_edge_torch/generative/layers/normalization.py +1 -3
  32. ai_edge_torch/generative/test/test_kv_cache.py +3 -3
  33. ai_edge_torch/generative/test/test_model_conversion.py +12 -14
  34. ai_edge_torch/generative/test/test_model_conversion_large.py +63 -59
  35. ai_edge_torch/generative/test/utils.py +31 -6
  36. ai_edge_torch/generative/utilities/converter.py +25 -4
  37. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  38. ai_edge_torch/generative/utilities/verifier.py +16 -2
  39. ai_edge_torch/lowertools/_shim.py +4 -2
  40. ai_edge_torch/lowertools/test_utils.py +4 -2
  41. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
  42. ai_edge_torch/odml_torch/lowerings/_basic.py +5 -3
  43. ai_edge_torch/odml_torch/lowerings/_convolution.py +3 -1
  44. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +28 -2
  45. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +11 -2
  46. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +9 -9
  47. ai_edge_torch/odml_torch/lowerings/decomp.py +65 -0
  48. ai_edge_torch/odml_torch/lowerings/registry.py +0 -32
  49. ai_edge_torch/version.py +1 -1
  50. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/METADATA +7 -5
  51. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/RECORD +54 -54
  52. ai_edge_torch/config.py +0 -27
  53. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +0 -283
  54. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/LICENSE +0 -0
  55. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/WHEEL +0 -0
  56. {ai_edge_torch_nightly-0.3.0.dev20241206.dist-info → ai_edge_torch_nightly-0.3.0.dev20241214.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,28 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from typing import Union
18
+ from functools import partial
19
+ from typing import Any, Union
19
20
 
20
21
  from ai_edge_torch._convert import converter as converter_utils
21
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
24
26
  import torch
27
+ import torch.nn as nn
28
+
29
+
30
+ class ExportableModule(torch.nn.Module):
31
+
32
+ def __init__(self, module, **extra_kwargs):
33
+ super().__init__()
34
+ self.module = module
35
+ self.extra_kwargs = extra_kwargs
36
+
37
+ def forward(self, *export_args, **export_kwargs):
38
+ full_kwargs = {**export_kwargs, **self.extra_kwargs}
39
+ return self.module(*export_args, **full_kwargs)
25
40
 
26
41
 
27
42
  def convert_to_tflite(
@@ -31,6 +46,7 @@ def convert_to_tflite(
31
46
  pixel_values_size: torch.Size = None,
32
47
  quantize: bool = True,
33
48
  config: cfg.ModelConfig = None,
49
+ export_config: ExportConfig = None,
34
50
  ):
35
51
  """Converts a nn.Module model to multi-signature tflite model.
36
52
 
@@ -97,6 +113,11 @@ def convert_to_tflite(
97
113
  )
98
114
 
99
115
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
+
117
+ # For export, we create a module that captures any non-exportable,
118
+ # arugments, e.g. the generation config object.
119
+ mod = ExportableModule(pytorch_model, export_config=export_config)
120
+
100
121
  converter = converter_utils.Converter()
101
122
  for i in range(len(prefill_seq_lens)):
102
123
  prefill_seq_len = prefill_seq_lens[i]
@@ -108,7 +129,7 @@ def convert_to_tflite(
108
129
  prefill_signature_name = f'prefill_{prefill_seq_len}'
109
130
  converter.add_signature(
110
131
  prefill_signature_name,
111
- pytorch_model,
132
+ mod,
112
133
  sample_kwargs={
113
134
  'tokens': prefill_tokens,
114
135
  'input_pos': prefill_input_pos,
@@ -118,7 +139,7 @@ def convert_to_tflite(
118
139
  if prefill_pixel_values is not None:
119
140
  converter.add_signature(
120
141
  prefill_signature_name + '_pixel',
121
- pytorch_model,
142
+ mod,
122
143
  sample_kwargs={
123
144
  'tokens': prefill_tokens,
124
145
  'input_pos': prefill_input_pos,
@@ -129,7 +150,7 @@ def convert_to_tflite(
129
150
 
130
151
  converter.add_signature(
131
152
  'decode',
132
- pytorch_model,
153
+ mod,
133
154
  sample_kwargs={
134
155
  'tokens': decode_token,
135
156
  'input_pos': decode_input_pos,
@@ -16,7 +16,8 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
- from typing import Tuple
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
45
46
  TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
46
47
 
47
48
 
49
+ @dataclass
50
+ class ExportConfig:
51
+ """Model generating configuration settings."""
52
+
53
+ # On prefill signatures, should the model produce logit output?
54
+ # When False, only decode signatures will produce output.
55
+ output_logits_on_prefill: bool = False
56
+
57
+
48
58
  class DecoderOnlyModel(nn.Module):
49
59
  """A simple decoder-only transformer model built from the Edge Generative API.
50
60
 
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
93
103
  tokens: torch.Tensor,
94
104
  input_pos: torch.Tensor,
95
105
  kv_cache: kv_utils.KVCache,
106
+ export_config: Optional[ExportConfig] = None,
96
107
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
108
  _, seq_len = tokens.size()
98
109
  assert self.config.max_seq_len >= seq_len, (
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
108
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
109
120
 
110
121
  return self.forward_with_embeds(
111
- input_embeds, rope, mask, input_pos, kv_cache
122
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
112
123
  )
113
124
 
114
125
  def forward_with_embeds(
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
118
129
  mask: torch.Tensor,
119
130
  input_pos: torch.Tensor,
120
131
  kv_cache: kv_utils.KVCache,
132
+ export_config: Optional[ExportConfig] = None,
121
133
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
122
134
  """Forwards the model with input embeddings."""
123
135
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
137
149
  updated_kv_entires.append(kv_entry)
138
150
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
139
151
 
152
+ if export_config is not None:
153
+ if (
154
+ torch.numel(input_pos) > 1
155
+ and not export_config.output_logits_on_prefill
156
+ ):
157
+ return {"kv_cache": updated_kv_cache}
158
+
140
159
  x = self.final_norm(x)
141
160
  logits = self.lm_head(x) # (b, t, vocab_size)
142
161
  return {"logits": logits, "kv_cache": updated_kv_cache}
@@ -146,8 +165,9 @@ def build_decoder_only_model(
146
165
  checkpoint_path: str,
147
166
  config: cfg.ModelConfig,
148
167
  tensor_names: loading_utils.ModelLoader.TensorNames,
149
- ) -> DecoderOnlyModel:
150
- transformer = DecoderOnlyModel(config)
168
+ model_class: type[nn.Module] = DecoderOnlyModel,
169
+ ) -> nn.Module:
170
+ transformer = model_class(config)
151
171
  loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
152
172
  loader.load(
153
173
  transformer, strict=not config.lm_head_share_weight_with_embedding
@@ -19,6 +19,7 @@ import logging
19
19
  from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22
23
  import torch
23
24
 
24
25
 
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
40
41
  """
41
42
  super().__init__()
42
43
  self.model = model
44
+ self.export_config = ExportConfig(output_logits_on_prefill=True)
43
45
 
44
46
  def forward(
45
47
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
103
105
  Returns:
104
106
  The output logits and the updated KV cache.
105
107
  """
108
+ # Verification requires logit outputs on prefill for comparison.
109
+ if (
110
+ self.export_config is not None
111
+ and not self.export_config.output_logits_on_prefill
112
+ ):
113
+ raise ValueError("Verifier requires logit output on prefill.")
106
114
  # Since the reauthored model doesn't include keyword arguments, pass
107
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
108
116
  if pixel_values is None:
109
- output = self.model.forward(tokens, input_pos, kv_cache)
117
+ output = self.model.forward(
118
+ tokens, input_pos, kv_cache, export_config=self.export_config
119
+ )
110
120
  else:
111
121
  output = self.model.forward(
112
- tokens, input_pos, kv_cache, pixel_values=pixel_values
122
+ tokens,
123
+ input_pos,
124
+ kv_cache,
125
+ pixel_values=pixel_values,
126
+ export_config=self.export_config,
113
127
  )
114
128
  return output["logits"], output["kv_cache"]
115
129
 
@@ -15,13 +15,15 @@
15
15
 
16
16
  from typing import Any, Optional
17
17
 
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from ai_edge_torch._convert import signature
20
20
  from ai_edge_torch.quantize import quant_config as qcfg
21
21
  import torch
22
22
 
23
+ config = _config.config
24
+
23
25
  # isort: off
24
- if config.Config.use_torch_xla:
26
+ if config.use_torch_xla:
25
27
  from ai_edge_torch.lowertools import torch_xla_utils as utils
26
28
  from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
27
29
  from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
@@ -15,9 +15,11 @@
15
15
 
16
16
  import re
17
17
  from typing import Optional
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from absl.testing import absltest as googletest
20
20
 
21
+ config = _config.config
22
+
21
23
 
22
24
  def _extract_backend_configs(mlir):
23
25
  mlir = mlir.replace("\\22", '"')
@@ -38,7 +40,7 @@ def assert_string_count(
38
40
  if odml_torch_attr_counter is None:
39
41
  odml_torch_attr_counter = {}
40
42
 
41
- if config.Config.use_torch_xla:
43
+ if config.use_torch_xla:
42
44
  for key in torch_xla_pattern_counter:
43
45
  test_case.assertEqual(
44
46
  mlir.count(key),
@@ -21,6 +21,6 @@ from . import _quantized_decomposed
21
21
  from . import context
22
22
  from . import registry
23
23
  from . import utils
24
- from .registry import decompositions
24
+ from .decomp import decompositions
25
25
  from .registry import lookup
26
26
  from .registry import lower
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
276
276
  interior_padding if i == dim else 0 for i in range(rank)
277
277
  ],
278
278
  )
279
- pred = np.ones(self.type.shape, dtype=np.bool_)
280
- pred[*[
279
+
280
+ slices = [
281
281
  slice(start, end, step) if i == dim else slice(None, None, None)
282
282
  for i in range(rank)
283
- ]] = False
283
+ ]
284
+ pred = np.ones(self.type.shape, dtype=np.bool_)
285
+ pred[np.index_exp[tuple(slices)]] = False
284
286
  pred = stablehlo.constant(
285
287
  ir.DenseElementsAttr.get(
286
288
  np.packbits(pred, bitorder="little"),
@@ -232,7 +232,9 @@ def _aten_convolution(
232
232
 
233
233
  if bias is not None:
234
234
  # broadcast [C] to [NCHW]
235
- broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(
236
+ output_type, bias, ir.DenseI64ArrayAttr.get([1])
237
+ )
236
238
  res = stablehlo.add(
237
239
  lhs=res,
238
240
  rhs=broadcasted_bias,
@@ -16,12 +16,15 @@ import functools
16
16
  import logging
17
17
 
18
18
  from ai_edge_torch.odml_torch import jax_bridge
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
21
+ import jax.numpy as jnp
22
+ from jax._src.lib.mlir import ir
19
23
  import torch
20
24
  import torch_xla2.ops.jaten # Import to load torch_xla2 ops
21
25
  import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
22
26
 
23
- from . import registry
24
-
27
+ LoweringContext = context.LoweringContext
25
28
 
26
29
  @functools.cache
27
30
  def _log_usage(op):
@@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
258
261
  @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
259
262
  def _aten_copy(self, src, **kwargs):
260
263
  return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
264
+
265
+
266
+ # Schema:
267
+ # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
268
+ # -> Tensor
269
+ # Torch Reference:
270
+ # - https://pytorch.org/docs/stable/generated/torch.einsum.html
271
+ # - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
272
+ @registry.lower(torch.ops.aten.einsum.default)
273
+ def _aten_einsum_default(
274
+ lctx: LoweringContext,
275
+ equation: str,
276
+ tensors: list[ir.Value],
277
+ path=None,
278
+ ):
279
+ _log_usage(torch.ops.aten.einsum.default)
280
+
281
+ @jax_bridge.wrap
282
+ def jax_lowering(operands):
283
+ # Ignore the input path and let JAX determine the path.
284
+ return jnp.einsum(equation, *operands, optimize="optimal")
285
+
286
+ return jax_lowering(lctx, tuple(tensors))
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
20
20
  from ai_edge_torch.odml_torch.lowerings import utils
21
21
  from jax._src.lib.mlir import ir
22
22
  from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
23
24
  import torch
24
25
 
25
26
 
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
66
67
  normalized_rank = len(normalized_shape)
67
68
  if weight is not None:
68
69
  weight = stablehlo.broadcast_in_dim(
69
- data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ data_type,
71
+ weight,
72
+ ir.DenseI64ArrayAttr.get(
73
+ list(range(data_rank - normalized_rank, data_rank))
74
+ ),
70
75
  )
71
76
  output = stablehlo.multiply(weight, output)
72
77
  if bias is not None:
73
78
  bias = stablehlo.broadcast_in_dim(
74
- data_type, bias, list(range(data_rank - normalized_rank, data_rank))
79
+ data_type,
80
+ bias,
81
+ ir.DenseI64ArrayAttr.get(
82
+ list(range(data_rank - normalized_rank, data_rank))
83
+ ),
75
84
  )
76
85
  output = stablehlo.add(bias, output)
77
86
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Lowerings for PT2E torch.ops.quantized_decomposed ops."""
16
- from typing import Union, cast
16
+ from typing import Optional, Union, cast
17
17
 
18
18
  from ai_edge_torch.odml_torch.lowerings import context
19
19
  from ai_edge_torch.odml_torch.lowerings import utils
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
30
30
 
31
31
 
32
32
  def _uniform_quantized_type(
33
- stored_type: str | ir.Type,
34
- expressed_type: str | ir.Type,
33
+ stored_type: Union[str, ir.Type],
34
+ expressed_type: Union[str, ir.Type],
35
35
  *,
36
- scale=float | list[float] | tuple[float],
37
- zero_point=float | list[float] | tuple[float],
38
- storage_type_min: int | None = None,
39
- storage_type_max: int | None = None,
40
- channel_axis: int | None = None,
41
- channel_axis_size: int | None = None,
36
+ scale=Union[float, list[float], tuple[float]],
37
+ zero_point=Union[float, list[float], tuple[float]],
38
+ storage_type_min: Optional[int] = None,
39
+ storage_type_max: Optional[int] = None,
40
+ channel_axis: Optional[int] = None,
41
+ channel_axis_size: Optional[int] = None,
42
42
  ):
43
43
  """Polyfill for quant.UniformQuantizedType."""
44
44
  if storage_type_min and storage_type_max:
@@ -0,0 +1,65 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch export decompositions to run before lowering."""
16
+
17
+ import functools
18
+
19
+ import torch
20
+
21
+
22
+ @functools.cache
23
+ def decompositions():
24
+ # Base: Core ATen decompositions
25
+ decompositions = torch._decomp.core_aten_decompositions()
26
+
27
+ decompositions.update(
28
+ torch._decomp.get_decompositions([
29
+ torch.ops.aten.upsample_nearest2d,
30
+ torch.ops.aten._native_batch_norm_legit.no_stats,
31
+ torch.ops.aten._native_batch_norm_legit_functional,
32
+ torch.ops.aten._adaptive_avg_pool2d,
33
+ torch.ops.aten._adaptive_avg_pool3d,
34
+ torch.ops.aten.grid_sampler_2d,
35
+ torch.ops.aten.native_group_norm,
36
+ torch.ops.aten.native_dropout,
37
+ torch.ops.aten.reflection_pad1d,
38
+ torch.ops.aten.reflection_pad2d,
39
+ torch.ops.aten.reflection_pad3d,
40
+ torch.ops.aten.replication_pad1d,
41
+ torch.ops.aten.replication_pad2d,
42
+ torch.ops.aten.replication_pad3d,
43
+ torch.ops.aten.addmm,
44
+ ])
45
+ )
46
+
47
+ torch._decomp.remove_decompositions(
48
+ decompositions,
49
+ [
50
+ torch.ops.aten.roll,
51
+ # Torch's default einsum impl/decompositions is less efficient and
52
+ # optimized through converter than JAX's impl. Disable einsum
53
+ # decomposition to use JAX bridge for a more efficient lowering.
54
+ torch.ops.aten.einsum.default,
55
+ ],
56
+ )
57
+
58
+ # Override _safe_softmax decompositions with regular softmax.
59
+ # _safe_softmax introduces additional check-select ops to guard extreme
60
+ # input values to softmax, which could make the converted model inefficient
61
+ # on-device.
62
+ if hasattr(torch.ops.aten, "_safe_softmax"):
63
+ decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
64
+
65
+ return decompositions
@@ -26,7 +26,6 @@ class LoweringRegistry:
26
26
 
27
27
  def __init__(self):
28
28
  self.registered_ops = {}
29
- self.decompositions = {}
30
29
 
31
30
  def lookup(self, op_or_name):
32
31
  candidate = self._get_lowering(op_or_name)
@@ -52,33 +51,6 @@ class LoweringRegistry:
52
51
 
53
52
 
54
53
  global_registry = LoweringRegistry()
55
- global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
56
- global_registry.decompositions.update(
57
- torch._decomp.get_decompositions([
58
- torch.ops.aten.upsample_nearest2d,
59
- torch.ops.aten._native_batch_norm_legit.no_stats,
60
- torch.ops.aten._native_batch_norm_legit_functional,
61
- torch.ops.aten._adaptive_avg_pool2d,
62
- torch.ops.aten._adaptive_avg_pool3d,
63
- torch.ops.aten.grid_sampler_2d,
64
- torch.ops.aten.native_group_norm,
65
- torch.ops.aten.native_dropout,
66
- torch.ops.aten.reflection_pad1d,
67
- torch.ops.aten.reflection_pad2d,
68
- torch.ops.aten.reflection_pad3d,
69
- torch.ops.aten.replication_pad1d,
70
- torch.ops.aten.replication_pad2d,
71
- torch.ops.aten.replication_pad3d,
72
- torch.ops.aten.addmm,
73
- ])
74
- )
75
-
76
- torch._decomp.remove_decompositions(
77
- global_registry.decompositions,
78
- [
79
- torch.ops.aten.roll,
80
- ],
81
- )
82
54
 
83
55
 
84
56
  def lookup(op):
@@ -91,7 +63,3 @@ def lower(op):
91
63
  return lowering
92
64
 
93
65
  return inner
94
-
95
-
96
- def decompositions():
97
- return global_registry.decompositions
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241206"
16
+ __version__ = "0.3.0.dev20241214"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241206
3
+ Version: 0.3.0.dev20241214
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
11
11
  Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3 :: Only
14
- Classifier: Programming Language :: Python :: 3.9
15
14
  Classifier: Programming Language :: Python :: 3.10
16
15
  Classifier: Programming Language :: Python :: 3.11
17
16
  Classifier: Topic :: Scientific/Engineering
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
19
  Classifier: Topic :: Software Development
21
20
  Classifier: Topic :: Software Development :: Libraries
22
21
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
- Requires-Python: >=3.9
22
+ Requires-Python: >=3.10
24
23
  Description-Content-Type: text/markdown
25
24
  License-File: LICENSE
26
25
  Requires-Dist: numpy
@@ -28,10 +27,13 @@ Requires-Dist: scipy
28
27
  Requires-Dist: safetensors
29
28
  Requires-Dist: tabulate
30
29
  Requires-Dist: torch>=2.4.0
31
- Requires-Dist: torch-xla>=2.4.0
32
- Requires-Dist: tf-nightly>=2.19.0.dev20241121
30
+ Requires-Dist: tf-nightly>=2.19.0.dev20241201
33
31
  Requires-Dist: ai-edge-litert-nightly
34
32
  Requires-Dist: ai-edge-quantizer-nightly
33
+ Requires-Dist: jax
34
+ Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
35
+ Provides-Extra: torch-xla
36
+ Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
35
37
 
36
38
  Library that supports converting PyTorch models into a .tflite format, which can
37
39
  then be run with TensorFlow Lite and MediaPipe. This enables applications for