torchax 0.0.6__py3-none-any.whl → 0.0.10.dev20251116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/CONTRIBUTING.md CHANGED
@@ -1,9 +1,7 @@
1
- # Contributing to TorchXLA2
1
+ # Contributing to torchax
2
2
 
3
3
  We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
4
4
 
5
- If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
6
-
7
5
 
8
6
  # Developer setup
9
7
 
@@ -19,9 +17,17 @@ conda activate <your_name>
19
17
  pip install --upgrade "jax[cpu]" torch
20
18
  pip install -r test_requirements.txt
21
19
  pip install -e .
22
- pytest test
20
+ pip install pytest-xdist # recommended for running test faster
21
+ pytest -n auto test
23
22
  ```
24
23
 
24
+ ## Setup on GPU or TPU
25
+
26
+ Same as Mac setup, except, if you run test using pytest, please also
27
+ add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs
28
+ test in multiple threads. CPU device can be accessed concurrently where
29
+ TPU devices usually only allow one accesor per process; so it could deadlock.
30
+
25
31
  ### VSCode
26
32
 
27
33
  I use vscode on my Mac. I loosely followed instruction in
@@ -35,4 +41,3 @@ The plugins I installed (a subset of the ones listed above) are:
35
41
 
36
42
  I also changed Python interpreter to point at the one in my conda env.
37
43
  That is all the changes I have.
38
-
torchax/__init__.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import contextlib
2
16
  from typing import List, Dict, Any, Optional
3
17
  import dataclasses
@@ -8,119 +22,132 @@ from torch.utils import _pytree as pytree
8
22
  from torchax import tensor
9
23
  from contextlib import contextmanager
10
24
 
11
- __version__ = "0.0.6"
25
+ __version__ = "0.0.10.dev20251116"
12
26
  VERSION = __version__
13
27
 
28
+ # the "fast path" uses some sparse tensor thingies that currently we
29
+ # don't support
30
+ torch.backends.mha.set_fastpath_enabled(False)
31
+
32
+
14
33
  __all__ = [
15
- 'default_env',
16
- 'extract_jax',
17
- 'enable_globally',
34
+ "default_env",
35
+ "extract_jax",
36
+ "enable_globally",
37
+ "save_checkpoint",
38
+ "load_checkpoint",
18
39
  ]
19
40
 
20
- from jax._src import xla_bridge
41
+ from .checkpoint import save_checkpoint, load_checkpoint
21
42
 
22
- os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
43
+ os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
23
44
 
24
45
  # torchax:oss-begin
25
- if getattr(jax.config, 'jax_pjrt_client_create_options', None):
26
- jax.config.update(
27
- 'jax_pjrt_client_create_options',
28
- f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}')
46
+ if getattr(jax.config, "jax_pjrt_client_create_options", None):
47
+ jax.config.update(
48
+ "jax_pjrt_client_create_options",
49
+ f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
50
+ )
29
51
  # torchax:oss-end
30
52
 
31
53
  env = None
32
54
 
33
55
 
34
56
  def default_env():
35
- global env
57
+ global env
36
58
 
37
- if env is None:
38
- env = tensor.Environment()
39
- return env
59
+ if env is None:
60
+ env = tensor.Environment()
61
+ return env
40
62
 
41
63
 
42
64
  def extract_jax(mod: torch.nn.Module, env=None):
43
- """Returns a pytree of jax.ndarray and a jax callable."""
44
- if env is None:
45
- env = default_env()
46
- states = dict(mod.named_buffers())
47
- states.update(mod.named_parameters())
65
+ """Returns a pytree of jax.ndarray and a jax callable."""
66
+ if env is None:
67
+ env = default_env()
68
+ states = dict(mod.named_buffers())
69
+ states.update(mod.named_parameters())
48
70
 
49
- states = env.t2j_copy(states)
71
+ states = env.t2j_copy(states)
50
72
 
51
- #@jax.jit
52
- def jax_func(states, args, kwargs=None):
53
- (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
54
- with env:
55
- res = torch.func.functional_call(
56
- mod, states, args, kwargs, tie_weights=False)
57
- return env.t2j_iso(res)
73
+ # @jax.jit
74
+ def jax_func(states, args, kwargs=None):
75
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
76
+ with env:
77
+ res = torch.func.functional_call(
78
+ mod, states, args, kwargs, tie_weights=False
79
+ )
80
+ return env.t2j_iso(res)
58
81
 
59
- return states, jax_func
82
+ return states, jax_func
60
83
 
61
84
 
62
85
  def enable_globally():
63
- env = default_env().enable_torch_modes()
64
- return env
86
+ env = default_env().enable_torch_modes()
87
+ return env
65
88
 
66
89
 
67
90
  def disable_globally():
68
- global env
69
- default_env().disable_torch_modes()
91
+ global env
92
+ default_env().disable_torch_modes()
70
93
 
71
94
 
72
95
  @contextlib.contextmanager
73
96
  def disable_temporarily():
74
- prev = default_env().enabled
75
- if prev:
76
- disable_globally()
77
- yield ()
78
- if prev:
79
- enable_globally()
97
+ prev = default_env().enabled
98
+ if prev:
99
+ disable_globally()
100
+ yield ()
101
+ if prev:
102
+ enable_globally()
80
103
 
81
104
 
82
- torch.utils.rename_privateuse1_backend('jax')
105
+ torch.utils.rename_privateuse1_backend("jax")
83
106
  unsupported_dtype = [torch.quint8]
84
107
 
85
108
  import jax
86
109
  import torchax.device_module
87
110
 
88
- torch._register_device_module('jax', torchax.device_module)
111
+ torch._register_device_module("jax", torchax.device_module)
89
112
 
90
113
 
91
114
  def enable_accuracy_mode():
92
- jax.config.update('jax_enable_x64', True)
93
- jax.config.update('jax_default_matmul_precision', 'highest')
94
- default_env().config.internal_respect_torch_return_dtypes = True
115
+ jax.config.update("jax_enable_x64", True)
116
+ jax.config.update("jax_default_matmul_precision", "highest")
117
+ default_env().config.internal_respect_torch_return_dtypes = True
95
118
 
96
119
 
97
120
  def enable_performance_mode():
98
- jax.config.update('jax_enable_x64', False)
99
- jax.config.update('jax_default_matmul_precision', 'default')
100
- default_env().config.internal_respect_torch_return_dtypes = False
121
+ jax.config.update("jax_enable_x64", False)
122
+ jax.config.update("jax_default_matmul_precision", "default")
123
+ default_env().config.internal_respect_torch_return_dtypes = False
101
124
 
102
125
 
103
126
  @dataclasses.dataclass
104
127
  class CompileOptions:
105
- # only valid if compiling nn.Module
106
- methods_to_compile: List[str] = dataclasses.field(
107
- default_factory=lambda: ['forward'])
108
- jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
109
- mode: str = 'jax' # or dynamo or export
128
+ # only valid if compiling nn.Module
129
+ methods_to_compile: List[str] = dataclasses.field(
130
+ default_factory=lambda: ["forward"]
131
+ )
132
+ jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
133
+ mode: str = "jax" # or dynamo or export
110
134
 
111
135
 
112
136
  def compile(fn, options: Optional[CompileOptions] = None):
113
- options = options or CompileOptions()
114
- if options.mode == 'jax':
115
- from torchax import interop
116
- if isinstance(fn, torch.nn.Module):
117
- module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
118
- for n in options.methods_to_compile:
119
- module.make_jitted(n)
120
- return module
121
- else:
122
- return interop.jax_jit(fn)
123
- elif options.mode == 'dynamo':
124
- raise RuntimeError('dynamo mode is not supported yet')
125
- elif options.mode == 'export':
126
- raise RuntimeError('export mode is not supported yet')
137
+ options = options or CompileOptions()
138
+ if options.mode == "jax":
139
+ from torchax import interop
140
+
141
+ if isinstance(fn, torch.nn.Module):
142
+ module = interop.JittableModule(
143
+ fn, extra_jit_args=options.jax_jit_kwargs
144
+ )
145
+ for n in options.methods_to_compile:
146
+ module.make_jitted(n)
147
+ return module
148
+ else:
149
+ return interop.jax_jit(fn)
150
+ elif options.mode == "dynamo":
151
+ raise RuntimeError("dynamo mode is not supported yet")
152
+ elif options.mode == "export":
153
+ raise RuntimeError("export mode is not supported yet")
torchax/amp.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import contextlib
2
16
  import enum
3
17
  import torch
torchax/checkpoint.py ADDED
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import os
17
+ from typing import Any, Dict
18
+ from flax.training import checkpoints
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from . import tensor
23
+
24
+ def _to_jax(pytree):
25
+ def to_jax_array(x):
26
+ if isinstance(x, tensor.Tensor):
27
+ return x.jax()
28
+ elif isinstance(x, torch.Tensor):
29
+ return jnp.asarray(x.cpu().numpy())
30
+ return x
31
+ return jax.tree_util.tree_map(to_jax_array, pytree)
32
+
33
+
34
+ def _to_torch(pytree):
35
+ return jax.tree_util.tree_map(
36
+ lambda x: torch.from_numpy(np.asarray(x))
37
+ if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
38
+
39
+
40
+ def save_checkpoint(state: Dict[str, Any], path: str, step: int):
41
+ """Saves a checkpoint to a file in JAX style.
42
+
43
+ Args:
44
+ state: A dictionary containing the state to save. torch.Tensors will be
45
+ converted to jax.Array.
46
+ path: The path to save the checkpoint to. This is a directory.
47
+ step: The training step.
48
+ """
49
+ state = _to_jax(state)
50
+ checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
51
+
52
+
53
+ def load_checkpoint(path: str) -> Dict[str, Any]:
54
+ """Loads a checkpoint and returns it in JAX format.
55
+
56
+ This function can load both PyTorch-style (single file) and JAX-style
57
+ (directory) checkpoints.
58
+
59
+ If the checkpoint is in PyTorch format, it will be converted to JAX format.
60
+
61
+ Args:
62
+ path: The path to the checkpoint.
63
+
64
+ Returns:
65
+ The loaded state in JAX format (pytree with jax.Array leaves).
66
+ """
67
+ if os.path.isdir(path):
68
+ # JAX-style checkpoint
69
+ state = checkpoints.restore_checkpoint(path, target=None)
70
+ if state is None:
71
+ raise FileNotFoundError(f"No checkpoint found at {path}")
72
+ return state
73
+ elif os.path.isfile(path):
74
+ # PyTorch-style checkpoint
75
+ state = torch.load(path, weights_only=False)
76
+ return _to_jax(state)
77
+ else:
78
+ raise FileNotFoundError(f"No such file or directory: {path}")
79
+
torchax/config.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import dataclasses
2
16
 
3
17
 
torchax/decompositions.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  """This file contains some decompositons that are not available in torch stable.
2
16
 
3
17
  Most likely from Content of
torchax/device_module.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import torch
2
16
 
3
17
 
torchax/export.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # pylint: disable
2
16
  """Utilities for exporting a torch program to jax/stablehlo."""
3
17
  import copy
torchax/flax.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  """Flax interop."""
2
16
 
3
17
  import torch
torchax/interop.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import collections
2
16
  import copy
3
17
  import functools
@@ -182,7 +196,6 @@ def _torch_view(t: JaxValue) -> TorchValue:
182
196
  # t is an object from jax land
183
197
  # view it as-if it's a torch land object
184
198
  if isinstance(t, jax.Array):
185
- # TODO
186
199
  return tensor.Tensor(t, torchax.default_env())
187
200
  if isinstance(t, jnp.dtype):
188
201
  return mappings.j2t_dtype(t)
@@ -237,6 +250,36 @@ def j2t_autograd(fn, call_jax=call_jax):
237
250
  the PyTorch autograd framework by saving the residuals into the context object.
238
251
  """
239
252
 
253
+ # NOTE(qihqi): This function cannot be inlined from the callsite
254
+ # Becuase if it does, then it won't hit the compilation cache for
255
+ # call_jax. Call jax uses functions' id as key.
256
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
257
+ # wrapped pure function, preventing cache collisions between different pure modules.
258
+ def _jax_forward(fn, other, tree_def, tensors):
259
+ """JAX function to compute output and vjp function.
260
+
261
+ primals should be a tuple (args, kwargs).
262
+ """
263
+ import jax
264
+ from jax.tree_util import tree_flatten, tree_unflatten
265
+
266
+ def fn_wrapper(*tensors):
267
+ # Reconstruct the original args and kwargs
268
+ flat_inputs = util.merge(tensors, other)
269
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
270
+ return fn(*args, **kwargs)
271
+
272
+ return jax.vjp(fn_wrapper, *tensors)
273
+
274
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
275
+ """JAX function to compute input gradients.
276
+
277
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
278
+ """
279
+ from jax.tree_util import tree_unflatten
280
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
281
+ return fun_vjp(grad_out)
282
+
240
283
  @wraps(fn)
241
284
  def inner(*args, **kwargs):
242
285
  from jax.tree_util import tree_flatten
@@ -290,36 +333,6 @@ def j2t_autograd(fn, call_jax=call_jax):
290
333
  return inner
291
334
 
292
335
 
293
- # NOTE(qihqi): This function cannot be inlined from the callsite
294
- # Becuase if it does, then it won't hit the compilation cache for
295
- # call_jax. Call jax uses functions' id as key.
296
- def _jax_forward(fn, other, tree_def, tensors):
297
- """JAX function to compute output and vjp function.
298
-
299
- primals should be a tuple (args, kwargs).
300
- """
301
- import jax
302
- from jax.tree_util import tree_flatten, tree_unflatten
303
-
304
- def fn_wrapper(*tensors):
305
- # Reconstruct the original args and kwargs
306
- flat_inputs = util.merge(tensors, other)
307
- args, kwargs = tree_unflatten(tree_def, flat_inputs)
308
- return fn(*args, **kwargs)
309
-
310
- return jax.vjp(fn_wrapper, *tensors)
311
-
312
-
313
- def _jax_backward(vjp_spec, saved_tensors, grad_out):
314
- """JAX function to compute input gradients.
315
-
316
- Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
317
- """
318
- from jax.tree_util import tree_unflatten
319
- fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
320
- return fun_vjp(grad_out)
321
-
322
-
323
336
  fori_loop = torch_view(jax.lax.fori_loop)
324
337
 
325
338
 
torchax/mesh_util.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import jax
2
16
  from jax.sharding import PartitionSpec, NamedSharding
3
17
  import torch
torchax/ops/__init__.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  def all_aten_jax_ops():
2
16
  # to load the ops
3
17
  import torchax.ops.jaten # type: ignore