ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
- ai_edge_torch/_convert/test/test_convert.py +1 -1
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
- ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
- ai_edge_torch/debug/test/test_culprit.py +1 -1
- ai_edge_torch/debug/test/test_search_model.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
- ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_quantize.py +1 -1
- ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
- ai_edge_torch/lowertools/test_utils.py +1 -1
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +320 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utilities for ODML Torch export."""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import re
|
|
19
|
+
from typing import Sequence, cast
|
|
20
|
+
import jax._src.interpreters.mlir
|
|
21
|
+
from jax._src.lib.mlir import ir
|
|
22
|
+
from jax._src.lib.mlir.dialects import func
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# std::numeric_limits<int64_t>::min()
|
|
26
|
+
IR_DYNAMIC = -9223372036854775808
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_ir_dynamic(v):
|
|
30
|
+
return v == IR_DYNAMIC
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_torch_dynamic(v):
|
|
34
|
+
return isinstance(v, torch.SymInt)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def is_iterable(v):
|
|
38
|
+
try:
|
|
39
|
+
iter(v)
|
|
40
|
+
except TypeError:
|
|
41
|
+
return False
|
|
42
|
+
return True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_ir_context():
|
|
46
|
+
# HACK: Use ir context from JAX as base for better stability in OSS.
|
|
47
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
|
48
|
+
context = jax._src.interpreters.mlir.make_ir_context()
|
|
49
|
+
context.allow_unregistered_dialects = True
|
|
50
|
+
|
|
51
|
+
return context
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def inline(
|
|
55
|
+
symbol_table: ir.SymbolTable,
|
|
56
|
+
block: ir.Block,
|
|
57
|
+
):
|
|
58
|
+
"""Recursively inlines all func.call ops in the block.
|
|
59
|
+
|
|
60
|
+
The symbol_table must include all func.func called by func.call ops.
|
|
61
|
+
This inliner in Python is implemented because MLIR inline pass from JAX's
|
|
62
|
+
MLIR pybinding build in OSS cannot properly inline func.call ops.
|
|
63
|
+
"""
|
|
64
|
+
while True:
|
|
65
|
+
is_changed = False
|
|
66
|
+
for op in block.operations:
|
|
67
|
+
if op.OPERATION_NAME != func.CallOp.OPERATION_NAME:
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
call_op = cast(func.CallOp, op)
|
|
71
|
+
func_op = cast(func.FuncOp, symbol_table[call_op.callee.value])
|
|
72
|
+
with ir.InsertionPoint(op):
|
|
73
|
+
new_results = clone_func_body_ops(func_op, call_op.operands)
|
|
74
|
+
|
|
75
|
+
for old_result, new_result in zip(call_op.results, new_results):
|
|
76
|
+
old_result = cast(ir.Value, old_result)
|
|
77
|
+
old_result.replace_all_uses_with(new_result)
|
|
78
|
+
call_op.erase()
|
|
79
|
+
is_changed = True
|
|
80
|
+
|
|
81
|
+
if not is_changed:
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
for op in block.operations:
|
|
85
|
+
for region in op.regions:
|
|
86
|
+
for block in region.blocks:
|
|
87
|
+
inline(symbol_table, block)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
|
|
91
|
+
"""Clone operations in the func_op's body by one into the current context."""
|
|
92
|
+
func_args = list(func_op.arguments)
|
|
93
|
+
ir_inputs = list(ir_inputs)
|
|
94
|
+
assert len(func_args) == len(ir_inputs)
|
|
95
|
+
|
|
96
|
+
value_mapping = {arg: ir_input for arg, ir_input in zip(func_args, ir_inputs)}
|
|
97
|
+
|
|
98
|
+
for op in list(func_op.entry_block.operations):
|
|
99
|
+
cloned_operands = [value_mapping[val] for val in op.operands]
|
|
100
|
+
if op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME:
|
|
101
|
+
return cloned_operands
|
|
102
|
+
|
|
103
|
+
cloned = cast(ir.Operation, op.operation.clone())
|
|
104
|
+
|
|
105
|
+
for i in range(len(op.operands)):
|
|
106
|
+
cloned.operands[i] = cloned_operands[i]
|
|
107
|
+
|
|
108
|
+
for i in range(len(op.results)):
|
|
109
|
+
value_mapping[op.results[i]] = cloned.results[i]
|
|
110
|
+
|
|
111
|
+
return []
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def sanitize_aten_op_name(op, chars=":."):
|
|
115
|
+
return re.sub("[{}]".format(chars), "_", str(op))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def build_ir_attr(val):
|
|
119
|
+
if val is None:
|
|
120
|
+
return ir.StringAttr.get("py_None")
|
|
121
|
+
if isinstance(val, bool):
|
|
122
|
+
return ir.BoolAttr.get(val)
|
|
123
|
+
if isinstance(val, int):
|
|
124
|
+
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
|
|
125
|
+
if isinstance(val, float):
|
|
126
|
+
return ir.BoolAttr.get(val)
|
|
127
|
+
if isinstance(val, str):
|
|
128
|
+
return ir.StringAttr.get(val)
|
|
129
|
+
if isinstance(val, dict):
|
|
130
|
+
return ir.DictAttr.get({k: build_ir_attr(v) for k, v in val.items()})
|
|
131
|
+
if isinstance(val, (list, tuple)):
|
|
132
|
+
return ir.ArrayAttr.get([build_ir_attr(v) for v in val])
|
|
133
|
+
|
|
134
|
+
# Stringify the value to a StringAttr by default
|
|
135
|
+
return ir.StringAttr.get(str(val))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def torch_dtype_to_ir_element_type(ctx, dtype):
|
|
139
|
+
ty_get = {
|
|
140
|
+
torch.double: ir.F64Type.get,
|
|
141
|
+
torch.float32: ir.F32Type.get,
|
|
142
|
+
torch.half: ir.F16Type.get,
|
|
143
|
+
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
|
|
144
|
+
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
|
|
145
|
+
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
|
146
|
+
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
|
147
|
+
}.get(dtype)
|
|
148
|
+
return ty_get(ctx)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def ir_element_type_to_torch_dtype(ty):
|
|
152
|
+
if isinstance(ty, ir.F32Type):
|
|
153
|
+
return torch.float32
|
|
154
|
+
if isinstance(ty, ir.F64Type):
|
|
155
|
+
return torch.float64
|
|
156
|
+
if isinstance(ty, ir.F16Type):
|
|
157
|
+
return torch.half
|
|
158
|
+
if isinstance(ty, ir.IntegerType):
|
|
159
|
+
if ty.is_signless:
|
|
160
|
+
if ty.width == 64:
|
|
161
|
+
return torch.long
|
|
162
|
+
if ty.width == 32:
|
|
163
|
+
return torch.int32
|
|
164
|
+
if ty.width == 16:
|
|
165
|
+
return torch.int16
|
|
166
|
+
if ty.width == 1:
|
|
167
|
+
return torch.bool
|
|
168
|
+
raise RuntimeError(f"Unsupported ir element type: {ty}")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""APIs to wrap JAX functions for using in ODML Torch lowerings."""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import inspect
|
|
19
|
+
from typing import Any, Callable, cast
|
|
20
|
+
import uuid
|
|
21
|
+
from ai_edge_torch.odml_torch import export_utils
|
|
22
|
+
from ai_edge_torch.odml_torch import passes
|
|
23
|
+
from ai_edge_torch.odml_torch.jax_bridge import utils
|
|
24
|
+
import jax
|
|
25
|
+
from jax._src.lib.mlir import ir
|
|
26
|
+
from jax._src.lib.mlir.dialects import func
|
|
27
|
+
import torch.utils._pytree as pytree
|
|
28
|
+
|
|
29
|
+
# Jax double (64bit) precision is required to generate StableHLO mlir with
|
|
30
|
+
# i64/f64 tensors from Jax bridged lowerings. If not set properly, all the
|
|
31
|
+
# 64bit tensors would be truncated to 32bit dtype and potentially break the
|
|
32
|
+
# lowering.
|
|
33
|
+
jax.config.update("jax_enable_x64", True)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _lower_to_ir_text(
|
|
37
|
+
jaxfn, args, kwargs, ir_input_names: list[str] = None
|
|
38
|
+
) -> str:
|
|
39
|
+
args = utils.tree_map_list_to_tuple(args)
|
|
40
|
+
kwargs = utils.tree_map_list_to_tuple(kwargs)
|
|
41
|
+
|
|
42
|
+
names_args = [
|
|
43
|
+
*zip(inspect.signature(jaxfn).parameters.keys(), args),
|
|
44
|
+
*kwargs.items(),
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
static_argnames = []
|
|
48
|
+
jax_lower_static_kwargs = {}
|
|
49
|
+
jax_lower_args = []
|
|
50
|
+
jax_lower_argnames = []
|
|
51
|
+
ir_inputs = []
|
|
52
|
+
|
|
53
|
+
for i, (name, arg) in enumerate(names_args):
|
|
54
|
+
is_positional = i < len(args)
|
|
55
|
+
if not utils.is_ir_variable(arg):
|
|
56
|
+
static_argnames.append(name)
|
|
57
|
+
jax_lower_static_kwargs[name] = arg
|
|
58
|
+
else:
|
|
59
|
+
# Enforce the arg order in the mlir is the same as the lowering func
|
|
60
|
+
jax_lower_args.append(utils.ir_variable_to_jax(arg))
|
|
61
|
+
|
|
62
|
+
if is_positional and len(jax_lower_args) == i + 1:
|
|
63
|
+
# The first N continuous tensor args are passed to the lowering func
|
|
64
|
+
# as positional args, when they passed to the bridged func as
|
|
65
|
+
# positional args also.
|
|
66
|
+
jax_lower_argnames.append(None)
|
|
67
|
+
else:
|
|
68
|
+
# Otherwise pass the arg to the lowering func as keyword arg.
|
|
69
|
+
jax_lower_argnames.append(name)
|
|
70
|
+
|
|
71
|
+
if ir_input_names is None or name in ir_input_names:
|
|
72
|
+
# ir variable can be a nested tuple, while mlir args should be flat.
|
|
73
|
+
ir_inputs += [
|
|
74
|
+
x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
def new_lowering(*args, **jax_lower_static_kwargs):
|
|
78
|
+
jaxfn_args = []
|
|
79
|
+
jaxfn_kwargs = jax_lower_static_kwargs.copy()
|
|
80
|
+
for name, arg in zip(jax_lower_argnames, args):
|
|
81
|
+
if name is None:
|
|
82
|
+
jaxfn_args.append(arg)
|
|
83
|
+
else:
|
|
84
|
+
jaxfn_kwargs[name] = arg
|
|
85
|
+
|
|
86
|
+
return jaxfn(*jaxfn_args, **jaxfn_kwargs)
|
|
87
|
+
|
|
88
|
+
return (
|
|
89
|
+
jax.jit(new_lowering, static_argnames=static_argnames)
|
|
90
|
+
.lower(*jax_lower_args, **jax_lower_static_kwargs)
|
|
91
|
+
.as_text()
|
|
92
|
+
), ir_inputs
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
96
|
+
"""Return the wrapped JAX function to be used in ODMLTorch lowerings.
|
|
97
|
+
|
|
98
|
+
If the given jaxfn has signature `jaxfn(*args, **kwargs) -> return`, the
|
|
99
|
+
wrapped function would:
|
|
100
|
+
- Have signature `wrapped(lctx: odml_torch.export.LoweringContext, *args,
|
|
101
|
+
**kwargs) -> return`.
|
|
102
|
+
- Accept mlir.ir.Value for all params expecting jax.Array as inputs.
|
|
103
|
+
- Return mlir.ir.Value for all jax.Array outputs from jaxfn.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
jaxfn: The JAX function to be wrapped.
|
|
107
|
+
ir_input_names: The input (param) names of the JAX function to be used in
|
|
108
|
+
the MLIR lowering. This is useful when the JAX impl only depends on
|
|
109
|
+
specific inputs to the function. If not specified, all ir.Value passed to
|
|
110
|
+
the wrapped function are assumed to be used in the lowering.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
@functools.wraps(jaxfn)
|
|
114
|
+
def wrapped(lctx, *args, **kwargs):
|
|
115
|
+
|
|
116
|
+
ir_text, ir_inputs = _lower_to_ir_text(
|
|
117
|
+
jaxfn,
|
|
118
|
+
args,
|
|
119
|
+
kwargs,
|
|
120
|
+
ir_input_names=ir_input_names,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
module = ir.Module.parse(ir_text)
|
|
124
|
+
passes.strip_debuginfo(module)
|
|
125
|
+
|
|
126
|
+
symbol_table = ir.SymbolTable(module.operation)
|
|
127
|
+
main_func = symbol_table["main"]
|
|
128
|
+
|
|
129
|
+
with ir.InsertionPoint.at_block_begin(lctx.ir_module.body):
|
|
130
|
+
cloned_func = cast(func.FuncOp, main_func.clone())
|
|
131
|
+
cloned_func_name = f"{jaxfn.__name__}_{uuid.uuid4().hex[:8]}"
|
|
132
|
+
cloned_func.attributes["sym_name"] = ir.StringAttr.get(cloned_func_name)
|
|
133
|
+
cloned_func.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
134
|
+
|
|
135
|
+
# HACK: Use the custom inliner implemented in Python because MLIR inline
|
|
136
|
+
# pass from JAX's MLIR pybinding build in OSS cannot properly inline
|
|
137
|
+
# func.call ops.
|
|
138
|
+
# This should be switched to `passes.inline(module)` when we have our own
|
|
139
|
+
# MLIR pybinding build.
|
|
140
|
+
export_utils.inline(symbol_table, cloned_func.entry_block)
|
|
141
|
+
|
|
142
|
+
if not cloned_func.arguments:
|
|
143
|
+
# Known edge case: when the lowering does not depend on input but
|
|
144
|
+
# just the meta of input like shape or dtype.
|
|
145
|
+
ir_inputs = []
|
|
146
|
+
|
|
147
|
+
results = func.CallOp(cloned_func, ir_inputs).results
|
|
148
|
+
if len(results) == 1:
|
|
149
|
+
return results[0]
|
|
150
|
+
return results
|
|
151
|
+
|
|
152
|
+
return wrapped
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utilities for Jax bridge."""
|
|
16
|
+
|
|
17
|
+
from ai_edge_torch import odml_torch
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
from jax._src.lib.mlir import ir
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def t2j_dtype(dtype):
|
|
25
|
+
return {
|
|
26
|
+
torch.bfloat16: jnp.bfloat16,
|
|
27
|
+
torch.half: jnp.float16,
|
|
28
|
+
torch.float32: jnp.float32,
|
|
29
|
+
torch.double: jnp.double,
|
|
30
|
+
torch.long: jnp.int64,
|
|
31
|
+
torch.int64: jnp.int64,
|
|
32
|
+
torch.int32: jnp.int32,
|
|
33
|
+
torch.int16: jnp.int16,
|
|
34
|
+
torch.int8: jnp.int8,
|
|
35
|
+
torch.uint8: jnp.uint8,
|
|
36
|
+
torch.bool: jnp.bool_,
|
|
37
|
+
torch.complex64: jnp.complex64,
|
|
38
|
+
torch.complex128: jnp.complex128,
|
|
39
|
+
}.get(dtype)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def is_ir_variable(value):
|
|
43
|
+
if isinstance(value, ir.Value):
|
|
44
|
+
return True
|
|
45
|
+
if isinstance(value, (list, tuple)):
|
|
46
|
+
return any(is_ir_variable(x) for x in value)
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def ir_variable_to_jax(value):
|
|
51
|
+
if isinstance(value, (list, tuple)):
|
|
52
|
+
return tuple([ir_variable_to_jax(x) for x in value])
|
|
53
|
+
elif not isinstance(value, ir.Value):
|
|
54
|
+
return value
|
|
55
|
+
elif not isinstance(value.type, ir.RankedTensorType):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"ir.Value to JAX must be in ir.RankedTensorType, got {value}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return jax.ShapeDtypeStruct(
|
|
61
|
+
value.type.shape,
|
|
62
|
+
t2j_dtype(
|
|
63
|
+
odml_torch.export_utils.ir_element_type_to_torch_dtype(
|
|
64
|
+
value.type.element_type
|
|
65
|
+
)
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def tree_map_list_to_tuple(value):
|
|
71
|
+
if isinstance(value, dict):
|
|
72
|
+
return {k: tree_map_list_to_tuple(v) for k, v in value.items()}
|
|
73
|
+
if isinstance(value, (list, tuple)):
|
|
74
|
+
return tuple([tree_map_list_to_tuple(v) for v in value])
|
|
75
|
+
return value
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from . import _basic
|
|
16
|
+
from . import _batch_norm
|
|
17
|
+
from . import _convolution
|
|
18
|
+
from . import _jax_lowerings
|
|
19
|
+
from . import context
|
|
20
|
+
from . import registry
|
|
21
|
+
from . import utils
|
|
22
|
+
from .registry import decompositions
|
|
23
|
+
from .registry import lookup
|
|
24
|
+
from .registry import lower
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import math
|
|
16
|
+
from typing import Optional, Union
|
|
17
|
+
|
|
18
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
|
19
|
+
from jax._src.lib.mlir import ir
|
|
20
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from .registry import lower
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# add(Tensor self, Tensor other) -> Tensor
|
|
28
|
+
# @lower(torch.ops.aten.add)
|
|
29
|
+
def _aten_add(lctx, x: ir.Value, y: ir.Value, alpha=1):
|
|
30
|
+
x, y = utils.upcast_to_same_type(x, y)
|
|
31
|
+
x, y = utils.broadcast_args_if_needed(x, y)
|
|
32
|
+
if alpha == 1:
|
|
33
|
+
return stablehlo.add(x, y)
|
|
34
|
+
|
|
35
|
+
alpha_splat = utils.splat(alpha, y.type.element_type, y.type.shape)
|
|
36
|
+
return stablehlo.add(x, stablehlo.multiply(y, alpha_splat))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# mul.Tensor(Tensor self, Tensor other) -> Tensor
|
|
40
|
+
# @lower(torch.ops.aten.mul.Tensor)
|
|
41
|
+
def _aten_mul_tensor(lctx, self: ir.Value, other: ir.Value):
|
|
42
|
+
self, other = utils.upcast_to_same_type(self, other)
|
|
43
|
+
self, other = utils.broadcast_args_if_needed(self, other)
|
|
44
|
+
|
|
45
|
+
return stablehlo.multiply(self, other)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# cat(Tensor[] tensors, int dim=0) -> Tensor
|
|
49
|
+
# @lower(torch.ops.aten.cat)
|
|
50
|
+
def _aten_cat(lctx, tensors: list[ir.Value], dim: int = 1):
|
|
51
|
+
return stablehlo.ConcatenateOp(tensors, dim).result
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
|
55
|
+
# @lower(torch.ops.aten.view)
|
|
56
|
+
def _aten_view(lctx, self: ir.Value, size: list[int]):
|
|
57
|
+
return stablehlo.ReshapeOp(
|
|
58
|
+
ir.RankedTensorType.get(size, self.type.element_type), self
|
|
59
|
+
).result
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
|
|
63
|
+
@lower(torch.ops.aten.hardtanh)
|
|
64
|
+
def _aten_hardtanh(
|
|
65
|
+
lctx,
|
|
66
|
+
self: ir.Value,
|
|
67
|
+
min_val: Union[int, float] = -1.0,
|
|
68
|
+
max_val: Union[int, float] = 1.0,
|
|
69
|
+
):
|
|
70
|
+
elty = self.type.element_type
|
|
71
|
+
min_val = utils.splat(min_val, elty)
|
|
72
|
+
max_val = utils.splat(max_val, elty)
|
|
73
|
+
|
|
74
|
+
return stablehlo.clamp(min_val, self, max_val)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
|
|
78
|
+
# mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *,
|
|
79
|
+
# ScalarType? dtype=None) -> Tensor
|
|
80
|
+
@lower(torch.ops.aten.mean)
|
|
81
|
+
@lower(torch.ops.aten.mean.dim)
|
|
82
|
+
def _aten_mean_dim(
|
|
83
|
+
lctx,
|
|
84
|
+
self: ir.Value,
|
|
85
|
+
dim: Optional[list[int]] = None,
|
|
86
|
+
keepdim: bool = False,
|
|
87
|
+
*,
|
|
88
|
+
dtype=None,
|
|
89
|
+
):
|
|
90
|
+
self_shape = self.type.shape
|
|
91
|
+
self_elty = self.type.element_type
|
|
92
|
+
if dim is None:
|
|
93
|
+
dim = list(range(len(self_shape)))
|
|
94
|
+
dim = [len(self_shape) + d if d < 0 else d for d in dim]
|
|
95
|
+
dim_ = ir.DenseI64ArrayAttr.get(np.asarray(dim, np.int64))
|
|
96
|
+
dim_to_keep = [d for d in range(len(self_shape)) if d not in dim]
|
|
97
|
+
dim_to_keep_ = ir.DenseI64ArrayAttr.get(np.asarray(dim_to_keep, np.int64))
|
|
98
|
+
|
|
99
|
+
zero_ = utils.splat(0.0, self_elty)
|
|
100
|
+
|
|
101
|
+
reduce_result_shape = [
|
|
102
|
+
s for d, s in enumerate(self_shape) if d in dim_to_keep
|
|
103
|
+
]
|
|
104
|
+
reduce_result_ty = ir.RankedTensorType.get(reduce_result_shape, self_elty)
|
|
105
|
+
reduce_op = stablehlo.ReduceOp([reduce_result_ty], [self], [zero_], dim_)
|
|
106
|
+
|
|
107
|
+
reducer_arg_ty = ir.RankedTensorType.get(tuple(), self_elty)
|
|
108
|
+
reducer = reduce_op.regions[0].blocks.append(reducer_arg_ty, reducer_arg_ty)
|
|
109
|
+
with ir.InsertionPoint(reducer):
|
|
110
|
+
stablehlo.return_(
|
|
111
|
+
[stablehlo.add(reducer.arguments[0], reducer.arguments[1])]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
sum_ = reduce_op.result
|
|
115
|
+
if keepdim:
|
|
116
|
+
sum_ = stablehlo.broadcast_in_dim(
|
|
117
|
+
ir.RankedTensorType.get(
|
|
118
|
+
[s if d in dim_to_keep else 1 for d, s in enumerate(self_shape)],
|
|
119
|
+
self_elty,
|
|
120
|
+
),
|
|
121
|
+
sum_,
|
|
122
|
+
dim_to_keep_,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
dim_els = math.prod([s for d, s in enumerate(self_shape) if d in dim])
|
|
126
|
+
dim_els_ = utils.splat(dim_els, self_elty)
|
|
127
|
+
div_ = stablehlo.broadcast_in_dim(
|
|
128
|
+
sum_.type, dim_els_, ir.DenseI64ArrayAttr.get([])
|
|
129
|
+
)
|
|
130
|
+
mean_ = stablehlo.divide(sum_, div_)
|
|
131
|
+
|
|
132
|
+
return mean_
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# https://pytorch.org/docs/stable/generated/torch.clone.html
|
|
136
|
+
# https://github.com/pytorch/pytorch/blob/a95ceb51a23ae33c00b3a99224143c609b1b3eb3/aten/src/ATen/native/TensorFactories.cpp#L1730
|
|
137
|
+
@lower(torch.ops.aten.clone)
|
|
138
|
+
def _aten_clone(lctx, x: ir.Value, *, memory_format=None):
|
|
139
|
+
return x
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# https://pytorch.org/docs/stable/generated/torch.permute.html
|
|
143
|
+
# https://github.com/pytorch/pytorch/blob/519151a062a9bd4f0d32a9c7c7eae47d7ed847b2/aten/src/ATen/native/TensorShape.cpp#L1448
|
|
144
|
+
# https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
|
|
145
|
+
@lower(torch.ops.aten.permute)
|
|
146
|
+
def _aten_permute(lctx, x: ir.Value, dims: list[int]):
|
|
147
|
+
dim = len(x.type.shape)
|
|
148
|
+
return stablehlo.transpose(x, ir.DenseI64ArrayAttr.get(dims))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# https://pytorch.org/docs/stable/generated/torch.mm.html
|
|
152
|
+
# https://github.com/pytorch/pytorch/blob/ffabb25c489df1dc631a577c12a0c843c8b202f3/aten/src/ATen/native/LinearAlgebra.cpp#L193
|
|
153
|
+
# https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general
|
|
154
|
+
@lower(torch.ops.aten.mm)
|
|
155
|
+
def _aten_mm(mod, mat1: ir.Value, mat2: ir.Value) -> ir.Value:
|
|
156
|
+
mat1_shape = mat1.type.shape
|
|
157
|
+
mat2_shape = mat2.type.shape
|
|
158
|
+
mat1_dims = len(mat1_shape)
|
|
159
|
+
mat2_dims = len(mat2_shape)
|
|
160
|
+
|
|
161
|
+
if mat1_dims != 2 or mat1_dims != 2:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Both arguments must be 2D matrices, received dimensions %d and %d"
|
|
164
|
+
% (mat1_dims, mat2_dims)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if mat1_shape[1] != mat2_shape[0]:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"mat1 and mat2 shapes cannot be multiplied, received shapes %s and %s"
|
|
170
|
+
% (mat1_shape, mat2_shape)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
dot_dnums = stablehlo.DotDimensionNumbers.get(
|
|
174
|
+
lhs_batching_dimensions=[],
|
|
175
|
+
rhs_batching_dimensions=[],
|
|
176
|
+
lhs_contracting_dimensions=(1,),
|
|
177
|
+
rhs_contracting_dimensions=(0,),
|
|
178
|
+
)
|
|
179
|
+
return stablehlo.dot_general(
|
|
180
|
+
ir.RankedTensorType.get(
|
|
181
|
+
(mat1.type.shape[0], mat2.type.shape[1]), mat1.type.element_type
|
|
182
|
+
),
|
|
183
|
+
mat1,
|
|
184
|
+
mat2,
|
|
185
|
+
dot_dnums,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
# https://pytorch.org/docs/stable/generated/torch.div.html
|
|
190
|
+
# https://openxla.org/stablehlo/spec#divide
|
|
191
|
+
# TODO: support rounding mode and type promotion (see torch.div spec).
|
|
192
|
+
# @lower(torch.ops.aten.div)
|
|
193
|
+
def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
194
|
+
# By default, PyTorch performs a "true" division like Python 3. This requires
|
|
195
|
+
# casting integer input types to float to achieve the same semantics using
|
|
196
|
+
# stablehlo.divide.
|
|
197
|
+
if isinstance(x.type.element_type, ir.IntegerType):
|
|
198
|
+
x = utils.convert_int_to_float(x)
|
|
199
|
+
if isinstance(y.type.element_type, ir.IntegerType):
|
|
200
|
+
y = utils.convert_int_to_float(y)
|
|
201
|
+
|
|
202
|
+
x, y = utils.broadcast_args_if_needed(x, y)
|
|
203
|
+
|
|
204
|
+
return stablehlo.divide(x, y)
|