torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,43 @@
1
+ # Contributing to torchax
2
+
3
+ We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
4
+
5
+
6
+ # Developer setup
7
+
8
+ ## Mac setup:
9
+ @qihqi
10
+
11
+ I am able to develop directly on mac (m1) laptop for most of parts. Using steps
12
+ in README.md works. The condensed version for easy copy & paste:
13
+
14
+ ```bash
15
+ conda create --name <your_name> python=3.10
16
+ conda activate <your_name>
17
+ pip install --upgrade "jax[cpu]" torch
18
+ pip install -r test_requirements.txt
19
+ pip install -e .
20
+ pip install pytest-xdist # recommended for running test faster
21
+ pytest -n auto test
22
+ ```
23
+
24
+ ## Setup on GPU or TPU
25
+
26
+ Same as Mac setup, except, if you run test using pytest, please also
27
+ add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs
28
+ test in multiple threads. CPU device can be accessed concurrently where
29
+ TPU devices usually only allow one accesor per process; so it could deadlock.
30
+
31
+ ### VSCode
32
+
33
+ I use vscode on my Mac. I loosely followed instruction in
34
+ https://code.visualstudio.com/docs/python/python-tutorial
35
+ to setup a proper python environment.
36
+
37
+ The plugins I installed (a subset of the ones listed above) are:
38
+ * VSCode's official Python plugin
39
+ * Ruff formatter
40
+ * Python Debugger
41
+
42
+ I also changed Python interpreter to point at the one in my conda env.
43
+ That is all the changes I have.
torchax/__init__.py ADDED
@@ -0,0 +1,149 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import dataclasses
17
+ import os
18
+ from contextlib import contextmanager
19
+ from typing import Any
20
+
21
+ import jax
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+ import torchax.device_module
26
+ from torchax import tensor
27
+
28
+ from .checkpoint import load_checkpoint, save_checkpoint
29
+
30
+ __version__ = "0.0.10.dev20251118"
31
+ VERSION = __version__
32
+
33
+ # the "fast path" uses some sparse tensor thingies that currently we
34
+ # don't support
35
+ torch.backends.mha.set_fastpath_enabled(False)
36
+
37
+
38
+ __all__ = [
39
+ "default_env",
40
+ "extract_jax",
41
+ "enable_globally",
42
+ "save_checkpoint",
43
+ "load_checkpoint",
44
+ ]
45
+
46
+
47
+ os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
48
+
49
+ # torchax:oss-begin
50
+ if getattr(jax.config, "jax_pjrt_client_create_options", None):
51
+ jax.config.update(
52
+ "jax_pjrt_client_create_options",
53
+ f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
54
+ )
55
+ # torchax:oss-end
56
+
57
+ env = None
58
+
59
+
60
+ def default_env():
61
+ global env
62
+
63
+ if env is None:
64
+ env = tensor.Environment()
65
+ return env
66
+
67
+
68
+ def extract_jax(mod: torch.nn.Module, env=None):
69
+ """Returns a pytree of jax.ndarray and a jax callable."""
70
+ if env is None:
71
+ env = default_env()
72
+ states = dict(mod.named_buffers())
73
+ states.update(mod.named_parameters())
74
+
75
+ states = env.t2j_copy(states)
76
+
77
+ # @jax.jit
78
+ def jax_func(states, args, kwargs=None):
79
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
80
+ with env:
81
+ res = torch.func.functional_call(mod, states, args, kwargs, tie_weights=False)
82
+ return env.t2j_iso(res)
83
+
84
+ return states, jax_func
85
+
86
+
87
+ def enable_globally():
88
+ env = default_env().enable_torch_modes()
89
+ return env
90
+
91
+
92
+ def disable_globally():
93
+ global env
94
+ default_env().disable_torch_modes()
95
+
96
+
97
+ @contextlib.contextmanager
98
+ def disable_temporarily():
99
+ prev = default_env().enabled
100
+ if prev:
101
+ disable_globally()
102
+ yield ()
103
+ if prev:
104
+ enable_globally()
105
+
106
+
107
+ torch.utils.rename_privateuse1_backend("jax")
108
+ unsupported_dtype = [torch.quint8]
109
+
110
+
111
+ torch._register_device_module("jax", torchax.device_module)
112
+
113
+
114
+ def enable_accuracy_mode():
115
+ jax.config.update("jax_enable_x64", True)
116
+ jax.config.update("jax_default_matmul_precision", "highest")
117
+ default_env().config.internal_respect_torch_return_dtypes = True
118
+
119
+
120
+ def enable_performance_mode():
121
+ jax.config.update("jax_enable_x64", False)
122
+ jax.config.update("jax_default_matmul_precision", "default")
123
+ default_env().config.internal_respect_torch_return_dtypes = False
124
+
125
+
126
+ @dataclasses.dataclass
127
+ class CompileOptions:
128
+ # only valid if compiling nn.Module
129
+ methods_to_compile: list[str] = dataclasses.field(default_factory=lambda: ["forward"])
130
+ jax_jit_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
131
+ mode: str = "jax" # or dynamo or export
132
+
133
+
134
+ def compile(fn, options: CompileOptions | None = None):
135
+ options = options or CompileOptions()
136
+ if options.mode == "jax":
137
+ from torchax import interop
138
+
139
+ if isinstance(fn, torch.nn.Module):
140
+ module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
141
+ for n in options.methods_to_compile:
142
+ module.make_jitted(n)
143
+ return module
144
+ else:
145
+ return interop.jax_jit(fn)
146
+ elif options.mode == "dynamo":
147
+ raise RuntimeError("dynamo mode is not supported yet")
148
+ elif options.mode == "export":
149
+ raise RuntimeError("export mode is not supported yet")
torchax/amp.py ADDED
@@ -0,0 +1,218 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import enum
17
+
18
+ import torch
19
+ from torch.utils import _pytree as pytree
20
+
21
+
22
+ # enum class CastPolicy : uint8_t {
23
+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
24
+ # // running the op. Currently, lower_precision_fp is
25
+ # // fp16 for AutocastCUDA, and is defined by user
26
+ # // (default bf16) for AutocastCPU or other device.
27
+ # fp32, // Cast all inputs to at::kFloat before running the op.
28
+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
29
+ # // 1. we'd like to run in fp32 and
30
+ # // 2. have a std::optional<ScalarType> arg that controls
31
+ # // the output type.
32
+ # // fp32_set_opt_dtype wrappers' policy is: if the output
33
+ # // type is already set, don't touch it, otherwise, set
34
+ # // it to at::kFloat.
35
+ # fp32_append_dtype, // Treats functions (like norm) that
36
+ # // 1. we'd like to run in fp32 and
37
+ # // 2. have some overloads that accept an output type and
38
+ # // other overloads that don't.
39
+ # // fp32_append_dtype wrappers wrap the overloads that don't
40
+ # // have an output dtype.
41
+ # // The wrapper policy is: append at::kFloat to the args,
42
+ # // and redispatch to the type-aware overload.
43
+ # promote, // Run in the widest dtype among several args.
44
+ # };
45
+ class CastPolicy(enum.Enum):
46
+ LOWER_PRECISION_FP = 0
47
+ FP32 = 1
48
+ FP32_SET_OPT_DTYPE = 2
49
+ FP32_APPEND_DTYPE = 3
50
+ PROMOTE = 4
51
+
52
+
53
+ def execute_policy(policy, args, kwargs, target_lower_fp):
54
+ def is_float(a):
55
+ return isinstance(a, torch.Tensor) and a.is_floating_point()
56
+
57
+ match policy:
58
+ case CastPolicy.LOWER_PRECISION_FP:
59
+ return pytree.tree_map_only(
60
+ is_float, lambda a: a.to(target_lower_fp), (args, kwargs)
61
+ )
62
+ case CastPolicy.FP32:
63
+ return pytree.tree_map_only(
64
+ is_float, lambda a: a.to(torch.float32), (args, kwargs)
65
+ )
66
+ case CastPolicy.PROMOTE:
67
+ dtypes = {a.dtype for a in args}
68
+ widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
69
+ return pytree.tree_map_only(is_float, lambda a: a.to(widest), (args, kwargs))
70
+ case _:
71
+ raise AssertionError(f"Policy {policy} not implemented yet.")
72
+
73
+
74
+ @contextlib.contextmanager
75
+ def autocast(device, dtype=torch.bfloat16, env=None):
76
+ del device
77
+ if env is None:
78
+ import torchax
79
+
80
+ env = torchax.default_env()
81
+ with env.override_property(autocast_dtype=dtype):
82
+ yield
83
+
84
+
85
+ # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
86
+ autocast_policy = {
87
+ torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP,
88
+ torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP,
89
+ torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP,
90
+ torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP,
91
+ torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP,
92
+ torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP,
93
+ torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP,
94
+ torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP,
95
+ torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP,
96
+ torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP,
97
+ torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP,
98
+ torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP,
99
+ torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP,
100
+ torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP,
101
+ torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP,
102
+ torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP,
103
+ torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP,
104
+ torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP,
105
+ torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP,
106
+ torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP,
107
+ torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP,
108
+ torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP,
109
+ torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP,
110
+ torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP,
111
+ # fp32 cast policy
112
+ torch.ops.aten.avg_pool3d.default: CastPolicy.FP32,
113
+ torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32,
114
+ torch.ops.aten.grid_sampler.default: CastPolicy.FP32,
115
+ torch.ops.aten.polar.default: CastPolicy.FP32,
116
+ torch.ops.aten.prod.default: CastPolicy.FP32,
117
+ torch.ops.aten.prod.dim_int: CastPolicy.FP32,
118
+ torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32,
119
+ torch.ops.aten.quantile.default: CastPolicy.FP32,
120
+ torch.ops.aten.quantile.scalar: CastPolicy.FP32,
121
+ torch.ops.aten.nanquantile.default: CastPolicy.FP32,
122
+ torch.ops.aten.nanquantile.scalar: CastPolicy.FP32,
123
+ torch.ops.aten.stft.default: CastPolicy.FP32,
124
+ torch.ops.aten.stft.center: CastPolicy.FP32,
125
+ torch.ops.aten.cdist.default: CastPolicy.FP32,
126
+ torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32,
127
+ torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32,
128
+ torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32,
129
+ torch.ops.aten.trace.default: CastPolicy.FP32,
130
+ torch.ops.aten.view_as_complex.default: CastPolicy.FP32,
131
+ torch.ops.aten.cholesky.default: CastPolicy.FP32,
132
+ torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32,
133
+ torch.ops.aten.cholesky_solve.default: CastPolicy.FP32,
134
+ torch.ops.aten.inverse.default: CastPolicy.FP32,
135
+ torch.ops.aten.lu_solve.default: CastPolicy.FP32,
136
+ torch.ops.aten.orgqr.default: CastPolicy.FP32,
137
+ torch.ops.aten.ormqr.default: CastPolicy.FP32,
138
+ torch.ops.aten.pinverse.default: CastPolicy.FP32,
139
+ torch.ops.aten.max_pool3d.default: CastPolicy.FP32,
140
+ torch.ops.aten.max_unpool2d.default: CastPolicy.FP32,
141
+ torch.ops.aten.max_unpool3d.default: CastPolicy.FP32,
142
+ torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32,
143
+ torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32,
144
+ torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32,
145
+ torch.ops.aten.replication_pad1d.default: CastPolicy.FP32,
146
+ torch.ops.aten.replication_pad2d.default: CastPolicy.FP32,
147
+ torch.ops.aten.replication_pad3d.default: CastPolicy.FP32,
148
+ torch.ops.aten.mse_loss.default: CastPolicy.FP32,
149
+ torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32,
150
+ torch.ops.aten.nll_loss.default: CastPolicy.FP32,
151
+ torch.ops.aten.nll_loss2d.default: CastPolicy.FP32,
152
+ torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32,
153
+ torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32,
154
+ torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32,
155
+ torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32,
156
+ torch.ops.aten.l1_loss.default: CastPolicy.FP32,
157
+ torch.ops.aten.huber_loss.default: CastPolicy.FP32,
158
+ torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32,
159
+ torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32,
160
+ torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32,
161
+ torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32,
162
+ torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32,
163
+ torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32,
164
+ torch.ops.aten.kl_div.default: CastPolicy.FP32,
165
+ torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32,
166
+ torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32,
167
+ torch.ops.aten.fft_fft.default: CastPolicy.FP32,
168
+ torch.ops.aten.fft_ifft.default: CastPolicy.FP32,
169
+ torch.ops.aten.fft_fft2.default: CastPolicy.FP32,
170
+ torch.ops.aten.fft_ifft2.default: CastPolicy.FP32,
171
+ torch.ops.aten.fft_fftn.default: CastPolicy.FP32,
172
+ torch.ops.aten.fft_ifftn.default: CastPolicy.FP32,
173
+ torch.ops.aten.fft_rfft.default: CastPolicy.FP32,
174
+ torch.ops.aten.fft_irfft.default: CastPolicy.FP32,
175
+ torch.ops.aten.fft_rfft2.default: CastPolicy.FP32,
176
+ torch.ops.aten.fft_irfft2.default: CastPolicy.FP32,
177
+ torch.ops.aten.fft_rfftn.default: CastPolicy.FP32,
178
+ torch.ops.aten.fft_irfftn.default: CastPolicy.FP32,
179
+ torch.ops.aten.fft_hfft.default: CastPolicy.FP32,
180
+ torch.ops.aten.fft_ihfft.default: CastPolicy.FP32,
181
+ torch.ops.aten.linalg_cond.default: CastPolicy.FP32,
182
+ torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32,
183
+ torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32,
184
+ torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32,
185
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32,
186
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32,
187
+ torch.ops.aten.linalg_solve.default: CastPolicy.FP32,
188
+ torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32,
189
+ torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32,
190
+ torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32,
191
+ torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32,
192
+ torch.ops.aten.linalg_inv.default: CastPolicy.FP32,
193
+ torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32,
194
+ torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32,
195
+ torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32,
196
+ torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32,
197
+ torch.ops.aten.geqrf.default: CastPolicy.FP32,
198
+ torch.ops.aten._lu_with_info.default: CastPolicy.FP32,
199
+ torch.ops.aten.qr.default: CastPolicy.FP32,
200
+ torch.ops.aten.svd.default: CastPolicy.FP32,
201
+ torch.ops.aten.triangular_solve.default: CastPolicy.FP32,
202
+ torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32,
203
+ torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32,
204
+ torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32,
205
+ torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32,
206
+ torch.ops.aten.linalg_qr.default: CastPolicy.FP32,
207
+ torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32,
208
+ torch.ops.aten.linalg_svd.default: CastPolicy.FP32,
209
+ torch.ops.aten.linalg_eig.default: CastPolicy.FP32,
210
+ torch.ops.aten.linalg_eigh.default: CastPolicy.FP32,
211
+ torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32,
212
+ torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32,
213
+ # promote
214
+ torch.ops.aten.stack.default: CastPolicy.PROMOTE,
215
+ torch.ops.aten.cat.default: CastPolicy.PROMOTE,
216
+ torch.ops.aten.index_copy.default: CastPolicy.PROMOTE,
217
+ torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE,
218
+ }
torchax/checkpoint.py ADDED
@@ -0,0 +1,85 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import Any
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import torch
22
+ from flax.training import checkpoints
23
+
24
+ from . import tensor
25
+
26
+
27
+ def _to_jax(pytree):
28
+ def to_jax_array(x):
29
+ if isinstance(x, tensor.Tensor):
30
+ return x.jax()
31
+ elif isinstance(x, torch.Tensor):
32
+ return jnp.asarray(x.cpu().numpy())
33
+ return x
34
+
35
+ return jax.tree_util.tree_map(to_jax_array, pytree)
36
+
37
+
38
+ def _to_torch(pytree):
39
+ return jax.tree_util.tree_map(
40
+ lambda x: torch.from_numpy(np.asarray(x))
41
+ if isinstance(x, (jnp.ndarray, jax.Array))
42
+ else x,
43
+ pytree,
44
+ )
45
+
46
+
47
+ def save_checkpoint(state: dict[str, Any], path: str, step: int):
48
+ """Saves a checkpoint to a file in JAX style.
49
+
50
+ Args:
51
+ state: A dictionary containing the state to save. torch.Tensors will be
52
+ converted to jax.Array.
53
+ path: The path to save the checkpoint to. This is a directory.
54
+ step: The training step.
55
+ """
56
+ state = _to_jax(state)
57
+ checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
58
+
59
+
60
+ def load_checkpoint(path: str) -> dict[str, Any]:
61
+ """Loads a checkpoint and returns it in JAX format.
62
+
63
+ This function can load both PyTorch-style (single file) and JAX-style
64
+ (directory) checkpoints.
65
+
66
+ If the checkpoint is in PyTorch format, it will be converted to JAX format.
67
+
68
+ Args:
69
+ path: The path to the checkpoint.
70
+
71
+ Returns:
72
+ The loaded state in JAX format (pytree with jax.Array leaves).
73
+ """
74
+ if os.path.isdir(path):
75
+ # JAX-style checkpoint
76
+ state = checkpoints.restore_checkpoint(path, target=None)
77
+ if state is None:
78
+ raise FileNotFoundError(f"No checkpoint found at {path}")
79
+ return state
80
+ elif os.path.isfile(path):
81
+ # PyTorch-style checkpoint
82
+ state = torch.load(path, weights_only=False)
83
+ return _to_jax(state)
84
+ else:
85
+ raise FileNotFoundError(f"No such file or directory: {path}")
torchax/config.py ADDED
@@ -0,0 +1,44 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Configuration:
20
+ debug_print_each_op: bool = False
21
+ debug_accuracy_for_each_op: bool = False
22
+ debug_mixed_tensor: bool = False
23
+ debug_print_each_op_operands: bool = False
24
+
25
+ use_int32_for_index: bool = False
26
+
27
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
28
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
29
+ # can use scalar * tensor math to handle it
30
+ allow_mixed_math_with_scalar_tensor: bool = True
31
+
32
+ # If true, we will convert Views into torchax.Tensors eagerly
33
+ force_materialize_views: bool = False
34
+
35
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
36
+ use_dlpack_for_data_conversion: bool = False
37
+
38
+ # Flash attention
39
+ use_tpu_flash_attention: bool = False
40
+ shmap_flash_attention: bool = False
41
+
42
+ # device
43
+ treat_cuda_as_jax_device: bool = True
44
+ internal_respect_torch_return_dtypes: bool = False