torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,47 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ def _is_in_bad_fork():
19
+ return False
20
+
21
+
22
+ def manual_seed_all(seed):
23
+ pass
24
+
25
+
26
+ def device_count():
27
+ return 1
28
+
29
+
30
+ def get_rng_state():
31
+ return []
32
+
33
+
34
+ def set_rng_state(new_state, device):
35
+ pass
36
+
37
+
38
+ def is_available():
39
+ return True
40
+
41
+
42
+ def current_device():
43
+ return 0
44
+
45
+
46
+ def get_amp_supported_dtype():
47
+ return [torch.float16, torch.bfloat16]
torchax/export.py ADDED
@@ -0,0 +1,258 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable
16
+ """Utilities for exporting a torch program to jax/stablehlo."""
17
+
18
+ import copy
19
+ from typing import Any
20
+
21
+ import jax
22
+ import jax.export
23
+ import sympy
24
+ import torch
25
+ import torch._refs
26
+ from torch._decomp import get_decompositions
27
+ from torch.utils import _pytree as pytree
28
+
29
+ import torchax
30
+ from torchax import decompositions
31
+ from torchax.ops import mappings, ops_registry
32
+
33
+ DEBUG = False
34
+
35
+
36
+ class JaxInterpreter(torch.fx.Interpreter):
37
+ """Experimental."""
38
+
39
+ def __init__(self, graph_module):
40
+ super().__init__(graph_module)
41
+
42
+ def call_function(self, target, args: tuple, kwargs: dict) -> Any:
43
+ if not isinstance(target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
44
+ return super().call_function(target, args, kwargs)
45
+
46
+ if DEBUG:
47
+ print("Running ", target.name(), "--------")
48
+
49
+ op = ops_registry.all_aten_ops.get(target)
50
+ if op is None:
51
+ op = ops_registry.all_aten_ops.get(target.overloadpacket)
52
+ assert op is not None, target
53
+ assert op.is_jax_function, op
54
+ if op is None:
55
+ op = ops_registry.all_aten_ops.get(target.overloadpacket)
56
+ if op is None:
57
+ print(target.name(), target.tags)
58
+ raise RuntimeError("No lowering found for", target.name())
59
+ return op.func(*args, **kwargs)
60
+
61
+ def run_node(self, n) -> Any:
62
+ res = super().run_node(n)
63
+ if DEBUG:
64
+ if n.op == "call_function":
65
+ if hasattr(res, "shape"):
66
+ print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape)
67
+ return res
68
+
69
+
70
+ _extra_decomp = get_decompositions([torch.ops.aten.unfold])
71
+
72
+
73
+ def _extract_states_from_exported_program(exported_model):
74
+ # NOTE call convention: (parameters, buffers, user_inputs)
75
+ param_and_buffer_keys = (
76
+ exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
77
+ )
78
+ state_dict = copy.copy(exported_model.state_dict)
79
+ if (constants := getattr(exported_model, "constants", None)) is not None:
80
+ state_dict.update(constants)
81
+ param_buffer_values = [state_dict[key] for key in param_and_buffer_keys]
82
+
83
+ if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
84
+ for name in exported_model.graph_signature.lifted_tensor_constants:
85
+ param_buffer_values.append(exported_model.tensor_constants[name])
86
+
87
+ return param_and_buffer_keys, param_buffer_values
88
+
89
+
90
+ def exported_program_to_jax(exported_program, export_raw: bool = False):
91
+ """returns a pytree of jax arrays(state), and
92
+
93
+ a callable(func) that is jax function.
94
+
95
+ func(state, input) would be how you call it.
96
+ """
97
+ if torch.__version__ >= "2.2":
98
+ # torch version 2.1 didn't expose this yet
99
+ exported_program = exported_program.run_decompositions()
100
+ exported_program = exported_program.run_decompositions(
101
+ decompositions.DECOMPOSITIONS
102
+ )
103
+ if DEBUG:
104
+ print(exported_program.graph_module.code)
105
+
106
+ names, states = _extract_states_from_exported_program(exported_program)
107
+
108
+ def _extract_args(args, kwargs):
109
+ flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined]
110
+ return flat_args
111
+
112
+ num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
113
+
114
+ def func(states, inputs):
115
+ args = _extract_args(inputs, {})
116
+ res = JaxInterpreter(exported_program.graph_module).run(
117
+ *states,
118
+ *args,
119
+ enable_io_processing=False,
120
+ )
121
+ res = res[num_mutations:]
122
+ return res
123
+
124
+ if export_raw:
125
+ return names, states, func
126
+ env = torchax.default_env()
127
+ states = env.t2j_copy(states)
128
+ return states, func
129
+
130
+
131
+ def extract_avals(exported):
132
+ """Return JAX Abstract Value shapes for all input parameters of the exported
133
+ program. This supports dynamic batch dimensions, including with constraints.
134
+ """
135
+
136
+ def _to_aval(arg_meta, symbolic_shapes):
137
+ """Convet from torch type to jax abstract value for export tracing"""
138
+
139
+ def _get_dim(d):
140
+ if isinstance(d, torch.SymInt):
141
+ return symbolic_shapes[str(d)]
142
+ return d
143
+
144
+ val = arg_meta["val"]
145
+ is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
146
+ if is_scalar:
147
+ return jax.ShapeDtypeStruct([], type(arg_meta["val"]))
148
+
149
+ tensor_meta = arg_meta["tensor_meta"]
150
+ shape = [_get_dim(d) for d in tensor_meta.shape]
151
+ return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
152
+
153
+ def _get_inputs(exported):
154
+ """Return placeholders with input metadata"""
155
+ placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
156
+ input_placeholders = [
157
+ p
158
+ for p, s in zip(placeholders, exported.graph_signature.input_specs, strict=False)
159
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
160
+ ]
161
+ return input_placeholders
162
+
163
+ def _build_symbolic_shapes(range_constraints):
164
+ """Convert torch SymInt to JAX symbolic_shape and stores in a map using the
165
+ string name of the torch symbolic int.
166
+
167
+ TODO: There is probably a better way of storing a key for a symbolic int.
168
+ This value needs to be looked up again in `_to_aval` to figure out which
169
+ JAX symbolic to map to for a given torch tensor.
170
+ """
171
+ if len(range_constraints) == 0:
172
+ return None
173
+
174
+ def _build_symbolic_constraints(symbol_name, torch_constraint):
175
+ """Convert torch SymInt constraints to string for JAX symbolic_shape
176
+ Using sympy may be overkill here, currently PyTorch only uses ValueRanges
177
+ which allow specifying the min and the max of a value, for example:
178
+ torch.export.Dim("a", min=5, max=10)
179
+ ==> ("a >= 5", "a <= 10",)
180
+ """
181
+ if (
182
+ not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges)
183
+ or torch_constraint.is_bool
184
+ ):
185
+ raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
186
+
187
+ constraints = []
188
+ symbol = sympy.Symbol(symbol_name)
189
+ if torch_constraint.lower != 2:
190
+ constraints.append(symbol >= torch_constraint.lower)
191
+ from sympy.core.singleton import S
192
+
193
+ if (
194
+ not torch_constraint.upper.is_infinite
195
+ and torch_constraint.upper is not S.IntInfinity
196
+ ):
197
+ constraints.append(symbol <= torch_constraint.upper)
198
+
199
+ return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
200
+
201
+ def _build_symbolic_shape(sym, constraint, free_symbols):
202
+ """Returns a JAX symbolic shape for a given symbol and constraint
203
+
204
+ There are two possible sympy `sym` inputs:
205
+ 1. Symbol - (s0) These can have custom constraints.
206
+ 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
207
+
208
+ Currently support is limited to operations with a symbol and and int,
209
+ in `torch/export/dynamic_shapes.py`:
210
+ "Only increasing linear operations with integer coefficients are supported."
211
+ """
212
+ symbol_name = str(sym)
213
+ constraints = _build_symbolic_constraints(symbol_name, constraint)
214
+ if sym.is_symbol:
215
+ symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
216
+ else:
217
+ assert len(sym.free_symbols) > 0
218
+ scope = free_symbols[str(list(sym.free_symbols)[0])].scope
219
+ symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope)
220
+ assert len(symbolic_shape) == 1
221
+ return symbolic_shape[0]
222
+
223
+ # Populate symbol variables before expressions, exprs need to use the same
224
+ # Symbolic scope as the variable they operate on. Expressions can only be
225
+ # integer compuations on symbol variables, so each symbol variable is OK to
226
+ # have its own scope.
227
+ symbolic_shapes = {}
228
+ symbol_variables = [(s, v) for s, v in range_constraints.items() if s.is_symbol]
229
+ symbol_exprs = [(s, v) for s, v in range_constraints.items() if not s.is_symbol]
230
+ for sym, constraint in symbol_variables + symbol_exprs:
231
+ symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
232
+ symbolic_shapes[str(sym)] = symbolic_shape
233
+ return symbolic_shapes
234
+
235
+ symbolic_shapes = _build_symbolic_shapes(exported.range_constraints)
236
+ args = _get_inputs(exported)
237
+
238
+ if DEBUG:
239
+ print("Inputs to aval:", args, "--------")
240
+ print("Symbolic shapes:", symbolic_shapes)
241
+ for arg in args:
242
+ print("Meta2Aval", arg.meta, "--> ", _to_aval(arg.meta, symbolic_shapes))
243
+
244
+ return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
245
+
246
+
247
+ def exported_program_to_stablehlo(exported_program):
248
+ """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo
249
+
250
+ Convert a program exported via torch.export to StableHLO.
251
+
252
+ This supports dynamic dimension sizes and generates explicit checks for
253
+ dynamo guards in the IR using shape_assertion custom_call ops.
254
+ """
255
+ weights, func = exported_program_to_jax(exported_program)
256
+ jax_avals = extract_avals(exported_program)
257
+ jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,))
258
+ return weights, jax_export
torchax/flax.py ADDED
@@ -0,0 +1,55 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Flax interop."""
16
+
17
+ import torch
18
+
19
+ import torchax as tx
20
+ import torchax.interop
21
+
22
+
23
+ class FlaxNNModule(torch.nn.Module):
24
+ def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
25
+ super().__init__()
26
+ prng = env.prng_key
27
+ sample_kwargs = sample_kwargs or {}
28
+ parameter_dict = tx.interop.call_jax(
29
+ flax_module.init, prng, *sample_args, **sample_kwargs
30
+ )
31
+
32
+ self._params = self._encode_nested_dict(parameter_dict)
33
+
34
+ self._flax_module = flax_module
35
+
36
+ def _encode_nested_dict(self, nested_dict):
37
+ child_module = torch.nn.Module()
38
+ for k, v in nested_dict.items():
39
+ if isinstance(v, dict):
40
+ child_module.add_module(k, self._encode_nested_dict(v))
41
+ else:
42
+ child_module.register_parameter(k, torch.nn.Parameter(v))
43
+ return child_module
44
+
45
+ def _decode_nested_dict(self, child_module):
46
+ result = dict(child_module.named_parameters(recurse=False))
47
+ for k, v in child_module.named_children():
48
+ result[k] = self._decode_nested_dict(v)
49
+ return result
50
+
51
+ def forward(self, *args, **kwargs):
52
+ nested_dict_params = self._decode_nested_dict(self._params)
53
+ return tx.interop.call_jax(
54
+ self._flax_module.apply, nested_dict_params, *args, **kwargs
55
+ )