torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,43 @@
1
+ # Contributing to torchax
2
+
3
+ We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
4
+
5
+
6
+ # Developer setup
7
+
8
+ ## Mac setup:
9
+ @qihqi
10
+
11
+ I am able to develop directly on mac (m1) laptop for most of parts. Using steps
12
+ in README.md works. The condensed version for easy copy & paste:
13
+
14
+ ```bash
15
+ conda create --name <your_name> python=3.10
16
+ conda activate <your_name>
17
+ pip install --upgrade "jax[cpu]" torch
18
+ pip install -r test_requirements.txt
19
+ pip install -e .
20
+ pip install pytest-xdist # recommended for running test faster
21
+ pytest -n auto test
22
+ ```
23
+
24
+ ## Setup on GPU or TPU
25
+
26
+ Same as Mac setup, except, if you run test using pytest, please also
27
+ add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs
28
+ test in multiple threads. CPU device can be accessed concurrently where
29
+ TPU devices usually only allow one accesor per process; so it could deadlock.
30
+
31
+ ### VSCode
32
+
33
+ I use vscode on my Mac. I loosely followed instruction in
34
+ https://code.visualstudio.com/docs/python/python-tutorial
35
+ to setup a proper python environment.
36
+
37
+ The plugins I installed (a subset of the ones listed above) are:
38
+ * VSCode's official Python plugin
39
+ * Ruff formatter
40
+ * Python Debugger
41
+
42
+ I also changed Python interpreter to point at the one in my conda env.
43
+ That is all the changes I have.
torchax/__init__.py ADDED
@@ -0,0 +1,153 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ from typing import List, Dict, Any, Optional
17
+ import dataclasses
18
+ import jax
19
+ import os
20
+ import torch
21
+ from torch.utils import _pytree as pytree
22
+ from torchax import tensor
23
+ from contextlib import contextmanager
24
+
25
+ __version__ = "0.0.10.dev20251117"
26
+ VERSION = __version__
27
+
28
+ # the "fast path" uses some sparse tensor thingies that currently we
29
+ # don't support
30
+ torch.backends.mha.set_fastpath_enabled(False)
31
+
32
+
33
+ __all__ = [
34
+ "default_env",
35
+ "extract_jax",
36
+ "enable_globally",
37
+ "save_checkpoint",
38
+ "load_checkpoint",
39
+ ]
40
+
41
+ from .checkpoint import save_checkpoint, load_checkpoint
42
+
43
+ os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
44
+
45
+ # torchax:oss-begin
46
+ if getattr(jax.config, "jax_pjrt_client_create_options", None):
47
+ jax.config.update(
48
+ "jax_pjrt_client_create_options",
49
+ f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
50
+ )
51
+ # torchax:oss-end
52
+
53
+ env = None
54
+
55
+
56
+ def default_env():
57
+ global env
58
+
59
+ if env is None:
60
+ env = tensor.Environment()
61
+ return env
62
+
63
+
64
+ def extract_jax(mod: torch.nn.Module, env=None):
65
+ """Returns a pytree of jax.ndarray and a jax callable."""
66
+ if env is None:
67
+ env = default_env()
68
+ states = dict(mod.named_buffers())
69
+ states.update(mod.named_parameters())
70
+
71
+ states = env.t2j_copy(states)
72
+
73
+ # @jax.jit
74
+ def jax_func(states, args, kwargs=None):
75
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
76
+ with env:
77
+ res = torch.func.functional_call(
78
+ mod, states, args, kwargs, tie_weights=False
79
+ )
80
+ return env.t2j_iso(res)
81
+
82
+ return states, jax_func
83
+
84
+
85
+ def enable_globally():
86
+ env = default_env().enable_torch_modes()
87
+ return env
88
+
89
+
90
+ def disable_globally():
91
+ global env
92
+ default_env().disable_torch_modes()
93
+
94
+
95
+ @contextlib.contextmanager
96
+ def disable_temporarily():
97
+ prev = default_env().enabled
98
+ if prev:
99
+ disable_globally()
100
+ yield ()
101
+ if prev:
102
+ enable_globally()
103
+
104
+
105
+ torch.utils.rename_privateuse1_backend("jax")
106
+ unsupported_dtype = [torch.quint8]
107
+
108
+ import jax
109
+ import torchax.device_module
110
+
111
+ torch._register_device_module("jax", torchax.device_module)
112
+
113
+
114
+ def enable_accuracy_mode():
115
+ jax.config.update("jax_enable_x64", True)
116
+ jax.config.update("jax_default_matmul_precision", "highest")
117
+ default_env().config.internal_respect_torch_return_dtypes = True
118
+
119
+
120
+ def enable_performance_mode():
121
+ jax.config.update("jax_enable_x64", False)
122
+ jax.config.update("jax_default_matmul_precision", "default")
123
+ default_env().config.internal_respect_torch_return_dtypes = False
124
+
125
+
126
+ @dataclasses.dataclass
127
+ class CompileOptions:
128
+ # only valid if compiling nn.Module
129
+ methods_to_compile: List[str] = dataclasses.field(
130
+ default_factory=lambda: ["forward"]
131
+ )
132
+ jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
133
+ mode: str = "jax" # or dynamo or export
134
+
135
+
136
+ def compile(fn, options: Optional[CompileOptions] = None):
137
+ options = options or CompileOptions()
138
+ if options.mode == "jax":
139
+ from torchax import interop
140
+
141
+ if isinstance(fn, torch.nn.Module):
142
+ module = interop.JittableModule(
143
+ fn, extra_jit_args=options.jax_jit_kwargs
144
+ )
145
+ for n in options.methods_to_compile:
146
+ module.make_jitted(n)
147
+ return module
148
+ else:
149
+ return interop.jax_jit(fn)
150
+ elif options.mode == "dynamo":
151
+ raise RuntimeError("dynamo mode is not supported yet")
152
+ elif options.mode == "export":
153
+ raise RuntimeError("export mode is not supported yet")
torchax/amp.py ADDED
@@ -0,0 +1,346 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import enum
17
+ import torch
18
+ from torch.utils import _pytree as pytree
19
+
20
+
21
+ # enum class CastPolicy : uint8_t {
22
+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
23
+ # // running the op. Currently, lower_precision_fp is
24
+ # // fp16 for AutocastCUDA, and is defined by user
25
+ # // (default bf16) for AutocastCPU or other device.
26
+ # fp32, // Cast all inputs to at::kFloat before running the op.
27
+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
28
+ # // 1. we'd like to run in fp32 and
29
+ # // 2. have a std::optional<ScalarType> arg that controls
30
+ # // the output type.
31
+ # // fp32_set_opt_dtype wrappers' policy is: if the output
32
+ # // type is already set, don't touch it, otherwise, set
33
+ # // it to at::kFloat.
34
+ # fp32_append_dtype, // Treats functions (like norm) that
35
+ # // 1. we'd like to run in fp32 and
36
+ # // 2. have some overloads that accept an output type and
37
+ # // other overloads that don't.
38
+ # // fp32_append_dtype wrappers wrap the overloads that don't
39
+ # // have an output dtype.
40
+ # // The wrapper policy is: append at::kFloat to the args,
41
+ # // and redispatch to the type-aware overload.
42
+ # promote, // Run in the widest dtype among several args.
43
+ # };
44
+ class CastPolicy(enum.Enum):
45
+ LOWER_PRECISION_FP = 0
46
+ FP32 = 1
47
+ FP32_SET_OPT_DTYPE = 2
48
+ FP32_APPEND_DTYPE = 3
49
+ PROMOTE = 4
50
+
51
+
52
+ def execute_policy(policy, args, kwargs, target_lower_fp):
53
+
54
+ def is_float(a):
55
+ return isinstance(a, torch.Tensor) and a.is_floating_point()
56
+ match policy:
57
+ case CastPolicy.LOWER_PRECISION_FP:
58
+ return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
59
+ (args, kwargs))
60
+ case CastPolicy.FP32:
61
+ return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
62
+ (args, kwargs))
63
+ case CastPolicy.PROMOTE:
64
+ dtypes = set(a.dtype for a in args)
65
+ widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
66
+ return pytree.tree_map_only(is_float, lambda a: a.to(widest),
67
+ (args, kwargs))
68
+ case _:
69
+ raise AssertionError(f'Policy {policy} not implemented yet.')
70
+
71
+
72
+ @contextlib.contextmanager
73
+ def autocast(device, dtype=torch.bfloat16, env=None):
74
+ del device
75
+ if env is None:
76
+ import torchax
77
+ env = torchax.default_env()
78
+ with env.override_property(autocast_dtype=dtype):
79
+ yield
80
+
81
+
82
+ # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
83
+ autocast_policy = {
84
+ torch.ops.aten.conv1d.default:
85
+ CastPolicy.LOWER_PRECISION_FP,
86
+ torch.ops.aten.conv1d.padding:
87
+ CastPolicy.LOWER_PRECISION_FP,
88
+ torch.ops.aten.conv2d.default:
89
+ CastPolicy.LOWER_PRECISION_FP,
90
+ torch.ops.aten.conv2d.padding:
91
+ CastPolicy.LOWER_PRECISION_FP,
92
+ torch.ops.aten.conv3d.default:
93
+ CastPolicy.LOWER_PRECISION_FP,
94
+ torch.ops.aten.conv3d.padding:
95
+ CastPolicy.LOWER_PRECISION_FP,
96
+ torch.ops.aten.bmm.default:
97
+ CastPolicy.LOWER_PRECISION_FP,
98
+ torch.ops.aten.mm.default:
99
+ CastPolicy.LOWER_PRECISION_FP,
100
+ torch.ops.aten.linalg_vecdot.default:
101
+ CastPolicy.LOWER_PRECISION_FP,
102
+ torch.ops.aten.baddbmm.default:
103
+ CastPolicy.LOWER_PRECISION_FP,
104
+ torch.ops.aten.addmm.default:
105
+ CastPolicy.LOWER_PRECISION_FP,
106
+ torch.ops.aten._addmm_activation.default:
107
+ CastPolicy.LOWER_PRECISION_FP,
108
+ torch.ops.aten.addbmm.default:
109
+ CastPolicy.LOWER_PRECISION_FP,
110
+ torch.ops.aten.linear.default:
111
+ CastPolicy.LOWER_PRECISION_FP,
112
+ torch.ops.aten._convolution.deprecated:
113
+ CastPolicy.LOWER_PRECISION_FP,
114
+ torch.ops.aten.matmul.default:
115
+ CastPolicy.LOWER_PRECISION_FP,
116
+ torch.ops.aten.conv_tbc.default:
117
+ CastPolicy.LOWER_PRECISION_FP,
118
+ torch.ops.aten.mkldnn_rnn_layer.default:
119
+ CastPolicy.LOWER_PRECISION_FP,
120
+ torch.ops.aten.conv_transpose1d.default:
121
+ CastPolicy.LOWER_PRECISION_FP,
122
+ torch.ops.aten.conv_transpose2d.input:
123
+ CastPolicy.LOWER_PRECISION_FP,
124
+ torch.ops.aten.conv_transpose3d.input:
125
+ CastPolicy.LOWER_PRECISION_FP,
126
+ torch.ops.aten.prelu.default:
127
+ CastPolicy.LOWER_PRECISION_FP,
128
+ torch.ops.aten.scaled_dot_product_attention.default:
129
+ CastPolicy.LOWER_PRECISION_FP,
130
+ torch.ops.aten._native_multi_head_attention.default:
131
+ CastPolicy.LOWER_PRECISION_FP,
132
+
133
+ # fp32 cast policy
134
+ torch.ops.aten.avg_pool3d.default:
135
+ CastPolicy.FP32,
136
+ torch.ops.aten.binary_cross_entropy.default:
137
+ CastPolicy.FP32,
138
+ torch.ops.aten.grid_sampler.default:
139
+ CastPolicy.FP32,
140
+ torch.ops.aten.polar.default:
141
+ CastPolicy.FP32,
142
+ torch.ops.aten.prod.default:
143
+ CastPolicy.FP32,
144
+ torch.ops.aten.prod.dim_int:
145
+ CastPolicy.FP32,
146
+ torch.ops.aten.prod.dim_Dimname:
147
+ CastPolicy.FP32,
148
+ torch.ops.aten.quantile.default:
149
+ CastPolicy.FP32,
150
+ torch.ops.aten.quantile.scalar:
151
+ CastPolicy.FP32,
152
+ torch.ops.aten.nanquantile.default:
153
+ CastPolicy.FP32,
154
+ torch.ops.aten.nanquantile.scalar:
155
+ CastPolicy.FP32,
156
+ torch.ops.aten.stft.default:
157
+ CastPolicy.FP32,
158
+ torch.ops.aten.stft.center:
159
+ CastPolicy.FP32,
160
+ torch.ops.aten.cdist.default:
161
+ CastPolicy.FP32,
162
+ torch.ops.aten.grid_sampler_2d.default:
163
+ CastPolicy.FP32,
164
+ torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
165
+ CastPolicy.FP32,
166
+ torch.ops.aten.grid_sampler_3d.default:
167
+ CastPolicy.FP32,
168
+ torch.ops.aten.trace.default:
169
+ CastPolicy.FP32,
170
+ torch.ops.aten.view_as_complex.default:
171
+ CastPolicy.FP32,
172
+ torch.ops.aten.cholesky.default:
173
+ CastPolicy.FP32,
174
+ torch.ops.aten.cholesky_inverse.default:
175
+ CastPolicy.FP32,
176
+ torch.ops.aten.cholesky_solve.default:
177
+ CastPolicy.FP32,
178
+ torch.ops.aten.inverse.default:
179
+ CastPolicy.FP32,
180
+ torch.ops.aten.lu_solve.default:
181
+ CastPolicy.FP32,
182
+ torch.ops.aten.orgqr.default:
183
+ CastPolicy.FP32,
184
+ torch.ops.aten.ormqr.default:
185
+ CastPolicy.FP32,
186
+ torch.ops.aten.pinverse.default:
187
+ CastPolicy.FP32,
188
+ torch.ops.aten.max_pool3d.default:
189
+ CastPolicy.FP32,
190
+ torch.ops.aten.max_unpool2d.default:
191
+ CastPolicy.FP32,
192
+ torch.ops.aten.max_unpool3d.default:
193
+ CastPolicy.FP32,
194
+ torch.ops.aten.adaptive_avg_pool3d.default:
195
+ CastPolicy.FP32,
196
+ torch.ops.aten.reflection_pad1d.default:
197
+ CastPolicy.FP32,
198
+ torch.ops.aten.reflection_pad2d.default:
199
+ CastPolicy.FP32,
200
+ torch.ops.aten.replication_pad1d.default:
201
+ CastPolicy.FP32,
202
+ torch.ops.aten.replication_pad2d.default:
203
+ CastPolicy.FP32,
204
+ torch.ops.aten.replication_pad3d.default:
205
+ CastPolicy.FP32,
206
+ torch.ops.aten.mse_loss.default:
207
+ CastPolicy.FP32,
208
+ torch.ops.aten.cosine_embedding_loss.default:
209
+ CastPolicy.FP32,
210
+ torch.ops.aten.nll_loss.default:
211
+ CastPolicy.FP32,
212
+ torch.ops.aten.nll_loss2d.default:
213
+ CastPolicy.FP32,
214
+ torch.ops.aten.hinge_embedding_loss.default:
215
+ CastPolicy.FP32,
216
+ torch.ops.aten.poisson_nll_loss.default:
217
+ CastPolicy.FP32,
218
+ torch.ops.aten.smooth_l1_loss.default:
219
+ CastPolicy.FP32,
220
+ torch.ops.aten.cross_entropy_loss.default:
221
+ CastPolicy.FP32,
222
+ torch.ops.aten.l1_loss.default:
223
+ CastPolicy.FP32,
224
+ torch.ops.aten.huber_loss.default:
225
+ CastPolicy.FP32,
226
+ torch.ops.aten.margin_ranking_loss.default:
227
+ CastPolicy.FP32,
228
+ torch.ops.aten.soft_margin_loss.default:
229
+ CastPolicy.FP32,
230
+ torch.ops.aten.triplet_margin_loss.default:
231
+ CastPolicy.FP32,
232
+ torch.ops.aten.multi_margin_loss.default:
233
+ CastPolicy.FP32,
234
+ torch.ops.aten.ctc_loss.IntList:
235
+ CastPolicy.FP32,
236
+ torch.ops.aten.ctc_loss.Tensor:
237
+ CastPolicy.FP32,
238
+ torch.ops.aten.kl_div.default:
239
+ CastPolicy.FP32,
240
+ torch.ops.aten.multilabel_margin_loss.default:
241
+ CastPolicy.FP32,
242
+ torch.ops.aten.binary_cross_entropy_with_logits.default:
243
+ CastPolicy.FP32,
244
+ torch.ops.aten.fft_fft.default:
245
+ CastPolicy.FP32,
246
+ torch.ops.aten.fft_ifft.default:
247
+ CastPolicy.FP32,
248
+ torch.ops.aten.fft_fft2.default:
249
+ CastPolicy.FP32,
250
+ torch.ops.aten.fft_ifft2.default:
251
+ CastPolicy.FP32,
252
+ torch.ops.aten.fft_fftn.default:
253
+ CastPolicy.FP32,
254
+ torch.ops.aten.fft_ifftn.default:
255
+ CastPolicy.FP32,
256
+ torch.ops.aten.fft_rfft.default:
257
+ CastPolicy.FP32,
258
+ torch.ops.aten.fft_irfft.default:
259
+ CastPolicy.FP32,
260
+ torch.ops.aten.fft_rfft2.default:
261
+ CastPolicy.FP32,
262
+ torch.ops.aten.fft_irfft2.default:
263
+ CastPolicy.FP32,
264
+ torch.ops.aten.fft_rfftn.default:
265
+ CastPolicy.FP32,
266
+ torch.ops.aten.fft_irfftn.default:
267
+ CastPolicy.FP32,
268
+ torch.ops.aten.fft_hfft.default:
269
+ CastPolicy.FP32,
270
+ torch.ops.aten.fft_ihfft.default:
271
+ CastPolicy.FP32,
272
+ torch.ops.aten.linalg_cond.default:
273
+ CastPolicy.FP32,
274
+ torch.ops.aten.linalg_cond.p_str:
275
+ CastPolicy.FP32,
276
+ torch.ops.aten.linalg_matrix_rank.default:
277
+ CastPolicy.FP32,
278
+ torch.ops.aten.linalg_matrix_rank.tol_tensor:
279
+ CastPolicy.FP32,
280
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
281
+ CastPolicy.FP32,
282
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
283
+ CastPolicy.FP32,
284
+ torch.ops.aten.linalg_solve.default:
285
+ CastPolicy.FP32,
286
+ torch.ops.aten.linalg_cholesky.default:
287
+ CastPolicy.FP32,
288
+ torch.ops.aten.linalg_svdvals.default:
289
+ CastPolicy.FP32,
290
+ torch.ops.aten.linalg_eigvals.default:
291
+ CastPolicy.FP32,
292
+ torch.ops.aten.linalg_eigvalsh.default:
293
+ CastPolicy.FP32,
294
+ torch.ops.aten.linalg_inv.default:
295
+ CastPolicy.FP32,
296
+ torch.ops.aten.linalg_householder_product.default:
297
+ CastPolicy.FP32,
298
+ torch.ops.aten.linalg_tensorinv.default:
299
+ CastPolicy.FP32,
300
+ torch.ops.aten.linalg_tensorsolve.default:
301
+ CastPolicy.FP32,
302
+ torch.ops.aten.fake_quantize_per_tensor_affine.default:
303
+ CastPolicy.FP32,
304
+ torch.ops.aten.geqrf.default:
305
+ CastPolicy.FP32,
306
+ torch.ops.aten._lu_with_info.default:
307
+ CastPolicy.FP32,
308
+ torch.ops.aten.qr.default:
309
+ CastPolicy.FP32,
310
+ torch.ops.aten.svd.default:
311
+ CastPolicy.FP32,
312
+ torch.ops.aten.triangular_solve.default:
313
+ CastPolicy.FP32,
314
+ torch.ops.aten.fractional_max_pool2d.default:
315
+ CastPolicy.FP32,
316
+ torch.ops.aten.fractional_max_pool3d.default:
317
+ CastPolicy.FP32,
318
+ torch.ops.aten.adaptive_max_pool3d.default:
319
+ CastPolicy.FP32,
320
+ torch.ops.aten.multilabel_margin_loss_forward.default:
321
+ CastPolicy.FP32,
322
+ torch.ops.aten.linalg_qr.default:
323
+ CastPolicy.FP32,
324
+ torch.ops.aten.linalg_cholesky_ex.default:
325
+ CastPolicy.FP32,
326
+ torch.ops.aten.linalg_svd.default:
327
+ CastPolicy.FP32,
328
+ torch.ops.aten.linalg_eig.default:
329
+ CastPolicy.FP32,
330
+ torch.ops.aten.linalg_eigh.default:
331
+ CastPolicy.FP32,
332
+ torch.ops.aten.linalg_lstsq.default:
333
+ CastPolicy.FP32,
334
+ torch.ops.aten.linalg_inv_ex.default:
335
+ CastPolicy.FP32,
336
+
337
+ # promote
338
+ torch.ops.aten.stack.default:
339
+ CastPolicy.PROMOTE,
340
+ torch.ops.aten.cat.default:
341
+ CastPolicy.PROMOTE,
342
+ torch.ops.aten.index_copy.default:
343
+ CastPolicy.PROMOTE,
344
+ torch.ops.aten.index_copy.dimname:
345
+ CastPolicy.PROMOTE,
346
+ }
torchax/checkpoint.py ADDED
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import os
17
+ from typing import Any, Dict
18
+ from flax.training import checkpoints
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from . import tensor
23
+
24
+ def _to_jax(pytree):
25
+ def to_jax_array(x):
26
+ if isinstance(x, tensor.Tensor):
27
+ return x.jax()
28
+ elif isinstance(x, torch.Tensor):
29
+ return jnp.asarray(x.cpu().numpy())
30
+ return x
31
+ return jax.tree_util.tree_map(to_jax_array, pytree)
32
+
33
+
34
+ def _to_torch(pytree):
35
+ return jax.tree_util.tree_map(
36
+ lambda x: torch.from_numpy(np.asarray(x))
37
+ if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
38
+
39
+
40
+ def save_checkpoint(state: Dict[str, Any], path: str, step: int):
41
+ """Saves a checkpoint to a file in JAX style.
42
+
43
+ Args:
44
+ state: A dictionary containing the state to save. torch.Tensors will be
45
+ converted to jax.Array.
46
+ path: The path to save the checkpoint to. This is a directory.
47
+ step: The training step.
48
+ """
49
+ state = _to_jax(state)
50
+ checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
51
+
52
+
53
+ def load_checkpoint(path: str) -> Dict[str, Any]:
54
+ """Loads a checkpoint and returns it in JAX format.
55
+
56
+ This function can load both PyTorch-style (single file) and JAX-style
57
+ (directory) checkpoints.
58
+
59
+ If the checkpoint is in PyTorch format, it will be converted to JAX format.
60
+
61
+ Args:
62
+ path: The path to the checkpoint.
63
+
64
+ Returns:
65
+ The loaded state in JAX format (pytree with jax.Array leaves).
66
+ """
67
+ if os.path.isdir(path):
68
+ # JAX-style checkpoint
69
+ state = checkpoints.restore_checkpoint(path, target=None)
70
+ if state is None:
71
+ raise FileNotFoundError(f"No checkpoint found at {path}")
72
+ return state
73
+ elif os.path.isfile(path):
74
+ # PyTorch-style checkpoint
75
+ state = torch.load(path, weights_only=False)
76
+ return _to_jax(state)
77
+ else:
78
+ raise FileNotFoundError(f"No such file or directory: {path}")
79
+
torchax/config.py ADDED
@@ -0,0 +1,44 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Configuration:
20
+ debug_print_each_op: bool = False
21
+ debug_accuracy_for_each_op: bool = False
22
+ debug_mixed_tensor: bool = False
23
+ debug_print_each_op_operands: bool = False
24
+
25
+ use_int32_for_index: bool = False
26
+
27
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
28
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
29
+ # can use scalar * tensor math to handle it
30
+ allow_mixed_math_with_scalar_tensor: bool = True
31
+
32
+ # If true, we will convert Views into torchax.Tensors eagerly
33
+ force_materialize_views: bool = False
34
+
35
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
36
+ use_dlpack_for_data_conversion: bool = False
37
+
38
+ # Flash attention
39
+ use_tpu_flash_attention: bool = False
40
+ shmap_flash_attention: bool = False
41
+
42
+ # device
43
+ treat_cuda_as_jax_device: bool = True
44
+ internal_respect_torch_return_dtypes: bool = False