torchax 0.0.6__py3-none-any.whl → 0.0.10.dev20251116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +10 -5
- torchax/__init__.py +92 -65
- torchax/amp.py +14 -0
- torchax/checkpoint.py +79 -0
- torchax/config.py +14 -0
- torchax/decompositions.py +14 -0
- torchax/device_module.py +14 -0
- torchax/export.py +14 -0
- torchax/flax.py +14 -0
- torchax/interop.py +44 -31
- torchax/mesh_util.py +14 -0
- torchax/ops/__init__.py +14 -0
- torchax/ops/jaten.py +3985 -3686
- torchax/ops/jax_reimplement.py +14 -0
- torchax/ops/jc10d.py +14 -0
- torchax/ops/jimage.py +14 -0
- torchax/ops/jlibrary.py +14 -0
- torchax/ops/jtorch.py +364 -309
- torchax/ops/jtorchvision_nms.py +14 -0
- torchax/ops/mappings.py +26 -4
- torchax/ops/op_base.py +14 -0
- torchax/ops/ops_registry.py +14 -0
- torchax/tensor.py +38 -13
- torchax/train.py +112 -97
- torchax/types.py +14 -0
- torchax/util.py +14 -0
- torchax/view.py +14 -0
- torchax-0.0.10.dev20251116.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251116.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251116.dist-info/licenses/LICENSE +201 -0
- torchax/configuration.py +0 -30
- torchax/environment.py +0 -1
- torchax/tf_integration.py +0 -119
- torchax-0.0.6.dist-info/METADATA +0 -307
- torchax-0.0.6.dist-info/RECORD +0 -33
- torchax-0.0.6.dist-info/licenses/LICENSE +0 -28
- {torchax-0.0.6.dist-info → torchax-0.0.10.dev20251116.dist-info}/WHEEL +0 -0
torchax/CONTRIBUTING.md
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
# Contributing to
|
|
1
|
+
# Contributing to torchax
|
|
2
2
|
|
|
3
3
|
We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
|
|
4
4
|
|
|
5
|
-
If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
|
|
6
|
-
|
|
7
5
|
|
|
8
6
|
# Developer setup
|
|
9
7
|
|
|
@@ -19,9 +17,17 @@ conda activate <your_name>
|
|
|
19
17
|
pip install --upgrade "jax[cpu]" torch
|
|
20
18
|
pip install -r test_requirements.txt
|
|
21
19
|
pip install -e .
|
|
22
|
-
pytest test
|
|
20
|
+
pip install pytest-xdist # recommended for running test faster
|
|
21
|
+
pytest -n auto test
|
|
23
22
|
```
|
|
24
23
|
|
|
24
|
+
## Setup on GPU or TPU
|
|
25
|
+
|
|
26
|
+
Same as Mac setup, except, if you run test using pytest, please also
|
|
27
|
+
add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs
|
|
28
|
+
test in multiple threads. CPU device can be accessed concurrently where
|
|
29
|
+
TPU devices usually only allow one accesor per process; so it could deadlock.
|
|
30
|
+
|
|
25
31
|
### VSCode
|
|
26
32
|
|
|
27
33
|
I use vscode on my Mac. I loosely followed instruction in
|
|
@@ -35,4 +41,3 @@ The plugins I installed (a subset of the ones listed above) are:
|
|
|
35
41
|
|
|
36
42
|
I also changed Python interpreter to point at the one in my conda env.
|
|
37
43
|
That is all the changes I have.
|
|
38
|
-
|
torchax/__init__.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import contextlib
|
|
2
16
|
from typing import List, Dict, Any, Optional
|
|
3
17
|
import dataclasses
|
|
@@ -8,119 +22,132 @@ from torch.utils import _pytree as pytree
|
|
|
8
22
|
from torchax import tensor
|
|
9
23
|
from contextlib import contextmanager
|
|
10
24
|
|
|
11
|
-
__version__ = "0.0.
|
|
25
|
+
__version__ = "0.0.10.dev20251116"
|
|
12
26
|
VERSION = __version__
|
|
13
27
|
|
|
28
|
+
# the "fast path" uses some sparse tensor thingies that currently we
|
|
29
|
+
# don't support
|
|
30
|
+
torch.backends.mha.set_fastpath_enabled(False)
|
|
31
|
+
|
|
32
|
+
|
|
14
33
|
__all__ = [
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
34
|
+
"default_env",
|
|
35
|
+
"extract_jax",
|
|
36
|
+
"enable_globally",
|
|
37
|
+
"save_checkpoint",
|
|
38
|
+
"load_checkpoint",
|
|
18
39
|
]
|
|
19
40
|
|
|
20
|
-
from
|
|
41
|
+
from .checkpoint import save_checkpoint, load_checkpoint
|
|
21
42
|
|
|
22
|
-
os.environ.setdefault(
|
|
43
|
+
os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
|
|
23
44
|
|
|
24
45
|
# torchax:oss-begin
|
|
25
|
-
if getattr(jax.config,
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
46
|
+
if getattr(jax.config, "jax_pjrt_client_create_options", None):
|
|
47
|
+
jax.config.update(
|
|
48
|
+
"jax_pjrt_client_create_options",
|
|
49
|
+
f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
|
|
50
|
+
)
|
|
29
51
|
# torchax:oss-end
|
|
30
52
|
|
|
31
53
|
env = None
|
|
32
54
|
|
|
33
55
|
|
|
34
56
|
def default_env():
|
|
35
|
-
|
|
57
|
+
global env
|
|
36
58
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
59
|
+
if env is None:
|
|
60
|
+
env = tensor.Environment()
|
|
61
|
+
return env
|
|
40
62
|
|
|
41
63
|
|
|
42
64
|
def extract_jax(mod: torch.nn.Module, env=None):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
65
|
+
"""Returns a pytree of jax.ndarray and a jax callable."""
|
|
66
|
+
if env is None:
|
|
67
|
+
env = default_env()
|
|
68
|
+
states = dict(mod.named_buffers())
|
|
69
|
+
states.update(mod.named_parameters())
|
|
48
70
|
|
|
49
|
-
|
|
71
|
+
states = env.t2j_copy(states)
|
|
50
72
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
73
|
+
# @jax.jit
|
|
74
|
+
def jax_func(states, args, kwargs=None):
|
|
75
|
+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
|
|
76
|
+
with env:
|
|
77
|
+
res = torch.func.functional_call(
|
|
78
|
+
mod, states, args, kwargs, tie_weights=False
|
|
79
|
+
)
|
|
80
|
+
return env.t2j_iso(res)
|
|
58
81
|
|
|
59
|
-
|
|
82
|
+
return states, jax_func
|
|
60
83
|
|
|
61
84
|
|
|
62
85
|
def enable_globally():
|
|
63
|
-
|
|
64
|
-
|
|
86
|
+
env = default_env().enable_torch_modes()
|
|
87
|
+
return env
|
|
65
88
|
|
|
66
89
|
|
|
67
90
|
def disable_globally():
|
|
68
|
-
|
|
69
|
-
|
|
91
|
+
global env
|
|
92
|
+
default_env().disable_torch_modes()
|
|
70
93
|
|
|
71
94
|
|
|
72
95
|
@contextlib.contextmanager
|
|
73
96
|
def disable_temporarily():
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
97
|
+
prev = default_env().enabled
|
|
98
|
+
if prev:
|
|
99
|
+
disable_globally()
|
|
100
|
+
yield ()
|
|
101
|
+
if prev:
|
|
102
|
+
enable_globally()
|
|
80
103
|
|
|
81
104
|
|
|
82
|
-
torch.utils.rename_privateuse1_backend(
|
|
105
|
+
torch.utils.rename_privateuse1_backend("jax")
|
|
83
106
|
unsupported_dtype = [torch.quint8]
|
|
84
107
|
|
|
85
108
|
import jax
|
|
86
109
|
import torchax.device_module
|
|
87
110
|
|
|
88
|
-
torch._register_device_module(
|
|
111
|
+
torch._register_device_module("jax", torchax.device_module)
|
|
89
112
|
|
|
90
113
|
|
|
91
114
|
def enable_accuracy_mode():
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
115
|
+
jax.config.update("jax_enable_x64", True)
|
|
116
|
+
jax.config.update("jax_default_matmul_precision", "highest")
|
|
117
|
+
default_env().config.internal_respect_torch_return_dtypes = True
|
|
95
118
|
|
|
96
119
|
|
|
97
120
|
def enable_performance_mode():
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
121
|
+
jax.config.update("jax_enable_x64", False)
|
|
122
|
+
jax.config.update("jax_default_matmul_precision", "default")
|
|
123
|
+
default_env().config.internal_respect_torch_return_dtypes = False
|
|
101
124
|
|
|
102
125
|
|
|
103
126
|
@dataclasses.dataclass
|
|
104
127
|
class CompileOptions:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
128
|
+
# only valid if compiling nn.Module
|
|
129
|
+
methods_to_compile: List[str] = dataclasses.field(
|
|
130
|
+
default_factory=lambda: ["forward"]
|
|
131
|
+
)
|
|
132
|
+
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
133
|
+
mode: str = "jax" # or dynamo or export
|
|
110
134
|
|
|
111
135
|
|
|
112
136
|
def compile(fn, options: Optional[CompileOptions] = None):
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
137
|
+
options = options or CompileOptions()
|
|
138
|
+
if options.mode == "jax":
|
|
139
|
+
from torchax import interop
|
|
140
|
+
|
|
141
|
+
if isinstance(fn, torch.nn.Module):
|
|
142
|
+
module = interop.JittableModule(
|
|
143
|
+
fn, extra_jit_args=options.jax_jit_kwargs
|
|
144
|
+
)
|
|
145
|
+
for n in options.methods_to_compile:
|
|
146
|
+
module.make_jitted(n)
|
|
147
|
+
return module
|
|
148
|
+
else:
|
|
149
|
+
return interop.jax_jit(fn)
|
|
150
|
+
elif options.mode == "dynamo":
|
|
151
|
+
raise RuntimeError("dynamo mode is not supported yet")
|
|
152
|
+
elif options.mode == "export":
|
|
153
|
+
raise RuntimeError("export mode is not supported yet")
|
torchax/amp.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import contextlib
|
|
2
16
|
import enum
|
|
3
17
|
import torch
|
torchax/checkpoint.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import os
|
|
17
|
+
from typing import Any, Dict
|
|
18
|
+
from flax.training import checkpoints
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
import numpy as np
|
|
22
|
+
from . import tensor
|
|
23
|
+
|
|
24
|
+
def _to_jax(pytree):
|
|
25
|
+
def to_jax_array(x):
|
|
26
|
+
if isinstance(x, tensor.Tensor):
|
|
27
|
+
return x.jax()
|
|
28
|
+
elif isinstance(x, torch.Tensor):
|
|
29
|
+
return jnp.asarray(x.cpu().numpy())
|
|
30
|
+
return x
|
|
31
|
+
return jax.tree_util.tree_map(to_jax_array, pytree)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _to_torch(pytree):
|
|
35
|
+
return jax.tree_util.tree_map(
|
|
36
|
+
lambda x: torch.from_numpy(np.asarray(x))
|
|
37
|
+
if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def save_checkpoint(state: Dict[str, Any], path: str, step: int):
|
|
41
|
+
"""Saves a checkpoint to a file in JAX style.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
state: A dictionary containing the state to save. torch.Tensors will be
|
|
45
|
+
converted to jax.Array.
|
|
46
|
+
path: The path to save the checkpoint to. This is a directory.
|
|
47
|
+
step: The training step.
|
|
48
|
+
"""
|
|
49
|
+
state = _to_jax(state)
|
|
50
|
+
checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_checkpoint(path: str) -> Dict[str, Any]:
|
|
54
|
+
"""Loads a checkpoint and returns it in JAX format.
|
|
55
|
+
|
|
56
|
+
This function can load both PyTorch-style (single file) and JAX-style
|
|
57
|
+
(directory) checkpoints.
|
|
58
|
+
|
|
59
|
+
If the checkpoint is in PyTorch format, it will be converted to JAX format.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
path: The path to the checkpoint.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The loaded state in JAX format (pytree with jax.Array leaves).
|
|
66
|
+
"""
|
|
67
|
+
if os.path.isdir(path):
|
|
68
|
+
# JAX-style checkpoint
|
|
69
|
+
state = checkpoints.restore_checkpoint(path, target=None)
|
|
70
|
+
if state is None:
|
|
71
|
+
raise FileNotFoundError(f"No checkpoint found at {path}")
|
|
72
|
+
return state
|
|
73
|
+
elif os.path.isfile(path):
|
|
74
|
+
# PyTorch-style checkpoint
|
|
75
|
+
state = torch.load(path, weights_only=False)
|
|
76
|
+
return _to_jax(state)
|
|
77
|
+
else:
|
|
78
|
+
raise FileNotFoundError(f"No such file or directory: {path}")
|
|
79
|
+
|
torchax/config.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import dataclasses
|
|
2
16
|
|
|
3
17
|
|
torchax/decompositions.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
"""This file contains some decompositons that are not available in torch stable.
|
|
2
16
|
|
|
3
17
|
Most likely from Content of
|
torchax/device_module.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import torch
|
|
2
16
|
|
|
3
17
|
|
torchax/export.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# pylint: disable
|
|
2
16
|
"""Utilities for exporting a torch program to jax/stablehlo."""
|
|
3
17
|
import copy
|
torchax/flax.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
"""Flax interop."""
|
|
2
16
|
|
|
3
17
|
import torch
|
torchax/interop.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import collections
|
|
2
16
|
import copy
|
|
3
17
|
import functools
|
|
@@ -182,7 +196,6 @@ def _torch_view(t: JaxValue) -> TorchValue:
|
|
|
182
196
|
# t is an object from jax land
|
|
183
197
|
# view it as-if it's a torch land object
|
|
184
198
|
if isinstance(t, jax.Array):
|
|
185
|
-
# TODO
|
|
186
199
|
return tensor.Tensor(t, torchax.default_env())
|
|
187
200
|
if isinstance(t, jnp.dtype):
|
|
188
201
|
return mappings.j2t_dtype(t)
|
|
@@ -237,6 +250,36 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
237
250
|
the PyTorch autograd framework by saving the residuals into the context object.
|
|
238
251
|
"""
|
|
239
252
|
|
|
253
|
+
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
254
|
+
# Becuase if it does, then it won't hit the compilation cache for
|
|
255
|
+
# call_jax. Call jax uses functions' id as key.
|
|
256
|
+
# It is nested inside j2t_autograd to ensure it gets a unique ID for each
|
|
257
|
+
# wrapped pure function, preventing cache collisions between different pure modules.
|
|
258
|
+
def _jax_forward(fn, other, tree_def, tensors):
|
|
259
|
+
"""JAX function to compute output and vjp function.
|
|
260
|
+
|
|
261
|
+
primals should be a tuple (args, kwargs).
|
|
262
|
+
"""
|
|
263
|
+
import jax
|
|
264
|
+
from jax.tree_util import tree_flatten, tree_unflatten
|
|
265
|
+
|
|
266
|
+
def fn_wrapper(*tensors):
|
|
267
|
+
# Reconstruct the original args and kwargs
|
|
268
|
+
flat_inputs = util.merge(tensors, other)
|
|
269
|
+
args, kwargs = tree_unflatten(tree_def, flat_inputs)
|
|
270
|
+
return fn(*args, **kwargs)
|
|
271
|
+
|
|
272
|
+
return jax.vjp(fn_wrapper, *tensors)
|
|
273
|
+
|
|
274
|
+
def _jax_backward(vjp_spec, saved_tensors, grad_out):
|
|
275
|
+
"""JAX function to compute input gradients.
|
|
276
|
+
|
|
277
|
+
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
278
|
+
"""
|
|
279
|
+
from jax.tree_util import tree_unflatten
|
|
280
|
+
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
281
|
+
return fun_vjp(grad_out)
|
|
282
|
+
|
|
240
283
|
@wraps(fn)
|
|
241
284
|
def inner(*args, **kwargs):
|
|
242
285
|
from jax.tree_util import tree_flatten
|
|
@@ -290,36 +333,6 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
290
333
|
return inner
|
|
291
334
|
|
|
292
335
|
|
|
293
|
-
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
294
|
-
# Becuase if it does, then it won't hit the compilation cache for
|
|
295
|
-
# call_jax. Call jax uses functions' id as key.
|
|
296
|
-
def _jax_forward(fn, other, tree_def, tensors):
|
|
297
|
-
"""JAX function to compute output and vjp function.
|
|
298
|
-
|
|
299
|
-
primals should be a tuple (args, kwargs).
|
|
300
|
-
"""
|
|
301
|
-
import jax
|
|
302
|
-
from jax.tree_util import tree_flatten, tree_unflatten
|
|
303
|
-
|
|
304
|
-
def fn_wrapper(*tensors):
|
|
305
|
-
# Reconstruct the original args and kwargs
|
|
306
|
-
flat_inputs = util.merge(tensors, other)
|
|
307
|
-
args, kwargs = tree_unflatten(tree_def, flat_inputs)
|
|
308
|
-
return fn(*args, **kwargs)
|
|
309
|
-
|
|
310
|
-
return jax.vjp(fn_wrapper, *tensors)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
def _jax_backward(vjp_spec, saved_tensors, grad_out):
|
|
314
|
-
"""JAX function to compute input gradients.
|
|
315
|
-
|
|
316
|
-
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
317
|
-
"""
|
|
318
|
-
from jax.tree_util import tree_unflatten
|
|
319
|
-
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
320
|
-
return fun_vjp(grad_out)
|
|
321
|
-
|
|
322
|
-
|
|
323
336
|
fori_loop = torch_view(jax.lax.fori_loop)
|
|
324
337
|
|
|
325
338
|
|
torchax/mesh_util.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import jax
|
|
2
16
|
from jax.sharding import PartitionSpec, NamedSharding
|
|
3
17
|
import torch
|
torchax/ops/__init__.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
def all_aten_jax_ops():
|
|
2
16
|
# to load the ops
|
|
3
17
|
import torchax.ops.jaten # type: ignore
|