torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/tensor.py CHANGED
@@ -5,15 +5,18 @@ from typing import Optional, Any
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy
8
+ import itertools
8
9
  import torch
9
10
  import torch.distributed._functional_collectives
10
11
  import torch.func
11
12
  import torch.utils._mode_utils as mode_utils
12
13
  import torch.utils._python_dispatch as torch_dispatch
13
14
  import torch.utils._pytree as torch_pytree
14
-
15
+ from torchax.view import View
15
16
  from torchax import config
16
17
  from torchax.ops import mappings, ops_registry
18
+ from torchax import amp
19
+ from jax.experimental import mutable_array
17
20
 
18
21
  logger = logging.getLogger(__name__)
19
22
 
@@ -30,32 +33,15 @@ def unwrap(torchtensors):
30
33
  return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
31
34
 
32
35
 
33
- def t2j(t):
34
- if isinstance(t, Tensor):
35
- return t._elem
36
- return mappings.t2j(t)
37
-
38
-
39
- def j2t(x):
40
- return mappings.j2t(x)
41
-
42
-
43
- def t2j_dtype(dtype):
44
- return mappings.t2j_dtype(dtype)
45
-
46
-
47
- def j2t_dtype(dtype):
48
- return mappings.j2t_dtype(dtype)
49
-
50
-
51
36
  @contextlib.contextmanager
52
37
  def log_nested(env, message):
53
38
  if env.config.debug_print_each_op:
54
- print((' ' * log_nested.level) + message, file=sys.stderr)
39
+ print((" " * log_nested.level) + message, file=sys.stderr)
55
40
  log_nested.level += 1
56
41
  yield
57
42
  log_nested.level -= 1
58
43
 
44
+
59
45
  log_nested.level = 0
60
46
 
61
47
 
@@ -63,7 +49,7 @@ class Tensor(torch.Tensor):
63
49
 
64
50
  @staticmethod
65
51
  def __new__(cls, elem, env):
66
- dtype = j2t_dtype(elem.dtype)
52
+ dtype = mappings.j2t_dtype(elem.dtype)
67
53
  shape = list(elem.shape)
68
54
  for i, s in enumerate(shape):
69
55
  if not isinstance(s, int):
@@ -74,11 +60,11 @@ class Tensor(torch.Tensor):
74
60
  cls,
75
61
  shape,
76
62
  dtype=dtype,
77
- device='meta',
63
+ device="meta",
78
64
  requires_grad=False,
79
65
  )
80
66
 
81
- def __init__(self, elem: jax.Array, env: 'Environment'):
67
+ def __init__(self, elem: jax.Array, env: "Environment"):
82
68
  super().__init__()
83
69
  self._elem = elem
84
70
  self._env = env
@@ -93,7 +79,7 @@ class Tensor(torch.Tensor):
93
79
 
94
80
  @property
95
81
  def shape(self):
96
- return self._elem.shape
82
+ return torch.Size(self._elem.shape)
97
83
 
98
84
  @property
99
85
  def ndim(self):
@@ -120,14 +106,13 @@ class Tensor(torch.Tensor):
120
106
 
121
107
  @classmethod
122
108
  def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
123
- env = None
124
- for arg in torch_pytree.arg_tree_leaves(*args, **kwargs):
125
- if isinstance(arg, Tensor):
126
- env = arg._env
127
- break
128
-
129
- with env:
130
- return func(*args, **(kwargs or {}))
109
+ # TODO(hanq): figure out why is dispatch mode not sufficient
110
+ if func == torch.ops._c10d_functional.wait_tensor.default:
111
+ return args[0]._env.dispatch(func, types, args, kwargs)
112
+ raise AssertionError(
113
+ 'torchax Tensors can only do math within the torchax environment.'
114
+ 'Please wrap your code with `with torchax.default_env()` or '
115
+ 'call torchax.enable_globally() before.')
131
116
 
132
117
  def detach(self):
133
118
  return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
@@ -141,18 +126,18 @@ class Tensor(torch.Tensor):
141
126
  return self._elem
142
127
 
143
128
  def torch(self) -> torch.Tensor:
144
- return j2t(self.jax())
129
+ return self._env.j2t_copy(self.jax())
145
130
 
146
131
  @property
147
132
  def dtype(self):
148
- return j2t_dtype(self._elem.dtype)
133
+ return mappings.j2t_dtype(self._elem.dtype)
149
134
 
150
135
  def dim(self):
151
136
  return self.ndim
152
137
 
153
138
  @property
154
139
  def device(self):
155
- return torch.device('jax:0')
140
+ return torch.device("jax:0")
156
141
 
157
142
  @property
158
143
  def jax_device(self):
@@ -160,7 +145,8 @@ class Tensor(torch.Tensor):
160
145
 
161
146
  @property
162
147
  def data(self):
163
- logger.warn("In-place to .data modifications still results a copy on TPU")
148
+ logger.warning(
149
+ "In-place to .data modifications still results a copy on TPU")
164
150
  return self
165
151
 
166
152
  @data.setter
@@ -182,15 +168,15 @@ class Tensor(torch.Tensor):
182
168
 
183
169
  def shard_(self, sharding):
184
170
  self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
185
-
171
+
186
172
 
187
173
  def debug_accuracy(func, args, kwargs, current_output):
188
174
  args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
189
- torch.Tensor, lambda x: j2t(x._elem), (args, kwargs, current_output))
175
+ torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output))
190
176
 
191
177
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
192
- if 'device' in kwargs_torch:
193
- kwargs_torch['device'] = 'cpu' # do the torch native for comparison
178
+ if "device" in kwargs_torch:
179
+ kwargs_torch["device"] = "cpu" # do the torch native for comparison
194
180
  expected_out = func(*args_torch, **kwargs_torch)
195
181
 
196
182
  flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
@@ -200,8 +186,8 @@ def debug_accuracy(func, args, kwargs, current_output):
200
186
  if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
201
187
  ex = ex.to(real.dtype)
202
188
  try:
203
- if (isinstance(ex, torch.Tensor) and
204
- not torch.allclose(ex, real, atol=1e-3, equal_nan=True)):
189
+ if isinstance(ex, torch.Tensor) and not torch.allclose(
190
+ ex, real, atol=1e-3, equal_nan=True):
205
191
  import pdb
206
192
 
207
193
  pdb.set_trace()
@@ -212,46 +198,52 @@ def debug_accuracy(func, args, kwargs, current_output):
212
198
 
213
199
  return True
214
200
 
201
+
215
202
  def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
203
+
216
204
  def _display(a):
217
205
  if isinstance(a, torch.Tensor):
218
- return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
206
+ return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
219
207
  elif isinstance(a, jax.Array):
220
- return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
208
+ return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
221
209
  else:
222
210
  return str(a)
223
211
 
224
212
  kwargs = kwargs or {}
225
- title = 'DISPATCH' if is_dispatch else 'FUNCTION'
226
- args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
227
- kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
228
- return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'
213
+ title = "DISPATCH" if is_dispatch else "FUNCTION"
214
+ args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
215
+ kwargs_msg = ("kwargs: " +
216
+ ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
217
+ if log_args else "")
218
+ return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
229
219
 
230
220
 
231
221
  class XLAFunctionMode(torch.overrides.TorchFunctionMode):
232
222
  """Context manager that dispatches torch function calls to JAX."""
233
223
 
234
224
  def __init__(self, env):
235
- self.env = env
225
+ self.env = env
236
226
 
237
227
  def __torch_function__(self,
238
228
  func,
239
229
  types,
240
230
  args=(),
241
231
  kwargs=None) -> torch.Tensor:
242
- message = f'FUNCTION: {_name_of_func(func)}'
232
+ message = f"FUNCTION: {_name_of_func(func)}"
243
233
  if self.env.config.debug_print_each_op_operands:
244
- message = message + 'f'
245
- message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
234
+ message = message + "f"
235
+ message = _make_debug_msg(False,
236
+ self.env.config.debug_print_each_op_operands,
246
237
  func, args, kwargs)
247
238
  with log_nested(self.env, message):
248
239
  try:
249
240
  return self.env.dispatch(func, types, args, kwargs)
250
241
  except OperatorNotFound:
251
242
  pass
252
- if _name_of_func(func) in ('rot90'): # skip rot90 with k%4==0 due to no change
243
+ if _name_of_func(func) in (
244
+ "rot90"): # skip rot90 with k%4==0 due to no change
253
245
  if len(args) >= 2 and type(args[1]) == int:
254
- if ((args[1])%4 == 0):
246
+ if (args[1]) % 4 == 0:
255
247
  return args[0]
256
248
  return func(*args, **(kwargs or {}))
257
249
 
@@ -262,296 +254,446 @@ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
262
254
  self.env = env
263
255
 
264
256
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
265
- message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
257
+ message = _make_debug_msg(True,
258
+ self.env.config.debug_print_each_op_operands,
266
259
  func, args, kwargs)
267
260
  with log_nested(self.env, message):
268
261
  if isinstance(func, torch._ops.OpOverloadPacket):
269
262
  with self:
270
263
  return func(*args, **kwargs)
271
- if func.namespace not in ('aten', '_c10d_functional', 'torchvision'):
264
+ # Only functions under these namespaces will be intercepted
265
+ if func.namespace not in (
266
+ "aten",
267
+ "_c10d_functional",
268
+ "torchvision",
269
+ "xla",
270
+ ):
272
271
  return func(*args, **kwargs)
273
272
  return self.env.dispatch(func, types, args, kwargs)
274
273
 
274
+
275
275
  def _name_of_func(func):
276
- if hasattr(func, 'name'):
276
+ if hasattr(func, "name"):
277
277
  return func.name()
278
278
  return func.__name__
279
279
 
280
280
 
281
281
  # Constructors that don't take other tensor as input
282
282
  TENSOR_CONSTRUCTORS = {
283
- torch.ones,
284
- torch.zeros,
285
- torch.empty,
286
- torch.empty_strided,
287
- torch.tensor,
288
- torch.arange,
289
- torch.eye,
290
- torch.randn,
291
- torch.rand,
292
- torch.randint,
293
- torch.full,
294
- torch.as_tensor,
283
+ torch.ones,
284
+ torch.zeros,
285
+ torch.empty,
286
+ torch.empty_strided,
287
+ torch.tensor,
288
+ torch.arange,
289
+ torch.eye,
290
+ torch.randn,
291
+ torch.rand,
292
+ torch.randint,
293
+ torch.full,
294
+ torch.as_tensor,
295
295
  }
296
296
 
297
+ # TODO(wen): use existing types, either from torch or jax
298
+ SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
297
299
 
298
- class Environment(contextlib.ContextDecorator):
299
- """This class holds a set of configurations and "globals" needed
300
-
301
- for executing torch program using jax.
302
- Things included so far:
303
-
304
- op registry
305
- PRNGKey
306
- Configs
307
300
 
308
- Also helper functions to manipulate those.
309
- """
301
+ class Environment(contextlib.ContextDecorator):
302
+ """This class holds a set of configurations and "globals" needed
310
303
 
311
- _prng_key: jax.random.PRNGKey
304
+ for executing torch program using jax.
305
+ Things included so far:
312
306
 
307
+ op registry
308
+ PRNGKey
309
+ Configs
313
310
 
314
- def __init__(self, configuration=None):
315
- self._function_mode = XLAFunctionMode(self)
316
- self._dispatch_mode = XLADispatchMode(self)
311
+ Also helper functions to manipulate those.
312
+ """
317
313
 
318
- # name is torch callable
319
- self._ops = {}
320
- self.load_ops()
314
+ def __init__(self, configuration=None):
315
+ self._function_mode = XLAFunctionMode(self)
316
+ self._dispatch_mode = XLADispatchMode(self)
321
317
 
322
- self._mesh = None
323
- self.config = configuration or config.Configuration()
318
+ # name is torch callable
319
+ self._ops = {}
320
+ self._decomps = {}
324
321
 
325
- self._manually_entered = False
326
- self.enabled = False
327
- self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
322
+ self.load_ops()
328
323
 
329
- def get_as_jax_device(self, device: Any):
330
- if device is None:
331
- device = torch.get_default_device()
324
+ self._mesh = None
325
+ self.config = configuration or config.Configuration()
332
326
 
333
- if isinstance(device, torch.device):
334
- device = str(device)
327
+ self._manually_entered = False
328
+ self.enabled = False
335
329
 
336
- if (not self.config.use_torch_native_for_cpu_tensor and
337
- device.startswith('cpu')):
338
- return jax.devices('cpu')[0]
330
+ self._prng_key = mutable_array(
331
+ jax.random.key(torch.initial_seed() % (1 << 63)))
332
+ self.autocast_dtype = None
333
+ self._target_device = jax.local_devices()[0].platform
339
334
 
340
- if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
341
- return jax.local_devices()[0]
335
+ @property
336
+ def target_device(self):
337
+ return self._target_device
342
338
 
343
- if device.startswith('jax'):
344
- return jax.local_devices()[0]
339
+ @target_device.setter
340
+ def target_device(self, device: str):
341
+ self._target_device = device.lower()
345
342
 
346
- return None # fallback to torch
347
-
343
+ def manual_seed(self, key):
344
+ self._prng_key = mutable_array(jax.random.key(key))
348
345
 
346
+ @property
347
+ def prng_key(self):
348
+ return self._prng_key[...]
349
+
350
+ def get_as_jax_device(self, device: Any):
351
+ if device is None:
352
+ device = torch.get_default_device()
353
+
354
+ if isinstance(device, torch.device):
355
+ device = str(device)
356
+
357
+ if not self.config.use_torch_native_for_cpu_tensor and device.startswith(
358
+ "cpu"):
359
+ return jax.devices("cpu")[0]
360
+
361
+ if self.config.treat_cuda_as_jax_device and device.startswith("cuda"):
362
+ return jax.local_devices()[0]
363
+
364
+ if device.startswith("xla"):
365
+ return jax.local_devices()[0]
366
+
367
+ # TODO (wen): jax is NOT a device type,
368
+ # once we can register more than one backend, revisit
369
+ if device.startswith("jax"):
370
+ match self.target_device:
371
+ case "cpu":
372
+ return jax.devices("cpu")[0]
373
+ case "tpu":
374
+ return jax.devices("tpu")[0]
375
+ case _:
376
+ raise AttributeError(
377
+ f"Cannot handle env.target_device {self.target_device}")
378
+
379
+ return None # fallback to torch
380
+
381
+ def load_ops(self):
382
+ from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
383
+
384
+ for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
385
+ ops_registry.all_torch_functions.items()):
386
+ if v.is_jax_function:
387
+ self._ops[k] = v
388
+ else:
389
+ self._decomps[k] = v
349
390
 
350
- def load_ops(self):
351
- from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
352
- self._ops.update(ops_registry.all_aten_ops)
353
- self._ops.update(ops_registry.all_torch_functions)
391
+ from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
354
392
 
355
- decomps = torch._decomp.core_aten_decompositions()
356
- from torchax.decompositions import EXTRA_DECOMP
357
- decomps.update(EXTRA_DECOMP)
358
- for k, v in decomps.items():
359
- if k not in self._ops:
360
- self._ops[k] = ops_registry.Operator(
393
+ for k, v in DECOMPOSITIONS.items():
394
+ if k not in self._decomps:
395
+ self._decomps[k] = ops_registry.Operator(
361
396
  k,
362
397
  v,
363
398
  is_jax_function=False,
364
399
  is_user_defined=False,
365
- needs_env=False
366
- )
367
-
368
- def _to_copy(self, the_tensor, new_dtype, new_device):
369
- if isinstance(the_tensor, Tensor):
370
- arr = the_tensor.jax()
371
- if new_dtype is not None and new_dtype != arr.dtype:
372
- arr = arr.astype(mappings.t2j_dtype(new_dtype))
373
- if new_device is not None:
374
- # convert xla tensor to other device
375
- # only supported is CPU
376
- if str(new_device).startswith('cpu'):
400
+ needs_env=False,
401
+ is_view_op=k in MUTABLE_DECOMPOSITION,
402
+ )
403
+
404
+ def _get_op_or_decomp(self, func):
405
+
406
+ def _get_from_dict(op_dict, op):
407
+ op = op_dict.get(func)
408
+ if op is None and isinstance(func, torch._ops.OpOverloadPacket):
409
+ op = op_dict.get(func.default)
410
+ if op is None and isinstance(func, torch._ops.OpOverload):
411
+ op = op_dict.get(func.overloadpacket)
412
+ return op
413
+
414
+ op = _get_from_dict(self._ops, func)
415
+
416
+ if op is None:
417
+ # fallback to decompose
418
+ op = _get_from_dict(self._decomps, func)
419
+
420
+ if op is None:
421
+ raise OperatorNotFound(
422
+ f"Operator with name {_name_of_func(func)} has no lowering")
423
+
424
+ return op
425
+
426
+ def _to_copy(self, the_tensor, new_dtype, new_device):
427
+ if isinstance(the_tensor, View):
428
+ the_tensor = the_tensor.torch()
429
+
430
+ if isinstance(the_tensor, Tensor):
431
+
432
+ arr = the_tensor.jax()
433
+
434
+ if new_dtype is not None and new_dtype != arr.dtype:
435
+ arr = arr.astype(mappings.t2j_dtype(new_dtype))
436
+
437
+ if new_device is not None:
438
+ match str(new_device).lower():
439
+ case "cpu":
377
440
  # converting to a non-jax device: let torch native handle it
378
- torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
441
+ torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor,
442
+ Tensor) else arr
379
443
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
380
444
  return torch_tensor.to(new_device)
445
+ case "jax":
446
+ # move torchax.tensor / jax tensor between devices
447
+ # I don't know ifgit this will work after the model is jitted
448
+ if self.target_device != the_tensor.jax_device.platform:
449
+ arr = jax.device_put(the_tensor.jax(),
450
+ jax.devices(self.target_device)[0])
451
+ return Tensor(arr, self)
452
+ case _:
453
+ logging.error(f"torchax.Tenosr cannot handle device {new_device}")
454
+
455
+ else:
456
+ if new_dtype is not None and new_dtype != the_tensor.dtype:
457
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
458
+ the_tensor = the_tensor.to(new_dtype)
459
+
460
+ if new_device is None: ## device is None means don't change device
461
+ return the_tensor
462
+
463
+ jax_device = self.get_as_jax_device(new_device)
464
+ if jax_device:
465
+ arr = self.t2j_copy(the_tensor)
466
+ arr = jax.device_put(arr, jax_device)
381
467
  else:
382
- if new_dtype is not None and new_dtype != the_tensor.dtype:
383
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
384
- the_tensor = the_tensor.to(new_dtype)
385
- jax_device = self.get_as_jax_device(new_device)
386
- if jax_device:
387
- arr = t2j(the_tensor)
388
- arr = jax.device_put(arr, jax_device)
389
- else:
390
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
391
- return the_tensor.to(new_device)
468
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
469
+ return the_tensor.to(new_device)
392
470
 
393
- return Tensor(arr, self)
394
-
471
+ return Tensor(arr, self)
395
472
 
396
- def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
397
- # Always use the default `randint` to get the next seed
473
+ def get_and_rotate_prng_key(self,
474
+ generator: Optional[torch.Generator] = None):
475
+ if generator is not None:
398
476
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
399
- next_key = torch.randint(
400
- 0, 2**32, (), dtype=torch.uint32, generator=generator).numpy()
401
-
402
- return jax.random.key(next_key)
477
+ self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63))
478
+ old_key = self._prng_key[...]
479
+ new_prng_key, next_key = jax.random.split(old_key)
480
+ self._prng_key[...] = new_prng_key
481
+ return next_key
482
+
483
+ def _handle_tensor_constructor(self, func, args, kwargs):
484
+ device = kwargs.get("device")
485
+ jax_device = self.get_as_jax_device(device)
486
+ # TODO(qihqi) figure out better ways for device propagation
487
+ if not self._manually_entered and jax_device is None:
488
+ # let torch handle it
489
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
490
+ return func(*args, **kwargs)
491
+ with jax.default_device(jax_device):
492
+ requires_grad = kwargs.get("requires_grad", False)
493
+ op = self._get_op_or_decomp(func)
494
+ res = op.func(*args, **kwargs)
495
+ if isinstance(res, jax.Array):
496
+ res = Tensor(res, self)
497
+ if requires_grad:
498
+ res.requires_grad = True
499
+ return res
403
500
 
404
- def _handle_tensor_constructor(self, func, args, kwargs):
405
- device = kwargs.get('device')
406
- jax_device = self.get_as_jax_device(device)
407
- # TODO(qihqi) figure out better ways for device propagation
408
- if not self._manually_entered and jax_device is None:
409
- # let torch handle it
410
- with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
411
- return func(*args, **kwargs)
412
- with jax.default_device(jax_device):
413
- op = self._ops.get(func)
414
- if op is None and isinstance(func, torch._ops.OpOverload):
415
- op = self._ops.get(func.overloadpacket)
416
- res = op.func(*args, **kwargs)
417
- if isinstance(res, jax.Array):
418
- res = Tensor(res, self)
419
- return res
420
-
421
- def _torch_Tensor_to(self, args, kwargs):
422
- the_tensor = args[0]
423
- args = args[1:]
424
- if len(args) >= 1 and isinstance(args[0], torch.Tensor):
425
- dtype = args[0].dtype
426
- device = args[0].device
427
- return self._to_copy(the_tensor, dtype, device)
428
- device = kwargs.get('device')
429
- dtype = kwargs.get('dtype')
430
- # args like pin_memory etc that we will ignore
431
- args = list(filter(lambda x: not isinstance(x, bool), args))
432
- if len(args) >= 2:
433
- device, dtype, *_ = args
434
- elif len(args) == 1 and isinstance(args[0], torch.dtype):
435
- dtype = args[0]
436
- elif len(args) == 1:
437
- device = args[0]
501
+ def _torch_Tensor_to(self, args, kwargs):
502
+ the_tensor = args[0]
503
+ args = args[1:]
504
+ if len(args) >= 1 and isinstance(args[0], torch.Tensor):
505
+ dtype = args[0].dtype
506
+ device = args[0].device
438
507
  return self._to_copy(the_tensor, dtype, device)
508
+ device = kwargs.get("device")
509
+ dtype = kwargs.get("dtype")
510
+ # args like pin_memory etc that we will ignore
511
+ args = list(filter(lambda x: not isinstance(x, bool), args))
512
+ if len(args) >= 2:
513
+ device, dtype, *_ = args
514
+ elif len(args) == 1 and isinstance(args[0], torch.dtype):
515
+ dtype = args[0]
516
+ elif len(args) == 1:
517
+ device = args[0]
518
+ return self._to_copy(the_tensor, dtype, device)
519
+
520
+ def dispatch(self, func, types, args, kwargs):
521
+ kwargs = kwargs or {}
522
+ if func in TENSOR_CONSTRUCTORS:
523
+ return self._handle_tensor_constructor(func, args, kwargs)
524
+ if func in (
525
+ torch.Tensor.to,
526
+ torch.ops.aten.lift_fresh.default,
527
+ torch.ops.aten._to_copy,
528
+ torch.ops.aten._to_copy.default,
529
+ ):
530
+ return self._torch_Tensor_to(args, kwargs)
531
+
532
+ # If the func doesn't act on Tensor, and is not a tensor constructor,
533
+ # We should skip and let torch handle it.
534
+
535
+ tensor_args = [
536
+ t for t in torch_pytree.tree_flatten(args)[0]
537
+ if isinstance(t, torch.Tensor)
538
+ ]
539
+
540
+ def is_not_torchax_tensor(x):
541
+ return not isinstance(x, Tensor) and not isinstance(x, View)
542
+
543
+ if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
544
+ res = func(*args, **kwargs)
545
+ return res
439
546
 
547
+ with jax.named_scope(_name_of_func(func)):
548
+ op = self._get_op_or_decomp(func)
440
549
 
441
- def dispatch(self, func, types, args, kwargs):
550
+ old_args, old_kwargs = args, kwargs
551
+ with self._dispatch_mode:
552
+ args, kwargs = torch_pytree.tree_map_only(
553
+ torch.distributed._functional_collectives.AsyncCollectiveTensor,
554
+ torch.distributed._functional_collectives.wait_tensor,
555
+ (args, kwargs),
556
+ )
442
557
 
443
- kwargs = kwargs or {}
444
- if func in TENSOR_CONSTRUCTORS:
445
- return self._handle_tensor_constructor(func, args, kwargs)
446
- if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default):
447
- return self._torch_Tensor_to(args, kwargs)
558
+ try:
559
+ if not op.is_view_op:
560
+ args, kwargs = self.v2t_iso((args, kwargs))
448
561
 
449
- # If the func doesn't act on Tensor, and is not a tensor constructor,
450
- # We should skip and let torch handle it.
451
-
452
- tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)]
453
- if tensor_args and all(not isinstance(t, Tensor) for t in tensor_args):
454
- return func(*args, **kwargs)
562
+ with self:
563
+ if self.autocast_dtype is not None:
564
+ autocast_policy = amp.autocast_policy.get(func)
565
+ if autocast_policy is not None:
566
+ args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
567
+ self.autocast_dtype)
455
568
 
456
- with jax.named_scope(_name_of_func(func)):
457
- op = self._ops.get(func)
569
+ if op.is_jax_function:
570
+ args, kwargs = self.t2j_iso((args, kwargs))
571
+ except AssertionError:
572
+ if self.config.debug_mixed_tensor:
573
+ breakpoint()
574
+ else:
575
+ raise
458
576
 
459
- if op is None and isinstance(func, torch._ops.OpOverloadPacket):
460
- op = self._ops.get(func.default)
577
+ if op.needs_env:
578
+ kwargs["env"] = self
461
579
 
462
- if op is None and isinstance(func, torch._ops.OpOverload):
463
- op = self._ops.get(func.overloadpacket)
580
+ if op.is_jax_function:
581
+ res = op.func(*args, **kwargs)
582
+ else:
583
+ # enable dispatch mode because this op could be a composite autograd op
584
+ # meaning, it will decompose in C++
585
+ with self._dispatch_mode:
586
+ res = op.func(*args, **kwargs)
464
587
 
465
- if op is None:
466
- raise OperatorNotFound(
467
- f'Operator with name {_name_of_func(func)} has no lowering')
588
+ if op.is_jax_function:
589
+ res = self.j2t_iso(res)
468
590
 
469
- old_args, old_kwargs = args, kwargs
470
- args, kwargs = torch_pytree.tree_map_only(
471
- torch.distributed._functional_collectives.AsyncCollectiveTensor,
472
- torch.distributed._functional_collectives.wait_tensor,
473
- (args, kwargs))
474
- try:
475
- if op.is_jax_function:
476
- args, kwargs = self.t2j_iso((args, kwargs))
477
- except AssertionError:
478
- if self.config.debug_mixed_tensor:
479
- import pdb; pdb.set_trace()
480
- else:
481
- raise
591
+ if self.config.force_materialize_views and isinstance(res, View):
592
+ res = res.torch()
482
593
 
594
+ if self.config.debug_accuracy_for_each_op:
595
+ debug_accuracy(func, old_args, old_kwargs, res)
596
+ return res
483
597
 
484
- if op.needs_env:
485
- kwargs['env'] = self
598
+ def enable_torch_modes(self):
599
+ self._dispatch_mode.__enter__()
600
+ self._function_mode.__enter__()
601
+ self.enabled = True
602
+
603
+ def disable_torch_modes(self, *exc):
604
+ if not exc:
605
+ exc = (None, None, None)
606
+ self._function_mode.__exit__(*exc)
607
+ self._dispatch_mode.__exit__(*exc)
608
+ self.enabled = False
609
+
610
+ def __enter__(self):
611
+ self.enable_torch_modes()
612
+ self._manually_entered = True
613
+ return self
486
614
 
487
- with self:
488
- res = op.func(*args, **kwargs)
615
+ def __exit__(self, *exc):
616
+ self._manually_entered = False
617
+ self.disable_torch_modes(*exc)
489
618
 
490
- if op.is_jax_function:
491
- res = self.j2t_iso(res)
619
+ def _move_one_value(self, val):
620
+ if isinstance(val, torch.nn.Module):
621
+ with self:
622
+ return val.to("jax")
623
+ if isinstance(val, Tensor):
624
+ return val
625
+ if isinstance(val, torch.Tensor):
626
+ return Tensor(self.t2j_copy(val), self)
627
+ return val
492
628
 
493
- if self.config.debug_accuracy_for_each_op:
494
- debug_accuracy(func, old_args, old_kwargs, res)
495
- return res
629
+ def to_xla(self, torchvalues):
630
+ # tensors are torch.Tensors (not XLATensor)
631
+ res = torch_pytree.tree_map(self._move_one_value, torchvalues)
632
+ return res
496
633
 
497
- def enable_torch_modes(self):
498
- self._dispatch_mode.__enter__()
499
- self._function_mode.__enter__()
500
- self.enabled = True
634
+ def t2j_iso(self, torchtensors):
635
+ """Convert torchax Tensor to jax array.
501
636
 
502
- def disable_torch_modes(self, *exc):
503
- if not exc:
504
- exc = (None, None, None)
505
- self._function_mode.__exit__(*exc)
506
- self._dispatch_mode.__exit__(*exc)
507
- self.enabled = False
508
-
509
- def __enter__(self):
510
- self.enable_torch_modes()
511
- self._manually_entered = True
512
- return self
513
-
514
- def __exit__(self, *exc):
515
- self._manually_entered = False
516
- self.disable_torch_modes(*exc)
517
-
518
- def _move_one_value(self, val):
519
- if isinstance(val, torch.nn.Module):
520
- with self:
521
- return val.to('jax')
522
- if isinstance(val, Tensor):
523
- return val
524
- if isinstance(val, torch.Tensor):
525
- return Tensor(t2j(val), self)
526
- return val
637
+ This function will not copy, will just unwrap the inner jax array out.
638
+ Note: iso is short for "isomorphic"
639
+ """
527
640
 
528
- def to_xla(self, torchvalues):
529
- # tensors are torch.Tensors (not XLATensor)
530
- res = torch_pytree.tree_map(
531
- self._move_one_value,
532
- torchvalues)
533
- return res
641
+ def to_jax(x):
642
+ if isinstance(
643
+ x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
644
+ x = x.wait()
645
+ assert isinstance(x, Tensor) or isinstance(x, View), (
646
+ f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
647
+ )
648
+ return x.jax()
649
+
650
+ res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
651
+ return res
534
652
 
535
- def t2j_iso(self, torchtensors):
536
- def to_jax(x):
537
- if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
538
- x = x.wait()
539
- assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor'
540
- return x.jax()
541
- return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
653
+ def v2t_iso(self, views):
542
654
 
543
- def j2t_iso(self, jaxarray):
544
- return torch_pytree.tree_map_only(
545
- jnp.ndarray, lambda x: Tensor(x, self), jaxarray)
655
+ def to_tensor(x):
656
+ if isinstance(x, View):
657
+ return x.torch()
658
+ return x
546
659
 
547
- def j2t_copy(self, args):
548
- pass
660
+ res = torch_pytree.tree_map_only(View, to_tensor, views)
661
+ return res
549
662
 
550
- def override_op_definition(self, op_to_override, op_impl):
551
- self._ops[op_to_override] = ops_registry.Operator(
663
+ def j2t_iso(self, jaxarray):
664
+ """Convert jax array to torchax Tensor.
665
+
666
+ This function will not copy, will just wrap the jax array with a torchax Tensor
667
+ Note: iso is short for "isomorphic"
668
+ """
669
+ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
670
+ jaxarray)
671
+
672
+ def j2t_copy(self, args):
673
+ """Convert torch.Tensor in cpu to a jax array
674
+
675
+ This might involves copying the data (depending if dlpack is enabled)
676
+ """
677
+ return torch_pytree.tree_map_only(
678
+ jax.Array,
679
+ lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
680
+ args)
681
+
682
+ def t2j_copy(self, args):
683
+ """Convert jax array to torch.Tensor in cpu.
684
+
685
+ This might involves copying the data (depending if dlpack is enabled)
686
+ """
687
+ return torch_pytree.tree_map_only(
688
+ torch.Tensor,
689
+ lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
690
+ args)
691
+
692
+ def override_op_definition(self, op_to_override, op_impl):
693
+ self._ops[op_to_override] = ops_registry.Operator(
552
694
  op_to_override,
553
695
  op_impl,
554
696
  is_jax_function=False,
555
697
  is_user_defined=True,
556
- needs_env=False
557
- )
698
+ needs_env=False,
699
+ )