torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202617.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251116.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/licenses/LICENSE +0 -0
torchax/interop.py
CHANGED
|
@@ -15,20 +15,24 @@
|
|
|
15
15
|
import collections
|
|
16
16
|
import copy
|
|
17
17
|
import functools
|
|
18
|
-
import torch
|
|
19
|
-
from inspect import signature
|
|
20
18
|
from functools import wraps
|
|
21
|
-
from
|
|
19
|
+
from inspect import signature
|
|
20
|
+
|
|
22
21
|
import jax
|
|
23
22
|
import jax.numpy as jnp
|
|
23
|
+
import torch
|
|
24
24
|
from jax import tree_util as pytree
|
|
25
|
-
from
|
|
26
|
-
|
|
27
|
-
from torchax import util
|
|
28
|
-
from torchax.ops import mappings
|
|
25
|
+
from torch.nn.utils import stateless as torch_stateless
|
|
26
|
+
|
|
29
27
|
import torchax
|
|
28
|
+
from torchax import tensor, util
|
|
29
|
+
from torchax.ops import mappings
|
|
30
|
+
from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
try:
|
|
33
|
+
from jax import shard_map as shard_map # for jax since v0.8.0
|
|
34
|
+
except ImportError:
|
|
35
|
+
from jax.experimental.shard_map import shard_map
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
def extract_all_buffers(m: torch.nn.Module):
|
|
@@ -39,7 +43,7 @@ def extract_all_buffers(m: torch.nn.Module):
|
|
|
39
43
|
for k in dir(module):
|
|
40
44
|
try:
|
|
41
45
|
v = getattr(module, k)
|
|
42
|
-
except:
|
|
46
|
+
except Exception:
|
|
43
47
|
continue
|
|
44
48
|
qual_name = prefix + k
|
|
45
49
|
if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
|
|
@@ -47,14 +51,13 @@ def extract_all_buffers(m: torch.nn.Module):
|
|
|
47
51
|
elif isinstance(v, torch.Tensor):
|
|
48
52
|
buffers[qual_name] = v
|
|
49
53
|
for name, child in module.named_children():
|
|
50
|
-
extract_one(child, prefix + name +
|
|
54
|
+
extract_one(child, prefix + name + ".")
|
|
51
55
|
|
|
52
|
-
extract_one(m,
|
|
56
|
+
extract_one(m, "")
|
|
53
57
|
return params, buffers
|
|
54
58
|
|
|
55
59
|
|
|
56
60
|
def set_all_buffers(m, params, buffers):
|
|
57
|
-
|
|
58
61
|
def set_one(module, prefix):
|
|
59
62
|
for k in dir(module):
|
|
60
63
|
qual_name = prefix + k
|
|
@@ -64,17 +67,15 @@ def set_all_buffers(m, params, buffers):
|
|
|
64
67
|
print(k, potential_v)
|
|
65
68
|
setattr(module, k, torch.nn.Parameter(potential_v))
|
|
66
69
|
for name, child in module.named_children():
|
|
67
|
-
set_one(child, prefix + name +
|
|
70
|
+
set_one(child, prefix + name + ".")
|
|
68
71
|
|
|
69
|
-
set_one(m,
|
|
72
|
+
set_one(m, "")
|
|
70
73
|
|
|
71
74
|
|
|
72
75
|
class JittableModule(torch.nn.Module):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
extra_jit_args={},
|
|
77
|
-
dedup_parameters=True):
|
|
76
|
+
def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True):
|
|
77
|
+
if extra_jit_args is None:
|
|
78
|
+
extra_jit_args = {}
|
|
78
79
|
super().__init__()
|
|
79
80
|
self.params, self.buffers = extract_all_buffers(m)
|
|
80
81
|
self._model = m
|
|
@@ -119,7 +120,7 @@ class JittableModule(torch.nn.Module):
|
|
|
119
120
|
else:
|
|
120
121
|
if not callable(method_or_name):
|
|
121
122
|
raise TypeError(
|
|
122
|
-
|
|
123
|
+
f"method_or_name should be a callable or a string, got {type(method_or_name)}"
|
|
123
124
|
)
|
|
124
125
|
method = method_or_name
|
|
125
126
|
args = (self._model,) + args
|
|
@@ -130,8 +131,8 @@ class JittableModule(torch.nn.Module):
|
|
|
130
131
|
def jittable_call(self, method_name: str, *args, **kwargs):
|
|
131
132
|
if method_name not in self._jitted:
|
|
132
133
|
jitted = jax_jit(
|
|
133
|
-
|
|
134
|
-
|
|
134
|
+
functools.partial(self.functional_call, method_name),
|
|
135
|
+
kwargs_for_jax_jit=self._extra_jit_args,
|
|
135
136
|
)
|
|
136
137
|
|
|
137
138
|
def jitted_forward(*args, **kwargs):
|
|
@@ -141,10 +142,10 @@ class JittableModule(torch.nn.Module):
|
|
|
141
142
|
return self._jitted[method_name](*args, **kwargs)
|
|
142
143
|
|
|
143
144
|
def forward(self, *args, **kwargs):
|
|
144
|
-
return self.jittable_call(
|
|
145
|
+
return self.jittable_call("forward", *args, **kwargs)
|
|
145
146
|
|
|
146
147
|
def __getattr__(self, key):
|
|
147
|
-
if key ==
|
|
148
|
+
if key == "_model":
|
|
148
149
|
return super().__getattr__(key)
|
|
149
150
|
if key in self._jitted:
|
|
150
151
|
return self._jitted[key]
|
|
@@ -152,8 +153,9 @@ class JittableModule(torch.nn.Module):
|
|
|
152
153
|
|
|
153
154
|
def make_jitted(self, key):
|
|
154
155
|
jitted = jax_jit(
|
|
155
|
-
|
|
156
|
-
|
|
156
|
+
functools.partial(self.functional_call, key),
|
|
157
|
+
kwargs_for_jax_jit=self._extra_jit_args,
|
|
158
|
+
)
|
|
157
159
|
|
|
158
160
|
def call(*args, **kwargs):
|
|
159
161
|
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
@@ -162,7 +164,6 @@ class JittableModule(torch.nn.Module):
|
|
|
162
164
|
|
|
163
165
|
|
|
164
166
|
class CompileMixin:
|
|
165
|
-
|
|
166
167
|
def functional_call(self, method, params, buffers, *args, **kwargs):
|
|
167
168
|
kwargs = kwargs or {}
|
|
168
169
|
params_copy = copy.copy(params)
|
|
@@ -172,24 +173,23 @@ class CompileMixin:
|
|
|
172
173
|
return res
|
|
173
174
|
|
|
174
175
|
def jit(self, method):
|
|
175
|
-
jitted = jax_jit(functools.partial(self.functional_call, method_name))
|
|
176
|
+
jitted = jax_jit(functools.partial(self.functional_call, method_name)) # noqa: F821
|
|
176
177
|
|
|
177
178
|
def call(*args, **kwargs):
|
|
178
|
-
return jitted(self.named_paramters(), self.named_buffers(), *args,
|
|
179
|
-
**kwargs)
|
|
179
|
+
return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
|
|
180
180
|
|
|
181
181
|
return call
|
|
182
182
|
|
|
183
183
|
|
|
184
184
|
def compile_nn_module(m: torch.nn.Module, methods=None):
|
|
185
185
|
if methods is None:
|
|
186
|
-
methods = [
|
|
186
|
+
methods = ["forward"]
|
|
187
187
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
188
|
+
type(
|
|
189
|
+
m.__class__.__name__ + "_with_CompileMixin",
|
|
190
|
+
(CompileMixin, m.__class__),
|
|
191
191
|
)
|
|
192
|
-
m.__class__ = NewParent
|
|
192
|
+
m.__class__ = NewParent # noqa: F821
|
|
193
193
|
|
|
194
194
|
|
|
195
195
|
def _torch_view(t: JaxValue) -> TorchValue:
|
|
@@ -227,15 +227,17 @@ def _jax_view(t: TorchValue) -> JaxValue:
|
|
|
227
227
|
jax_view = functools.partial(pytree.tree_map, _jax_view)
|
|
228
228
|
|
|
229
229
|
|
|
230
|
-
def call_jax(
|
|
231
|
-
|
|
230
|
+
def call_jax(
|
|
231
|
+
jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue
|
|
232
|
+
) -> TorchValue:
|
|
232
233
|
args, kwargs = jax_view((args, kwargs))
|
|
233
234
|
res: JaxValue = jax_func(*args, **kwargs)
|
|
234
235
|
return torch_view(res)
|
|
235
236
|
|
|
236
237
|
|
|
237
|
-
def call_torch(
|
|
238
|
-
|
|
238
|
+
def call_torch(
|
|
239
|
+
torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue
|
|
240
|
+
) -> JaxValue:
|
|
239
241
|
args, kwargs = torch_view((args, kwargs))
|
|
240
242
|
with torchax.default_env():
|
|
241
243
|
res: TorchValue = torch_func(*args, **kwargs)
|
|
@@ -245,10 +247,10 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue,
|
|
|
245
247
|
def j2t_autograd(fn, call_jax=call_jax):
|
|
246
248
|
"""Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
|
|
247
249
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
250
|
+
It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
|
|
251
|
+
activations). The wrapped function is then run via `call_jax` and integrated into
|
|
252
|
+
the PyTorch autograd framework by saving the residuals into the context object.
|
|
253
|
+
"""
|
|
252
254
|
|
|
253
255
|
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
254
256
|
# Becuase if it does, then it won't hit the compilation cache for
|
|
@@ -261,7 +263,7 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
261
263
|
primals should be a tuple (args, kwargs).
|
|
262
264
|
"""
|
|
263
265
|
import jax
|
|
264
|
-
from jax.tree_util import
|
|
266
|
+
from jax.tree_util import tree_unflatten
|
|
265
267
|
|
|
266
268
|
def fn_wrapper(*tensors):
|
|
267
269
|
# Reconstruct the original args and kwargs
|
|
@@ -277,6 +279,7 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
277
279
|
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
278
280
|
"""
|
|
279
281
|
from jax.tree_util import tree_unflatten
|
|
282
|
+
|
|
280
283
|
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
281
284
|
return fun_vjp(grad_out)
|
|
282
285
|
|
|
@@ -285,12 +288,11 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
285
288
|
from jax.tree_util import tree_flatten
|
|
286
289
|
|
|
287
290
|
class JaxFun(torch.autograd.Function):
|
|
288
|
-
|
|
289
291
|
@staticmethod
|
|
290
292
|
def forward(ctx, tree_def, *flat_args_kwargs):
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
293
|
+
tensors, other = util.partition(
|
|
294
|
+
flat_args_kwargs, lambda x: isinstance(x, torch.Tensor)
|
|
295
|
+
)
|
|
294
296
|
# We want the arguments that don't require grads to be closured?
|
|
295
297
|
|
|
296
298
|
y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
|
|
@@ -308,8 +310,9 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
308
310
|
assert len(grad_out) > 0
|
|
309
311
|
grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
|
|
310
312
|
|
|
311
|
-
input_grads_structured = call_jax(
|
|
312
|
-
|
|
313
|
+
input_grads_structured = call_jax(
|
|
314
|
+
_jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out
|
|
315
|
+
)
|
|
313
316
|
|
|
314
317
|
# Construct the gradient tuple to be returned.
|
|
315
318
|
# It needs to match the inputs to forward: (tree_def, *flat_inputs)
|
|
@@ -318,7 +321,8 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
318
321
|
# We need to put a None for inputs that did not require gradients.
|
|
319
322
|
final_grads = [None]
|
|
320
323
|
for needs_grad, grad in zip(
|
|
321
|
-
|
|
324
|
+
ctx.needs_input_grad[1:], input_grads_structured, strict=True
|
|
325
|
+
):
|
|
322
326
|
final_grads.append(grad if needs_grad else None)
|
|
323
327
|
|
|
324
328
|
return tuple(final_grads)
|
|
@@ -343,27 +347,27 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
|
|
|
343
347
|
return torch_view(jitted)
|
|
344
348
|
|
|
345
349
|
|
|
346
|
-
def jax_jit(torch_function,
|
|
347
|
-
kwargs_for_jax_jit=None,
|
|
348
|
-
fix_for_buffer_donation=False):
|
|
350
|
+
def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False):
|
|
349
351
|
return wrap_jax_jit(
|
|
350
|
-
|
|
352
|
+
torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit
|
|
353
|
+
)
|
|
351
354
|
|
|
352
355
|
|
|
353
356
|
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
|
|
354
357
|
return wrap_jax_jit(
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
kwargs_for_jax=kwargs_for_jax_shard_map)
|
|
358
|
+
torch_function, jax_jit_func=shard_map, kwargs_for_jax=kwargs_for_jax_shard_map
|
|
359
|
+
)
|
|
358
360
|
|
|
359
361
|
|
|
360
362
|
def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
|
|
361
363
|
return wrap_jax_jit(
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
364
|
+
torch_function,
|
|
365
|
+
jax_jit_func=jax.value_and_grad,
|
|
366
|
+
kwargs_for_jax=kwargs_for_value_and_grad,
|
|
367
|
+
)
|
|
365
368
|
|
|
366
369
|
|
|
367
370
|
def gradient_checkpoint(torch_function, kwargs=None):
|
|
368
371
|
return wrap_jax_jit(
|
|
369
|
-
|
|
372
|
+
torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs
|
|
373
|
+
)
|
torchax/mesh_util.py
CHANGED
|
@@ -13,8 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import jax
|
|
16
|
-
from jax.sharding import PartitionSpec, NamedSharding
|
|
17
16
|
import torch
|
|
17
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
+
|
|
18
19
|
import torchax
|
|
19
20
|
from torchax import interop
|
|
20
21
|
|
|
@@ -94,12 +95,13 @@ class SingleAxisSharder:
|
|
|
94
95
|
`_shard_first_multiple_of`.
|
|
95
96
|
"""
|
|
96
97
|
del name
|
|
97
|
-
sharding = _shard_first_multiple_of(
|
|
98
|
-
|
|
98
|
+
sharding = _shard_first_multiple_of(
|
|
99
|
+
self.axis_name, shapedtype.shape, self.axis_size
|
|
100
|
+
)
|
|
99
101
|
if not self.replicate_unshardable and all(s is None for s in sharding):
|
|
100
102
|
raise AssertionError(
|
|
101
|
-
|
|
102
|
-
|
|
103
|
+
f"Unable to find a dim to shard because "
|
|
104
|
+
f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
|
|
103
105
|
)
|
|
104
106
|
return sharding
|
|
105
107
|
|
|
@@ -159,15 +161,14 @@ class Mesh:
|
|
|
159
161
|
self.jax_mesh = jax_mesh
|
|
160
162
|
if sharder is None:
|
|
161
163
|
assert len(self.jax_mesh.axis_names) == 1
|
|
162
|
-
sharder = SingleAxisSharder(
|
|
163
|
-
|
|
164
|
+
sharder = SingleAxisSharder(
|
|
165
|
+
self.jax_mesh.axis_names[0], len(self.mesh.device_ids)
|
|
166
|
+
)
|
|
164
167
|
self._sharder = sharder
|
|
165
168
|
|
|
166
|
-
def initialize_model_sharded(
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
init_kwargs=None,
|
|
170
|
-
override_sharder=None):
|
|
169
|
+
def initialize_model_sharded(
|
|
170
|
+
self, model_class, init_args, init_kwargs=None, override_sharder=None
|
|
171
|
+
):
|
|
171
172
|
"""Initializes a PyTorch model with its parameters sharded across the mesh.
|
|
172
173
|
|
|
173
174
|
This method orchestrates the initialization of a `torch.nn.Module` such
|
|
@@ -208,17 +209,18 @@ class Mesh:
|
|
|
208
209
|
|
|
209
210
|
states = model.state_dict()
|
|
210
211
|
output_shards = {
|
|
211
|
-
|
|
212
|
-
|
|
212
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
213
|
+
for name, tensor in states.items()
|
|
213
214
|
}
|
|
214
215
|
|
|
215
216
|
def model_initializer():
|
|
216
|
-
with torchax.default_env(), torch.device(
|
|
217
|
+
with torchax.default_env(), torch.device("meta"):
|
|
217
218
|
model = model_class(*init_args, **init_kwargs)
|
|
218
219
|
return dict(model.state_dict())
|
|
219
220
|
|
|
220
221
|
jitted = interop.jax_jit(
|
|
221
|
-
|
|
222
|
+
model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}
|
|
223
|
+
)
|
|
222
224
|
weights_dict = jitted()
|
|
223
225
|
|
|
224
226
|
model.load_state_dict(weights_dict, assign=True)
|
|
@@ -228,7 +230,7 @@ class Mesh:
|
|
|
228
230
|
sharder = override_sharder or self._sharder
|
|
229
231
|
states = model.state_dict()
|
|
230
232
|
output_shards = {
|
|
231
|
-
|
|
232
|
-
|
|
233
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
234
|
+
for name, tensor in states.items()
|
|
233
235
|
}
|
|
234
236
|
model.load_state_dict(output_shards, assign=True)
|
torchax/ops/__init__.py
CHANGED
|
@@ -12,13 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
def all_aten_jax_ops():
|
|
16
17
|
# to load the ops
|
|
17
18
|
import torchax.ops.jaten # type: ignore
|
|
18
19
|
import torchax.ops.ops_registry # type: ignore
|
|
19
20
|
|
|
20
21
|
return {
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
key: val.func
|
|
23
|
+
for key, val in torchax.ops.ops_registry.all_aten_ops.items()
|
|
24
|
+
if val.is_jax_function
|
|
24
25
|
}
|