torchax 0.0.10.dev20251118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +149 -0
- torchax/amp.py +218 -0
- torchax/checkpoint.py +85 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +796 -0
- torchax/device_module.py +47 -0
- torchax/export.py +258 -0
- torchax/flax.py +55 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +236 -0
- torchax/ops/__init__.py +25 -0
- torchax/ops/jaten.py +5753 -0
- torchax/ops/jax_reimplement.py +211 -0
- torchax/ops/jc10d.py +64 -0
- torchax/ops/jimage.py +122 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +608 -0
- torchax/ops/jtorchvision_nms.py +268 -0
- torchax/ops/mappings.py +139 -0
- torchax/ops/op_base.py +137 -0
- torchax/ops/ops_registry.py +74 -0
- torchax/tensor.py +732 -0
- torchax/train.py +130 -0
- torchax/types.py +27 -0
- torchax/util.py +104 -0
- torchax/view.py +399 -0
- torchax-0.0.10.dev20251118.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251118.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251118.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251118.dist-info/licenses/LICENSE +201 -0
torchax/interop.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import collections
|
|
16
|
+
import copy
|
|
17
|
+
import functools
|
|
18
|
+
from functools import wraps
|
|
19
|
+
from inspect import signature
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
import torch
|
|
24
|
+
from jax import tree_util as pytree
|
|
25
|
+
from jax.experimental.shard_map import shard_map
|
|
26
|
+
from torch.nn.utils import stateless as torch_stateless
|
|
27
|
+
|
|
28
|
+
import torchax
|
|
29
|
+
from torchax import tensor, util
|
|
30
|
+
from torchax.ops import mappings
|
|
31
|
+
from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def extract_all_buffers(m: torch.nn.Module):
|
|
35
|
+
buffers = {}
|
|
36
|
+
params = {}
|
|
37
|
+
|
|
38
|
+
def extract_one(module, prefix):
|
|
39
|
+
for k in dir(module):
|
|
40
|
+
try:
|
|
41
|
+
v = getattr(module, k)
|
|
42
|
+
except Exception:
|
|
43
|
+
continue
|
|
44
|
+
qual_name = prefix + k
|
|
45
|
+
if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
|
|
46
|
+
params[qual_name] = v
|
|
47
|
+
elif isinstance(v, torch.Tensor):
|
|
48
|
+
buffers[qual_name] = v
|
|
49
|
+
for name, child in module.named_children():
|
|
50
|
+
extract_one(child, prefix + name + ".")
|
|
51
|
+
|
|
52
|
+
extract_one(m, "")
|
|
53
|
+
return params, buffers
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def set_all_buffers(m, params, buffers):
|
|
57
|
+
def set_one(module, prefix):
|
|
58
|
+
for k in dir(module):
|
|
59
|
+
qual_name = prefix + k
|
|
60
|
+
if (potential_v := buffers.get(qual_name)) is not None:
|
|
61
|
+
setattr(module, k, potential_v)
|
|
62
|
+
elif (potential_v := params.get(qual_name)) is not None:
|
|
63
|
+
print(k, potential_v)
|
|
64
|
+
setattr(module, k, torch.nn.Parameter(potential_v))
|
|
65
|
+
for name, child in module.named_children():
|
|
66
|
+
set_one(child, prefix + name + ".")
|
|
67
|
+
|
|
68
|
+
set_one(m, "")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class JittableModule(torch.nn.Module):
|
|
72
|
+
def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True):
|
|
73
|
+
if extra_jit_args is None:
|
|
74
|
+
extra_jit_args = {}
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.params, self.buffers = extract_all_buffers(m)
|
|
77
|
+
self._model = m
|
|
78
|
+
self._jitted = {}
|
|
79
|
+
|
|
80
|
+
self._extra_jit_args = extra_jit_args
|
|
81
|
+
|
|
82
|
+
self._extra_dumped_weights = {}
|
|
83
|
+
|
|
84
|
+
if dedup_parameters:
|
|
85
|
+
temp = collections.defaultdict(list)
|
|
86
|
+
for k, v in self.params.items():
|
|
87
|
+
temp[id(v)].append(k)
|
|
88
|
+
|
|
89
|
+
for v in temp.values():
|
|
90
|
+
if len(v) > 1:
|
|
91
|
+
# duplicated weights with different name
|
|
92
|
+
self._extra_dumped_weights[v[0]] = v[1:]
|
|
93
|
+
for extra_keys in v[1:]:
|
|
94
|
+
del self.params[extra_keys]
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def __class__(self):
|
|
98
|
+
# Lie about the class type so that
|
|
99
|
+
# isinstance(jittable_module, self._model.__class__) works
|
|
100
|
+
return self._model.__class__
|
|
101
|
+
|
|
102
|
+
def __call__(self, *args, **kwargs):
|
|
103
|
+
return self.forward(*args, **kwargs)
|
|
104
|
+
|
|
105
|
+
def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
|
|
106
|
+
kwargs = kwargs or {}
|
|
107
|
+
params_copy = copy.copy(params)
|
|
108
|
+
params_copy.update(buffers)
|
|
109
|
+
# reinflate the state dict so there are not any missing keys
|
|
110
|
+
for k, v in self._extra_dumped_weights.items():
|
|
111
|
+
for new_key in v:
|
|
112
|
+
params_copy[new_key] = params_copy[k]
|
|
113
|
+
|
|
114
|
+
if isinstance(method_or_name, str):
|
|
115
|
+
method = getattr(self._model, method_or_name)
|
|
116
|
+
else:
|
|
117
|
+
if not callable(method_or_name):
|
|
118
|
+
raise TypeError(
|
|
119
|
+
f"method_or_name should be a callable or a string, got {type(method_or_name)}"
|
|
120
|
+
)
|
|
121
|
+
method = method_or_name
|
|
122
|
+
args = (self._model,) + args
|
|
123
|
+
with torch_stateless._reparametrize_module(self._model, params_copy):
|
|
124
|
+
res = method(*args, **kwargs)
|
|
125
|
+
return res
|
|
126
|
+
|
|
127
|
+
def jittable_call(self, method_name: str, *args, **kwargs):
|
|
128
|
+
if method_name not in self._jitted:
|
|
129
|
+
jitted = jax_jit(
|
|
130
|
+
functools.partial(self.functional_call, method_name),
|
|
131
|
+
kwargs_for_jax_jit=self._extra_jit_args,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def jitted_forward(*args, **kwargs):
|
|
135
|
+
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
136
|
+
|
|
137
|
+
self._jitted[method_name] = jitted_forward
|
|
138
|
+
return self._jitted[method_name](*args, **kwargs)
|
|
139
|
+
|
|
140
|
+
def forward(self, *args, **kwargs):
|
|
141
|
+
return self.jittable_call("forward", *args, **kwargs)
|
|
142
|
+
|
|
143
|
+
def __getattr__(self, key):
|
|
144
|
+
if key == "_model":
|
|
145
|
+
return super().__getattr__(key)
|
|
146
|
+
if key in self._jitted:
|
|
147
|
+
return self._jitted[key]
|
|
148
|
+
return getattr(self._model, key)
|
|
149
|
+
|
|
150
|
+
def make_jitted(self, key):
|
|
151
|
+
jitted = jax_jit(
|
|
152
|
+
functools.partial(self.functional_call, key),
|
|
153
|
+
kwargs_for_jax_jit=self._extra_jit_args,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def call(*args, **kwargs):
|
|
157
|
+
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
158
|
+
|
|
159
|
+
self._jitted[key] = call
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class CompileMixin:
|
|
163
|
+
def functional_call(self, method, params, buffers, *args, **kwargs):
|
|
164
|
+
kwargs = kwargs or {}
|
|
165
|
+
params_copy = copy.copy(params)
|
|
166
|
+
params_copy.update(buffers)
|
|
167
|
+
with torch_stateless._reparametrize_module(self, params_copy):
|
|
168
|
+
res = method(*args, **kwargs)
|
|
169
|
+
return res
|
|
170
|
+
|
|
171
|
+
def jit(self, method):
|
|
172
|
+
jitted = jax_jit(functools.partial(self.functional_call, method_name)) # noqa: F821
|
|
173
|
+
|
|
174
|
+
def call(*args, **kwargs):
|
|
175
|
+
return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
|
|
176
|
+
|
|
177
|
+
return call
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def compile_nn_module(m: torch.nn.Module, methods=None):
|
|
181
|
+
if methods is None:
|
|
182
|
+
methods = ["forward"]
|
|
183
|
+
|
|
184
|
+
type(
|
|
185
|
+
m.__class__.__name__ + "_with_CompileMixin",
|
|
186
|
+
(CompileMixin, m.__class__),
|
|
187
|
+
)
|
|
188
|
+
m.__class__ = NewParent # noqa: F821
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _torch_view(t: JaxValue) -> TorchValue:
|
|
192
|
+
# t is an object from jax land
|
|
193
|
+
# view it as-if it's a torch land object
|
|
194
|
+
if isinstance(t, jax.Array):
|
|
195
|
+
return tensor.Tensor(t, torchax.default_env())
|
|
196
|
+
if isinstance(t, jnp.dtype):
|
|
197
|
+
return mappings.j2t_dtype(t)
|
|
198
|
+
if callable(t): # t is a JaxCallable
|
|
199
|
+
return functools.partial(call_jax, t)
|
|
200
|
+
# regular types are not changed
|
|
201
|
+
return t
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
torch_view = functools.partial(pytree.tree_map, _torch_view)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _jax_view(t: TorchValue) -> JaxValue:
|
|
208
|
+
# t is an object from torch land
|
|
209
|
+
# view it as-if it's a jax land object
|
|
210
|
+
if isinstance(t, torch.Tensor):
|
|
211
|
+
assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
|
|
212
|
+
return t.jax()
|
|
213
|
+
if isinstance(t, type(torch.int32)):
|
|
214
|
+
return mappings.t2j_dtype(t)
|
|
215
|
+
|
|
216
|
+
# torch.nn.Module needs special handling
|
|
217
|
+
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
|
|
218
|
+
return functools.partial(call_torch, t)
|
|
219
|
+
# regular types are not changed
|
|
220
|
+
return t
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
jax_view = functools.partial(pytree.tree_map, _jax_view)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def call_jax(
|
|
227
|
+
jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue
|
|
228
|
+
) -> TorchValue:
|
|
229
|
+
args, kwargs = jax_view((args, kwargs))
|
|
230
|
+
res: JaxValue = jax_func(*args, **kwargs)
|
|
231
|
+
return torch_view(res)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def call_torch(
|
|
235
|
+
torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue
|
|
236
|
+
) -> JaxValue:
|
|
237
|
+
args, kwargs = torch_view((args, kwargs))
|
|
238
|
+
with torchax.default_env():
|
|
239
|
+
res: TorchValue = torch_func(*args, **kwargs)
|
|
240
|
+
return jax_view(res)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def j2t_autograd(fn, call_jax=call_jax):
|
|
244
|
+
"""Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
|
|
245
|
+
|
|
246
|
+
It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
|
|
247
|
+
activations). The wrapped function is then run via `call_jax` and integrated into
|
|
248
|
+
the PyTorch autograd framework by saving the residuals into the context object.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
252
|
+
# Becuase if it does, then it won't hit the compilation cache for
|
|
253
|
+
# call_jax. Call jax uses functions' id as key.
|
|
254
|
+
# It is nested inside j2t_autograd to ensure it gets a unique ID for each
|
|
255
|
+
# wrapped pure function, preventing cache collisions between different pure modules.
|
|
256
|
+
def _jax_forward(fn, other, tree_def, tensors):
|
|
257
|
+
"""JAX function to compute output and vjp function.
|
|
258
|
+
|
|
259
|
+
primals should be a tuple (args, kwargs).
|
|
260
|
+
"""
|
|
261
|
+
import jax
|
|
262
|
+
from jax.tree_util import tree_unflatten
|
|
263
|
+
|
|
264
|
+
def fn_wrapper(*tensors):
|
|
265
|
+
# Reconstruct the original args and kwargs
|
|
266
|
+
flat_inputs = util.merge(tensors, other)
|
|
267
|
+
args, kwargs = tree_unflatten(tree_def, flat_inputs)
|
|
268
|
+
return fn(*args, **kwargs)
|
|
269
|
+
|
|
270
|
+
return jax.vjp(fn_wrapper, *tensors)
|
|
271
|
+
|
|
272
|
+
def _jax_backward(vjp_spec, saved_tensors, grad_out):
|
|
273
|
+
"""JAX function to compute input gradients.
|
|
274
|
+
|
|
275
|
+
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
276
|
+
"""
|
|
277
|
+
from jax.tree_util import tree_unflatten
|
|
278
|
+
|
|
279
|
+
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
280
|
+
return fun_vjp(grad_out)
|
|
281
|
+
|
|
282
|
+
@wraps(fn)
|
|
283
|
+
def inner(*args, **kwargs):
|
|
284
|
+
from jax.tree_util import tree_flatten
|
|
285
|
+
|
|
286
|
+
class JaxFun(torch.autograd.Function):
|
|
287
|
+
@staticmethod
|
|
288
|
+
def forward(ctx, tree_def, *flat_args_kwargs):
|
|
289
|
+
tensors, other = util.partition(
|
|
290
|
+
flat_args_kwargs, lambda x: isinstance(x, torch.Tensor)
|
|
291
|
+
)
|
|
292
|
+
# We want the arguments that don't require grads to be closured?
|
|
293
|
+
|
|
294
|
+
y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
|
|
295
|
+
|
|
296
|
+
# Save necessary information for backward
|
|
297
|
+
# Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass.
|
|
298
|
+
# `residuals` contains the tensors needed for the backward pass.`
|
|
299
|
+
residuals, vjp_spec = tree_flatten(fun_vjp)
|
|
300
|
+
ctx.vjp_spec = vjp_spec
|
|
301
|
+
ctx.save_for_backward(*residuals)
|
|
302
|
+
return y
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def backward(ctx, *grad_out):
|
|
306
|
+
assert len(grad_out) > 0
|
|
307
|
+
grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
|
|
308
|
+
|
|
309
|
+
input_grads_structured = call_jax(
|
|
310
|
+
_jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Construct the gradient tuple to be returned.
|
|
314
|
+
# It needs to match the inputs to forward: (tree_def, *flat_inputs)
|
|
315
|
+
# The first gradient (for tree_def) is None.
|
|
316
|
+
# The subsequent gradients correspond to flat_inputs.
|
|
317
|
+
# We need to put a None for inputs that did not require gradients.
|
|
318
|
+
final_grads = [None]
|
|
319
|
+
for needs_grad, grad in zip(
|
|
320
|
+
ctx.needs_input_grad[1:], input_grads_structured, strict=True
|
|
321
|
+
):
|
|
322
|
+
final_grads.append(grad if needs_grad else None)
|
|
323
|
+
|
|
324
|
+
return tuple(final_grads)
|
|
325
|
+
|
|
326
|
+
sig = signature(fn)
|
|
327
|
+
bound = sig.bind(*args, **kwargs)
|
|
328
|
+
bound.apply_defaults()
|
|
329
|
+
flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs))
|
|
330
|
+
y = JaxFun.apply(tree_def, *flat_args_kwargs)
|
|
331
|
+
return y
|
|
332
|
+
|
|
333
|
+
return inner
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
fori_loop = torch_view(jax.lax.fori_loop)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
|
|
340
|
+
kwargs_for_jax = kwargs_for_jax or {}
|
|
341
|
+
jax_func = jax_view(torch_function)
|
|
342
|
+
jitted = jax_jit_func(jax_func, **kwargs_for_jax)
|
|
343
|
+
return torch_view(jitted)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False):
|
|
347
|
+
return wrap_jax_jit(
|
|
348
|
+
torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
|
|
353
|
+
return wrap_jax_jit(
|
|
354
|
+
torch_function, jax_jit_func=shard_map, kwargs_for_jax=kwargs_for_jax_shard_map
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
|
|
359
|
+
return wrap_jax_jit(
|
|
360
|
+
torch_function,
|
|
361
|
+
jax_jit_func=jax.value_and_grad,
|
|
362
|
+
kwargs_for_jax=kwargs_for_value_and_grad,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def gradient_checkpoint(torch_function, kwargs=None):
|
|
367
|
+
return wrap_jax_jit(
|
|
368
|
+
torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs
|
|
369
|
+
)
|
torchax/mesh_util.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import torch
|
|
17
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
+
|
|
19
|
+
import torchax
|
|
20
|
+
from torchax import interop
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _shard_first_multiple_of(axis_name, shape, multiple_of):
|
|
24
|
+
"""Creates a PartitionSpec to shard the first dimension divisible by a number.
|
|
25
|
+
|
|
26
|
+
Iterates through the dimensions specified by `shape`. Finds the first dimension
|
|
27
|
+
whose size is a multiple of `multiple_of` and returns a PartitionSpec that
|
|
28
|
+
shards that dimension along the given `axis_name`. All preceding dimensions
|
|
29
|
+
are not sharded (marked as None in the PartitionSpec). All subsequent dimensions
|
|
30
|
+
skipped, which would be implicitly treated as replicated.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl").
|
|
34
|
+
shape: A tuple or list representing the shape of the tensor to be sharded.
|
|
35
|
+
multiple_of: The integer value that a dimension size must be divisible by
|
|
36
|
+
in order to be sharded. Typically the size of the mesh axis.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A jax.sharding.PartitionSpec object specifying how to shard the tensor.
|
|
40
|
+
For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4,
|
|
41
|
+
it would return PartitionSpec(None, 'x', None).
|
|
42
|
+
If none divides then it should return a replicated PartitionSpec
|
|
43
|
+
"""
|
|
44
|
+
sharding = []
|
|
45
|
+
found = False
|
|
46
|
+
for size in shape:
|
|
47
|
+
if not found and size % multiple_of == 0:
|
|
48
|
+
found = True
|
|
49
|
+
sharding.append(axis_name)
|
|
50
|
+
else:
|
|
51
|
+
sharding.append(None)
|
|
52
|
+
return PartitionSpec(*sharding)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class SingleAxisSharder:
|
|
56
|
+
"""A callable object that generates PartitionSpecs for single-axis sharding.
|
|
57
|
+
|
|
58
|
+
This sharder strategy attempts to shard the *first* dimension of a tensor
|
|
59
|
+
that is divisible by the specified `axis_size` along the given `axis_name`.
|
|
60
|
+
It's useful for simple 1D mesh sharding scenarios like FSDP where parameters
|
|
61
|
+
are typically sharded along one dimension.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
axis_name: The name of the mesh axis to shard along.
|
|
65
|
+
axis_size: The size of the mesh axis (number of devices along that axis).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, axis_name, axis_size, replicate_unshardable=False):
|
|
69
|
+
"""Initializes the SingleAxisSharder.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
axis_name: The name of the mesh axis (e.g., "fsdp", "data").
|
|
73
|
+
axis_size: The number of devices along the specified mesh axis.
|
|
74
|
+
replicate_unshardable: indicate whether it should return replicated sharding
|
|
75
|
+
(P()) when none of the axis is divisible by the axis size.
|
|
76
|
+
"""
|
|
77
|
+
self.axis_name = axis_name
|
|
78
|
+
self.axis_size = axis_size
|
|
79
|
+
self.replicate_unshardable = replicate_unshardable
|
|
80
|
+
|
|
81
|
+
def __call__(self, name, shapedtype):
|
|
82
|
+
"""Generates a PartitionSpec for a given tensor name and shaped type.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
name: The name of the tensor (e.g., parameter name). This argument is
|
|
86
|
+
provided for compatibility with more complex sharders but is not used
|
|
87
|
+
by this simple sharder.
|
|
88
|
+
shapedtype: An object with a `.shape` attribute describing the tensor's shape,
|
|
89
|
+
and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct
|
|
90
|
+
or a torch.Tensor)
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
A jax.sharding.PartitionSpec determined by finding the first dimension
|
|
94
|
+
in `shapedtype.shape` divisible by `self.axis_size` using the helper
|
|
95
|
+
`_shard_first_multiple_of`.
|
|
96
|
+
"""
|
|
97
|
+
del name
|
|
98
|
+
sharding = _shard_first_multiple_of(
|
|
99
|
+
self.axis_name, shapedtype.shape, self.axis_size
|
|
100
|
+
)
|
|
101
|
+
if not self.replicate_unshardable and all(s is None for s in sharding):
|
|
102
|
+
raise AssertionError(
|
|
103
|
+
f"Unable to find a dim to shard because "
|
|
104
|
+
f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
|
|
105
|
+
)
|
|
106
|
+
return sharding
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Mesh:
|
|
110
|
+
"""A helper class that wraps `jax.sharding.Mesh` object.
|
|
111
|
+
|
|
112
|
+
The goal of this class is to provide helper methods that facilitate the
|
|
113
|
+
sharding of PyTorch tensors or models given a JAX device mesh configuration.
|
|
114
|
+
It simplifies initializing models directly into a sharded state.
|
|
115
|
+
|
|
116
|
+
Attributes:
|
|
117
|
+
jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid
|
|
118
|
+
and axis names.
|
|
119
|
+
_sharder: The default sharding strategy callable (like SingleAxisSharder)
|
|
120
|
+
used to determine the PartitionSpec for each parameter if not overridden
|
|
121
|
+
during method calls. Can be None if no default is appropriate or set.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def fsdp_mesh(cls, axis_name="fsdp"):
|
|
126
|
+
"""Creates a Mesh instance suitable for 1D FSDP-style sharding.
|
|
127
|
+
|
|
128
|
+
This named constructor creates a 1D mesh encompassing all available XLA
|
|
129
|
+
devices. It assigns the specified `axis_name` to this single dimension.
|
|
130
|
+
It then creates a `Mesh` instance using this JAX mesh and a
|
|
131
|
+
`SingleAxisSharder` configured appropriately for this 1D mesh.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
axis_name: The name to assign to the single mesh axis (default: "fsdp").
|
|
135
|
+
This name will be used by the default `SingleAxisSharder`.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A Mesh instance configured with a 1D JAX mesh across all devices and a
|
|
139
|
+
corresponding SingleAxisSharder.
|
|
140
|
+
"""
|
|
141
|
+
ndevice = jax.device_count()
|
|
142
|
+
jax_mesh = jax.make_mesh((ndevice,), (axis_name,))
|
|
143
|
+
# replicate_unshardable so scalars and small model attributes are replicated.
|
|
144
|
+
return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True))
|
|
145
|
+
|
|
146
|
+
def __init__(self, jax_mesh, sharder=None):
|
|
147
|
+
"""Initializes the Mesh helper.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the
|
|
151
|
+
physical device grid and logical axis names.
|
|
152
|
+
sharder: An optional callable (e.g., an instance of SingleAxisSharder)
|
|
153
|
+
that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`.
|
|
154
|
+
This serves as the default sharding strategy.
|
|
155
|
+
If None, and the provided `jax_mesh` has exactly one axis, a
|
|
156
|
+
`SingleAxisSharder` is created automatically for that single axis.
|
|
157
|
+
If None and the mesh has multiple axes, `_sharder` remains None, and
|
|
158
|
+
an `override_sharder` must be provided to methods like
|
|
159
|
+
`initialize_model_sharded`.
|
|
160
|
+
"""
|
|
161
|
+
self.jax_mesh = jax_mesh
|
|
162
|
+
if sharder is None:
|
|
163
|
+
assert len(self.jax_mesh.axis_names) == 1
|
|
164
|
+
sharder = SingleAxisSharder(
|
|
165
|
+
self.jax_mesh.axis_names[0], len(self.mesh.device_ids)
|
|
166
|
+
)
|
|
167
|
+
self._sharder = sharder
|
|
168
|
+
|
|
169
|
+
def initialize_model_sharded(
|
|
170
|
+
self, model_class, init_args, init_kwargs=None, override_sharder=None
|
|
171
|
+
):
|
|
172
|
+
"""Initializes a PyTorch model with its parameters sharded across the mesh.
|
|
173
|
+
|
|
174
|
+
This method orchestrates the initialization of a `torch.nn.Module` such
|
|
175
|
+
that its parameters are created directly on the target devices according
|
|
176
|
+
to the sharding specifications derived from the mesh and the chosen sharder.
|
|
177
|
+
It leverages `torchax.interop.jax_jit` to achieve this.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
model_class: The PyTorch model class (a subclass of `torch.nn.Module`).
|
|
181
|
+
init_args: A tuple containing the positional arguments required by the
|
|
182
|
+
`model_class.__init__` method.
|
|
183
|
+
init_kwargs: An optional dictionary containing the keyword arguments for
|
|
184
|
+
the `model_class.__init__` method. Defaults to None (treated as {}).
|
|
185
|
+
override_sharder: An optional callable sharding strategy to use
|
|
186
|
+
specifically for this initialization. If provided, it takes precedence
|
|
187
|
+
over the mesh's default `_sharder`. It must accept `(name, shapedtype)`
|
|
188
|
+
and return a `PartitionSpec`. If None, the mesh's default `_sharder`
|
|
189
|
+
is used.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
An instance of `model_class` whose parameters have been initialized and
|
|
193
|
+
are represented by sharded tensors distributed across the devices in the
|
|
194
|
+
`jax_mesh`.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If no sharder is available (i.e., `override_sharder` is None
|
|
198
|
+
and the mesh's default `_sharder` is also None).
|
|
199
|
+
AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`)
|
|
200
|
+
if it fails to determine a valid sharding for any parameter.
|
|
201
|
+
TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`.
|
|
202
|
+
Other errors from JAX JIT compilation or PyTorch model initialization.
|
|
203
|
+
"""
|
|
204
|
+
init_kwargs = init_kwargs or {}
|
|
205
|
+
with torch.device("meta"), torchax.disable_temporarily():
|
|
206
|
+
model = model_class(*init_args, **init_kwargs)
|
|
207
|
+
|
|
208
|
+
sharder = override_sharder or self._sharder
|
|
209
|
+
|
|
210
|
+
states = model.state_dict()
|
|
211
|
+
output_shards = {
|
|
212
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
213
|
+
for name, tensor in states.items()
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
def model_initializer():
|
|
217
|
+
with torchax.default_env(), torch.device("meta"):
|
|
218
|
+
model = model_class(*init_args, **init_kwargs)
|
|
219
|
+
return dict(model.state_dict())
|
|
220
|
+
|
|
221
|
+
jitted = interop.jax_jit(
|
|
222
|
+
model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}
|
|
223
|
+
)
|
|
224
|
+
weights_dict = jitted()
|
|
225
|
+
|
|
226
|
+
model.load_state_dict(weights_dict, assign=True)
|
|
227
|
+
return model
|
|
228
|
+
|
|
229
|
+
def shard_model(self, model, override_sharder=None):
|
|
230
|
+
sharder = override_sharder or self._sharder
|
|
231
|
+
states = model.state_dict()
|
|
232
|
+
output_shards = {
|
|
233
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
234
|
+
for name, tensor in states.items()
|
|
235
|
+
}
|
|
236
|
+
model.load_state_dict(output_shards, assign=True)
|
torchax/ops/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def all_aten_jax_ops():
|
|
17
|
+
# to load the ops
|
|
18
|
+
import torchax.ops.jaten # type: ignore
|
|
19
|
+
import torchax.ops.ops_registry # type: ignore
|
|
20
|
+
|
|
21
|
+
return {
|
|
22
|
+
key: val.func
|
|
23
|
+
for key, val in torchax.ops.ops_registry.all_aten_ops.items()
|
|
24
|
+
if val.is_jax_function
|
|
25
|
+
}
|