ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/test/loader_test.py +5 -4
  69. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  70. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  71. ai_edge_torch/generative/test/test_quantize.py +45 -48
  72. ai_edge_torch/generative/utilities/loader.py +55 -28
  73. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  74. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  75. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  76. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  79. ai_edge_torch/model.py +8 -5
  80. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  81. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  82. ai_edge_torch/quantize/quant_config.py +6 -2
  83. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  84. ai_edge_torch/version.py +16 -0
  85. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
  87. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -17,6 +17,7 @@ from .convert.converter import convert
17
17
  from .convert.converter import signature
18
18
  from .convert.to_channel_last_io import to_channel_last_io
19
19
  from .model import Model
20
+ from .version import __version__
20
21
 
21
22
 
22
23
  def load(path: str) -> Model:
@@ -18,10 +18,6 @@ import logging
18
18
  import os
19
19
  from typing import Optional
20
20
 
21
- import torch
22
- from torch.export import ExportedProgram
23
- from torch_xla import stablehlo
24
-
25
21
  from ai_edge_torch import model
26
22
  from ai_edge_torch.convert import conversion_utils as cutils
27
23
  from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
@@ -32,6 +28,9 @@ from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
32
28
  from ai_edge_torch.convert.fx_passes import run_passes
33
29
  from ai_edge_torch.generative.fx_passes import run_generative_passes
34
30
  from ai_edge_torch.quantize import quant_config as qcfg
31
+ import torch
32
+ from torch.export import ExportedProgram
33
+ from torch_xla import stablehlo
35
34
 
36
35
  os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
37
36
 
@@ -61,8 +60,9 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
61
60
  continue
62
61
 
63
62
  message = (
64
- "Your model {sig_name}is converted in training mode. "
65
- "Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
63
+ "Your model {sig_name}is converted in training mode. Please set the"
64
+ " module in evaluation mode with `module.eval()` for better on-device"
65
+ " performance and compatibility."
66
66
  )
67
67
  if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
68
68
  # User does not specify any signature names explicitly.
@@ -88,7 +88,9 @@ def convert_signatures(
88
88
  _warn_training_modules(signatures)
89
89
 
90
90
  exported_programs: torch.export.ExportedProgram = [
91
- torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
91
+ torch.export.export(
92
+ sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
93
+ )
92
94
  for sig in signatures
93
95
  ]
94
96
 
@@ -100,7 +102,9 @@ def convert_signatures(
100
102
  ]
101
103
 
102
104
  merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
103
- cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
105
+ cutils.merge_stablehlo_bundles(
106
+ shlo_bundles, signatures, exported_programs
107
+ )
104
108
  )
105
109
  del exported_programs
106
110
  del shlo_bundles
@@ -22,15 +22,15 @@ import logging
22
22
  import tempfile
23
23
  from typing import Any, Dict, List, Optional, Tuple, Union
24
24
 
25
+ from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
26
+ from ai_edge_torch.quantize import quant_config as qcfg
25
27
  import torch
26
28
  import torch.utils._pytree as pytree
27
29
  from torch_xla import stablehlo
28
30
 
29
- from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
30
- from ai_edge_torch.quantize import quant_config as qcfg
31
-
32
31
  try:
33
32
  import tensorflow as tf
33
+
34
34
  from tensorflow.compiler.tf2xla.python import xla as tfxla
35
35
 
36
36
  from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
@@ -90,18 +90,20 @@ class Signature:
90
90
  if context is None:
91
91
  for i, spec in enumerate(specs):
92
92
  if spec.children_specs:
93
- flat_names.extend(
94
- [
95
- f"{i}_{name}"
96
- for name in self._flat_kwarg_names(spec.children_specs, spec.context)
97
- ]
98
- )
93
+ flat_names.extend([
94
+ f"{i}_{name}"
95
+ for name in self._flat_kwarg_names(
96
+ spec.children_specs, spec.context
97
+ )
98
+ ])
99
99
  else:
100
100
  flat_names.append(f"{i}")
101
101
  else:
102
102
  flat_ctx = self._flatten_list(context)
103
103
  for prefix, spec in zip(flat_ctx, specs):
104
- leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
104
+ leaf_flat_names = self._flat_kwarg_names(
105
+ spec.children_specs, spec.context
106
+ )
105
107
  if leaf_flat_names:
106
108
  flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
107
109
  else:
@@ -125,7 +127,8 @@ class Signature:
125
127
 
126
128
 
127
129
  def exported_program_to_stablehlo_bundle(
128
- exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
130
+ exported_program: torch.export.ExportedProgram,
131
+ sample_args: tuple[torch.Tensor],
129
132
  ) -> stablehlo.StableHLOModelBundle:
130
133
  # Setting export_weights to False here so that pytorch/xla avoids copying the weights
131
134
  # to a numpy array which would lead to memory bloat. This means that the state_dict
@@ -146,7 +149,9 @@ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
146
149
  dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
147
150
  tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
148
151
  except Exception:
149
- logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
152
+ logging.info(
153
+ "Can not use dlpack to convert torch tensors. Falling back to numpy."
154
+ )
150
155
  nparray = torch_tensor.cpu().detach().numpy()
151
156
  tf_tensor = tf.convert_to_tensor(nparray)
152
157
 
@@ -154,7 +159,8 @@ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
154
159
 
155
160
 
156
161
  def _get_states(
157
- exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
162
+ exported_programs: list[torch.export.ExportedProgram],
163
+ signatures: list[Signature],
158
164
  ):
159
165
  for exported_program, signature in zip(exported_programs, signatures):
160
166
  args, _ = exported_program.example_inputs
@@ -166,7 +172,8 @@ def _get_states(
166
172
  # Only interested in Tensors that are part of the state (and not user input).
167
173
  if (
168
174
  not isinstance(tensor, torch.Tensor)
169
- or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
175
+ or input_spec.kind
176
+ == torch.export.graph_signature.InputKind.USER_INPUT
170
177
  ):
171
178
  continue
172
179
  yield signature, tensor, input_spec
@@ -192,9 +199,13 @@ def _gather_state_dict(
192
199
  deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
193
200
 
194
201
  state_dict = {}
195
- for signature, tensor, input_spec in _get_states(exported_programs, signatures):
202
+ for signature, tensor, input_spec in _get_states(
203
+ exported_programs, signatures
204
+ ):
196
205
  unique_id = _tensor_unique_id(tensor)
197
- state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
206
+ state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
207
+ unique_id
208
+ ]
198
209
 
199
210
  return state_dict
200
211
 
@@ -236,7 +247,9 @@ def _wrap_as_tf_func(
236
247
  ):
237
248
  def inner(*args):
238
249
  type_info = [sig.dtype for sig in func.meta.output_signature]
239
- shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
250
+ shape_info = [
251
+ _get_shape_with_dynamic(sig) for sig in func.meta.output_signature
252
+ ]
240
253
  call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
241
254
  return tfxla.call_module(
242
255
  tuple(call_args),
@@ -369,7 +382,9 @@ def convert_stablehlo_to_tflite(
369
382
  )
370
383
  )
371
384
 
372
- tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
385
+ tf_module._variables = (
386
+ list(bundle.state_dict.values()) + bundle.additional_constants
387
+ )
373
388
  del bundle
374
389
  gc.collect()
375
390
 
@@ -385,7 +400,8 @@ def convert_stablehlo_to_tflite(
385
400
  tf_module,
386
401
  temp_dir_path,
387
402
  signatures={
388
- sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
403
+ sig.name: tf_concrete_funcs[idx]
404
+ for idx, sig in enumerate(signatures)
389
405
  },
390
406
  )
391
407
  # Clean up intermediate memory early.
@@ -416,6 +432,8 @@ def convert_stablehlo_to_tflite(
416
432
  and quant_config._quantizer_mode
417
433
  == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
418
434
  ):
419
- tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
435
+ tflite_model = translate_recipe.quantize_model(
436
+ tflite_model, translated_recipe
437
+ )
420
438
 
421
439
  return tflite_model
@@ -17,12 +17,11 @@ from __future__ import annotations
17
17
 
18
18
  from typing import Any, Dict, Optional, Tuple, Union
19
19
 
20
- import torch
21
-
22
20
  from ai_edge_torch import model
23
21
  from ai_edge_torch.convert import conversion
24
22
  from ai_edge_torch.convert import conversion_utils as cutils
25
23
  from ai_edge_torch.quantize import quant_config as qcfg
24
+ import torch
26
25
 
27
26
 
28
27
  class Converter:
@@ -68,14 +67,20 @@ class Converter:
68
67
  """
69
68
 
70
69
  if name in [sig.name for sig in self._signatures]:
71
- raise ValueError(f"A signature with the provided name ({name}) is already added.")
70
+ raise ValueError(
71
+ f"A signature with the provided name ({name}) is already added."
72
+ )
72
73
 
73
74
  if sample_args is None and sample_kwargs is None:
74
75
  raise ValueError("sample_args or sample_kwargs must be provided.")
75
76
 
76
77
  self._signatures.append(
77
78
  cutils.Signature(
78
- name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
79
+ name,
80
+ module,
81
+ sample_args,
82
+ sample_kwargs,
83
+ dynamic_shapes=dynamic_shapes,
79
84
  )
80
85
  )
81
86
  return self
@@ -128,7 +133,8 @@ class Converter:
128
133
  )
129
134
  else: # module is provided but not args
130
135
  raise ValueError(
131
- "sample_args or sample_kwargs must be provided if a module is specified."
136
+ "sample_args or sample_kwargs must be provided if a module is"
137
+ " specified."
132
138
  )
133
139
 
134
140
  return conversion.convert_signatures(
@@ -15,10 +15,6 @@
15
15
 
16
16
  from typing import Sequence, Union
17
17
 
18
- from torch.export import ExportedProgram
19
- from torch.fx.passes.infra.pass_manager import pass_result_wrapper
20
- import torch.utils._pytree as pytree
21
-
22
18
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
23
19
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
24
20
  from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
@@ -28,6 +24,9 @@ from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import Bui
28
24
  from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
29
25
  from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
30
26
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27
+ from torch.export import ExportedProgram
28
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29
+ import torch.utils._pytree as pytree
31
30
 
32
31
 
33
32
  # TODO(cnchan): make a PassManager class.
@@ -32,14 +32,18 @@ class ExportedProgramPassResult(
32
32
 
33
33
  class ExportedProgramPassBase(abc.ABC):
34
34
 
35
- def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
35
+ def __call__(
36
+ self, exported_program: ExportedProgram
37
+ ) -> ExportedProgramPassResult:
36
38
  self.requires(exported_program)
37
39
  res = self.call(exported_program)
38
40
  self.ensures(exported_program)
39
41
  return res
40
42
 
41
43
  @abc.abstractmethod
42
- def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
44
+ def call(
45
+ self, exported_program: ExportedProgram
46
+ ) -> ExportedProgramPassResult:
43
47
  pass
44
48
 
45
49
  def requires(self, exported_program: ExportedProgram) -> None:
@@ -15,8 +15,10 @@
15
15
 
16
16
  import copy
17
17
  import functools
18
+ from functools import reduce
18
19
  from typing import Any, Callable
19
20
 
21
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
20
22
  import torch
21
23
  from torch.fx import GraphModule
22
24
  from torch.fx import Node
@@ -24,8 +26,6 @@ from torch.fx.passes.infra.pass_base import PassBase
24
26
  from torch.fx.passes.infra.pass_base import PassResult
25
27
  import torch.utils._pytree as pytree
26
28
 
27
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
28
-
29
29
  _composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
30
30
 
31
31
 
@@ -41,7 +41,9 @@ def _register_composite_builder(op):
41
41
  return inner
42
42
 
43
43
 
44
- def _tree_map_to_composite_attr_values(values, *, stringify_incompatible_values=True):
44
+ def _tree_map_to_composite_attr_values(
45
+ values, *, stringify_incompatible_values=True
46
+ ):
45
47
 
46
48
  def convert(value):
47
49
  nonlocal stringify_incompatible_values
@@ -65,7 +67,9 @@ class TorchOpArgumentsMapper:
65
67
 
66
68
  assert hasattr(op, "_schema")
67
69
  self.op = op
68
- self.arg_specs = [(spec.name, spec.default_value) for spec in op._schema.arguments]
70
+ self.arg_specs = [
71
+ (spec.name, spec.default_value) for spec in op._schema.arguments
72
+ ]
69
73
 
70
74
  def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
71
75
  """Inspect the op's schema and extract all its args and kwargs
@@ -110,16 +114,17 @@ def _aten_gelu(gm: GraphModule, node: Node):
110
114
  full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
111
115
 
112
116
  # TFLite supports exact and tanh approximate.
113
- if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
117
+ if (
118
+ full_kwargs["approximate"] != "none"
119
+ and full_kwargs["approximate"] != "tanh"
120
+ ):
114
121
  return op(*args, **kwargs)
115
122
 
116
123
  builder = StableHLOCompositeBuilder(
117
124
  "aten.gelu.default",
118
- attr=_tree_map_to_composite_attr_values(
119
- {
120
- "approximate": full_kwargs["approximate"],
121
- }
122
- ),
125
+ attr=_tree_map_to_composite_attr_values({
126
+ "approximate": full_kwargs["approximate"],
127
+ }),
123
128
  )
124
129
  full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
125
130
  output = op(full_kwargs["self"])
@@ -150,7 +155,10 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
150
155
  ):
151
156
  dim_output_size = int((dim_input_size + dim_stride - 1) / dim_stride)
152
157
  padding_needed = max(
153
- 0, (dim_output_size - 1) * dim_stride + dim_kernel_size - dim_input_size
158
+ 0,
159
+ (dim_output_size - 1) * dim_stride
160
+ + dim_kernel_size
161
+ - dim_input_size,
154
162
  )
155
163
  if padding_needed % 2 != 0:
156
164
  return False
@@ -193,16 +201,14 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
193
201
 
194
202
  builder = StableHLOCompositeBuilder(
195
203
  "aten.avg_pool2d.default",
196
- attr=_tree_map_to_composite_attr_values(
197
- {
198
- "kernel_size": full_kwargs["kernel_size"],
199
- "stride": full_kwargs["stride"],
200
- "padding": full_kwargs["padding"],
201
- "ceil_mode": full_kwargs["ceil_mode"],
202
- "count_include_pad": full_kwargs["count_include_pad"],
203
- "divisor_override": full_kwargs["divisor_override"],
204
- }
205
- ),
204
+ attr=_tree_map_to_composite_attr_values({
205
+ "kernel_size": full_kwargs["kernel_size"],
206
+ "stride": full_kwargs["stride"],
207
+ "padding": full_kwargs["padding"],
208
+ "ceil_mode": full_kwargs["ceil_mode"],
209
+ "count_include_pad": full_kwargs["count_include_pad"],
210
+ "divisor_override": full_kwargs["divisor_override"],
211
+ }),
206
212
  )
207
213
 
208
214
  full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
@@ -223,25 +229,25 @@ def _aten_embedding(gm: GraphModule, node: Node):
223
229
  full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
224
230
  _, embedding_dim = full_kwargs["weight"].size()
225
231
  idx = full_kwargs["indices"]
226
- # TODO(b/356458830): Handle relative positional encoding
227
- if len(idx.size()) == 2:
228
- idx = idx.type(torch.int)
229
- B, T = idx.size()
230
-
231
- idx = torch.reshape(idx, (B * T,))
232
-
233
- builder = StableHLOCompositeBuilder("odml.embedding_lookup")
234
- full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
235
- idx,
236
- full_kwargs["weight"],
237
- )
238
- output = op(**full_kwargs)
239
- output = builder.mark_outputs(output)
240
-
241
- output = torch.reshape(output, (B, T, embedding_dim))
242
- return output
243
- else:
244
- return op(**full_kwargs)
232
+
233
+ # Explicitly cast to INT32. This places the CastOp outside of the HLFB.
234
+ idx = idx.type(torch.int)
235
+ original_idx_shape = idx.size()
236
+
237
+ # Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
238
+ idx = torch.reshape(idx, (idx.numel(),))
239
+
240
+ builder = StableHLOCompositeBuilder("odml.embedding_lookup")
241
+ full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
242
+ idx,
243
+ full_kwargs["weight"],
244
+ )
245
+ output = op(**full_kwargs)
246
+ output = builder.mark_outputs(output)
247
+
248
+ # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
249
+ output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
250
+ return output
245
251
 
246
252
  node.target = embedding
247
253
 
@@ -15,23 +15,20 @@
15
15
 
16
16
  import functools
17
17
 
18
- import torch
19
-
20
18
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
21
19
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
22
20
  from ai_edge_torch.hlfb import mark_pattern
21
+ import torch
23
22
 
24
23
  # For torch nightly released after mid June 2024,
25
24
  # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
25
  # but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
27
26
  # This behavior would our pattern matching based composite builder.
28
27
  # It requires the pattern and model graph to get decomposed first for backward compatibility.
29
- _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions(
30
- [
31
- torch.ops.aten.upsample_bilinear2d.vec,
32
- torch.ops.aten.upsample_nearest2d.vec,
33
- ]
34
- )
28
+ _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
29
+ torch.ops.aten.upsample_bilinear2d.vec,
30
+ torch.ops.aten.upsample_nearest2d.vec,
31
+ ])
35
32
 
36
33
 
37
34
  @functools.cache
@@ -84,7 +81,9 @@ def _get_upsample_bilinear2d_align_corners_pattern():
84
81
  def _get_interpolate_nearest2d_pattern():
85
82
  pattern = mark_pattern.Pattern(
86
83
  "tfl.resize_nearest_neighbor",
87
- lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
84
+ lambda x: torch.nn.functional.interpolate(
85
+ x, scale_factor=2, mode="nearest"
86
+ ),
88
87
  export_args=(torch.rand(1, 3, 100, 100),),
89
88
  decomp_table=_INTERPOLATE_DECOMPOSITIONS,
90
89
  )
@@ -112,7 +111,9 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
112
111
  ]
113
112
 
114
113
  def call(self, exported_program: torch.export.ExportedProgram):
115
- exported_program = exported_program.run_decompositions(_INTERPOLATE_DECOMPOSITIONS)
114
+ exported_program = exported_program.run_decompositions(
115
+ _INTERPOLATE_DECOMPOSITIONS
116
+ )
116
117
 
117
118
  graph_module = exported_program.graph_module
118
119
  for pattern in self._patterns:
@@ -13,11 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch.export import ExportedProgram
18
-
19
16
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
20
17
  from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
18
+ import torch
19
+ from torch.export import ExportedProgram
21
20
 
22
21
  # A dummy decomp table for running ExportedProgram.run_decompositions without
23
22
  # any op decompositions but just aot_export_module. Due to the check in
@@ -15,13 +15,12 @@
15
15
  import dataclasses
16
16
  import operator
17
17
 
18
- import torch
19
- from torch.fx import Node
20
-
21
18
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
19
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
23
20
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
24
21
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
22
+ import torch
23
+ from torch.fx import Node
25
24
 
26
25
  aten = torch.ops.aten
27
26
 
@@ -150,7 +149,9 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
150
149
 
151
150
 
152
151
  @layout_sensitive_inputs_getters.register(aten.convolution)
153
- @layout_sensitive_inputs_getters.register(aten._native_batch_norm_legit_no_training)
152
+ @layout_sensitive_inputs_getters.register(
153
+ aten._native_batch_norm_legit_no_training
154
+ )
154
155
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
155
156
  def _first_arg_getter(node):
156
157
  return [node.args[0]]
@@ -174,7 +175,11 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
174
175
  @nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
175
176
  def _aten_norm_checker(node):
176
177
  val = node.meta.get("val")
177
- if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
178
+ if (
179
+ not isinstance(val, (list, tuple))
180
+ or not val
181
+ or not hasattr(val[0], "shape")
182
+ ):
178
183
  return NHWCable(can_be=False, must_be=False)
179
184
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
180
185
 
@@ -182,9 +187,15 @@ def _aten_norm_checker(node):
182
187
  @nhwcable_node_checkers.register(aten.native_group_norm)
183
188
  def _aten_native_group_norm_checker(node):
184
189
  val = node.meta.get("val")
185
- if not isinstance(val, (list, tuple)) or not val or not hasattr(val[0], "shape"):
190
+ if (
191
+ not isinstance(val, (list, tuple))
192
+ or not val
193
+ or not hasattr(val[0], "shape")
194
+ ):
186
195
  return NHWCable(can_be=False, must_be=False)
187
- if len(node.args) >= 3 and (node.args[1] is not None or node.args[2] is not None):
196
+ if len(node.args) >= 3 and (
197
+ node.args[1] is not None or node.args[2] is not None
198
+ ):
188
199
  # Disable NHWC rewriter due to precision issue with weight and bias.
189
200
  # TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
190
201
  return NHWCable(can_be=False, must_be=False)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
-
18
16
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
19
17
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
18
+ import torch
20
19
 
21
20
 
22
21
  def partition(graph_module: torch.fx.GraphModule):
@@ -45,7 +44,9 @@ def partition(graph_module: torch.fx.GraphModule):
45
44
 
46
45
  layout_sensitive_inputs = layout_check.get_layout_sensitive_inputs(node)
47
46
 
48
- should_be_nhwc = any(map(layout_mark.is_nhwc_node, layout_sensitive_inputs))
47
+ should_be_nhwc = any(
48
+ map(layout_mark.is_nhwc_node, layout_sensitive_inputs)
49
+ )
49
50
  for input_node in layout_sensitive_inputs:
50
51
  if not layout_mark.is_nhwc_node(input_node) and not layout_check.is_4d(
51
52
  input_node
@@ -17,13 +17,12 @@ import collections
17
17
  import dataclasses
18
18
  import itertools
19
19
 
20
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
21
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
20
22
  import numpy as np
21
23
  import scipy
22
24
  import torch
23
25
 
24
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
25
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
26
-
27
26
 
28
27
  def can_partition(graph_module: torch.fx.GraphModule):
29
28
  """Returns true if the input graph_module can be partitioned by min cut solver
@@ -83,7 +82,10 @@ class MinCutSolver:
83
82
  def graph(self):
84
83
  edges = np.array(self.edges)
85
84
  return scipy.sparse.csr_matrix(
86
- (np.minimum(edges[:, 2], MinCutSolver.INF_COST), (edges[:, 0], edges[:, 1])),
85
+ (
86
+ np.minimum(edges[:, 2], MinCutSolver.INF_COST),
87
+ (edges[:, 0], edges[:, 1]),
88
+ ),
87
89
  shape=(self._nodes_cnt, self._nodes_cnt),
88
90
  dtype=np.int32,
89
91
  )
@@ -14,13 +14,12 @@
14
14
  # ==============================================================================
15
15
  import operator
16
16
 
17
- import torch
18
- from torch.fx import Node
19
- import torch.utils._pytree as pytree
20
-
21
17
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
22
18
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
23
19
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
20
+ import torch
21
+ from torch.fx import Node
22
+ import torch.utils._pytree as pytree
24
23
 
25
24
  aten = torch.ops.aten
26
25
 
@@ -349,7 +348,12 @@ def _aten_native_group_norm(node):
349
348
  ):
350
349
  input_reshaped = torch.reshape(
351
350
  input,
352
- [batch_size, flattened_inner_size, num_groups, num_channels // num_groups],
351
+ [
352
+ batch_size,
353
+ flattened_inner_size,
354
+ num_groups,
355
+ num_channels // num_groups,
356
+ ],
353
357
  )
354
358
  reduction_dims = [1, 3]
355
359
 
@@ -12,9 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import torch
16
-
17
15
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
16
+ import torch
18
17
 
19
18
 
20
19
  class OpFuncRegistry(dict):