torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/mesh_util.py ADDED
@@ -0,0 +1,220 @@
1
+ import jax
2
+ from jax.sharding import PartitionSpec, NamedSharding
3
+ import torch
4
+ import torchax
5
+ from torchax import interop
6
+
7
+
8
+ def _shard_first_multiple_of(axis_name, shape, multiple_of):
9
+ """Creates a PartitionSpec to shard the first dimension divisible by a number.
10
+
11
+ Iterates through the dimensions specified by `shape`. Finds the first dimension
12
+ whose size is a multiple of `multiple_of` and returns a PartitionSpec that
13
+ shards that dimension along the given `axis_name`. All preceding dimensions
14
+ are not sharded (marked as None in the PartitionSpec). All subsequent dimensions
15
+ skipped, which would be implicitly treated as replicated.
16
+
17
+ Args:
18
+ axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl").
19
+ shape: A tuple or list representing the shape of the tensor to be sharded.
20
+ multiple_of: The integer value that a dimension size must be divisible by
21
+ in order to be sharded. Typically the size of the mesh axis.
22
+
23
+ Returns:
24
+ A jax.sharding.PartitionSpec object specifying how to shard the tensor.
25
+ For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4,
26
+ it would return PartitionSpec(None, 'x', None).
27
+ If none divides then it should return a replicated PartitionSpec
28
+ """
29
+ sharding = []
30
+ found = False
31
+ for size in shape:
32
+ if not found and size % multiple_of == 0:
33
+ found = True
34
+ sharding.append(axis_name)
35
+ else:
36
+ sharding.append(None)
37
+ return PartitionSpec(*sharding)
38
+
39
+
40
+ class SingleAxisSharder:
41
+ """A callable object that generates PartitionSpecs for single-axis sharding.
42
+
43
+ This sharder strategy attempts to shard the *first* dimension of a tensor
44
+ that is divisible by the specified `axis_size` along the given `axis_name`.
45
+ It's useful for simple 1D mesh sharding scenarios like FSDP where parameters
46
+ are typically sharded along one dimension.
47
+
48
+ Attributes:
49
+ axis_name: The name of the mesh axis to shard along.
50
+ axis_size: The size of the mesh axis (number of devices along that axis).
51
+ """
52
+
53
+ def __init__(self, axis_name, axis_size, replicate_unshardable=False):
54
+ """Initializes the SingleAxisSharder.
55
+
56
+ Args:
57
+ axis_name: The name of the mesh axis (e.g., "fsdp", "data").
58
+ axis_size: The number of devices along the specified mesh axis.
59
+ replicate_unshardable: indicate whether it should return replicated sharding
60
+ (P()) when none of the axis is divisible by the axis size.
61
+ """
62
+ self.axis_name = axis_name
63
+ self.axis_size = axis_size
64
+ self.replicate_unshardable = replicate_unshardable
65
+
66
+ def __call__(self, name, shapedtype):
67
+ """Generates a PartitionSpec for a given tensor name and shaped type.
68
+
69
+ Args:
70
+ name: The name of the tensor (e.g., parameter name). This argument is
71
+ provided for compatibility with more complex sharders but is not used
72
+ by this simple sharder.
73
+ shapedtype: An object with a `.shape` attribute describing the tensor's shape,
74
+ and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct
75
+ or a torch.Tensor)
76
+
77
+ Returns:
78
+ A jax.sharding.PartitionSpec determined by finding the first dimension
79
+ in `shapedtype.shape` divisible by `self.axis_size` using the helper
80
+ `_shard_first_multiple_of`.
81
+ """
82
+ del name
83
+ sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape,
84
+ self.axis_size)
85
+ if not self.replicate_unshardable and all(s is None for s in sharding):
86
+ raise AssertionError(
87
+ f"Unable to find a dim to shard because "
88
+ f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
89
+ )
90
+ return sharding
91
+
92
+
93
+ class Mesh:
94
+ """A helper class that wraps `jax.sharding.Mesh` object.
95
+
96
+ The goal of this class is to provide helper methods that facilitate the
97
+ sharding of PyTorch tensors or models given a JAX device mesh configuration.
98
+ It simplifies initializing models directly into a sharded state.
99
+
100
+ Attributes:
101
+ jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid
102
+ and axis names.
103
+ _sharder: The default sharding strategy callable (like SingleAxisSharder)
104
+ used to determine the PartitionSpec for each parameter if not overridden
105
+ during method calls. Can be None if no default is appropriate or set.
106
+ """
107
+
108
+ @classmethod
109
+ def fsdp_mesh(cls, axis_name="fsdp"):
110
+ """Creates a Mesh instance suitable for 1D FSDP-style sharding.
111
+
112
+ This named constructor creates a 1D mesh encompassing all available XLA
113
+ devices. It assigns the specified `axis_name` to this single dimension.
114
+ It then creates a `Mesh` instance using this JAX mesh and a
115
+ `SingleAxisSharder` configured appropriately for this 1D mesh.
116
+
117
+ Args:
118
+ axis_name: The name to assign to the single mesh axis (default: "fsdp").
119
+ This name will be used by the default `SingleAxisSharder`.
120
+
121
+ Returns:
122
+ A Mesh instance configured with a 1D JAX mesh across all devices and a
123
+ corresponding SingleAxisSharder.
124
+ """
125
+ ndevice = jax.device_count()
126
+ jax_mesh = jax.make_mesh((ndevice,), (axis_name,))
127
+ # replicate_unshardable so scalars and small model attributes are replicated.
128
+ return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True))
129
+
130
+ def __init__(self, jax_mesh, sharder=None):
131
+ """Initializes the Mesh helper.
132
+
133
+ Args:
134
+ jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the
135
+ physical device grid and logical axis names.
136
+ sharder: An optional callable (e.g., an instance of SingleAxisSharder)
137
+ that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`.
138
+ This serves as the default sharding strategy.
139
+ If None, and the provided `jax_mesh` has exactly one axis, a
140
+ `SingleAxisSharder` is created automatically for that single axis.
141
+ If None and the mesh has multiple axes, `_sharder` remains None, and
142
+ an `override_sharder` must be provided to methods like
143
+ `initialize_model_sharded`.
144
+ """
145
+ self.jax_mesh = jax_mesh
146
+ if sharder is None:
147
+ assert len(self.jax_mesh.axis_names) == 1
148
+ sharder = SingleAxisSharder(self.jax_mesh.axis_names[0],
149
+ len(self.mesh.device_ids))
150
+ self._sharder = sharder
151
+
152
+ def initialize_model_sharded(self,
153
+ model_class,
154
+ init_args,
155
+ init_kwargs=None,
156
+ override_sharder=None):
157
+ """Initializes a PyTorch model with its parameters sharded across the mesh.
158
+
159
+ This method orchestrates the initialization of a `torch.nn.Module` such
160
+ that its parameters are created directly on the target devices according
161
+ to the sharding specifications derived from the mesh and the chosen sharder.
162
+ It leverages `torchax.interop.jax_jit` to achieve this.
163
+
164
+ Args:
165
+ model_class: The PyTorch model class (a subclass of `torch.nn.Module`).
166
+ init_args: A tuple containing the positional arguments required by the
167
+ `model_class.__init__` method.
168
+ init_kwargs: An optional dictionary containing the keyword arguments for
169
+ the `model_class.__init__` method. Defaults to None (treated as {}).
170
+ override_sharder: An optional callable sharding strategy to use
171
+ specifically for this initialization. If provided, it takes precedence
172
+ over the mesh's default `_sharder`. It must accept `(name, shapedtype)`
173
+ and return a `PartitionSpec`. If None, the mesh's default `_sharder`
174
+ is used.
175
+
176
+ Returns:
177
+ An instance of `model_class` whose parameters have been initialized and
178
+ are represented by sharded tensors distributed across the devices in the
179
+ `jax_mesh`.
180
+
181
+ Raises:
182
+ ValueError: If no sharder is available (i.e., `override_sharder` is None
183
+ and the mesh's default `_sharder` is also None).
184
+ AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`)
185
+ if it fails to determine a valid sharding for any parameter.
186
+ TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`.
187
+ Other errors from JAX JIT compilation or PyTorch model initialization.
188
+ """
189
+ init_kwargs = init_kwargs or {}
190
+ with torch.device("meta"), torchax.disable_temporarily():
191
+ model = model_class(*init_args, **init_kwargs)
192
+
193
+ sharder = override_sharder or self._sharder
194
+
195
+ states = model.state_dict()
196
+ output_shards = {
197
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
198
+ for name, tensor in states.items()
199
+ }
200
+
201
+ def model_initializer():
202
+ with torchax.default_env(), torch.device('meta'):
203
+ model = model_class(*init_args, **init_kwargs)
204
+ return dict(model.state_dict())
205
+
206
+ jitted = interop.jax_jit(
207
+ model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards})
208
+ weights_dict = jitted()
209
+
210
+ model.load_state_dict(weights_dict, assign=True)
211
+ return model
212
+
213
+ def shard_model(self, model, override_sharder=None):
214
+ sharder = override_sharder or self._sharder
215
+ states = model.state_dict()
216
+ output_shards = {
217
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
218
+ for name, tensor in states.items()
219
+ }
220
+ model.load_state_dict(output_shards, assign=True)