ai-edge-torch-nightly 0.3.0.dev20240828__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
- ai_edge_torch/_convert/test/test_convert.py +1 -1
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
- ai_edge_torch/_convert/test/test_convert_multisig.py +1 -1
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
- ai_edge_torch/debug/test/test_culprit.py +1 -1
- ai_edge_torch/debug/test/test_search_model.py +1 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_quantize.py +1 -1
- ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
- ai_edge_torch/lowertools/test_utils.py +1 -1
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +320 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +45 -21
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
|
@@ -331,7 +331,12 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
|
331
331
|
def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
|
|
332
332
|
a = input - running_mean
|
|
333
333
|
b = torch.sqrt(running_var + eps)
|
|
334
|
-
|
|
334
|
+
out = a / b
|
|
335
|
+
if weight is not None:
|
|
336
|
+
out = out * weight
|
|
337
|
+
if bias is not None:
|
|
338
|
+
out = out + bias
|
|
339
|
+
return out, None, None
|
|
335
340
|
|
|
336
341
|
node.target = batch_norm
|
|
337
342
|
|
|
@@ -21,7 +21,7 @@ from ai_edge_torch.testing import model_coverage
|
|
|
21
21
|
import parameterized
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from
|
|
24
|
+
from absl.testing import absltest as googletest
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
@@ -21,7 +21,7 @@ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
|
21
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from
|
|
24
|
+
from absl.testing import absltest as googletest
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class TestExternalKVLayers(googletest.TestCase):
|
|
@@ -22,7 +22,7 @@ from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
|
22
22
|
import safetensors.torch
|
|
23
23
|
import torch
|
|
24
24
|
|
|
25
|
-
from
|
|
25
|
+
from absl.testing import absltest as googletest
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class TestLoader(googletest.TestCase):
|
|
@@ -28,7 +28,7 @@ from ai_edge_torch.testing import model_coverage
|
|
|
28
28
|
from parameterized import parameterized
|
|
29
29
|
import torch
|
|
30
30
|
|
|
31
|
-
from
|
|
31
|
+
from absl.testing import absltest as googletest
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class TestVerifyRecipes(googletest.TestCase):
|
|
@@ -19,7 +19,7 @@ from ai_edge_torch.hlfb import mark_pattern
|
|
|
19
19
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
|
20
20
|
import torch
|
|
21
21
|
|
|
22
|
-
from
|
|
22
|
+
from absl.testing import absltest as googletest
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def _export_stablehlo_mlir(model, args=None):
|
|
@@ -22,7 +22,7 @@ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
|
22
22
|
import torch
|
|
23
23
|
import torch.nn.functional as F
|
|
24
24
|
|
|
25
|
-
from
|
|
25
|
+
from absl.testing import absltest as googletest
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _export_stablehlo_mlir(model, args):
|
|
@@ -84,13 +84,17 @@ def _wrap_as_tf_func(
|
|
|
84
84
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
|
|
85
85
|
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
|
|
86
86
|
call_args = _extract_call_args(bundle, args, tf_state_dict)
|
|
87
|
+
# HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
|
|
88
|
+
# build, which may not have the same StableHLO version as what used in
|
|
89
|
+
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
|
|
90
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
|
87
91
|
call_module_return = tfxla.call_module(
|
|
88
92
|
tuple(call_args),
|
|
89
93
|
version=5,
|
|
90
94
|
Tout=t_outs, # dtype information
|
|
91
95
|
Sout=s_outs, # Shape information
|
|
92
96
|
function_list=[],
|
|
93
|
-
module=bundle.
|
|
97
|
+
module=bundle.module_bytecode_vhlo,
|
|
94
98
|
)
|
|
95
99
|
spec = exported_program.call_spec.out_spec
|
|
96
100
|
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from . import composite
|
|
16
|
+
from . import debuginfo
|
|
17
|
+
from . import export
|
|
18
|
+
from . import export_utils
|
|
19
|
+
from . import lowerings
|
|
20
|
+
from . import passes
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Wrappers for latest torch APIs/utilities to maintain backward compatibility with older torch releases."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch.fx import _pytree as fx_pytree
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
|
|
22
|
+
"""Transform args, kwargs of __call__ to args for graph_module.
|
|
23
|
+
|
|
24
|
+
self.graph_module takes stuff from state dict as inputs.
|
|
25
|
+
The invariant is for ep: ExportedProgram is
|
|
26
|
+
ep(args, kwargs) ==
|
|
27
|
+
ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
|
|
28
|
+
"""
|
|
29
|
+
if hasattr(ep, "_graph_module_flat_inputs"):
|
|
30
|
+
return ep._graph_module_flat_inputs(args, kwargs)
|
|
31
|
+
|
|
32
|
+
if args is None:
|
|
33
|
+
args = tuple()
|
|
34
|
+
if kwargs is None:
|
|
35
|
+
kwargs = {}
|
|
36
|
+
|
|
37
|
+
flat_args = args
|
|
38
|
+
if (in_spec := ep.call_spec.in_spec) is not None:
|
|
39
|
+
if (
|
|
40
|
+
in_spec.type == tuple
|
|
41
|
+
and len(in_spec.children_specs) == 2
|
|
42
|
+
and in_spec.children_specs[0].type == tuple
|
|
43
|
+
and in_spec.children_specs[1].type == dict
|
|
44
|
+
):
|
|
45
|
+
# NOTE: this is the case where in_spec is for both args and kwargs
|
|
46
|
+
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
47
|
+
else:
|
|
48
|
+
flat_args = fx_pytree.tree_flatten_spec(args, in_spec)
|
|
49
|
+
|
|
50
|
+
param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers
|
|
51
|
+
param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys)
|
|
52
|
+
|
|
53
|
+
if hasattr(ep.graph_signature, "lifted_tensor_constants"):
|
|
54
|
+
ordered_tensor_constants = tuple(
|
|
55
|
+
ep.tensor_constants[name]
|
|
56
|
+
for name in ep.graph_signature.lifted_tensor_constants
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
ordered_tensor_constants = tuple()
|
|
60
|
+
|
|
61
|
+
return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Torch library for registering ODML Torch custom ops."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
ODML_TORCH_LIB = torch.library.Library("odml_torch", "DEF")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from .mark_tensor import mark_tensor_op
|
|
16
|
+
from .stablehlo_composite_builder import StableHLOCompositeBuilder
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import json
|
|
16
|
+
from typing import Sequence, Union
|
|
17
|
+
|
|
18
|
+
from jax._src.lib.mlir import ir
|
|
19
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from .. import _torch_library
|
|
23
|
+
from .. import lowerings
|
|
24
|
+
|
|
25
|
+
CompositeAttrType = dict[
|
|
26
|
+
str,
|
|
27
|
+
Union[
|
|
28
|
+
int,
|
|
29
|
+
float,
|
|
30
|
+
bool,
|
|
31
|
+
str,
|
|
32
|
+
Sequence[int],
|
|
33
|
+
Sequence[float],
|
|
34
|
+
Sequence[bool],
|
|
35
|
+
],
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _assert_valid_composite_attr(attr: CompositeAttrType):
|
|
40
|
+
if attr is None:
|
|
41
|
+
return
|
|
42
|
+
if not isinstance(attr, dict):
|
|
43
|
+
raise ValueError("Composite attr must be a Python dictionary.")
|
|
44
|
+
|
|
45
|
+
for k, v in attr.items():
|
|
46
|
+
if not isinstance(k, str):
|
|
47
|
+
raise ValueError("Composite attr name must be a Python str.")
|
|
48
|
+
|
|
49
|
+
invalid_attr_value_error = ValueError(
|
|
50
|
+
"Composite attr value must be either Python str, float, int, bool,"
|
|
51
|
+
" list[int], list[float], list[bool]."
|
|
52
|
+
)
|
|
53
|
+
if isinstance(v, (list, tuple)):
|
|
54
|
+
eltys = {type(el) for el in v}
|
|
55
|
+
if len(eltys) > 1 or next(iter(eltys)) not in (int, float, bool):
|
|
56
|
+
raise invalid_attr_value_error
|
|
57
|
+
elif type(v) not in (str, float, int, bool):
|
|
58
|
+
raise invalid_attr_value_error
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@torch._dynamo.assume_constant_result
|
|
62
|
+
def serialize_composite_attr(attr: Union[CompositeAttrType, None]):
|
|
63
|
+
"""Serialize the composite attr into a dynamo-tracable value."""
|
|
64
|
+
if attr is None:
|
|
65
|
+
return None
|
|
66
|
+
_assert_valid_composite_attr(attr)
|
|
67
|
+
return tuple(attr.items())
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@torch._dynamo.assume_constant_result
|
|
71
|
+
def deserialize_composite_attr(serialized_attr) -> CompositeAttrType:
|
|
72
|
+
"""Deserialize dynamo-tracable composite attribute into its raw value."""
|
|
73
|
+
if serialized_attr is None:
|
|
74
|
+
return None
|
|
75
|
+
return dict(serialized_attr)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
_torch_library.ODML_TORCH_LIB.define(
|
|
79
|
+
"mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any?"
|
|
80
|
+
" attr=None) -> Tensor"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
mark_tensor_op = torch.ops.odml_torch.mark_tensor.default
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@torch.library.impl(
|
|
87
|
+
_torch_library.ODML_TORCH_LIB, "mark_tensor", "CompositeExplicitAutograd"
|
|
88
|
+
)
|
|
89
|
+
def mark_tensor(
|
|
90
|
+
x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
|
|
91
|
+
):
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@torch.library.impl(_torch_library.ODML_TORCH_LIB, "mark_tensor", "Meta")
|
|
96
|
+
def mark_tensor_meta(
|
|
97
|
+
x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
|
|
98
|
+
):
|
|
99
|
+
return torch.empty_like(x)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@lowerings.lower(torch.ops.odml_torch.mark_tensor)
|
|
103
|
+
def mark_tensor_lowering(
|
|
104
|
+
lctx, x: ir.Value, name: str, pos: int, id: str, is_input: bool, attr=None
|
|
105
|
+
):
|
|
106
|
+
attr = deserialize_composite_attr(attr)
|
|
107
|
+
return stablehlo.custom_call(
|
|
108
|
+
[x.type],
|
|
109
|
+
inputs=[x],
|
|
110
|
+
call_target_name="mark_tensor",
|
|
111
|
+
backend_config=ir.StringAttr.get(
|
|
112
|
+
json.dumps({
|
|
113
|
+
"name": name,
|
|
114
|
+
"pos": pos,
|
|
115
|
+
"id": id,
|
|
116
|
+
"is_input": is_input,
|
|
117
|
+
"attr": attr,
|
|
118
|
+
})
|
|
119
|
+
),
|
|
120
|
+
)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import uuid
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from . import mark_tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@torch._dynamo.assume_constant_result
|
|
23
|
+
def _get_uuid() -> str:
|
|
24
|
+
return uuid.uuid4().hex
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class StableHLOCompositeBuilder:
|
|
28
|
+
"""Builder class for building a StableHLO composite in the lowering."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, name: str, attr: mark_tensor.CompositeAttrType = None):
|
|
31
|
+
"""Helper for building a StableHLO Composite by marking input and output tensors.
|
|
32
|
+
|
|
33
|
+
It should be used with the StableHLO converters from `torch_xla.stablehlo`.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name (str): The name of the built StableHLO Composite op.
|
|
37
|
+
attr (mark_tensor.CompositeAttrType): Attributes of the StableHLO
|
|
38
|
+
Composite op.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
self.attr = attr
|
|
42
|
+
self.name = name
|
|
43
|
+
self.id = _get_uuid()
|
|
44
|
+
self._inputs = []
|
|
45
|
+
self._outputs = []
|
|
46
|
+
|
|
47
|
+
def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
|
|
48
|
+
"""Mark the input/output tensors of the StableHLO Composite."""
|
|
49
|
+
marked_tensors = []
|
|
50
|
+
serialized_attr = (
|
|
51
|
+
mark_tensor.serialize_composite_attr(self.attr)
|
|
52
|
+
if not is_input
|
|
53
|
+
else None
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
for pos, tensor in enumerate(tensors):
|
|
57
|
+
if not isinstance(tensor, torch.Tensor):
|
|
58
|
+
raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
|
|
59
|
+
marked_tensors.append(
|
|
60
|
+
mark_tensor.mark_tensor_op(
|
|
61
|
+
tensor,
|
|
62
|
+
name=self.name,
|
|
63
|
+
pos=pos,
|
|
64
|
+
id=self.id,
|
|
65
|
+
is_input=is_input,
|
|
66
|
+
attr=serialized_attr,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if len(marked_tensors) == 1:
|
|
71
|
+
return marked_tensors[0]
|
|
72
|
+
return tuple(marked_tensors)
|
|
73
|
+
|
|
74
|
+
def mark_inputs(self, *tensors: torch.Tensor):
|
|
75
|
+
"""Mark the input tensors of the StableHLO Composite.
|
|
76
|
+
|
|
77
|
+
This method must only be called once per builder.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
*tensors (torch.Tensor): Torch tensors to mark.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
|
|
84
|
+
Torch tensors marked as composite inputs. The tensor inputs of this
|
|
85
|
+
method
|
|
86
|
+
should be replaced by the marked tensors in later usages.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
return self._mark_tensor(*tensors, is_input=True)
|
|
90
|
+
|
|
91
|
+
def mark_outputs(self, *tensors: torch.Tensor):
|
|
92
|
+
"""Mark the output tensors of the StableHLO Composite.
|
|
93
|
+
|
|
94
|
+
This method must only be called once per builder.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
*tensors (torch.Tensor): Torch tensors to mark.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
|
|
101
|
+
Torch tensors marked as composite outputs. The tensor inputs of this
|
|
102
|
+
method
|
|
103
|
+
should be replaced by the marked tensors in later usages.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
return self._mark_tensor(*tensors, is_input=False)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from ._build import build_mlir_debuginfo
|
|
16
|
+
from ._op_polyfill import write_mlir_debuginfo_op
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _class_fullname(cls):
|
|
19
|
+
module = cls.__module__
|
|
20
|
+
if module == "builtins":
|
|
21
|
+
return cls.__qualname__
|
|
22
|
+
return module + "." + cls.__qualname__
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_hierarchy(node: torch.fx.Node):
|
|
26
|
+
nn_module_stack = node.meta.get("nn_module_stack", {})
|
|
27
|
+
layers = []
|
|
28
|
+
for name, layer in nn_module_stack.values():
|
|
29
|
+
iid = ("_" + name.split(".")[-1]) if name else ""
|
|
30
|
+
layer_str = layer if isinstance(layer, str) else _class_fullname(layer)
|
|
31
|
+
layers.append(layer_str + iid)
|
|
32
|
+
|
|
33
|
+
hierachy_str = "/".join(layers) + ";"
|
|
34
|
+
return hierachy_str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_mlir_debuginfo(node: torch.fx.Node):
|
|
38
|
+
"""Build the debuginfo string for the given node's lowerings in MLIR."""
|
|
39
|
+
|
|
40
|
+
if not hasattr(node, "meta") or "nn_module_stack" not in node.meta:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
return _get_hierarchy(node)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Polyfill op for torch.ops.xla.write_mlir_debuginfo.
|
|
16
|
+
|
|
17
|
+
In odml-torch, MLIR debuginfo is generated in the lowering framework directly
|
|
18
|
+
without the need of an additional torch op to write. This file register a no-op
|
|
19
|
+
placeholder torch op to replace torch.ops.xla.write_mlir_debuginfo in
|
|
20
|
+
ai-edge-torch.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from jax._src.lib.mlir import ir
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from .. import _torch_library
|
|
27
|
+
from .. import lowerings
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_torch_library.ODML_TORCH_LIB.define(
|
|
31
|
+
"write_mlir_debuginfo(Tensor x, str data) -> Tensor"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
write_mlir_debuginfo_op = torch.ops.odml_torch.write_mlir_debuginfo
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@torch.library.impl(
|
|
38
|
+
_torch_library.ODML_TORCH_LIB,
|
|
39
|
+
"write_mlir_debuginfo",
|
|
40
|
+
"CompositeExplicitAutograd",
|
|
41
|
+
)
|
|
42
|
+
def write_mlir_debuginfo(x: torch.Tensor, _: str):
|
|
43
|
+
return x
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@torch.library.impl(
|
|
47
|
+
_torch_library.ODML_TORCH_LIB, "write_mlir_debuginfo", "Meta"
|
|
48
|
+
)
|
|
49
|
+
def write_mlir_debuginfo_meta(x: torch.Tensor, _: str):
|
|
50
|
+
return torch.empty_like(x)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@lowerings.lower(torch.ops.odml_torch.write_mlir_debuginfo)
|
|
54
|
+
def write_mlir_debuginfo_lowering(lctx, x: ir.Value, _: str):
|
|
55
|
+
return x
|