torchax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +5 -41
- torchax/amp.py +2 -3
- torchax/config.py +5 -1
- torchax/configuration.py +30 -0
- torchax/device_module.py +7 -0
- torchax/environment.py +1 -0
- torchax/interop.py +27 -14
- torchax/mesh_util.py +10 -1
- torchax/ops/jaten.py +5 -3
- torchax/ops/jtorch.py +18 -10
- torchax/tensor.py +127 -115
- {torchax-0.0.5.dist-info → torchax-0.0.6.dist-info}/METADATA +1 -1
- {torchax-0.0.5.dist-info → torchax-0.0.6.dist-info}/RECORD +15 -14
- torchax/distributed.py +0 -241
- {torchax-0.0.5.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.5.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/__init__.py
CHANGED
|
@@ -6,10 +6,9 @@ import os
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch.utils import _pytree as pytree
|
|
8
8
|
from torchax import tensor
|
|
9
|
-
from torchax import distributed # noqa: F401
|
|
10
9
|
from contextlib import contextmanager
|
|
11
10
|
|
|
12
|
-
__version__ = "0.0.
|
|
11
|
+
__version__ = "0.0.6"
|
|
13
12
|
VERSION = __version__
|
|
14
13
|
|
|
15
14
|
__all__ = [
|
|
@@ -50,10 +49,11 @@ def extract_jax(mod: torch.nn.Module, env=None):
|
|
|
50
49
|
states = env.t2j_copy(states)
|
|
51
50
|
|
|
52
51
|
#@jax.jit
|
|
53
|
-
def jax_func(states,
|
|
54
|
-
(states,
|
|
52
|
+
def jax_func(states, args, kwargs=None):
|
|
53
|
+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
|
|
55
54
|
with env:
|
|
56
|
-
res = torch.func.functional_call(
|
|
55
|
+
res = torch.func.functional_call(
|
|
56
|
+
mod, states, args, kwargs, tie_weights=False)
|
|
57
57
|
return env.t2j_iso(res)
|
|
58
58
|
|
|
59
59
|
return states, jax_func
|
|
@@ -81,11 +81,6 @@ def disable_temporarily():
|
|
|
81
81
|
|
|
82
82
|
torch.utils.rename_privateuse1_backend('jax')
|
|
83
83
|
unsupported_dtype = [torch.quint8]
|
|
84
|
-
torch.utils.generate_methods_for_privateuse1_backend(
|
|
85
|
-
for_tensor=True,
|
|
86
|
-
for_module=True,
|
|
87
|
-
for_storage=True,
|
|
88
|
-
unsupported_dtype=unsupported_dtype)
|
|
89
84
|
|
|
90
85
|
import jax
|
|
91
86
|
import torchax.device_module
|
|
@@ -129,34 +124,3 @@ def compile(fn, options: Optional[CompileOptions] = None):
|
|
|
129
124
|
raise RuntimeError('dynamo mode is not supported yet')
|
|
130
125
|
elif options.mode == 'export':
|
|
131
126
|
raise RuntimeError('export mode is not supported yet')
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
@contextmanager
|
|
135
|
-
def jax_device(target_device: str, env: tensor.Environment | None = None):
|
|
136
|
-
"""
|
|
137
|
-
to("jax") cannot differentiate the device/platform (cpu vs tpu).
|
|
138
|
-
Use this context manager to control jax array's storage device
|
|
139
|
-
|
|
140
|
-
Examples:
|
|
141
|
-
|
|
142
|
-
a = torch.ones(3, 3)
|
|
143
|
-
|
|
144
|
-
with jax_device("cpu"):
|
|
145
|
-
b = a.to("jax")
|
|
146
|
-
|
|
147
|
-
with jax_device("tpu"):
|
|
148
|
-
c = a.to("jax")
|
|
149
|
-
|
|
150
|
-
with jax_device("tpu"):
|
|
151
|
-
c = b.to("jax")
|
|
152
|
-
|
|
153
|
-
"""
|
|
154
|
-
if env is None:
|
|
155
|
-
env = default_env()
|
|
156
|
-
|
|
157
|
-
prev_target_device = env.target_device
|
|
158
|
-
try:
|
|
159
|
-
env.target_device = target_device
|
|
160
|
-
yield env
|
|
161
|
-
finally:
|
|
162
|
-
env.target_device = prev_target_device
|
torchax/amp.py
CHANGED
|
@@ -61,9 +61,8 @@ def autocast(device, dtype=torch.bfloat16, env=None):
|
|
|
61
61
|
if env is None:
|
|
62
62
|
import torchax
|
|
63
63
|
env = torchax.default_env()
|
|
64
|
-
env.autocast_dtype
|
|
65
|
-
|
|
66
|
-
env.autocast_dtype = old
|
|
64
|
+
with env.override_property(autocast_dtype=dtype):
|
|
65
|
+
yield
|
|
67
66
|
|
|
68
67
|
|
|
69
68
|
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
|
torchax/config.py
CHANGED
|
@@ -10,6 +10,11 @@ class Configuration:
|
|
|
10
10
|
|
|
11
11
|
use_int32_for_index: bool = False
|
|
12
12
|
|
|
13
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
14
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
15
|
+
# can use scalar * tensor math to handle it
|
|
16
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
17
|
+
|
|
13
18
|
# If true, we will convert Views into torchax.Tensors eagerly
|
|
14
19
|
force_materialize_views: bool = False
|
|
15
20
|
|
|
@@ -22,5 +27,4 @@ class Configuration:
|
|
|
22
27
|
|
|
23
28
|
# device
|
|
24
29
|
treat_cuda_as_jax_device: bool = True
|
|
25
|
-
use_torch_native_for_cpu_tensor: bool = True
|
|
26
30
|
internal_respect_torch_return_dtypes: bool = False
|
torchax/configuration.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclasses.dataclass
|
|
5
|
+
class Configuration:
|
|
6
|
+
debug_print_each_op: bool = False
|
|
7
|
+
debug_accuracy_for_each_op: bool = False
|
|
8
|
+
debug_mixed_tensor: bool = False
|
|
9
|
+
debug_print_each_op_operands: bool = False
|
|
10
|
+
|
|
11
|
+
use_int32_for_index: bool = False
|
|
12
|
+
|
|
13
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
14
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
15
|
+
# can use scalar * tensor math to handle it
|
|
16
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
17
|
+
|
|
18
|
+
# If true, we will convert Views into torchax.Tensors eagerly
|
|
19
|
+
force_materialize_views: bool = False
|
|
20
|
+
|
|
21
|
+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
22
|
+
use_dlpack_for_data_conversion: bool = False
|
|
23
|
+
|
|
24
|
+
# Flash attention
|
|
25
|
+
use_tpu_flash_attention: bool = False
|
|
26
|
+
shmap_flash_attention: bool = False
|
|
27
|
+
|
|
28
|
+
# device
|
|
29
|
+
treat_cuda_as_jax_device: bool = True
|
|
30
|
+
internal_respect_torch_return_dtypes: bool = False
|
torchax/device_module.py
CHANGED
torchax/environment.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
torchax/interop.py
CHANGED
|
@@ -11,6 +11,7 @@ from jax import tree_util as pytree
|
|
|
11
11
|
from jax.experimental.shard_map import shard_map
|
|
12
12
|
from torchax import tensor
|
|
13
13
|
from torchax import util
|
|
14
|
+
from torchax.ops import mappings
|
|
14
15
|
import torchax
|
|
15
16
|
|
|
16
17
|
from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
|
|
@@ -90,7 +91,7 @@ class JittableModule(torch.nn.Module):
|
|
|
90
91
|
def __call__(self, *args, **kwargs):
|
|
91
92
|
return self.forward(*args, **kwargs)
|
|
92
93
|
|
|
93
|
-
def functional_call(self,
|
|
94
|
+
def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
|
|
94
95
|
kwargs = kwargs or {}
|
|
95
96
|
params_copy = copy.copy(params)
|
|
96
97
|
params_copy.update(buffers)
|
|
@@ -98,22 +99,35 @@ class JittableModule(torch.nn.Module):
|
|
|
98
99
|
for k, v in self._extra_dumped_weights.items():
|
|
99
100
|
for new_key in v:
|
|
100
101
|
params_copy[new_key] = params_copy[k]
|
|
102
|
+
|
|
103
|
+
if isinstance(method_or_name, str):
|
|
104
|
+
method = getattr(self._model, method_or_name)
|
|
105
|
+
else:
|
|
106
|
+
if not callable(method_or_name):
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"method_or_name should be a callable or a string, got {type(method_or_name)}"
|
|
109
|
+
)
|
|
110
|
+
method = method_or_name
|
|
111
|
+
args = (self._model,) + args
|
|
101
112
|
with torch_stateless._reparametrize_module(self._model, params_copy):
|
|
102
|
-
res =
|
|
113
|
+
res = method(*args, **kwargs)
|
|
103
114
|
return res
|
|
104
115
|
|
|
105
|
-
def
|
|
106
|
-
if
|
|
116
|
+
def jittable_call(self, method_name: str, *args, **kwargs):
|
|
117
|
+
if method_name not in self._jitted:
|
|
107
118
|
jitted = jax_jit(
|
|
108
|
-
functools.partial(self.functional_call,
|
|
119
|
+
functools.partial(self.functional_call, method_name),
|
|
109
120
|
kwargs_for_jax_jit=self._extra_jit_args,
|
|
110
121
|
)
|
|
111
122
|
|
|
112
123
|
def jitted_forward(*args, **kwargs):
|
|
113
124
|
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
114
125
|
|
|
115
|
-
self._jitted[
|
|
116
|
-
return self._jitted[
|
|
126
|
+
self._jitted[method_name] = jitted_forward
|
|
127
|
+
return self._jitted[method_name](*args, **kwargs)
|
|
128
|
+
|
|
129
|
+
def forward(self, *args, **kwargs):
|
|
130
|
+
return self.jittable_call('forward', *args, **kwargs)
|
|
117
131
|
|
|
118
132
|
def __getattr__(self, key):
|
|
119
133
|
if key == '_model':
|
|
@@ -170,8 +184,8 @@ def _torch_view(t: JaxValue) -> TorchValue:
|
|
|
170
184
|
if isinstance(t, jax.Array):
|
|
171
185
|
# TODO
|
|
172
186
|
return tensor.Tensor(t, torchax.default_env())
|
|
173
|
-
if isinstance(t,
|
|
174
|
-
return
|
|
187
|
+
if isinstance(t, jnp.dtype):
|
|
188
|
+
return mappings.j2t_dtype(t)
|
|
175
189
|
if callable(t): # t is a JaxCallable
|
|
176
190
|
return functools.partial(call_jax, t)
|
|
177
191
|
# regular types are not changed
|
|
@@ -188,7 +202,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
|
|
|
188
202
|
assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
|
|
189
203
|
return t.jax()
|
|
190
204
|
if isinstance(t, type(torch.int32)):
|
|
191
|
-
return
|
|
205
|
+
return mappings.t2j_dtype(t)
|
|
192
206
|
|
|
193
207
|
# torch.nn.Module needs special handling
|
|
194
208
|
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
|
|
@@ -225,8 +239,7 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
225
239
|
|
|
226
240
|
@wraps(fn)
|
|
227
241
|
def inner(*args, **kwargs):
|
|
228
|
-
from jax.tree_util import tree_flatten
|
|
229
|
-
from jax.util import safe_zip
|
|
242
|
+
from jax.tree_util import tree_flatten
|
|
230
243
|
|
|
231
244
|
class JaxFun(torch.autograd.Function):
|
|
232
245
|
|
|
@@ -261,8 +274,8 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
261
274
|
# The subsequent gradients correspond to flat_inputs.
|
|
262
275
|
# We need to put a None for inputs that did not require gradients.
|
|
263
276
|
final_grads = [None]
|
|
264
|
-
for needs_grad, grad in
|
|
265
|
-
|
|
277
|
+
for needs_grad, grad in zip(
|
|
278
|
+
ctx.needs_input_grad[1:], input_grads_structured, strict=True):
|
|
266
279
|
final_grads.append(grad if needs_grad else None)
|
|
267
280
|
|
|
268
281
|
return tuple(final_grads)
|
torchax/mesh_util.py
CHANGED
|
@@ -199,7 +199,7 @@ class Mesh:
|
|
|
199
199
|
}
|
|
200
200
|
|
|
201
201
|
def model_initializer():
|
|
202
|
-
with torchax.default_env():
|
|
202
|
+
with torchax.default_env(), torch.device('meta'):
|
|
203
203
|
model = model_class(*init_args, **init_kwargs)
|
|
204
204
|
return dict(model.state_dict())
|
|
205
205
|
|
|
@@ -209,3 +209,12 @@ class Mesh:
|
|
|
209
209
|
|
|
210
210
|
model.load_state_dict(weights_dict, assign=True)
|
|
211
211
|
return model
|
|
212
|
+
|
|
213
|
+
def shard_model(self, model, override_sharder=None):
|
|
214
|
+
sharder = override_sharder or self._sharder
|
|
215
|
+
states = model.state_dict()
|
|
216
|
+
output_shards = {
|
|
217
|
+
name: NamedSharding(self.jax_mesh, sharder(name, tensor))
|
|
218
|
+
for name, tensor in states.items()
|
|
219
|
+
}
|
|
220
|
+
model.load_state_dict(output_shards, assign=True)
|
torchax/ops/jaten.py
CHANGED
|
@@ -736,7 +736,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
|
|
|
736
736
|
return jnp.empty(sizes, dtype=dtype)
|
|
737
737
|
|
|
738
738
|
|
|
739
|
-
@op(torch.ops.aten.index_put_)
|
|
740
739
|
@op(torch.ops.aten.index_put)
|
|
741
740
|
def _aten_index_put(self, indexes, values, accumulate=False):
|
|
742
741
|
indexes = [slice(None, None, None) if i is None else i for i in indexes]
|
|
@@ -3532,7 +3531,7 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):
|
|
|
3532
3531
|
|
|
3533
3532
|
@op(torch.ops.aten.randn, needs_env=True)
|
|
3534
3533
|
@op_base.convert_dtype()
|
|
3535
|
-
def
|
|
3534
|
+
def _aten_randn(
|
|
3536
3535
|
*size,
|
|
3537
3536
|
generator=None,
|
|
3538
3537
|
out=None,
|
|
@@ -3652,7 +3651,7 @@ def _aten_native_batch_norm(input,
|
|
|
3652
3651
|
@op(torch.ops.aten.normal, needs_env=True)
|
|
3653
3652
|
def _aten_normal(self, mean=0, std=1, generator=None, env=None):
|
|
3654
3653
|
shape = self.shape
|
|
3655
|
-
res =
|
|
3654
|
+
res = _aten_randn(*shape, generator=generator, env=env)
|
|
3656
3655
|
return res * std + mean
|
|
3657
3656
|
|
|
3658
3657
|
|
|
@@ -5541,6 +5540,7 @@ def _aten_floor_divide(x, y):
|
|
|
5541
5540
|
|
|
5542
5541
|
|
|
5543
5542
|
@op(torch.ops.aten._assert_tensor_metadata)
|
|
5543
|
+
@op(torch.ops.aten._assert_scalar)
|
|
5544
5544
|
def _aten__assert_tensor_metadata(*args, **kwargs):
|
|
5545
5545
|
pass
|
|
5546
5546
|
|
|
@@ -5617,6 +5617,8 @@ mutation_ops_to_functional = {
|
|
|
5617
5617
|
op_base.InplaceOp(torch.ops.aten.floor_divide),
|
|
5618
5618
|
torch.ops.aten.remainder_:
|
|
5619
5619
|
op_base.InplaceOp(torch.ops.aten.remainder),
|
|
5620
|
+
torch.ops.aten.index_put_:
|
|
5621
|
+
op_base.InplaceOp(torch.ops.aten.index_put),
|
|
5620
5622
|
}
|
|
5621
5623
|
|
|
5622
5624
|
# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
|
torchax/ops/jtorch.py
CHANGED
|
@@ -179,6 +179,13 @@ def _tpu_flash_attention(query, key, value, env):
|
|
|
179
179
|
return wrap_flash_attention(query, key, value)
|
|
180
180
|
|
|
181
181
|
|
|
182
|
+
@register_function(torch.nn.functional.one_hot)
|
|
183
|
+
def one_hot(tensor, num_classes=-1):
|
|
184
|
+
if num_classes == -1:
|
|
185
|
+
num_classes = jnp.max(tensor) + 1
|
|
186
|
+
return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
|
|
187
|
+
|
|
188
|
+
|
|
182
189
|
@register_function(torch.nn.functional.pad)
|
|
183
190
|
def pad(tensor, pad, mode="constant", value=None):
|
|
184
191
|
# For padding modes that have different names between Torch and NumPy, this
|
|
@@ -341,7 +348,7 @@ def empty(*size: Sequence[int], dtype=None, **kwargs):
|
|
|
341
348
|
return jnp.empty(size, dtype=dtype)
|
|
342
349
|
|
|
343
350
|
|
|
344
|
-
@register_function(torch.arange, is_jax_function=
|
|
351
|
+
@register_function(torch.arange, is_jax_function=True)
|
|
345
352
|
def arange(
|
|
346
353
|
start,
|
|
347
354
|
end=None,
|
|
@@ -358,10 +365,10 @@ def arange(
|
|
|
358
365
|
start = 0
|
|
359
366
|
if step is None:
|
|
360
367
|
step = 1
|
|
361
|
-
return
|
|
368
|
+
return jaten._aten_arange(start, end, step, dtype=dtype)
|
|
362
369
|
|
|
363
370
|
|
|
364
|
-
@register_function(torch.empty_strided, is_jax_function=
|
|
371
|
+
@register_function(torch.empty_strided, is_jax_function=True)
|
|
365
372
|
def empty_strided(
|
|
366
373
|
size,
|
|
367
374
|
stride,
|
|
@@ -372,7 +379,7 @@ def empty_strided(
|
|
|
372
379
|
requires_grad=False,
|
|
373
380
|
pin_memory=False,
|
|
374
381
|
):
|
|
375
|
-
return empty(size, dtype=dtype)
|
|
382
|
+
return empty(size, dtype=dtype, requires_grad=requires_grad)
|
|
376
383
|
|
|
377
384
|
|
|
378
385
|
@register_function(torch.unravel_index)
|
|
@@ -380,14 +387,14 @@ def unravel_index(indices, shape):
|
|
|
380
387
|
return jnp.unravel_index(indices, shape)
|
|
381
388
|
|
|
382
389
|
|
|
383
|
-
@register_function(torch.rand, is_jax_function=
|
|
390
|
+
@register_function(torch.rand, is_jax_function=True, needs_env=True)
|
|
384
391
|
def rand(*size, **kwargs):
|
|
385
392
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
386
393
|
size = size[0]
|
|
387
|
-
return
|
|
394
|
+
return jaten._rand(size, **kwargs)
|
|
388
395
|
|
|
389
396
|
|
|
390
|
-
@register_function(torch.randn, is_jax_function=
|
|
397
|
+
@register_function(torch.randn, is_jax_function=True, needs_env=True)
|
|
391
398
|
def randn(
|
|
392
399
|
*size,
|
|
393
400
|
generator=None,
|
|
@@ -397,15 +404,16 @@ def randn(
|
|
|
397
404
|
device=None,
|
|
398
405
|
requires_grad=False,
|
|
399
406
|
pin_memory=False,
|
|
407
|
+
env=None,
|
|
400
408
|
):
|
|
401
409
|
if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
|
|
402
410
|
size = size[0]
|
|
403
|
-
return
|
|
411
|
+
return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
|
|
404
412
|
|
|
405
413
|
|
|
406
|
-
@register_function(torch.randint, is_jax_function=False)
|
|
414
|
+
@register_function(torch.randint, is_jax_function=False, needs_env=True)
|
|
407
415
|
def randint(*args, **kwargs):
|
|
408
|
-
return
|
|
416
|
+
return jaten._aten_randint(*args, **kwargs)
|
|
409
417
|
|
|
410
418
|
|
|
411
419
|
@register_function(torch.logdet)
|
torchax/tensor.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import threading
|
|
1
2
|
import logging
|
|
2
3
|
import sys
|
|
3
4
|
import contextlib
|
|
@@ -16,7 +17,6 @@ from torchax.view import View
|
|
|
16
17
|
from torchax import config
|
|
17
18
|
from torchax.ops import mappings, ops_registry
|
|
18
19
|
from torchax import amp
|
|
19
|
-
from jax.experimental import mutable_array
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
@@ -25,14 +25,6 @@ class OperatorNotFound(Exception):
|
|
|
25
25
|
pass
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def wrap(jaxarray):
|
|
29
|
-
return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def unwrap(torchtensors):
|
|
33
|
-
return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
|
|
34
|
-
|
|
35
|
-
|
|
36
28
|
@contextlib.contextmanager
|
|
37
29
|
def log_nested(env, message):
|
|
38
30
|
if env.config.debug_print_each_op:
|
|
@@ -48,7 +40,7 @@ log_nested.level = 0
|
|
|
48
40
|
class Tensor(torch.Tensor):
|
|
49
41
|
|
|
50
42
|
@staticmethod
|
|
51
|
-
def __new__(cls, elem, env):
|
|
43
|
+
def __new__(cls, elem, env, requires_grad=False):
|
|
52
44
|
dtype = mappings.j2t_dtype(elem.dtype)
|
|
53
45
|
shape = list(elem.shape)
|
|
54
46
|
for i, s in enumerate(shape):
|
|
@@ -56,15 +48,19 @@ class Tensor(torch.Tensor):
|
|
|
56
48
|
shape[i] = 1
|
|
57
49
|
if dtype is None:
|
|
58
50
|
dtype = torch.float32
|
|
51
|
+
#dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
|
|
52
|
+
if not (dtype.is_floating_point or dtype.is_complex):
|
|
53
|
+
requires_grad = False
|
|
54
|
+
|
|
59
55
|
return torch.Tensor._make_wrapper_subclass(
|
|
60
56
|
cls,
|
|
61
57
|
shape,
|
|
62
58
|
dtype=dtype,
|
|
63
|
-
device=
|
|
64
|
-
requires_grad=
|
|
59
|
+
device='meta',
|
|
60
|
+
requires_grad=requires_grad,
|
|
65
61
|
)
|
|
66
62
|
|
|
67
|
-
def __init__(self, elem: jax.Array, env: "Environment"):
|
|
63
|
+
def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
|
|
68
64
|
super().__init__()
|
|
69
65
|
self._elem = elem
|
|
70
66
|
self._env = env
|
|
@@ -74,9 +70,6 @@ class Tensor(torch.Tensor):
|
|
|
74
70
|
|
|
75
71
|
__repr__ = __str__
|
|
76
72
|
|
|
77
|
-
def __jax_array__(self):
|
|
78
|
-
return self._elem
|
|
79
|
-
|
|
80
73
|
@property
|
|
81
74
|
def shape(self):
|
|
82
75
|
return torch.Size(self._elem.shape)
|
|
@@ -109,6 +102,8 @@ class Tensor(torch.Tensor):
|
|
|
109
102
|
# TODO(hanq): figure out why is dispatch mode not sufficient
|
|
110
103
|
if func == torch.ops._c10d_functional.wait_tensor.default:
|
|
111
104
|
return args[0]._env.dispatch(func, types, args, kwargs)
|
|
105
|
+
if func == torch.ops.prim.device.default:
|
|
106
|
+
return torch.device('privateuseone', 0)
|
|
112
107
|
raise AssertionError(
|
|
113
108
|
'torchax Tensors can only do math within the torchax environment.'
|
|
114
109
|
'Please wrap your code with `with torchax.default_env()` or '
|
|
@@ -298,6 +293,38 @@ TENSOR_CONSTRUCTORS = {
|
|
|
298
293
|
SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
|
|
299
294
|
|
|
300
295
|
|
|
296
|
+
class RuntimeProperty:
|
|
297
|
+
mesh: Any
|
|
298
|
+
prng: Any
|
|
299
|
+
autocast_dtype: Any
|
|
300
|
+
|
|
301
|
+
def __init__(self, mesh, prng, autocast_dtype):
|
|
302
|
+
self.mesh = mesh
|
|
303
|
+
self.prng = prng
|
|
304
|
+
self.autocast_dtype = autocast_dtype
|
|
305
|
+
|
|
306
|
+
def override(self, **kwargs):
|
|
307
|
+
return OverrideProperty(self, kwargs)
|
|
308
|
+
|
|
309
|
+
def get_and_rotate_prng_key(self):
|
|
310
|
+
old_key = self.prng
|
|
311
|
+
new_prng_key, next_key = jax.random.split(old_key)
|
|
312
|
+
self.prng = new_prng_key
|
|
313
|
+
return next_key
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class OverrideProperty(RuntimeProperty):
|
|
317
|
+
|
|
318
|
+
def __init__(self, parent, override):
|
|
319
|
+
self.parent = parent
|
|
320
|
+
self._override = dict(override)
|
|
321
|
+
|
|
322
|
+
def __getattr__(self, name):
|
|
323
|
+
if name in self._override:
|
|
324
|
+
return self._override[name]
|
|
325
|
+
return getattr(self.parent, name)
|
|
326
|
+
|
|
327
|
+
|
|
301
328
|
class Environment(contextlib.ContextDecorator):
|
|
302
329
|
"""This class holds a set of configurations and "globals" needed
|
|
303
330
|
|
|
@@ -321,62 +348,55 @@ class Environment(contextlib.ContextDecorator):
|
|
|
321
348
|
|
|
322
349
|
self.load_ops()
|
|
323
350
|
|
|
324
|
-
|
|
351
|
+
_mesh = None
|
|
325
352
|
self.config = configuration or config.Configuration()
|
|
326
353
|
|
|
327
|
-
self._manually_entered = False
|
|
328
354
|
self.enabled = False
|
|
329
355
|
|
|
330
|
-
|
|
331
|
-
jax.random.key(torch.initial_seed() % (1 << 63)))
|
|
332
|
-
self.autocast_dtype = None
|
|
333
|
-
self._target_device = jax.local_devices()[0].platform
|
|
356
|
+
autocast_dtype = None
|
|
334
357
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
358
|
+
_prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
|
|
359
|
+
self._property = threading.local()
|
|
360
|
+
self._property.content = [
|
|
361
|
+
RuntimeProperty(
|
|
362
|
+
mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
|
|
363
|
+
]
|
|
338
364
|
|
|
339
|
-
@
|
|
340
|
-
def
|
|
341
|
-
self.
|
|
365
|
+
@property
|
|
366
|
+
def param(self):
|
|
367
|
+
return self._property.content[-1]
|
|
342
368
|
|
|
343
369
|
def manual_seed(self, key):
|
|
344
|
-
|
|
370
|
+
jax_key = jax.random.PRNGKey(key)
|
|
371
|
+
new_prop = self.param.override(prng=jax_key)
|
|
372
|
+
self._property.content.append(new_prop)
|
|
345
373
|
|
|
346
374
|
@property
|
|
347
375
|
def prng_key(self):
|
|
348
|
-
return self.
|
|
376
|
+
return self.param.prng
|
|
349
377
|
|
|
350
|
-
def
|
|
378
|
+
def _should_use_torchax_tensor(self, device):
|
|
351
379
|
if device is None:
|
|
352
380
|
device = torch.get_default_device()
|
|
353
381
|
|
|
354
382
|
if isinstance(device, torch.device):
|
|
355
|
-
device =
|
|
356
|
-
|
|
357
|
-
if
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
return jax.devices("cpu")[0]
|
|
373
|
-
case "tpu":
|
|
374
|
-
return jax.devices("tpu")[0]
|
|
375
|
-
case _:
|
|
376
|
-
raise AttributeError(
|
|
377
|
-
f"Cannot handle env.target_device {self.target_device}")
|
|
378
|
-
|
|
379
|
-
return None # fallback to torch
|
|
383
|
+
device = device.type
|
|
384
|
+
|
|
385
|
+
if ':' in device:
|
|
386
|
+
device = device.split(':')[0]
|
|
387
|
+
|
|
388
|
+
match device:
|
|
389
|
+
case 'cpu':
|
|
390
|
+
return False
|
|
391
|
+
case 'cuda':
|
|
392
|
+
return self.config.treat_cuda_as_jax_device
|
|
393
|
+
case 'jax':
|
|
394
|
+
return True
|
|
395
|
+
case 'privateuseone':
|
|
396
|
+
return True
|
|
397
|
+
case 'meta':
|
|
398
|
+
return self.enabled
|
|
399
|
+
return False
|
|
380
400
|
|
|
381
401
|
def load_ops(self):
|
|
382
402
|
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
|
|
@@ -423,80 +443,63 @@ class Environment(contextlib.ContextDecorator):
|
|
|
423
443
|
|
|
424
444
|
return op
|
|
425
445
|
|
|
446
|
+
def _is_same_device(self, the_tensor, new_device):
|
|
447
|
+
if new_device is None:
|
|
448
|
+
return True
|
|
449
|
+
if new_device == 'meta' and the_tensor.device.type == 'jax':
|
|
450
|
+
return True
|
|
451
|
+
if the_tensor.device.type != new_device:
|
|
452
|
+
if the_tensor.device.type == 'cuda':
|
|
453
|
+
return self.config.treat_cuda_as_jax_device
|
|
454
|
+
return False
|
|
455
|
+
return True
|
|
456
|
+
|
|
426
457
|
def _to_copy(self, the_tensor, new_dtype, new_device):
|
|
427
458
|
if isinstance(the_tensor, View):
|
|
428
459
|
the_tensor = the_tensor.torch()
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
arr = arr.astype(mappings.t2j_dtype(new_dtype))
|
|
436
|
-
|
|
437
|
-
if new_device is not None:
|
|
438
|
-
match str(new_device).lower():
|
|
439
|
-
case "cpu":
|
|
440
|
-
# converting to a non-jax device: let torch native handle it
|
|
441
|
-
torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor,
|
|
442
|
-
Tensor) else arr
|
|
443
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
444
|
-
return torch_tensor.to(new_device)
|
|
445
|
-
case "jax":
|
|
446
|
-
# move torchax.tensor / jax tensor between devices
|
|
447
|
-
# I don't know ifgit this will work after the model is jitted
|
|
448
|
-
if self.target_device != the_tensor.jax_device.platform:
|
|
449
|
-
arr = jax.device_put(the_tensor.jax(),
|
|
450
|
-
jax.devices(self.target_device)[0])
|
|
451
|
-
return Tensor(arr, self)
|
|
452
|
-
case _:
|
|
453
|
-
logging.error(f"torchax.Tenosr cannot handle device {new_device}")
|
|
454
|
-
|
|
455
|
-
else:
|
|
456
|
-
if new_dtype is not None and new_dtype != the_tensor.dtype:
|
|
460
|
+
if isinstance(new_device, torch.device):
|
|
461
|
+
new_device = new_device.type
|
|
462
|
+
res = the_tensor
|
|
463
|
+
if not self._is_same_device(the_tensor, new_device):
|
|
464
|
+
if isinstance(the_tensor, Tensor):
|
|
465
|
+
torch_tensor = self.j2t_copy(the_tensor._elem)
|
|
457
466
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
if new_device is None: ## device is None means don't change device
|
|
461
|
-
return the_tensor
|
|
462
|
-
|
|
463
|
-
jax_device = self.get_as_jax_device(new_device)
|
|
464
|
-
if jax_device:
|
|
467
|
+
return torch_tensor.to(device=new_device, dtype=new_dtype)
|
|
468
|
+
else:
|
|
465
469
|
arr = self.t2j_copy(the_tensor)
|
|
466
|
-
|
|
470
|
+
res = Tensor(arr, self, the_tensor.requires_grad)
|
|
471
|
+
|
|
472
|
+
if new_dtype is not None and new_dtype != the_tensor.dtype:
|
|
473
|
+
if isinstance(the_tensor, Tensor):
|
|
474
|
+
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
|
|
467
475
|
else:
|
|
468
476
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
469
|
-
return the_tensor.to(new_device)
|
|
470
|
-
|
|
471
|
-
return Tensor(arr, self)
|
|
477
|
+
return the_tensor.to(device=new_device, dtype=new_dtype)
|
|
478
|
+
return res
|
|
472
479
|
|
|
473
480
|
def get_and_rotate_prng_key(self,
|
|
474
481
|
generator: Optional[torch.Generator] = None):
|
|
475
482
|
if generator is not None:
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
old_key = self._prng_key[...]
|
|
479
|
-
new_prng_key, next_key = jax.random.split(old_key)
|
|
480
|
-
self._prng_key[...] = new_prng_key
|
|
481
|
-
return next_key
|
|
483
|
+
return jax.random.PRNGKey(generator.initial_seed() % (2**63))
|
|
484
|
+
return self.param.get_and_rotate_prng_key()
|
|
482
485
|
|
|
483
486
|
def _handle_tensor_constructor(self, func, args, kwargs):
|
|
484
487
|
device = kwargs.get("device")
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
if not self._manually_entered and jax_device is None:
|
|
488
|
-
# let torch handle it
|
|
489
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
490
|
-
return func(*args, **kwargs)
|
|
491
|
-
with jax.default_device(jax_device):
|
|
488
|
+
if self._should_use_torchax_tensor(device):
|
|
489
|
+
# don't set default device, let caller set it
|
|
492
490
|
requires_grad = kwargs.get("requires_grad", False)
|
|
493
491
|
op = self._get_op_or_decomp(func)
|
|
492
|
+
if op.needs_env:
|
|
493
|
+
kwargs['env'] = self
|
|
494
|
+
if op.is_jax_function:
|
|
495
|
+
(args, kwargs) = self.t2j_iso((args, kwargs))
|
|
494
496
|
res = op.func(*args, **kwargs)
|
|
495
497
|
if isinstance(res, jax.Array):
|
|
496
|
-
res = Tensor(res, self)
|
|
497
|
-
if requires_grad:
|
|
498
|
-
res.requires_grad = True
|
|
498
|
+
res = Tensor(res, self, requires_grad)
|
|
499
499
|
return res
|
|
500
|
+
else:
|
|
501
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
502
|
+
return func(*args, **kwargs)
|
|
500
503
|
|
|
501
504
|
def _torch_Tensor_to(self, args, kwargs):
|
|
502
505
|
the_tensor = args[0]
|
|
@@ -560,11 +563,11 @@ class Environment(contextlib.ContextDecorator):
|
|
|
560
563
|
args, kwargs = self.v2t_iso((args, kwargs))
|
|
561
564
|
|
|
562
565
|
with self:
|
|
563
|
-
if self.autocast_dtype is not None:
|
|
566
|
+
if self.param.autocast_dtype is not None:
|
|
564
567
|
autocast_policy = amp.autocast_policy.get(func)
|
|
565
568
|
if autocast_policy is not None:
|
|
566
569
|
args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
|
|
567
|
-
self.autocast_dtype)
|
|
570
|
+
self.param.autocast_dtype)
|
|
568
571
|
|
|
569
572
|
if op.is_jax_function:
|
|
570
573
|
args, kwargs = self.t2j_iso((args, kwargs))
|
|
@@ -609,11 +612,9 @@ class Environment(contextlib.ContextDecorator):
|
|
|
609
612
|
|
|
610
613
|
def __enter__(self):
|
|
611
614
|
self.enable_torch_modes()
|
|
612
|
-
self._manually_entered = True
|
|
613
615
|
return self
|
|
614
616
|
|
|
615
617
|
def __exit__(self, *exc):
|
|
616
|
-
self._manually_entered = False
|
|
617
618
|
self.disable_torch_modes(*exc)
|
|
618
619
|
|
|
619
620
|
def _move_one_value(self, val):
|
|
@@ -639,6 +640,10 @@ class Environment(contextlib.ContextDecorator):
|
|
|
639
640
|
"""
|
|
640
641
|
|
|
641
642
|
def to_jax(x):
|
|
643
|
+
if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
|
|
644
|
+
x, Tensor):
|
|
645
|
+
if x.squeeze().ndim == 0:
|
|
646
|
+
return x.item()
|
|
642
647
|
if isinstance(
|
|
643
648
|
x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
644
649
|
x = x.wait()
|
|
@@ -697,3 +702,10 @@ class Environment(contextlib.ContextDecorator):
|
|
|
697
702
|
is_user_defined=True,
|
|
698
703
|
needs_env=False,
|
|
699
704
|
)
|
|
705
|
+
|
|
706
|
+
@contextlib.contextmanager
|
|
707
|
+
def override_property(self, **kwargs):
|
|
708
|
+
new_prop = self.param.override(**kwargs)
|
|
709
|
+
self._property.content.append(new_prop)
|
|
710
|
+
yield
|
|
711
|
+
self._property.content.pop()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
@@ -1,32 +1,33 @@
|
|
|
1
1
|
torchax/CONTRIBUTING.md,sha256=VOL0us6kS-uc4yE6IlSm6SDHYHnx-gw-0upFnP0VkSQ,1369
|
|
2
|
-
torchax/__init__.py,sha256=
|
|
3
|
-
torchax/amp.py,sha256
|
|
4
|
-
torchax/config.py,sha256=
|
|
2
|
+
torchax/__init__.py,sha256=c98iIGugRTbEVcsx8eWnbAjsC4mpcDrK23ZQqiMycLg,3157
|
|
3
|
+
torchax/amp.py,sha256=-k8t4lrCsJLKHEhI6J0aHE3MAPEL-4DP6wCKtMwo1AM,11791
|
|
4
|
+
torchax/config.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
|
|
5
|
+
torchax/configuration.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
|
|
5
6
|
torchax/decompositions.py,sha256=1p5TFZfAJ2Bs9BiSO1vXbnWEXnbPfC_gCQ54rDXhd9k,28859
|
|
6
|
-
torchax/device_module.py,sha256=
|
|
7
|
-
torchax/
|
|
7
|
+
torchax/device_module.py,sha256=7fkdPwXG0qCBTmvDYHp0fvv4xK0W9avV_Ua3MeMzczE,349
|
|
8
|
+
torchax/environment.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
8
9
|
torchax/export.py,sha256=xU-UbrQBvQWUy-GM2FfeIHymlEdmYDYcPymjlcXM23w,8969
|
|
9
10
|
torchax/flax.py,sha256=2Tg8inGskgAfByPxJQh4ItZHHAb-960gYq156bSO8V4,1280
|
|
10
|
-
torchax/interop.py,sha256=
|
|
11
|
-
torchax/mesh_util.py,sha256=
|
|
12
|
-
torchax/tensor.py,sha256=
|
|
11
|
+
torchax/interop.py,sha256=7HvJwtxdodcCrMyJzs-Wr47hkHuoh6CWb2-YKoBwqV0,11076
|
|
12
|
+
torchax/mesh_util.py,sha256=Ab4ic2eHWmQ3Mw3jpERvi-TKLIcDvQQoC6tuIZ9ig7Q,9314
|
|
13
|
+
torchax/tensor.py,sha256=XjAp7khpQNhoVsSMzDj-V8l4DFT9jBaL4NVCi88a6K0,20893
|
|
13
14
|
torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
|
|
14
15
|
torchax/train.py,sha256=rtvj6HkdnG9fc3VWYPNwHuxGlUxFJkUXJWED8azgtok,3855
|
|
15
16
|
torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
|
|
16
17
|
torchax/util.py,sha256=cb-eudDE7AX2s-6zYtXdowgyzyvqPqE9MPP82PfH23g,3069
|
|
17
18
|
torchax/view.py,sha256=1ekqRN04lAPd_icgZMKbSYWhr738DzVloc34ynml4wo,11121
|
|
18
19
|
torchax/ops/__init__.py,sha256=Vr1p8zDHwfXZBUbw70iNiCJLZLNdI6gR_vUlaiA7Usg,270
|
|
19
|
-
torchax/ops/jaten.py,sha256=
|
|
20
|
+
torchax/ops/jaten.py,sha256=WxfZU6p7b7OR98B3z0LCXKlV6U5aslXxJMJirBr6lns,165835
|
|
20
21
|
torchax/ops/jax_reimplement.py,sha256=idkmFWNCXBilkmaHBGdivKz0XhsjSpqLNlGXxbBOKWQ,7302
|
|
21
22
|
torchax/ops/jc10d.py,sha256=OzSYYle_5jBmNVP64SuJPz9S-rRGD6H7e1a9HHIKsjU,1322
|
|
22
23
|
torchax/ops/jimage.py,sha256=P0lAauYX_au_xjIHDsG7H6jO7Jf54_VCAjzZuIZdhO0,3182
|
|
23
24
|
torchax/ops/jlibrary.py,sha256=YfYUQbf5dKiMtEHUMfdgHTeLuNvvSTJ-l8s7wQNIvO0,2930
|
|
24
|
-
torchax/ops/jtorch.py,sha256=
|
|
25
|
+
torchax/ops/jtorch.py,sha256=wR4ZdDscxqG4VpxjcLGzgdUKmipa3fp7S0mK3DcD--A,17161
|
|
25
26
|
torchax/ops/jtorchvision_nms.py,sha256=HSnhwU0gFaHucT7EvrEruJdnWkAWTw4T35GY525ohO8,8903
|
|
26
27
|
torchax/ops/mappings.py,sha256=AESERtXJ6i_Hm0ycwEw7z5OJnHu-7QteWlSs-mlUPE4,3492
|
|
27
28
|
torchax/ops/op_base.py,sha256=MLKFxMojIXgz4lkTE6k-8F-ddve-9vEiXkzj3P-YJPs,3739
|
|
28
29
|
torchax/ops/ops_registry.py,sha256=qADpG1up0JOThoybiOQoRDWtAe5TOkHlqcj1bSHjtGY,1594
|
|
29
|
-
torchax-0.0.
|
|
30
|
-
torchax-0.0.
|
|
31
|
-
torchax-0.0.
|
|
32
|
-
torchax-0.0.
|
|
30
|
+
torchax-0.0.6.dist-info/METADATA,sha256=uB9hoyxdfrAD14pHy0U8Gh1uCHbYwok-oEW12pEa6qs,10753
|
|
31
|
+
torchax-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
32
|
+
torchax-0.0.6.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
|
|
33
|
+
torchax-0.0.6.dist-info/RECORD,,
|
torchax/distributed.py
DELETED
|
@@ -1,241 +0,0 @@
|
|
|
1
|
-
"""`torch.distributed` backend implemented with JAX collective ops.
|
|
2
|
-
|
|
3
|
-
EXPERIMENTAL: This module is still highly experimental, and it may be removed
|
|
4
|
-
before any stable release.
|
|
5
|
-
|
|
6
|
-
Note: JAX collective ops require that axis names be defined in `pmap` or
|
|
7
|
-
`shmap`. The distributed backend only supports one axis, named `torch_dist`.
|
|
8
|
-
This name is defined by our mirror implementation of `spawn`.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
import datetime
|
|
12
|
-
import functools
|
|
13
|
-
import logging
|
|
14
|
-
import os
|
|
15
|
-
from typing import List, Optional, Union
|
|
16
|
-
|
|
17
|
-
import jax
|
|
18
|
-
import numpy as np
|
|
19
|
-
import torch
|
|
20
|
-
import torch.distributed as dist
|
|
21
|
-
import torch.distributed._functional_collectives
|
|
22
|
-
from torch._C._distributed_c10d import ProcessGroup # type: ignore
|
|
23
|
-
import torch.distributed
|
|
24
|
-
import torchax
|
|
25
|
-
from jax.sharding import NamedSharding
|
|
26
|
-
from jax.sharding import Mesh, PartitionSpec as P
|
|
27
|
-
from jax.experimental import mesh_utils
|
|
28
|
-
import torch.utils._pytree as torch_pytree
|
|
29
|
-
from torchax import interop
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class ProcessGroupJax(ProcessGroup):
|
|
33
|
-
"""Distributed backend implemented with JAX."""
|
|
34
|
-
|
|
35
|
-
def __init__(self, prefix_store, rank, size, timeout):
|
|
36
|
-
super().__init__(rank, size)
|
|
37
|
-
self._group_name = None
|
|
38
|
-
|
|
39
|
-
def getBackendName(self):
|
|
40
|
-
return "jax"
|
|
41
|
-
|
|
42
|
-
# TODO(wcromar): why doesn't default group name setter work?
|
|
43
|
-
# https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
|
|
44
|
-
def _set_group_name(self, name: str) -> None:
|
|
45
|
-
self._group_name = name
|
|
46
|
-
|
|
47
|
-
@property
|
|
48
|
-
def group_name(self):
|
|
49
|
-
assert self._group_name
|
|
50
|
-
return self._group_name
|
|
51
|
-
|
|
52
|
-
@staticmethod
|
|
53
|
-
def _work(
|
|
54
|
-
tensors: Union[torch.Tensor, List[torch.Tensor],
|
|
55
|
-
List[List[torch.Tensor]]],
|
|
56
|
-
) -> dist.Work:
|
|
57
|
-
fut = torch.futures.Future()
|
|
58
|
-
fut.set_result(tensors)
|
|
59
|
-
return torch._C._distributed_c10d._create_work_from_future(fut)
|
|
60
|
-
|
|
61
|
-
def _allgather_base(
|
|
62
|
-
self,
|
|
63
|
-
output: torch.Tensor,
|
|
64
|
-
input: torch.Tensor,
|
|
65
|
-
opts=...,
|
|
66
|
-
) -> dist.Work:
|
|
67
|
-
assert isinstance(input, torchax.tensor.Tensor)
|
|
68
|
-
assert isinstance(output, torchax.tensor.Tensor)
|
|
69
|
-
torch.distributed._functional_collectives.all_gather_tensor_inplace(
|
|
70
|
-
output, input, group=self)
|
|
71
|
-
return self._work(output)
|
|
72
|
-
|
|
73
|
-
def allreduce(
|
|
74
|
-
self,
|
|
75
|
-
tensors: List[torch.Tensor],
|
|
76
|
-
opts: dist.AllreduceOptions = ...,
|
|
77
|
-
) -> dist.Work:
|
|
78
|
-
assert len(tensors) == 1
|
|
79
|
-
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
80
|
-
torch.distributed._functional_collectives.all_reduce_inplace(
|
|
81
|
-
tensors[0],
|
|
82
|
-
torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
|
|
83
|
-
opts.reduceOp.op],
|
|
84
|
-
self,
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
return self._work(tensors)
|
|
88
|
-
|
|
89
|
-
def broadcast(
|
|
90
|
-
self,
|
|
91
|
-
tensors: List[torch.Tensor],
|
|
92
|
-
opts: dist.BroadcastOptions = ...,
|
|
93
|
-
) -> dist.Work:
|
|
94
|
-
assert len(tensors) == 1
|
|
95
|
-
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
96
|
-
tensors[0].copy_(
|
|
97
|
-
torch.distributed._functional_collectives.broadcast(
|
|
98
|
-
tensors[0], opts.rootRank, group=self))
|
|
99
|
-
|
|
100
|
-
return self._work(tensors)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"])
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def jax_rendezvous_handler(url: str,
|
|
107
|
-
timeout: datetime.timedelta = ...,
|
|
108
|
-
**kwargs):
|
|
109
|
-
"""Initialize distributed store with JAX process IDs.
|
|
110
|
-
|
|
111
|
-
Requires `$MASTER_ADDR` and `$MASTER_PORT`.
|
|
112
|
-
"""
|
|
113
|
-
# TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU
|
|
114
|
-
# TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part
|
|
115
|
-
# of their public Python API
|
|
116
|
-
master_ip = os.environ["MASTER_ADDR"]
|
|
117
|
-
master_port = int(os.environ["MASTER_PORT"])
|
|
118
|
-
# TODO(wcromar): Use `torchrun`'s store if available
|
|
119
|
-
store = dist.TCPStore(
|
|
120
|
-
master_ip,
|
|
121
|
-
master_port,
|
|
122
|
-
jax.process_count(),
|
|
123
|
-
is_master=jax.process_index() == 0,
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
yield (store, jax.process_index(), jax.process_count())
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
dist.register_rendezvous_handler("jax", jax_rendezvous_handler)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
|
|
133
|
-
"""Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined.
|
|
134
|
-
`f` is expected to take the replica index as a positional argument, similar
|
|
135
|
-
to `torch.multiprocessing.spawn`.
|
|
136
|
-
Note: `spawn` does not actually create parallel processes.
|
|
137
|
-
"""
|
|
138
|
-
env = env or torchax.default_env()
|
|
139
|
-
|
|
140
|
-
def jax_wrapper(index, jax_args):
|
|
141
|
-
index, args = env.j2t_iso([index, jax_args])
|
|
142
|
-
torch_outputs = f(index, *args)
|
|
143
|
-
return env.t2j_iso(torch_outputs)
|
|
144
|
-
|
|
145
|
-
jax_outputs = jax.pmap(
|
|
146
|
-
jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()),
|
|
147
|
-
env.t2j_iso(args))
|
|
148
|
-
return env.j2t_iso(jax_outputs)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
class DistributedDataParallel(torch.nn.Module):
|
|
152
|
-
"""Re-implementation of DistributedDataParallel using JAX SPMD.
|
|
153
|
-
|
|
154
|
-
Splits inputs along batch dimension (assumed to be 0) across all devices in
|
|
155
|
-
JAX runtime, including remote devices. Each process should load a distinct
|
|
156
|
-
shard of the input data using e.g. DistributedSampler. Each process' shard
|
|
157
|
-
is then further split among the addressable devices (e.g. local TPU chips)
|
|
158
|
-
by `shard_input`.
|
|
159
|
-
|
|
160
|
-
Note: since parameters are replicated across addressable devices, inputs
|
|
161
|
-
must also be SPMD sharded using `shard_input` or `replicate_input`.
|
|
162
|
-
|
|
163
|
-
Example usage:
|
|
164
|
-
|
|
165
|
-
```
|
|
166
|
-
jax_model = torchax.distributed.DistributedDataParallel(create_model())
|
|
167
|
-
for data, dataloader:
|
|
168
|
-
jax_data = jax_model.shard_input(data)
|
|
169
|
-
jax_output = jax_model(jax_data)
|
|
170
|
-
```
|
|
171
|
-
"""
|
|
172
|
-
|
|
173
|
-
def __init__(
|
|
174
|
-
self,
|
|
175
|
-
module: torch.nn.Module,
|
|
176
|
-
env: Optional[torchax.tensor.Environment] = None,
|
|
177
|
-
**kwargs,
|
|
178
|
-
):
|
|
179
|
-
if kwargs:
|
|
180
|
-
logging.warning(f"Unsupported kwargs {kwargs}")
|
|
181
|
-
|
|
182
|
-
super().__init__()
|
|
183
|
-
self._env = env or torchax.default_env()
|
|
184
|
-
self._mesh = Mesh(
|
|
185
|
-
mesh_utils.create_device_mesh((jax.device_count(),)),
|
|
186
|
-
axis_names=("batch",),
|
|
187
|
-
)
|
|
188
|
-
replicated_state = torch_pytree.tree_map_only(
|
|
189
|
-
torch.Tensor,
|
|
190
|
-
lambda t: self._env.j2t_iso(
|
|
191
|
-
jax.device_put(
|
|
192
|
-
self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))),
|
|
193
|
-
module.state_dict(),
|
|
194
|
-
)
|
|
195
|
-
# TODO: broadcast
|
|
196
|
-
module.load_state_dict(replicated_state, assign=True)
|
|
197
|
-
self._module = module
|
|
198
|
-
|
|
199
|
-
def shard_input(self, inp):
|
|
200
|
-
per_process_batch_size = inp.shape[0] # assumes batch dim is 0
|
|
201
|
-
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
|
|
202
|
-
per_replica_batches = torch.chunk(inp, jax.local_device_count())
|
|
203
|
-
global_batch_size = per_replica_batch_size * jax.device_count()
|
|
204
|
-
global_batch_shape = (global_batch_size,) + inp.shape[1:]
|
|
205
|
-
|
|
206
|
-
sharding = NamedSharding(self._mesh, P("batch"))
|
|
207
|
-
return self._env.j2t_iso(
|
|
208
|
-
jax.make_array_from_single_device_arrays(
|
|
209
|
-
global_batch_shape,
|
|
210
|
-
NamedSharding(self._mesh, P("batch")),
|
|
211
|
-
arrays=[
|
|
212
|
-
jax.device_put(self._env.to_xla(batch)._elem, device) for batch,
|
|
213
|
-
device in zip(per_replica_batches, sharding.addressable_devices)
|
|
214
|
-
],
|
|
215
|
-
))
|
|
216
|
-
|
|
217
|
-
def replicate_input(self, inp):
|
|
218
|
-
return self._env.j2t_iso(
|
|
219
|
-
jax.device_put(inp._elem, NamedSharding(self._mesh, P())))
|
|
220
|
-
|
|
221
|
-
def jit_step(self, func):
|
|
222
|
-
|
|
223
|
-
@functools.partial(
|
|
224
|
-
interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0})
|
|
225
|
-
def _jit_fn(states, args):
|
|
226
|
-
self.load_state_dict(states)
|
|
227
|
-
outputs = func(*args)
|
|
228
|
-
return self.state_dict(), outputs
|
|
229
|
-
|
|
230
|
-
@functools.wraps(func)
|
|
231
|
-
def inner(*args):
|
|
232
|
-
jax_states = self.state_dict()
|
|
233
|
-
new_states, outputs = _jit_fn(jax_states, args)
|
|
234
|
-
self.load_state_dict(new_states)
|
|
235
|
-
return outputs
|
|
236
|
-
|
|
237
|
-
return inner
|
|
238
|
-
|
|
239
|
-
def forward(self, *args):
|
|
240
|
-
with self._env:
|
|
241
|
-
return self._module(*args)
|
|
File without changes
|
|
File without changes
|