torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/train.py CHANGED
@@ -13,120 +13,118 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import collections
16
- import functools
17
- import torch
16
+
18
17
  import jax
18
+ import optax
19
+ import torch
20
+
19
21
  import torchax
20
22
  from torchax import interop
21
- from torchax.interop import torch_view, jax_view
22
- import optax
23
+ from torchax.interop import torch_view
23
24
 
24
25
  remat = torch_view(jax.remat)
25
26
  mark_sharding = torch_view(jax.lax.with_sharding_constraint)
26
27
 
27
28
 
28
29
  def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
29
- """Make a function that do one train step given model and loss.
30
-
31
- model_fn: a function representing the model's forward:
32
- i.e. has signature Callable[weights, buffers, args] -> result. Where,
33
- weights is a pytree of trainable parameters
34
- buffers is a pytree of non-trainable parameters / constants
35
- args is the input data loaded from the data set
36
- result is the return value of the model
37
- loss_fn: a function to compute loss.
38
- i.e. it has signature of Callable[result, label] -> loss
39
- where, result is what model_fn returned
40
- loss is loaded from the dataloader.
41
- optax_optimizer: the optimizer from optax library. for example, optax.adam
42
- remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
43
- to do gradient checkpointing. If None, then it means checkpoint everything.
44
- """
45
- env = torchax.default_env()
46
-
47
- def loss(weights, buffers, args, label): # inputs are XLATensor
48
- with env, jax.named_scope("compute_loss"):
49
- res = model_fn(weights, buffers, args)
50
- l = loss_fn(res, label)
51
- return l
52
-
53
- # loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
54
- grad_fn = interop.jax_value_and_grad(loss)
55
-
56
- def step(weights, buffers, opt_state, args, label): # inputs are array
57
- with jax.named_scope("compute_gradient"):
58
- loss, gradient = grad_fn(weights, buffers, args, label)
59
-
60
- with jax.named_scope("optimizer_updates"):
61
- updates, opt_state = interop.call_jax(
62
- optax_optimizer.update, gradient, opt_state, weights
63
- )
64
- weights = interop.call_jax(optax.apply_updates, weights, updates)
65
- return loss, weights, opt_state
66
-
67
- # TODO: apply jax.jit so the user don't have to.
68
- return step
30
+ """Make a function that do one train step given model and loss.
31
+
32
+ model_fn: a function representing the model's forward:
33
+ i.e. has signature Callable[weights, buffers, args] -> result. Where,
34
+ weights is a pytree of trainable parameters
35
+ buffers is a pytree of non-trainable parameters / constants
36
+ args is the input data loaded from the data set
37
+ result is the return value of the model
38
+ loss_fn: a function to compute loss.
39
+ i.e. it has signature of Callable[result, label] -> loss
40
+ where, result is what model_fn returned
41
+ loss is loaded from the dataloader.
42
+ optax_optimizer: the optimizer from optax library. for example, optax.adam
43
+ remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
44
+ to do gradient checkpointing. If None, then it means checkpoint everything.
45
+ """
46
+ env = torchax.default_env()
47
+
48
+ def loss(weights, buffers, args, label): # inputs are XLATensor
49
+ with env, jax.named_scope("compute_loss"):
50
+ res = model_fn(weights, buffers, args)
51
+ l = loss_fn(res, label) # noqa: E741
52
+ return l
53
+
54
+ # loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
55
+ grad_fn = interop.jax_value_and_grad(loss)
56
+
57
+ def step(weights, buffers, opt_state, args, label): # inputs are array
58
+ with jax.named_scope("compute_gradient"):
59
+ loss, gradient = grad_fn(weights, buffers, args, label)
60
+
61
+ with jax.named_scope("optimizer_updates"):
62
+ updates, opt_state = interop.call_jax(
63
+ optax_optimizer.update, gradient, opt_state, weights
64
+ )
65
+ weights = interop.call_jax(optax.apply_updates, weights, updates)
66
+ return loss, weights, opt_state
67
+
68
+ # TODO: apply jax.jit so the user don't have to.
69
+ return step
69
70
 
70
71
 
71
72
  class Container:
72
- pass
73
+ pass
73
74
 
74
75
 
75
76
  class ScannedModule(torch.nn.Module):
76
- def __init__(self, module_list, checkpoint_policy=None):
77
- super().__init__()
78
-
79
- self.c = None
80
- assert module_list
81
- self.c = Container()
82
- self.c.one_mod = module_list[0]
83
- self.checkpoint_policy = checkpoint_policy
84
-
85
- weights = self._stack_layer_weights(module_list)
86
- self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
87
- self.params = torch.nn.ParameterDict(
88
- {self._param_name_new(k): v for k, v in weights.items()}
89
- )
90
-
91
- def _stack_layer_weights(self, module_list):
92
- # Create weights such that, for every [n, m] weights
93
- # becomes [k, n, m] where k is number of layer
94
- # i.e. stacking layer weights together
95
- temp = collections.defaultdict(list)
96
- for m in module_list:
97
- for k, v in m.state_dict().items():
98
- temp[k].append(v)
99
- res = {k: torch.stack(v) for k, v in temp.items()}
100
- return res
101
-
102
- def _param_name_new(self, old):
103
- return "___".join(old.split("."))
104
-
105
- def _param_name_old(self, new):
106
- return ".".join(new.split("___"))
107
-
108
- def forward(self, *args, **kwargs):
109
- assert not kwargs
110
- weights = {
111
- k: self.params[self._param_name_new(k)]
112
- for k in self.layer_weights_keys
113
- }
114
- scan = interop.torch_view(jax.lax.scan)
115
-
116
- def eval_one_layer(args, weight):
117
- # unpack args
118
- h, *rest = args
119
- newh = torch.func.functional_call(self.c.one_mod, weight, args)
120
- # next layer's input; and residual to be added to list
121
- return (newh, *rest), None
122
-
123
- _eval_one_layer = interop.gradient_checkpoint(
124
- eval_one_layer,
125
- kwargs={"policy": self.checkpoint_policy},
126
- )
127
- h, _ = scan(
128
- _eval_one_layer,
129
- args,
130
- weights,
131
- )
132
- return h[0]
77
+ def __init__(self, module_list, checkpoint_policy=None):
78
+ super().__init__()
79
+
80
+ self.c = None
81
+ assert module_list
82
+ self.c = Container()
83
+ self.c.one_mod = module_list[0]
84
+ self.checkpoint_policy = checkpoint_policy
85
+
86
+ weights = self._stack_layer_weights(module_list)
87
+ self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
88
+ self.params = torch.nn.ParameterDict(
89
+ {self._param_name_new(k): v for k, v in weights.items()}
90
+ )
91
+
92
+ def _stack_layer_weights(self, module_list):
93
+ # Create weights such that, for every [n, m] weights
94
+ # becomes [k, n, m] where k is number of layer
95
+ # i.e. stacking layer weights together
96
+ temp = collections.defaultdict(list)
97
+ for m in module_list:
98
+ for k, v in m.state_dict().items():
99
+ temp[k].append(v)
100
+ res = {k: torch.stack(v) for k, v in temp.items()}
101
+ return res
102
+
103
+ def _param_name_new(self, old):
104
+ return "___".join(old.split("."))
105
+
106
+ def _param_name_old(self, new):
107
+ return ".".join(new.split("___"))
108
+
109
+ def forward(self, *args, **kwargs):
110
+ assert not kwargs
111
+ weights = {k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys}
112
+ scan = interop.torch_view(jax.lax.scan)
113
+
114
+ def eval_one_layer(args, weight):
115
+ # unpack args
116
+ h, *rest = args
117
+ newh = torch.func.functional_call(self.c.one_mod, weight, args)
118
+ # next layer's input; and residual to be added to list
119
+ return (newh, *rest), None
120
+
121
+ _eval_one_layer = interop.gradient_checkpoint(
122
+ eval_one_layer,
123
+ kwargs={"policy": self.checkpoint_policy},
124
+ )
125
+ h, _ = scan(
126
+ _eval_one_layer,
127
+ args,
128
+ weights,
129
+ )
130
+ return h[0]
torchax/types.py CHANGED
@@ -12,15 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Callable, Any, Union, ParamSpec, TypeAlias
16
- import torch
15
+ from collections.abc import Callable
16
+ from typing import Any, ParamSpec, TypeAlias, Union
17
+
17
18
  import jax
18
19
  import jax.numpy as jnp
19
- import sys
20
+ import torch
20
21
 
21
- P = ParamSpec('P')
22
+ P = ParamSpec("P")
22
23
 
23
- TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]
24
+ TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, "TorchCallable", Any]
24
25
  TorchCallable: TypeAlias = Callable[P, TorchValue]
25
- JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any]
26
- JaxCallable: TypeAlias = Callable[P, JaxValue]
26
+ JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, "JaxCallable", Any]
27
+ JaxCallable: TypeAlias = Callable[P, JaxValue]
torchax/util.py CHANGED
@@ -12,11 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Callable
15
+ from collections.abc import Callable
16
+ from typing import Any
16
17
 
17
18
 
18
- def partition(original: list[Any],
19
- func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]:
19
+ def partition(
20
+ original: list[Any], func: Callable[[Any], bool]
21
+ ) -> tuple[list[Any], list[Any]]:
20
22
  """Partitions elements into two parallel lists based on a predicate function.
21
23
 
22
24
  Iterates through the 'original' list, applying 'func' to each element 'a'.
@@ -97,6 +99,6 @@ def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
97
99
  """
98
100
  assert len(list1) == len(list2)
99
101
  res = []
100
- for a, b in zip(list1, list2):
102
+ for a, b in zip(list1, list2, strict=False):
101
103
  res.append(b if a is None else a)
102
104
  return res