torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/interop.py ADDED
@@ -0,0 +1,369 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import copy
17
+ import functools
18
+ from functools import wraps
19
+ from inspect import signature
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import torch
24
+ from jax import tree_util as pytree
25
+ from jax.experimental.shard_map import shard_map
26
+ from torch.nn.utils import stateless as torch_stateless
27
+
28
+ import torchax
29
+ from torchax import tensor, util
30
+ from torchax.ops import mappings
31
+ from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
32
+
33
+
34
+ def extract_all_buffers(m: torch.nn.Module):
35
+ buffers = {}
36
+ params = {}
37
+
38
+ def extract_one(module, prefix):
39
+ for k in dir(module):
40
+ try:
41
+ v = getattr(module, k)
42
+ except Exception:
43
+ continue
44
+ qual_name = prefix + k
45
+ if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
46
+ params[qual_name] = v
47
+ elif isinstance(v, torch.Tensor):
48
+ buffers[qual_name] = v
49
+ for name, child in module.named_children():
50
+ extract_one(child, prefix + name + ".")
51
+
52
+ extract_one(m, "")
53
+ return params, buffers
54
+
55
+
56
+ def set_all_buffers(m, params, buffers):
57
+ def set_one(module, prefix):
58
+ for k in dir(module):
59
+ qual_name = prefix + k
60
+ if (potential_v := buffers.get(qual_name)) is not None:
61
+ setattr(module, k, potential_v)
62
+ elif (potential_v := params.get(qual_name)) is not None:
63
+ print(k, potential_v)
64
+ setattr(module, k, torch.nn.Parameter(potential_v))
65
+ for name, child in module.named_children():
66
+ set_one(child, prefix + name + ".")
67
+
68
+ set_one(m, "")
69
+
70
+
71
+ class JittableModule(torch.nn.Module):
72
+ def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True):
73
+ if extra_jit_args is None:
74
+ extra_jit_args = {}
75
+ super().__init__()
76
+ self.params, self.buffers = extract_all_buffers(m)
77
+ self._model = m
78
+ self._jitted = {}
79
+
80
+ self._extra_jit_args = extra_jit_args
81
+
82
+ self._extra_dumped_weights = {}
83
+
84
+ if dedup_parameters:
85
+ temp = collections.defaultdict(list)
86
+ for k, v in self.params.items():
87
+ temp[id(v)].append(k)
88
+
89
+ for v in temp.values():
90
+ if len(v) > 1:
91
+ # duplicated weights with different name
92
+ self._extra_dumped_weights[v[0]] = v[1:]
93
+ for extra_keys in v[1:]:
94
+ del self.params[extra_keys]
95
+
96
+ @property
97
+ def __class__(self):
98
+ # Lie about the class type so that
99
+ # isinstance(jittable_module, self._model.__class__) works
100
+ return self._model.__class__
101
+
102
+ def __call__(self, *args, **kwargs):
103
+ return self.forward(*args, **kwargs)
104
+
105
+ def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
106
+ kwargs = kwargs or {}
107
+ params_copy = copy.copy(params)
108
+ params_copy.update(buffers)
109
+ # reinflate the state dict so there are not any missing keys
110
+ for k, v in self._extra_dumped_weights.items():
111
+ for new_key in v:
112
+ params_copy[new_key] = params_copy[k]
113
+
114
+ if isinstance(method_or_name, str):
115
+ method = getattr(self._model, method_or_name)
116
+ else:
117
+ if not callable(method_or_name):
118
+ raise TypeError(
119
+ f"method_or_name should be a callable or a string, got {type(method_or_name)}"
120
+ )
121
+ method = method_or_name
122
+ args = (self._model,) + args
123
+ with torch_stateless._reparametrize_module(self._model, params_copy):
124
+ res = method(*args, **kwargs)
125
+ return res
126
+
127
+ def jittable_call(self, method_name: str, *args, **kwargs):
128
+ if method_name not in self._jitted:
129
+ jitted = jax_jit(
130
+ functools.partial(self.functional_call, method_name),
131
+ kwargs_for_jax_jit=self._extra_jit_args,
132
+ )
133
+
134
+ def jitted_forward(*args, **kwargs):
135
+ return jitted(self.params, self.buffers, *args, **kwargs)
136
+
137
+ self._jitted[method_name] = jitted_forward
138
+ return self._jitted[method_name](*args, **kwargs)
139
+
140
+ def forward(self, *args, **kwargs):
141
+ return self.jittable_call("forward", *args, **kwargs)
142
+
143
+ def __getattr__(self, key):
144
+ if key == "_model":
145
+ return super().__getattr__(key)
146
+ if key in self._jitted:
147
+ return self._jitted[key]
148
+ return getattr(self._model, key)
149
+
150
+ def make_jitted(self, key):
151
+ jitted = jax_jit(
152
+ functools.partial(self.functional_call, key),
153
+ kwargs_for_jax_jit=self._extra_jit_args,
154
+ )
155
+
156
+ def call(*args, **kwargs):
157
+ return jitted(self.params, self.buffers, *args, **kwargs)
158
+
159
+ self._jitted[key] = call
160
+
161
+
162
+ class CompileMixin:
163
+ def functional_call(self, method, params, buffers, *args, **kwargs):
164
+ kwargs = kwargs or {}
165
+ params_copy = copy.copy(params)
166
+ params_copy.update(buffers)
167
+ with torch_stateless._reparametrize_module(self, params_copy):
168
+ res = method(*args, **kwargs)
169
+ return res
170
+
171
+ def jit(self, method):
172
+ jitted = jax_jit(functools.partial(self.functional_call, method_name)) # noqa: F821
173
+
174
+ def call(*args, **kwargs):
175
+ return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
176
+
177
+ return call
178
+
179
+
180
+ def compile_nn_module(m: torch.nn.Module, methods=None):
181
+ if methods is None:
182
+ methods = ["forward"]
183
+
184
+ type(
185
+ m.__class__.__name__ + "_with_CompileMixin",
186
+ (CompileMixin, m.__class__),
187
+ )
188
+ m.__class__ = NewParent # noqa: F821
189
+
190
+
191
+ def _torch_view(t: JaxValue) -> TorchValue:
192
+ # t is an object from jax land
193
+ # view it as-if it's a torch land object
194
+ if isinstance(t, jax.Array):
195
+ return tensor.Tensor(t, torchax.default_env())
196
+ if isinstance(t, jnp.dtype):
197
+ return mappings.j2t_dtype(t)
198
+ if callable(t): # t is a JaxCallable
199
+ return functools.partial(call_jax, t)
200
+ # regular types are not changed
201
+ return t
202
+
203
+
204
+ torch_view = functools.partial(pytree.tree_map, _torch_view)
205
+
206
+
207
+ def _jax_view(t: TorchValue) -> JaxValue:
208
+ # t is an object from torch land
209
+ # view it as-if it's a jax land object
210
+ if isinstance(t, torch.Tensor):
211
+ assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
212
+ return t.jax()
213
+ if isinstance(t, type(torch.int32)):
214
+ return mappings.t2j_dtype(t)
215
+
216
+ # torch.nn.Module needs special handling
217
+ if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
218
+ return functools.partial(call_torch, t)
219
+ # regular types are not changed
220
+ return t
221
+
222
+
223
+ jax_view = functools.partial(pytree.tree_map, _jax_view)
224
+
225
+
226
+ def call_jax(
227
+ jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue
228
+ ) -> TorchValue:
229
+ args, kwargs = jax_view((args, kwargs))
230
+ res: JaxValue = jax_func(*args, **kwargs)
231
+ return torch_view(res)
232
+
233
+
234
+ def call_torch(
235
+ torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue
236
+ ) -> JaxValue:
237
+ args, kwargs = torch_view((args, kwargs))
238
+ with torchax.default_env():
239
+ res: TorchValue = torch_func(*args, **kwargs)
240
+ return jax_view(res)
241
+
242
+
243
+ def j2t_autograd(fn, call_jax=call_jax):
244
+ """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
245
+
246
+ It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
247
+ activations). The wrapped function is then run via `call_jax` and integrated into
248
+ the PyTorch autograd framework by saving the residuals into the context object.
249
+ """
250
+
251
+ # NOTE(qihqi): This function cannot be inlined from the callsite
252
+ # Becuase if it does, then it won't hit the compilation cache for
253
+ # call_jax. Call jax uses functions' id as key.
254
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
255
+ # wrapped pure function, preventing cache collisions between different pure modules.
256
+ def _jax_forward(fn, other, tree_def, tensors):
257
+ """JAX function to compute output and vjp function.
258
+
259
+ primals should be a tuple (args, kwargs).
260
+ """
261
+ import jax
262
+ from jax.tree_util import tree_unflatten
263
+
264
+ def fn_wrapper(*tensors):
265
+ # Reconstruct the original args and kwargs
266
+ flat_inputs = util.merge(tensors, other)
267
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
268
+ return fn(*args, **kwargs)
269
+
270
+ return jax.vjp(fn_wrapper, *tensors)
271
+
272
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
273
+ """JAX function to compute input gradients.
274
+
275
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
276
+ """
277
+ from jax.tree_util import tree_unflatten
278
+
279
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
280
+ return fun_vjp(grad_out)
281
+
282
+ @wraps(fn)
283
+ def inner(*args, **kwargs):
284
+ from jax.tree_util import tree_flatten
285
+
286
+ class JaxFun(torch.autograd.Function):
287
+ @staticmethod
288
+ def forward(ctx, tree_def, *flat_args_kwargs):
289
+ tensors, other = util.partition(
290
+ flat_args_kwargs, lambda x: isinstance(x, torch.Tensor)
291
+ )
292
+ # We want the arguments that don't require grads to be closured?
293
+
294
+ y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
295
+
296
+ # Save necessary information for backward
297
+ # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass.
298
+ # `residuals` contains the tensors needed for the backward pass.`
299
+ residuals, vjp_spec = tree_flatten(fun_vjp)
300
+ ctx.vjp_spec = vjp_spec
301
+ ctx.save_for_backward(*residuals)
302
+ return y
303
+
304
+ @staticmethod
305
+ def backward(ctx, *grad_out):
306
+ assert len(grad_out) > 0
307
+ grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
308
+
309
+ input_grads_structured = call_jax(
310
+ _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out
311
+ )
312
+
313
+ # Construct the gradient tuple to be returned.
314
+ # It needs to match the inputs to forward: (tree_def, *flat_inputs)
315
+ # The first gradient (for tree_def) is None.
316
+ # The subsequent gradients correspond to flat_inputs.
317
+ # We need to put a None for inputs that did not require gradients.
318
+ final_grads = [None]
319
+ for needs_grad, grad in zip(
320
+ ctx.needs_input_grad[1:], input_grads_structured, strict=True
321
+ ):
322
+ final_grads.append(grad if needs_grad else None)
323
+
324
+ return tuple(final_grads)
325
+
326
+ sig = signature(fn)
327
+ bound = sig.bind(*args, **kwargs)
328
+ bound.apply_defaults()
329
+ flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs))
330
+ y = JaxFun.apply(tree_def, *flat_args_kwargs)
331
+ return y
332
+
333
+ return inner
334
+
335
+
336
+ fori_loop = torch_view(jax.lax.fori_loop)
337
+
338
+
339
+ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
340
+ kwargs_for_jax = kwargs_for_jax or {}
341
+ jax_func = jax_view(torch_function)
342
+ jitted = jax_jit_func(jax_func, **kwargs_for_jax)
343
+ return torch_view(jitted)
344
+
345
+
346
+ def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False):
347
+ return wrap_jax_jit(
348
+ torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit
349
+ )
350
+
351
+
352
+ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
353
+ return wrap_jax_jit(
354
+ torch_function, jax_jit_func=shard_map, kwargs_for_jax=kwargs_for_jax_shard_map
355
+ )
356
+
357
+
358
+ def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
359
+ return wrap_jax_jit(
360
+ torch_function,
361
+ jax_jit_func=jax.value_and_grad,
362
+ kwargs_for_jax=kwargs_for_value_and_grad,
363
+ )
364
+
365
+
366
+ def gradient_checkpoint(torch_function, kwargs=None):
367
+ return wrap_jax_jit(
368
+ torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs
369
+ )
torchax/mesh_util.py ADDED
@@ -0,0 +1,236 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import torch
17
+ from jax.sharding import NamedSharding, PartitionSpec
18
+
19
+ import torchax
20
+ from torchax import interop
21
+
22
+
23
+ def _shard_first_multiple_of(axis_name, shape, multiple_of):
24
+ """Creates a PartitionSpec to shard the first dimension divisible by a number.
25
+
26
+ Iterates through the dimensions specified by `shape`. Finds the first dimension
27
+ whose size is a multiple of `multiple_of` and returns a PartitionSpec that
28
+ shards that dimension along the given `axis_name`. All preceding dimensions
29
+ are not sharded (marked as None in the PartitionSpec). All subsequent dimensions
30
+ skipped, which would be implicitly treated as replicated.
31
+
32
+ Args:
33
+ axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl").
34
+ shape: A tuple or list representing the shape of the tensor to be sharded.
35
+ multiple_of: The integer value that a dimension size must be divisible by
36
+ in order to be sharded. Typically the size of the mesh axis.
37
+
38
+ Returns:
39
+ A jax.sharding.PartitionSpec object specifying how to shard the tensor.
40
+ For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4,
41
+ it would return PartitionSpec(None, 'x', None).
42
+ If none divides then it should return a replicated PartitionSpec
43
+ """
44
+ sharding = []
45
+ found = False
46
+ for size in shape:
47
+ if not found and size % multiple_of == 0:
48
+ found = True
49
+ sharding.append(axis_name)
50
+ else:
51
+ sharding.append(None)
52
+ return PartitionSpec(*sharding)
53
+
54
+
55
+ class SingleAxisSharder:
56
+ """A callable object that generates PartitionSpecs for single-axis sharding.
57
+
58
+ This sharder strategy attempts to shard the *first* dimension of a tensor
59
+ that is divisible by the specified `axis_size` along the given `axis_name`.
60
+ It's useful for simple 1D mesh sharding scenarios like FSDP where parameters
61
+ are typically sharded along one dimension.
62
+
63
+ Attributes:
64
+ axis_name: The name of the mesh axis to shard along.
65
+ axis_size: The size of the mesh axis (number of devices along that axis).
66
+ """
67
+
68
+ def __init__(self, axis_name, axis_size, replicate_unshardable=False):
69
+ """Initializes the SingleAxisSharder.
70
+
71
+ Args:
72
+ axis_name: The name of the mesh axis (e.g., "fsdp", "data").
73
+ axis_size: The number of devices along the specified mesh axis.
74
+ replicate_unshardable: indicate whether it should return replicated sharding
75
+ (P()) when none of the axis is divisible by the axis size.
76
+ """
77
+ self.axis_name = axis_name
78
+ self.axis_size = axis_size
79
+ self.replicate_unshardable = replicate_unshardable
80
+
81
+ def __call__(self, name, shapedtype):
82
+ """Generates a PartitionSpec for a given tensor name and shaped type.
83
+
84
+ Args:
85
+ name: The name of the tensor (e.g., parameter name). This argument is
86
+ provided for compatibility with more complex sharders but is not used
87
+ by this simple sharder.
88
+ shapedtype: An object with a `.shape` attribute describing the tensor's shape,
89
+ and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct
90
+ or a torch.Tensor)
91
+
92
+ Returns:
93
+ A jax.sharding.PartitionSpec determined by finding the first dimension
94
+ in `shapedtype.shape` divisible by `self.axis_size` using the helper
95
+ `_shard_first_multiple_of`.
96
+ """
97
+ del name
98
+ sharding = _shard_first_multiple_of(
99
+ self.axis_name, shapedtype.shape, self.axis_size
100
+ )
101
+ if not self.replicate_unshardable and all(s is None for s in sharding):
102
+ raise AssertionError(
103
+ f"Unable to find a dim to shard because "
104
+ f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
105
+ )
106
+ return sharding
107
+
108
+
109
+ class Mesh:
110
+ """A helper class that wraps `jax.sharding.Mesh` object.
111
+
112
+ The goal of this class is to provide helper methods that facilitate the
113
+ sharding of PyTorch tensors or models given a JAX device mesh configuration.
114
+ It simplifies initializing models directly into a sharded state.
115
+
116
+ Attributes:
117
+ jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid
118
+ and axis names.
119
+ _sharder: The default sharding strategy callable (like SingleAxisSharder)
120
+ used to determine the PartitionSpec for each parameter if not overridden
121
+ during method calls. Can be None if no default is appropriate or set.
122
+ """
123
+
124
+ @classmethod
125
+ def fsdp_mesh(cls, axis_name="fsdp"):
126
+ """Creates a Mesh instance suitable for 1D FSDP-style sharding.
127
+
128
+ This named constructor creates a 1D mesh encompassing all available XLA
129
+ devices. It assigns the specified `axis_name` to this single dimension.
130
+ It then creates a `Mesh` instance using this JAX mesh and a
131
+ `SingleAxisSharder` configured appropriately for this 1D mesh.
132
+
133
+ Args:
134
+ axis_name: The name to assign to the single mesh axis (default: "fsdp").
135
+ This name will be used by the default `SingleAxisSharder`.
136
+
137
+ Returns:
138
+ A Mesh instance configured with a 1D JAX mesh across all devices and a
139
+ corresponding SingleAxisSharder.
140
+ """
141
+ ndevice = jax.device_count()
142
+ jax_mesh = jax.make_mesh((ndevice,), (axis_name,))
143
+ # replicate_unshardable so scalars and small model attributes are replicated.
144
+ return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True))
145
+
146
+ def __init__(self, jax_mesh, sharder=None):
147
+ """Initializes the Mesh helper.
148
+
149
+ Args:
150
+ jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the
151
+ physical device grid and logical axis names.
152
+ sharder: An optional callable (e.g., an instance of SingleAxisSharder)
153
+ that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`.
154
+ This serves as the default sharding strategy.
155
+ If None, and the provided `jax_mesh` has exactly one axis, a
156
+ `SingleAxisSharder` is created automatically for that single axis.
157
+ If None and the mesh has multiple axes, `_sharder` remains None, and
158
+ an `override_sharder` must be provided to methods like
159
+ `initialize_model_sharded`.
160
+ """
161
+ self.jax_mesh = jax_mesh
162
+ if sharder is None:
163
+ assert len(self.jax_mesh.axis_names) == 1
164
+ sharder = SingleAxisSharder(
165
+ self.jax_mesh.axis_names[0], len(self.mesh.device_ids)
166
+ )
167
+ self._sharder = sharder
168
+
169
+ def initialize_model_sharded(
170
+ self, model_class, init_args, init_kwargs=None, override_sharder=None
171
+ ):
172
+ """Initializes a PyTorch model with its parameters sharded across the mesh.
173
+
174
+ This method orchestrates the initialization of a `torch.nn.Module` such
175
+ that its parameters are created directly on the target devices according
176
+ to the sharding specifications derived from the mesh and the chosen sharder.
177
+ It leverages `torchax.interop.jax_jit` to achieve this.
178
+
179
+ Args:
180
+ model_class: The PyTorch model class (a subclass of `torch.nn.Module`).
181
+ init_args: A tuple containing the positional arguments required by the
182
+ `model_class.__init__` method.
183
+ init_kwargs: An optional dictionary containing the keyword arguments for
184
+ the `model_class.__init__` method. Defaults to None (treated as {}).
185
+ override_sharder: An optional callable sharding strategy to use
186
+ specifically for this initialization. If provided, it takes precedence
187
+ over the mesh's default `_sharder`. It must accept `(name, shapedtype)`
188
+ and return a `PartitionSpec`. If None, the mesh's default `_sharder`
189
+ is used.
190
+
191
+ Returns:
192
+ An instance of `model_class` whose parameters have been initialized and
193
+ are represented by sharded tensors distributed across the devices in the
194
+ `jax_mesh`.
195
+
196
+ Raises:
197
+ ValueError: If no sharder is available (i.e., `override_sharder` is None
198
+ and the mesh's default `_sharder` is also None).
199
+ AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`)
200
+ if it fails to determine a valid sharding for any parameter.
201
+ TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`.
202
+ Other errors from JAX JIT compilation or PyTorch model initialization.
203
+ """
204
+ init_kwargs = init_kwargs or {}
205
+ with torch.device("meta"), torchax.disable_temporarily():
206
+ model = model_class(*init_args, **init_kwargs)
207
+
208
+ sharder = override_sharder or self._sharder
209
+
210
+ states = model.state_dict()
211
+ output_shards = {
212
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
213
+ for name, tensor in states.items()
214
+ }
215
+
216
+ def model_initializer():
217
+ with torchax.default_env(), torch.device("meta"):
218
+ model = model_class(*init_args, **init_kwargs)
219
+ return dict(model.state_dict())
220
+
221
+ jitted = interop.jax_jit(
222
+ model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}
223
+ )
224
+ weights_dict = jitted()
225
+
226
+ model.load_state_dict(weights_dict, assign=True)
227
+ return model
228
+
229
+ def shard_model(self, model, override_sharder=None):
230
+ sharder = override_sharder or self._sharder
231
+ states = model.state_dict()
232
+ output_shards = {
233
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
234
+ for name, tensor in states.items()
235
+ }
236
+ model.load_state_dict(output_shards, assign=True)
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ def all_aten_jax_ops():
17
+ # to load the ops
18
+ import torchax.ops.jaten # type: ignore
19
+ import torchax.ops.ops_registry # type: ignore
20
+
21
+ return {
22
+ key: val.func
23
+ for key, val in torchax.ops.ops_registry.all_aten_ops.items()
24
+ if val.is_jax_function
25
+ }