torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchax/interop.py ADDED
@@ -0,0 +1,369 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import copy
17
+ import functools
18
+ import torch
19
+ from inspect import signature
20
+ from functools import wraps
21
+ from torch.nn.utils import stateless as torch_stateless
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from jax import tree_util as pytree
25
+ from jax.experimental.shard_map import shard_map
26
+ from torchax import tensor
27
+ from torchax import util
28
+ from torchax.ops import mappings
29
+ import torchax
30
+
31
+ from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
32
+
33
+
34
+ def extract_all_buffers(m: torch.nn.Module):
35
+ buffers = {}
36
+ params = {}
37
+
38
+ def extract_one(module, prefix):
39
+ for k in dir(module):
40
+ try:
41
+ v = getattr(module, k)
42
+ except:
43
+ continue
44
+ qual_name = prefix + k
45
+ if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
46
+ params[qual_name] = v
47
+ elif isinstance(v, torch.Tensor):
48
+ buffers[qual_name] = v
49
+ for name, child in module.named_children():
50
+ extract_one(child, prefix + name + '.')
51
+
52
+ extract_one(m, '')
53
+ return params, buffers
54
+
55
+
56
+ def set_all_buffers(m, params, buffers):
57
+
58
+ def set_one(module, prefix):
59
+ for k in dir(module):
60
+ qual_name = prefix + k
61
+ if (potential_v := buffers.get(qual_name)) is not None:
62
+ setattr(module, k, potential_v)
63
+ elif (potential_v := params.get(qual_name)) is not None:
64
+ print(k, potential_v)
65
+ setattr(module, k, torch.nn.Parameter(potential_v))
66
+ for name, child in module.named_children():
67
+ set_one(child, prefix + name + '.')
68
+
69
+ set_one(m, '')
70
+
71
+
72
+ class JittableModule(torch.nn.Module):
73
+
74
+ def __init__(self,
75
+ m: torch.nn.Module,
76
+ extra_jit_args={},
77
+ dedup_parameters=True):
78
+ super().__init__()
79
+ self.params, self.buffers = extract_all_buffers(m)
80
+ self._model = m
81
+ self._jitted = {}
82
+
83
+ self._extra_jit_args = extra_jit_args
84
+
85
+ self._extra_dumped_weights = {}
86
+
87
+ if dedup_parameters:
88
+ temp = collections.defaultdict(list)
89
+ for k, v in self.params.items():
90
+ temp[id(v)].append(k)
91
+
92
+ for v in temp.values():
93
+ if len(v) > 1:
94
+ # duplicated weights with different name
95
+ self._extra_dumped_weights[v[0]] = v[1:]
96
+ for extra_keys in v[1:]:
97
+ del self.params[extra_keys]
98
+
99
+ @property
100
+ def __class__(self):
101
+ # Lie about the class type so that
102
+ # isinstance(jittable_module, self._model.__class__) works
103
+ return self._model.__class__
104
+
105
+ def __call__(self, *args, **kwargs):
106
+ return self.forward(*args, **kwargs)
107
+
108
+ def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
109
+ kwargs = kwargs or {}
110
+ params_copy = copy.copy(params)
111
+ params_copy.update(buffers)
112
+ # reinflate the state dict so there are not any missing keys
113
+ for k, v in self._extra_dumped_weights.items():
114
+ for new_key in v:
115
+ params_copy[new_key] = params_copy[k]
116
+
117
+ if isinstance(method_or_name, str):
118
+ method = getattr(self._model, method_or_name)
119
+ else:
120
+ if not callable(method_or_name):
121
+ raise TypeError(
122
+ f"method_or_name should be a callable or a string, got {type(method_or_name)}"
123
+ )
124
+ method = method_or_name
125
+ args = (self._model,) + args
126
+ with torch_stateless._reparametrize_module(self._model, params_copy):
127
+ res = method(*args, **kwargs)
128
+ return res
129
+
130
+ def jittable_call(self, method_name: str, *args, **kwargs):
131
+ if method_name not in self._jitted:
132
+ jitted = jax_jit(
133
+ functools.partial(self.functional_call, method_name),
134
+ kwargs_for_jax_jit=self._extra_jit_args,
135
+ )
136
+
137
+ def jitted_forward(*args, **kwargs):
138
+ return jitted(self.params, self.buffers, *args, **kwargs)
139
+
140
+ self._jitted[method_name] = jitted_forward
141
+ return self._jitted[method_name](*args, **kwargs)
142
+
143
+ def forward(self, *args, **kwargs):
144
+ return self.jittable_call('forward', *args, **kwargs)
145
+
146
+ def __getattr__(self, key):
147
+ if key == '_model':
148
+ return super().__getattr__(key)
149
+ if key in self._jitted:
150
+ return self._jitted[key]
151
+ return getattr(self._model, key)
152
+
153
+ def make_jitted(self, key):
154
+ jitted = jax_jit(
155
+ functools.partial(self.functional_call, key),
156
+ kwargs_for_jax_jit=self._extra_jit_args)
157
+
158
+ def call(*args, **kwargs):
159
+ return jitted(self.params, self.buffers, *args, **kwargs)
160
+
161
+ self._jitted[key] = call
162
+
163
+
164
+ class CompileMixin:
165
+
166
+ def functional_call(self, method, params, buffers, *args, **kwargs):
167
+ kwargs = kwargs or {}
168
+ params_copy = copy.copy(params)
169
+ params_copy.update(buffers)
170
+ with torch_stateless._reparametrize_module(self, params_copy):
171
+ res = method(*args, **kwargs)
172
+ return res
173
+
174
+ def jit(self, method):
175
+ jitted = jax_jit(functools.partial(self.functional_call, method_name))
176
+
177
+ def call(*args, **kwargs):
178
+ return jitted(self.named_paramters(), self.named_buffers(), *args,
179
+ **kwargs)
180
+
181
+ return call
182
+
183
+
184
+ def compile_nn_module(m: torch.nn.Module, methods=None):
185
+ if methods is None:
186
+ methods = ['forward']
187
+
188
+ new_parent = type(
189
+ m.__class__.__name__ + '_with_CompileMixin',
190
+ (CompileMixin, m.__class__),
191
+ )
192
+ m.__class__ = NewParent
193
+
194
+
195
+ def _torch_view(t: JaxValue) -> TorchValue:
196
+ # t is an object from jax land
197
+ # view it as-if it's a torch land object
198
+ if isinstance(t, jax.Array):
199
+ return tensor.Tensor(t, torchax.default_env())
200
+ if isinstance(t, jnp.dtype):
201
+ return mappings.j2t_dtype(t)
202
+ if callable(t): # t is a JaxCallable
203
+ return functools.partial(call_jax, t)
204
+ # regular types are not changed
205
+ return t
206
+
207
+
208
+ torch_view = functools.partial(pytree.tree_map, _torch_view)
209
+
210
+
211
+ def _jax_view(t: TorchValue) -> JaxValue:
212
+ # t is an object from torch land
213
+ # view it as-if it's a jax land object
214
+ if isinstance(t, torch.Tensor):
215
+ assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
216
+ return t.jax()
217
+ if isinstance(t, type(torch.int32)):
218
+ return mappings.t2j_dtype(t)
219
+
220
+ # torch.nn.Module needs special handling
221
+ if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
222
+ return functools.partial(call_torch, t)
223
+ # regular types are not changed
224
+ return t
225
+
226
+
227
+ jax_view = functools.partial(pytree.tree_map, _jax_view)
228
+
229
+
230
+ def call_jax(jax_func: JaxCallable, *args: TorchValue,
231
+ **kwargs: TorchValue) -> TorchValue:
232
+ args, kwargs = jax_view((args, kwargs))
233
+ res: JaxValue = jax_func(*args, **kwargs)
234
+ return torch_view(res)
235
+
236
+
237
+ def call_torch(torch_func: TorchCallable, *args: JaxValue,
238
+ **kwargs: JaxValue) -> JaxValue:
239
+ args, kwargs = torch_view((args, kwargs))
240
+ with torchax.default_env():
241
+ res: TorchValue = torch_func(*args, **kwargs)
242
+ return jax_view(res)
243
+
244
+
245
+ def j2t_autograd(fn, call_jax=call_jax):
246
+ """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
247
+
248
+ It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
249
+ activations). The wrapped function is then run via `call_jax` and integrated into
250
+ the PyTorch autograd framework by saving the residuals into the context object.
251
+ """
252
+
253
+ # NOTE(qihqi): This function cannot be inlined from the callsite
254
+ # Becuase if it does, then it won't hit the compilation cache for
255
+ # call_jax. Call jax uses functions' id as key.
256
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
257
+ # wrapped pure function, preventing cache collisions between different pure modules.
258
+ def _jax_forward(fn, other, tree_def, tensors):
259
+ """JAX function to compute output and vjp function.
260
+
261
+ primals should be a tuple (args, kwargs).
262
+ """
263
+ import jax
264
+ from jax.tree_util import tree_flatten, tree_unflatten
265
+
266
+ def fn_wrapper(*tensors):
267
+ # Reconstruct the original args and kwargs
268
+ flat_inputs = util.merge(tensors, other)
269
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
270
+ return fn(*args, **kwargs)
271
+
272
+ return jax.vjp(fn_wrapper, *tensors)
273
+
274
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
275
+ """JAX function to compute input gradients.
276
+
277
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
278
+ """
279
+ from jax.tree_util import tree_unflatten
280
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
281
+ return fun_vjp(grad_out)
282
+
283
+ @wraps(fn)
284
+ def inner(*args, **kwargs):
285
+ from jax.tree_util import tree_flatten
286
+
287
+ class JaxFun(torch.autograd.Function):
288
+
289
+ @staticmethod
290
+ def forward(ctx, tree_def, *flat_args_kwargs):
291
+
292
+ tensors, other = util.partition(flat_args_kwargs,
293
+ lambda x: isinstance(x, torch.Tensor))
294
+ # We want the arguments that don't require grads to be closured?
295
+
296
+ y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
297
+
298
+ # Save necessary information for backward
299
+ # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass.
300
+ # `residuals` contains the tensors needed for the backward pass.`
301
+ residuals, vjp_spec = tree_flatten(fun_vjp)
302
+ ctx.vjp_spec = vjp_spec
303
+ ctx.save_for_backward(*residuals)
304
+ return y
305
+
306
+ @staticmethod
307
+ def backward(ctx, *grad_out):
308
+ assert len(grad_out) > 0
309
+ grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
310
+
311
+ input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec,
312
+ ctx.saved_tensors, grad_out)
313
+
314
+ # Construct the gradient tuple to be returned.
315
+ # It needs to match the inputs to forward: (tree_def, *flat_inputs)
316
+ # The first gradient (for tree_def) is None.
317
+ # The subsequent gradients correspond to flat_inputs.
318
+ # We need to put a None for inputs that did not require gradients.
319
+ final_grads = [None]
320
+ for needs_grad, grad in zip(
321
+ ctx.needs_input_grad[1:], input_grads_structured, strict=True):
322
+ final_grads.append(grad if needs_grad else None)
323
+
324
+ return tuple(final_grads)
325
+
326
+ sig = signature(fn)
327
+ bound = sig.bind(*args, **kwargs)
328
+ bound.apply_defaults()
329
+ flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs))
330
+ y = JaxFun.apply(tree_def, *flat_args_kwargs)
331
+ return y
332
+
333
+ return inner
334
+
335
+
336
+ fori_loop = torch_view(jax.lax.fori_loop)
337
+
338
+
339
+ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
340
+ kwargs_for_jax = kwargs_for_jax or {}
341
+ jax_func = jax_view(torch_function)
342
+ jitted = jax_jit_func(jax_func, **kwargs_for_jax)
343
+ return torch_view(jitted)
344
+
345
+
346
+ def jax_jit(torch_function,
347
+ kwargs_for_jax_jit=None,
348
+ fix_for_buffer_donation=False):
349
+ return wrap_jax_jit(
350
+ torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)
351
+
352
+
353
+ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
354
+ return wrap_jax_jit(
355
+ torch_function,
356
+ jax_jit_func=shard_map,
357
+ kwargs_for_jax=kwargs_for_jax_shard_map)
358
+
359
+
360
+ def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
361
+ return wrap_jax_jit(
362
+ torch_function,
363
+ jax_jit_func=jax.value_and_grad,
364
+ kwargs_for_jax=kwargs_for_value_and_grad)
365
+
366
+
367
+ def gradient_checkpoint(torch_function, kwargs=None):
368
+ return wrap_jax_jit(
369
+ torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)
torchax/mesh_util.py ADDED
@@ -0,0 +1,234 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ from jax.sharding import PartitionSpec, NamedSharding
17
+ import torch
18
+ import torchax
19
+ from torchax import interop
20
+
21
+
22
+ def _shard_first_multiple_of(axis_name, shape, multiple_of):
23
+ """Creates a PartitionSpec to shard the first dimension divisible by a number.
24
+
25
+ Iterates through the dimensions specified by `shape`. Finds the first dimension
26
+ whose size is a multiple of `multiple_of` and returns a PartitionSpec that
27
+ shards that dimension along the given `axis_name`. All preceding dimensions
28
+ are not sharded (marked as None in the PartitionSpec). All subsequent dimensions
29
+ skipped, which would be implicitly treated as replicated.
30
+
31
+ Args:
32
+ axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl").
33
+ shape: A tuple or list representing the shape of the tensor to be sharded.
34
+ multiple_of: The integer value that a dimension size must be divisible by
35
+ in order to be sharded. Typically the size of the mesh axis.
36
+
37
+ Returns:
38
+ A jax.sharding.PartitionSpec object specifying how to shard the tensor.
39
+ For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4,
40
+ it would return PartitionSpec(None, 'x', None).
41
+ If none divides then it should return a replicated PartitionSpec
42
+ """
43
+ sharding = []
44
+ found = False
45
+ for size in shape:
46
+ if not found and size % multiple_of == 0:
47
+ found = True
48
+ sharding.append(axis_name)
49
+ else:
50
+ sharding.append(None)
51
+ return PartitionSpec(*sharding)
52
+
53
+
54
+ class SingleAxisSharder:
55
+ """A callable object that generates PartitionSpecs for single-axis sharding.
56
+
57
+ This sharder strategy attempts to shard the *first* dimension of a tensor
58
+ that is divisible by the specified `axis_size` along the given `axis_name`.
59
+ It's useful for simple 1D mesh sharding scenarios like FSDP where parameters
60
+ are typically sharded along one dimension.
61
+
62
+ Attributes:
63
+ axis_name: The name of the mesh axis to shard along.
64
+ axis_size: The size of the mesh axis (number of devices along that axis).
65
+ """
66
+
67
+ def __init__(self, axis_name, axis_size, replicate_unshardable=False):
68
+ """Initializes the SingleAxisSharder.
69
+
70
+ Args:
71
+ axis_name: The name of the mesh axis (e.g., "fsdp", "data").
72
+ axis_size: The number of devices along the specified mesh axis.
73
+ replicate_unshardable: indicate whether it should return replicated sharding
74
+ (P()) when none of the axis is divisible by the axis size.
75
+ """
76
+ self.axis_name = axis_name
77
+ self.axis_size = axis_size
78
+ self.replicate_unshardable = replicate_unshardable
79
+
80
+ def __call__(self, name, shapedtype):
81
+ """Generates a PartitionSpec for a given tensor name and shaped type.
82
+
83
+ Args:
84
+ name: The name of the tensor (e.g., parameter name). This argument is
85
+ provided for compatibility with more complex sharders but is not used
86
+ by this simple sharder.
87
+ shapedtype: An object with a `.shape` attribute describing the tensor's shape,
88
+ and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct
89
+ or a torch.Tensor)
90
+
91
+ Returns:
92
+ A jax.sharding.PartitionSpec determined by finding the first dimension
93
+ in `shapedtype.shape` divisible by `self.axis_size` using the helper
94
+ `_shard_first_multiple_of`.
95
+ """
96
+ del name
97
+ sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape,
98
+ self.axis_size)
99
+ if not self.replicate_unshardable and all(s is None for s in sharding):
100
+ raise AssertionError(
101
+ f"Unable to find a dim to shard because "
102
+ f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
103
+ )
104
+ return sharding
105
+
106
+
107
+ class Mesh:
108
+ """A helper class that wraps `jax.sharding.Mesh` object.
109
+
110
+ The goal of this class is to provide helper methods that facilitate the
111
+ sharding of PyTorch tensors or models given a JAX device mesh configuration.
112
+ It simplifies initializing models directly into a sharded state.
113
+
114
+ Attributes:
115
+ jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid
116
+ and axis names.
117
+ _sharder: The default sharding strategy callable (like SingleAxisSharder)
118
+ used to determine the PartitionSpec for each parameter if not overridden
119
+ during method calls. Can be None if no default is appropriate or set.
120
+ """
121
+
122
+ @classmethod
123
+ def fsdp_mesh(cls, axis_name="fsdp"):
124
+ """Creates a Mesh instance suitable for 1D FSDP-style sharding.
125
+
126
+ This named constructor creates a 1D mesh encompassing all available XLA
127
+ devices. It assigns the specified `axis_name` to this single dimension.
128
+ It then creates a `Mesh` instance using this JAX mesh and a
129
+ `SingleAxisSharder` configured appropriately for this 1D mesh.
130
+
131
+ Args:
132
+ axis_name: The name to assign to the single mesh axis (default: "fsdp").
133
+ This name will be used by the default `SingleAxisSharder`.
134
+
135
+ Returns:
136
+ A Mesh instance configured with a 1D JAX mesh across all devices and a
137
+ corresponding SingleAxisSharder.
138
+ """
139
+ ndevice = jax.device_count()
140
+ jax_mesh = jax.make_mesh((ndevice,), (axis_name,))
141
+ # replicate_unshardable so scalars and small model attributes are replicated.
142
+ return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True))
143
+
144
+ def __init__(self, jax_mesh, sharder=None):
145
+ """Initializes the Mesh helper.
146
+
147
+ Args:
148
+ jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the
149
+ physical device grid and logical axis names.
150
+ sharder: An optional callable (e.g., an instance of SingleAxisSharder)
151
+ that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`.
152
+ This serves as the default sharding strategy.
153
+ If None, and the provided `jax_mesh` has exactly one axis, a
154
+ `SingleAxisSharder` is created automatically for that single axis.
155
+ If None and the mesh has multiple axes, `_sharder` remains None, and
156
+ an `override_sharder` must be provided to methods like
157
+ `initialize_model_sharded`.
158
+ """
159
+ self.jax_mesh = jax_mesh
160
+ if sharder is None:
161
+ assert len(self.jax_mesh.axis_names) == 1
162
+ sharder = SingleAxisSharder(self.jax_mesh.axis_names[0],
163
+ len(self.mesh.device_ids))
164
+ self._sharder = sharder
165
+
166
+ def initialize_model_sharded(self,
167
+ model_class,
168
+ init_args,
169
+ init_kwargs=None,
170
+ override_sharder=None):
171
+ """Initializes a PyTorch model with its parameters sharded across the mesh.
172
+
173
+ This method orchestrates the initialization of a `torch.nn.Module` such
174
+ that its parameters are created directly on the target devices according
175
+ to the sharding specifications derived from the mesh and the chosen sharder.
176
+ It leverages `torchax.interop.jax_jit` to achieve this.
177
+
178
+ Args:
179
+ model_class: The PyTorch model class (a subclass of `torch.nn.Module`).
180
+ init_args: A tuple containing the positional arguments required by the
181
+ `model_class.__init__` method.
182
+ init_kwargs: An optional dictionary containing the keyword arguments for
183
+ the `model_class.__init__` method. Defaults to None (treated as {}).
184
+ override_sharder: An optional callable sharding strategy to use
185
+ specifically for this initialization. If provided, it takes precedence
186
+ over the mesh's default `_sharder`. It must accept `(name, shapedtype)`
187
+ and return a `PartitionSpec`. If None, the mesh's default `_sharder`
188
+ is used.
189
+
190
+ Returns:
191
+ An instance of `model_class` whose parameters have been initialized and
192
+ are represented by sharded tensors distributed across the devices in the
193
+ `jax_mesh`.
194
+
195
+ Raises:
196
+ ValueError: If no sharder is available (i.e., `override_sharder` is None
197
+ and the mesh's default `_sharder` is also None).
198
+ AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`)
199
+ if it fails to determine a valid sharding for any parameter.
200
+ TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`.
201
+ Other errors from JAX JIT compilation or PyTorch model initialization.
202
+ """
203
+ init_kwargs = init_kwargs or {}
204
+ with torch.device("meta"), torchax.disable_temporarily():
205
+ model = model_class(*init_args, **init_kwargs)
206
+
207
+ sharder = override_sharder or self._sharder
208
+
209
+ states = model.state_dict()
210
+ output_shards = {
211
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
212
+ for name, tensor in states.items()
213
+ }
214
+
215
+ def model_initializer():
216
+ with torchax.default_env(), torch.device('meta'):
217
+ model = model_class(*init_args, **init_kwargs)
218
+ return dict(model.state_dict())
219
+
220
+ jitted = interop.jax_jit(
221
+ model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards})
222
+ weights_dict = jitted()
223
+
224
+ model.load_state_dict(weights_dict, assign=True)
225
+ return model
226
+
227
+ def shard_model(self, model, override_sharder=None):
228
+ sharder = override_sharder or self._sharder
229
+ states = model.state_dict()
230
+ output_shards = {
231
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
232
+ for name, tensor in states.items()
233
+ }
234
+ model.load_state_dict(output_shards, assign=True)
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ def all_aten_jax_ops():
16
+ # to load the ops
17
+ import torchax.ops.jaten # type: ignore
18
+ import torchax.ops.ops_registry # type: ignore
19
+
20
+ return {
21
+ key: val.func
22
+ for key, val in torchax.ops.ops_registry.all_aten_ops.items()
23
+ if val.is_jax_function
24
+ }