torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,47 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ def _is_in_bad_fork():
19
+ return False
20
+
21
+
22
+ def manual_seed_all(seed):
23
+ pass
24
+
25
+
26
+ def device_count():
27
+ return 1
28
+
29
+
30
+ def get_rng_state():
31
+ return []
32
+
33
+
34
+ def set_rng_state(new_state, device):
35
+ pass
36
+
37
+
38
+ def is_available():
39
+ return True
40
+
41
+
42
+ def current_device():
43
+ return 0
44
+
45
+
46
+ def get_amp_supported_dtype():
47
+ return [torch.float16, torch.bfloat16]
torchax/export.py ADDED
@@ -0,0 +1,259 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable
16
+ """Utilities for exporting a torch program to jax/stablehlo."""
17
+ import copy
18
+ from typing import Any, Dict, Tuple
19
+ import torch
20
+ from torch.utils import _pytree as pytree
21
+ import torchax
22
+ from torchax import tensor
23
+ from torchax.ops import ops_registry, mappings
24
+ from torchax import decompositions
25
+ import jax
26
+ import jax.export
27
+ import sympy
28
+
29
+ DEBUG = False
30
+
31
+
32
+ class JaxInterpreter(torch.fx.Interpreter):
33
+ """Experimental."""
34
+
35
+ def __init__(self, graph_module):
36
+ super().__init__(graph_module)
37
+ import torchax.ops.jaten
38
+ import torchax.ops.jtorch
39
+
40
+ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
41
+ if not isinstance(target,
42
+ (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
43
+ return super().call_function(target, args, kwargs)
44
+
45
+ if DEBUG:
46
+ print('Running ', target.name(), '--------')
47
+
48
+ op = ops_registry.all_aten_ops.get(target)
49
+ if op is None:
50
+ op = ops_registry.all_aten_ops.get(target.overloadpacket)
51
+ assert op is not None, target
52
+ assert op.is_jax_function, op
53
+ if op is None:
54
+ op = ops_registry.all_aten_ops.get(target.overloadpacket)
55
+ if op is None:
56
+ print(target.name(), target.tags)
57
+ raise RuntimeError('No lowering found for', target.name())
58
+ return op.func(*args, **kwargs)
59
+
60
+ def run_node(self, n) -> Any:
61
+ res = super().run_node(n)
62
+ if DEBUG:
63
+ if n.op == 'call_function':
64
+ if hasattr(res, 'shape'):
65
+ print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
66
+ return res
67
+
68
+
69
+ from torch._decomp import get_decompositions
70
+ import torch._refs
71
+
72
+ _extra_decomp = get_decompositions([torch.ops.aten.unfold])
73
+
74
+
75
+ def _extract_states_from_exported_program(exported_model):
76
+ # NOTE call convention: (parameters, buffers, user_inputs)
77
+ param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
78
+ state_dict = copy.copy(exported_model.state_dict)
79
+ if (constants := getattr(exported_model, 'constants', None)) is not None:
80
+ state_dict.update(constants)
81
+ param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys)
82
+
83
+ if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
84
+ for name in exported_model.graph_signature.lifted_tensor_constants:
85
+ param_buffer_values.append(exported_model.tensor_constants[name])
86
+
87
+ return param_and_buffer_keys, param_buffer_values
88
+
89
+
90
+ def exported_program_to_jax(exported_program, export_raw: bool = False):
91
+ """returns a pytree of jax arrays(state), and
92
+
93
+ a callable(func) that is jax function.
94
+
95
+ func(state, input) would be how you call it.
96
+ """
97
+ if torch.__version__ >= '2.2':
98
+ # torch version 2.1 didn't expose this yet
99
+ exported_program = exported_program.run_decompositions()
100
+ exported_program = exported_program.run_decompositions(
101
+ decompositions.DECOMPOSITIONS)
102
+ if DEBUG:
103
+ print(exported_program.graph_module.code)
104
+
105
+ names, states = _extract_states_from_exported_program(exported_program)
106
+
107
+ def _extract_args(args, kwargs):
108
+ flat_args, received_spec = pytree.tree_flatten(
109
+ (args, kwargs)) # type: ignore[possibly-undefined]
110
+ return flat_args
111
+
112
+ num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
113
+
114
+ def func(states, inputs):
115
+ args = _extract_args(inputs, {})
116
+ res = JaxInterpreter(exported_program.graph_module).run(
117
+ *states,
118
+ *args,
119
+ enable_io_processing=False,
120
+ )
121
+ res = res[num_mutations:]
122
+ return res
123
+
124
+ if export_raw:
125
+ return names, states, func
126
+ env = torchax.default_env()
127
+ states = env.t2j_copy(states)
128
+ return states, func
129
+
130
+
131
+ def extract_avals(exported):
132
+ """Return JAX Abstract Value shapes for all input parameters of the exported
133
+ program. This supports dynamic batch dimensions, including with constraints.
134
+ """
135
+
136
+ def _to_aval(arg_meta, symbolic_shapes):
137
+ """Convet from torch type to jax abstract value for export tracing
138
+ """
139
+
140
+ def _get_dim(d):
141
+ if isinstance(d, torch.SymInt):
142
+ return symbolic_shapes[str(d)]
143
+ return d
144
+
145
+ val = arg_meta['val']
146
+ is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
147
+ val, bool)
148
+ if is_scalar:
149
+ return jax.ShapeDtypeStruct([], type(arg_meta['val']))
150
+
151
+ tensor_meta = arg_meta['tensor_meta']
152
+ shape = [_get_dim(d) for d in tensor_meta.shape]
153
+ return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
154
+
155
+ def _get_inputs(exported):
156
+ """Return placeholders with input metadata"""
157
+ placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
158
+ input_placeholders = [
159
+ p for p, s in zip(placeholders, exported.graph_signature.input_specs)
160
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
161
+ ]
162
+ return input_placeholders
163
+
164
+ def _build_symbolic_shapes(range_constraints):
165
+ """Convert torch SymInt to JAX symbolic_shape and stores in a map using the
166
+ string name of the torch symbolic int.
167
+
168
+ TODO: There is probably a better way of storing a key for a symbolic int.
169
+ This value needs to be looked up again in `_to_aval` to figure out which
170
+ JAX symbolic to map to for a given torch tensor.
171
+ """
172
+ if len(range_constraints) == 0:
173
+ return None
174
+
175
+ def _build_symbolic_constraints(symbol_name, torch_constraint):
176
+ """Convert torch SymInt constraints to string for JAX symbolic_shape
177
+ Using sympy may be overkill here, currently PyTorch only uses ValueRanges
178
+ which allow specifying the min and the max of a value, for example:
179
+ torch.export.Dim("a", min=5, max=10)
180
+ ==> ("a >= 5", "a <= 10",)
181
+ """
182
+ if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
183
+ ValueRanges) or torch_constraint.is_bool:
184
+ raise TypeError(
185
+ f"No symbolic constraint handler for: {torch_constraint}")
186
+
187
+ constraints = []
188
+ symbol = sympy.Symbol(symbol_name)
189
+ if torch_constraint.lower != 2:
190
+ constraints.append(symbol >= torch_constraint.lower)
191
+ from sympy.core.singleton import S
192
+ if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity:
193
+ constraints.append(symbol <= torch_constraint.upper)
194
+
195
+ return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
196
+
197
+ def _build_symbolic_shape(sym, constraint, free_symbols):
198
+ """Returns a JAX symbolic shape for a given symbol and constraint
199
+
200
+ There are two possible sympy `sym` inputs:
201
+ 1. Symbol - (s0) These can have custom constraints.
202
+ 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
203
+
204
+ Currently support is limited to operations with a symbol and and int,
205
+ in `torch/export/dynamic_shapes.py`:
206
+ "Only increasing linear operations with integer coefficients are supported."
207
+ """
208
+ symbol_name = str(sym)
209
+ constraints = _build_symbolic_constraints(symbol_name, constraint)
210
+ if sym.is_symbol:
211
+ symbolic_shape = jax.export.symbolic_shape(
212
+ symbol_name, constraints=constraints)
213
+ else:
214
+ assert len(sym.free_symbols) > 0
215
+ scope = free_symbols[str(list(sym.free_symbols)[0])].scope
216
+ symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope)
217
+ assert len(symbolic_shape) == 1
218
+ return symbolic_shape[0]
219
+
220
+ # Populate symbol variables before expressions, exprs need to use the same
221
+ # Symbolic scope as the variable they operate on. Expressions can only be
222
+ # integer compuations on symbol variables, so each symbol variable is OK to
223
+ # have its own scope.
224
+ symbolic_shapes = {}
225
+ symbol_variables = [
226
+ (s, v) for s, v in range_constraints.items() if s.is_symbol
227
+ ]
228
+ symbol_exprs = [
229
+ (s, v) for s, v in range_constraints.items() if not s.is_symbol
230
+ ]
231
+ for sym, constraint in symbol_variables + symbol_exprs:
232
+ symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
233
+ symbolic_shapes[str(sym)] = symbolic_shape
234
+ return symbolic_shapes
235
+
236
+ symbolic_shapes = _build_symbolic_shapes(exported.range_constraints)
237
+ args = _get_inputs(exported)
238
+
239
+ if DEBUG:
240
+ print('Inputs to aval:', args, '--------')
241
+ print('Symbolic shapes:', symbolic_shapes)
242
+ for arg in args:
243
+ print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes))
244
+
245
+ return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
246
+
247
+
248
+ def exported_program_to_stablehlo(exported_program):
249
+ """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo
250
+
251
+ Convert a program exported via torch.export to StableHLO.
252
+
253
+ This supports dynamic dimension sizes and generates explicit checks for
254
+ dynamo guards in the IR using shape_assertion custom_call ops.
255
+ """
256
+ weights, func = exported_program_to_jax(exported_program)
257
+ jax_avals = extract_avals(exported_program)
258
+ jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,))
259
+ return weights, jax_export
torchax/flax.py ADDED
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Flax interop."""
16
+
17
+ import torch
18
+ import torchax as tx
19
+ import torchax.interop
20
+
21
+
22
+ class FlaxNNModule(torch.nn.Module):
23
+
24
+ def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
25
+ super().__init__()
26
+ prng = env.prng_key
27
+ sample_kwargs = sample_kwargs or {}
28
+ parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
29
+ **sample_kwargs)
30
+
31
+ self._params = self._encode_nested_dict(parameter_dict)
32
+
33
+ self._flax_module = flax_module
34
+
35
+ def _encode_nested_dict(self, nested_dict):
36
+ child_module = torch.nn.Module()
37
+ for k, v in nested_dict.items():
38
+ if isinstance(v, dict):
39
+ child_module.add_module(k, self._encode_nested_dict(v))
40
+ else:
41
+ child_module.register_parameter(k, torch.nn.Parameter(v))
42
+ return child_module
43
+
44
+ def _decode_nested_dict(self, child_module):
45
+ result = dict(child_module.named_parameters(recurse=False))
46
+ for k, v in child_module.named_children():
47
+ result[k] = self._decode_nested_dict(v)
48
+ return result
49
+
50
+ def forward(self, *args, **kwargs):
51
+ nested_dict_params = self._decode_nested_dict(self._params)
52
+ return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
53
+ *args, **kwargs)