ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (46) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
  9. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  10. ai_edge_torch/generative/test/test_loader.py +1 -1
  11. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  12. ai_edge_torch/generative/test/test_quantize.py +1 -1
  13. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  14. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  15. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  16. ai_edge_torch/lowertools/test_utils.py +1 -1
  17. ai_edge_torch/odml_torch/__init__.py +20 -0
  18. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  19. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  20. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  21. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  22. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  23. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  24. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  25. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  26. ai_edge_torch/odml_torch/export.py +320 -0
  27. ai_edge_torch/odml_torch/export_utils.py +168 -0
  28. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  29. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  30. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  31. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  32. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  33. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  34. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  35. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  36. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  37. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  38. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  39. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  40. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  41. ai_edge_torch/version.py +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
  44. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for ODML Torch export."""
16
+
17
+ import functools
18
+ import re
19
+ from typing import Sequence, cast
20
+ import jax._src.interpreters.mlir
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import func
23
+ import torch
24
+
25
+ # std::numeric_limits<int64_t>::min()
26
+ IR_DYNAMIC = -9223372036854775808
27
+
28
+
29
+ def is_ir_dynamic(v):
30
+ return v == IR_DYNAMIC
31
+
32
+
33
+ def is_torch_dynamic(v):
34
+ return isinstance(v, torch.SymInt)
35
+
36
+
37
+ def is_iterable(v):
38
+ try:
39
+ iter(v)
40
+ except TypeError:
41
+ return False
42
+ return True
43
+
44
+
45
+ def create_ir_context():
46
+ # HACK: Use ir context from JAX as base for better stability in OSS.
47
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
48
+ context = jax._src.interpreters.mlir.make_ir_context()
49
+ context.allow_unregistered_dialects = True
50
+
51
+ return context
52
+
53
+
54
+ def inline(
55
+ symbol_table: ir.SymbolTable,
56
+ block: ir.Block,
57
+ ):
58
+ """Recursively inlines all func.call ops in the block.
59
+
60
+ The symbol_table must include all func.func called by func.call ops.
61
+ This inliner in Python is implemented because MLIR inline pass from JAX's
62
+ MLIR pybinding build in OSS cannot properly inline func.call ops.
63
+ """
64
+ while True:
65
+ is_changed = False
66
+ for op in block.operations:
67
+ if op.OPERATION_NAME != func.CallOp.OPERATION_NAME:
68
+ continue
69
+
70
+ call_op = cast(func.CallOp, op)
71
+ func_op = cast(func.FuncOp, symbol_table[call_op.callee.value])
72
+ with ir.InsertionPoint(op):
73
+ new_results = clone_func_body_ops(func_op, call_op.operands)
74
+
75
+ for old_result, new_result in zip(call_op.results, new_results):
76
+ old_result = cast(ir.Value, old_result)
77
+ old_result.replace_all_uses_with(new_result)
78
+ call_op.erase()
79
+ is_changed = True
80
+
81
+ if not is_changed:
82
+ break
83
+
84
+ for op in block.operations:
85
+ for region in op.regions:
86
+ for block in region.blocks:
87
+ inline(symbol_table, block)
88
+
89
+
90
+ def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
91
+ """Clone operations in the func_op's body by one into the current context."""
92
+ func_args = list(func_op.arguments)
93
+ ir_inputs = list(ir_inputs)
94
+ assert len(func_args) == len(ir_inputs)
95
+
96
+ value_mapping = {arg: ir_input for arg, ir_input in zip(func_args, ir_inputs)}
97
+
98
+ for op in list(func_op.entry_block.operations):
99
+ cloned_operands = [value_mapping[val] for val in op.operands]
100
+ if op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME:
101
+ return cloned_operands
102
+
103
+ cloned = cast(ir.Operation, op.operation.clone())
104
+
105
+ for i in range(len(op.operands)):
106
+ cloned.operands[i] = cloned_operands[i]
107
+
108
+ for i in range(len(op.results)):
109
+ value_mapping[op.results[i]] = cloned.results[i]
110
+
111
+ return []
112
+
113
+
114
+ def sanitize_aten_op_name(op, chars=":."):
115
+ return re.sub("[{}]".format(chars), "_", str(op))
116
+
117
+
118
+ def build_ir_attr(val):
119
+ if val is None:
120
+ return ir.StringAttr.get("py_None")
121
+ if isinstance(val, bool):
122
+ return ir.BoolAttr.get(val)
123
+ if isinstance(val, int):
124
+ return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
125
+ if isinstance(val, float):
126
+ return ir.BoolAttr.get(val)
127
+ if isinstance(val, str):
128
+ return ir.StringAttr.get(val)
129
+ if isinstance(val, dict):
130
+ return ir.DictAttr.get({k: build_ir_attr(v) for k, v in val.items()})
131
+ if isinstance(val, (list, tuple)):
132
+ return ir.ArrayAttr.get([build_ir_attr(v) for v in val])
133
+
134
+ # Stringify the value to a StringAttr by default
135
+ return ir.StringAttr.get(str(val))
136
+
137
+
138
+ def torch_dtype_to_ir_element_type(ctx, dtype):
139
+ ty_get = {
140
+ torch.double: ir.F64Type.get,
141
+ torch.float32: ir.F32Type.get,
142
+ torch.half: ir.F16Type.get,
143
+ torch.long: functools.partial(ir.IntegerType.get_signless, 64),
144
+ torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
145
+ torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
146
+ torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
147
+ }.get(dtype)
148
+ return ty_get(ctx)
149
+
150
+
151
+ def ir_element_type_to_torch_dtype(ty):
152
+ if isinstance(ty, ir.F32Type):
153
+ return torch.float32
154
+ if isinstance(ty, ir.F64Type):
155
+ return torch.float64
156
+ if isinstance(ty, ir.F16Type):
157
+ return torch.half
158
+ if isinstance(ty, ir.IntegerType):
159
+ if ty.is_signless:
160
+ if ty.width == 64:
161
+ return torch.long
162
+ if ty.width == 32:
163
+ return torch.int32
164
+ if ty.width == 16:
165
+ return torch.int16
166
+ if ty.width == 1:
167
+ return torch.bool
168
+ raise RuntimeError(f"Unsupported ir element type: {ty}")
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
@@ -0,0 +1,152 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to wrap JAX functions for using in ODML Torch lowerings."""
16
+
17
+ import functools
18
+ import inspect
19
+ from typing import Any, Callable, cast
20
+ import uuid
21
+ from ai_edge_torch.odml_torch import export_utils
22
+ from ai_edge_torch.odml_torch import passes
23
+ from ai_edge_torch.odml_torch.jax_bridge import utils
24
+ import jax
25
+ from jax._src.lib.mlir import ir
26
+ from jax._src.lib.mlir.dialects import func
27
+ import torch.utils._pytree as pytree
28
+
29
+ # Jax double (64bit) precision is required to generate StableHLO mlir with
30
+ # i64/f64 tensors from Jax bridged lowerings. If not set properly, all the
31
+ # 64bit tensors would be truncated to 32bit dtype and potentially break the
32
+ # lowering.
33
+ jax.config.update("jax_enable_x64", True)
34
+
35
+
36
+ def _lower_to_ir_text(
37
+ jaxfn, args, kwargs, ir_input_names: list[str] = None
38
+ ) -> str:
39
+ args = utils.tree_map_list_to_tuple(args)
40
+ kwargs = utils.tree_map_list_to_tuple(kwargs)
41
+
42
+ names_args = [
43
+ *zip(inspect.signature(jaxfn).parameters.keys(), args),
44
+ *kwargs.items(),
45
+ ]
46
+
47
+ static_argnames = []
48
+ jax_lower_static_kwargs = {}
49
+ jax_lower_args = []
50
+ jax_lower_argnames = []
51
+ ir_inputs = []
52
+
53
+ for i, (name, arg) in enumerate(names_args):
54
+ is_positional = i < len(args)
55
+ if not utils.is_ir_variable(arg):
56
+ static_argnames.append(name)
57
+ jax_lower_static_kwargs[name] = arg
58
+ else:
59
+ # Enforce the arg order in the mlir is the same as the lowering func
60
+ jax_lower_args.append(utils.ir_variable_to_jax(arg))
61
+
62
+ if is_positional and len(jax_lower_args) == i + 1:
63
+ # The first N continuous tensor args are passed to the lowering func
64
+ # as positional args, when they passed to the bridged func as
65
+ # positional args also.
66
+ jax_lower_argnames.append(None)
67
+ else:
68
+ # Otherwise pass the arg to the lowering func as keyword arg.
69
+ jax_lower_argnames.append(name)
70
+
71
+ if ir_input_names is None or name in ir_input_names:
72
+ # ir variable can be a nested tuple, while mlir args should be flat.
73
+ ir_inputs += [
74
+ x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
+ ]
76
+
77
+ def new_lowering(*args, **jax_lower_static_kwargs):
78
+ jaxfn_args = []
79
+ jaxfn_kwargs = jax_lower_static_kwargs.copy()
80
+ for name, arg in zip(jax_lower_argnames, args):
81
+ if name is None:
82
+ jaxfn_args.append(arg)
83
+ else:
84
+ jaxfn_kwargs[name] = arg
85
+
86
+ return jaxfn(*jaxfn_args, **jaxfn_kwargs)
87
+
88
+ return (
89
+ jax.jit(new_lowering, static_argnames=static_argnames)
90
+ .lower(*jax_lower_args, **jax_lower_static_kwargs)
91
+ .as_text()
92
+ ), ir_inputs
93
+
94
+
95
+ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
96
+ """Return the wrapped JAX function to be used in ODMLTorch lowerings.
97
+
98
+ If the given jaxfn has signature `jaxfn(*args, **kwargs) -> return`, the
99
+ wrapped function would:
100
+ - Have signature `wrapped(lctx: odml_torch.export.LoweringContext, *args,
101
+ **kwargs) -> return`.
102
+ - Accept mlir.ir.Value for all params expecting jax.Array as inputs.
103
+ - Return mlir.ir.Value for all jax.Array outputs from jaxfn.
104
+
105
+ Args:
106
+ jaxfn: The JAX function to be wrapped.
107
+ ir_input_names: The input (param) names of the JAX function to be used in
108
+ the MLIR lowering. This is useful when the JAX impl only depends on
109
+ specific inputs to the function. If not specified, all ir.Value passed to
110
+ the wrapped function are assumed to be used in the lowering.
111
+ """
112
+
113
+ @functools.wraps(jaxfn)
114
+ def wrapped(lctx, *args, **kwargs):
115
+
116
+ ir_text, ir_inputs = _lower_to_ir_text(
117
+ jaxfn,
118
+ args,
119
+ kwargs,
120
+ ir_input_names=ir_input_names,
121
+ )
122
+
123
+ module = ir.Module.parse(ir_text)
124
+ passes.strip_debuginfo(module)
125
+
126
+ symbol_table = ir.SymbolTable(module.operation)
127
+ main_func = symbol_table["main"]
128
+
129
+ with ir.InsertionPoint.at_block_begin(lctx.ir_module.body):
130
+ cloned_func = cast(func.FuncOp, main_func.clone())
131
+ cloned_func_name = f"{jaxfn.__name__}_{uuid.uuid4().hex[:8]}"
132
+ cloned_func.attributes["sym_name"] = ir.StringAttr.get(cloned_func_name)
133
+ cloned_func.attributes["sym_visibility"] = ir.StringAttr.get("private")
134
+
135
+ # HACK: Use the custom inliner implemented in Python because MLIR inline
136
+ # pass from JAX's MLIR pybinding build in OSS cannot properly inline
137
+ # func.call ops.
138
+ # This should be switched to `passes.inline(module)` when we have our own
139
+ # MLIR pybinding build.
140
+ export_utils.inline(symbol_table, cloned_func.entry_block)
141
+
142
+ if not cloned_func.arguments:
143
+ # Known edge case: when the lowering does not depend on input but
144
+ # just the meta of input like shape or dtype.
145
+ ir_inputs = []
146
+
147
+ results = func.CallOp(cloned_func, ir_inputs).results
148
+ if len(results) == 1:
149
+ return results[0]
150
+ return results
151
+
152
+ return wrapped
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for Jax bridge."""
16
+
17
+ from ai_edge_torch import odml_torch
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from jax._src.lib.mlir import ir
21
+ import torch
22
+
23
+
24
+ def t2j_dtype(dtype):
25
+ return {
26
+ torch.bfloat16: jnp.bfloat16,
27
+ torch.half: jnp.float16,
28
+ torch.float32: jnp.float32,
29
+ torch.double: jnp.double,
30
+ torch.long: jnp.int64,
31
+ torch.int64: jnp.int64,
32
+ torch.int32: jnp.int32,
33
+ torch.int16: jnp.int16,
34
+ torch.int8: jnp.int8,
35
+ torch.uint8: jnp.uint8,
36
+ torch.bool: jnp.bool_,
37
+ torch.complex64: jnp.complex64,
38
+ torch.complex128: jnp.complex128,
39
+ }.get(dtype)
40
+
41
+
42
+ def is_ir_variable(value):
43
+ if isinstance(value, ir.Value):
44
+ return True
45
+ if isinstance(value, (list, tuple)):
46
+ return any(is_ir_variable(x) for x in value)
47
+ return False
48
+
49
+
50
+ def ir_variable_to_jax(value):
51
+ if isinstance(value, (list, tuple)):
52
+ return tuple([ir_variable_to_jax(x) for x in value])
53
+ elif not isinstance(value, ir.Value):
54
+ return value
55
+ elif not isinstance(value.type, ir.RankedTensorType):
56
+ raise ValueError(
57
+ f"ir.Value to JAX must be in ir.RankedTensorType, got {value}"
58
+ )
59
+
60
+ return jax.ShapeDtypeStruct(
61
+ value.type.shape,
62
+ t2j_dtype(
63
+ odml_torch.export_utils.ir_element_type_to_torch_dtype(
64
+ value.type.element_type
65
+ )
66
+ ),
67
+ )
68
+
69
+
70
+ def tree_map_list_to_tuple(value):
71
+ if isinstance(value, dict):
72
+ return {k: tree_map_list_to_tuple(v) for k, v in value.items()}
73
+ if isinstance(value, (list, tuple)):
74
+ return tuple([tree_map_list_to_tuple(v) for v in value])
75
+ return value
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from . import _basic
16
+ from . import _batch_norm
17
+ from . import _convolution
18
+ from . import _jax_lowerings
19
+ from . import context
20
+ from . import registry
21
+ from . import utils
22
+ from .registry import decompositions
23
+ from .registry import lookup
24
+ from .registry import lower
@@ -0,0 +1,204 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import math
16
+ from typing import Optional, Union
17
+
18
+ from ai_edge_torch.odml_torch.lowerings import utils
19
+ from jax._src.lib.mlir import ir
20
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
21
+ import numpy as np
22
+ import torch
23
+
24
+ from .registry import lower
25
+
26
+
27
+ # add(Tensor self, Tensor other) -> Tensor
28
+ # @lower(torch.ops.aten.add)
29
+ def _aten_add(lctx, x: ir.Value, y: ir.Value, alpha=1):
30
+ x, y = utils.upcast_to_same_type(x, y)
31
+ x, y = utils.broadcast_args_if_needed(x, y)
32
+ if alpha == 1:
33
+ return stablehlo.add(x, y)
34
+
35
+ alpha_splat = utils.splat(alpha, y.type.element_type, y.type.shape)
36
+ return stablehlo.add(x, stablehlo.multiply(y, alpha_splat))
37
+
38
+
39
+ # mul.Tensor(Tensor self, Tensor other) -> Tensor
40
+ # @lower(torch.ops.aten.mul.Tensor)
41
+ def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value):
42
+ self, other = utils.upcast_to_same_type(self, other)
43
+ self, other = utils.broadcast_args_if_needed(self, other)
44
+
45
+ return stablehlo.multiply(self, other)
46
+
47
+
48
+ # cat(Tensor[] tensors, int dim=0) -> Tensor
49
+ # @lower(torch.ops.aten.cat)
50
+ def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
51
+ return stablehlo.ConcatenateOp(tensors, dim).result
52
+
53
+
54
+ # view(Tensor(a) self, SymInt[] size) -> Tensor(a)
55
+ # @lower(torch.ops.aten.view)
56
+ def _aten_view(lctx, self: ir.Value, size: list[int]):
57
+ return stablehlo.ReshapeOp(
58
+ ir.RankedTensorType.get(size, self.type.element_type), self
59
+ ).result
60
+
61
+
62
+ # hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
63
+ @lower(torch.ops.aten.hardtanh)
64
+ def _aten_hardtanh(
65
+ lctx,
66
+ self: ir.Value,
67
+ min_val: Union[int, float] = -1.0,
68
+ max_val: Union[int, float] = 1.0,
69
+ ):
70
+ elty = self.type.element_type
71
+ min_val = utils.splat(min_val, elty)
72
+ max_val = utils.splat(max_val, elty)
73
+
74
+ return stablehlo.clamp(min_val, self, max_val)
75
+
76
+
77
+ # mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
78
+ # mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *,
79
+ # ScalarType? dtype=None) -> Tensor
80
+ @lower(torch.ops.aten.mean)
81
+ @lower(torch.ops.aten.mean.dim)
82
+ def _aten_mean_dim(
83
+ lctx,
84
+ self: ir.Value,
85
+ dim: Optional[list[int]] = None,
86
+ keepdim: bool = False,
87
+ *,
88
+ dtype=None,
89
+ ):
90
+ self_shape = self.type.shape
91
+ self_elty = self.type.element_type
92
+ if dim is None:
93
+ dim = list(range(len(self_shape)))
94
+ dim = [len(self_shape) + d if d < 0 else d for d in dim]
95
+ dim_ = ir.DenseI64ArrayAttr.get(np.asarray(dim, np.int64))
96
+ dim_to_keep = [d for d in range(len(self_shape)) if d not in dim]
97
+ dim_to_keep_ = ir.DenseI64ArrayAttr.get(np.asarray(dim_to_keep, np.int64))
98
+
99
+ zero_ = utils.splat(0.0, self_elty)
100
+
101
+ reduce_result_shape = [
102
+ s for d, s in enumerate(self_shape) if d in dim_to_keep
103
+ ]
104
+ reduce_result_ty = ir.RankedTensorType.get(reduce_result_shape, self_elty)
105
+ reduce_op = stablehlo.ReduceOp([reduce_result_ty], [self], [zero_], dim_)
106
+
107
+ reducer_arg_ty = ir.RankedTensorType.get(tuple(), self_elty)
108
+ reducer = reduce_op.regions[0].blocks.append(reducer_arg_ty, reducer_arg_ty)
109
+ with ir.InsertionPoint(reducer):
110
+ stablehlo.return_(
111
+ [stablehlo.add(reducer.arguments[0], reducer.arguments[1])]
112
+ )
113
+
114
+ sum_ = reduce_op.result
115
+ if keepdim:
116
+ sum_ = stablehlo.broadcast_in_dim(
117
+ ir.RankedTensorType.get(
118
+ [s if d in dim_to_keep else 1 for d, s in enumerate(self_shape)],
119
+ self_elty,
120
+ ),
121
+ sum_,
122
+ dim_to_keep_,
123
+ )
124
+
125
+ dim_els = math.prod([s for d, s in enumerate(self_shape) if d in dim])
126
+ dim_els_ = utils.splat(dim_els, self_elty)
127
+ div_ = stablehlo.broadcast_in_dim(
128
+ sum_.type, dim_els_, ir.DenseI64ArrayAttr.get([])
129
+ )
130
+ mean_ = stablehlo.divide(sum_, div_)
131
+
132
+ return mean_
133
+
134
+
135
+ # https://pytorch.org/docs/stable/generated/torch.clone.html
136
+ # https://github.com/pytorch/pytorch/blob/a95ceb51a23ae33c00b3a99224143c609b1b3eb3/aten/src/ATen/native/TensorFactories.cpp#L1730
137
+ @lower(torch.ops.aten.clone)
138
+ def _aten_clone(lctx, x: ir.Value, *, memory_format=None):
139
+ return x
140
+
141
+
142
+ # https://pytorch.org/docs/stable/generated/torch.permute.html
143
+ # https://github.com/pytorch/pytorch/blob/519151a062a9bd4f0d32a9c7c7eae47d7ed847b2/aten/src/ATen/native/TensorShape.cpp#L1448
144
+ # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
145
+ @lower(torch.ops.aten.permute)
146
+ def _aten_permute(lctx, x: ir.Value, dims: list[int]):
147
+ dim = len(x.type.shape)
148
+ return stablehlo.transpose(x, ir.DenseI64ArrayAttr.get(dims))
149
+
150
+
151
+ # https://pytorch.org/docs/stable/generated/torch.mm.html
152
+ # https://github.com/pytorch/pytorch/blob/ffabb25c489df1dc631a577c12a0c843c8b202f3/aten/src/ATen/native/LinearAlgebra.cpp#L193
153
+ # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general
154
+ @lower(torch.ops.aten.mm)
155
+ def _aten_mm(mod, mat1: ir.Value, mat2: ir.Value) -> ir.Value:
156
+ mat1_shape = mat1.type.shape
157
+ mat2_shape = mat2.type.shape
158
+ mat1_dims = len(mat1_shape)
159
+ mat2_dims = len(mat2_shape)
160
+
161
+ if mat1_dims != 2 or mat1_dims != 2:
162
+ raise ValueError(
163
+ "Both arguments must be 2D matrices, received dimensions %d and %d"
164
+ % (mat1_dims, mat2_dims)
165
+ )
166
+
167
+ if mat1_shape[1] != mat2_shape[0]:
168
+ raise ValueError(
169
+ "mat1 and mat2 shapes cannot be multiplied, received shapes %s and %s"
170
+ % (mat1_shape, mat2_shape)
171
+ )
172
+
173
+ dot_dnums = stablehlo.DotDimensionNumbers.get(
174
+ lhs_batching_dimensions=[],
175
+ rhs_batching_dimensions=[],
176
+ lhs_contracting_dimensions=(1,),
177
+ rhs_contracting_dimensions=(0,),
178
+ )
179
+ return stablehlo.dot_general(
180
+ ir.RankedTensorType.get(
181
+ (mat1.type.shape[0], mat2.type.shape[1]), mat1.type.element_type
182
+ ),
183
+ mat1,
184
+ mat2,
185
+ dot_dnums,
186
+ )
187
+
188
+
189
+ # https://pytorch.org/docs/stable/generated/torch.div.html
190
+ # https://openxla.org/stablehlo/spec#divide
191
+ # TODO: support rounding mode and type promotion (see torch.div spec).
192
+ # @lower(torch.ops.aten.div)
193
+ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
194
+ # By default, PyTorch performs a "true" division like Python 3. This requires
195
+ # casting integer input types to float to achieve the same semantics using
196
+ # stablehlo.divide.
197
+ if isinstance(x.type.element_type, ir.IntegerType):
198
+ x = utils.convert_int_to_float(x)
199
+ if isinstance(y.type.element_type, ir.IntegerType):
200
+ y = utils.convert_int_to_float(y)
201
+
202
+ x, y = utils.broadcast_args_if_needed(x, y)
203
+
204
+ return stablehlo.divide(x, y)