torchax 0.0.10.dev20251117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +153 -0
- torchax/amp.py +346 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +790 -0
- torchax/device_module.py +47 -0
- torchax/export.py +259 -0
- torchax/flax.py +53 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +234 -0
- torchax/ops/__init__.py +24 -0
- torchax/ops/jaten.py +5937 -0
- torchax/ops/jax_reimplement.py +185 -0
- torchax/ops/jc10d.py +66 -0
- torchax/ops/jimage.py +127 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +631 -0
- torchax/ops/jtorchvision_nms.py +248 -0
- torchax/ops/mappings.py +161 -0
- torchax/ops/op_base.py +145 -0
- torchax/ops/ops_registry.py +69 -0
- torchax/tensor.py +736 -0
- torchax/train.py +132 -0
- torchax/types.py +26 -0
- torchax/util.py +102 -0
- torchax/view.py +391 -0
- torchax-0.0.10.dev20251117.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251117.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251117.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251117.dist-info/licenses/LICENSE +201 -0
torchax/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Contributing to torchax
|
|
2
|
+
|
|
3
|
+
We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# Developer setup
|
|
7
|
+
|
|
8
|
+
## Mac setup:
|
|
9
|
+
@qihqi
|
|
10
|
+
|
|
11
|
+
I am able to develop directly on mac (m1) laptop for most of parts. Using steps
|
|
12
|
+
in README.md works. The condensed version for easy copy & paste:
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
conda create --name <your_name> python=3.10
|
|
16
|
+
conda activate <your_name>
|
|
17
|
+
pip install --upgrade "jax[cpu]" torch
|
|
18
|
+
pip install -r test_requirements.txt
|
|
19
|
+
pip install -e .
|
|
20
|
+
pip install pytest-xdist # recommended for running test faster
|
|
21
|
+
pytest -n auto test
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Setup on GPU or TPU
|
|
25
|
+
|
|
26
|
+
Same as Mac setup, except, if you run test using pytest, please also
|
|
27
|
+
add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs
|
|
28
|
+
test in multiple threads. CPU device can be accessed concurrently where
|
|
29
|
+
TPU devices usually only allow one accesor per process; so it could deadlock.
|
|
30
|
+
|
|
31
|
+
### VSCode
|
|
32
|
+
|
|
33
|
+
I use vscode on my Mac. I loosely followed instruction in
|
|
34
|
+
https://code.visualstudio.com/docs/python/python-tutorial
|
|
35
|
+
to setup a proper python environment.
|
|
36
|
+
|
|
37
|
+
The plugins I installed (a subset of the ones listed above) are:
|
|
38
|
+
* VSCode's official Python plugin
|
|
39
|
+
* Ruff formatter
|
|
40
|
+
* Python Debugger
|
|
41
|
+
|
|
42
|
+
I also changed Python interpreter to point at the one in my conda env.
|
|
43
|
+
That is all the changes I have.
|
torchax/__init__.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import contextlib
|
|
16
|
+
from typing import List, Dict, Any, Optional
|
|
17
|
+
import dataclasses
|
|
18
|
+
import jax
|
|
19
|
+
import os
|
|
20
|
+
import torch
|
|
21
|
+
from torch.utils import _pytree as pytree
|
|
22
|
+
from torchax import tensor
|
|
23
|
+
from contextlib import contextmanager
|
|
24
|
+
|
|
25
|
+
__version__ = "0.0.10.dev20251117"
|
|
26
|
+
VERSION = __version__
|
|
27
|
+
|
|
28
|
+
# the "fast path" uses some sparse tensor thingies that currently we
|
|
29
|
+
# don't support
|
|
30
|
+
torch.backends.mha.set_fastpath_enabled(False)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"default_env",
|
|
35
|
+
"extract_jax",
|
|
36
|
+
"enable_globally",
|
|
37
|
+
"save_checkpoint",
|
|
38
|
+
"load_checkpoint",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
from .checkpoint import save_checkpoint, load_checkpoint
|
|
42
|
+
|
|
43
|
+
os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
|
|
44
|
+
|
|
45
|
+
# torchax:oss-begin
|
|
46
|
+
if getattr(jax.config, "jax_pjrt_client_create_options", None):
|
|
47
|
+
jax.config.update(
|
|
48
|
+
"jax_pjrt_client_create_options",
|
|
49
|
+
f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
|
|
50
|
+
)
|
|
51
|
+
# torchax:oss-end
|
|
52
|
+
|
|
53
|
+
env = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def default_env():
|
|
57
|
+
global env
|
|
58
|
+
|
|
59
|
+
if env is None:
|
|
60
|
+
env = tensor.Environment()
|
|
61
|
+
return env
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def extract_jax(mod: torch.nn.Module, env=None):
|
|
65
|
+
"""Returns a pytree of jax.ndarray and a jax callable."""
|
|
66
|
+
if env is None:
|
|
67
|
+
env = default_env()
|
|
68
|
+
states = dict(mod.named_buffers())
|
|
69
|
+
states.update(mod.named_parameters())
|
|
70
|
+
|
|
71
|
+
states = env.t2j_copy(states)
|
|
72
|
+
|
|
73
|
+
# @jax.jit
|
|
74
|
+
def jax_func(states, args, kwargs=None):
|
|
75
|
+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
|
|
76
|
+
with env:
|
|
77
|
+
res = torch.func.functional_call(
|
|
78
|
+
mod, states, args, kwargs, tie_weights=False
|
|
79
|
+
)
|
|
80
|
+
return env.t2j_iso(res)
|
|
81
|
+
|
|
82
|
+
return states, jax_func
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def enable_globally():
|
|
86
|
+
env = default_env().enable_torch_modes()
|
|
87
|
+
return env
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def disable_globally():
|
|
91
|
+
global env
|
|
92
|
+
default_env().disable_torch_modes()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@contextlib.contextmanager
|
|
96
|
+
def disable_temporarily():
|
|
97
|
+
prev = default_env().enabled
|
|
98
|
+
if prev:
|
|
99
|
+
disable_globally()
|
|
100
|
+
yield ()
|
|
101
|
+
if prev:
|
|
102
|
+
enable_globally()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
torch.utils.rename_privateuse1_backend("jax")
|
|
106
|
+
unsupported_dtype = [torch.quint8]
|
|
107
|
+
|
|
108
|
+
import jax
|
|
109
|
+
import torchax.device_module
|
|
110
|
+
|
|
111
|
+
torch._register_device_module("jax", torchax.device_module)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def enable_accuracy_mode():
|
|
115
|
+
jax.config.update("jax_enable_x64", True)
|
|
116
|
+
jax.config.update("jax_default_matmul_precision", "highest")
|
|
117
|
+
default_env().config.internal_respect_torch_return_dtypes = True
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def enable_performance_mode():
|
|
121
|
+
jax.config.update("jax_enable_x64", False)
|
|
122
|
+
jax.config.update("jax_default_matmul_precision", "default")
|
|
123
|
+
default_env().config.internal_respect_torch_return_dtypes = False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclasses.dataclass
|
|
127
|
+
class CompileOptions:
|
|
128
|
+
# only valid if compiling nn.Module
|
|
129
|
+
methods_to_compile: List[str] = dataclasses.field(
|
|
130
|
+
default_factory=lambda: ["forward"]
|
|
131
|
+
)
|
|
132
|
+
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
133
|
+
mode: str = "jax" # or dynamo or export
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def compile(fn, options: Optional[CompileOptions] = None):
|
|
137
|
+
options = options or CompileOptions()
|
|
138
|
+
if options.mode == "jax":
|
|
139
|
+
from torchax import interop
|
|
140
|
+
|
|
141
|
+
if isinstance(fn, torch.nn.Module):
|
|
142
|
+
module = interop.JittableModule(
|
|
143
|
+
fn, extra_jit_args=options.jax_jit_kwargs
|
|
144
|
+
)
|
|
145
|
+
for n in options.methods_to_compile:
|
|
146
|
+
module.make_jitted(n)
|
|
147
|
+
return module
|
|
148
|
+
else:
|
|
149
|
+
return interop.jax_jit(fn)
|
|
150
|
+
elif options.mode == "dynamo":
|
|
151
|
+
raise RuntimeError("dynamo mode is not supported yet")
|
|
152
|
+
elif options.mode == "export":
|
|
153
|
+
raise RuntimeError("export mode is not supported yet")
|
torchax/amp.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import contextlib
|
|
16
|
+
import enum
|
|
17
|
+
import torch
|
|
18
|
+
from torch.utils import _pytree as pytree
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# enum class CastPolicy : uint8_t {
|
|
22
|
+
# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
|
|
23
|
+
# // running the op. Currently, lower_precision_fp is
|
|
24
|
+
# // fp16 for AutocastCUDA, and is defined by user
|
|
25
|
+
# // (default bf16) for AutocastCPU or other device.
|
|
26
|
+
# fp32, // Cast all inputs to at::kFloat before running the op.
|
|
27
|
+
# fp32_set_opt_dtype, // Treats functions (like softmax) that
|
|
28
|
+
# // 1. we'd like to run in fp32 and
|
|
29
|
+
# // 2. have a std::optional<ScalarType> arg that controls
|
|
30
|
+
# // the output type.
|
|
31
|
+
# // fp32_set_opt_dtype wrappers' policy is: if the output
|
|
32
|
+
# // type is already set, don't touch it, otherwise, set
|
|
33
|
+
# // it to at::kFloat.
|
|
34
|
+
# fp32_append_dtype, // Treats functions (like norm) that
|
|
35
|
+
# // 1. we'd like to run in fp32 and
|
|
36
|
+
# // 2. have some overloads that accept an output type and
|
|
37
|
+
# // other overloads that don't.
|
|
38
|
+
# // fp32_append_dtype wrappers wrap the overloads that don't
|
|
39
|
+
# // have an output dtype.
|
|
40
|
+
# // The wrapper policy is: append at::kFloat to the args,
|
|
41
|
+
# // and redispatch to the type-aware overload.
|
|
42
|
+
# promote, // Run in the widest dtype among several args.
|
|
43
|
+
# };
|
|
44
|
+
class CastPolicy(enum.Enum):
|
|
45
|
+
LOWER_PRECISION_FP = 0
|
|
46
|
+
FP32 = 1
|
|
47
|
+
FP32_SET_OPT_DTYPE = 2
|
|
48
|
+
FP32_APPEND_DTYPE = 3
|
|
49
|
+
PROMOTE = 4
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def execute_policy(policy, args, kwargs, target_lower_fp):
|
|
53
|
+
|
|
54
|
+
def is_float(a):
|
|
55
|
+
return isinstance(a, torch.Tensor) and a.is_floating_point()
|
|
56
|
+
match policy:
|
|
57
|
+
case CastPolicy.LOWER_PRECISION_FP:
|
|
58
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
|
|
59
|
+
(args, kwargs))
|
|
60
|
+
case CastPolicy.FP32:
|
|
61
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
|
|
62
|
+
(args, kwargs))
|
|
63
|
+
case CastPolicy.PROMOTE:
|
|
64
|
+
dtypes = set(a.dtype for a in args)
|
|
65
|
+
widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
|
|
66
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(widest),
|
|
67
|
+
(args, kwargs))
|
|
68
|
+
case _:
|
|
69
|
+
raise AssertionError(f'Policy {policy} not implemented yet.')
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@contextlib.contextmanager
|
|
73
|
+
def autocast(device, dtype=torch.bfloat16, env=None):
|
|
74
|
+
del device
|
|
75
|
+
if env is None:
|
|
76
|
+
import torchax
|
|
77
|
+
env = torchax.default_env()
|
|
78
|
+
with env.override_property(autocast_dtype=dtype):
|
|
79
|
+
yield
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
|
|
83
|
+
autocast_policy = {
|
|
84
|
+
torch.ops.aten.conv1d.default:
|
|
85
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
86
|
+
torch.ops.aten.conv1d.padding:
|
|
87
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
88
|
+
torch.ops.aten.conv2d.default:
|
|
89
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
90
|
+
torch.ops.aten.conv2d.padding:
|
|
91
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
92
|
+
torch.ops.aten.conv3d.default:
|
|
93
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
94
|
+
torch.ops.aten.conv3d.padding:
|
|
95
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
96
|
+
torch.ops.aten.bmm.default:
|
|
97
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
98
|
+
torch.ops.aten.mm.default:
|
|
99
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
100
|
+
torch.ops.aten.linalg_vecdot.default:
|
|
101
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
102
|
+
torch.ops.aten.baddbmm.default:
|
|
103
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
104
|
+
torch.ops.aten.addmm.default:
|
|
105
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
106
|
+
torch.ops.aten._addmm_activation.default:
|
|
107
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
108
|
+
torch.ops.aten.addbmm.default:
|
|
109
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
110
|
+
torch.ops.aten.linear.default:
|
|
111
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
112
|
+
torch.ops.aten._convolution.deprecated:
|
|
113
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
114
|
+
torch.ops.aten.matmul.default:
|
|
115
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
116
|
+
torch.ops.aten.conv_tbc.default:
|
|
117
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
118
|
+
torch.ops.aten.mkldnn_rnn_layer.default:
|
|
119
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
120
|
+
torch.ops.aten.conv_transpose1d.default:
|
|
121
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
122
|
+
torch.ops.aten.conv_transpose2d.input:
|
|
123
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
124
|
+
torch.ops.aten.conv_transpose3d.input:
|
|
125
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
126
|
+
torch.ops.aten.prelu.default:
|
|
127
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
128
|
+
torch.ops.aten.scaled_dot_product_attention.default:
|
|
129
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
130
|
+
torch.ops.aten._native_multi_head_attention.default:
|
|
131
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
132
|
+
|
|
133
|
+
# fp32 cast policy
|
|
134
|
+
torch.ops.aten.avg_pool3d.default:
|
|
135
|
+
CastPolicy.FP32,
|
|
136
|
+
torch.ops.aten.binary_cross_entropy.default:
|
|
137
|
+
CastPolicy.FP32,
|
|
138
|
+
torch.ops.aten.grid_sampler.default:
|
|
139
|
+
CastPolicy.FP32,
|
|
140
|
+
torch.ops.aten.polar.default:
|
|
141
|
+
CastPolicy.FP32,
|
|
142
|
+
torch.ops.aten.prod.default:
|
|
143
|
+
CastPolicy.FP32,
|
|
144
|
+
torch.ops.aten.prod.dim_int:
|
|
145
|
+
CastPolicy.FP32,
|
|
146
|
+
torch.ops.aten.prod.dim_Dimname:
|
|
147
|
+
CastPolicy.FP32,
|
|
148
|
+
torch.ops.aten.quantile.default:
|
|
149
|
+
CastPolicy.FP32,
|
|
150
|
+
torch.ops.aten.quantile.scalar:
|
|
151
|
+
CastPolicy.FP32,
|
|
152
|
+
torch.ops.aten.nanquantile.default:
|
|
153
|
+
CastPolicy.FP32,
|
|
154
|
+
torch.ops.aten.nanquantile.scalar:
|
|
155
|
+
CastPolicy.FP32,
|
|
156
|
+
torch.ops.aten.stft.default:
|
|
157
|
+
CastPolicy.FP32,
|
|
158
|
+
torch.ops.aten.stft.center:
|
|
159
|
+
CastPolicy.FP32,
|
|
160
|
+
torch.ops.aten.cdist.default:
|
|
161
|
+
CastPolicy.FP32,
|
|
162
|
+
torch.ops.aten.grid_sampler_2d.default:
|
|
163
|
+
CastPolicy.FP32,
|
|
164
|
+
torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
|
|
165
|
+
CastPolicy.FP32,
|
|
166
|
+
torch.ops.aten.grid_sampler_3d.default:
|
|
167
|
+
CastPolicy.FP32,
|
|
168
|
+
torch.ops.aten.trace.default:
|
|
169
|
+
CastPolicy.FP32,
|
|
170
|
+
torch.ops.aten.view_as_complex.default:
|
|
171
|
+
CastPolicy.FP32,
|
|
172
|
+
torch.ops.aten.cholesky.default:
|
|
173
|
+
CastPolicy.FP32,
|
|
174
|
+
torch.ops.aten.cholesky_inverse.default:
|
|
175
|
+
CastPolicy.FP32,
|
|
176
|
+
torch.ops.aten.cholesky_solve.default:
|
|
177
|
+
CastPolicy.FP32,
|
|
178
|
+
torch.ops.aten.inverse.default:
|
|
179
|
+
CastPolicy.FP32,
|
|
180
|
+
torch.ops.aten.lu_solve.default:
|
|
181
|
+
CastPolicy.FP32,
|
|
182
|
+
torch.ops.aten.orgqr.default:
|
|
183
|
+
CastPolicy.FP32,
|
|
184
|
+
torch.ops.aten.ormqr.default:
|
|
185
|
+
CastPolicy.FP32,
|
|
186
|
+
torch.ops.aten.pinverse.default:
|
|
187
|
+
CastPolicy.FP32,
|
|
188
|
+
torch.ops.aten.max_pool3d.default:
|
|
189
|
+
CastPolicy.FP32,
|
|
190
|
+
torch.ops.aten.max_unpool2d.default:
|
|
191
|
+
CastPolicy.FP32,
|
|
192
|
+
torch.ops.aten.max_unpool3d.default:
|
|
193
|
+
CastPolicy.FP32,
|
|
194
|
+
torch.ops.aten.adaptive_avg_pool3d.default:
|
|
195
|
+
CastPolicy.FP32,
|
|
196
|
+
torch.ops.aten.reflection_pad1d.default:
|
|
197
|
+
CastPolicy.FP32,
|
|
198
|
+
torch.ops.aten.reflection_pad2d.default:
|
|
199
|
+
CastPolicy.FP32,
|
|
200
|
+
torch.ops.aten.replication_pad1d.default:
|
|
201
|
+
CastPolicy.FP32,
|
|
202
|
+
torch.ops.aten.replication_pad2d.default:
|
|
203
|
+
CastPolicy.FP32,
|
|
204
|
+
torch.ops.aten.replication_pad3d.default:
|
|
205
|
+
CastPolicy.FP32,
|
|
206
|
+
torch.ops.aten.mse_loss.default:
|
|
207
|
+
CastPolicy.FP32,
|
|
208
|
+
torch.ops.aten.cosine_embedding_loss.default:
|
|
209
|
+
CastPolicy.FP32,
|
|
210
|
+
torch.ops.aten.nll_loss.default:
|
|
211
|
+
CastPolicy.FP32,
|
|
212
|
+
torch.ops.aten.nll_loss2d.default:
|
|
213
|
+
CastPolicy.FP32,
|
|
214
|
+
torch.ops.aten.hinge_embedding_loss.default:
|
|
215
|
+
CastPolicy.FP32,
|
|
216
|
+
torch.ops.aten.poisson_nll_loss.default:
|
|
217
|
+
CastPolicy.FP32,
|
|
218
|
+
torch.ops.aten.smooth_l1_loss.default:
|
|
219
|
+
CastPolicy.FP32,
|
|
220
|
+
torch.ops.aten.cross_entropy_loss.default:
|
|
221
|
+
CastPolicy.FP32,
|
|
222
|
+
torch.ops.aten.l1_loss.default:
|
|
223
|
+
CastPolicy.FP32,
|
|
224
|
+
torch.ops.aten.huber_loss.default:
|
|
225
|
+
CastPolicy.FP32,
|
|
226
|
+
torch.ops.aten.margin_ranking_loss.default:
|
|
227
|
+
CastPolicy.FP32,
|
|
228
|
+
torch.ops.aten.soft_margin_loss.default:
|
|
229
|
+
CastPolicy.FP32,
|
|
230
|
+
torch.ops.aten.triplet_margin_loss.default:
|
|
231
|
+
CastPolicy.FP32,
|
|
232
|
+
torch.ops.aten.multi_margin_loss.default:
|
|
233
|
+
CastPolicy.FP32,
|
|
234
|
+
torch.ops.aten.ctc_loss.IntList:
|
|
235
|
+
CastPolicy.FP32,
|
|
236
|
+
torch.ops.aten.ctc_loss.Tensor:
|
|
237
|
+
CastPolicy.FP32,
|
|
238
|
+
torch.ops.aten.kl_div.default:
|
|
239
|
+
CastPolicy.FP32,
|
|
240
|
+
torch.ops.aten.multilabel_margin_loss.default:
|
|
241
|
+
CastPolicy.FP32,
|
|
242
|
+
torch.ops.aten.binary_cross_entropy_with_logits.default:
|
|
243
|
+
CastPolicy.FP32,
|
|
244
|
+
torch.ops.aten.fft_fft.default:
|
|
245
|
+
CastPolicy.FP32,
|
|
246
|
+
torch.ops.aten.fft_ifft.default:
|
|
247
|
+
CastPolicy.FP32,
|
|
248
|
+
torch.ops.aten.fft_fft2.default:
|
|
249
|
+
CastPolicy.FP32,
|
|
250
|
+
torch.ops.aten.fft_ifft2.default:
|
|
251
|
+
CastPolicy.FP32,
|
|
252
|
+
torch.ops.aten.fft_fftn.default:
|
|
253
|
+
CastPolicy.FP32,
|
|
254
|
+
torch.ops.aten.fft_ifftn.default:
|
|
255
|
+
CastPolicy.FP32,
|
|
256
|
+
torch.ops.aten.fft_rfft.default:
|
|
257
|
+
CastPolicy.FP32,
|
|
258
|
+
torch.ops.aten.fft_irfft.default:
|
|
259
|
+
CastPolicy.FP32,
|
|
260
|
+
torch.ops.aten.fft_rfft2.default:
|
|
261
|
+
CastPolicy.FP32,
|
|
262
|
+
torch.ops.aten.fft_irfft2.default:
|
|
263
|
+
CastPolicy.FP32,
|
|
264
|
+
torch.ops.aten.fft_rfftn.default:
|
|
265
|
+
CastPolicy.FP32,
|
|
266
|
+
torch.ops.aten.fft_irfftn.default:
|
|
267
|
+
CastPolicy.FP32,
|
|
268
|
+
torch.ops.aten.fft_hfft.default:
|
|
269
|
+
CastPolicy.FP32,
|
|
270
|
+
torch.ops.aten.fft_ihfft.default:
|
|
271
|
+
CastPolicy.FP32,
|
|
272
|
+
torch.ops.aten.linalg_cond.default:
|
|
273
|
+
CastPolicy.FP32,
|
|
274
|
+
torch.ops.aten.linalg_cond.p_str:
|
|
275
|
+
CastPolicy.FP32,
|
|
276
|
+
torch.ops.aten.linalg_matrix_rank.default:
|
|
277
|
+
CastPolicy.FP32,
|
|
278
|
+
torch.ops.aten.linalg_matrix_rank.tol_tensor:
|
|
279
|
+
CastPolicy.FP32,
|
|
280
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
|
|
281
|
+
CastPolicy.FP32,
|
|
282
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
|
|
283
|
+
CastPolicy.FP32,
|
|
284
|
+
torch.ops.aten.linalg_solve.default:
|
|
285
|
+
CastPolicy.FP32,
|
|
286
|
+
torch.ops.aten.linalg_cholesky.default:
|
|
287
|
+
CastPolicy.FP32,
|
|
288
|
+
torch.ops.aten.linalg_svdvals.default:
|
|
289
|
+
CastPolicy.FP32,
|
|
290
|
+
torch.ops.aten.linalg_eigvals.default:
|
|
291
|
+
CastPolicy.FP32,
|
|
292
|
+
torch.ops.aten.linalg_eigvalsh.default:
|
|
293
|
+
CastPolicy.FP32,
|
|
294
|
+
torch.ops.aten.linalg_inv.default:
|
|
295
|
+
CastPolicy.FP32,
|
|
296
|
+
torch.ops.aten.linalg_householder_product.default:
|
|
297
|
+
CastPolicy.FP32,
|
|
298
|
+
torch.ops.aten.linalg_tensorinv.default:
|
|
299
|
+
CastPolicy.FP32,
|
|
300
|
+
torch.ops.aten.linalg_tensorsolve.default:
|
|
301
|
+
CastPolicy.FP32,
|
|
302
|
+
torch.ops.aten.fake_quantize_per_tensor_affine.default:
|
|
303
|
+
CastPolicy.FP32,
|
|
304
|
+
torch.ops.aten.geqrf.default:
|
|
305
|
+
CastPolicy.FP32,
|
|
306
|
+
torch.ops.aten._lu_with_info.default:
|
|
307
|
+
CastPolicy.FP32,
|
|
308
|
+
torch.ops.aten.qr.default:
|
|
309
|
+
CastPolicy.FP32,
|
|
310
|
+
torch.ops.aten.svd.default:
|
|
311
|
+
CastPolicy.FP32,
|
|
312
|
+
torch.ops.aten.triangular_solve.default:
|
|
313
|
+
CastPolicy.FP32,
|
|
314
|
+
torch.ops.aten.fractional_max_pool2d.default:
|
|
315
|
+
CastPolicy.FP32,
|
|
316
|
+
torch.ops.aten.fractional_max_pool3d.default:
|
|
317
|
+
CastPolicy.FP32,
|
|
318
|
+
torch.ops.aten.adaptive_max_pool3d.default:
|
|
319
|
+
CastPolicy.FP32,
|
|
320
|
+
torch.ops.aten.multilabel_margin_loss_forward.default:
|
|
321
|
+
CastPolicy.FP32,
|
|
322
|
+
torch.ops.aten.linalg_qr.default:
|
|
323
|
+
CastPolicy.FP32,
|
|
324
|
+
torch.ops.aten.linalg_cholesky_ex.default:
|
|
325
|
+
CastPolicy.FP32,
|
|
326
|
+
torch.ops.aten.linalg_svd.default:
|
|
327
|
+
CastPolicy.FP32,
|
|
328
|
+
torch.ops.aten.linalg_eig.default:
|
|
329
|
+
CastPolicy.FP32,
|
|
330
|
+
torch.ops.aten.linalg_eigh.default:
|
|
331
|
+
CastPolicy.FP32,
|
|
332
|
+
torch.ops.aten.linalg_lstsq.default:
|
|
333
|
+
CastPolicy.FP32,
|
|
334
|
+
torch.ops.aten.linalg_inv_ex.default:
|
|
335
|
+
CastPolicy.FP32,
|
|
336
|
+
|
|
337
|
+
# promote
|
|
338
|
+
torch.ops.aten.stack.default:
|
|
339
|
+
CastPolicy.PROMOTE,
|
|
340
|
+
torch.ops.aten.cat.default:
|
|
341
|
+
CastPolicy.PROMOTE,
|
|
342
|
+
torch.ops.aten.index_copy.default:
|
|
343
|
+
CastPolicy.PROMOTE,
|
|
344
|
+
torch.ops.aten.index_copy.dimname:
|
|
345
|
+
CastPolicy.PROMOTE,
|
|
346
|
+
}
|
torchax/checkpoint.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import os
|
|
17
|
+
from typing import Any, Dict
|
|
18
|
+
from flax.training import checkpoints
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
import numpy as np
|
|
22
|
+
from . import tensor
|
|
23
|
+
|
|
24
|
+
def _to_jax(pytree):
|
|
25
|
+
def to_jax_array(x):
|
|
26
|
+
if isinstance(x, tensor.Tensor):
|
|
27
|
+
return x.jax()
|
|
28
|
+
elif isinstance(x, torch.Tensor):
|
|
29
|
+
return jnp.asarray(x.cpu().numpy())
|
|
30
|
+
return x
|
|
31
|
+
return jax.tree_util.tree_map(to_jax_array, pytree)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _to_torch(pytree):
|
|
35
|
+
return jax.tree_util.tree_map(
|
|
36
|
+
lambda x: torch.from_numpy(np.asarray(x))
|
|
37
|
+
if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def save_checkpoint(state: Dict[str, Any], path: str, step: int):
|
|
41
|
+
"""Saves a checkpoint to a file in JAX style.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
state: A dictionary containing the state to save. torch.Tensors will be
|
|
45
|
+
converted to jax.Array.
|
|
46
|
+
path: The path to save the checkpoint to. This is a directory.
|
|
47
|
+
step: The training step.
|
|
48
|
+
"""
|
|
49
|
+
state = _to_jax(state)
|
|
50
|
+
checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_checkpoint(path: str) -> Dict[str, Any]:
|
|
54
|
+
"""Loads a checkpoint and returns it in JAX format.
|
|
55
|
+
|
|
56
|
+
This function can load both PyTorch-style (single file) and JAX-style
|
|
57
|
+
(directory) checkpoints.
|
|
58
|
+
|
|
59
|
+
If the checkpoint is in PyTorch format, it will be converted to JAX format.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
path: The path to the checkpoint.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The loaded state in JAX format (pytree with jax.Array leaves).
|
|
66
|
+
"""
|
|
67
|
+
if os.path.isdir(path):
|
|
68
|
+
# JAX-style checkpoint
|
|
69
|
+
state = checkpoints.restore_checkpoint(path, target=None)
|
|
70
|
+
if state is None:
|
|
71
|
+
raise FileNotFoundError(f"No checkpoint found at {path}")
|
|
72
|
+
return state
|
|
73
|
+
elif os.path.isfile(path):
|
|
74
|
+
# PyTorch-style checkpoint
|
|
75
|
+
state = torch.load(path, weights_only=False)
|
|
76
|
+
return _to_jax(state)
|
|
77
|
+
else:
|
|
78
|
+
raise FileNotFoundError(f"No such file or directory: {path}")
|
|
79
|
+
|
torchax/config.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclasses.dataclass
|
|
19
|
+
class Configuration:
|
|
20
|
+
debug_print_each_op: bool = False
|
|
21
|
+
debug_accuracy_for_each_op: bool = False
|
|
22
|
+
debug_mixed_tensor: bool = False
|
|
23
|
+
debug_print_each_op_operands: bool = False
|
|
24
|
+
|
|
25
|
+
use_int32_for_index: bool = False
|
|
26
|
+
|
|
27
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
28
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
29
|
+
# can use scalar * tensor math to handle it
|
|
30
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
31
|
+
|
|
32
|
+
# If true, we will convert Views into torchax.Tensors eagerly
|
|
33
|
+
force_materialize_views: bool = False
|
|
34
|
+
|
|
35
|
+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
36
|
+
use_dlpack_for_data_conversion: bool = False
|
|
37
|
+
|
|
38
|
+
# Flash attention
|
|
39
|
+
use_tpu_flash_attention: bool = False
|
|
40
|
+
shmap_flash_attention: bool = False
|
|
41
|
+
|
|
42
|
+
# device
|
|
43
|
+
treat_cuda_as_jax_device: bool = True
|
|
44
|
+
internal_respect_torch_return_dtypes: bool = False
|