torchax 0.0.10.dev20251118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +149 -0
- torchax/amp.py +218 -0
- torchax/checkpoint.py +85 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +796 -0
- torchax/device_module.py +47 -0
- torchax/export.py +258 -0
- torchax/flax.py +55 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +236 -0
- torchax/ops/__init__.py +25 -0
- torchax/ops/jaten.py +5753 -0
- torchax/ops/jax_reimplement.py +211 -0
- torchax/ops/jc10d.py +64 -0
- torchax/ops/jimage.py +122 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +608 -0
- torchax/ops/jtorchvision_nms.py +268 -0
- torchax/ops/mappings.py +139 -0
- torchax/ops/op_base.py +137 -0
- torchax/ops/ops_registry.py +74 -0
- torchax/tensor.py +732 -0
- torchax/train.py +130 -0
- torchax/types.py +27 -0
- torchax/util.py +104 -0
- torchax/view.py +399 -0
- torchax-0.0.10.dev20251118.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251118.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251118.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251118.dist-info/licenses/LICENSE +201 -0
torchax/device_module.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _is_in_bad_fork():
|
|
19
|
+
return False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def manual_seed_all(seed):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def device_count():
|
|
27
|
+
return 1
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_rng_state():
|
|
31
|
+
return []
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def set_rng_state(new_state, device):
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def is_available():
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def current_device():
|
|
43
|
+
return 0
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_amp_supported_dtype():
|
|
47
|
+
return [torch.float16, torch.bfloat16]
|
torchax/export.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# pylint: disable
|
|
16
|
+
"""Utilities for exporting a torch program to jax/stablehlo."""
|
|
17
|
+
|
|
18
|
+
import copy
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.export
|
|
23
|
+
import sympy
|
|
24
|
+
import torch
|
|
25
|
+
import torch._refs
|
|
26
|
+
from torch._decomp import get_decompositions
|
|
27
|
+
from torch.utils import _pytree as pytree
|
|
28
|
+
|
|
29
|
+
import torchax
|
|
30
|
+
from torchax import decompositions
|
|
31
|
+
from torchax.ops import mappings, ops_registry
|
|
32
|
+
|
|
33
|
+
DEBUG = False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class JaxInterpreter(torch.fx.Interpreter):
|
|
37
|
+
"""Experimental."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, graph_module):
|
|
40
|
+
super().__init__(graph_module)
|
|
41
|
+
|
|
42
|
+
def call_function(self, target, args: tuple, kwargs: dict) -> Any:
|
|
43
|
+
if not isinstance(target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
|
|
44
|
+
return super().call_function(target, args, kwargs)
|
|
45
|
+
|
|
46
|
+
if DEBUG:
|
|
47
|
+
print("Running ", target.name(), "--------")
|
|
48
|
+
|
|
49
|
+
op = ops_registry.all_aten_ops.get(target)
|
|
50
|
+
if op is None:
|
|
51
|
+
op = ops_registry.all_aten_ops.get(target.overloadpacket)
|
|
52
|
+
assert op is not None, target
|
|
53
|
+
assert op.is_jax_function, op
|
|
54
|
+
if op is None:
|
|
55
|
+
op = ops_registry.all_aten_ops.get(target.overloadpacket)
|
|
56
|
+
if op is None:
|
|
57
|
+
print(target.name(), target.tags)
|
|
58
|
+
raise RuntimeError("No lowering found for", target.name())
|
|
59
|
+
return op.func(*args, **kwargs)
|
|
60
|
+
|
|
61
|
+
def run_node(self, n) -> Any:
|
|
62
|
+
res = super().run_node(n)
|
|
63
|
+
if DEBUG:
|
|
64
|
+
if n.op == "call_function":
|
|
65
|
+
if hasattr(res, "shape"):
|
|
66
|
+
print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape)
|
|
67
|
+
return res
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_extra_decomp = get_decompositions([torch.ops.aten.unfold])
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _extract_states_from_exported_program(exported_model):
|
|
74
|
+
# NOTE call convention: (parameters, buffers, user_inputs)
|
|
75
|
+
param_and_buffer_keys = (
|
|
76
|
+
exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
|
|
77
|
+
)
|
|
78
|
+
state_dict = copy.copy(exported_model.state_dict)
|
|
79
|
+
if (constants := getattr(exported_model, "constants", None)) is not None:
|
|
80
|
+
state_dict.update(constants)
|
|
81
|
+
param_buffer_values = [state_dict[key] for key in param_and_buffer_keys]
|
|
82
|
+
|
|
83
|
+
if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
|
|
84
|
+
for name in exported_model.graph_signature.lifted_tensor_constants:
|
|
85
|
+
param_buffer_values.append(exported_model.tensor_constants[name])
|
|
86
|
+
|
|
87
|
+
return param_and_buffer_keys, param_buffer_values
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
91
|
+
"""returns a pytree of jax arrays(state), and
|
|
92
|
+
|
|
93
|
+
a callable(func) that is jax function.
|
|
94
|
+
|
|
95
|
+
func(state, input) would be how you call it.
|
|
96
|
+
"""
|
|
97
|
+
if torch.__version__ >= "2.2":
|
|
98
|
+
# torch version 2.1 didn't expose this yet
|
|
99
|
+
exported_program = exported_program.run_decompositions()
|
|
100
|
+
exported_program = exported_program.run_decompositions(
|
|
101
|
+
decompositions.DECOMPOSITIONS
|
|
102
|
+
)
|
|
103
|
+
if DEBUG:
|
|
104
|
+
print(exported_program.graph_module.code)
|
|
105
|
+
|
|
106
|
+
names, states = _extract_states_from_exported_program(exported_program)
|
|
107
|
+
|
|
108
|
+
def _extract_args(args, kwargs):
|
|
109
|
+
flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined]
|
|
110
|
+
return flat_args
|
|
111
|
+
|
|
112
|
+
num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
|
|
113
|
+
|
|
114
|
+
def func(states, inputs):
|
|
115
|
+
args = _extract_args(inputs, {})
|
|
116
|
+
res = JaxInterpreter(exported_program.graph_module).run(
|
|
117
|
+
*states,
|
|
118
|
+
*args,
|
|
119
|
+
enable_io_processing=False,
|
|
120
|
+
)
|
|
121
|
+
res = res[num_mutations:]
|
|
122
|
+
return res
|
|
123
|
+
|
|
124
|
+
if export_raw:
|
|
125
|
+
return names, states, func
|
|
126
|
+
env = torchax.default_env()
|
|
127
|
+
states = env.t2j_copy(states)
|
|
128
|
+
return states, func
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def extract_avals(exported):
|
|
132
|
+
"""Return JAX Abstract Value shapes for all input parameters of the exported
|
|
133
|
+
program. This supports dynamic batch dimensions, including with constraints.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def _to_aval(arg_meta, symbolic_shapes):
|
|
137
|
+
"""Convet from torch type to jax abstract value for export tracing"""
|
|
138
|
+
|
|
139
|
+
def _get_dim(d):
|
|
140
|
+
if isinstance(d, torch.SymInt):
|
|
141
|
+
return symbolic_shapes[str(d)]
|
|
142
|
+
return d
|
|
143
|
+
|
|
144
|
+
val = arg_meta["val"]
|
|
145
|
+
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
|
|
146
|
+
if is_scalar:
|
|
147
|
+
return jax.ShapeDtypeStruct([], type(arg_meta["val"]))
|
|
148
|
+
|
|
149
|
+
tensor_meta = arg_meta["tensor_meta"]
|
|
150
|
+
shape = [_get_dim(d) for d in tensor_meta.shape]
|
|
151
|
+
return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
|
|
152
|
+
|
|
153
|
+
def _get_inputs(exported):
|
|
154
|
+
"""Return placeholders with input metadata"""
|
|
155
|
+
placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
|
|
156
|
+
input_placeholders = [
|
|
157
|
+
p
|
|
158
|
+
for p, s in zip(placeholders, exported.graph_signature.input_specs, strict=False)
|
|
159
|
+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
160
|
+
]
|
|
161
|
+
return input_placeholders
|
|
162
|
+
|
|
163
|
+
def _build_symbolic_shapes(range_constraints):
|
|
164
|
+
"""Convert torch SymInt to JAX symbolic_shape and stores in a map using the
|
|
165
|
+
string name of the torch symbolic int.
|
|
166
|
+
|
|
167
|
+
TODO: There is probably a better way of storing a key for a symbolic int.
|
|
168
|
+
This value needs to be looked up again in `_to_aval` to figure out which
|
|
169
|
+
JAX symbolic to map to for a given torch tensor.
|
|
170
|
+
"""
|
|
171
|
+
if len(range_constraints) == 0:
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
def _build_symbolic_constraints(symbol_name, torch_constraint):
|
|
175
|
+
"""Convert torch SymInt constraints to string for JAX symbolic_shape
|
|
176
|
+
Using sympy may be overkill here, currently PyTorch only uses ValueRanges
|
|
177
|
+
which allow specifying the min and the max of a value, for example:
|
|
178
|
+
torch.export.Dim("a", min=5, max=10)
|
|
179
|
+
==> ("a >= 5", "a <= 10",)
|
|
180
|
+
"""
|
|
181
|
+
if (
|
|
182
|
+
not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges)
|
|
183
|
+
or torch_constraint.is_bool
|
|
184
|
+
):
|
|
185
|
+
raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
|
|
186
|
+
|
|
187
|
+
constraints = []
|
|
188
|
+
symbol = sympy.Symbol(symbol_name)
|
|
189
|
+
if torch_constraint.lower != 2:
|
|
190
|
+
constraints.append(symbol >= torch_constraint.lower)
|
|
191
|
+
from sympy.core.singleton import S
|
|
192
|
+
|
|
193
|
+
if (
|
|
194
|
+
not torch_constraint.upper.is_infinite
|
|
195
|
+
and torch_constraint.upper is not S.IntInfinity
|
|
196
|
+
):
|
|
197
|
+
constraints.append(symbol <= torch_constraint.upper)
|
|
198
|
+
|
|
199
|
+
return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
|
|
200
|
+
|
|
201
|
+
def _build_symbolic_shape(sym, constraint, free_symbols):
|
|
202
|
+
"""Returns a JAX symbolic shape for a given symbol and constraint
|
|
203
|
+
|
|
204
|
+
There are two possible sympy `sym` inputs:
|
|
205
|
+
1. Symbol - (s0) These can have custom constraints.
|
|
206
|
+
2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
|
|
207
|
+
|
|
208
|
+
Currently support is limited to operations with a symbol and and int,
|
|
209
|
+
in `torch/export/dynamic_shapes.py`:
|
|
210
|
+
"Only increasing linear operations with integer coefficients are supported."
|
|
211
|
+
"""
|
|
212
|
+
symbol_name = str(sym)
|
|
213
|
+
constraints = _build_symbolic_constraints(symbol_name, constraint)
|
|
214
|
+
if sym.is_symbol:
|
|
215
|
+
symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
|
|
216
|
+
else:
|
|
217
|
+
assert len(sym.free_symbols) > 0
|
|
218
|
+
scope = free_symbols[str(list(sym.free_symbols)[0])].scope
|
|
219
|
+
symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope)
|
|
220
|
+
assert len(symbolic_shape) == 1
|
|
221
|
+
return symbolic_shape[0]
|
|
222
|
+
|
|
223
|
+
# Populate symbol variables before expressions, exprs need to use the same
|
|
224
|
+
# Symbolic scope as the variable they operate on. Expressions can only be
|
|
225
|
+
# integer compuations on symbol variables, so each symbol variable is OK to
|
|
226
|
+
# have its own scope.
|
|
227
|
+
symbolic_shapes = {}
|
|
228
|
+
symbol_variables = [(s, v) for s, v in range_constraints.items() if s.is_symbol]
|
|
229
|
+
symbol_exprs = [(s, v) for s, v in range_constraints.items() if not s.is_symbol]
|
|
230
|
+
for sym, constraint in symbol_variables + symbol_exprs:
|
|
231
|
+
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
|
|
232
|
+
symbolic_shapes[str(sym)] = symbolic_shape
|
|
233
|
+
return symbolic_shapes
|
|
234
|
+
|
|
235
|
+
symbolic_shapes = _build_symbolic_shapes(exported.range_constraints)
|
|
236
|
+
args = _get_inputs(exported)
|
|
237
|
+
|
|
238
|
+
if DEBUG:
|
|
239
|
+
print("Inputs to aval:", args, "--------")
|
|
240
|
+
print("Symbolic shapes:", symbolic_shapes)
|
|
241
|
+
for arg in args:
|
|
242
|
+
print("Meta2Aval", arg.meta, "--> ", _to_aval(arg.meta, symbolic_shapes))
|
|
243
|
+
|
|
244
|
+
return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def exported_program_to_stablehlo(exported_program):
|
|
248
|
+
"""Replacement for torch_xla.stablehlo.exported_program_to_stablehlo
|
|
249
|
+
|
|
250
|
+
Convert a program exported via torch.export to StableHLO.
|
|
251
|
+
|
|
252
|
+
This supports dynamic dimension sizes and generates explicit checks for
|
|
253
|
+
dynamo guards in the IR using shape_assertion custom_call ops.
|
|
254
|
+
"""
|
|
255
|
+
weights, func = exported_program_to_jax(exported_program)
|
|
256
|
+
jax_avals = extract_avals(exported_program)
|
|
257
|
+
jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,))
|
|
258
|
+
return weights, jax_export
|
torchax/flax.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Flax interop."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
import torchax as tx
|
|
20
|
+
import torchax.interop
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FlaxNNModule(torch.nn.Module):
|
|
24
|
+
def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
prng = env.prng_key
|
|
27
|
+
sample_kwargs = sample_kwargs or {}
|
|
28
|
+
parameter_dict = tx.interop.call_jax(
|
|
29
|
+
flax_module.init, prng, *sample_args, **sample_kwargs
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self._params = self._encode_nested_dict(parameter_dict)
|
|
33
|
+
|
|
34
|
+
self._flax_module = flax_module
|
|
35
|
+
|
|
36
|
+
def _encode_nested_dict(self, nested_dict):
|
|
37
|
+
child_module = torch.nn.Module()
|
|
38
|
+
for k, v in nested_dict.items():
|
|
39
|
+
if isinstance(v, dict):
|
|
40
|
+
child_module.add_module(k, self._encode_nested_dict(v))
|
|
41
|
+
else:
|
|
42
|
+
child_module.register_parameter(k, torch.nn.Parameter(v))
|
|
43
|
+
return child_module
|
|
44
|
+
|
|
45
|
+
def _decode_nested_dict(self, child_module):
|
|
46
|
+
result = dict(child_module.named_parameters(recurse=False))
|
|
47
|
+
for k, v in child_module.named_children():
|
|
48
|
+
result[k] = self._decode_nested_dict(v)
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
def forward(self, *args, **kwargs):
|
|
52
|
+
nested_dict_params = self._decode_nested_dict(self._params)
|
|
53
|
+
return tx.interop.call_jax(
|
|
54
|
+
self._flax_module.apply, nested_dict_params, *args, **kwargs
|
|
55
|
+
)
|