torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/tensor.py CHANGED
@@ -12,25 +12,26 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import threading
15
+ import contextlib
16
+ import itertools
16
17
  import logging
17
18
  import sys
18
- import contextlib
19
- from typing import Optional, Any
19
+ import threading
20
+ from typing import Any
21
+
20
22
  import jax
21
23
  import jax.numpy as jnp
22
24
  import numpy
23
- import itertools
24
25
  import torch
25
26
  import torch.distributed._functional_collectives
26
27
  import torch.func
27
28
  import torch.utils._mode_utils as mode_utils
28
29
  import torch.utils._python_dispatch as torch_dispatch
29
30
  import torch.utils._pytree as torch_pytree
30
- from torchax.view import View
31
- from torchax import config
31
+
32
+ from torchax import amp, config
32
33
  from torchax.ops import mappings, ops_registry
33
- from torchax import amp
34
+ from torchax.view import View
34
35
 
35
36
  logger = logging.getLogger(__name__)
36
37
 
@@ -52,7 +53,6 @@ log_nested.level = 0
52
53
 
53
54
 
54
55
  class Tensor(torch.Tensor):
55
-
56
56
  @staticmethod
57
57
  def __new__(cls, elem, env, requires_grad=False):
58
58
  dtype = mappings.j2t_dtype(elem.dtype)
@@ -62,16 +62,16 @@ class Tensor(torch.Tensor):
62
62
  shape[i] = 1
63
63
  if dtype is None:
64
64
  dtype = torch.float32
65
- #dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
65
+ # dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
66
66
  if not (dtype.is_floating_point or dtype.is_complex):
67
67
  requires_grad = False
68
68
 
69
69
  return torch.Tensor._make_wrapper_subclass(
70
- cls,
71
- shape,
72
- dtype=dtype,
73
- device='meta',
74
- requires_grad=requires_grad,
70
+ cls,
71
+ shape,
72
+ dtype=dtype,
73
+ device="meta",
74
+ requires_grad=requires_grad,
75
75
  )
76
76
 
77
77
  def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
@@ -80,7 +80,7 @@ class Tensor(torch.Tensor):
80
80
  self._env = env
81
81
 
82
82
  def __str__(self):
83
- return "Tensor({} {})".format(str(type(self._elem)), str(self._elem))
83
+ return f"Tensor({str(type(self._elem))} {str(self._elem)})"
84
84
 
85
85
  __repr__ = __str__
86
86
 
@@ -95,8 +95,7 @@ class Tensor(torch.Tensor):
95
95
  def flatten(self, start_dim=0, end_dim=-1):
96
96
  if end_dim == -1:
97
97
  end_dim = self.ndim
98
- new_shape = (
99
- self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
98
+ new_shape = self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :]
100
99
  new_elem = jnp.reshape(self._elem, new_shape)
101
100
  return Tensor(new_elem, self._env)
102
101
  # return torch.reshape(self, new_shape)
@@ -117,11 +116,12 @@ class Tensor(torch.Tensor):
117
116
  if func == torch.ops._c10d_functional.wait_tensor.default:
118
117
  return args[0]._env.dispatch(func, types, args, kwargs)
119
118
  if func == torch.ops.prim.device.default:
120
- return torch.device('privateuseone', 0)
119
+ return torch.device("privateuseone", 0)
121
120
  raise AssertionError(
122
- 'torchax Tensors can only do math within the torchax environment.'
123
- 'Please wrap your code with `with torchax.default_env()` or '
124
- 'call torchax.enable_globally() before.')
121
+ "torchax Tensors can only do math within the torchax environment."
122
+ "Please wrap your code with `with torchax.default_env()` or "
123
+ "call torchax.enable_globally() before."
124
+ )
125
125
 
126
126
  def detach(self):
127
127
  return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
@@ -154,8 +154,7 @@ class Tensor(torch.Tensor):
154
154
 
155
155
  @property
156
156
  def data(self):
157
- logger.warning(
158
- "In-place to .data modifications still results a copy on TPU")
157
+ logger.warning("In-place to .data modifications still results a copy on TPU")
159
158
  return self
160
159
 
161
160
  @data.setter
@@ -181,7 +180,8 @@ class Tensor(torch.Tensor):
181
180
 
182
181
  def debug_accuracy(func, args, kwargs, current_output):
183
182
  args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
184
- torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output))
183
+ torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)
184
+ )
185
185
 
186
186
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
187
187
  if "device" in kwargs_torch:
@@ -191,16 +191,17 @@ def debug_accuracy(func, args, kwargs, current_output):
191
191
  flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
192
192
  flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
193
193
 
194
- for ex, real in zip(flattened_expected_out, flattened_current_out):
194
+ for ex, real in zip(flattened_expected_out, flattened_current_out, strict=False):
195
195
  if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
196
196
  ex = ex.to(real.dtype)
197
197
  try:
198
198
  if isinstance(ex, torch.Tensor) and not torch.allclose(
199
- ex, real, atol=1e-3, equal_nan=True):
199
+ ex, real, atol=1e-3, equal_nan=True
200
+ ):
200
201
  import pdb
201
202
 
202
203
  pdb.set_trace()
203
- except:
204
+ except Exception:
204
205
  import pdb
205
206
 
206
207
  pdb.set_trace()
@@ -209,7 +210,6 @@ def debug_accuracy(func, args, kwargs, current_output):
209
210
 
210
211
 
211
212
  def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
212
-
213
213
  def _display(a):
214
214
  if isinstance(a, torch.Tensor):
215
215
  return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
@@ -221,9 +221,11 @@ def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
221
221
  kwargs = kwargs or {}
222
222
  title = "DISPATCH" if is_dispatch else "FUNCTION"
223
223
  args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
224
- kwargs_msg = ("kwargs: " +
225
- ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
226
- if log_args else "")
224
+ kwargs_msg = (
225
+ "kwargs: " + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
226
+ if log_args
227
+ else ""
228
+ )
227
229
  return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
228
230
 
229
231
 
@@ -233,49 +235,43 @@ class XLAFunctionMode(torch.overrides.TorchFunctionMode):
233
235
  def __init__(self, env):
234
236
  self.env = env
235
237
 
236
- def __torch_function__(self,
237
- func,
238
- types,
239
- args=(),
240
- kwargs=None) -> torch.Tensor:
238
+ def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
241
239
  message = f"FUNCTION: {_name_of_func(func)}"
242
240
  if self.env.config.debug_print_each_op_operands:
243
241
  message = message + "f"
244
- message = _make_debug_msg(False,
245
- self.env.config.debug_print_each_op_operands,
246
- func, args, kwargs)
242
+ message = _make_debug_msg(
243
+ False, self.env.config.debug_print_each_op_operands, func, args, kwargs
244
+ )
247
245
  with log_nested(self.env, message):
248
246
  try:
249
247
  return self.env.dispatch(func, types, args, kwargs)
250
248
  except OperatorNotFound:
251
249
  pass
252
- if _name_of_func(func) in (
253
- "rot90"): # skip rot90 with k%4==0 due to no change
254
- if len(args) >= 2 and type(args[1]) == int:
250
+ if _name_of_func(func) in ("rot90"): # skip rot90 with k%4==0 due to no change
251
+ if len(args) >= 2 and isinstance(args[1], int):
255
252
  if (args[1]) % 4 == 0:
256
253
  return args[0]
257
254
  return func(*args, **(kwargs or {}))
258
255
 
259
256
 
260
257
  class XLADispatchMode(torch_dispatch.TorchDispatchMode):
261
-
262
258
  def __init__(self, env):
263
259
  self.env = env
264
260
 
265
261
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
266
- message = _make_debug_msg(True,
267
- self.env.config.debug_print_each_op_operands,
268
- func, args, kwargs)
262
+ message = _make_debug_msg(
263
+ True, self.env.config.debug_print_each_op_operands, func, args, kwargs
264
+ )
269
265
  with log_nested(self.env, message):
270
266
  if isinstance(func, torch._ops.OpOverloadPacket):
271
267
  with self:
272
268
  return func(*args, **kwargs)
273
269
  # Only functions under these namespaces will be intercepted
274
270
  if func.namespace not in (
275
- "aten",
276
- "_c10d_functional",
277
- "torchvision",
278
- "xla",
271
+ "aten",
272
+ "_c10d_functional",
273
+ "torchvision",
274
+ "xla",
279
275
  ):
280
276
  return func(*args, **kwargs)
281
277
  return self.env.dispatch(func, types, args, kwargs)
@@ -289,18 +285,18 @@ def _name_of_func(func):
289
285
 
290
286
  # Constructors that don't take other tensor as input
291
287
  TENSOR_CONSTRUCTORS = {
292
- torch.ones,
293
- torch.zeros,
294
- torch.empty,
295
- torch.empty_strided,
296
- torch.tensor,
297
- torch.arange,
298
- torch.eye,
299
- torch.randn,
300
- torch.rand,
301
- torch.randint,
302
- torch.full,
303
- torch.as_tensor,
288
+ torch.ones,
289
+ torch.zeros,
290
+ torch.empty,
291
+ torch.empty_strided,
292
+ torch.tensor,
293
+ torch.arange,
294
+ torch.eye,
295
+ torch.randn,
296
+ torch.rand,
297
+ torch.randint,
298
+ torch.full,
299
+ torch.as_tensor,
304
300
  }
305
301
 
306
302
  # TODO(wen): use existing types, either from torch or jax
@@ -328,7 +324,6 @@ class RuntimeProperty:
328
324
 
329
325
 
330
326
  class OverrideProperty(RuntimeProperty):
331
-
332
327
  def __init__(self, parent, override):
333
328
  self.parent = parent
334
329
  self._override = dict(override)
@@ -372,25 +367,24 @@ class Environment(contextlib.ContextDecorator):
372
367
  _prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
373
368
  self._property = threading.local()
374
369
  self._initial_content = RuntimeProperty(
375
- mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
370
+ mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype
371
+ )
376
372
 
377
373
  @property
378
374
  def param(self):
379
- if not hasattr(self._property, 'content'):
380
- self._property.content = [
381
- self._initial_content
382
- ]
375
+ if not hasattr(self._property, "content"):
376
+ self._property.content = [self._initial_content]
383
377
  return self._property.content[-1]
384
378
 
385
379
  def manual_seed(self, key):
386
380
  if isinstance(key, torch.Tensor):
387
- assert key.ndim == 0, 'manual seed can only take scalars'
388
- assert not key.dtype.is_floating_point, 'manual seed can only be integers'
381
+ assert key.ndim == 0, "manual seed can only take scalars"
382
+ assert not key.dtype.is_floating_point, "manual seed can only be integers"
389
383
 
390
- if isinstance(key, Tensor):
391
- key = key._elem
392
- else:
393
- key = key.item()
384
+ if isinstance(key, Tensor):
385
+ key = key._elem
386
+ else:
387
+ key = key.item()
394
388
  jax_key = jax.random.PRNGKey(key)
395
389
  new_prop = self.param.override(prng=jax_key)
396
390
  self._property.content.append(new_prop)
@@ -406,27 +400,28 @@ class Environment(contextlib.ContextDecorator):
406
400
  if isinstance(device, torch.device):
407
401
  device = device.type
408
402
 
409
- if ':' in device:
410
- device = device.split(':')[0]
403
+ if ":" in device:
404
+ device = device.split(":")[0]
411
405
 
412
406
  match device:
413
- case 'cpu':
407
+ case "cpu":
414
408
  return False
415
- case 'cuda':
409
+ case "cuda":
416
410
  return self.config.treat_cuda_as_jax_device
417
- case 'jax':
411
+ case "jax":
418
412
  return True
419
- case 'privateuseone':
413
+ case "privateuseone":
420
414
  return True
421
- case 'meta':
415
+ case "meta":
422
416
  return self.enabled
423
417
  return False
424
418
 
425
419
  def load_ops(self):
426
- from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
420
+ from torchax.ops import jaten, jc10d, jtorch, jtorchvision_nms # noqa: F401
427
421
 
428
- for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
429
- ops_registry.all_torch_functions.items()):
422
+ for k, v in itertools.chain(
423
+ ops_registry.all_aten_ops.items(), ops_registry.all_torch_functions.items()
424
+ ):
430
425
  if v.is_jax_function:
431
426
  self._ops[k] = v
432
427
  else:
@@ -437,16 +432,15 @@ class Environment(contextlib.ContextDecorator):
437
432
  for k, v in DECOMPOSITIONS.items():
438
433
  if k not in self._decomps:
439
434
  self._decomps[k] = ops_registry.Operator(
440
- k,
441
- v,
442
- is_jax_function=False,
443
- is_user_defined=False,
444
- needs_env=False,
445
- is_view_op=k in MUTABLE_DECOMPOSITION,
435
+ k,
436
+ v,
437
+ is_jax_function=False,
438
+ is_user_defined=False,
439
+ needs_env=False,
440
+ is_view_op=k in MUTABLE_DECOMPOSITION,
446
441
  )
447
442
 
448
443
  def _get_op_or_decomp(self, func):
449
-
450
444
  def _get_from_dict(op_dict, op):
451
445
  op = op_dict.get(func)
452
446
  if op is None and isinstance(func, torch._ops.OpOverloadPacket):
@@ -463,17 +457,18 @@ class Environment(contextlib.ContextDecorator):
463
457
 
464
458
  if op is None:
465
459
  raise OperatorNotFound(
466
- f"Operator with name {_name_of_func(func)} has no lowering")
460
+ f"Operator with name {_name_of_func(func)} has no lowering"
461
+ )
467
462
 
468
463
  return op
469
464
 
470
465
  def _is_same_device(self, the_tensor, new_device):
471
466
  if new_device is None:
472
467
  return True
473
- if new_device == 'meta' and the_tensor.device.type == 'jax':
468
+ if new_device == "meta" and the_tensor.device.type == "jax":
474
469
  return True
475
470
  if the_tensor.device.type != new_device:
476
- if the_tensor.device.type == 'cuda':
471
+ if the_tensor.device.type == "cuda":
477
472
  return self.config.treat_cuda_as_jax_device
478
473
  return False
479
474
  return True
@@ -501,8 +496,7 @@ class Environment(contextlib.ContextDecorator):
501
496
  return res.to(device=new_device, dtype=new_dtype)
502
497
  return res
503
498
 
504
- def get_and_rotate_prng_key(self,
505
- generator: Optional[torch.Generator] = None):
499
+ def get_and_rotate_prng_key(self, generator: torch.Generator | None = None):
506
500
  if generator is not None:
507
501
  return jax.random.PRNGKey(generator.initial_seed() % (2**63))
508
502
  return self.param.get_and_rotate_prng_key()
@@ -514,7 +508,7 @@ class Environment(contextlib.ContextDecorator):
514
508
  requires_grad = kwargs.get("requires_grad", False)
515
509
  op = self._get_op_or_decomp(func)
516
510
  if op.needs_env:
517
- kwargs['env'] = self
511
+ kwargs["env"] = self
518
512
  if op.is_jax_function:
519
513
  (args, kwargs) = self.t2j_iso((args, kwargs))
520
514
  res = op.func(*args, **kwargs)
@@ -549,10 +543,10 @@ class Environment(contextlib.ContextDecorator):
549
543
  if func in TENSOR_CONSTRUCTORS:
550
544
  return self._handle_tensor_constructor(func, args, kwargs)
551
545
  if func in (
552
- torch.Tensor.to,
553
- torch.ops.aten.lift_fresh.default,
554
- torch.ops.aten._to_copy,
555
- torch.ops.aten._to_copy.default,
546
+ torch.Tensor.to,
547
+ torch.ops.aten.lift_fresh.default,
548
+ torch.ops.aten._to_copy,
549
+ torch.ops.aten._to_copy.default,
556
550
  ):
557
551
  return self._torch_Tensor_to(args, kwargs)
558
552
 
@@ -560,8 +554,7 @@ class Environment(contextlib.ContextDecorator):
560
554
  # We should skip and let torch handle it.
561
555
 
562
556
  tensor_args = [
563
- t for t in torch_pytree.tree_flatten(args)[0]
564
- if isinstance(t, torch.Tensor)
557
+ t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)
565
558
  ]
566
559
 
567
560
  def is_not_torchax_tensor(x):
@@ -577,9 +570,9 @@ class Environment(contextlib.ContextDecorator):
577
570
  old_args, old_kwargs = args, kwargs
578
571
  with self._dispatch_mode:
579
572
  args, kwargs = torch_pytree.tree_map_only(
580
- torch.distributed._functional_collectives.AsyncCollectiveTensor,
581
- torch.distributed._functional_collectives.wait_tensor,
582
- (args, kwargs),
573
+ torch.distributed._functional_collectives.AsyncCollectiveTensor,
574
+ torch.distributed._functional_collectives.wait_tensor,
575
+ (args, kwargs),
583
576
  )
584
577
 
585
578
  try:
@@ -590,8 +583,9 @@ class Environment(contextlib.ContextDecorator):
590
583
  if self.param.autocast_dtype is not None:
591
584
  autocast_policy = amp.autocast_policy.get(func)
592
585
  if autocast_policy is not None:
593
- args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
594
- self.param.autocast_dtype)
586
+ args, kwargs = amp.execute_policy(
587
+ autocast_policy, args, kwargs, self.param.autocast_dtype
588
+ )
595
589
 
596
590
  if op.is_jax_function:
597
591
  args, kwargs = self.t2j_iso((args, kwargs))
@@ -664,15 +658,17 @@ class Environment(contextlib.ContextDecorator):
664
658
  """
665
659
 
666
660
  def to_jax(x):
667
- if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
668
- x, Tensor) and not isinstance(x, View):
661
+ if (
662
+ self.config.allow_mixed_math_with_scalar_tensor
663
+ and not isinstance(x, Tensor)
664
+ and not isinstance(x, View)
665
+ ):
669
666
  if x.squeeze().ndim == 0:
670
667
  return x.item()
671
- if isinstance(
672
- x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
668
+ if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
673
669
  x = x.wait()
674
670
  assert isinstance(x, Tensor) or isinstance(x, View), (
675
- f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
671
+ f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
676
672
  )
677
673
  return x.jax()
678
674
 
@@ -680,7 +676,6 @@ class Environment(contextlib.ContextDecorator):
680
676
  return res
681
677
 
682
678
  def v2t_iso(self, views):
683
-
684
679
  def to_tensor(x):
685
680
  if isinstance(x, View):
686
681
  return x.torch()
@@ -695,8 +690,7 @@ class Environment(contextlib.ContextDecorator):
695
690
  This function will not copy, will just wrap the jax array with a torchax Tensor
696
691
  Note: iso is short for "isomorphic"
697
692
  """
698
- return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
699
- jaxarray)
693
+ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), jaxarray)
700
694
 
701
695
  def j2t_copy(self, args):
702
696
  """Convert torch.Tensor in cpu to a jax array
@@ -704,9 +698,10 @@ class Environment(contextlib.ContextDecorator):
704
698
  This might involves copying the data (depending if dlpack is enabled)
705
699
  """
706
700
  return torch_pytree.tree_map_only(
707
- jax.Array,
708
- lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
709
- args)
701
+ jax.Array,
702
+ lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
703
+ args,
704
+ )
710
705
 
711
706
  def t2j_copy(self, args):
712
707
  """Convert jax array to torch.Tensor in cpu.
@@ -714,18 +709,19 @@ class Environment(contextlib.ContextDecorator):
714
709
  This might involves copying the data (depending if dlpack is enabled)
715
710
  """
716
711
  return torch_pytree.tree_map_only(
717
- torch.Tensor,
718
- lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
719
- args)
712
+ torch.Tensor,
713
+ lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
714
+ args,
715
+ )
720
716
 
721
717
  def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
722
718
  self._ops[op_to_override] = ops_registry.Operator(
723
- op_to_override,
724
- op_impl,
725
- is_jax_function=False,
726
- is_user_defined=True,
727
- needs_env=False,
728
- is_view_op=is_view_op,
719
+ op_to_override,
720
+ op_impl,
721
+ is_jax_function=False,
722
+ is_user_defined=True,
723
+ needs_env=False,
724
+ is_view_op=is_view_op,
729
725
  )
730
726
 
731
727
  @contextlib.contextmanager