torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/tensor.py CHANGED
@@ -1,3 +1,4 @@
1
+ import threading
1
2
  import logging
2
3
  import sys
3
4
  import contextlib
@@ -5,15 +6,17 @@ from typing import Optional, Any
5
6
  import jax
6
7
  import jax.numpy as jnp
7
8
  import numpy
9
+ import itertools
8
10
  import torch
9
11
  import torch.distributed._functional_collectives
10
12
  import torch.func
11
13
  import torch.utils._mode_utils as mode_utils
12
14
  import torch.utils._python_dispatch as torch_dispatch
13
15
  import torch.utils._pytree as torch_pytree
14
-
16
+ from torchax.view import View
15
17
  from torchax import config
16
18
  from torchax.ops import mappings, ops_registry
19
+ from torchax import amp
17
20
 
18
21
  logger = logging.getLogger(__name__)
19
22
 
@@ -22,63 +25,42 @@ class OperatorNotFound(Exception):
22
25
  pass
23
26
 
24
27
 
25
- def wrap(jaxarray):
26
- return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray)
27
-
28
-
29
- def unwrap(torchtensors):
30
- return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
31
-
32
-
33
- def t2j(t):
34
- if isinstance(t, Tensor):
35
- return t._elem
36
- return mappings.t2j(t)
37
-
38
-
39
- def j2t(x):
40
- return mappings.j2t(x)
41
-
42
-
43
- def t2j_dtype(dtype):
44
- return mappings.t2j_dtype(dtype)
45
-
46
-
47
- def j2t_dtype(dtype):
48
- return mappings.j2t_dtype(dtype)
49
-
50
-
51
28
  @contextlib.contextmanager
52
29
  def log_nested(env, message):
53
30
  if env.config.debug_print_each_op:
54
- print((' ' * log_nested.level) + message, file=sys.stderr)
31
+ print((" " * log_nested.level) + message, file=sys.stderr)
55
32
  log_nested.level += 1
56
33
  yield
57
34
  log_nested.level -= 1
58
35
 
36
+
59
37
  log_nested.level = 0
60
38
 
61
39
 
62
40
  class Tensor(torch.Tensor):
63
41
 
64
42
  @staticmethod
65
- def __new__(cls, elem, env):
66
- dtype = j2t_dtype(elem.dtype)
43
+ def __new__(cls, elem, env, requires_grad=False):
44
+ dtype = mappings.j2t_dtype(elem.dtype)
67
45
  shape = list(elem.shape)
68
46
  for i, s in enumerate(shape):
69
47
  if not isinstance(s, int):
70
48
  shape[i] = 1
71
49
  if dtype is None:
72
50
  dtype = torch.float32
51
+ #dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
52
+ if not (dtype.is_floating_point or dtype.is_complex):
53
+ requires_grad = False
54
+
73
55
  return torch.Tensor._make_wrapper_subclass(
74
56
  cls,
75
57
  shape,
76
58
  dtype=dtype,
77
59
  device='meta',
78
- requires_grad=False,
60
+ requires_grad=requires_grad,
79
61
  )
80
62
 
81
- def __init__(self, elem: jax.Array, env: 'Environment'):
63
+ def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
82
64
  super().__init__()
83
65
  self._elem = elem
84
66
  self._env = env
@@ -88,12 +70,9 @@ class Tensor(torch.Tensor):
88
70
 
89
71
  __repr__ = __str__
90
72
 
91
- def __jax_array__(self):
92
- return self._elem
93
-
94
73
  @property
95
74
  def shape(self):
96
- return self._elem.shape
75
+ return torch.Size(self._elem.shape)
97
76
 
98
77
  @property
99
78
  def ndim(self):
@@ -120,14 +99,15 @@ class Tensor(torch.Tensor):
120
99
 
121
100
  @classmethod
122
101
  def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
123
- env = None
124
- for arg in torch_pytree.arg_tree_leaves(*args, **kwargs):
125
- if isinstance(arg, Tensor):
126
- env = arg._env
127
- break
128
-
129
- with env:
130
- return func(*args, **(kwargs or {}))
102
+ # TODO(hanq): figure out why is dispatch mode not sufficient
103
+ if func == torch.ops._c10d_functional.wait_tensor.default:
104
+ return args[0]._env.dispatch(func, types, args, kwargs)
105
+ if func == torch.ops.prim.device.default:
106
+ return torch.device('privateuseone', 0)
107
+ raise AssertionError(
108
+ 'torchax Tensors can only do math within the torchax environment.'
109
+ 'Please wrap your code with `with torchax.default_env()` or '
110
+ 'call torchax.enable_globally() before.')
131
111
 
132
112
  def detach(self):
133
113
  return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
@@ -141,18 +121,18 @@ class Tensor(torch.Tensor):
141
121
  return self._elem
142
122
 
143
123
  def torch(self) -> torch.Tensor:
144
- return j2t(self.jax())
124
+ return self._env.j2t_copy(self.jax())
145
125
 
146
126
  @property
147
127
  def dtype(self):
148
- return j2t_dtype(self._elem.dtype)
128
+ return mappings.j2t_dtype(self._elem.dtype)
149
129
 
150
130
  def dim(self):
151
131
  return self.ndim
152
132
 
153
133
  @property
154
134
  def device(self):
155
- return torch.device('jax:0')
135
+ return torch.device("jax:0")
156
136
 
157
137
  @property
158
138
  def jax_device(self):
@@ -160,7 +140,8 @@ class Tensor(torch.Tensor):
160
140
 
161
141
  @property
162
142
  def data(self):
163
- logger.warn("In-place to .data modifications still results a copy on TPU")
143
+ logger.warning(
144
+ "In-place to .data modifications still results a copy on TPU")
164
145
  return self
165
146
 
166
147
  @data.setter
@@ -182,15 +163,15 @@ class Tensor(torch.Tensor):
182
163
 
183
164
  def shard_(self, sharding):
184
165
  self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
185
-
166
+
186
167
 
187
168
  def debug_accuracy(func, args, kwargs, current_output):
188
169
  args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
189
- torch.Tensor, lambda x: j2t(x._elem), (args, kwargs, current_output))
170
+ torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output))
190
171
 
191
172
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
192
- if 'device' in kwargs_torch:
193
- kwargs_torch['device'] = 'cpu' # do the torch native for comparison
173
+ if "device" in kwargs_torch:
174
+ kwargs_torch["device"] = "cpu" # do the torch native for comparison
194
175
  expected_out = func(*args_torch, **kwargs_torch)
195
176
 
196
177
  flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
@@ -200,8 +181,8 @@ def debug_accuracy(func, args, kwargs, current_output):
200
181
  if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
201
182
  ex = ex.to(real.dtype)
202
183
  try:
203
- if (isinstance(ex, torch.Tensor) and
204
- not torch.allclose(ex, real, atol=1e-3, equal_nan=True)):
184
+ if isinstance(ex, torch.Tensor) and not torch.allclose(
185
+ ex, real, atol=1e-3, equal_nan=True):
205
186
  import pdb
206
187
 
207
188
  pdb.set_trace()
@@ -212,46 +193,52 @@ def debug_accuracy(func, args, kwargs, current_output):
212
193
 
213
194
  return True
214
195
 
196
+
215
197
  def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
198
+
216
199
  def _display(a):
217
200
  if isinstance(a, torch.Tensor):
218
- return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
201
+ return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
219
202
  elif isinstance(a, jax.Array):
220
- return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
203
+ return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
221
204
  else:
222
205
  return str(a)
223
206
 
224
207
  kwargs = kwargs or {}
225
- title = 'DISPATCH' if is_dispatch else 'FUNCTION'
226
- args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
227
- kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
228
- return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'
208
+ title = "DISPATCH" if is_dispatch else "FUNCTION"
209
+ args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
210
+ kwargs_msg = ("kwargs: " +
211
+ ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
212
+ if log_args else "")
213
+ return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
229
214
 
230
215
 
231
216
  class XLAFunctionMode(torch.overrides.TorchFunctionMode):
232
217
  """Context manager that dispatches torch function calls to JAX."""
233
218
 
234
219
  def __init__(self, env):
235
- self.env = env
220
+ self.env = env
236
221
 
237
222
  def __torch_function__(self,
238
223
  func,
239
224
  types,
240
225
  args=(),
241
226
  kwargs=None) -> torch.Tensor:
242
- message = f'FUNCTION: {_name_of_func(func)}'
227
+ message = f"FUNCTION: {_name_of_func(func)}"
243
228
  if self.env.config.debug_print_each_op_operands:
244
- message = message + 'f'
245
- message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
229
+ message = message + "f"
230
+ message = _make_debug_msg(False,
231
+ self.env.config.debug_print_each_op_operands,
246
232
  func, args, kwargs)
247
233
  with log_nested(self.env, message):
248
234
  try:
249
235
  return self.env.dispatch(func, types, args, kwargs)
250
236
  except OperatorNotFound:
251
237
  pass
252
- if _name_of_func(func) in ('rot90'): # skip rot90 with k%4==0 due to no change
238
+ if _name_of_func(func) in (
239
+ "rot90"): # skip rot90 with k%4==0 due to no change
253
240
  if len(args) >= 2 and type(args[1]) == int:
254
- if ((args[1])%4 == 0):
241
+ if (args[1]) % 4 == 0:
255
242
  return args[0]
256
243
  return func(*args, **(kwargs or {}))
257
244
 
@@ -262,296 +249,463 @@ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
262
249
  self.env = env
263
250
 
264
251
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
265
- message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
252
+ message = _make_debug_msg(True,
253
+ self.env.config.debug_print_each_op_operands,
266
254
  func, args, kwargs)
267
255
  with log_nested(self.env, message):
268
256
  if isinstance(func, torch._ops.OpOverloadPacket):
269
257
  with self:
270
258
  return func(*args, **kwargs)
271
- if func.namespace not in ('aten', '_c10d_functional', 'torchvision'):
259
+ # Only functions under these namespaces will be intercepted
260
+ if func.namespace not in (
261
+ "aten",
262
+ "_c10d_functional",
263
+ "torchvision",
264
+ "xla",
265
+ ):
272
266
  return func(*args, **kwargs)
273
267
  return self.env.dispatch(func, types, args, kwargs)
274
268
 
269
+
275
270
  def _name_of_func(func):
276
- if hasattr(func, 'name'):
271
+ if hasattr(func, "name"):
277
272
  return func.name()
278
273
  return func.__name__
279
274
 
280
275
 
281
276
  # Constructors that don't take other tensor as input
282
277
  TENSOR_CONSTRUCTORS = {
283
- torch.ones,
284
- torch.zeros,
285
- torch.empty,
286
- torch.empty_strided,
287
- torch.tensor,
288
- torch.arange,
289
- torch.eye,
290
- torch.randn,
291
- torch.rand,
292
- torch.randint,
293
- torch.full,
294
- torch.as_tensor,
278
+ torch.ones,
279
+ torch.zeros,
280
+ torch.empty,
281
+ torch.empty_strided,
282
+ torch.tensor,
283
+ torch.arange,
284
+ torch.eye,
285
+ torch.randn,
286
+ torch.rand,
287
+ torch.randint,
288
+ torch.full,
289
+ torch.as_tensor,
295
290
  }
296
291
 
292
+ # TODO(wen): use existing types, either from torch or jax
293
+ SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
297
294
 
298
- class Environment(contextlib.ContextDecorator):
299
- """This class holds a set of configurations and "globals" needed
300
295
 
301
- for executing torch program using jax.
302
- Things included so far:
296
+ class RuntimeProperty:
297
+ mesh: Any
298
+ prng: Any
299
+ autocast_dtype: Any
303
300
 
304
- op registry
305
- PRNGKey
306
- Configs
301
+ def __init__(self, mesh, prng, autocast_dtype):
302
+ self.mesh = mesh
303
+ self.prng = prng
304
+ self.autocast_dtype = autocast_dtype
307
305
 
308
- Also helper functions to manipulate those.
309
- """
306
+ def override(self, **kwargs):
307
+ return OverrideProperty(self, kwargs)
308
+
309
+ def get_and_rotate_prng_key(self):
310
+ old_key = self.prng
311
+ new_prng_key, next_key = jax.random.split(old_key)
312
+ self.prng = new_prng_key
313
+ return next_key
314
+
315
+
316
+ class OverrideProperty(RuntimeProperty):
310
317
 
311
- _prng_key: jax.random.PRNGKey
318
+ def __init__(self, parent, override):
319
+ self.parent = parent
320
+ self._override = dict(override)
312
321
 
322
+ def __getattr__(self, name):
323
+ if name in self._override:
324
+ return self._override[name]
325
+ return getattr(self.parent, name)
313
326
 
314
- def __init__(self, configuration=None):
315
- self._function_mode = XLAFunctionMode(self)
316
- self._dispatch_mode = XLADispatchMode(self)
317
327
 
318
- # name is torch callable
319
- self._ops = {}
320
- self.load_ops()
328
+ class Environment(contextlib.ContextDecorator):
329
+ """This class holds a set of configurations and "globals" needed
330
+
331
+ for executing torch program using jax.
332
+ Things included so far:
333
+
334
+ op registry
335
+ PRNGKey
336
+ Configs
321
337
 
322
- self._mesh = None
323
- self.config = configuration or config.Configuration()
338
+ Also helper functions to manipulate those.
339
+ """
324
340
 
325
- self._manually_entered = False
326
- self.enabled = False
327
- self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
341
+ def __init__(self, configuration=None):
342
+ self._function_mode = XLAFunctionMode(self)
343
+ self._dispatch_mode = XLADispatchMode(self)
328
344
 
329
- def get_as_jax_device(self, device: Any):
330
- if device is None:
331
- device = torch.get_default_device()
345
+ # name is torch callable
346
+ self._ops = {}
347
+ self._decomps = {}
332
348
 
333
- if isinstance(device, torch.device):
334
- device = str(device)
349
+ self.load_ops()
335
350
 
336
- if (not self.config.use_torch_native_for_cpu_tensor and
337
- device.startswith('cpu')):
338
- return jax.devices('cpu')[0]
351
+ _mesh = None
352
+ self.config = configuration or config.Configuration()
339
353
 
340
- if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
341
- return jax.local_devices()[0]
354
+ self.enabled = False
342
355
 
343
- if device.startswith('jax'):
344
- return jax.local_devices()[0]
356
+ autocast_dtype = None
345
357
 
346
- return None # fallback to torch
347
-
358
+ _prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
359
+ self._property = threading.local()
360
+ self._property.content = [
361
+ RuntimeProperty(
362
+ mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
363
+ ]
348
364
 
365
+ @property
366
+ def param(self):
367
+ return self._property.content[-1]
349
368
 
350
- def load_ops(self):
351
- from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
352
- self._ops.update(ops_registry.all_aten_ops)
353
- self._ops.update(ops_registry.all_torch_functions)
369
+ def manual_seed(self, key):
370
+ jax_key = jax.random.PRNGKey(key)
371
+ new_prop = self.param.override(prng=jax_key)
372
+ self._property.content.append(new_prop)
354
373
 
355
- decomps = torch._decomp.core_aten_decompositions()
356
- from torchax.decompositions import EXTRA_DECOMP
357
- decomps.update(EXTRA_DECOMP)
358
- for k, v in decomps.items():
359
- if k not in self._ops:
360
- self._ops[k] = ops_registry.Operator(
374
+ @property
375
+ def prng_key(self):
376
+ return self.param.prng
377
+
378
+ def _should_use_torchax_tensor(self, device):
379
+ if device is None:
380
+ device = torch.get_default_device()
381
+
382
+ if isinstance(device, torch.device):
383
+ device = device.type
384
+
385
+ if ':' in device:
386
+ device = device.split(':')[0]
387
+
388
+ match device:
389
+ case 'cpu':
390
+ return False
391
+ case 'cuda':
392
+ return self.config.treat_cuda_as_jax_device
393
+ case 'jax':
394
+ return True
395
+ case 'privateuseone':
396
+ return True
397
+ case 'meta':
398
+ return self.enabled
399
+ return False
400
+
401
+ def load_ops(self):
402
+ from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
403
+
404
+ for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
405
+ ops_registry.all_torch_functions.items()):
406
+ if v.is_jax_function:
407
+ self._ops[k] = v
408
+ else:
409
+ self._decomps[k] = v
410
+
411
+ from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
412
+
413
+ for k, v in DECOMPOSITIONS.items():
414
+ if k not in self._decomps:
415
+ self._decomps[k] = ops_registry.Operator(
361
416
  k,
362
417
  v,
363
418
  is_jax_function=False,
364
419
  is_user_defined=False,
365
- needs_env=False
366
- )
367
-
368
- def _to_copy(self, the_tensor, new_dtype, new_device):
420
+ needs_env=False,
421
+ is_view_op=k in MUTABLE_DECOMPOSITION,
422
+ )
423
+
424
+ def _get_op_or_decomp(self, func):
425
+
426
+ def _get_from_dict(op_dict, op):
427
+ op = op_dict.get(func)
428
+ if op is None and isinstance(func, torch._ops.OpOverloadPacket):
429
+ op = op_dict.get(func.default)
430
+ if op is None and isinstance(func, torch._ops.OpOverload):
431
+ op = op_dict.get(func.overloadpacket)
432
+ return op
433
+
434
+ op = _get_from_dict(self._ops, func)
435
+
436
+ if op is None:
437
+ # fallback to decompose
438
+ op = _get_from_dict(self._decomps, func)
439
+
440
+ if op is None:
441
+ raise OperatorNotFound(
442
+ f"Operator with name {_name_of_func(func)} has no lowering")
443
+
444
+ return op
445
+
446
+ def _is_same_device(self, the_tensor, new_device):
447
+ if new_device is None:
448
+ return True
449
+ if new_device == 'meta' and the_tensor.device.type == 'jax':
450
+ return True
451
+ if the_tensor.device.type != new_device:
452
+ if the_tensor.device.type == 'cuda':
453
+ return self.config.treat_cuda_as_jax_device
454
+ return False
455
+ return True
456
+
457
+ def _to_copy(self, the_tensor, new_dtype, new_device):
458
+ if isinstance(the_tensor, View):
459
+ the_tensor = the_tensor.torch()
460
+ if isinstance(new_device, torch.device):
461
+ new_device = new_device.type
462
+ res = the_tensor
463
+ if not self._is_same_device(the_tensor, new_device):
369
464
  if isinstance(the_tensor, Tensor):
370
- arr = the_tensor.jax()
371
- if new_dtype is not None and new_dtype != arr.dtype:
372
- arr = arr.astype(mappings.t2j_dtype(new_dtype))
373
- if new_device is not None:
374
- # convert xla tensor to other device
375
- # only supported is CPU
376
- if str(new_device).startswith('cpu'):
377
- # converting to a non-jax device: let torch native handle it
378
- torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
379
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
380
- return torch_tensor.to(new_device)
465
+ torch_tensor = self.j2t_copy(the_tensor._elem)
466
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
467
+ return torch_tensor.to(device=new_device, dtype=new_dtype)
381
468
  else:
382
- if new_dtype is not None and new_dtype != the_tensor.dtype:
383
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
384
- the_tensor = the_tensor.to(new_dtype)
385
- jax_device = self.get_as_jax_device(new_device)
386
- if jax_device:
387
- arr = t2j(the_tensor)
388
- arr = jax.device_put(arr, jax_device)
389
- else:
390
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
391
- return the_tensor.to(new_device)
392
-
393
- return Tensor(arr, self)
394
-
469
+ arr = self.t2j_copy(the_tensor)
470
+ res = Tensor(arr, self, the_tensor.requires_grad)
395
471
 
396
- def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
397
- # Always use the default `randint` to get the next seed
472
+ if new_dtype is not None and new_dtype != the_tensor.dtype:
473
+ if isinstance(the_tensor, Tensor):
474
+ res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
475
+ else:
476
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
477
+ return the_tensor.to(device=new_device, dtype=new_dtype)
478
+ return res
479
+
480
+ def get_and_rotate_prng_key(self,
481
+ generator: Optional[torch.Generator] = None):
482
+ if generator is not None:
483
+ return jax.random.PRNGKey(generator.initial_seed() % (2**63))
484
+ return self.param.get_and_rotate_prng_key()
485
+
486
+ def _handle_tensor_constructor(self, func, args, kwargs):
487
+ device = kwargs.get("device")
488
+ if self._should_use_torchax_tensor(device):
489
+ # don't set default device, let caller set it
490
+ requires_grad = kwargs.get("requires_grad", False)
491
+ op = self._get_op_or_decomp(func)
492
+ if op.needs_env:
493
+ kwargs['env'] = self
494
+ if op.is_jax_function:
495
+ (args, kwargs) = self.t2j_iso((args, kwargs))
496
+ res = op.func(*args, **kwargs)
497
+ if isinstance(res, jax.Array):
498
+ res = Tensor(res, self, requires_grad)
499
+ return res
500
+ else:
398
501
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
399
- next_key = torch.randint(
400
- 0, 2**32, (), dtype=torch.uint32, generator=generator).numpy()
401
-
402
- return jax.random.key(next_key)
502
+ return func(*args, **kwargs)
403
503
 
404
- def _handle_tensor_constructor(self, func, args, kwargs):
405
- device = kwargs.get('device')
406
- jax_device = self.get_as_jax_device(device)
407
- # TODO(qihqi) figure out better ways for device propagation
408
- if not self._manually_entered and jax_device is None:
409
- # let torch handle it
410
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
411
- return func(*args, **kwargs)
412
- with jax.default_device(jax_device):
413
- op = self._ops.get(func)
414
- if op is None and isinstance(func, torch._ops.OpOverload):
415
- op = self._ops.get(func.overloadpacket)
416
- res = op.func(*args, **kwargs)
417
- if isinstance(res, jax.Array):
418
- res = Tensor(res, self)
419
- return res
420
-
421
- def _torch_Tensor_to(self, args, kwargs):
422
- the_tensor = args[0]
423
- args = args[1:]
424
- if len(args) >= 1 and isinstance(args[0], torch.Tensor):
425
- dtype = args[0].dtype
426
- device = args[0].device
427
- return self._to_copy(the_tensor, dtype, device)
428
- device = kwargs.get('device')
429
- dtype = kwargs.get('dtype')
430
- # args like pin_memory etc that we will ignore
431
- args = list(filter(lambda x: not isinstance(x, bool), args))
432
- if len(args) >= 2:
433
- device, dtype, *_ = args
434
- elif len(args) == 1 and isinstance(args[0], torch.dtype):
435
- dtype = args[0]
436
- elif len(args) == 1:
437
- device = args[0]
504
+ def _torch_Tensor_to(self, args, kwargs):
505
+ the_tensor = args[0]
506
+ args = args[1:]
507
+ if len(args) >= 1 and isinstance(args[0], torch.Tensor):
508
+ dtype = args[0].dtype
509
+ device = args[0].device
438
510
  return self._to_copy(the_tensor, dtype, device)
511
+ device = kwargs.get("device")
512
+ dtype = kwargs.get("dtype")
513
+ # args like pin_memory etc that we will ignore
514
+ args = list(filter(lambda x: not isinstance(x, bool), args))
515
+ if len(args) >= 2:
516
+ device, dtype, *_ = args
517
+ elif len(args) == 1 and isinstance(args[0], torch.dtype):
518
+ dtype = args[0]
519
+ elif len(args) == 1:
520
+ device = args[0]
521
+ return self._to_copy(the_tensor, dtype, device)
522
+
523
+ def dispatch(self, func, types, args, kwargs):
524
+ kwargs = kwargs or {}
525
+ if func in TENSOR_CONSTRUCTORS:
526
+ return self._handle_tensor_constructor(func, args, kwargs)
527
+ if func in (
528
+ torch.Tensor.to,
529
+ torch.ops.aten.lift_fresh.default,
530
+ torch.ops.aten._to_copy,
531
+ torch.ops.aten._to_copy.default,
532
+ ):
533
+ return self._torch_Tensor_to(args, kwargs)
534
+
535
+ # If the func doesn't act on Tensor, and is not a tensor constructor,
536
+ # We should skip and let torch handle it.
537
+
538
+ tensor_args = [
539
+ t for t in torch_pytree.tree_flatten(args)[0]
540
+ if isinstance(t, torch.Tensor)
541
+ ]
542
+
543
+ def is_not_torchax_tensor(x):
544
+ return not isinstance(x, Tensor) and not isinstance(x, View)
545
+
546
+ if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
547
+ res = func(*args, **kwargs)
548
+ return res
439
549
 
550
+ with jax.named_scope(_name_of_func(func)):
551
+ op = self._get_op_or_decomp(func)
440
552
 
441
- def dispatch(self, func, types, args, kwargs):
553
+ old_args, old_kwargs = args, kwargs
554
+ with self._dispatch_mode:
555
+ args, kwargs = torch_pytree.tree_map_only(
556
+ torch.distributed._functional_collectives.AsyncCollectiveTensor,
557
+ torch.distributed._functional_collectives.wait_tensor,
558
+ (args, kwargs),
559
+ )
442
560
 
443
- kwargs = kwargs or {}
444
- if func in TENSOR_CONSTRUCTORS:
445
- return self._handle_tensor_constructor(func, args, kwargs)
446
- if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default):
447
- return self._torch_Tensor_to(args, kwargs)
561
+ try:
562
+ if not op.is_view_op:
563
+ args, kwargs = self.v2t_iso((args, kwargs))
448
564
 
449
- # If the func doesn't act on Tensor, and is not a tensor constructor,
450
- # We should skip and let torch handle it.
451
-
452
- tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)]
453
- if tensor_args and all(not isinstance(t, Tensor) for t in tensor_args):
454
- return func(*args, **kwargs)
565
+ with self:
566
+ if self.param.autocast_dtype is not None:
567
+ autocast_policy = amp.autocast_policy.get(func)
568
+ if autocast_policy is not None:
569
+ args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
570
+ self.param.autocast_dtype)
455
571
 
456
- with jax.named_scope(_name_of_func(func)):
457
- op = self._ops.get(func)
572
+ if op.is_jax_function:
573
+ args, kwargs = self.t2j_iso((args, kwargs))
574
+ except AssertionError:
575
+ if self.config.debug_mixed_tensor:
576
+ breakpoint()
577
+ else:
578
+ raise
458
579
 
459
- if op is None and isinstance(func, torch._ops.OpOverloadPacket):
460
- op = self._ops.get(func.default)
580
+ if op.needs_env:
581
+ kwargs["env"] = self
461
582
 
462
- if op is None and isinstance(func, torch._ops.OpOverload):
463
- op = self._ops.get(func.overloadpacket)
583
+ if op.is_jax_function:
584
+ res = op.func(*args, **kwargs)
585
+ else:
586
+ # enable dispatch mode because this op could be a composite autograd op
587
+ # meaning, it will decompose in C++
588
+ with self._dispatch_mode:
589
+ res = op.func(*args, **kwargs)
464
590
 
465
- if op is None:
466
- raise OperatorNotFound(
467
- f'Operator with name {_name_of_func(func)} has no lowering')
591
+ if op.is_jax_function:
592
+ res = self.j2t_iso(res)
468
593
 
469
- old_args, old_kwargs = args, kwargs
470
- args, kwargs = torch_pytree.tree_map_only(
471
- torch.distributed._functional_collectives.AsyncCollectiveTensor,
472
- torch.distributed._functional_collectives.wait_tensor,
473
- (args, kwargs))
474
- try:
475
- if op.is_jax_function:
476
- args, kwargs = self.t2j_iso((args, kwargs))
477
- except AssertionError:
478
- if self.config.debug_mixed_tensor:
479
- import pdb; pdb.set_trace()
480
- else:
481
- raise
594
+ if self.config.force_materialize_views and isinstance(res, View):
595
+ res = res.torch()
482
596
 
597
+ if self.config.debug_accuracy_for_each_op:
598
+ debug_accuracy(func, old_args, old_kwargs, res)
599
+ return res
483
600
 
484
- if op.needs_env:
485
- kwargs['env'] = self
601
+ def enable_torch_modes(self):
602
+ self._dispatch_mode.__enter__()
603
+ self._function_mode.__enter__()
604
+ self.enabled = True
486
605
 
487
- with self:
488
- res = op.func(*args, **kwargs)
606
+ def disable_torch_modes(self, *exc):
607
+ if not exc:
608
+ exc = (None, None, None)
609
+ self._function_mode.__exit__(*exc)
610
+ self._dispatch_mode.__exit__(*exc)
611
+ self.enabled = False
489
612
 
490
- if op.is_jax_function:
491
- res = self.j2t_iso(res)
613
+ def __enter__(self):
614
+ self.enable_torch_modes()
615
+ return self
492
616
 
493
- if self.config.debug_accuracy_for_each_op:
494
- debug_accuracy(func, old_args, old_kwargs, res)
495
- return res
617
+ def __exit__(self, *exc):
618
+ self.disable_torch_modes(*exc)
496
619
 
497
- def enable_torch_modes(self):
498
- self._dispatch_mode.__enter__()
499
- self._function_mode.__enter__()
500
- self.enabled = True
501
-
502
- def disable_torch_modes(self, *exc):
503
- if not exc:
504
- exc = (None, None, None)
505
- self._function_mode.__exit__(*exc)
506
- self._dispatch_mode.__exit__(*exc)
507
- self.enabled = False
508
-
509
- def __enter__(self):
510
- self.enable_torch_modes()
511
- self._manually_entered = True
512
- return self
513
-
514
- def __exit__(self, *exc):
515
- self._manually_entered = False
516
- self.disable_torch_modes(*exc)
517
-
518
- def _move_one_value(self, val):
519
- if isinstance(val, torch.nn.Module):
520
- with self:
521
- return val.to('jax')
522
- if isinstance(val, Tensor):
523
- return val
524
- if isinstance(val, torch.Tensor):
525
- return Tensor(t2j(val), self)
620
+ def _move_one_value(self, val):
621
+ if isinstance(val, torch.nn.Module):
622
+ with self:
623
+ return val.to("jax")
624
+ if isinstance(val, Tensor):
526
625
  return val
626
+ if isinstance(val, torch.Tensor):
627
+ return Tensor(self.t2j_copy(val), self)
628
+ return val
527
629
 
528
- def to_xla(self, torchvalues):
529
- # tensors are torch.Tensors (not XLATensor)
530
- res = torch_pytree.tree_map(
531
- self._move_one_value,
532
- torchvalues)
533
- return res
630
+ def to_xla(self, torchvalues):
631
+ # tensors are torch.Tensors (not XLATensor)
632
+ res = torch_pytree.tree_map(self._move_one_value, torchvalues)
633
+ return res
634
+
635
+ def t2j_iso(self, torchtensors):
636
+ """Convert torchax Tensor to jax array.
637
+
638
+ This function will not copy, will just unwrap the inner jax array out.
639
+ Note: iso is short for "isomorphic"
640
+ """
641
+
642
+ def to_jax(x):
643
+ if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
644
+ x, Tensor):
645
+ if x.squeeze().ndim == 0:
646
+ return x.item()
647
+ if isinstance(
648
+ x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
649
+ x = x.wait()
650
+ assert isinstance(x, Tensor) or isinstance(x, View), (
651
+ f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
652
+ )
653
+ return x.jax()
654
+
655
+ res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
656
+ return res
534
657
 
535
- def t2j_iso(self, torchtensors):
536
- def to_jax(x):
537
- if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
538
- x = x.wait()
539
- assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor'
540
- return x.jax()
541
- return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
658
+ def v2t_iso(self, views):
542
659
 
543
- def j2t_iso(self, jaxarray):
544
- return torch_pytree.tree_map_only(
545
- jnp.ndarray, lambda x: Tensor(x, self), jaxarray)
660
+ def to_tensor(x):
661
+ if isinstance(x, View):
662
+ return x.torch()
663
+ return x
546
664
 
547
- def j2t_copy(self, args):
548
- pass
665
+ res = torch_pytree.tree_map_only(View, to_tensor, views)
666
+ return res
549
667
 
550
- def override_op_definition(self, op_to_override, op_impl):
551
- self._ops[op_to_override] = ops_registry.Operator(
668
+ def j2t_iso(self, jaxarray):
669
+ """Convert jax array to torchax Tensor.
670
+
671
+ This function will not copy, will just wrap the jax array with a torchax Tensor
672
+ Note: iso is short for "isomorphic"
673
+ """
674
+ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
675
+ jaxarray)
676
+
677
+ def j2t_copy(self, args):
678
+ """Convert torch.Tensor in cpu to a jax array
679
+
680
+ This might involves copying the data (depending if dlpack is enabled)
681
+ """
682
+ return torch_pytree.tree_map_only(
683
+ jax.Array,
684
+ lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
685
+ args)
686
+
687
+ def t2j_copy(self, args):
688
+ """Convert jax array to torch.Tensor in cpu.
689
+
690
+ This might involves copying the data (depending if dlpack is enabled)
691
+ """
692
+ return torch_pytree.tree_map_only(
693
+ torch.Tensor,
694
+ lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
695
+ args)
696
+
697
+ def override_op_definition(self, op_to_override, op_impl):
698
+ self._ops[op_to_override] = ops_registry.Operator(
552
699
  op_to_override,
553
700
  op_impl,
554
701
  is_jax_function=False,
555
702
  is_user_defined=True,
556
- needs_env=False
557
- )
703
+ needs_env=False,
704
+ )
705
+
706
+ @contextlib.contextmanager
707
+ def override_property(self, **kwargs):
708
+ new_prop = self.param.override(**kwargs)
709
+ self._property.content.append(new_prop)
710
+ yield
711
+ self._property.content.pop()