torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202612.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251116.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE +0 -0
torchax/train.py
CHANGED
|
@@ -13,120 +13,118 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import collections
|
|
16
|
-
|
|
17
|
-
import torch
|
|
16
|
+
|
|
18
17
|
import jax
|
|
18
|
+
import optax
|
|
19
|
+
import torch
|
|
20
|
+
|
|
19
21
|
import torchax
|
|
20
22
|
from torchax import interop
|
|
21
|
-
from torchax.interop import torch_view
|
|
22
|
-
import optax
|
|
23
|
+
from torchax.interop import torch_view
|
|
23
24
|
|
|
24
25
|
remat = torch_view(jax.remat)
|
|
25
26
|
mark_sharding = torch_view(jax.lax.with_sharding_constraint)
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
30
|
+
"""Make a function that do one train step given model and loss.
|
|
31
|
+
|
|
32
|
+
model_fn: a function representing the model's forward:
|
|
33
|
+
i.e. has signature Callable[weights, buffers, args] -> result. Where,
|
|
34
|
+
weights is a pytree of trainable parameters
|
|
35
|
+
buffers is a pytree of non-trainable parameters / constants
|
|
36
|
+
args is the input data loaded from the data set
|
|
37
|
+
result is the return value of the model
|
|
38
|
+
loss_fn: a function to compute loss.
|
|
39
|
+
i.e. it has signature of Callable[result, label] -> loss
|
|
40
|
+
where, result is what model_fn returned
|
|
41
|
+
loss is loaded from the dataloader.
|
|
42
|
+
optax_optimizer: the optimizer from optax library. for example, optax.adam
|
|
43
|
+
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
|
|
44
|
+
to do gradient checkpointing. If None, then it means checkpoint everything.
|
|
45
|
+
"""
|
|
46
|
+
env = torchax.default_env()
|
|
47
|
+
|
|
48
|
+
def loss(weights, buffers, args, label): # inputs are XLATensor
|
|
49
|
+
with env, jax.named_scope("compute_loss"):
|
|
50
|
+
res = model_fn(weights, buffers, args)
|
|
51
|
+
l = loss_fn(res, label) # noqa: E741
|
|
52
|
+
return l
|
|
53
|
+
|
|
54
|
+
# loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
|
|
55
|
+
grad_fn = interop.jax_value_and_grad(loss)
|
|
56
|
+
|
|
57
|
+
def step(weights, buffers, opt_state, args, label): # inputs are array
|
|
58
|
+
with jax.named_scope("compute_gradient"):
|
|
59
|
+
loss, gradient = grad_fn(weights, buffers, args, label)
|
|
60
|
+
|
|
61
|
+
with jax.named_scope("optimizer_updates"):
|
|
62
|
+
updates, opt_state = interop.call_jax(
|
|
63
|
+
optax_optimizer.update, gradient, opt_state, weights
|
|
64
|
+
)
|
|
65
|
+
weights = interop.call_jax(optax.apply_updates, weights, updates)
|
|
66
|
+
return loss, weights, opt_state
|
|
67
|
+
|
|
68
|
+
# TODO: apply jax.jit so the user don't have to.
|
|
69
|
+
return step
|
|
69
70
|
|
|
70
71
|
|
|
71
72
|
class Container:
|
|
72
|
-
|
|
73
|
+
pass
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
class ScannedModule(torch.nn.Module):
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
weights,
|
|
131
|
-
)
|
|
132
|
-
return h[0]
|
|
77
|
+
def __init__(self, module_list, checkpoint_policy=None):
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self.c = None
|
|
81
|
+
assert module_list
|
|
82
|
+
self.c = Container()
|
|
83
|
+
self.c.one_mod = module_list[0]
|
|
84
|
+
self.checkpoint_policy = checkpoint_policy
|
|
85
|
+
|
|
86
|
+
weights = self._stack_layer_weights(module_list)
|
|
87
|
+
self.layer_weights_keys = list(self.c.one_mod.state_dict().keys())
|
|
88
|
+
self.params = torch.nn.ParameterDict(
|
|
89
|
+
{self._param_name_new(k): v for k, v in weights.items()}
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _stack_layer_weights(self, module_list):
|
|
93
|
+
# Create weights such that, for every [n, m] weights
|
|
94
|
+
# becomes [k, n, m] where k is number of layer
|
|
95
|
+
# i.e. stacking layer weights together
|
|
96
|
+
temp = collections.defaultdict(list)
|
|
97
|
+
for m in module_list:
|
|
98
|
+
for k, v in m.state_dict().items():
|
|
99
|
+
temp[k].append(v)
|
|
100
|
+
res = {k: torch.stack(v) for k, v in temp.items()}
|
|
101
|
+
return res
|
|
102
|
+
|
|
103
|
+
def _param_name_new(self, old):
|
|
104
|
+
return "___".join(old.split("."))
|
|
105
|
+
|
|
106
|
+
def _param_name_old(self, new):
|
|
107
|
+
return ".".join(new.split("___"))
|
|
108
|
+
|
|
109
|
+
def forward(self, *args, **kwargs):
|
|
110
|
+
assert not kwargs
|
|
111
|
+
weights = {k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys}
|
|
112
|
+
scan = interop.torch_view(jax.lax.scan)
|
|
113
|
+
|
|
114
|
+
def eval_one_layer(args, weight):
|
|
115
|
+
# unpack args
|
|
116
|
+
h, *rest = args
|
|
117
|
+
newh = torch.func.functional_call(self.c.one_mod, weight, args)
|
|
118
|
+
# next layer's input; and residual to be added to list
|
|
119
|
+
return (newh, *rest), None
|
|
120
|
+
|
|
121
|
+
_eval_one_layer = interop.gradient_checkpoint(
|
|
122
|
+
eval_one_layer,
|
|
123
|
+
kwargs={"policy": self.checkpoint_policy},
|
|
124
|
+
)
|
|
125
|
+
h, _ = scan(
|
|
126
|
+
_eval_one_layer,
|
|
127
|
+
args,
|
|
128
|
+
weights,
|
|
129
|
+
)
|
|
130
|
+
return h[0]
|
torchax/types.py
CHANGED
|
@@ -12,15 +12,16 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
16
|
-
import
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from typing import Any, ParamSpec, TypeAlias, Union
|
|
17
|
+
|
|
17
18
|
import jax
|
|
18
19
|
import jax.numpy as jnp
|
|
19
|
-
import
|
|
20
|
+
import torch
|
|
20
21
|
|
|
21
|
-
P = ParamSpec(
|
|
22
|
+
P = ParamSpec("P")
|
|
22
23
|
|
|
23
|
-
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype,
|
|
24
|
+
TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, "TorchCallable", Any]
|
|
24
25
|
TorchCallable: TypeAlias = Callable[P, TorchValue]
|
|
25
|
-
JaxValue: TypeAlias = Union[jax.Array, jnp.dtype,
|
|
26
|
-
JaxCallable: TypeAlias = Callable[P, JaxValue]
|
|
26
|
+
JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, "JaxCallable", Any]
|
|
27
|
+
JaxCallable: TypeAlias = Callable[P, JaxValue]
|
torchax/util.py
CHANGED
|
@@ -12,11 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from typing import Any
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
def partition(
|
|
19
|
-
|
|
19
|
+
def partition(
|
|
20
|
+
original: list[Any], func: Callable[[Any], bool]
|
|
21
|
+
) -> tuple[list[Any], list[Any]]:
|
|
20
22
|
"""Partitions elements into two parallel lists based on a predicate function.
|
|
21
23
|
|
|
22
24
|
Iterates through the 'original' list, applying 'func' to each element 'a'.
|
|
@@ -97,6 +99,6 @@ def merge(list1: list[Any], list2: list[Any]) -> list[Any]:
|
|
|
97
99
|
"""
|
|
98
100
|
assert len(list1) == len(list2)
|
|
99
101
|
res = []
|
|
100
|
-
for a, b in zip(list1, list2):
|
|
102
|
+
for a, b in zip(list1, list2, strict=False):
|
|
101
103
|
res.append(b if a is None else a)
|
|
102
104
|
return res
|