torchax 0.0.6__py3-none-any.whl → 0.0.10.dev20251116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  """
2
16
  Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
3
17
  """
torchax/ops/mappings.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from jax import dlpack as jaxdl
2
16
  import jax.numpy as jnp
3
17
  import numpy
@@ -6,6 +20,14 @@ import torch.func
6
20
  import torch.utils.dlpack as torchdl
7
21
  import torch.utils._mode_utils as mode_utils
8
22
 
23
+ NUMPY_UNSUPPORTED_DTYPES = {
24
+ torch.bfloat16: jnp.bfloat16,
25
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
26
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
27
+ torch.float8_e5m2: jnp.float8_e5m2,
28
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
29
+ }
30
+
9
31
 
10
32
  def t2j(t, use_dlpack=True):
11
33
  is_bool = False
@@ -28,14 +50,14 @@ def t2j(t, use_dlpack=True):
28
50
  if res is None:
29
51
  # https://github.com/google/jax/issues/7657
30
52
  # https://github.com/google/jax/issues/17784
31
- if t.dtype == torch.bfloat16:
53
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
32
54
  nparray = (t.cpu().detach().to(torch.float32).numpy()
33
- ) # numpy don't support bfloat16
55
+ ) # handle dtypes not supported by numpy
34
56
  else:
35
57
  nparray = t.cpu().detach().numpy()
36
58
  res = jnp.asarray(nparray)
37
- if t.dtype == torch.bfloat16:
38
- res = res.astype(jnp.bfloat16)
59
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
60
+ res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
39
61
 
40
62
  if is_bool:
41
63
  res = res.astype(jnp.bool_)
torchax/ops/op_base.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  import jax
3
17
  import jax.numpy as jnp
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import dataclasses
2
16
  import logging
3
17
  from torchax.types import JaxCallable, TorchCallable
torchax/tensor.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import threading
2
16
  import logging
3
17
  import sys
@@ -357,16 +371,26 @@ class Environment(contextlib.ContextDecorator):
357
371
 
358
372
  _prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
359
373
  self._property = threading.local()
360
- self._property.content = [
361
- RuntimeProperty(
362
- mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
363
- ]
374
+ self._initial_content = RuntimeProperty(
375
+ mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
364
376
 
365
377
  @property
366
378
  def param(self):
379
+ if not hasattr(self._property, 'content'):
380
+ self._property.content = [
381
+ self._initial_content
382
+ ]
367
383
  return self._property.content[-1]
368
384
 
369
385
  def manual_seed(self, key):
386
+ if isinstance(key, torch.Tensor):
387
+ assert key.ndim == 0, 'manual seed can only take scalars'
388
+ assert not key.dtype.is_floating_point, 'manual seed can only be integers'
389
+
390
+ if isinstance(key, Tensor):
391
+ key = key._elem
392
+ else:
393
+ key = key.item()
370
394
  jax_key = jax.random.PRNGKey(key)
371
395
  new_prop = self.param.override(prng=jax_key)
372
396
  self._property.content.append(new_prop)
@@ -469,12 +493,12 @@ class Environment(contextlib.ContextDecorator):
469
493
  arr = self.t2j_copy(the_tensor)
470
494
  res = Tensor(arr, self, the_tensor.requires_grad)
471
495
 
472
- if new_dtype is not None and new_dtype != the_tensor.dtype:
473
- if isinstance(the_tensor, Tensor):
496
+ if new_dtype is not None and new_dtype != res.dtype:
497
+ if isinstance(res, Tensor):
474
498
  res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
475
499
  else:
476
500
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
477
- return the_tensor.to(device=new_device, dtype=new_dtype)
501
+ return res.to(device=new_device, dtype=new_dtype)
478
502
  return res
479
503
 
480
504
  def get_and_rotate_prng_key(self,
@@ -634,14 +658,14 @@ class Environment(contextlib.ContextDecorator):
634
658
 
635
659
  def t2j_iso(self, torchtensors):
636
660
  """Convert torchax Tensor to jax array.
637
-
661
+
638
662
  This function will not copy, will just unwrap the inner jax array out.
639
663
  Note: iso is short for "isomorphic"
640
664
  """
641
665
 
642
666
  def to_jax(x):
643
667
  if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
644
- x, Tensor):
668
+ x, Tensor) and not isinstance(x, View):
645
669
  if x.squeeze().ndim == 0:
646
670
  return x.item()
647
671
  if isinstance(
@@ -667,7 +691,7 @@ class Environment(contextlib.ContextDecorator):
667
691
 
668
692
  def j2t_iso(self, jaxarray):
669
693
  """Convert jax array to torchax Tensor.
670
-
694
+
671
695
  This function will not copy, will just wrap the jax array with a torchax Tensor
672
696
  Note: iso is short for "isomorphic"
673
697
  """
@@ -676,7 +700,7 @@ class Environment(contextlib.ContextDecorator):
676
700
 
677
701
  def j2t_copy(self, args):
678
702
  """Convert torch.Tensor in cpu to a jax array
679
-
703
+
680
704
  This might involves copying the data (depending if dlpack is enabled)
681
705
  """
682
706
  return torch_pytree.tree_map_only(
@@ -686,7 +710,7 @@ class Environment(contextlib.ContextDecorator):
686
710
 
687
711
  def t2j_copy(self, args):
688
712
  """Convert jax array to torch.Tensor in cpu.
689
-
713
+
690
714
  This might involves copying the data (depending if dlpack is enabled)
691
715
  """
692
716
  return torch_pytree.tree_map_only(
@@ -694,13 +718,14 @@ class Environment(contextlib.ContextDecorator):
694
718
  lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
695
719
  args)
696
720
 
697
- def override_op_definition(self, op_to_override, op_impl):
721
+ def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
698
722
  self._ops[op_to_override] = ops_registry.Operator(
699
723
  op_to_override,
700
724
  op_impl,
701
725
  is_jax_function=False,
702
726
  is_user_defined=True,
703
727
  needs_env=False,
728
+ is_view_op=is_view_op,
704
729
  )
705
730
 
706
731
  @contextlib.contextmanager
torchax/train.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import collections
2
16
  import functools
3
17
  import torch
@@ -12,106 +26,107 @@ mark_sharding = torch_view(jax.lax.with_sharding_constraint)
12
26
 
13
27
 
14
28
  def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
15
- """Make a function that do one train step given model and loss.
16
-
17
- model_fn: a function representing the model's forward:
18
- i.e. has signature Callable[weights, buffers, args] -> result. Where,
19
- weights is a pytree of trainable parameters
20
- buffers is a pytree of non-trainable parameters / constants
21
- args is the input data loaded from the data set
22
- result is the return value of the model
23
- loss_fn: a function to compute loss.
24
- i.e. it has signature of Callable[result, label] -> loss
25
- where, result is what model_fn returned
26
- loss is loaded from the dataloader.
27
- optax_optimizer: the optimizer from optax library. for example, optax.adam
28
- remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
29
- to do gradient checkpointing. If None, then it means checkpoint everything.
30
- """
31
- env = torchax.default_env()
32
-
33
- def loss(weights, buffers, args, label): # inputs are XLATensor
34
- with env, jax.named_scope('compute_loss'):
35
- res = model_fn(weights, buffers, args)
36
- l = loss_fn(res, label)
37
- return l
38
-
39
- loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
40
- grad_fn = interop.jax_value_and_grad(loss)
41
-
42
- def step(weights, buffers, opt_state, args, label): #inputs are array
43
- with jax.named_scope('compute_gradient'):
44
- loss, gradient = grad_fn(weights, buffers, args, label)
45
-
46
- with jax.named_scope("optimizer_updates"):
47
- updates, opt_state = interop.call_jax(optax_optimizer.update, gradient,
48
- opt_state, weights)
49
- weights = interop.call_jax(optax.apply_updates, weights, updates)
50
- return loss, weights, opt_state
51
-
52
- # TODO: apply jax.jit so the user don't have to.
53
- return step
29
+ """Make a function that do one train step given model and loss.
30
+
31
+ model_fn: a function representing the model's forward:
32
+ i.e. has signature Callable[weights, buffers, args] -> result. Where,
33
+ weights is a pytree of trainable parameters
34
+ buffers is a pytree of non-trainable parameters / constants
35
+ args is the input data loaded from the data set
36
+ result is the return value of the model
37
+ loss_fn: a function to compute loss.
38
+ i.e. it has signature of Callable[result, label] -> loss
39
+ where, result is what model_fn returned
40
+ loss is loaded from the dataloader.
41
+ optax_optimizer: the optimizer from optax library. for example, optax.adam
42
+ remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
43
+ to do gradient checkpointing. If None, then it means checkpoint everything.
44
+ """
45
+ env = torchax.default_env()
46
+
47
+ def loss(weights, buffers, args, label): # inputs are XLATensor
48
+ with env, jax.named_scope("compute_loss"):
49
+ res = model_fn(weights, buffers, args)
50
+ l = loss_fn(res, label)
51
+ return l
52
+
53
+ # loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
54
+ grad_fn = interop.jax_value_and_grad(loss)
55
+
56
+ def step(weights, buffers, opt_state, args, label): # inputs are array
57
+ with jax.named_scope("compute_gradient"):
58
+ loss, gradient = grad_fn(weights, buffers, args, label)
59
+
60
+ with jax.named_scope("optimizer_updates"):
61
+ updates, opt_state = interop.call_jax(
62
+ optax_optimizer.update, gradient, opt_state, weights
63
+ )
64
+ weights = interop.call_jax(optax.apply_updates, weights, updates)
65
+ return loss, weights, opt_state
66
+
67
+ # TODO: apply jax.jit so the user don't have to.
68
+ return step
54
69
 
55
70
 
56
71
  class Container:
57
- pass
72
+ pass
58
73
 
59
74
 
60
75
  class ScannedModule(torch.nn.Module):
61
-
62
- def __init__(self, module_list, checkpoint_policy=None):
63
- super().__init__()
64
-
65
- self.c = None
66
- assert module_list
67
- self.c = Container()
68
- self.c.one_mod = module_list[0]
69
- self.checkpoint_policy = checkpoint_policy
70
-
71
- weights = self._stack_layer_weights(module_list)
72
- self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
73
- self.params = torch.nn.ParameterDict({
74
- self._param_name_new(k): v for k, v in weights.items()
75
- })
76
-
77
- def _stack_layer_weights(self, module_list):
78
- # Create weights such that, for every [n, m] weights
79
- # becomes [k, n, m] where k is number of layer
80
- # i.e. stacking layer weights together
81
- temp = collections.defaultdict(list)
82
- for m in module_list:
83
- for k, v in m.state_dict().items():
84
- temp[k].append(v)
85
- res = {k: torch.stack(v) for k, v in temp.items()}
86
- return res
87
-
88
- def _param_name_new(self, old):
89
- return '___'.join(old.split('.'))
90
-
91
- def _param_name_old(self, new):
92
- return '.'.join(new.split('___'))
93
-
94
- def forward(self, *args, **kwargs):
95
- assert not kwargs
96
- weights = {
97
- k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys
98
- }
99
- scan = interop.torch_view(jax.lax.scan)
100
-
101
- def eval_one_layer(args, weight):
102
- # unpack args
103
- h, *rest = args
104
- newh = torch.func.functional_call(self.c.one_mod, weight, args)
105
- # next layer's input; and residual to be added to list
106
- return (newh, *rest), None
107
-
108
- _eval_one_layer = interop.gradient_checkpoint(
109
- eval_one_layer,
110
- kwargs={'policy': self.checkpoint_policy},
111
- )
112
- h, _ = scan(
113
- _eval_one_layer,
114
- args,
115
- weights,
116
- )
117
- return h[0]
76
+ def __init__(self, module_list, checkpoint_policy=None):
77
+ super().__init__()
78
+
79
+ self.c = None
80
+ assert module_list
81
+ self.c = Container()
82
+ self.c.one_mod = module_list[0]
83
+ self.checkpoint_policy = checkpoint_policy
84
+
85
+ weights = self._stack_layer_weights(module_list)
86
+ self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
87
+ self.params = torch.nn.ParameterDict(
88
+ {self._param_name_new(k): v for k, v in weights.items()}
89
+ )
90
+
91
+ def _stack_layer_weights(self, module_list):
92
+ # Create weights such that, for every [n, m] weights
93
+ # becomes [k, n, m] where k is number of layer
94
+ # i.e. stacking layer weights together
95
+ temp = collections.defaultdict(list)
96
+ for m in module_list:
97
+ for k, v in m.state_dict().items():
98
+ temp[k].append(v)
99
+ res = {k: torch.stack(v) for k, v in temp.items()}
100
+ return res
101
+
102
+ def _param_name_new(self, old):
103
+ return "___".join(old.split("."))
104
+
105
+ def _param_name_old(self, new):
106
+ return ".".join(new.split("___"))
107
+
108
+ def forward(self, *args, **kwargs):
109
+ assert not kwargs
110
+ weights = {
111
+ k: self.params[self._param_name_new(k)]
112
+ for k in self.layer_weights_keys
113
+ }
114
+ scan = interop.torch_view(jax.lax.scan)
115
+
116
+ def eval_one_layer(args, weight):
117
+ # unpack args
118
+ h, *rest = args
119
+ newh = torch.func.functional_call(self.c.one_mod, weight, args)
120
+ # next layer's input; and residual to be added to list
121
+ return (newh, *rest), None
122
+
123
+ _eval_one_layer = interop.gradient_checkpoint(
124
+ eval_one_layer,
125
+ kwargs={"policy": self.checkpoint_policy},
126
+ )
127
+ h, _ = scan(
128
+ _eval_one_layer,
129
+ args,
130
+ weights,
131
+ )
132
+ return h[0]
torchax/types.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Callable, Any, Union, ParamSpec, TypeAlias
2
16
  import torch
3
17
  import jax
torchax/util.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Any, Callable
2
16
 
3
17
 
torchax/view.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import torch
2
16
  import torch.utils._pytree as torch_pytree
3
17
  import jax