ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -48
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
|
@@ -18,10 +18,6 @@ import logging
|
|
|
18
18
|
import os
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
import torch
|
|
22
|
-
from torch.export import ExportedProgram
|
|
23
|
-
from torch_xla import stablehlo
|
|
24
|
-
|
|
25
21
|
from ai_edge_torch import model
|
|
26
22
|
from ai_edge_torch.convert import conversion_utils as cutils
|
|
27
23
|
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
|
|
@@ -32,6 +28,9 @@ from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
|
|
|
32
28
|
from ai_edge_torch.convert.fx_passes import run_passes
|
|
33
29
|
from ai_edge_torch.generative.fx_passes import run_generative_passes
|
|
34
30
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
31
|
+
import torch
|
|
32
|
+
from torch.export import ExportedProgram
|
|
33
|
+
from torch_xla import stablehlo
|
|
35
34
|
|
|
36
35
|
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
|
37
36
|
|
|
@@ -61,8 +60,9 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
|
|
|
61
60
|
continue
|
|
62
61
|
|
|
63
62
|
message = (
|
|
64
|
-
"Your model {sig_name}is converted in training mode. "
|
|
65
|
-
"
|
|
63
|
+
"Your model {sig_name}is converted in training mode. Please set the"
|
|
64
|
+
" module in evaluation mode with `module.eval()` for better on-device"
|
|
65
|
+
" performance and compatibility."
|
|
66
66
|
)
|
|
67
67
|
if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
|
|
68
68
|
# User does not specify any signature names explicitly.
|
|
@@ -88,7 +88,9 @@ def convert_signatures(
|
|
|
88
88
|
_warn_training_modules(signatures)
|
|
89
89
|
|
|
90
90
|
exported_programs: torch.export.ExportedProgram = [
|
|
91
|
-
torch.export.export(
|
|
91
|
+
torch.export.export(
|
|
92
|
+
sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
|
|
93
|
+
)
|
|
92
94
|
for sig in signatures
|
|
93
95
|
]
|
|
94
96
|
|
|
@@ -100,7 +102,9 @@ def convert_signatures(
|
|
|
100
102
|
]
|
|
101
103
|
|
|
102
104
|
merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
|
|
103
|
-
cutils.merge_stablehlo_bundles(
|
|
105
|
+
cutils.merge_stablehlo_bundles(
|
|
106
|
+
shlo_bundles, signatures, exported_programs
|
|
107
|
+
)
|
|
104
108
|
)
|
|
105
109
|
del exported_programs
|
|
106
110
|
del shlo_bundles
|
|
@@ -22,15 +22,15 @@ import logging
|
|
|
22
22
|
import tempfile
|
|
23
23
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24
24
|
|
|
25
|
+
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
26
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
25
27
|
import torch
|
|
26
28
|
import torch.utils._pytree as pytree
|
|
27
29
|
from torch_xla import stablehlo
|
|
28
30
|
|
|
29
|
-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
30
|
-
from ai_edge_torch.quantize import quant_config as qcfg
|
|
31
|
-
|
|
32
31
|
try:
|
|
33
32
|
import tensorflow as tf
|
|
33
|
+
|
|
34
34
|
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
35
35
|
|
|
36
36
|
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
|
|
@@ -90,18 +90,20 @@ class Signature:
|
|
|
90
90
|
if context is None:
|
|
91
91
|
for i, spec in enumerate(specs):
|
|
92
92
|
if spec.children_specs:
|
|
93
|
-
flat_names.extend(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
)
|
|
93
|
+
flat_names.extend([
|
|
94
|
+
f"{i}_{name}"
|
|
95
|
+
for name in self._flat_kwarg_names(
|
|
96
|
+
spec.children_specs, spec.context
|
|
97
|
+
)
|
|
98
|
+
])
|
|
99
99
|
else:
|
|
100
100
|
flat_names.append(f"{i}")
|
|
101
101
|
else:
|
|
102
102
|
flat_ctx = self._flatten_list(context)
|
|
103
103
|
for prefix, spec in zip(flat_ctx, specs):
|
|
104
|
-
leaf_flat_names = self._flat_kwarg_names(
|
|
104
|
+
leaf_flat_names = self._flat_kwarg_names(
|
|
105
|
+
spec.children_specs, spec.context
|
|
106
|
+
)
|
|
105
107
|
if leaf_flat_names:
|
|
106
108
|
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
107
109
|
else:
|
|
@@ -125,7 +127,8 @@ class Signature:
|
|
|
125
127
|
|
|
126
128
|
|
|
127
129
|
def exported_program_to_stablehlo_bundle(
|
|
128
|
-
exported_program: torch.export.ExportedProgram,
|
|
130
|
+
exported_program: torch.export.ExportedProgram,
|
|
131
|
+
sample_args: tuple[torch.Tensor],
|
|
129
132
|
) -> stablehlo.StableHLOModelBundle:
|
|
130
133
|
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
|
|
131
134
|
# to a numpy array which would lead to memory bloat. This means that the state_dict
|
|
@@ -146,7 +149,9 @@ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
|
|
|
146
149
|
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
147
150
|
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
148
151
|
except Exception:
|
|
149
|
-
logging.info(
|
|
152
|
+
logging.info(
|
|
153
|
+
"Can not use dlpack to convert torch tensors. Falling back to numpy."
|
|
154
|
+
)
|
|
150
155
|
nparray = torch_tensor.cpu().detach().numpy()
|
|
151
156
|
tf_tensor = tf.convert_to_tensor(nparray)
|
|
152
157
|
|
|
@@ -154,7 +159,8 @@ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
|
|
|
154
159
|
|
|
155
160
|
|
|
156
161
|
def _get_states(
|
|
157
|
-
exported_programs: list[torch.export.ExportedProgram],
|
|
162
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
163
|
+
signatures: list[Signature],
|
|
158
164
|
):
|
|
159
165
|
for exported_program, signature in zip(exported_programs, signatures):
|
|
160
166
|
args, _ = exported_program.example_inputs
|
|
@@ -166,7 +172,8 @@ def _get_states(
|
|
|
166
172
|
# Only interested in Tensors that are part of the state (and not user input).
|
|
167
173
|
if (
|
|
168
174
|
not isinstance(tensor, torch.Tensor)
|
|
169
|
-
or input_spec.kind
|
|
175
|
+
or input_spec.kind
|
|
176
|
+
== torch.export.graph_signature.InputKind.USER_INPUT
|
|
170
177
|
):
|
|
171
178
|
continue
|
|
172
179
|
yield signature, tensor, input_spec
|
|
@@ -192,9 +199,13 @@ def _gather_state_dict(
|
|
|
192
199
|
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
|
|
193
200
|
|
|
194
201
|
state_dict = {}
|
|
195
|
-
for signature, tensor, input_spec in _get_states(
|
|
202
|
+
for signature, tensor, input_spec in _get_states(
|
|
203
|
+
exported_programs, signatures
|
|
204
|
+
):
|
|
196
205
|
unique_id = _tensor_unique_id(tensor)
|
|
197
|
-
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
|
|
206
|
+
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
|
|
207
|
+
unique_id
|
|
208
|
+
]
|
|
198
209
|
|
|
199
210
|
return state_dict
|
|
200
211
|
|
|
@@ -236,7 +247,9 @@ def _wrap_as_tf_func(
|
|
|
236
247
|
):
|
|
237
248
|
def inner(*args):
|
|
238
249
|
type_info = [sig.dtype for sig in func.meta.output_signature]
|
|
239
|
-
shape_info = [
|
|
250
|
+
shape_info = [
|
|
251
|
+
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
|
|
252
|
+
]
|
|
240
253
|
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
|
|
241
254
|
return tfxla.call_module(
|
|
242
255
|
tuple(call_args),
|
|
@@ -369,7 +382,9 @@ def convert_stablehlo_to_tflite(
|
|
|
369
382
|
)
|
|
370
383
|
)
|
|
371
384
|
|
|
372
|
-
tf_module._variables =
|
|
385
|
+
tf_module._variables = (
|
|
386
|
+
list(bundle.state_dict.values()) + bundle.additional_constants
|
|
387
|
+
)
|
|
373
388
|
del bundle
|
|
374
389
|
gc.collect()
|
|
375
390
|
|
|
@@ -385,7 +400,8 @@ def convert_stablehlo_to_tflite(
|
|
|
385
400
|
tf_module,
|
|
386
401
|
temp_dir_path,
|
|
387
402
|
signatures={
|
|
388
|
-
sig.name: tf_concrete_funcs[idx]
|
|
403
|
+
sig.name: tf_concrete_funcs[idx]
|
|
404
|
+
for idx, sig in enumerate(signatures)
|
|
389
405
|
},
|
|
390
406
|
)
|
|
391
407
|
# Clean up intermediate memory early.
|
|
@@ -416,6 +432,8 @@ def convert_stablehlo_to_tflite(
|
|
|
416
432
|
and quant_config._quantizer_mode
|
|
417
433
|
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
418
434
|
):
|
|
419
|
-
tflite_model = translate_recipe.quantize_model(
|
|
435
|
+
tflite_model = translate_recipe.quantize_model(
|
|
436
|
+
tflite_model, translated_recipe
|
|
437
|
+
)
|
|
420
438
|
|
|
421
439
|
return tflite_model
|
|
@@ -17,12 +17,11 @@ from __future__ import annotations
|
|
|
17
17
|
|
|
18
18
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
19
19
|
|
|
20
|
-
import torch
|
|
21
|
-
|
|
22
20
|
from ai_edge_torch import model
|
|
23
21
|
from ai_edge_torch.convert import conversion
|
|
24
22
|
from ai_edge_torch.convert import conversion_utils as cutils
|
|
25
23
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
24
|
+
import torch
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
class Converter:
|
|
@@ -68,14 +67,20 @@ class Converter:
|
|
|
68
67
|
"""
|
|
69
68
|
|
|
70
69
|
if name in [sig.name for sig in self._signatures]:
|
|
71
|
-
raise ValueError(
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"A signature with the provided name ({name}) is already added."
|
|
72
|
+
)
|
|
72
73
|
|
|
73
74
|
if sample_args is None and sample_kwargs is None:
|
|
74
75
|
raise ValueError("sample_args or sample_kwargs must be provided.")
|
|
75
76
|
|
|
76
77
|
self._signatures.append(
|
|
77
78
|
cutils.Signature(
|
|
78
|
-
name,
|
|
79
|
+
name,
|
|
80
|
+
module,
|
|
81
|
+
sample_args,
|
|
82
|
+
sample_kwargs,
|
|
83
|
+
dynamic_shapes=dynamic_shapes,
|
|
79
84
|
)
|
|
80
85
|
)
|
|
81
86
|
return self
|
|
@@ -128,7 +133,8 @@ class Converter:
|
|
|
128
133
|
)
|
|
129
134
|
else: # module is provided but not args
|
|
130
135
|
raise ValueError(
|
|
131
|
-
"sample_args or sample_kwargs must be provided if a module is
|
|
136
|
+
"sample_args or sample_kwargs must be provided if a module is"
|
|
137
|
+
" specified."
|
|
132
138
|
)
|
|
133
139
|
|
|
134
140
|
return conversion.convert_signatures(
|
|
@@ -15,10 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Sequence, Union
|
|
17
17
|
|
|
18
|
-
from torch.export import ExportedProgram
|
|
19
|
-
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
|
20
|
-
import torch.utils._pytree as pytree
|
|
21
|
-
|
|
22
18
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
23
19
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
24
20
|
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
|
|
@@ -28,6 +24,9 @@ from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import Bui
|
|
|
28
24
|
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
|
|
29
25
|
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
|
|
30
26
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
|
|
27
|
+
from torch.export import ExportedProgram
|
|
28
|
+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
|
29
|
+
import torch.utils._pytree as pytree
|
|
31
30
|
|
|
32
31
|
|
|
33
32
|
# TODO(cnchan): make a PassManager class.
|
|
@@ -32,14 +32,18 @@ class ExportedProgramPassResult(
|
|
|
32
32
|
|
|
33
33
|
class ExportedProgramPassBase(abc.ABC):
|
|
34
34
|
|
|
35
|
-
def __call__(
|
|
35
|
+
def __call__(
|
|
36
|
+
self, exported_program: ExportedProgram
|
|
37
|
+
) -> ExportedProgramPassResult:
|
|
36
38
|
self.requires(exported_program)
|
|
37
39
|
res = self.call(exported_program)
|
|
38
40
|
self.ensures(exported_program)
|
|
39
41
|
return res
|
|
40
42
|
|
|
41
43
|
@abc.abstractmethod
|
|
42
|
-
def call(
|
|
44
|
+
def call(
|
|
45
|
+
self, exported_program: ExportedProgram
|
|
46
|
+
) -> ExportedProgramPassResult:
|
|
43
47
|
pass
|
|
44
48
|
|
|
45
49
|
def requires(self, exported_program: ExportedProgram) -> None:
|
|
@@ -15,8 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import copy
|
|
17
17
|
import functools
|
|
18
|
+
from functools import reduce
|
|
18
19
|
from typing import Any, Callable
|
|
19
20
|
|
|
21
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
20
22
|
import torch
|
|
21
23
|
from torch.fx import GraphModule
|
|
22
24
|
from torch.fx import Node
|
|
@@ -24,8 +26,6 @@ from torch.fx.passes.infra.pass_base import PassBase
|
|
|
24
26
|
from torch.fx.passes.infra.pass_base import PassResult
|
|
25
27
|
import torch.utils._pytree as pytree
|
|
26
28
|
|
|
27
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
28
|
-
|
|
29
29
|
_composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
|
|
30
30
|
|
|
31
31
|
|
|
@@ -41,7 +41,9 @@ def _register_composite_builder(op):
|
|
|
41
41
|
return inner
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def _tree_map_to_composite_attr_values(
|
|
44
|
+
def _tree_map_to_composite_attr_values(
|
|
45
|
+
values, *, stringify_incompatible_values=True
|
|
46
|
+
):
|
|
45
47
|
|
|
46
48
|
def convert(value):
|
|
47
49
|
nonlocal stringify_incompatible_values
|
|
@@ -65,7 +67,9 @@ class TorchOpArgumentsMapper:
|
|
|
65
67
|
|
|
66
68
|
assert hasattr(op, "_schema")
|
|
67
69
|
self.op = op
|
|
68
|
-
self.arg_specs = [
|
|
70
|
+
self.arg_specs = [
|
|
71
|
+
(spec.name, spec.default_value) for spec in op._schema.arguments
|
|
72
|
+
]
|
|
69
73
|
|
|
70
74
|
def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
|
|
71
75
|
"""Inspect the op's schema and extract all its args and kwargs
|
|
@@ -110,16 +114,17 @@ def _aten_gelu(gm: GraphModule, node: Node):
|
|
|
110
114
|
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
|
111
115
|
|
|
112
116
|
# TFLite supports exact and tanh approximate.
|
|
113
|
-
if
|
|
117
|
+
if (
|
|
118
|
+
full_kwargs["approximate"] != "none"
|
|
119
|
+
and full_kwargs["approximate"] != "tanh"
|
|
120
|
+
):
|
|
114
121
|
return op(*args, **kwargs)
|
|
115
122
|
|
|
116
123
|
builder = StableHLOCompositeBuilder(
|
|
117
124
|
"aten.gelu.default",
|
|
118
|
-
attr=_tree_map_to_composite_attr_values(
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
}
|
|
122
|
-
),
|
|
125
|
+
attr=_tree_map_to_composite_attr_values({
|
|
126
|
+
"approximate": full_kwargs["approximate"],
|
|
127
|
+
}),
|
|
123
128
|
)
|
|
124
129
|
full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
|
|
125
130
|
output = op(full_kwargs["self"])
|
|
@@ -150,7 +155,10 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
|
150
155
|
):
|
|
151
156
|
dim_output_size = int((dim_input_size + dim_stride - 1) / dim_stride)
|
|
152
157
|
padding_needed = max(
|
|
153
|
-
0,
|
|
158
|
+
0,
|
|
159
|
+
(dim_output_size - 1) * dim_stride
|
|
160
|
+
+ dim_kernel_size
|
|
161
|
+
- dim_input_size,
|
|
154
162
|
)
|
|
155
163
|
if padding_needed % 2 != 0:
|
|
156
164
|
return False
|
|
@@ -193,16 +201,14 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
|
193
201
|
|
|
194
202
|
builder = StableHLOCompositeBuilder(
|
|
195
203
|
"aten.avg_pool2d.default",
|
|
196
|
-
attr=_tree_map_to_composite_attr_values(
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
}
|
|
205
|
-
),
|
|
204
|
+
attr=_tree_map_to_composite_attr_values({
|
|
205
|
+
"kernel_size": full_kwargs["kernel_size"],
|
|
206
|
+
"stride": full_kwargs["stride"],
|
|
207
|
+
"padding": full_kwargs["padding"],
|
|
208
|
+
"ceil_mode": full_kwargs["ceil_mode"],
|
|
209
|
+
"count_include_pad": full_kwargs["count_include_pad"],
|
|
210
|
+
"divisor_override": full_kwargs["divisor_override"],
|
|
211
|
+
}),
|
|
206
212
|
)
|
|
207
213
|
|
|
208
214
|
full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
|
|
@@ -223,25 +229,25 @@ def _aten_embedding(gm: GraphModule, node: Node):
|
|
|
223
229
|
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
|
224
230
|
_, embedding_dim = full_kwargs["weight"].size()
|
|
225
231
|
idx = full_kwargs["indices"]
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
232
|
+
|
|
233
|
+
# Explicitly cast to INT32. This places the CastOp outside of the HLFB.
|
|
234
|
+
idx = idx.type(torch.int)
|
|
235
|
+
original_idx_shape = idx.size()
|
|
236
|
+
|
|
237
|
+
# Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
|
|
238
|
+
idx = torch.reshape(idx, (idx.numel(),))
|
|
239
|
+
|
|
240
|
+
builder = StableHLOCompositeBuilder("odml.embedding_lookup")
|
|
241
|
+
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
|
|
242
|
+
idx,
|
|
243
|
+
full_kwargs["weight"],
|
|
244
|
+
)
|
|
245
|
+
output = op(**full_kwargs)
|
|
246
|
+
output = builder.mark_outputs(output)
|
|
247
|
+
|
|
248
|
+
# Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
|
|
249
|
+
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
|
|
250
|
+
return output
|
|
245
251
|
|
|
246
252
|
node.target = embedding
|
|
247
253
|
|
|
@@ -15,23 +15,20 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
|
|
18
|
-
import torch
|
|
19
|
-
|
|
20
18
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
21
19
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
22
20
|
from ai_edge_torch.hlfb import mark_pattern
|
|
21
|
+
import torch
|
|
23
22
|
|
|
24
23
|
# For torch nightly released after mid June 2024,
|
|
25
24
|
# torch.nn.functional.interpolate no longer gets exported into decomposed graph
|
|
26
25
|
# but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
|
|
27
26
|
# This behavior would our pattern matching based composite builder.
|
|
28
27
|
# It requires the pattern and model graph to get decomposed first for backward compatibility.
|
|
29
|
-
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
]
|
|
34
|
-
)
|
|
28
|
+
_INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions([
|
|
29
|
+
torch.ops.aten.upsample_bilinear2d.vec,
|
|
30
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
|
31
|
+
])
|
|
35
32
|
|
|
36
33
|
|
|
37
34
|
@functools.cache
|
|
@@ -84,7 +81,9 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
|
84
81
|
def _get_interpolate_nearest2d_pattern():
|
|
85
82
|
pattern = mark_pattern.Pattern(
|
|
86
83
|
"tfl.resize_nearest_neighbor",
|
|
87
|
-
lambda x: torch.nn.functional.interpolate(
|
|
84
|
+
lambda x: torch.nn.functional.interpolate(
|
|
85
|
+
x, scale_factor=2, mode="nearest"
|
|
86
|
+
),
|
|
88
87
|
export_args=(torch.rand(1, 3, 100, 100),),
|
|
89
88
|
decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
|
90
89
|
)
|
|
@@ -112,7 +111,9 @@ class BuildInterpolateCompositePass(ExportedProgramPassBase):
|
|
|
112
111
|
]
|
|
113
112
|
|
|
114
113
|
def call(self, exported_program: torch.export.ExportedProgram):
|
|
115
|
-
exported_program = exported_program.run_decompositions(
|
|
114
|
+
exported_program = exported_program.run_decompositions(
|
|
115
|
+
_INTERPOLATE_DECOMPOSITIONS
|
|
116
|
+
)
|
|
116
117
|
|
|
117
118
|
graph_module = exported_program.graph_module
|
|
118
119
|
for pattern in self._patterns:
|
|
@@ -13,11 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from torch.export import ExportedProgram
|
|
18
|
-
|
|
19
16
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
20
17
|
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
18
|
+
import torch
|
|
19
|
+
from torch.export import ExportedProgram
|
|
21
20
|
|
|
22
21
|
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
|
23
22
|
# any op decompositions but just aot_export_module. Due to the check in
|
|
@@ -15,13 +15,12 @@
|
|
|
15
15
|
import dataclasses
|
|
16
16
|
import operator
|
|
17
17
|
|
|
18
|
-
import torch
|
|
19
|
-
from torch.fx import Node
|
|
20
|
-
|
|
21
18
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
22
19
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
|
|
23
20
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
24
21
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
|
|
22
|
+
import torch
|
|
23
|
+
from torch.fx import Node
|
|
25
24
|
|
|
26
25
|
aten = torch.ops.aten
|
|
27
26
|
|
|
@@ -150,7 +149,9 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
|
150
149
|
|
|
151
150
|
|
|
152
151
|
@layout_sensitive_inputs_getters.register(aten.convolution)
|
|
153
|
-
@layout_sensitive_inputs_getters.register(
|
|
152
|
+
@layout_sensitive_inputs_getters.register(
|
|
153
|
+
aten._native_batch_norm_legit_no_training
|
|
154
|
+
)
|
|
154
155
|
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
|
155
156
|
def _first_arg_getter(node):
|
|
156
157
|
return [node.args[0]]
|
|
@@ -174,7 +175,11 @@ def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
|
|
174
175
|
@nhwcable_node_checkers.register(aten._native_batch_norm_legit_no_training)
|
|
175
176
|
def _aten_norm_checker(node):
|
|
176
177
|
val = node.meta.get("val")
|
|
177
|
-
if
|
|
178
|
+
if (
|
|
179
|
+
not isinstance(val, (list, tuple))
|
|
180
|
+
or not val
|
|
181
|
+
or not hasattr(val[0], "shape")
|
|
182
|
+
):
|
|
178
183
|
return NHWCable(can_be=False, must_be=False)
|
|
179
184
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
|
180
185
|
|
|
@@ -182,9 +187,15 @@ def _aten_norm_checker(node):
|
|
|
182
187
|
@nhwcable_node_checkers.register(aten.native_group_norm)
|
|
183
188
|
def _aten_native_group_norm_checker(node):
|
|
184
189
|
val = node.meta.get("val")
|
|
185
|
-
if
|
|
190
|
+
if (
|
|
191
|
+
not isinstance(val, (list, tuple))
|
|
192
|
+
or not val
|
|
193
|
+
or not hasattr(val[0], "shape")
|
|
194
|
+
):
|
|
186
195
|
return NHWCable(can_be=False, must_be=False)
|
|
187
|
-
if len(node.args) >= 3 and (
|
|
196
|
+
if len(node.args) >= 3 and (
|
|
197
|
+
node.args[1] is not None or node.args[2] is not None
|
|
198
|
+
):
|
|
188
199
|
# Disable NHWC rewriter due to precision issue with weight and bias.
|
|
189
200
|
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
|
|
190
201
|
return NHWCable(can_be=False, must_be=False)
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py
CHANGED
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
19
17
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
18
|
+
import torch
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
def partition(graph_module: torch.fx.GraphModule):
|
|
@@ -45,7 +44,9 @@ def partition(graph_module: torch.fx.GraphModule):
|
|
|
45
44
|
|
|
46
45
|
layout_sensitive_inputs = layout_check.get_layout_sensitive_inputs(node)
|
|
47
46
|
|
|
48
|
-
should_be_nhwc = any(
|
|
47
|
+
should_be_nhwc = any(
|
|
48
|
+
map(layout_mark.is_nhwc_node, layout_sensitive_inputs)
|
|
49
|
+
)
|
|
49
50
|
for input_node in layout_sensitive_inputs:
|
|
50
51
|
if not layout_mark.is_nhwc_node(input_node) and not layout_check.is_4d(
|
|
51
52
|
input_node
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py
CHANGED
|
@@ -17,13 +17,12 @@ import collections
|
|
|
17
17
|
import dataclasses
|
|
18
18
|
import itertools
|
|
19
19
|
|
|
20
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
21
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
20
22
|
import numpy as np
|
|
21
23
|
import scipy
|
|
22
24
|
import torch
|
|
23
25
|
|
|
24
|
-
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
|
25
|
-
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
26
|
-
|
|
27
26
|
|
|
28
27
|
def can_partition(graph_module: torch.fx.GraphModule):
|
|
29
28
|
"""Returns true if the input graph_module can be partitioned by min cut solver
|
|
@@ -83,7 +82,10 @@ class MinCutSolver:
|
|
|
83
82
|
def graph(self):
|
|
84
83
|
edges = np.array(self.edges)
|
|
85
84
|
return scipy.sparse.csr_matrix(
|
|
86
|
-
(
|
|
85
|
+
(
|
|
86
|
+
np.minimum(edges[:, 2], MinCutSolver.INF_COST),
|
|
87
|
+
(edges[:, 0], edges[:, 1]),
|
|
88
|
+
),
|
|
87
89
|
shape=(self._nodes_cnt, self._nodes_cnt),
|
|
88
90
|
dtype=np.int32,
|
|
89
91
|
)
|
|
@@ -14,13 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import operator
|
|
16
16
|
|
|
17
|
-
import torch
|
|
18
|
-
from torch.fx import Node
|
|
19
|
-
import torch.utils._pytree as pytree
|
|
20
|
-
|
|
21
17
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
|
22
18
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
23
19
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.op_func_registry import OpFuncRegistry # NOQA
|
|
20
|
+
import torch
|
|
21
|
+
from torch.fx import Node
|
|
22
|
+
import torch.utils._pytree as pytree
|
|
24
23
|
|
|
25
24
|
aten = torch.ops.aten
|
|
26
25
|
|
|
@@ -349,7 +348,12 @@ def _aten_native_group_norm(node):
|
|
|
349
348
|
):
|
|
350
349
|
input_reshaped = torch.reshape(
|
|
351
350
|
input,
|
|
352
|
-
[
|
|
351
|
+
[
|
|
352
|
+
batch_size,
|
|
353
|
+
flattened_inner_size,
|
|
354
|
+
num_groups,
|
|
355
|
+
num_channels // num_groups,
|
|
356
|
+
],
|
|
353
357
|
)
|
|
354
358
|
reduction_dims = [1, 3]
|
|
355
359
|
|
|
@@ -12,9 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
import torch
|
|
16
|
-
|
|
17
15
|
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
|
|
16
|
+
import torch
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
class OpFuncRegistry(dict):
|