torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/mesh_util.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
from jax.sharding import PartitionSpec, NamedSharding
|
|
3
|
+
import torch
|
|
4
|
+
import torchax
|
|
5
|
+
from torchax import interop
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _shard_first_multiple_of(axis_name, shape, multiple_of):
|
|
9
|
+
"""Creates a PartitionSpec to shard the first dimension divisible by a number.
|
|
10
|
+
|
|
11
|
+
Iterates through the dimensions specified by `shape`. Finds the first dimension
|
|
12
|
+
whose size is a multiple of `multiple_of` and returns a PartitionSpec that
|
|
13
|
+
shards that dimension along the given `axis_name`. All preceding dimensions
|
|
14
|
+
are not sharded (marked as None in the PartitionSpec). All subsequent dimensions
|
|
15
|
+
skipped, which would be implicitly treated as replicated.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl").
|
|
19
|
+
shape: A tuple or list representing the shape of the tensor to be sharded.
|
|
20
|
+
multiple_of: The integer value that a dimension size must be divisible by
|
|
21
|
+
in order to be sharded. Typically the size of the mesh axis.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A jax.sharding.PartitionSpec object specifying how to shard the tensor.
|
|
25
|
+
For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4,
|
|
26
|
+
it would return PartitionSpec(None, 'x', None).
|
|
27
|
+
If none divides then it should return a replicated PartitionSpec
|
|
28
|
+
"""
|
|
29
|
+
sharding = []
|
|
30
|
+
found = False
|
|
31
|
+
for size in shape:
|
|
32
|
+
if not found and size % multiple_of == 0:
|
|
33
|
+
found = True
|
|
34
|
+
sharding.append(axis_name)
|
|
35
|
+
else:
|
|
36
|
+
sharding.append(None)
|
|
37
|
+
return PartitionSpec(*sharding)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SingleAxisSharder:
|
|
41
|
+
"""A callable object that generates PartitionSpecs for single-axis sharding.
|
|
42
|
+
|
|
43
|
+
This sharder strategy attempts to shard the *first* dimension of a tensor
|
|
44
|
+
that is divisible by the specified `axis_size` along the given `axis_name`.
|
|
45
|
+
It's useful for simple 1D mesh sharding scenarios like FSDP where parameters
|
|
46
|
+
are typically sharded along one dimension.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
axis_name: The name of the mesh axis to shard along.
|
|
50
|
+
axis_size: The size of the mesh axis (number of devices along that axis).
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, axis_name, axis_size, replicate_unshardable=False):
|
|
54
|
+
"""Initializes the SingleAxisSharder.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
axis_name: The name of the mesh axis (e.g., "fsdp", "data").
|
|
58
|
+
axis_size: The number of devices along the specified mesh axis.
|
|
59
|
+
replicate_unshardable: indicate whether it should return replicated sharding
|
|
60
|
+
(P()) when none of the axis is divisible by the axis size.
|
|
61
|
+
"""
|
|
62
|
+
self.axis_name = axis_name
|
|
63
|
+
self.axis_size = axis_size
|
|
64
|
+
self.replicate_unshardable = replicate_unshardable
|
|
65
|
+
|
|
66
|
+
def __call__(self, name, shapedtype):
|
|
67
|
+
"""Generates a PartitionSpec for a given tensor name and shaped type.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
name: The name of the tensor (e.g., parameter name). This argument is
|
|
71
|
+
provided for compatibility with more complex sharders but is not used
|
|
72
|
+
by this simple sharder.
|
|
73
|
+
shapedtype: An object with a `.shape` attribute describing the tensor's shape,
|
|
74
|
+
and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct
|
|
75
|
+
or a torch.Tensor)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
A jax.sharding.PartitionSpec determined by finding the first dimension
|
|
79
|
+
in `shapedtype.shape` divisible by `self.axis_size` using the helper
|
|
80
|
+
`_shard_first_multiple_of`.
|
|
81
|
+
"""
|
|
82
|
+
del name
|
|
83
|
+
sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape,
|
|
84
|
+
self.axis_size)
|
|
85
|
+
if not self.replicate_unshardable and all(s is None for s in sharding):
|
|
86
|
+
raise AssertionError(
|
|
87
|
+
f"Unable to find a dim to shard because "
|
|
88
|
+
f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
|
|
89
|
+
)
|
|
90
|
+
return sharding
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Mesh:
|
|
94
|
+
"""A helper class that wraps `jax.sharding.Mesh` object.
|
|
95
|
+
|
|
96
|
+
The goal of this class is to provide helper methods that facilitate the
|
|
97
|
+
sharding of PyTorch tensors or models given a JAX device mesh configuration.
|
|
98
|
+
It simplifies initializing models directly into a sharded state.
|
|
99
|
+
|
|
100
|
+
Attributes:
|
|
101
|
+
jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid
|
|
102
|
+
and axis names.
|
|
103
|
+
_sharder: The default sharding strategy callable (like SingleAxisSharder)
|
|
104
|
+
used to determine the PartitionSpec for each parameter if not overridden
|
|
105
|
+
during method calls. Can be None if no default is appropriate or set.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def fsdp_mesh(cls, axis_name="fsdp"):
|
|
110
|
+
"""Creates a Mesh instance suitable for 1D FSDP-style sharding.
|
|
111
|
+
|
|
112
|
+
This named constructor creates a 1D mesh encompassing all available XLA
|
|
113
|
+
devices. It assigns the specified `axis_name` to this single dimension.
|
|
114
|
+
It then creates a `Mesh` instance using this JAX mesh and a
|
|
115
|
+
`SingleAxisSharder` configured appropriately for this 1D mesh.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
axis_name: The name to assign to the single mesh axis (default: "fsdp").
|
|
119
|
+
This name will be used by the default `SingleAxisSharder`.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
A Mesh instance configured with a 1D JAX mesh across all devices and a
|
|
123
|
+
corresponding SingleAxisSharder.
|
|
124
|
+
"""
|
|
125
|
+
ndevice = jax.device_count()
|
|
126
|
+
jax_mesh = jax.make_mesh((ndevice,), (axis_name,))
|
|
127
|
+
# replicate_unshardable so scalars and small model attributes are replicated.
|
|
128
|
+
return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True))
|
|
129
|
+
|
|
130
|
+
def __init__(self, jax_mesh, sharder=None):
|
|
131
|
+
"""Initializes the Mesh helper.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the
|
|
135
|
+
physical device grid and logical axis names.
|
|
136
|
+
sharder: An optional callable (e.g., an instance of SingleAxisSharder)
|
|
137
|
+
that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`.
|
|
138
|
+
This serves as the default sharding strategy.
|
|
139
|
+
If None, and the provided `jax_mesh` has exactly one axis, a
|
|
140
|
+
`SingleAxisSharder` is created automatically for that single axis.
|
|
141
|
+
If None and the mesh has multiple axes, `_sharder` remains None, and
|
|
142
|
+
an `override_sharder` must be provided to methods like
|
|
143
|
+
`initialize_model_sharded`.
|
|
144
|
+
"""
|
|
145
|
+
self.jax_mesh = jax_mesh
|
|
146
|
+
if sharder is None:
|
|
147
|
+
assert len(self.jax_mesh.axis_names) == 1
|
|
148
|
+
sharder = SingleAxisSharder(self.jax_mesh.axis_names[0],
|
|
149
|
+
len(self.mesh.device_ids))
|
|
150
|
+
self._sharder = sharder
|
|
151
|
+
|
|
152
|
+
def initialize_model_sharded(self,
|
|
153
|
+
model_class,
|
|
154
|
+
init_args,
|
|
155
|
+
init_kwargs=None,
|
|
156
|
+
override_sharder=None):
|
|
157
|
+
"""Initializes a PyTorch model with its parameters sharded across the mesh.
|
|
158
|
+
|
|
159
|
+
This method orchestrates the initialization of a `torch.nn.Module` such
|
|
160
|
+
that its parameters are created directly on the target devices according
|
|
161
|
+
to the sharding specifications derived from the mesh and the chosen sharder.
|
|
162
|
+
It leverages `torchax.interop.jax_jit` to achieve this.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
model_class: The PyTorch model class (a subclass of `torch.nn.Module`).
|
|
166
|
+
init_args: A tuple containing the positional arguments required by the
|
|
167
|
+
`model_class.__init__` method.
|
|
168
|
+
init_kwargs: An optional dictionary containing the keyword arguments for
|
|
169
|
+
the `model_class.__init__` method. Defaults to None (treated as {}).
|
|
170
|
+
override_sharder: An optional callable sharding strategy to use
|
|
171
|
+
specifically for this initialization. If provided, it takes precedence
|
|
172
|
+
over the mesh's default `_sharder`. It must accept `(name, shapedtype)`
|
|
173
|
+
and return a `PartitionSpec`. If None, the mesh's default `_sharder`
|
|
174
|
+
is used.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
An instance of `model_class` whose parameters have been initialized and
|
|
178
|
+
are represented by sharded tensors distributed across the devices in the
|
|
179
|
+
`jax_mesh`.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If no sharder is available (i.e., `override_sharder` is None
|
|
183
|
+
and the mesh's default `_sharder` is also None).
|
|
184
|
+
AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`)
|
|
185
|
+
if it fails to determine a valid sharding for any parameter.
|
|
186
|
+
TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`.
|
|
187
|
+
Other errors from JAX JIT compilation or PyTorch model initialization.
|
|
188
|
+
"""
|
|
189
|
+
init_kwargs = init_kwargs or {}
|
|
190
|
+
with torch.device("meta"), torchax.disable_temporarily():
|
|
191
|
+
model = model_class(*init_args, **init_kwargs)
|
|
192
|
+
|
|
193
|
+
sharder = override_sharder or self._sharder
|
|
194
|
+
|
|
195
|
+
states = model.state_dict()
|
|
196
|
+
output_shards = {
|
|
197
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
198
|
+
for name, tensor in states.items()
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def model_initializer():
|
|
202
|
+
with torchax.default_env(), torch.device('meta'):
|
|
203
|
+
model = model_class(*init_args, **init_kwargs)
|
|
204
|
+
return dict(model.state_dict())
|
|
205
|
+
|
|
206
|
+
jitted = interop.jax_jit(
|
|
207
|
+
model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards})
|
|
208
|
+
weights_dict = jitted()
|
|
209
|
+
|
|
210
|
+
model.load_state_dict(weights_dict, assign=True)
|
|
211
|
+
return model
|
|
212
|
+
|
|
213
|
+
def shard_model(self, model, override_sharder=None):
|
|
214
|
+
sharder = override_sharder or self._sharder
|
|
215
|
+
states = model.state_dict()
|
|
216
|
+
output_shards = {
|
|
217
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
218
|
+
for name, tensor in states.items()
|
|
219
|
+
}
|
|
220
|
+
model.load_state_dict(output_shards, assign=True)
|