ai-edge-torch-nightly 0.3.0.dev20240828__py3-none-any.whl → 0.3.0.dev20240830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (45) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +1 -1
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  9. ai_edge_torch/generative/test/test_loader.py +1 -1
  10. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  11. ai_edge_torch/generative/test/test_quantize.py +1 -1
  12. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  13. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  14. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  15. ai_edge_torch/lowertools/test_utils.py +1 -1
  16. ai_edge_torch/odml_torch/__init__.py +20 -0
  17. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  18. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  19. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  20. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  21. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  22. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  23. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  24. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  25. ai_edge_torch/odml_torch/export.py +320 -0
  26. ai_edge_torch/odml_torch/export_utils.py +168 -0
  27. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  28. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  29. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  31. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  32. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  33. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  34. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  35. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  36. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  37. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  38. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  39. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  40. ai_edge_torch/version.py +1 -1
  41. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/METADATA +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/RECORD +45 -21
  43. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/LICENSE +0 -0
  44. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/WHEEL +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/top_level.txt +0 -0
@@ -331,7 +331,12 @@ def _aten__native_batch_norm_legit_no_training(node):
331
331
  def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
332
332
  a = input - running_mean
333
333
  b = torch.sqrt(running_var + eps)
334
- return a / b * weight + bias, None, None
334
+ out = a / b
335
+ if weight is not None:
336
+ out = out * weight
337
+ if bias is not None:
338
+ out = out + bias
339
+ return out, None, None
335
340
 
336
341
  node.target = batch_norm
337
342
 
@@ -27,7 +27,7 @@ import tensorflow as tf
27
27
  import torch
28
28
  import torchvision
29
29
 
30
- from tensorflow.python.platform import googletest
30
+ from absl.testing import absltest as googletest
31
31
 
32
32
 
33
33
  @dataclasses.dataclass
@@ -21,7 +21,7 @@ from ai_edge_torch.testing import model_coverage
21
21
  import parameterized
22
22
  import torch
23
23
 
24
- from tensorflow.python.platform import googletest
24
+ from absl.testing import absltest as googletest
25
25
 
26
26
 
27
27
  def _func_to_torch_module(func: Callable[..., torch.Tensor]):
@@ -19,7 +19,7 @@ from ai_edge_torch.testing import model_coverage
19
19
  import torch
20
20
  from torch import nn
21
21
 
22
- from tensorflow.python.platform import googletest
22
+ from absl.testing import absltest as googletest
23
23
 
24
24
 
25
25
  class FullyConnectedModel(nn.Module):
@@ -17,7 +17,7 @@
17
17
  import ai_edge_torch
18
18
  import torch
19
19
 
20
- from tensorflow.python.platform import googletest
20
+ from absl.testing import absltest as googletest
21
21
 
22
22
 
23
23
  class Identity(torch.nn.Module):
@@ -21,7 +21,7 @@ import sys
21
21
  from ai_edge_torch.debug import find_culprits
22
22
  import torch
23
23
 
24
- from tensorflow.python.platform import googletest
24
+ from absl.testing import absltest as googletest
25
25
 
26
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
27
 
@@ -17,7 +17,7 @@
17
17
  from ai_edge_torch.debug import _search_model
18
18
  import torch
19
19
 
20
- from tensorflow.python.platform import googletest
20
+ from absl.testing import absltest as googletest
21
21
 
22
22
 
23
23
  class TestSearchModel(googletest.TestCase):
@@ -21,7 +21,7 @@ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
22
  import torch
23
23
 
24
- from tensorflow.python.platform import googletest
24
+ from absl.testing import absltest as googletest
25
25
 
26
26
 
27
27
  class TestExternalKVLayers(googletest.TestCase):
@@ -22,7 +22,7 @@ from ai_edge_torch.generative.utilities import loader as loading_utils
22
22
  import safetensors.torch
23
23
  import torch
24
24
 
25
- from tensorflow.python.platform import googletest
25
+ from absl.testing import absltest as googletest
26
26
 
27
27
 
28
28
  class TestLoader(googletest.TestCase):
@@ -24,7 +24,7 @@ from ai_edge_torch.testing import model_coverage
24
24
  import numpy as np
25
25
  import torch
26
26
 
27
- from tensorflow.python.platform import googletest
27
+ from absl.testing import absltest as googletest
28
28
 
29
29
 
30
30
  class TestModelConversion(googletest.TestCase):
@@ -28,7 +28,7 @@ from ai_edge_torch.testing import model_coverage
28
28
  from parameterized import parameterized
29
29
  import torch
30
30
 
31
- from tensorflow.python.platform import googletest
31
+ from absl.testing import absltest as googletest
32
32
 
33
33
 
34
34
  class TestVerifyRecipes(googletest.TestCase):
@@ -19,7 +19,7 @@ from ai_edge_torch.hlfb import mark_pattern
19
19
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
20
20
  import torch
21
21
 
22
- from tensorflow.python.platform import googletest
22
+ from absl.testing import absltest as googletest
23
23
 
24
24
 
25
25
  def _export_stablehlo_mlir(model, args=None):
@@ -22,7 +22,7 @@ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
22
22
  import torch
23
23
  import torch.nn.functional as F
24
24
 
25
- from tensorflow.python.platform import googletest
25
+ from absl.testing import absltest as googletest
26
26
 
27
27
 
28
28
  def _export_stablehlo_mlir(model, args):
@@ -84,13 +84,17 @@ def _wrap_as_tf_func(
84
84
  t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
85
85
  s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
86
86
  call_args = _extract_call_args(bundle, args, tf_state_dict)
87
+ # HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
88
+ # build, which may not have the same StableHLO version as what used in
89
+ # TFLite converter. Therefore we always serialize MLIR module in VHLO.
90
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
87
91
  call_module_return = tfxla.call_module(
88
92
  tuple(call_args),
89
93
  version=5,
90
94
  Tout=t_outs, # dtype information
91
95
  Sout=s_outs, # Shape information
92
96
  function_list=[],
93
- module=bundle.module_bytecode,
97
+ module=bundle.module_bytecode_vhlo,
94
98
  )
95
99
  spec = exported_program.call_spec.out_spec
96
100
 
@@ -16,7 +16,7 @@
16
16
  import re
17
17
  from typing import Optional
18
18
  from ai_edge_torch import config
19
- from tensorflow.python.platform import googletest
19
+ from absl.testing import absltest as googletest
20
20
 
21
21
 
22
22
  def _extract_backend_configs(mlir):
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from . import composite
16
+ from . import debuginfo
17
+ from . import export
18
+ from . import export_utils
19
+ from . import lowerings
20
+ from . import passes
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Wrappers for latest torch APIs/utilities to maintain backward compatibility with older torch releases."""
16
+
17
+ import torch
18
+ from torch.fx import _pytree as fx_pytree
19
+
20
+
21
+ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
22
+ """Transform args, kwargs of __call__ to args for graph_module.
23
+
24
+ self.graph_module takes stuff from state dict as inputs.
25
+ The invariant is for ep: ExportedProgram is
26
+ ep(args, kwargs) ==
27
+ ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
28
+ """
29
+ if hasattr(ep, "_graph_module_flat_inputs"):
30
+ return ep._graph_module_flat_inputs(args, kwargs)
31
+
32
+ if args is None:
33
+ args = tuple()
34
+ if kwargs is None:
35
+ kwargs = {}
36
+
37
+ flat_args = args
38
+ if (in_spec := ep.call_spec.in_spec) is not None:
39
+ if (
40
+ in_spec.type == tuple
41
+ and len(in_spec.children_specs) == 2
42
+ and in_spec.children_specs[0].type == tuple
43
+ and in_spec.children_specs[1].type == dict
44
+ ):
45
+ # NOTE: this is the case where in_spec is for both args and kwargs
46
+ flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
47
+ else:
48
+ flat_args = fx_pytree.tree_flatten_spec(args, in_spec)
49
+
50
+ param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers
51
+ param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys)
52
+
53
+ if hasattr(ep.graph_signature, "lifted_tensor_constants"):
54
+ ordered_tensor_constants = tuple(
55
+ ep.tensor_constants[name]
56
+ for name in ep.graph_signature.lifted_tensor_constants
57
+ )
58
+ else:
59
+ ordered_tensor_constants = tuple()
60
+
61
+ return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
@@ -0,0 +1,19 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch library for registering ODML Torch custom ops."""
16
+
17
+ import torch
18
+
19
+ ODML_TORCH_LIB = torch.library.Library("odml_torch", "DEF")
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from .mark_tensor import mark_tensor_op
16
+ from .stablehlo_composite_builder import StableHLOCompositeBuilder
@@ -0,0 +1,120 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import json
16
+ from typing import Sequence, Union
17
+
18
+ from jax._src.lib.mlir import ir
19
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
20
+ import torch
21
+
22
+ from .. import _torch_library
23
+ from .. import lowerings
24
+
25
+ CompositeAttrType = dict[
26
+ str,
27
+ Union[
28
+ int,
29
+ float,
30
+ bool,
31
+ str,
32
+ Sequence[int],
33
+ Sequence[float],
34
+ Sequence[bool],
35
+ ],
36
+ ]
37
+
38
+
39
+ def _assert_valid_composite_attr(attr: CompositeAttrType):
40
+ if attr is None:
41
+ return
42
+ if not isinstance(attr, dict):
43
+ raise ValueError("Composite attr must be a Python dictionary.")
44
+
45
+ for k, v in attr.items():
46
+ if not isinstance(k, str):
47
+ raise ValueError("Composite attr name must be a Python str.")
48
+
49
+ invalid_attr_value_error = ValueError(
50
+ "Composite attr value must be either Python str, float, int, bool,"
51
+ " list[int], list[float], list[bool]."
52
+ )
53
+ if isinstance(v, (list, tuple)):
54
+ eltys = {type(el) for el in v}
55
+ if len(eltys) > 1 or next(iter(eltys)) not in (int, float, bool):
56
+ raise invalid_attr_value_error
57
+ elif type(v) not in (str, float, int, bool):
58
+ raise invalid_attr_value_error
59
+
60
+
61
+ @torch._dynamo.assume_constant_result
62
+ def serialize_composite_attr(attr: Union[CompositeAttrType, None]):
63
+ """Serialize the composite attr into a dynamo-tracable value."""
64
+ if attr is None:
65
+ return None
66
+ _assert_valid_composite_attr(attr)
67
+ return tuple(attr.items())
68
+
69
+
70
+ @torch._dynamo.assume_constant_result
71
+ def deserialize_composite_attr(serialized_attr) -> CompositeAttrType:
72
+ """Deserialize dynamo-tracable composite attribute into its raw value."""
73
+ if serialized_attr is None:
74
+ return None
75
+ return dict(serialized_attr)
76
+
77
+
78
+ _torch_library.ODML_TORCH_LIB.define(
79
+ "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any?"
80
+ " attr=None) -> Tensor"
81
+ )
82
+
83
+ mark_tensor_op = torch.ops.odml_torch.mark_tensor.default
84
+
85
+
86
+ @torch.library.impl(
87
+ _torch_library.ODML_TORCH_LIB, "mark_tensor", "CompositeExplicitAutograd"
88
+ )
89
+ def mark_tensor(
90
+ x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
91
+ ):
92
+ return x
93
+
94
+
95
+ @torch.library.impl(_torch_library.ODML_TORCH_LIB, "mark_tensor", "Meta")
96
+ def mark_tensor_meta(
97
+ x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
98
+ ):
99
+ return torch.empty_like(x)
100
+
101
+
102
+ @lowerings.lower(torch.ops.odml_torch.mark_tensor)
103
+ def mark_tensor_lowering(
104
+ lctx, x: ir.Value, name: str, pos: int, id: str, is_input: bool, attr=None
105
+ ):
106
+ attr = deserialize_composite_attr(attr)
107
+ return stablehlo.custom_call(
108
+ [x.type],
109
+ inputs=[x],
110
+ call_target_name="mark_tensor",
111
+ backend_config=ir.StringAttr.get(
112
+ json.dumps({
113
+ "name": name,
114
+ "pos": pos,
115
+ "id": id,
116
+ "is_input": is_input,
117
+ "attr": attr,
118
+ })
119
+ ),
120
+ )
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import uuid
16
+
17
+ import torch
18
+
19
+ from . import mark_tensor
20
+
21
+
22
+ @torch._dynamo.assume_constant_result
23
+ def _get_uuid() -> str:
24
+ return uuid.uuid4().hex
25
+
26
+
27
+ class StableHLOCompositeBuilder:
28
+ """Builder class for building a StableHLO composite in the lowering."""
29
+
30
+ def __init__(self, name: str, attr: mark_tensor.CompositeAttrType = None):
31
+ """Helper for building a StableHLO Composite by marking input and output tensors.
32
+
33
+ It should be used with the StableHLO converters from `torch_xla.stablehlo`.
34
+
35
+ Args:
36
+ name (str): The name of the built StableHLO Composite op.
37
+ attr (mark_tensor.CompositeAttrType): Attributes of the StableHLO
38
+ Composite op.
39
+ """
40
+
41
+ self.attr = attr
42
+ self.name = name
43
+ self.id = _get_uuid()
44
+ self._inputs = []
45
+ self._outputs = []
46
+
47
+ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
48
+ """Mark the input/output tensors of the StableHLO Composite."""
49
+ marked_tensors = []
50
+ serialized_attr = (
51
+ mark_tensor.serialize_composite_attr(self.attr)
52
+ if not is_input
53
+ else None
54
+ )
55
+
56
+ for pos, tensor in enumerate(tensors):
57
+ if not isinstance(tensor, torch.Tensor):
58
+ raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
59
+ marked_tensors.append(
60
+ mark_tensor.mark_tensor_op(
61
+ tensor,
62
+ name=self.name,
63
+ pos=pos,
64
+ id=self.id,
65
+ is_input=is_input,
66
+ attr=serialized_attr,
67
+ )
68
+ )
69
+
70
+ if len(marked_tensors) == 1:
71
+ return marked_tensors[0]
72
+ return tuple(marked_tensors)
73
+
74
+ def mark_inputs(self, *tensors: torch.Tensor):
75
+ """Mark the input tensors of the StableHLO Composite.
76
+
77
+ This method must only be called once per builder.
78
+
79
+ Args:
80
+ *tensors (torch.Tensor): Torch tensors to mark.
81
+
82
+ Returns:
83
+ marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
84
+ Torch tensors marked as composite inputs. The tensor inputs of this
85
+ method
86
+ should be replaced by the marked tensors in later usages.
87
+ """
88
+
89
+ return self._mark_tensor(*tensors, is_input=True)
90
+
91
+ def mark_outputs(self, *tensors: torch.Tensor):
92
+ """Mark the output tensors of the StableHLO Composite.
93
+
94
+ This method must only be called once per builder.
95
+
96
+ Args:
97
+ *tensors (torch.Tensor): Torch tensors to mark.
98
+
99
+ Returns:
100
+ marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
101
+ Torch tensors marked as composite outputs. The tensor inputs of this
102
+ method
103
+ should be replaced by the marked tensors in later usages.
104
+ """
105
+
106
+ return self._mark_tensor(*tensors, is_input=False)
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ._build import build_mlir_debuginfo
16
+ from ._op_polyfill import write_mlir_debuginfo_op
@@ -0,0 +1,43 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+
17
+
18
+ def _class_fullname(cls):
19
+ module = cls.__module__
20
+ if module == "builtins":
21
+ return cls.__qualname__
22
+ return module + "." + cls.__qualname__
23
+
24
+
25
+ def _get_hierarchy(node: torch.fx.Node):
26
+ nn_module_stack = node.meta.get("nn_module_stack", {})
27
+ layers = []
28
+ for name, layer in nn_module_stack.values():
29
+ iid = ("_" + name.split(".")[-1]) if name else ""
30
+ layer_str = layer if isinstance(layer, str) else _class_fullname(layer)
31
+ layers.append(layer_str + iid)
32
+
33
+ hierachy_str = "/".join(layers) + ";"
34
+ return hierachy_str
35
+
36
+
37
+ def build_mlir_debuginfo(node: torch.fx.Node):
38
+ """Build the debuginfo string for the given node's lowerings in MLIR."""
39
+
40
+ if not hasattr(node, "meta") or "nn_module_stack" not in node.meta:
41
+ return None
42
+
43
+ return _get_hierarchy(node)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Polyfill op for torch.ops.xla.write_mlir_debuginfo.
16
+
17
+ In odml-torch, MLIR debuginfo is generated in the lowering framework directly
18
+ without the need of an additional torch op to write. This file register a no-op
19
+ placeholder torch op to replace torch.ops.xla.write_mlir_debuginfo in
20
+ ai-edge-torch.
21
+ """
22
+
23
+ from jax._src.lib.mlir import ir
24
+ import torch
25
+
26
+ from .. import _torch_library
27
+ from .. import lowerings
28
+
29
+
30
+ _torch_library.ODML_TORCH_LIB.define(
31
+ "write_mlir_debuginfo(Tensor x, str data) -> Tensor"
32
+ )
33
+
34
+ write_mlir_debuginfo_op = torch.ops.odml_torch.write_mlir_debuginfo
35
+
36
+
37
+ @torch.library.impl(
38
+ _torch_library.ODML_TORCH_LIB,
39
+ "write_mlir_debuginfo",
40
+ "CompositeExplicitAutograd",
41
+ )
42
+ def write_mlir_debuginfo(x: torch.Tensor, _: str):
43
+ return x
44
+
45
+
46
+ @torch.library.impl(
47
+ _torch_library.ODML_TORCH_LIB, "write_mlir_debuginfo", "Meta"
48
+ )
49
+ def write_mlir_debuginfo_meta(x: torch.Tensor, _: str):
50
+ return torch.empty_like(x)
51
+
52
+
53
+ @lowerings.lower(torch.ops.odml_torch.write_mlir_debuginfo)
54
+ def write_mlir_debuginfo_lowering(lctx, x: ir.Value, _: str):
55
+ return x