torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202612.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251116.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202612.dist-info}/licenses/LICENSE +0 -0
torchax/tensor.py
CHANGED
|
@@ -12,25 +12,26 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
15
|
+
import contextlib
|
|
16
|
+
import itertools
|
|
16
17
|
import logging
|
|
17
18
|
import sys
|
|
18
|
-
import
|
|
19
|
-
from typing import
|
|
19
|
+
import threading
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
20
22
|
import jax
|
|
21
23
|
import jax.numpy as jnp
|
|
22
24
|
import numpy
|
|
23
|
-
import itertools
|
|
24
25
|
import torch
|
|
25
26
|
import torch.distributed._functional_collectives
|
|
26
27
|
import torch.func
|
|
27
28
|
import torch.utils._mode_utils as mode_utils
|
|
28
29
|
import torch.utils._python_dispatch as torch_dispatch
|
|
29
30
|
import torch.utils._pytree as torch_pytree
|
|
30
|
-
|
|
31
|
-
from torchax import config
|
|
31
|
+
|
|
32
|
+
from torchax import amp, config
|
|
32
33
|
from torchax.ops import mappings, ops_registry
|
|
33
|
-
from torchax import
|
|
34
|
+
from torchax.view import View
|
|
34
35
|
|
|
35
36
|
logger = logging.getLogger(__name__)
|
|
36
37
|
|
|
@@ -52,7 +53,6 @@ log_nested.level = 0
|
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
class Tensor(torch.Tensor):
|
|
55
|
-
|
|
56
56
|
@staticmethod
|
|
57
57
|
def __new__(cls, elem, env, requires_grad=False):
|
|
58
58
|
dtype = mappings.j2t_dtype(elem.dtype)
|
|
@@ -62,16 +62,16 @@ class Tensor(torch.Tensor):
|
|
|
62
62
|
shape[i] = 1
|
|
63
63
|
if dtype is None:
|
|
64
64
|
dtype = torch.float32
|
|
65
|
-
#dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
|
|
65
|
+
# dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
|
|
66
66
|
if not (dtype.is_floating_point or dtype.is_complex):
|
|
67
67
|
requires_grad = False
|
|
68
68
|
|
|
69
69
|
return torch.Tensor._make_wrapper_subclass(
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
70
|
+
cls,
|
|
71
|
+
shape,
|
|
72
|
+
dtype=dtype,
|
|
73
|
+
device="meta",
|
|
74
|
+
requires_grad=requires_grad,
|
|
75
75
|
)
|
|
76
76
|
|
|
77
77
|
def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
|
|
@@ -80,7 +80,7 @@ class Tensor(torch.Tensor):
|
|
|
80
80
|
self._env = env
|
|
81
81
|
|
|
82
82
|
def __str__(self):
|
|
83
|
-
return "Tensor({
|
|
83
|
+
return f"Tensor({str(type(self._elem))} {str(self._elem)})"
|
|
84
84
|
|
|
85
85
|
__repr__ = __str__
|
|
86
86
|
|
|
@@ -95,8 +95,7 @@ class Tensor(torch.Tensor):
|
|
|
95
95
|
def flatten(self, start_dim=0, end_dim=-1):
|
|
96
96
|
if end_dim == -1:
|
|
97
97
|
end_dim = self.ndim
|
|
98
|
-
new_shape = (
|
|
99
|
-
self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
|
|
98
|
+
new_shape = self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :]
|
|
100
99
|
new_elem = jnp.reshape(self._elem, new_shape)
|
|
101
100
|
return Tensor(new_elem, self._env)
|
|
102
101
|
# return torch.reshape(self, new_shape)
|
|
@@ -117,11 +116,12 @@ class Tensor(torch.Tensor):
|
|
|
117
116
|
if func == torch.ops._c10d_functional.wait_tensor.default:
|
|
118
117
|
return args[0]._env.dispatch(func, types, args, kwargs)
|
|
119
118
|
if func == torch.ops.prim.device.default:
|
|
120
|
-
return torch.device(
|
|
119
|
+
return torch.device("privateuseone", 0)
|
|
121
120
|
raise AssertionError(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
121
|
+
"torchax Tensors can only do math within the torchax environment."
|
|
122
|
+
"Please wrap your code with `with torchax.default_env()` or "
|
|
123
|
+
"call torchax.enable_globally() before."
|
|
124
|
+
)
|
|
125
125
|
|
|
126
126
|
def detach(self):
|
|
127
127
|
return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
|
|
@@ -154,8 +154,7 @@ class Tensor(torch.Tensor):
|
|
|
154
154
|
|
|
155
155
|
@property
|
|
156
156
|
def data(self):
|
|
157
|
-
logger.warning(
|
|
158
|
-
"In-place to .data modifications still results a copy on TPU")
|
|
157
|
+
logger.warning("In-place to .data modifications still results a copy on TPU")
|
|
159
158
|
return self
|
|
160
159
|
|
|
161
160
|
@data.setter
|
|
@@ -181,7 +180,8 @@ class Tensor(torch.Tensor):
|
|
|
181
180
|
|
|
182
181
|
def debug_accuracy(func, args, kwargs, current_output):
|
|
183
182
|
args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
|
|
184
|
-
|
|
183
|
+
torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)
|
|
184
|
+
)
|
|
185
185
|
|
|
186
186
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
187
187
|
if "device" in kwargs_torch:
|
|
@@ -191,16 +191,17 @@ def debug_accuracy(func, args, kwargs, current_output):
|
|
|
191
191
|
flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
|
|
192
192
|
flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
|
|
193
193
|
|
|
194
|
-
for ex, real in zip(flattened_expected_out, flattened_current_out):
|
|
194
|
+
for ex, real in zip(flattened_expected_out, flattened_current_out, strict=False):
|
|
195
195
|
if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
|
|
196
196
|
ex = ex.to(real.dtype)
|
|
197
197
|
try:
|
|
198
198
|
if isinstance(ex, torch.Tensor) and not torch.allclose(
|
|
199
|
-
|
|
199
|
+
ex, real, atol=1e-3, equal_nan=True
|
|
200
|
+
):
|
|
200
201
|
import pdb
|
|
201
202
|
|
|
202
203
|
pdb.set_trace()
|
|
203
|
-
except:
|
|
204
|
+
except Exception:
|
|
204
205
|
import pdb
|
|
205
206
|
|
|
206
207
|
pdb.set_trace()
|
|
@@ -209,7 +210,6 @@ def debug_accuracy(func, args, kwargs, current_output):
|
|
|
209
210
|
|
|
210
211
|
|
|
211
212
|
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
|
|
212
|
-
|
|
213
213
|
def _display(a):
|
|
214
214
|
if isinstance(a, torch.Tensor):
|
|
215
215
|
return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
|
|
@@ -221,9 +221,11 @@ def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
|
|
|
221
221
|
kwargs = kwargs or {}
|
|
222
222
|
title = "DISPATCH" if is_dispatch else "FUNCTION"
|
|
223
223
|
args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
|
|
224
|
-
kwargs_msg = (
|
|
225
|
-
|
|
226
|
-
|
|
224
|
+
kwargs_msg = (
|
|
225
|
+
"kwargs: " + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
|
|
226
|
+
if log_args
|
|
227
|
+
else ""
|
|
228
|
+
)
|
|
227
229
|
return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
|
|
228
230
|
|
|
229
231
|
|
|
@@ -233,49 +235,43 @@ class XLAFunctionMode(torch.overrides.TorchFunctionMode):
|
|
|
233
235
|
def __init__(self, env):
|
|
234
236
|
self.env = env
|
|
235
237
|
|
|
236
|
-
def __torch_function__(self,
|
|
237
|
-
func,
|
|
238
|
-
types,
|
|
239
|
-
args=(),
|
|
240
|
-
kwargs=None) -> torch.Tensor:
|
|
238
|
+
def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
241
239
|
message = f"FUNCTION: {_name_of_func(func)}"
|
|
242
240
|
if self.env.config.debug_print_each_op_operands:
|
|
243
241
|
message = message + "f"
|
|
244
|
-
message = _make_debug_msg(
|
|
245
|
-
|
|
246
|
-
|
|
242
|
+
message = _make_debug_msg(
|
|
243
|
+
False, self.env.config.debug_print_each_op_operands, func, args, kwargs
|
|
244
|
+
)
|
|
247
245
|
with log_nested(self.env, message):
|
|
248
246
|
try:
|
|
249
247
|
return self.env.dispatch(func, types, args, kwargs)
|
|
250
248
|
except OperatorNotFound:
|
|
251
249
|
pass
|
|
252
|
-
if _name_of_func(func) in (
|
|
253
|
-
|
|
254
|
-
if len(args) >= 2 and type(args[1]) == int:
|
|
250
|
+
if _name_of_func(func) in ("rot90"): # skip rot90 with k%4==0 due to no change
|
|
251
|
+
if len(args) >= 2 and isinstance(args[1], int):
|
|
255
252
|
if (args[1]) % 4 == 0:
|
|
256
253
|
return args[0]
|
|
257
254
|
return func(*args, **(kwargs or {}))
|
|
258
255
|
|
|
259
256
|
|
|
260
257
|
class XLADispatchMode(torch_dispatch.TorchDispatchMode):
|
|
261
|
-
|
|
262
258
|
def __init__(self, env):
|
|
263
259
|
self.env = env
|
|
264
260
|
|
|
265
261
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
266
|
-
message = _make_debug_msg(
|
|
267
|
-
|
|
268
|
-
|
|
262
|
+
message = _make_debug_msg(
|
|
263
|
+
True, self.env.config.debug_print_each_op_operands, func, args, kwargs
|
|
264
|
+
)
|
|
269
265
|
with log_nested(self.env, message):
|
|
270
266
|
if isinstance(func, torch._ops.OpOverloadPacket):
|
|
271
267
|
with self:
|
|
272
268
|
return func(*args, **kwargs)
|
|
273
269
|
# Only functions under these namespaces will be intercepted
|
|
274
270
|
if func.namespace not in (
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
271
|
+
"aten",
|
|
272
|
+
"_c10d_functional",
|
|
273
|
+
"torchvision",
|
|
274
|
+
"xla",
|
|
279
275
|
):
|
|
280
276
|
return func(*args, **kwargs)
|
|
281
277
|
return self.env.dispatch(func, types, args, kwargs)
|
|
@@ -289,18 +285,18 @@ def _name_of_func(func):
|
|
|
289
285
|
|
|
290
286
|
# Constructors that don't take other tensor as input
|
|
291
287
|
TENSOR_CONSTRUCTORS = {
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
288
|
+
torch.ones,
|
|
289
|
+
torch.zeros,
|
|
290
|
+
torch.empty,
|
|
291
|
+
torch.empty_strided,
|
|
292
|
+
torch.tensor,
|
|
293
|
+
torch.arange,
|
|
294
|
+
torch.eye,
|
|
295
|
+
torch.randn,
|
|
296
|
+
torch.rand,
|
|
297
|
+
torch.randint,
|
|
298
|
+
torch.full,
|
|
299
|
+
torch.as_tensor,
|
|
304
300
|
}
|
|
305
301
|
|
|
306
302
|
# TODO(wen): use existing types, either from torch or jax
|
|
@@ -328,7 +324,6 @@ class RuntimeProperty:
|
|
|
328
324
|
|
|
329
325
|
|
|
330
326
|
class OverrideProperty(RuntimeProperty):
|
|
331
|
-
|
|
332
327
|
def __init__(self, parent, override):
|
|
333
328
|
self.parent = parent
|
|
334
329
|
self._override = dict(override)
|
|
@@ -372,25 +367,24 @@ class Environment(contextlib.ContextDecorator):
|
|
|
372
367
|
_prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
|
|
373
368
|
self._property = threading.local()
|
|
374
369
|
self._initial_content = RuntimeProperty(
|
|
375
|
-
|
|
370
|
+
mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype
|
|
371
|
+
)
|
|
376
372
|
|
|
377
373
|
@property
|
|
378
374
|
def param(self):
|
|
379
|
-
if not hasattr(self._property,
|
|
380
|
-
self._property.content = [
|
|
381
|
-
self._initial_content
|
|
382
|
-
]
|
|
375
|
+
if not hasattr(self._property, "content"):
|
|
376
|
+
self._property.content = [self._initial_content]
|
|
383
377
|
return self._property.content[-1]
|
|
384
378
|
|
|
385
379
|
def manual_seed(self, key):
|
|
386
380
|
if isinstance(key, torch.Tensor):
|
|
387
|
-
|
|
388
|
-
|
|
381
|
+
assert key.ndim == 0, "manual seed can only take scalars"
|
|
382
|
+
assert not key.dtype.is_floating_point, "manual seed can only be integers"
|
|
389
383
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
384
|
+
if isinstance(key, Tensor):
|
|
385
|
+
key = key._elem
|
|
386
|
+
else:
|
|
387
|
+
key = key.item()
|
|
394
388
|
jax_key = jax.random.PRNGKey(key)
|
|
395
389
|
new_prop = self.param.override(prng=jax_key)
|
|
396
390
|
self._property.content.append(new_prop)
|
|
@@ -406,27 +400,28 @@ class Environment(contextlib.ContextDecorator):
|
|
|
406
400
|
if isinstance(device, torch.device):
|
|
407
401
|
device = device.type
|
|
408
402
|
|
|
409
|
-
if
|
|
410
|
-
device = device.split(
|
|
403
|
+
if ":" in device:
|
|
404
|
+
device = device.split(":")[0]
|
|
411
405
|
|
|
412
406
|
match device:
|
|
413
|
-
case
|
|
407
|
+
case "cpu":
|
|
414
408
|
return False
|
|
415
|
-
case
|
|
409
|
+
case "cuda":
|
|
416
410
|
return self.config.treat_cuda_as_jax_device
|
|
417
|
-
case
|
|
411
|
+
case "jax":
|
|
418
412
|
return True
|
|
419
|
-
case
|
|
413
|
+
case "privateuseone":
|
|
420
414
|
return True
|
|
421
|
-
case
|
|
415
|
+
case "meta":
|
|
422
416
|
return self.enabled
|
|
423
417
|
return False
|
|
424
418
|
|
|
425
419
|
def load_ops(self):
|
|
426
|
-
from torchax.ops import jaten,
|
|
420
|
+
from torchax.ops import jaten, jc10d, jtorch, jtorchvision_nms # noqa: F401
|
|
427
421
|
|
|
428
|
-
for k, v in itertools.chain(
|
|
429
|
-
|
|
422
|
+
for k, v in itertools.chain(
|
|
423
|
+
ops_registry.all_aten_ops.items(), ops_registry.all_torch_functions.items()
|
|
424
|
+
):
|
|
430
425
|
if v.is_jax_function:
|
|
431
426
|
self._ops[k] = v
|
|
432
427
|
else:
|
|
@@ -437,16 +432,15 @@ class Environment(contextlib.ContextDecorator):
|
|
|
437
432
|
for k, v in DECOMPOSITIONS.items():
|
|
438
433
|
if k not in self._decomps:
|
|
439
434
|
self._decomps[k] = ops_registry.Operator(
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
435
|
+
k,
|
|
436
|
+
v,
|
|
437
|
+
is_jax_function=False,
|
|
438
|
+
is_user_defined=False,
|
|
439
|
+
needs_env=False,
|
|
440
|
+
is_view_op=k in MUTABLE_DECOMPOSITION,
|
|
446
441
|
)
|
|
447
442
|
|
|
448
443
|
def _get_op_or_decomp(self, func):
|
|
449
|
-
|
|
450
444
|
def _get_from_dict(op_dict, op):
|
|
451
445
|
op = op_dict.get(func)
|
|
452
446
|
if op is None and isinstance(func, torch._ops.OpOverloadPacket):
|
|
@@ -463,17 +457,18 @@ class Environment(contextlib.ContextDecorator):
|
|
|
463
457
|
|
|
464
458
|
if op is None:
|
|
465
459
|
raise OperatorNotFound(
|
|
466
|
-
|
|
460
|
+
f"Operator with name {_name_of_func(func)} has no lowering"
|
|
461
|
+
)
|
|
467
462
|
|
|
468
463
|
return op
|
|
469
464
|
|
|
470
465
|
def _is_same_device(self, the_tensor, new_device):
|
|
471
466
|
if new_device is None:
|
|
472
467
|
return True
|
|
473
|
-
if new_device ==
|
|
468
|
+
if new_device == "meta" and the_tensor.device.type == "jax":
|
|
474
469
|
return True
|
|
475
470
|
if the_tensor.device.type != new_device:
|
|
476
|
-
if the_tensor.device.type ==
|
|
471
|
+
if the_tensor.device.type == "cuda":
|
|
477
472
|
return self.config.treat_cuda_as_jax_device
|
|
478
473
|
return False
|
|
479
474
|
return True
|
|
@@ -501,8 +496,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
501
496
|
return res.to(device=new_device, dtype=new_dtype)
|
|
502
497
|
return res
|
|
503
498
|
|
|
504
|
-
def get_and_rotate_prng_key(self,
|
|
505
|
-
generator: Optional[torch.Generator] = None):
|
|
499
|
+
def get_and_rotate_prng_key(self, generator: torch.Generator | None = None):
|
|
506
500
|
if generator is not None:
|
|
507
501
|
return jax.random.PRNGKey(generator.initial_seed() % (2**63))
|
|
508
502
|
return self.param.get_and_rotate_prng_key()
|
|
@@ -514,7 +508,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
514
508
|
requires_grad = kwargs.get("requires_grad", False)
|
|
515
509
|
op = self._get_op_or_decomp(func)
|
|
516
510
|
if op.needs_env:
|
|
517
|
-
kwargs[
|
|
511
|
+
kwargs["env"] = self
|
|
518
512
|
if op.is_jax_function:
|
|
519
513
|
(args, kwargs) = self.t2j_iso((args, kwargs))
|
|
520
514
|
res = op.func(*args, **kwargs)
|
|
@@ -549,10 +543,10 @@ class Environment(contextlib.ContextDecorator):
|
|
|
549
543
|
if func in TENSOR_CONSTRUCTORS:
|
|
550
544
|
return self._handle_tensor_constructor(func, args, kwargs)
|
|
551
545
|
if func in (
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
546
|
+
torch.Tensor.to,
|
|
547
|
+
torch.ops.aten.lift_fresh.default,
|
|
548
|
+
torch.ops.aten._to_copy,
|
|
549
|
+
torch.ops.aten._to_copy.default,
|
|
556
550
|
):
|
|
557
551
|
return self._torch_Tensor_to(args, kwargs)
|
|
558
552
|
|
|
@@ -560,8 +554,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
560
554
|
# We should skip and let torch handle it.
|
|
561
555
|
|
|
562
556
|
tensor_args = [
|
|
563
|
-
|
|
564
|
-
if isinstance(t, torch.Tensor)
|
|
557
|
+
t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)
|
|
565
558
|
]
|
|
566
559
|
|
|
567
560
|
def is_not_torchax_tensor(x):
|
|
@@ -577,9 +570,9 @@ class Environment(contextlib.ContextDecorator):
|
|
|
577
570
|
old_args, old_kwargs = args, kwargs
|
|
578
571
|
with self._dispatch_mode:
|
|
579
572
|
args, kwargs = torch_pytree.tree_map_only(
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
573
|
+
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
574
|
+
torch.distributed._functional_collectives.wait_tensor,
|
|
575
|
+
(args, kwargs),
|
|
583
576
|
)
|
|
584
577
|
|
|
585
578
|
try:
|
|
@@ -590,8 +583,9 @@ class Environment(contextlib.ContextDecorator):
|
|
|
590
583
|
if self.param.autocast_dtype is not None:
|
|
591
584
|
autocast_policy = amp.autocast_policy.get(func)
|
|
592
585
|
if autocast_policy is not None:
|
|
593
|
-
args, kwargs = amp.execute_policy(
|
|
594
|
-
|
|
586
|
+
args, kwargs = amp.execute_policy(
|
|
587
|
+
autocast_policy, args, kwargs, self.param.autocast_dtype
|
|
588
|
+
)
|
|
595
589
|
|
|
596
590
|
if op.is_jax_function:
|
|
597
591
|
args, kwargs = self.t2j_iso((args, kwargs))
|
|
@@ -664,15 +658,17 @@ class Environment(contextlib.ContextDecorator):
|
|
|
664
658
|
"""
|
|
665
659
|
|
|
666
660
|
def to_jax(x):
|
|
667
|
-
if
|
|
668
|
-
|
|
661
|
+
if (
|
|
662
|
+
self.config.allow_mixed_math_with_scalar_tensor
|
|
663
|
+
and not isinstance(x, Tensor)
|
|
664
|
+
and not isinstance(x, View)
|
|
665
|
+
):
|
|
669
666
|
if x.squeeze().ndim == 0:
|
|
670
667
|
return x.item()
|
|
671
|
-
if isinstance(
|
|
672
|
-
x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
668
|
+
if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
673
669
|
x = x.wait()
|
|
674
670
|
assert isinstance(x, Tensor) or isinstance(x, View), (
|
|
675
|
-
|
|
671
|
+
f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
|
|
676
672
|
)
|
|
677
673
|
return x.jax()
|
|
678
674
|
|
|
@@ -680,7 +676,6 @@ class Environment(contextlib.ContextDecorator):
|
|
|
680
676
|
return res
|
|
681
677
|
|
|
682
678
|
def v2t_iso(self, views):
|
|
683
|
-
|
|
684
679
|
def to_tensor(x):
|
|
685
680
|
if isinstance(x, View):
|
|
686
681
|
return x.torch()
|
|
@@ -695,8 +690,7 @@ class Environment(contextlib.ContextDecorator):
|
|
|
695
690
|
This function will not copy, will just wrap the jax array with a torchax Tensor
|
|
696
691
|
Note: iso is short for "isomorphic"
|
|
697
692
|
"""
|
|
698
|
-
return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
|
|
699
|
-
jaxarray)
|
|
693
|
+
return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), jaxarray)
|
|
700
694
|
|
|
701
695
|
def j2t_copy(self, args):
|
|
702
696
|
"""Convert torch.Tensor in cpu to a jax array
|
|
@@ -704,9 +698,10 @@ class Environment(contextlib.ContextDecorator):
|
|
|
704
698
|
This might involves copying the data (depending if dlpack is enabled)
|
|
705
699
|
"""
|
|
706
700
|
return torch_pytree.tree_map_only(
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
701
|
+
jax.Array,
|
|
702
|
+
lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
|
|
703
|
+
args,
|
|
704
|
+
)
|
|
710
705
|
|
|
711
706
|
def t2j_copy(self, args):
|
|
712
707
|
"""Convert jax array to torch.Tensor in cpu.
|
|
@@ -714,18 +709,19 @@ class Environment(contextlib.ContextDecorator):
|
|
|
714
709
|
This might involves copying the data (depending if dlpack is enabled)
|
|
715
710
|
"""
|
|
716
711
|
return torch_pytree.tree_map_only(
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
712
|
+
torch.Tensor,
|
|
713
|
+
lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
|
|
714
|
+
args,
|
|
715
|
+
)
|
|
720
716
|
|
|
721
717
|
def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
|
|
722
718
|
self._ops[op_to_override] = ops_registry.Operator(
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
719
|
+
op_to_override,
|
|
720
|
+
op_impl,
|
|
721
|
+
is_jax_function=False,
|
|
722
|
+
is_user_defined=True,
|
|
723
|
+
needs_env=False,
|
|
724
|
+
is_view_op=is_view_op,
|
|
729
725
|
)
|
|
730
726
|
|
|
731
727
|
@contextlib.contextmanager
|