torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchax/tensor.py ADDED
@@ -0,0 +1,736 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import threading
16
+ import logging
17
+ import sys
18
+ import contextlib
19
+ from typing import Optional, Any
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy
23
+ import itertools
24
+ import torch
25
+ import torch.distributed._functional_collectives
26
+ import torch.func
27
+ import torch.utils._mode_utils as mode_utils
28
+ import torch.utils._python_dispatch as torch_dispatch
29
+ import torch.utils._pytree as torch_pytree
30
+ from torchax.view import View
31
+ from torchax import config
32
+ from torchax.ops import mappings, ops_registry
33
+ from torchax import amp
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class OperatorNotFound(Exception):
39
+ pass
40
+
41
+
42
+ @contextlib.contextmanager
43
+ def log_nested(env, message):
44
+ if env.config.debug_print_each_op:
45
+ print((" " * log_nested.level) + message, file=sys.stderr)
46
+ log_nested.level += 1
47
+ yield
48
+ log_nested.level -= 1
49
+
50
+
51
+ log_nested.level = 0
52
+
53
+
54
+ class Tensor(torch.Tensor):
55
+
56
+ @staticmethod
57
+ def __new__(cls, elem, env, requires_grad=False):
58
+ dtype = mappings.j2t_dtype(elem.dtype)
59
+ shape = list(elem.shape)
60
+ for i, s in enumerate(shape):
61
+ if not isinstance(s, int):
62
+ shape[i] = 1
63
+ if dtype is None:
64
+ dtype = torch.float32
65
+ #dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
66
+ if not (dtype.is_floating_point or dtype.is_complex):
67
+ requires_grad = False
68
+
69
+ return torch.Tensor._make_wrapper_subclass(
70
+ cls,
71
+ shape,
72
+ dtype=dtype,
73
+ device='meta',
74
+ requires_grad=requires_grad,
75
+ )
76
+
77
+ def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
78
+ super().__init__()
79
+ self._elem = elem
80
+ self._env = env
81
+
82
+ def __str__(self):
83
+ return "Tensor({} {})".format(str(type(self._elem)), str(self._elem))
84
+
85
+ __repr__ = __str__
86
+
87
+ @property
88
+ def shape(self):
89
+ return torch.Size(self._elem.shape)
90
+
91
+ @property
92
+ def ndim(self):
93
+ return len(self._elem.shape)
94
+
95
+ def flatten(self, start_dim=0, end_dim=-1):
96
+ if end_dim == -1:
97
+ end_dim = self.ndim
98
+ new_shape = (
99
+ self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
100
+ new_elem = jnp.reshape(self._elem, new_shape)
101
+ return Tensor(new_elem, self._env)
102
+ # return torch.reshape(self, new_shape)
103
+
104
+ def __setitem__(self, key, val):
105
+ key, val = self._env.t2j_iso((key, val))
106
+ self._elem = self._elem.at[key].set(val)
107
+
108
+ def type_as(self, other):
109
+ self._elem = self._elem.astype(other._elem.dtype)
110
+ return self
111
+
112
+ __torch_function__ = torch._C._disabled_torch_function_impl
113
+
114
+ @classmethod
115
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
116
+ # TODO(hanq): figure out why is dispatch mode not sufficient
117
+ if func == torch.ops._c10d_functional.wait_tensor.default:
118
+ return args[0]._env.dispatch(func, types, args, kwargs)
119
+ if func == torch.ops.prim.device.default:
120
+ return torch.device('privateuseone', 0)
121
+ raise AssertionError(
122
+ 'torchax Tensors can only do math within the torchax environment.'
123
+ 'Please wrap your code with `with torchax.default_env()` or '
124
+ 'call torchax.enable_globally() before.')
125
+
126
+ def detach(self):
127
+ return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
128
+
129
+ def numpy(self) -> numpy.ndarray:
130
+ import numpy as np
131
+
132
+ return np.array(self._elem)
133
+
134
+ def jax(self) -> jax.Array:
135
+ return self._elem
136
+
137
+ def torch(self) -> torch.Tensor:
138
+ return self._env.j2t_copy(self.jax())
139
+
140
+ @property
141
+ def dtype(self):
142
+ return mappings.j2t_dtype(self._elem.dtype)
143
+
144
+ def dim(self):
145
+ return self.ndim
146
+
147
+ @property
148
+ def device(self):
149
+ return torch.device("jax:0")
150
+
151
+ @property
152
+ def jax_device(self):
153
+ return self._elem.device
154
+
155
+ @property
156
+ def data(self):
157
+ logger.warning(
158
+ "In-place to .data modifications still results a copy on TPU")
159
+ return self
160
+
161
+ @data.setter
162
+ def data(self, other):
163
+ if isinstance(other, Tensor):
164
+ self._elem = other._elem
165
+
166
+ def apply_jax(self, jax_function, *args, **kwargs):
167
+ # Call a jax function on _elem
168
+ res = jax_function(self._elem, *args, **kwargs)
169
+ return self._env.j2t_iso(res)
170
+
171
+ def apply_jax_(self, jax_function, *args, **kwargs):
172
+ self._elem = jax_function(self._elem, *args, **kwargs)
173
+ return self
174
+
175
+ def tolist(self):
176
+ return self._elem.tolist()
177
+
178
+ def shard_(self, sharding):
179
+ self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
180
+
181
+
182
+ def debug_accuracy(func, args, kwargs, current_output):
183
+ args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
184
+ torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output))
185
+
186
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
187
+ if "device" in kwargs_torch:
188
+ kwargs_torch["device"] = "cpu" # do the torch native for comparison
189
+ expected_out = func(*args_torch, **kwargs_torch)
190
+
191
+ flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
192
+ flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
193
+
194
+ for ex, real in zip(flattened_expected_out, flattened_current_out):
195
+ if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
196
+ ex = ex.to(real.dtype)
197
+ try:
198
+ if isinstance(ex, torch.Tensor) and not torch.allclose(
199
+ ex, real, atol=1e-3, equal_nan=True):
200
+ import pdb
201
+
202
+ pdb.set_trace()
203
+ except:
204
+ import pdb
205
+
206
+ pdb.set_trace()
207
+
208
+ return True
209
+
210
+
211
+ def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
212
+
213
+ def _display(a):
214
+ if isinstance(a, torch.Tensor):
215
+ return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
216
+ elif isinstance(a, jax.Array):
217
+ return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
218
+ else:
219
+ return str(a)
220
+
221
+ kwargs = kwargs or {}
222
+ title = "DISPATCH" if is_dispatch else "FUNCTION"
223
+ args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
224
+ kwargs_msg = ("kwargs: " +
225
+ ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
226
+ if log_args else "")
227
+ return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
228
+
229
+
230
+ class XLAFunctionMode(torch.overrides.TorchFunctionMode):
231
+ """Context manager that dispatches torch function calls to JAX."""
232
+
233
+ def __init__(self, env):
234
+ self.env = env
235
+
236
+ def __torch_function__(self,
237
+ func,
238
+ types,
239
+ args=(),
240
+ kwargs=None) -> torch.Tensor:
241
+ message = f"FUNCTION: {_name_of_func(func)}"
242
+ if self.env.config.debug_print_each_op_operands:
243
+ message = message + "f"
244
+ message = _make_debug_msg(False,
245
+ self.env.config.debug_print_each_op_operands,
246
+ func, args, kwargs)
247
+ with log_nested(self.env, message):
248
+ try:
249
+ return self.env.dispatch(func, types, args, kwargs)
250
+ except OperatorNotFound:
251
+ pass
252
+ if _name_of_func(func) in (
253
+ "rot90"): # skip rot90 with k%4==0 due to no change
254
+ if len(args) >= 2 and type(args[1]) == int:
255
+ if (args[1]) % 4 == 0:
256
+ return args[0]
257
+ return func(*args, **(kwargs or {}))
258
+
259
+
260
+ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
261
+
262
+ def __init__(self, env):
263
+ self.env = env
264
+
265
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
266
+ message = _make_debug_msg(True,
267
+ self.env.config.debug_print_each_op_operands,
268
+ func, args, kwargs)
269
+ with log_nested(self.env, message):
270
+ if isinstance(func, torch._ops.OpOverloadPacket):
271
+ with self:
272
+ return func(*args, **kwargs)
273
+ # Only functions under these namespaces will be intercepted
274
+ if func.namespace not in (
275
+ "aten",
276
+ "_c10d_functional",
277
+ "torchvision",
278
+ "xla",
279
+ ):
280
+ return func(*args, **kwargs)
281
+ return self.env.dispatch(func, types, args, kwargs)
282
+
283
+
284
+ def _name_of_func(func):
285
+ if hasattr(func, "name"):
286
+ return func.name()
287
+ return func.__name__
288
+
289
+
290
+ # Constructors that don't take other tensor as input
291
+ TENSOR_CONSTRUCTORS = {
292
+ torch.ones,
293
+ torch.zeros,
294
+ torch.empty,
295
+ torch.empty_strided,
296
+ torch.tensor,
297
+ torch.arange,
298
+ torch.eye,
299
+ torch.randn,
300
+ torch.rand,
301
+ torch.randint,
302
+ torch.full,
303
+ torch.as_tensor,
304
+ }
305
+
306
+ # TODO(wen): use existing types, either from torch or jax
307
+ SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
308
+
309
+
310
+ class RuntimeProperty:
311
+ mesh: Any
312
+ prng: Any
313
+ autocast_dtype: Any
314
+
315
+ def __init__(self, mesh, prng, autocast_dtype):
316
+ self.mesh = mesh
317
+ self.prng = prng
318
+ self.autocast_dtype = autocast_dtype
319
+
320
+ def override(self, **kwargs):
321
+ return OverrideProperty(self, kwargs)
322
+
323
+ def get_and_rotate_prng_key(self):
324
+ old_key = self.prng
325
+ new_prng_key, next_key = jax.random.split(old_key)
326
+ self.prng = new_prng_key
327
+ return next_key
328
+
329
+
330
+ class OverrideProperty(RuntimeProperty):
331
+
332
+ def __init__(self, parent, override):
333
+ self.parent = parent
334
+ self._override = dict(override)
335
+
336
+ def __getattr__(self, name):
337
+ if name in self._override:
338
+ return self._override[name]
339
+ return getattr(self.parent, name)
340
+
341
+
342
+ class Environment(contextlib.ContextDecorator):
343
+ """This class holds a set of configurations and "globals" needed
344
+
345
+ for executing torch program using jax.
346
+ Things included so far:
347
+
348
+ op registry
349
+ PRNGKey
350
+ Configs
351
+
352
+ Also helper functions to manipulate those.
353
+ """
354
+
355
+ def __init__(self, configuration=None):
356
+ self._function_mode = XLAFunctionMode(self)
357
+ self._dispatch_mode = XLADispatchMode(self)
358
+
359
+ # name is torch callable
360
+ self._ops = {}
361
+ self._decomps = {}
362
+
363
+ self.load_ops()
364
+
365
+ _mesh = None
366
+ self.config = configuration or config.Configuration()
367
+
368
+ self.enabled = False
369
+
370
+ autocast_dtype = None
371
+
372
+ _prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
373
+ self._property = threading.local()
374
+ self._initial_content = RuntimeProperty(
375
+ mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
376
+
377
+ @property
378
+ def param(self):
379
+ if not hasattr(self._property, 'content'):
380
+ self._property.content = [
381
+ self._initial_content
382
+ ]
383
+ return self._property.content[-1]
384
+
385
+ def manual_seed(self, key):
386
+ if isinstance(key, torch.Tensor):
387
+ assert key.ndim == 0, 'manual seed can only take scalars'
388
+ assert not key.dtype.is_floating_point, 'manual seed can only be integers'
389
+
390
+ if isinstance(key, Tensor):
391
+ key = key._elem
392
+ else:
393
+ key = key.item()
394
+ jax_key = jax.random.PRNGKey(key)
395
+ new_prop = self.param.override(prng=jax_key)
396
+ self._property.content.append(new_prop)
397
+
398
+ @property
399
+ def prng_key(self):
400
+ return self.param.prng
401
+
402
+ def _should_use_torchax_tensor(self, device):
403
+ if device is None:
404
+ device = torch.get_default_device()
405
+
406
+ if isinstance(device, torch.device):
407
+ device = device.type
408
+
409
+ if ':' in device:
410
+ device = device.split(':')[0]
411
+
412
+ match device:
413
+ case 'cpu':
414
+ return False
415
+ case 'cuda':
416
+ return self.config.treat_cuda_as_jax_device
417
+ case 'jax':
418
+ return True
419
+ case 'privateuseone':
420
+ return True
421
+ case 'meta':
422
+ return self.enabled
423
+ return False
424
+
425
+ def load_ops(self):
426
+ from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
427
+
428
+ for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
429
+ ops_registry.all_torch_functions.items()):
430
+ if v.is_jax_function:
431
+ self._ops[k] = v
432
+ else:
433
+ self._decomps[k] = v
434
+
435
+ from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
436
+
437
+ for k, v in DECOMPOSITIONS.items():
438
+ if k not in self._decomps:
439
+ self._decomps[k] = ops_registry.Operator(
440
+ k,
441
+ v,
442
+ is_jax_function=False,
443
+ is_user_defined=False,
444
+ needs_env=False,
445
+ is_view_op=k in MUTABLE_DECOMPOSITION,
446
+ )
447
+
448
+ def _get_op_or_decomp(self, func):
449
+
450
+ def _get_from_dict(op_dict, op):
451
+ op = op_dict.get(func)
452
+ if op is None and isinstance(func, torch._ops.OpOverloadPacket):
453
+ op = op_dict.get(func.default)
454
+ if op is None and isinstance(func, torch._ops.OpOverload):
455
+ op = op_dict.get(func.overloadpacket)
456
+ return op
457
+
458
+ op = _get_from_dict(self._ops, func)
459
+
460
+ if op is None:
461
+ # fallback to decompose
462
+ op = _get_from_dict(self._decomps, func)
463
+
464
+ if op is None:
465
+ raise OperatorNotFound(
466
+ f"Operator with name {_name_of_func(func)} has no lowering")
467
+
468
+ return op
469
+
470
+ def _is_same_device(self, the_tensor, new_device):
471
+ if new_device is None:
472
+ return True
473
+ if new_device == 'meta' and the_tensor.device.type == 'jax':
474
+ return True
475
+ if the_tensor.device.type != new_device:
476
+ if the_tensor.device.type == 'cuda':
477
+ return self.config.treat_cuda_as_jax_device
478
+ return False
479
+ return True
480
+
481
+ def _to_copy(self, the_tensor, new_dtype, new_device):
482
+ if isinstance(the_tensor, View):
483
+ the_tensor = the_tensor.torch()
484
+ if isinstance(new_device, torch.device):
485
+ new_device = new_device.type
486
+ res = the_tensor
487
+ if not self._is_same_device(the_tensor, new_device):
488
+ if isinstance(the_tensor, Tensor):
489
+ torch_tensor = self.j2t_copy(the_tensor._elem)
490
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
491
+ return torch_tensor.to(device=new_device, dtype=new_dtype)
492
+ else:
493
+ arr = self.t2j_copy(the_tensor)
494
+ res = Tensor(arr, self, the_tensor.requires_grad)
495
+
496
+ if new_dtype is not None and new_dtype != res.dtype:
497
+ if isinstance(res, Tensor):
498
+ res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
499
+ else:
500
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
501
+ return res.to(device=new_device, dtype=new_dtype)
502
+ return res
503
+
504
+ def get_and_rotate_prng_key(self,
505
+ generator: Optional[torch.Generator] = None):
506
+ if generator is not None:
507
+ return jax.random.PRNGKey(generator.initial_seed() % (2**63))
508
+ return self.param.get_and_rotate_prng_key()
509
+
510
+ def _handle_tensor_constructor(self, func, args, kwargs):
511
+ device = kwargs.get("device")
512
+ if self._should_use_torchax_tensor(device):
513
+ # don't set default device, let caller set it
514
+ requires_grad = kwargs.get("requires_grad", False)
515
+ op = self._get_op_or_decomp(func)
516
+ if op.needs_env:
517
+ kwargs['env'] = self
518
+ if op.is_jax_function:
519
+ (args, kwargs) = self.t2j_iso((args, kwargs))
520
+ res = op.func(*args, **kwargs)
521
+ if isinstance(res, jax.Array):
522
+ res = Tensor(res, self, requires_grad)
523
+ return res
524
+ else:
525
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
526
+ return func(*args, **kwargs)
527
+
528
+ def _torch_Tensor_to(self, args, kwargs):
529
+ the_tensor = args[0]
530
+ args = args[1:]
531
+ if len(args) >= 1 and isinstance(args[0], torch.Tensor):
532
+ dtype = args[0].dtype
533
+ device = args[0].device
534
+ return self._to_copy(the_tensor, dtype, device)
535
+ device = kwargs.get("device")
536
+ dtype = kwargs.get("dtype")
537
+ # args like pin_memory etc that we will ignore
538
+ args = list(filter(lambda x: not isinstance(x, bool), args))
539
+ if len(args) >= 2:
540
+ device, dtype, *_ = args
541
+ elif len(args) == 1 and isinstance(args[0], torch.dtype):
542
+ dtype = args[0]
543
+ elif len(args) == 1:
544
+ device = args[0]
545
+ return self._to_copy(the_tensor, dtype, device)
546
+
547
+ def dispatch(self, func, types, args, kwargs):
548
+ kwargs = kwargs or {}
549
+ if func in TENSOR_CONSTRUCTORS:
550
+ return self._handle_tensor_constructor(func, args, kwargs)
551
+ if func in (
552
+ torch.Tensor.to,
553
+ torch.ops.aten.lift_fresh.default,
554
+ torch.ops.aten._to_copy,
555
+ torch.ops.aten._to_copy.default,
556
+ ):
557
+ return self._torch_Tensor_to(args, kwargs)
558
+
559
+ # If the func doesn't act on Tensor, and is not a tensor constructor,
560
+ # We should skip and let torch handle it.
561
+
562
+ tensor_args = [
563
+ t for t in torch_pytree.tree_flatten(args)[0]
564
+ if isinstance(t, torch.Tensor)
565
+ ]
566
+
567
+ def is_not_torchax_tensor(x):
568
+ return not isinstance(x, Tensor) and not isinstance(x, View)
569
+
570
+ if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
571
+ res = func(*args, **kwargs)
572
+ return res
573
+
574
+ with jax.named_scope(_name_of_func(func)):
575
+ op = self._get_op_or_decomp(func)
576
+
577
+ old_args, old_kwargs = args, kwargs
578
+ with self._dispatch_mode:
579
+ args, kwargs = torch_pytree.tree_map_only(
580
+ torch.distributed._functional_collectives.AsyncCollectiveTensor,
581
+ torch.distributed._functional_collectives.wait_tensor,
582
+ (args, kwargs),
583
+ )
584
+
585
+ try:
586
+ if not op.is_view_op:
587
+ args, kwargs = self.v2t_iso((args, kwargs))
588
+
589
+ with self:
590
+ if self.param.autocast_dtype is not None:
591
+ autocast_policy = amp.autocast_policy.get(func)
592
+ if autocast_policy is not None:
593
+ args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
594
+ self.param.autocast_dtype)
595
+
596
+ if op.is_jax_function:
597
+ args, kwargs = self.t2j_iso((args, kwargs))
598
+ except AssertionError:
599
+ if self.config.debug_mixed_tensor:
600
+ breakpoint()
601
+ else:
602
+ raise
603
+
604
+ if op.needs_env:
605
+ kwargs["env"] = self
606
+
607
+ if op.is_jax_function:
608
+ res = op.func(*args, **kwargs)
609
+ else:
610
+ # enable dispatch mode because this op could be a composite autograd op
611
+ # meaning, it will decompose in C++
612
+ with self._dispatch_mode:
613
+ res = op.func(*args, **kwargs)
614
+
615
+ if op.is_jax_function:
616
+ res = self.j2t_iso(res)
617
+
618
+ if self.config.force_materialize_views and isinstance(res, View):
619
+ res = res.torch()
620
+
621
+ if self.config.debug_accuracy_for_each_op:
622
+ debug_accuracy(func, old_args, old_kwargs, res)
623
+ return res
624
+
625
+ def enable_torch_modes(self):
626
+ self._dispatch_mode.__enter__()
627
+ self._function_mode.__enter__()
628
+ self.enabled = True
629
+
630
+ def disable_torch_modes(self, *exc):
631
+ if not exc:
632
+ exc = (None, None, None)
633
+ self._function_mode.__exit__(*exc)
634
+ self._dispatch_mode.__exit__(*exc)
635
+ self.enabled = False
636
+
637
+ def __enter__(self):
638
+ self.enable_torch_modes()
639
+ return self
640
+
641
+ def __exit__(self, *exc):
642
+ self.disable_torch_modes(*exc)
643
+
644
+ def _move_one_value(self, val):
645
+ if isinstance(val, torch.nn.Module):
646
+ with self:
647
+ return val.to("jax")
648
+ if isinstance(val, Tensor):
649
+ return val
650
+ if isinstance(val, torch.Tensor):
651
+ return Tensor(self.t2j_copy(val), self)
652
+ return val
653
+
654
+ def to_xla(self, torchvalues):
655
+ # tensors are torch.Tensors (not XLATensor)
656
+ res = torch_pytree.tree_map(self._move_one_value, torchvalues)
657
+ return res
658
+
659
+ def t2j_iso(self, torchtensors):
660
+ """Convert torchax Tensor to jax array.
661
+
662
+ This function will not copy, will just unwrap the inner jax array out.
663
+ Note: iso is short for "isomorphic"
664
+ """
665
+
666
+ def to_jax(x):
667
+ if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
668
+ x, Tensor) and not isinstance(x, View):
669
+ if x.squeeze().ndim == 0:
670
+ return x.item()
671
+ if isinstance(
672
+ x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
673
+ x = x.wait()
674
+ assert isinstance(x, Tensor) or isinstance(x, View), (
675
+ f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
676
+ )
677
+ return x.jax()
678
+
679
+ res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
680
+ return res
681
+
682
+ def v2t_iso(self, views):
683
+
684
+ def to_tensor(x):
685
+ if isinstance(x, View):
686
+ return x.torch()
687
+ return x
688
+
689
+ res = torch_pytree.tree_map_only(View, to_tensor, views)
690
+ return res
691
+
692
+ def j2t_iso(self, jaxarray):
693
+ """Convert jax array to torchax Tensor.
694
+
695
+ This function will not copy, will just wrap the jax array with a torchax Tensor
696
+ Note: iso is short for "isomorphic"
697
+ """
698
+ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
699
+ jaxarray)
700
+
701
+ def j2t_copy(self, args):
702
+ """Convert torch.Tensor in cpu to a jax array
703
+
704
+ This might involves copying the data (depending if dlpack is enabled)
705
+ """
706
+ return torch_pytree.tree_map_only(
707
+ jax.Array,
708
+ lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
709
+ args)
710
+
711
+ def t2j_copy(self, args):
712
+ """Convert jax array to torch.Tensor in cpu.
713
+
714
+ This might involves copying the data (depending if dlpack is enabled)
715
+ """
716
+ return torch_pytree.tree_map_only(
717
+ torch.Tensor,
718
+ lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
719
+ args)
720
+
721
+ def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
722
+ self._ops[op_to_override] = ops_registry.Operator(
723
+ op_to_override,
724
+ op_impl,
725
+ is_jax_function=False,
726
+ is_user_defined=True,
727
+ needs_env=False,
728
+ is_view_op=is_view_op,
729
+ )
730
+
731
+ @contextlib.contextmanager
732
+ def override_property(self, **kwargs):
733
+ new_prop = self.param.override(**kwargs)
734
+ self._property.content.append(new_prop)
735
+ yield
736
+ self._property.content.pop()