torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/tensor.py ADDED
@@ -0,0 +1,732 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import itertools
17
+ import logging
18
+ import sys
19
+ import threading
20
+ from typing import Any
21
+
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import numpy
25
+ import torch
26
+ import torch.distributed._functional_collectives
27
+ import torch.func
28
+ import torch.utils._mode_utils as mode_utils
29
+ import torch.utils._python_dispatch as torch_dispatch
30
+ import torch.utils._pytree as torch_pytree
31
+
32
+ from torchax import amp, config
33
+ from torchax.ops import mappings, ops_registry
34
+ from torchax.view import View
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class OperatorNotFound(Exception):
40
+ pass
41
+
42
+
43
+ @contextlib.contextmanager
44
+ def log_nested(env, message):
45
+ if env.config.debug_print_each_op:
46
+ print((" " * log_nested.level) + message, file=sys.stderr)
47
+ log_nested.level += 1
48
+ yield
49
+ log_nested.level -= 1
50
+
51
+
52
+ log_nested.level = 0
53
+
54
+
55
+ class Tensor(torch.Tensor):
56
+ @staticmethod
57
+ def __new__(cls, elem, env, requires_grad=False):
58
+ dtype = mappings.j2t_dtype(elem.dtype)
59
+ shape = list(elem.shape)
60
+ for i, s in enumerate(shape):
61
+ if not isinstance(s, int):
62
+ shape[i] = 1
63
+ if dtype is None:
64
+ dtype = torch.float32
65
+ # dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
66
+ if not (dtype.is_floating_point or dtype.is_complex):
67
+ requires_grad = False
68
+
69
+ return torch.Tensor._make_wrapper_subclass(
70
+ cls,
71
+ shape,
72
+ dtype=dtype,
73
+ device="meta",
74
+ requires_grad=requires_grad,
75
+ )
76
+
77
+ def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
78
+ super().__init__()
79
+ self._elem = elem
80
+ self._env = env
81
+
82
+ def __str__(self):
83
+ return f"Tensor({str(type(self._elem))} {str(self._elem)})"
84
+
85
+ __repr__ = __str__
86
+
87
+ @property
88
+ def shape(self):
89
+ return torch.Size(self._elem.shape)
90
+
91
+ @property
92
+ def ndim(self):
93
+ return len(self._elem.shape)
94
+
95
+ def flatten(self, start_dim=0, end_dim=-1):
96
+ if end_dim == -1:
97
+ end_dim = self.ndim
98
+ new_shape = self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :]
99
+ new_elem = jnp.reshape(self._elem, new_shape)
100
+ return Tensor(new_elem, self._env)
101
+ # return torch.reshape(self, new_shape)
102
+
103
+ def __setitem__(self, key, val):
104
+ key, val = self._env.t2j_iso((key, val))
105
+ self._elem = self._elem.at[key].set(val)
106
+
107
+ def type_as(self, other):
108
+ self._elem = self._elem.astype(other._elem.dtype)
109
+ return self
110
+
111
+ __torch_function__ = torch._C._disabled_torch_function_impl
112
+
113
+ @classmethod
114
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
115
+ # TODO(hanq): figure out why is dispatch mode not sufficient
116
+ if func == torch.ops._c10d_functional.wait_tensor.default:
117
+ return args[0]._env.dispatch(func, types, args, kwargs)
118
+ if func == torch.ops.prim.device.default:
119
+ return torch.device("privateuseone", 0)
120
+ raise AssertionError(
121
+ "torchax Tensors can only do math within the torchax environment."
122
+ "Please wrap your code with `with torchax.default_env()` or "
123
+ "call torchax.enable_globally() before."
124
+ )
125
+
126
+ def detach(self):
127
+ return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
128
+
129
+ def numpy(self) -> numpy.ndarray:
130
+ import numpy as np
131
+
132
+ return np.array(self._elem)
133
+
134
+ def jax(self) -> jax.Array:
135
+ return self._elem
136
+
137
+ def torch(self) -> torch.Tensor:
138
+ return self._env.j2t_copy(self.jax())
139
+
140
+ @property
141
+ def dtype(self):
142
+ return mappings.j2t_dtype(self._elem.dtype)
143
+
144
+ def dim(self):
145
+ return self.ndim
146
+
147
+ @property
148
+ def device(self):
149
+ return torch.device("jax:0")
150
+
151
+ @property
152
+ def jax_device(self):
153
+ return self._elem.device
154
+
155
+ @property
156
+ def data(self):
157
+ logger.warning("In-place to .data modifications still results a copy on TPU")
158
+ return self
159
+
160
+ @data.setter
161
+ def data(self, other):
162
+ if isinstance(other, Tensor):
163
+ self._elem = other._elem
164
+
165
+ def apply_jax(self, jax_function, *args, **kwargs):
166
+ # Call a jax function on _elem
167
+ res = jax_function(self._elem, *args, **kwargs)
168
+ return self._env.j2t_iso(res)
169
+
170
+ def apply_jax_(self, jax_function, *args, **kwargs):
171
+ self._elem = jax_function(self._elem, *args, **kwargs)
172
+ return self
173
+
174
+ def tolist(self):
175
+ return self._elem.tolist()
176
+
177
+ def shard_(self, sharding):
178
+ self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
179
+
180
+
181
+ def debug_accuracy(func, args, kwargs, current_output):
182
+ args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
183
+ torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)
184
+ )
185
+
186
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
187
+ if "device" in kwargs_torch:
188
+ kwargs_torch["device"] = "cpu" # do the torch native for comparison
189
+ expected_out = func(*args_torch, **kwargs_torch)
190
+
191
+ flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
192
+ flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
193
+
194
+ for ex, real in zip(flattened_expected_out, flattened_current_out, strict=False):
195
+ if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
196
+ ex = ex.to(real.dtype)
197
+ try:
198
+ if isinstance(ex, torch.Tensor) and not torch.allclose(
199
+ ex, real, atol=1e-3, equal_nan=True
200
+ ):
201
+ import pdb
202
+
203
+ pdb.set_trace()
204
+ except Exception:
205
+ import pdb
206
+
207
+ pdb.set_trace()
208
+
209
+ return True
210
+
211
+
212
+ def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
213
+ def _display(a):
214
+ if isinstance(a, torch.Tensor):
215
+ return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
216
+ elif isinstance(a, jax.Array):
217
+ return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
218
+ else:
219
+ return str(a)
220
+
221
+ kwargs = kwargs or {}
222
+ title = "DISPATCH" if is_dispatch else "FUNCTION"
223
+ args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
224
+ kwargs_msg = (
225
+ "kwargs: " + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
226
+ if log_args
227
+ else ""
228
+ )
229
+ return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
230
+
231
+
232
+ class XLAFunctionMode(torch.overrides.TorchFunctionMode):
233
+ """Context manager that dispatches torch function calls to JAX."""
234
+
235
+ def __init__(self, env):
236
+ self.env = env
237
+
238
+ def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
239
+ message = f"FUNCTION: {_name_of_func(func)}"
240
+ if self.env.config.debug_print_each_op_operands:
241
+ message = message + "f"
242
+ message = _make_debug_msg(
243
+ False, self.env.config.debug_print_each_op_operands, func, args, kwargs
244
+ )
245
+ with log_nested(self.env, message):
246
+ try:
247
+ return self.env.dispatch(func, types, args, kwargs)
248
+ except OperatorNotFound:
249
+ pass
250
+ if _name_of_func(func) in ("rot90"): # skip rot90 with k%4==0 due to no change
251
+ if len(args) >= 2 and isinstance(args[1], int):
252
+ if (args[1]) % 4 == 0:
253
+ return args[0]
254
+ return func(*args, **(kwargs or {}))
255
+
256
+
257
+ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
258
+ def __init__(self, env):
259
+ self.env = env
260
+
261
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
262
+ message = _make_debug_msg(
263
+ True, self.env.config.debug_print_each_op_operands, func, args, kwargs
264
+ )
265
+ with log_nested(self.env, message):
266
+ if isinstance(func, torch._ops.OpOverloadPacket):
267
+ with self:
268
+ return func(*args, **kwargs)
269
+ # Only functions under these namespaces will be intercepted
270
+ if func.namespace not in (
271
+ "aten",
272
+ "_c10d_functional",
273
+ "torchvision",
274
+ "xla",
275
+ ):
276
+ return func(*args, **kwargs)
277
+ return self.env.dispatch(func, types, args, kwargs)
278
+
279
+
280
+ def _name_of_func(func):
281
+ if hasattr(func, "name"):
282
+ return func.name()
283
+ return func.__name__
284
+
285
+
286
+ # Constructors that don't take other tensor as input
287
+ TENSOR_CONSTRUCTORS = {
288
+ torch.ones,
289
+ torch.zeros,
290
+ torch.empty,
291
+ torch.empty_strided,
292
+ torch.tensor,
293
+ torch.arange,
294
+ torch.eye,
295
+ torch.randn,
296
+ torch.rand,
297
+ torch.randint,
298
+ torch.full,
299
+ torch.as_tensor,
300
+ }
301
+
302
+ # TODO(wen): use existing types, either from torch or jax
303
+ SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
304
+
305
+
306
+ class RuntimeProperty:
307
+ mesh: Any
308
+ prng: Any
309
+ autocast_dtype: Any
310
+
311
+ def __init__(self, mesh, prng, autocast_dtype):
312
+ self.mesh = mesh
313
+ self.prng = prng
314
+ self.autocast_dtype = autocast_dtype
315
+
316
+ def override(self, **kwargs):
317
+ return OverrideProperty(self, kwargs)
318
+
319
+ def get_and_rotate_prng_key(self):
320
+ old_key = self.prng
321
+ new_prng_key, next_key = jax.random.split(old_key)
322
+ self.prng = new_prng_key
323
+ return next_key
324
+
325
+
326
+ class OverrideProperty(RuntimeProperty):
327
+ def __init__(self, parent, override):
328
+ self.parent = parent
329
+ self._override = dict(override)
330
+
331
+ def __getattr__(self, name):
332
+ if name in self._override:
333
+ return self._override[name]
334
+ return getattr(self.parent, name)
335
+
336
+
337
+ class Environment(contextlib.ContextDecorator):
338
+ """This class holds a set of configurations and "globals" needed
339
+
340
+ for executing torch program using jax.
341
+ Things included so far:
342
+
343
+ op registry
344
+ PRNGKey
345
+ Configs
346
+
347
+ Also helper functions to manipulate those.
348
+ """
349
+
350
+ def __init__(self, configuration=None):
351
+ self._function_mode = XLAFunctionMode(self)
352
+ self._dispatch_mode = XLADispatchMode(self)
353
+
354
+ # name is torch callable
355
+ self._ops = {}
356
+ self._decomps = {}
357
+
358
+ self.load_ops()
359
+
360
+ _mesh = None
361
+ self.config = configuration or config.Configuration()
362
+
363
+ self.enabled = False
364
+
365
+ autocast_dtype = None
366
+
367
+ _prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
368
+ self._property = threading.local()
369
+ self._initial_content = RuntimeProperty(
370
+ mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype
371
+ )
372
+
373
+ @property
374
+ def param(self):
375
+ if not hasattr(self._property, "content"):
376
+ self._property.content = [self._initial_content]
377
+ return self._property.content[-1]
378
+
379
+ def manual_seed(self, key):
380
+ if isinstance(key, torch.Tensor):
381
+ assert key.ndim == 0, "manual seed can only take scalars"
382
+ assert not key.dtype.is_floating_point, "manual seed can only be integers"
383
+
384
+ if isinstance(key, Tensor):
385
+ key = key._elem
386
+ else:
387
+ key = key.item()
388
+ jax_key = jax.random.PRNGKey(key)
389
+ new_prop = self.param.override(prng=jax_key)
390
+ self._property.content.append(new_prop)
391
+
392
+ @property
393
+ def prng_key(self):
394
+ return self.param.prng
395
+
396
+ def _should_use_torchax_tensor(self, device):
397
+ if device is None:
398
+ device = torch.get_default_device()
399
+
400
+ if isinstance(device, torch.device):
401
+ device = device.type
402
+
403
+ if ":" in device:
404
+ device = device.split(":")[0]
405
+
406
+ match device:
407
+ case "cpu":
408
+ return False
409
+ case "cuda":
410
+ return self.config.treat_cuda_as_jax_device
411
+ case "jax":
412
+ return True
413
+ case "privateuseone":
414
+ return True
415
+ case "meta":
416
+ return self.enabled
417
+ return False
418
+
419
+ def load_ops(self):
420
+ from torchax.ops import jaten, jc10d, jtorch, jtorchvision_nms # noqa: F401
421
+
422
+ for k, v in itertools.chain(
423
+ ops_registry.all_aten_ops.items(), ops_registry.all_torch_functions.items()
424
+ ):
425
+ if v.is_jax_function:
426
+ self._ops[k] = v
427
+ else:
428
+ self._decomps[k] = v
429
+
430
+ from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
431
+
432
+ for k, v in DECOMPOSITIONS.items():
433
+ if k not in self._decomps:
434
+ self._decomps[k] = ops_registry.Operator(
435
+ k,
436
+ v,
437
+ is_jax_function=False,
438
+ is_user_defined=False,
439
+ needs_env=False,
440
+ is_view_op=k in MUTABLE_DECOMPOSITION,
441
+ )
442
+
443
+ def _get_op_or_decomp(self, func):
444
+ def _get_from_dict(op_dict, op):
445
+ op = op_dict.get(func)
446
+ if op is None and isinstance(func, torch._ops.OpOverloadPacket):
447
+ op = op_dict.get(func.default)
448
+ if op is None and isinstance(func, torch._ops.OpOverload):
449
+ op = op_dict.get(func.overloadpacket)
450
+ return op
451
+
452
+ op = _get_from_dict(self._ops, func)
453
+
454
+ if op is None:
455
+ # fallback to decompose
456
+ op = _get_from_dict(self._decomps, func)
457
+
458
+ if op is None:
459
+ raise OperatorNotFound(
460
+ f"Operator with name {_name_of_func(func)} has no lowering"
461
+ )
462
+
463
+ return op
464
+
465
+ def _is_same_device(self, the_tensor, new_device):
466
+ if new_device is None:
467
+ return True
468
+ if new_device == "meta" and the_tensor.device.type == "jax":
469
+ return True
470
+ if the_tensor.device.type != new_device:
471
+ if the_tensor.device.type == "cuda":
472
+ return self.config.treat_cuda_as_jax_device
473
+ return False
474
+ return True
475
+
476
+ def _to_copy(self, the_tensor, new_dtype, new_device):
477
+ if isinstance(the_tensor, View):
478
+ the_tensor = the_tensor.torch()
479
+ if isinstance(new_device, torch.device):
480
+ new_device = new_device.type
481
+ res = the_tensor
482
+ if not self._is_same_device(the_tensor, new_device):
483
+ if isinstance(the_tensor, Tensor):
484
+ torch_tensor = self.j2t_copy(the_tensor._elem)
485
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
486
+ return torch_tensor.to(device=new_device, dtype=new_dtype)
487
+ else:
488
+ arr = self.t2j_copy(the_tensor)
489
+ res = Tensor(arr, self, the_tensor.requires_grad)
490
+
491
+ if new_dtype is not None and new_dtype != res.dtype:
492
+ if isinstance(res, Tensor):
493
+ res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
494
+ else:
495
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
496
+ return res.to(device=new_device, dtype=new_dtype)
497
+ return res
498
+
499
+ def get_and_rotate_prng_key(self, generator: torch.Generator | None = None):
500
+ if generator is not None:
501
+ return jax.random.PRNGKey(generator.initial_seed() % (2**63))
502
+ return self.param.get_and_rotate_prng_key()
503
+
504
+ def _handle_tensor_constructor(self, func, args, kwargs):
505
+ device = kwargs.get("device")
506
+ if self._should_use_torchax_tensor(device):
507
+ # don't set default device, let caller set it
508
+ requires_grad = kwargs.get("requires_grad", False)
509
+ op = self._get_op_or_decomp(func)
510
+ if op.needs_env:
511
+ kwargs["env"] = self
512
+ if op.is_jax_function:
513
+ (args, kwargs) = self.t2j_iso((args, kwargs))
514
+ res = op.func(*args, **kwargs)
515
+ if isinstance(res, jax.Array):
516
+ res = Tensor(res, self, requires_grad)
517
+ return res
518
+ else:
519
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
520
+ return func(*args, **kwargs)
521
+
522
+ def _torch_Tensor_to(self, args, kwargs):
523
+ the_tensor = args[0]
524
+ args = args[1:]
525
+ if len(args) >= 1 and isinstance(args[0], torch.Tensor):
526
+ dtype = args[0].dtype
527
+ device = args[0].device
528
+ return self._to_copy(the_tensor, dtype, device)
529
+ device = kwargs.get("device")
530
+ dtype = kwargs.get("dtype")
531
+ # args like pin_memory etc that we will ignore
532
+ args = list(filter(lambda x: not isinstance(x, bool), args))
533
+ if len(args) >= 2:
534
+ device, dtype, *_ = args
535
+ elif len(args) == 1 and isinstance(args[0], torch.dtype):
536
+ dtype = args[0]
537
+ elif len(args) == 1:
538
+ device = args[0]
539
+ return self._to_copy(the_tensor, dtype, device)
540
+
541
+ def dispatch(self, func, types, args, kwargs):
542
+ kwargs = kwargs or {}
543
+ if func in TENSOR_CONSTRUCTORS:
544
+ return self._handle_tensor_constructor(func, args, kwargs)
545
+ if func in (
546
+ torch.Tensor.to,
547
+ torch.ops.aten.lift_fresh.default,
548
+ torch.ops.aten._to_copy,
549
+ torch.ops.aten._to_copy.default,
550
+ ):
551
+ return self._torch_Tensor_to(args, kwargs)
552
+
553
+ # If the func doesn't act on Tensor, and is not a tensor constructor,
554
+ # We should skip and let torch handle it.
555
+
556
+ tensor_args = [
557
+ t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)
558
+ ]
559
+
560
+ def is_not_torchax_tensor(x):
561
+ return not isinstance(x, Tensor) and not isinstance(x, View)
562
+
563
+ if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
564
+ res = func(*args, **kwargs)
565
+ return res
566
+
567
+ with jax.named_scope(_name_of_func(func)):
568
+ op = self._get_op_or_decomp(func)
569
+
570
+ old_args, old_kwargs = args, kwargs
571
+ with self._dispatch_mode:
572
+ args, kwargs = torch_pytree.tree_map_only(
573
+ torch.distributed._functional_collectives.AsyncCollectiveTensor,
574
+ torch.distributed._functional_collectives.wait_tensor,
575
+ (args, kwargs),
576
+ )
577
+
578
+ try:
579
+ if not op.is_view_op:
580
+ args, kwargs = self.v2t_iso((args, kwargs))
581
+
582
+ with self:
583
+ if self.param.autocast_dtype is not None:
584
+ autocast_policy = amp.autocast_policy.get(func)
585
+ if autocast_policy is not None:
586
+ args, kwargs = amp.execute_policy(
587
+ autocast_policy, args, kwargs, self.param.autocast_dtype
588
+ )
589
+
590
+ if op.is_jax_function:
591
+ args, kwargs = self.t2j_iso((args, kwargs))
592
+ except AssertionError:
593
+ if self.config.debug_mixed_tensor:
594
+ breakpoint()
595
+ else:
596
+ raise
597
+
598
+ if op.needs_env:
599
+ kwargs["env"] = self
600
+
601
+ if op.is_jax_function:
602
+ res = op.func(*args, **kwargs)
603
+ else:
604
+ # enable dispatch mode because this op could be a composite autograd op
605
+ # meaning, it will decompose in C++
606
+ with self._dispatch_mode:
607
+ res = op.func(*args, **kwargs)
608
+
609
+ if op.is_jax_function:
610
+ res = self.j2t_iso(res)
611
+
612
+ if self.config.force_materialize_views and isinstance(res, View):
613
+ res = res.torch()
614
+
615
+ if self.config.debug_accuracy_for_each_op:
616
+ debug_accuracy(func, old_args, old_kwargs, res)
617
+ return res
618
+
619
+ def enable_torch_modes(self):
620
+ self._dispatch_mode.__enter__()
621
+ self._function_mode.__enter__()
622
+ self.enabled = True
623
+
624
+ def disable_torch_modes(self, *exc):
625
+ if not exc:
626
+ exc = (None, None, None)
627
+ self._function_mode.__exit__(*exc)
628
+ self._dispatch_mode.__exit__(*exc)
629
+ self.enabled = False
630
+
631
+ def __enter__(self):
632
+ self.enable_torch_modes()
633
+ return self
634
+
635
+ def __exit__(self, *exc):
636
+ self.disable_torch_modes(*exc)
637
+
638
+ def _move_one_value(self, val):
639
+ if isinstance(val, torch.nn.Module):
640
+ with self:
641
+ return val.to("jax")
642
+ if isinstance(val, Tensor):
643
+ return val
644
+ if isinstance(val, torch.Tensor):
645
+ return Tensor(self.t2j_copy(val), self)
646
+ return val
647
+
648
+ def to_xla(self, torchvalues):
649
+ # tensors are torch.Tensors (not XLATensor)
650
+ res = torch_pytree.tree_map(self._move_one_value, torchvalues)
651
+ return res
652
+
653
+ def t2j_iso(self, torchtensors):
654
+ """Convert torchax Tensor to jax array.
655
+
656
+ This function will not copy, will just unwrap the inner jax array out.
657
+ Note: iso is short for "isomorphic"
658
+ """
659
+
660
+ def to_jax(x):
661
+ if (
662
+ self.config.allow_mixed_math_with_scalar_tensor
663
+ and not isinstance(x, Tensor)
664
+ and not isinstance(x, View)
665
+ ):
666
+ if x.squeeze().ndim == 0:
667
+ return x.item()
668
+ if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
669
+ x = x.wait()
670
+ assert isinstance(x, Tensor) or isinstance(x, View), (
671
+ f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
672
+ )
673
+ return x.jax()
674
+
675
+ res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
676
+ return res
677
+
678
+ def v2t_iso(self, views):
679
+ def to_tensor(x):
680
+ if isinstance(x, View):
681
+ return x.torch()
682
+ return x
683
+
684
+ res = torch_pytree.tree_map_only(View, to_tensor, views)
685
+ return res
686
+
687
+ def j2t_iso(self, jaxarray):
688
+ """Convert jax array to torchax Tensor.
689
+
690
+ This function will not copy, will just wrap the jax array with a torchax Tensor
691
+ Note: iso is short for "isomorphic"
692
+ """
693
+ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), jaxarray)
694
+
695
+ def j2t_copy(self, args):
696
+ """Convert torch.Tensor in cpu to a jax array
697
+
698
+ This might involves copying the data (depending if dlpack is enabled)
699
+ """
700
+ return torch_pytree.tree_map_only(
701
+ jax.Array,
702
+ lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
703
+ args,
704
+ )
705
+
706
+ def t2j_copy(self, args):
707
+ """Convert jax array to torch.Tensor in cpu.
708
+
709
+ This might involves copying the data (depending if dlpack is enabled)
710
+ """
711
+ return torch_pytree.tree_map_only(
712
+ torch.Tensor,
713
+ lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
714
+ args,
715
+ )
716
+
717
+ def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
718
+ self._ops[op_to_override] = ops_registry.Operator(
719
+ op_to_override,
720
+ op_impl,
721
+ is_jax_function=False,
722
+ is_user_defined=True,
723
+ needs_env=False,
724
+ is_view_op=is_view_op,
725
+ )
726
+
727
+ @contextlib.contextmanager
728
+ def override_property(self, **kwargs):
729
+ new_prop = self.param.override(**kwargs)
730
+ self._property.content.append(new_prop)
731
+ yield
732
+ self._property.content.pop()