torchax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/__init__.py CHANGED
@@ -6,10 +6,9 @@ import os
6
6
  import torch
7
7
  from torch.utils import _pytree as pytree
8
8
  from torchax import tensor
9
- from torchax import distributed # noqa: F401
10
9
  from contextlib import contextmanager
11
10
 
12
- __version__ = "0.0.5"
11
+ __version__ = "0.0.7"
13
12
  VERSION = __version__
14
13
 
15
14
  __all__ = [
@@ -50,10 +49,11 @@ def extract_jax(mod: torch.nn.Module, env=None):
50
49
  states = env.t2j_copy(states)
51
50
 
52
51
  #@jax.jit
53
- def jax_func(states, inputs):
54
- (states, inputs) = env.j2t_iso((states, inputs))
52
+ def jax_func(states, args, kwargs=None):
53
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
55
54
  with env:
56
- res = torch.func.functional_call(mod, states, inputs, tie_weights=False)
55
+ res = torch.func.functional_call(
56
+ mod, states, args, kwargs, tie_weights=False)
57
57
  return env.t2j_iso(res)
58
58
 
59
59
  return states, jax_func
@@ -81,11 +81,6 @@ def disable_temporarily():
81
81
 
82
82
  torch.utils.rename_privateuse1_backend('jax')
83
83
  unsupported_dtype = [torch.quint8]
84
- torch.utils.generate_methods_for_privateuse1_backend(
85
- for_tensor=True,
86
- for_module=True,
87
- for_storage=True,
88
- unsupported_dtype=unsupported_dtype)
89
84
 
90
85
  import jax
91
86
  import torchax.device_module
@@ -129,34 +124,3 @@ def compile(fn, options: Optional[CompileOptions] = None):
129
124
  raise RuntimeError('dynamo mode is not supported yet')
130
125
  elif options.mode == 'export':
131
126
  raise RuntimeError('export mode is not supported yet')
132
-
133
-
134
- @contextmanager
135
- def jax_device(target_device: str, env: tensor.Environment | None = None):
136
- """
137
- to("jax") cannot differentiate the device/platform (cpu vs tpu).
138
- Use this context manager to control jax array's storage device
139
-
140
- Examples:
141
-
142
- a = torch.ones(3, 3)
143
-
144
- with jax_device("cpu"):
145
- b = a.to("jax")
146
-
147
- with jax_device("tpu"):
148
- c = a.to("jax")
149
-
150
- with jax_device("tpu"):
151
- c = b.to("jax")
152
-
153
- """
154
- if env is None:
155
- env = default_env()
156
-
157
- prev_target_device = env.target_device
158
- try:
159
- env.target_device = target_device
160
- yield env
161
- finally:
162
- env.target_device = prev_target_device
torchax/amp.py CHANGED
@@ -61,9 +61,8 @@ def autocast(device, dtype=torch.bfloat16, env=None):
61
61
  if env is None:
62
62
  import torchax
63
63
  env = torchax.default_env()
64
- env.autocast_dtype, old = dtype, env.autocast_dtype
65
- yield
66
- env.autocast_dtype = old
64
+ with env.override_property(autocast_dtype=dtype):
65
+ yield
67
66
 
68
67
 
69
68
  # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
torchax/config.py CHANGED
@@ -10,6 +10,11 @@ class Configuration:
10
10
 
11
11
  use_int32_for_index: bool = False
12
12
 
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor: bool = True
17
+
13
18
  # If true, we will convert Views into torchax.Tensors eagerly
14
19
  force_materialize_views: bool = False
15
20
 
@@ -22,5 +27,4 @@ class Configuration:
22
27
 
23
28
  # device
24
29
  treat_cuda_as_jax_device: bool = True
25
- use_torch_native_for_cpu_tensor: bool = True
26
30
  internal_respect_torch_return_dtypes: bool = False
@@ -0,0 +1,30 @@
1
+ import dataclasses
2
+
3
+
4
+ @dataclasses.dataclass
5
+ class Configuration:
6
+ debug_print_each_op: bool = False
7
+ debug_accuracy_for_each_op: bool = False
8
+ debug_mixed_tensor: bool = False
9
+ debug_print_each_op_operands: bool = False
10
+
11
+ use_int32_for_index: bool = False
12
+
13
+ # normally, math between CPU torch.Tensor with torchax.Tensor is not
14
+ # allowed. However, if that torch.Tensor happens to be scalar, then we
15
+ # can use scalar * tensor math to handle it
16
+ allow_mixed_math_with_scalar_tensor: bool = True
17
+
18
+ # If true, we will convert Views into torchax.Tensors eagerly
19
+ force_materialize_views: bool = False
20
+
21
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
22
+ use_dlpack_for_data_conversion: bool = False
23
+
24
+ # Flash attention
25
+ use_tpu_flash_attention: bool = False
26
+ shmap_flash_attention: bool = False
27
+
28
+ # device
29
+ treat_cuda_as_jax_device: bool = True
30
+ internal_respect_torch_return_dtypes: bool = False
torchax/device_module.py CHANGED
@@ -1,3 +1,6 @@
1
+ import torch
2
+
3
+
1
4
  def _is_in_bad_fork():
2
5
  return False
3
6
 
@@ -24,3 +27,7 @@ def is_available():
24
27
 
25
28
  def current_device():
26
29
  return 0
30
+
31
+
32
+ def get_amp_supported_dtype():
33
+ return [torch.float16, torch.bfloat16]
torchax/environment.py ADDED
@@ -0,0 +1 @@
1
+
torchax/interop.py CHANGED
@@ -11,6 +11,7 @@ from jax import tree_util as pytree
11
11
  from jax.experimental.shard_map import shard_map
12
12
  from torchax import tensor
13
13
  from torchax import util
14
+ from torchax.ops import mappings
14
15
  import torchax
15
16
 
16
17
  from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
@@ -90,7 +91,7 @@ class JittableModule(torch.nn.Module):
90
91
  def __call__(self, *args, **kwargs):
91
92
  return self.forward(*args, **kwargs)
92
93
 
93
- def functional_call(self, method_name, params, buffers, *args, **kwargs):
94
+ def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
94
95
  kwargs = kwargs or {}
95
96
  params_copy = copy.copy(params)
96
97
  params_copy.update(buffers)
@@ -98,22 +99,35 @@ class JittableModule(torch.nn.Module):
98
99
  for k, v in self._extra_dumped_weights.items():
99
100
  for new_key in v:
100
101
  params_copy[new_key] = params_copy[k]
102
+
103
+ if isinstance(method_or_name, str):
104
+ method = getattr(self._model, method_or_name)
105
+ else:
106
+ if not callable(method_or_name):
107
+ raise TypeError(
108
+ f"method_or_name should be a callable or a string, got {type(method_or_name)}"
109
+ )
110
+ method = method_or_name
111
+ args = (self._model,) + args
101
112
  with torch_stateless._reparametrize_module(self._model, params_copy):
102
- res = getattr(self._model, method_name)(*args, **kwargs)
113
+ res = method(*args, **kwargs)
103
114
  return res
104
115
 
105
- def forward(self, *args, **kwargs):
106
- if 'forward' not in self._jitted:
116
+ def jittable_call(self, method_name: str, *args, **kwargs):
117
+ if method_name not in self._jitted:
107
118
  jitted = jax_jit(
108
- functools.partial(self.functional_call, 'forward'),
119
+ functools.partial(self.functional_call, method_name),
109
120
  kwargs_for_jax_jit=self._extra_jit_args,
110
121
  )
111
122
 
112
123
  def jitted_forward(*args, **kwargs):
113
124
  return jitted(self.params, self.buffers, *args, **kwargs)
114
125
 
115
- self._jitted['forward'] = jitted_forward
116
- return self._jitted['forward'](*args, **kwargs)
126
+ self._jitted[method_name] = jitted_forward
127
+ return self._jitted[method_name](*args, **kwargs)
128
+
129
+ def forward(self, *args, **kwargs):
130
+ return self.jittable_call('forward', *args, **kwargs)
117
131
 
118
132
  def __getattr__(self, key):
119
133
  if key == '_model':
@@ -170,8 +184,8 @@ def _torch_view(t: JaxValue) -> TorchValue:
170
184
  if isinstance(t, jax.Array):
171
185
  # TODO
172
186
  return tensor.Tensor(t, torchax.default_env())
173
- if isinstance(t, type(jnp.int32)):
174
- return tensor.t2j_type(t)
187
+ if isinstance(t, jnp.dtype):
188
+ return mappings.j2t_dtype(t)
175
189
  if callable(t): # t is a JaxCallable
176
190
  return functools.partial(call_jax, t)
177
191
  # regular types are not changed
@@ -188,7 +202,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
188
202
  assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
189
203
  return t.jax()
190
204
  if isinstance(t, type(torch.int32)):
191
- return tensor.t2j_dtype(t)
205
+ return mappings.t2j_dtype(t)
192
206
 
193
207
  # torch.nn.Module needs special handling
194
208
  if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
@@ -223,10 +237,39 @@ def j2t_autograd(fn, call_jax=call_jax):
223
237
  the PyTorch autograd framework by saving the residuals into the context object.
224
238
  """
225
239
 
240
+ # NOTE(qihqi): This function cannot be inlined from the callsite
241
+ # Becuase if it does, then it won't hit the compilation cache for
242
+ # call_jax. Call jax uses functions' id as key.
243
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
244
+ # wrapped pure function, preventing cache collisions between different pure modules.
245
+ def _jax_forward(fn, other, tree_def, tensors):
246
+ """JAX function to compute output and vjp function.
247
+
248
+ primals should be a tuple (args, kwargs).
249
+ """
250
+ import jax
251
+ from jax.tree_util import tree_flatten, tree_unflatten
252
+
253
+ def fn_wrapper(*tensors):
254
+ # Reconstruct the original args and kwargs
255
+ flat_inputs = util.merge(tensors, other)
256
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
257
+ return fn(*args, **kwargs)
258
+
259
+ return jax.vjp(fn_wrapper, *tensors)
260
+
261
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
262
+ """JAX function to compute input gradients.
263
+
264
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
265
+ """
266
+ from jax.tree_util import tree_unflatten
267
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
268
+ return fun_vjp(grad_out)
269
+
226
270
  @wraps(fn)
227
271
  def inner(*args, **kwargs):
228
- from jax.tree_util import tree_flatten, tree_unflatten
229
- from jax.util import safe_zip
272
+ from jax.tree_util import tree_flatten
230
273
 
231
274
  class JaxFun(torch.autograd.Function):
232
275
 
@@ -261,8 +304,8 @@ def j2t_autograd(fn, call_jax=call_jax):
261
304
  # The subsequent gradients correspond to flat_inputs.
262
305
  # We need to put a None for inputs that did not require gradients.
263
306
  final_grads = [None]
264
- for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:],
265
- input_grads_structured):
307
+ for needs_grad, grad in zip(
308
+ ctx.needs_input_grad[1:], input_grads_structured, strict=True):
266
309
  final_grads.append(grad if needs_grad else None)
267
310
 
268
311
  return tuple(final_grads)
@@ -277,36 +320,6 @@ def j2t_autograd(fn, call_jax=call_jax):
277
320
  return inner
278
321
 
279
322
 
280
- # NOTE(qihqi): This function cannot be inlined from the callsite
281
- # Becuase if it does, then it won't hit the compilation cache for
282
- # call_jax. Call jax uses functions' id as key.
283
- def _jax_forward(fn, other, tree_def, tensors):
284
- """JAX function to compute output and vjp function.
285
-
286
- primals should be a tuple (args, kwargs).
287
- """
288
- import jax
289
- from jax.tree_util import tree_flatten, tree_unflatten
290
-
291
- def fn_wrapper(*tensors):
292
- # Reconstruct the original args and kwargs
293
- flat_inputs = util.merge(tensors, other)
294
- args, kwargs = tree_unflatten(tree_def, flat_inputs)
295
- return fn(*args, **kwargs)
296
-
297
- return jax.vjp(fn_wrapper, *tensors)
298
-
299
-
300
- def _jax_backward(vjp_spec, saved_tensors, grad_out):
301
- """JAX function to compute input gradients.
302
-
303
- Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
304
- """
305
- from jax.tree_util import tree_unflatten
306
- fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
307
- return fun_vjp(grad_out)
308
-
309
-
310
323
  fori_loop = torch_view(jax.lax.fori_loop)
311
324
 
312
325
 
torchax/mesh_util.py CHANGED
@@ -199,7 +199,7 @@ class Mesh:
199
199
  }
200
200
 
201
201
  def model_initializer():
202
- with torchax.default_env():
202
+ with torchax.default_env(), torch.device('meta'):
203
203
  model = model_class(*init_args, **init_kwargs)
204
204
  return dict(model.state_dict())
205
205
 
@@ -209,3 +209,12 @@ class Mesh:
209
209
 
210
210
  model.load_state_dict(weights_dict, assign=True)
211
211
  return model
212
+
213
+ def shard_model(self, model, override_sharder=None):
214
+ sharder = override_sharder or self._sharder
215
+ states = model.state_dict()
216
+ output_shards = {
217
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
218
+ for name, tensor in states.items()
219
+ }
220
+ model.load_state_dict(output_shards, assign=True)
torchax/ops/jaten.py CHANGED
@@ -736,7 +736,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs):
736
736
  return jnp.empty(sizes, dtype=dtype)
737
737
 
738
738
 
739
- @op(torch.ops.aten.index_put_)
740
739
  @op(torch.ops.aten.index_put)
741
740
  def _aten_index_put(self, indexes, values, accumulate=False):
742
741
  indexes = [slice(None, None, None) if i is None else i for i in indexes]
@@ -3532,7 +3531,7 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):
3532
3531
 
3533
3532
  @op(torch.ops.aten.randn, needs_env=True)
3534
3533
  @op_base.convert_dtype()
3535
- def _randn(
3534
+ def _aten_randn(
3536
3535
  *size,
3537
3536
  generator=None,
3538
3537
  out=None,
@@ -3652,7 +3651,7 @@ def _aten_native_batch_norm(input,
3652
3651
  @op(torch.ops.aten.normal, needs_env=True)
3653
3652
  def _aten_normal(self, mean=0, std=1, generator=None, env=None):
3654
3653
  shape = self.shape
3655
- res = _randn(*shape, generator=generator, env=env)
3654
+ res = _aten_randn(*shape, generator=generator, env=env)
3656
3655
  return res * std + mean
3657
3656
 
3658
3657
 
@@ -5541,6 +5540,7 @@ def _aten_floor_divide(x, y):
5541
5540
 
5542
5541
 
5543
5542
  @op(torch.ops.aten._assert_tensor_metadata)
5543
+ @op(torch.ops.aten._assert_scalar)
5544
5544
  def _aten__assert_tensor_metadata(*args, **kwargs):
5545
5545
  pass
5546
5546
 
@@ -5617,6 +5617,8 @@ mutation_ops_to_functional = {
5617
5617
  op_base.InplaceOp(torch.ops.aten.floor_divide),
5618
5618
  torch.ops.aten.remainder_:
5619
5619
  op_base.InplaceOp(torch.ops.aten.remainder),
5620
+ torch.ops.aten.index_put_:
5621
+ op_base.InplaceOp(torch.ops.aten.index_put),
5620
5622
  }
5621
5623
 
5622
5624
  # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`.
torchax/ops/jtorch.py CHANGED
@@ -179,6 +179,13 @@ def _tpu_flash_attention(query, key, value, env):
179
179
  return wrap_flash_attention(query, key, value)
180
180
 
181
181
 
182
+ @register_function(torch.nn.functional.one_hot)
183
+ def one_hot(tensor, num_classes=-1):
184
+ if num_classes == -1:
185
+ num_classes = jnp.max(tensor) + 1
186
+ return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64)
187
+
188
+
182
189
  @register_function(torch.nn.functional.pad)
183
190
  def pad(tensor, pad, mode="constant", value=None):
184
191
  # For padding modes that have different names between Torch and NumPy, this
@@ -341,7 +348,7 @@ def empty(*size: Sequence[int], dtype=None, **kwargs):
341
348
  return jnp.empty(size, dtype=dtype)
342
349
 
343
350
 
344
- @register_function(torch.arange, is_jax_function=False)
351
+ @register_function(torch.arange, is_jax_function=True)
345
352
  def arange(
346
353
  start,
347
354
  end=None,
@@ -358,10 +365,10 @@ def arange(
358
365
  start = 0
359
366
  if step is None:
360
367
  step = 1
361
- return torch.ops.aten.arange(start, end, step, dtype=dtype)
368
+ return jaten._aten_arange(start, end, step, dtype=dtype)
362
369
 
363
370
 
364
- @register_function(torch.empty_strided, is_jax_function=False)
371
+ @register_function(torch.empty_strided, is_jax_function=True)
365
372
  def empty_strided(
366
373
  size,
367
374
  stride,
@@ -372,7 +379,7 @@ def empty_strided(
372
379
  requires_grad=False,
373
380
  pin_memory=False,
374
381
  ):
375
- return empty(size, dtype=dtype)
382
+ return empty(size, dtype=dtype, requires_grad=requires_grad)
376
383
 
377
384
 
378
385
  @register_function(torch.unravel_index)
@@ -380,14 +387,14 @@ def unravel_index(indices, shape):
380
387
  return jnp.unravel_index(indices, shape)
381
388
 
382
389
 
383
- @register_function(torch.rand, is_jax_function=False)
390
+ @register_function(torch.rand, is_jax_function=True, needs_env=True)
384
391
  def rand(*size, **kwargs):
385
392
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
386
393
  size = size[0]
387
- return torch.ops.aten.rand(size, **kwargs)
394
+ return jaten._rand(size, **kwargs)
388
395
 
389
396
 
390
- @register_function(torch.randn, is_jax_function=False)
397
+ @register_function(torch.randn, is_jax_function=True, needs_env=True)
391
398
  def randn(
392
399
  *size,
393
400
  generator=None,
@@ -397,15 +404,16 @@ def randn(
397
404
  device=None,
398
405
  requires_grad=False,
399
406
  pin_memory=False,
407
+ env=None,
400
408
  ):
401
409
  if len(size) == 1 and isinstance(size[0], collections.abc.Iterable):
402
410
  size = size[0]
403
- return torch.ops.aten.randn(size, generator=generator, dtype=dtype)
411
+ return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env)
404
412
 
405
413
 
406
- @register_function(torch.randint, is_jax_function=False)
414
+ @register_function(torch.randint, is_jax_function=False, needs_env=True)
407
415
  def randint(*args, **kwargs):
408
- return torch.ops.aten.randint(*args, **kwargs)
416
+ return jaten._aten_randint(*args, **kwargs)
409
417
 
410
418
 
411
419
  @register_function(torch.logdet)
torchax/ops/mappings.py CHANGED
@@ -6,6 +6,14 @@ import torch.func
6
6
  import torch.utils.dlpack as torchdl
7
7
  import torch.utils._mode_utils as mode_utils
8
8
 
9
+ NUMPY_UNSUPPORTED_DTYPES = {
10
+ torch.bfloat16: jnp.bfloat16,
11
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
12
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
13
+ torch.float8_e5m2: jnp.float8_e5m2,
14
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
15
+ }
16
+
9
17
 
10
18
  def t2j(t, use_dlpack=True):
11
19
  is_bool = False
@@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
28
36
  if res is None:
29
37
  # https://github.com/google/jax/issues/7657
30
38
  # https://github.com/google/jax/issues/17784
31
- if t.dtype == torch.bfloat16:
39
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
32
40
  nparray = (t.cpu().detach().to(torch.float32).numpy()
33
- ) # numpy don't support bfloat16
41
+ ) # handle dtypes not supported by numpy
34
42
  else:
35
43
  nparray = t.cpu().detach().numpy()
36
44
  res = jnp.asarray(nparray)
37
- if t.dtype == torch.bfloat16:
38
- res = res.astype(jnp.bfloat16)
45
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
46
+ res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
39
47
 
40
48
  if is_bool:
41
49
  res = res.astype(jnp.bool_)
@@ -0,0 +1,47 @@
1
+
2
+ from torchax import config
3
+ from torchax.ops import mappings
4
+ import jax.numpy as jnp
5
+ import torch
6
+
7
+ def maybe_cast(result, torch_op):
8
+ """Casts the result to the torch op's return dtype if the config is set."""
9
+ if not config.DEFAULTS.internal_respect_torch_return_dtypes:
10
+ return result
11
+
12
+ if not hasattr(torch_op, '_schema'):
13
+ return result
14
+
15
+ schema = torch_op._schema
16
+ if not schema.returns:
17
+ return result
18
+
19
+ # TODO: Handle multiple return values
20
+ if len(schema.returns) > 1:
21
+ return result
22
+
23
+ return_type = schema.returns[0].type
24
+ if str(return_type) == 'Tensor':
25
+ # This is not quite right, we need to get the dtype of the tensor
26
+ # For now, let's assume we can get it from the first input argument
27
+ if not schema.arguments:
28
+ return result
29
+
30
+ input_type = schema.arguments[0].type
31
+ if str(input_type) != 'Tensor':
32
+ return result
33
+
34
+ # This is a hack, we need a better way to determine the return dtype
35
+ # For now, let's assume the return type is the same as the first input
36
+ # This is not always true, e.g. for comparison ops.
37
+ return result
38
+
39
+ try:
40
+ torch_dtype = getattr(torch, str(return_type))
41
+ jax_dtype = mappings.t2j_dtype(torch_dtype)
42
+ if isinstance(result, jnp.ndarray):
43
+ return result.astype(jax_dtype)
44
+ else:
45
+ return jax_dtype(result)
46
+ except (AttributeError, TypeError):
47
+ return result
torchax/tensor.py CHANGED
@@ -1,3 +1,4 @@
1
+ import threading
1
2
  import logging
2
3
  import sys
3
4
  import contextlib
@@ -16,7 +17,6 @@ from torchax.view import View
16
17
  from torchax import config
17
18
  from torchax.ops import mappings, ops_registry
18
19
  from torchax import amp
19
- from jax.experimental import mutable_array
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -25,14 +25,6 @@ class OperatorNotFound(Exception):
25
25
  pass
26
26
 
27
27
 
28
- def wrap(jaxarray):
29
- return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray)
30
-
31
-
32
- def unwrap(torchtensors):
33
- return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
34
-
35
-
36
28
  @contextlib.contextmanager
37
29
  def log_nested(env, message):
38
30
  if env.config.debug_print_each_op:
@@ -48,7 +40,7 @@ log_nested.level = 0
48
40
  class Tensor(torch.Tensor):
49
41
 
50
42
  @staticmethod
51
- def __new__(cls, elem, env):
43
+ def __new__(cls, elem, env, requires_grad=False):
52
44
  dtype = mappings.j2t_dtype(elem.dtype)
53
45
  shape = list(elem.shape)
54
46
  for i, s in enumerate(shape):
@@ -56,15 +48,19 @@ class Tensor(torch.Tensor):
56
48
  shape[i] = 1
57
49
  if dtype is None:
58
50
  dtype = torch.float32
51
+ #dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
52
+ if not (dtype.is_floating_point or dtype.is_complex):
53
+ requires_grad = False
54
+
59
55
  return torch.Tensor._make_wrapper_subclass(
60
56
  cls,
61
57
  shape,
62
58
  dtype=dtype,
63
- device="meta",
64
- requires_grad=False,
59
+ device='meta',
60
+ requires_grad=requires_grad,
65
61
  )
66
62
 
67
- def __init__(self, elem: jax.Array, env: "Environment"):
63
+ def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
68
64
  super().__init__()
69
65
  self._elem = elem
70
66
  self._env = env
@@ -74,9 +70,6 @@ class Tensor(torch.Tensor):
74
70
 
75
71
  __repr__ = __str__
76
72
 
77
- def __jax_array__(self):
78
- return self._elem
79
-
80
73
  @property
81
74
  def shape(self):
82
75
  return torch.Size(self._elem.shape)
@@ -109,6 +102,8 @@ class Tensor(torch.Tensor):
109
102
  # TODO(hanq): figure out why is dispatch mode not sufficient
110
103
  if func == torch.ops._c10d_functional.wait_tensor.default:
111
104
  return args[0]._env.dispatch(func, types, args, kwargs)
105
+ if func == torch.ops.prim.device.default:
106
+ return torch.device('privateuseone', 0)
112
107
  raise AssertionError(
113
108
  'torchax Tensors can only do math within the torchax environment.'
114
109
  'Please wrap your code with `with torchax.default_env()` or '
@@ -298,6 +293,38 @@ TENSOR_CONSTRUCTORS = {
298
293
  SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
299
294
 
300
295
 
296
+ class RuntimeProperty:
297
+ mesh: Any
298
+ prng: Any
299
+ autocast_dtype: Any
300
+
301
+ def __init__(self, mesh, prng, autocast_dtype):
302
+ self.mesh = mesh
303
+ self.prng = prng
304
+ self.autocast_dtype = autocast_dtype
305
+
306
+ def override(self, **kwargs):
307
+ return OverrideProperty(self, kwargs)
308
+
309
+ def get_and_rotate_prng_key(self):
310
+ old_key = self.prng
311
+ new_prng_key, next_key = jax.random.split(old_key)
312
+ self.prng = new_prng_key
313
+ return next_key
314
+
315
+
316
+ class OverrideProperty(RuntimeProperty):
317
+
318
+ def __init__(self, parent, override):
319
+ self.parent = parent
320
+ self._override = dict(override)
321
+
322
+ def __getattr__(self, name):
323
+ if name in self._override:
324
+ return self._override[name]
325
+ return getattr(self.parent, name)
326
+
327
+
301
328
  class Environment(contextlib.ContextDecorator):
302
329
  """This class holds a set of configurations and "globals" needed
303
330
 
@@ -321,62 +348,55 @@ class Environment(contextlib.ContextDecorator):
321
348
 
322
349
  self.load_ops()
323
350
 
324
- self._mesh = None
351
+ _mesh = None
325
352
  self.config = configuration or config.Configuration()
326
353
 
327
- self._manually_entered = False
328
354
  self.enabled = False
329
355
 
330
- self._prng_key = mutable_array(
331
- jax.random.key(torch.initial_seed() % (1 << 63)))
332
- self.autocast_dtype = None
333
- self._target_device = jax.local_devices()[0].platform
356
+ autocast_dtype = None
334
357
 
335
- @property
336
- def target_device(self):
337
- return self._target_device
358
+ _prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
359
+ self._property = threading.local()
360
+ self._property.content = [
361
+ RuntimeProperty(
362
+ mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
363
+ ]
338
364
 
339
- @target_device.setter
340
- def target_device(self, device: str):
341
- self._target_device = device.lower()
365
+ @property
366
+ def param(self):
367
+ return self._property.content[-1]
342
368
 
343
369
  def manual_seed(self, key):
344
- self._prng_key = mutable_array(jax.random.key(key))
370
+ jax_key = jax.random.PRNGKey(key)
371
+ new_prop = self.param.override(prng=jax_key)
372
+ self._property.content.append(new_prop)
345
373
 
346
374
  @property
347
375
  def prng_key(self):
348
- return self._prng_key[...]
376
+ return self.param.prng
349
377
 
350
- def get_as_jax_device(self, device: Any):
378
+ def _should_use_torchax_tensor(self, device):
351
379
  if device is None:
352
380
  device = torch.get_default_device()
353
381
 
354
382
  if isinstance(device, torch.device):
355
- device = str(device)
356
-
357
- if not self.config.use_torch_native_for_cpu_tensor and device.startswith(
358
- "cpu"):
359
- return jax.devices("cpu")[0]
360
-
361
- if self.config.treat_cuda_as_jax_device and device.startswith("cuda"):
362
- return jax.local_devices()[0]
363
-
364
- if device.startswith("xla"):
365
- return jax.local_devices()[0]
366
-
367
- # TODO (wen): jax is NOT a device type,
368
- # once we can register more than one backend, revisit
369
- if device.startswith("jax"):
370
- match self.target_device:
371
- case "cpu":
372
- return jax.devices("cpu")[0]
373
- case "tpu":
374
- return jax.devices("tpu")[0]
375
- case _:
376
- raise AttributeError(
377
- f"Cannot handle env.target_device {self.target_device}")
378
-
379
- return None # fallback to torch
383
+ device = device.type
384
+
385
+ if ':' in device:
386
+ device = device.split(':')[0]
387
+
388
+ match device:
389
+ case 'cpu':
390
+ return False
391
+ case 'cuda':
392
+ return self.config.treat_cuda_as_jax_device
393
+ case 'jax':
394
+ return True
395
+ case 'privateuseone':
396
+ return True
397
+ case 'meta':
398
+ return self.enabled
399
+ return False
380
400
 
381
401
  def load_ops(self):
382
402
  from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
@@ -423,80 +443,63 @@ class Environment(contextlib.ContextDecorator):
423
443
 
424
444
  return op
425
445
 
446
+ def _is_same_device(self, the_tensor, new_device):
447
+ if new_device is None:
448
+ return True
449
+ if new_device == 'meta' and the_tensor.device.type == 'jax':
450
+ return True
451
+ if the_tensor.device.type != new_device:
452
+ if the_tensor.device.type == 'cuda':
453
+ return self.config.treat_cuda_as_jax_device
454
+ return False
455
+ return True
456
+
426
457
  def _to_copy(self, the_tensor, new_dtype, new_device):
427
458
  if isinstance(the_tensor, View):
428
459
  the_tensor = the_tensor.torch()
429
-
430
- if isinstance(the_tensor, Tensor):
431
-
432
- arr = the_tensor.jax()
433
-
434
- if new_dtype is not None and new_dtype != arr.dtype:
435
- arr = arr.astype(mappings.t2j_dtype(new_dtype))
436
-
437
- if new_device is not None:
438
- match str(new_device).lower():
439
- case "cpu":
440
- # converting to a non-jax device: let torch native handle it
441
- torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor,
442
- Tensor) else arr
443
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
444
- return torch_tensor.to(new_device)
445
- case "jax":
446
- # move torchax.tensor / jax tensor between devices
447
- # I don't know ifgit this will work after the model is jitted
448
- if self.target_device != the_tensor.jax_device.platform:
449
- arr = jax.device_put(the_tensor.jax(),
450
- jax.devices(self.target_device)[0])
451
- return Tensor(arr, self)
452
- case _:
453
- logging.error(f"torchax.Tenosr cannot handle device {new_device}")
454
-
455
- else:
456
- if new_dtype is not None and new_dtype != the_tensor.dtype:
460
+ if isinstance(new_device, torch.device):
461
+ new_device = new_device.type
462
+ res = the_tensor
463
+ if not self._is_same_device(the_tensor, new_device):
464
+ if isinstance(the_tensor, Tensor):
465
+ torch_tensor = self.j2t_copy(the_tensor._elem)
457
466
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
458
- the_tensor = the_tensor.to(new_dtype)
459
-
460
- if new_device is None: ## device is None means don't change device
461
- return the_tensor
462
-
463
- jax_device = self.get_as_jax_device(new_device)
464
- if jax_device:
467
+ return torch_tensor.to(device=new_device, dtype=new_dtype)
468
+ else:
465
469
  arr = self.t2j_copy(the_tensor)
466
- arr = jax.device_put(arr, jax_device)
470
+ res = Tensor(arr, self, the_tensor.requires_grad)
471
+
472
+ if new_dtype is not None and new_dtype != res.dtype:
473
+ if isinstance(res, Tensor):
474
+ res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
467
475
  else:
468
476
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
469
- return the_tensor.to(new_device)
470
-
471
- return Tensor(arr, self)
477
+ return res.to(device=new_device, dtype=new_dtype)
478
+ return res
472
479
 
473
480
  def get_and_rotate_prng_key(self,
474
481
  generator: Optional[torch.Generator] = None):
475
482
  if generator is not None:
476
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
477
- self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63))
478
- old_key = self._prng_key[...]
479
- new_prng_key, next_key = jax.random.split(old_key)
480
- self._prng_key[...] = new_prng_key
481
- return next_key
483
+ return jax.random.PRNGKey(generator.initial_seed() % (2**63))
484
+ return self.param.get_and_rotate_prng_key()
482
485
 
483
486
  def _handle_tensor_constructor(self, func, args, kwargs):
484
487
  device = kwargs.get("device")
485
- jax_device = self.get_as_jax_device(device)
486
- # TODO(qihqi) figure out better ways for device propagation
487
- if not self._manually_entered and jax_device is None:
488
- # let torch handle it
489
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
490
- return func(*args, **kwargs)
491
- with jax.default_device(jax_device):
488
+ if self._should_use_torchax_tensor(device):
489
+ # don't set default device, let caller set it
492
490
  requires_grad = kwargs.get("requires_grad", False)
493
491
  op = self._get_op_or_decomp(func)
492
+ if op.needs_env:
493
+ kwargs['env'] = self
494
+ if op.is_jax_function:
495
+ (args, kwargs) = self.t2j_iso((args, kwargs))
494
496
  res = op.func(*args, **kwargs)
495
497
  if isinstance(res, jax.Array):
496
- res = Tensor(res, self)
497
- if requires_grad:
498
- res.requires_grad = True
498
+ res = Tensor(res, self, requires_grad)
499
499
  return res
500
+ else:
501
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
502
+ return func(*args, **kwargs)
500
503
 
501
504
  def _torch_Tensor_to(self, args, kwargs):
502
505
  the_tensor = args[0]
@@ -560,11 +563,11 @@ class Environment(contextlib.ContextDecorator):
560
563
  args, kwargs = self.v2t_iso((args, kwargs))
561
564
 
562
565
  with self:
563
- if self.autocast_dtype is not None:
566
+ if self.param.autocast_dtype is not None:
564
567
  autocast_policy = amp.autocast_policy.get(func)
565
568
  if autocast_policy is not None:
566
569
  args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
567
- self.autocast_dtype)
570
+ self.param.autocast_dtype)
568
571
 
569
572
  if op.is_jax_function:
570
573
  args, kwargs = self.t2j_iso((args, kwargs))
@@ -609,11 +612,9 @@ class Environment(contextlib.ContextDecorator):
609
612
 
610
613
  def __enter__(self):
611
614
  self.enable_torch_modes()
612
- self._manually_entered = True
613
615
  return self
614
616
 
615
617
  def __exit__(self, *exc):
616
- self._manually_entered = False
617
618
  self.disable_torch_modes(*exc)
618
619
 
619
620
  def _move_one_value(self, val):
@@ -639,6 +640,10 @@ class Environment(contextlib.ContextDecorator):
639
640
  """
640
641
 
641
642
  def to_jax(x):
643
+ if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
644
+ x, Tensor):
645
+ if x.squeeze().ndim == 0:
646
+ return x.item()
642
647
  if isinstance(
643
648
  x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
644
649
  x = x.wait()
@@ -697,3 +702,10 @@ class Environment(contextlib.ContextDecorator):
697
702
  is_user_defined=True,
698
703
  needs_env=False,
699
704
  )
705
+
706
+ @contextlib.contextmanager
707
+ def override_property(self, **kwargs):
708
+ new_prop = self.param.override(**kwargs)
709
+ self._property.content.append(new_prop)
710
+ yield
711
+ self._property.content.pop()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
@@ -1,32 +1,34 @@
1
1
  torchax/CONTRIBUTING.md,sha256=VOL0us6kS-uc4yE6IlSm6SDHYHnx-gw-0upFnP0VkSQ,1369
2
- torchax/__init__.py,sha256=T8tYMpwfP9i3FLzci2_TCGH58PBbRjcRO4O3sgyyk_0,3945
3
- torchax/amp.py,sha256=WycgMeZfwgzVDqu9ADnUHwhbXSQXtVUoIUXP3jcMF1k,11818
4
- torchax/config.py,sha256=N52pUw18H8UnIka524w07mX_kv3kGoRrZYhc4VbV8wc,727
2
+ torchax/__init__.py,sha256=fVp0Hgq6-FwGzj7Gt9yH0qwzAzZ3Z7TZdSyLMHc-nrY,3157
3
+ torchax/amp.py,sha256=-k8t4lrCsJLKHEhI6J0aHE3MAPEL-4DP6wCKtMwo1AM,11791
4
+ torchax/config.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
5
+ torchax/configuration.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
5
6
  torchax/decompositions.py,sha256=1p5TFZfAJ2Bs9BiSO1vXbnWEXnbPfC_gCQ54rDXhd9k,28859
6
- torchax/device_module.py,sha256=yGFPczPiXPlhTtpx-hBaxnhAhOuegRrxGgyvlWI2n_M,260
7
- torchax/distributed.py,sha256=9WyscssryK9jje9LPX-iiN0p4giHXzHzzPYu9G1Rg54,7703
7
+ torchax/device_module.py,sha256=7fkdPwXG0qCBTmvDYHp0fvv4xK0W9avV_Ua3MeMzczE,349
8
+ torchax/environment.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
9
  torchax/export.py,sha256=xU-UbrQBvQWUy-GM2FfeIHymlEdmYDYcPymjlcXM23w,8969
9
10
  torchax/flax.py,sha256=2Tg8inGskgAfByPxJQh4ItZHHAb-960gYq156bSO8V4,1280
10
- torchax/interop.py,sha256=T7Wt3nngGF-6lweRKb28cmjqhMgpQiro6yGnj9a35MY,10676
11
- torchax/mesh_util.py,sha256=g3yv8pS4ox1_QU70U6OHyPbEpLw-1AzQSTPGyZ4D5q8,8965
12
- torchax/tensor.py,sha256=sF0Pi0V4kUwuFADAs3k9-ypDc2o1RuNCBAz9LtJwe9c,20937
11
+ torchax/interop.py,sha256=5r3ZRUQAJj9n-7NGBxbP-N87-K-8GoYftULq1r2CDxE,11285
12
+ torchax/mesh_util.py,sha256=Ab4ic2eHWmQ3Mw3jpERvi-TKLIcDvQQoC6tuIZ9ig7Q,9314
13
+ torchax/tensor.py,sha256=vU-RR6LArrQlO62fTNQQ4RFLRyKJ3Oa9GXsbmq4K8rI,20872
13
14
  torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
14
15
  torchax/train.py,sha256=rtvj6HkdnG9fc3VWYPNwHuxGlUxFJkUXJWED8azgtok,3855
15
16
  torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
16
17
  torchax/util.py,sha256=cb-eudDE7AX2s-6zYtXdowgyzyvqPqE9MPP82PfH23g,3069
17
18
  torchax/view.py,sha256=1ekqRN04lAPd_icgZMKbSYWhr738DzVloc34ynml4wo,11121
18
19
  torchax/ops/__init__.py,sha256=Vr1p8zDHwfXZBUbw70iNiCJLZLNdI6gR_vUlaiA7Usg,270
19
- torchax/ops/jaten.py,sha256=rUnyJVzvU701SOIGH_b_huLWH7NrrgSDQTRzlGJNn_A,165737
20
+ torchax/ops/jaten.py,sha256=WxfZU6p7b7OR98B3z0LCXKlV6U5aslXxJMJirBr6lns,165835
20
21
  torchax/ops/jax_reimplement.py,sha256=idkmFWNCXBilkmaHBGdivKz0XhsjSpqLNlGXxbBOKWQ,7302
21
22
  torchax/ops/jc10d.py,sha256=OzSYYle_5jBmNVP64SuJPz9S-rRGD6H7e1a9HHIKsjU,1322
22
23
  torchax/ops/jimage.py,sha256=P0lAauYX_au_xjIHDsG7H6jO7Jf54_VCAjzZuIZdhO0,3182
23
24
  torchax/ops/jlibrary.py,sha256=YfYUQbf5dKiMtEHUMfdgHTeLuNvvSTJ-l8s7wQNIvO0,2930
24
- torchax/ops/jtorch.py,sha256=LMHz85UfLerbyrB0IZqHIXpQXfTDmnzhaaE3_SHtMH4,16870
25
+ torchax/ops/jtorch.py,sha256=wR4ZdDscxqG4VpxjcLGzgdUKmipa3fp7S0mK3DcD--A,17161
25
26
  torchax/ops/jtorchvision_nms.py,sha256=HSnhwU0gFaHucT7EvrEruJdnWkAWTw4T35GY525ohO8,8903
26
- torchax/ops/mappings.py,sha256=AESERtXJ6i_Hm0ycwEw7z5OJnHu-7QteWlSs-mlUPE4,3492
27
+ torchax/ops/mappings.py,sha256=H-2jlG9ODuV9VzCFqZEC-djTrbcYXmw4fAVwn5Yilc4,3787
27
28
  torchax/ops/op_base.py,sha256=MLKFxMojIXgz4lkTE6k-8F-ddve-9vEiXkzj3P-YJPs,3739
28
29
  torchax/ops/ops_registry.py,sha256=qADpG1up0JOThoybiOQoRDWtAe5TOkHlqcj1bSHjtGY,1594
29
- torchax-0.0.5.dist-info/METADATA,sha256=fyGJQ51oOgCz8OOpxsbiuIVyQFT5G-wyx2R2KiUHGXE,10753
30
- torchax-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
31
- torchax-0.0.5.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
32
- torchax-0.0.5.dist-info/RECORD,,
30
+ torchax/ops/type_casting.py,sha256=gNz3mbA9XtRhkHcx-qpF1bFzsnsila-jkCE9BPQD9GI,1391
31
+ torchax-0.0.7.dist-info/METADATA,sha256=_F_gU0Ea6epTCngRXcBeur4oH8NgOvgq78DBhjt6zEo,10753
32
+ torchax-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
33
+ torchax-0.0.7.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
34
+ torchax-0.0.7.dist-info/RECORD,,
torchax/distributed.py DELETED
@@ -1,241 +0,0 @@
1
- """`torch.distributed` backend implemented with JAX collective ops.
2
-
3
- EXPERIMENTAL: This module is still highly experimental, and it may be removed
4
- before any stable release.
5
-
6
- Note: JAX collective ops require that axis names be defined in `pmap` or
7
- `shmap`. The distributed backend only supports one axis, named `torch_dist`.
8
- This name is defined by our mirror implementation of `spawn`.
9
- """
10
-
11
- import datetime
12
- import functools
13
- import logging
14
- import os
15
- from typing import List, Optional, Union
16
-
17
- import jax
18
- import numpy as np
19
- import torch
20
- import torch.distributed as dist
21
- import torch.distributed._functional_collectives
22
- from torch._C._distributed_c10d import ProcessGroup # type: ignore
23
- import torch.distributed
24
- import torchax
25
- from jax.sharding import NamedSharding
26
- from jax.sharding import Mesh, PartitionSpec as P
27
- from jax.experimental import mesh_utils
28
- import torch.utils._pytree as torch_pytree
29
- from torchax import interop
30
-
31
-
32
- class ProcessGroupJax(ProcessGroup):
33
- """Distributed backend implemented with JAX."""
34
-
35
- def __init__(self, prefix_store, rank, size, timeout):
36
- super().__init__(rank, size)
37
- self._group_name = None
38
-
39
- def getBackendName(self):
40
- return "jax"
41
-
42
- # TODO(wcromar): why doesn't default group name setter work?
43
- # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
44
- def _set_group_name(self, name: str) -> None:
45
- self._group_name = name
46
-
47
- @property
48
- def group_name(self):
49
- assert self._group_name
50
- return self._group_name
51
-
52
- @staticmethod
53
- def _work(
54
- tensors: Union[torch.Tensor, List[torch.Tensor],
55
- List[List[torch.Tensor]]],
56
- ) -> dist.Work:
57
- fut = torch.futures.Future()
58
- fut.set_result(tensors)
59
- return torch._C._distributed_c10d._create_work_from_future(fut)
60
-
61
- def _allgather_base(
62
- self,
63
- output: torch.Tensor,
64
- input: torch.Tensor,
65
- opts=...,
66
- ) -> dist.Work:
67
- assert isinstance(input, torchax.tensor.Tensor)
68
- assert isinstance(output, torchax.tensor.Tensor)
69
- torch.distributed._functional_collectives.all_gather_tensor_inplace(
70
- output, input, group=self)
71
- return self._work(output)
72
-
73
- def allreduce(
74
- self,
75
- tensors: List[torch.Tensor],
76
- opts: dist.AllreduceOptions = ...,
77
- ) -> dist.Work:
78
- assert len(tensors) == 1
79
- assert isinstance(tensors[0], torchax.tensor.Tensor)
80
- torch.distributed._functional_collectives.all_reduce_inplace(
81
- tensors[0],
82
- torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
83
- opts.reduceOp.op],
84
- self,
85
- )
86
-
87
- return self._work(tensors)
88
-
89
- def broadcast(
90
- self,
91
- tensors: List[torch.Tensor],
92
- opts: dist.BroadcastOptions = ...,
93
- ) -> dist.Work:
94
- assert len(tensors) == 1
95
- assert isinstance(tensors[0], torchax.tensor.Tensor)
96
- tensors[0].copy_(
97
- torch.distributed._functional_collectives.broadcast(
98
- tensors[0], opts.rootRank, group=self))
99
-
100
- return self._work(tensors)
101
-
102
-
103
- dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"])
104
-
105
-
106
- def jax_rendezvous_handler(url: str,
107
- timeout: datetime.timedelta = ...,
108
- **kwargs):
109
- """Initialize distributed store with JAX process IDs.
110
-
111
- Requires `$MASTER_ADDR` and `$MASTER_PORT`.
112
- """
113
- # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU
114
- # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part
115
- # of their public Python API
116
- master_ip = os.environ["MASTER_ADDR"]
117
- master_port = int(os.environ["MASTER_PORT"])
118
- # TODO(wcromar): Use `torchrun`'s store if available
119
- store = dist.TCPStore(
120
- master_ip,
121
- master_port,
122
- jax.process_count(),
123
- is_master=jax.process_index() == 0,
124
- )
125
-
126
- yield (store, jax.process_index(), jax.process_count())
127
-
128
-
129
- dist.register_rendezvous_handler("jax", jax_rendezvous_handler)
130
-
131
-
132
- def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
133
- """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined.
134
- `f` is expected to take the replica index as a positional argument, similar
135
- to `torch.multiprocessing.spawn`.
136
- Note: `spawn` does not actually create parallel processes.
137
- """
138
- env = env or torchax.default_env()
139
-
140
- def jax_wrapper(index, jax_args):
141
- index, args = env.j2t_iso([index, jax_args])
142
- torch_outputs = f(index, *args)
143
- return env.t2j_iso(torch_outputs)
144
-
145
- jax_outputs = jax.pmap(
146
- jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()),
147
- env.t2j_iso(args))
148
- return env.j2t_iso(jax_outputs)
149
-
150
-
151
- class DistributedDataParallel(torch.nn.Module):
152
- """Re-implementation of DistributedDataParallel using JAX SPMD.
153
-
154
- Splits inputs along batch dimension (assumed to be 0) across all devices in
155
- JAX runtime, including remote devices. Each process should load a distinct
156
- shard of the input data using e.g. DistributedSampler. Each process' shard
157
- is then further split among the addressable devices (e.g. local TPU chips)
158
- by `shard_input`.
159
-
160
- Note: since parameters are replicated across addressable devices, inputs
161
- must also be SPMD sharded using `shard_input` or `replicate_input`.
162
-
163
- Example usage:
164
-
165
- ```
166
- jax_model = torchax.distributed.DistributedDataParallel(create_model())
167
- for data, dataloader:
168
- jax_data = jax_model.shard_input(data)
169
- jax_output = jax_model(jax_data)
170
- ```
171
- """
172
-
173
- def __init__(
174
- self,
175
- module: torch.nn.Module,
176
- env: Optional[torchax.tensor.Environment] = None,
177
- **kwargs,
178
- ):
179
- if kwargs:
180
- logging.warning(f"Unsupported kwargs {kwargs}")
181
-
182
- super().__init__()
183
- self._env = env or torchax.default_env()
184
- self._mesh = Mesh(
185
- mesh_utils.create_device_mesh((jax.device_count(),)),
186
- axis_names=("batch",),
187
- )
188
- replicated_state = torch_pytree.tree_map_only(
189
- torch.Tensor,
190
- lambda t: self._env.j2t_iso(
191
- jax.device_put(
192
- self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))),
193
- module.state_dict(),
194
- )
195
- # TODO: broadcast
196
- module.load_state_dict(replicated_state, assign=True)
197
- self._module = module
198
-
199
- def shard_input(self, inp):
200
- per_process_batch_size = inp.shape[0] # assumes batch dim is 0
201
- per_replica_batch_size = per_process_batch_size // jax.local_device_count()
202
- per_replica_batches = torch.chunk(inp, jax.local_device_count())
203
- global_batch_size = per_replica_batch_size * jax.device_count()
204
- global_batch_shape = (global_batch_size,) + inp.shape[1:]
205
-
206
- sharding = NamedSharding(self._mesh, P("batch"))
207
- return self._env.j2t_iso(
208
- jax.make_array_from_single_device_arrays(
209
- global_batch_shape,
210
- NamedSharding(self._mesh, P("batch")),
211
- arrays=[
212
- jax.device_put(self._env.to_xla(batch)._elem, device) for batch,
213
- device in zip(per_replica_batches, sharding.addressable_devices)
214
- ],
215
- ))
216
-
217
- def replicate_input(self, inp):
218
- return self._env.j2t_iso(
219
- jax.device_put(inp._elem, NamedSharding(self._mesh, P())))
220
-
221
- def jit_step(self, func):
222
-
223
- @functools.partial(
224
- interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0})
225
- def _jit_fn(states, args):
226
- self.load_state_dict(states)
227
- outputs = func(*args)
228
- return self.state_dict(), outputs
229
-
230
- @functools.wraps(func)
231
- def inner(*args):
232
- jax_states = self.state_dict()
233
- new_states, outputs = _jit_fn(jax_states, args)
234
- self.load_state_dict(new_states)
235
- return outputs
236
-
237
- return inner
238
-
239
- def forward(self, *args):
240
- with self._env:
241
- return self._module(*args)