torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/tensor.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import threading
|
|
1
2
|
import logging
|
|
2
3
|
import sys
|
|
3
4
|
import contextlib
|
|
@@ -5,15 +6,17 @@ from typing import Optional, Any
|
|
|
5
6
|
import jax
|
|
6
7
|
import jax.numpy as jnp
|
|
7
8
|
import numpy
|
|
9
|
+
import itertools
|
|
8
10
|
import torch
|
|
9
11
|
import torch.distributed._functional_collectives
|
|
10
12
|
import torch.func
|
|
11
13
|
import torch.utils._mode_utils as mode_utils
|
|
12
14
|
import torch.utils._python_dispatch as torch_dispatch
|
|
13
15
|
import torch.utils._pytree as torch_pytree
|
|
14
|
-
|
|
16
|
+
from torchax.view import View
|
|
15
17
|
from torchax import config
|
|
16
18
|
from torchax.ops import mappings, ops_registry
|
|
19
|
+
from torchax import amp
|
|
17
20
|
|
|
18
21
|
logger = logging.getLogger(__name__)
|
|
19
22
|
|
|
@@ -22,63 +25,42 @@ class OperatorNotFound(Exception):
|
|
|
22
25
|
pass
|
|
23
26
|
|
|
24
27
|
|
|
25
|
-
def wrap(jaxarray):
|
|
26
|
-
return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def unwrap(torchtensors):
|
|
30
|
-
return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def t2j(t):
|
|
34
|
-
if isinstance(t, Tensor):
|
|
35
|
-
return t._elem
|
|
36
|
-
return mappings.t2j(t)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def j2t(x):
|
|
40
|
-
return mappings.j2t(x)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def t2j_dtype(dtype):
|
|
44
|
-
return mappings.t2j_dtype(dtype)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def j2t_dtype(dtype):
|
|
48
|
-
return mappings.j2t_dtype(dtype)
|
|
49
|
-
|
|
50
|
-
|
|
51
28
|
@contextlib.contextmanager
|
|
52
29
|
def log_nested(env, message):
|
|
53
30
|
if env.config.debug_print_each_op:
|
|
54
|
-
print((
|
|
31
|
+
print((" " * log_nested.level) + message, file=sys.stderr)
|
|
55
32
|
log_nested.level += 1
|
|
56
33
|
yield
|
|
57
34
|
log_nested.level -= 1
|
|
58
35
|
|
|
36
|
+
|
|
59
37
|
log_nested.level = 0
|
|
60
38
|
|
|
61
39
|
|
|
62
40
|
class Tensor(torch.Tensor):
|
|
63
41
|
|
|
64
42
|
@staticmethod
|
|
65
|
-
def __new__(cls, elem, env):
|
|
66
|
-
dtype = j2t_dtype(elem.dtype)
|
|
43
|
+
def __new__(cls, elem, env, requires_grad=False):
|
|
44
|
+
dtype = mappings.j2t_dtype(elem.dtype)
|
|
67
45
|
shape = list(elem.shape)
|
|
68
46
|
for i, s in enumerate(shape):
|
|
69
47
|
if not isinstance(s, int):
|
|
70
48
|
shape[i] = 1
|
|
71
49
|
if dtype is None:
|
|
72
50
|
dtype = torch.float32
|
|
51
|
+
#dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
|
|
52
|
+
if not (dtype.is_floating_point or dtype.is_complex):
|
|
53
|
+
requires_grad = False
|
|
54
|
+
|
|
73
55
|
return torch.Tensor._make_wrapper_subclass(
|
|
74
56
|
cls,
|
|
75
57
|
shape,
|
|
76
58
|
dtype=dtype,
|
|
77
59
|
device='meta',
|
|
78
|
-
requires_grad=
|
|
60
|
+
requires_grad=requires_grad,
|
|
79
61
|
)
|
|
80
62
|
|
|
81
|
-
def __init__(self, elem: jax.Array, env:
|
|
63
|
+
def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
|
|
82
64
|
super().__init__()
|
|
83
65
|
self._elem = elem
|
|
84
66
|
self._env = env
|
|
@@ -88,12 +70,9 @@ class Tensor(torch.Tensor):
|
|
|
88
70
|
|
|
89
71
|
__repr__ = __str__
|
|
90
72
|
|
|
91
|
-
def __jax_array__(self):
|
|
92
|
-
return self._elem
|
|
93
|
-
|
|
94
73
|
@property
|
|
95
74
|
def shape(self):
|
|
96
|
-
return self._elem.shape
|
|
75
|
+
return torch.Size(self._elem.shape)
|
|
97
76
|
|
|
98
77
|
@property
|
|
99
78
|
def ndim(self):
|
|
@@ -120,14 +99,15 @@ class Tensor(torch.Tensor):
|
|
|
120
99
|
|
|
121
100
|
@classmethod
|
|
122
101
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
102
|
+
# TODO(hanq): figure out why is dispatch mode not sufficient
|
|
103
|
+
if func == torch.ops._c10d_functional.wait_tensor.default:
|
|
104
|
+
return args[0]._env.dispatch(func, types, args, kwargs)
|
|
105
|
+
if func == torch.ops.prim.device.default:
|
|
106
|
+
return torch.device('privateuseone', 0)
|
|
107
|
+
raise AssertionError(
|
|
108
|
+
'torchax Tensors can only do math within the torchax environment.'
|
|
109
|
+
'Please wrap your code with `with torchax.default_env()` or '
|
|
110
|
+
'call torchax.enable_globally() before.')
|
|
131
111
|
|
|
132
112
|
def detach(self):
|
|
133
113
|
return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
|
|
@@ -141,18 +121,18 @@ class Tensor(torch.Tensor):
|
|
|
141
121
|
return self._elem
|
|
142
122
|
|
|
143
123
|
def torch(self) -> torch.Tensor:
|
|
144
|
-
return
|
|
124
|
+
return self._env.j2t_copy(self.jax())
|
|
145
125
|
|
|
146
126
|
@property
|
|
147
127
|
def dtype(self):
|
|
148
|
-
return j2t_dtype(self._elem.dtype)
|
|
128
|
+
return mappings.j2t_dtype(self._elem.dtype)
|
|
149
129
|
|
|
150
130
|
def dim(self):
|
|
151
131
|
return self.ndim
|
|
152
132
|
|
|
153
133
|
@property
|
|
154
134
|
def device(self):
|
|
155
|
-
return torch.device(
|
|
135
|
+
return torch.device("jax:0")
|
|
156
136
|
|
|
157
137
|
@property
|
|
158
138
|
def jax_device(self):
|
|
@@ -160,7 +140,8 @@ class Tensor(torch.Tensor):
|
|
|
160
140
|
|
|
161
141
|
@property
|
|
162
142
|
def data(self):
|
|
163
|
-
logger.
|
|
143
|
+
logger.warning(
|
|
144
|
+
"In-place to .data modifications still results a copy on TPU")
|
|
164
145
|
return self
|
|
165
146
|
|
|
166
147
|
@data.setter
|
|
@@ -182,15 +163,15 @@ class Tensor(torch.Tensor):
|
|
|
182
163
|
|
|
183
164
|
def shard_(self, sharding):
|
|
184
165
|
self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
|
|
185
|
-
|
|
166
|
+
|
|
186
167
|
|
|
187
168
|
def debug_accuracy(func, args, kwargs, current_output):
|
|
188
169
|
args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
|
|
189
|
-
torch.Tensor, lambda x:
|
|
170
|
+
torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output))
|
|
190
171
|
|
|
191
172
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
192
|
-
if
|
|
193
|
-
kwargs_torch[
|
|
173
|
+
if "device" in kwargs_torch:
|
|
174
|
+
kwargs_torch["device"] = "cpu" # do the torch native for comparison
|
|
194
175
|
expected_out = func(*args_torch, **kwargs_torch)
|
|
195
176
|
|
|
196
177
|
flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
|
|
@@ -200,8 +181,8 @@ def debug_accuracy(func, args, kwargs, current_output):
|
|
|
200
181
|
if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
|
|
201
182
|
ex = ex.to(real.dtype)
|
|
202
183
|
try:
|
|
203
|
-
if
|
|
204
|
-
|
|
184
|
+
if isinstance(ex, torch.Tensor) and not torch.allclose(
|
|
185
|
+
ex, real, atol=1e-3, equal_nan=True):
|
|
205
186
|
import pdb
|
|
206
187
|
|
|
207
188
|
pdb.set_trace()
|
|
@@ -212,46 +193,52 @@ def debug_accuracy(func, args, kwargs, current_output):
|
|
|
212
193
|
|
|
213
194
|
return True
|
|
214
195
|
|
|
196
|
+
|
|
215
197
|
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
|
|
198
|
+
|
|
216
199
|
def _display(a):
|
|
217
200
|
if isinstance(a, torch.Tensor):
|
|
218
|
-
return f
|
|
201
|
+
return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
|
|
219
202
|
elif isinstance(a, jax.Array):
|
|
220
|
-
return f
|
|
203
|
+
return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
|
|
221
204
|
else:
|
|
222
205
|
return str(a)
|
|
223
206
|
|
|
224
207
|
kwargs = kwargs or {}
|
|
225
|
-
title =
|
|
226
|
-
args_msg =
|
|
227
|
-
kwargs_msg =
|
|
228
|
-
|
|
208
|
+
title = "DISPATCH" if is_dispatch else "FUNCTION"
|
|
209
|
+
args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
|
|
210
|
+
kwargs_msg = ("kwargs: " +
|
|
211
|
+
",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
|
|
212
|
+
if log_args else "")
|
|
213
|
+
return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
|
|
229
214
|
|
|
230
215
|
|
|
231
216
|
class XLAFunctionMode(torch.overrides.TorchFunctionMode):
|
|
232
217
|
"""Context manager that dispatches torch function calls to JAX."""
|
|
233
218
|
|
|
234
219
|
def __init__(self, env):
|
|
235
|
-
|
|
220
|
+
self.env = env
|
|
236
221
|
|
|
237
222
|
def __torch_function__(self,
|
|
238
223
|
func,
|
|
239
224
|
types,
|
|
240
225
|
args=(),
|
|
241
226
|
kwargs=None) -> torch.Tensor:
|
|
242
|
-
message = f
|
|
227
|
+
message = f"FUNCTION: {_name_of_func(func)}"
|
|
243
228
|
if self.env.config.debug_print_each_op_operands:
|
|
244
|
-
message = message +
|
|
245
|
-
message = _make_debug_msg(False,
|
|
229
|
+
message = message + "f"
|
|
230
|
+
message = _make_debug_msg(False,
|
|
231
|
+
self.env.config.debug_print_each_op_operands,
|
|
246
232
|
func, args, kwargs)
|
|
247
233
|
with log_nested(self.env, message):
|
|
248
234
|
try:
|
|
249
235
|
return self.env.dispatch(func, types, args, kwargs)
|
|
250
236
|
except OperatorNotFound:
|
|
251
237
|
pass
|
|
252
|
-
if _name_of_func(func) in (
|
|
238
|
+
if _name_of_func(func) in (
|
|
239
|
+
"rot90"): # skip rot90 with k%4==0 due to no change
|
|
253
240
|
if len(args) >= 2 and type(args[1]) == int:
|
|
254
|
-
if (
|
|
241
|
+
if (args[1]) % 4 == 0:
|
|
255
242
|
return args[0]
|
|
256
243
|
return func(*args, **(kwargs or {}))
|
|
257
244
|
|
|
@@ -262,296 +249,463 @@ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
|
|
|
262
249
|
self.env = env
|
|
263
250
|
|
|
264
251
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
265
|
-
message = _make_debug_msg(True,
|
|
252
|
+
message = _make_debug_msg(True,
|
|
253
|
+
self.env.config.debug_print_each_op_operands,
|
|
266
254
|
func, args, kwargs)
|
|
267
255
|
with log_nested(self.env, message):
|
|
268
256
|
if isinstance(func, torch._ops.OpOverloadPacket):
|
|
269
257
|
with self:
|
|
270
258
|
return func(*args, **kwargs)
|
|
271
|
-
|
|
259
|
+
# Only functions under these namespaces will be intercepted
|
|
260
|
+
if func.namespace not in (
|
|
261
|
+
"aten",
|
|
262
|
+
"_c10d_functional",
|
|
263
|
+
"torchvision",
|
|
264
|
+
"xla",
|
|
265
|
+
):
|
|
272
266
|
return func(*args, **kwargs)
|
|
273
267
|
return self.env.dispatch(func, types, args, kwargs)
|
|
274
268
|
|
|
269
|
+
|
|
275
270
|
def _name_of_func(func):
|
|
276
|
-
if hasattr(func,
|
|
271
|
+
if hasattr(func, "name"):
|
|
277
272
|
return func.name()
|
|
278
273
|
return func.__name__
|
|
279
274
|
|
|
280
275
|
|
|
281
276
|
# Constructors that don't take other tensor as input
|
|
282
277
|
TENSOR_CONSTRUCTORS = {
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
278
|
+
torch.ones,
|
|
279
|
+
torch.zeros,
|
|
280
|
+
torch.empty,
|
|
281
|
+
torch.empty_strided,
|
|
282
|
+
torch.tensor,
|
|
283
|
+
torch.arange,
|
|
284
|
+
torch.eye,
|
|
285
|
+
torch.randn,
|
|
286
|
+
torch.rand,
|
|
287
|
+
torch.randint,
|
|
288
|
+
torch.full,
|
|
289
|
+
torch.as_tensor,
|
|
295
290
|
}
|
|
296
291
|
|
|
292
|
+
# TODO(wen): use existing types, either from torch or jax
|
|
293
|
+
SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
|
|
297
294
|
|
|
298
|
-
class Environment(contextlib.ContextDecorator):
|
|
299
|
-
"""This class holds a set of configurations and "globals" needed
|
|
300
295
|
|
|
301
|
-
|
|
302
|
-
|
|
296
|
+
class RuntimeProperty:
|
|
297
|
+
mesh: Any
|
|
298
|
+
prng: Any
|
|
299
|
+
autocast_dtype: Any
|
|
303
300
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
301
|
+
def __init__(self, mesh, prng, autocast_dtype):
|
|
302
|
+
self.mesh = mesh
|
|
303
|
+
self.prng = prng
|
|
304
|
+
self.autocast_dtype = autocast_dtype
|
|
307
305
|
|
|
308
|
-
|
|
309
|
-
|
|
306
|
+
def override(self, **kwargs):
|
|
307
|
+
return OverrideProperty(self, kwargs)
|
|
308
|
+
|
|
309
|
+
def get_and_rotate_prng_key(self):
|
|
310
|
+
old_key = self.prng
|
|
311
|
+
new_prng_key, next_key = jax.random.split(old_key)
|
|
312
|
+
self.prng = new_prng_key
|
|
313
|
+
return next_key
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class OverrideProperty(RuntimeProperty):
|
|
310
317
|
|
|
311
|
-
|
|
318
|
+
def __init__(self, parent, override):
|
|
319
|
+
self.parent = parent
|
|
320
|
+
self._override = dict(override)
|
|
312
321
|
|
|
322
|
+
def __getattr__(self, name):
|
|
323
|
+
if name in self._override:
|
|
324
|
+
return self._override[name]
|
|
325
|
+
return getattr(self.parent, name)
|
|
313
326
|
|
|
314
|
-
def __init__(self, configuration=None):
|
|
315
|
-
self._function_mode = XLAFunctionMode(self)
|
|
316
|
-
self._dispatch_mode = XLADispatchMode(self)
|
|
317
327
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
328
|
+
class Environment(contextlib.ContextDecorator):
|
|
329
|
+
"""This class holds a set of configurations and "globals" needed
|
|
330
|
+
|
|
331
|
+
for executing torch program using jax.
|
|
332
|
+
Things included so far:
|
|
333
|
+
|
|
334
|
+
op registry
|
|
335
|
+
PRNGKey
|
|
336
|
+
Configs
|
|
321
337
|
|
|
322
|
-
|
|
323
|
-
|
|
338
|
+
Also helper functions to manipulate those.
|
|
339
|
+
"""
|
|
324
340
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
341
|
+
def __init__(self, configuration=None):
|
|
342
|
+
self._function_mode = XLAFunctionMode(self)
|
|
343
|
+
self._dispatch_mode = XLADispatchMode(self)
|
|
328
344
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
345
|
+
# name is torch callable
|
|
346
|
+
self._ops = {}
|
|
347
|
+
self._decomps = {}
|
|
332
348
|
|
|
333
|
-
|
|
334
|
-
device = str(device)
|
|
349
|
+
self.load_ops()
|
|
335
350
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
return jax.devices('cpu')[0]
|
|
351
|
+
_mesh = None
|
|
352
|
+
self.config = configuration or config.Configuration()
|
|
339
353
|
|
|
340
|
-
|
|
341
|
-
return jax.local_devices()[0]
|
|
354
|
+
self.enabled = False
|
|
342
355
|
|
|
343
|
-
|
|
344
|
-
return jax.local_devices()[0]
|
|
356
|
+
autocast_dtype = None
|
|
345
357
|
|
|
346
|
-
|
|
347
|
-
|
|
358
|
+
_prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
|
|
359
|
+
self._property = threading.local()
|
|
360
|
+
self._property.content = [
|
|
361
|
+
RuntimeProperty(
|
|
362
|
+
mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype)
|
|
363
|
+
]
|
|
348
364
|
|
|
365
|
+
@property
|
|
366
|
+
def param(self):
|
|
367
|
+
return self._property.content[-1]
|
|
349
368
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
369
|
+
def manual_seed(self, key):
|
|
370
|
+
jax_key = jax.random.PRNGKey(key)
|
|
371
|
+
new_prop = self.param.override(prng=jax_key)
|
|
372
|
+
self._property.content.append(new_prop)
|
|
354
373
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
374
|
+
@property
|
|
375
|
+
def prng_key(self):
|
|
376
|
+
return self.param.prng
|
|
377
|
+
|
|
378
|
+
def _should_use_torchax_tensor(self, device):
|
|
379
|
+
if device is None:
|
|
380
|
+
device = torch.get_default_device()
|
|
381
|
+
|
|
382
|
+
if isinstance(device, torch.device):
|
|
383
|
+
device = device.type
|
|
384
|
+
|
|
385
|
+
if ':' in device:
|
|
386
|
+
device = device.split(':')[0]
|
|
387
|
+
|
|
388
|
+
match device:
|
|
389
|
+
case 'cpu':
|
|
390
|
+
return False
|
|
391
|
+
case 'cuda':
|
|
392
|
+
return self.config.treat_cuda_as_jax_device
|
|
393
|
+
case 'jax':
|
|
394
|
+
return True
|
|
395
|
+
case 'privateuseone':
|
|
396
|
+
return True
|
|
397
|
+
case 'meta':
|
|
398
|
+
return self.enabled
|
|
399
|
+
return False
|
|
400
|
+
|
|
401
|
+
def load_ops(self):
|
|
402
|
+
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
|
|
403
|
+
|
|
404
|
+
for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
|
|
405
|
+
ops_registry.all_torch_functions.items()):
|
|
406
|
+
if v.is_jax_function:
|
|
407
|
+
self._ops[k] = v
|
|
408
|
+
else:
|
|
409
|
+
self._decomps[k] = v
|
|
410
|
+
|
|
411
|
+
from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
|
|
412
|
+
|
|
413
|
+
for k, v in DECOMPOSITIONS.items():
|
|
414
|
+
if k not in self._decomps:
|
|
415
|
+
self._decomps[k] = ops_registry.Operator(
|
|
361
416
|
k,
|
|
362
417
|
v,
|
|
363
418
|
is_jax_function=False,
|
|
364
419
|
is_user_defined=False,
|
|
365
|
-
needs_env=False
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
420
|
+
needs_env=False,
|
|
421
|
+
is_view_op=k in MUTABLE_DECOMPOSITION,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
def _get_op_or_decomp(self, func):
|
|
425
|
+
|
|
426
|
+
def _get_from_dict(op_dict, op):
|
|
427
|
+
op = op_dict.get(func)
|
|
428
|
+
if op is None and isinstance(func, torch._ops.OpOverloadPacket):
|
|
429
|
+
op = op_dict.get(func.default)
|
|
430
|
+
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
431
|
+
op = op_dict.get(func.overloadpacket)
|
|
432
|
+
return op
|
|
433
|
+
|
|
434
|
+
op = _get_from_dict(self._ops, func)
|
|
435
|
+
|
|
436
|
+
if op is None:
|
|
437
|
+
# fallback to decompose
|
|
438
|
+
op = _get_from_dict(self._decomps, func)
|
|
439
|
+
|
|
440
|
+
if op is None:
|
|
441
|
+
raise OperatorNotFound(
|
|
442
|
+
f"Operator with name {_name_of_func(func)} has no lowering")
|
|
443
|
+
|
|
444
|
+
return op
|
|
445
|
+
|
|
446
|
+
def _is_same_device(self, the_tensor, new_device):
|
|
447
|
+
if new_device is None:
|
|
448
|
+
return True
|
|
449
|
+
if new_device == 'meta' and the_tensor.device.type == 'jax':
|
|
450
|
+
return True
|
|
451
|
+
if the_tensor.device.type != new_device:
|
|
452
|
+
if the_tensor.device.type == 'cuda':
|
|
453
|
+
return self.config.treat_cuda_as_jax_device
|
|
454
|
+
return False
|
|
455
|
+
return True
|
|
456
|
+
|
|
457
|
+
def _to_copy(self, the_tensor, new_dtype, new_device):
|
|
458
|
+
if isinstance(the_tensor, View):
|
|
459
|
+
the_tensor = the_tensor.torch()
|
|
460
|
+
if isinstance(new_device, torch.device):
|
|
461
|
+
new_device = new_device.type
|
|
462
|
+
res = the_tensor
|
|
463
|
+
if not self._is_same_device(the_tensor, new_device):
|
|
369
464
|
if isinstance(the_tensor, Tensor):
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
if new_device is not None:
|
|
374
|
-
# convert xla tensor to other device
|
|
375
|
-
# only supported is CPU
|
|
376
|
-
if str(new_device).startswith('cpu'):
|
|
377
|
-
# converting to a non-jax device: let torch native handle it
|
|
378
|
-
torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
|
|
379
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
380
|
-
return torch_tensor.to(new_device)
|
|
465
|
+
torch_tensor = self.j2t_copy(the_tensor._elem)
|
|
466
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
467
|
+
return torch_tensor.to(device=new_device, dtype=new_dtype)
|
|
381
468
|
else:
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
the_tensor = the_tensor.to(new_dtype)
|
|
385
|
-
jax_device = self.get_as_jax_device(new_device)
|
|
386
|
-
if jax_device:
|
|
387
|
-
arr = t2j(the_tensor)
|
|
388
|
-
arr = jax.device_put(arr, jax_device)
|
|
389
|
-
else:
|
|
390
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
391
|
-
return the_tensor.to(new_device)
|
|
392
|
-
|
|
393
|
-
return Tensor(arr, self)
|
|
394
|
-
|
|
469
|
+
arr = self.t2j_copy(the_tensor)
|
|
470
|
+
res = Tensor(arr, self, the_tensor.requires_grad)
|
|
395
471
|
|
|
396
|
-
|
|
397
|
-
|
|
472
|
+
if new_dtype is not None and new_dtype != the_tensor.dtype:
|
|
473
|
+
if isinstance(the_tensor, Tensor):
|
|
474
|
+
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
|
|
475
|
+
else:
|
|
476
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
477
|
+
return the_tensor.to(device=new_device, dtype=new_dtype)
|
|
478
|
+
return res
|
|
479
|
+
|
|
480
|
+
def get_and_rotate_prng_key(self,
|
|
481
|
+
generator: Optional[torch.Generator] = None):
|
|
482
|
+
if generator is not None:
|
|
483
|
+
return jax.random.PRNGKey(generator.initial_seed() % (2**63))
|
|
484
|
+
return self.param.get_and_rotate_prng_key()
|
|
485
|
+
|
|
486
|
+
def _handle_tensor_constructor(self, func, args, kwargs):
|
|
487
|
+
device = kwargs.get("device")
|
|
488
|
+
if self._should_use_torchax_tensor(device):
|
|
489
|
+
# don't set default device, let caller set it
|
|
490
|
+
requires_grad = kwargs.get("requires_grad", False)
|
|
491
|
+
op = self._get_op_or_decomp(func)
|
|
492
|
+
if op.needs_env:
|
|
493
|
+
kwargs['env'] = self
|
|
494
|
+
if op.is_jax_function:
|
|
495
|
+
(args, kwargs) = self.t2j_iso((args, kwargs))
|
|
496
|
+
res = op.func(*args, **kwargs)
|
|
497
|
+
if isinstance(res, jax.Array):
|
|
498
|
+
res = Tensor(res, self, requires_grad)
|
|
499
|
+
return res
|
|
500
|
+
else:
|
|
398
501
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
399
|
-
|
|
400
|
-
0, 2**32, (), dtype=torch.uint32, generator=generator).numpy()
|
|
401
|
-
|
|
402
|
-
return jax.random.key(next_key)
|
|
502
|
+
return func(*args, **kwargs)
|
|
403
503
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
411
|
-
return func(*args, **kwargs)
|
|
412
|
-
with jax.default_device(jax_device):
|
|
413
|
-
op = self._ops.get(func)
|
|
414
|
-
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
415
|
-
op = self._ops.get(func.overloadpacket)
|
|
416
|
-
res = op.func(*args, **kwargs)
|
|
417
|
-
if isinstance(res, jax.Array):
|
|
418
|
-
res = Tensor(res, self)
|
|
419
|
-
return res
|
|
420
|
-
|
|
421
|
-
def _torch_Tensor_to(self, args, kwargs):
|
|
422
|
-
the_tensor = args[0]
|
|
423
|
-
args = args[1:]
|
|
424
|
-
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
|
|
425
|
-
dtype = args[0].dtype
|
|
426
|
-
device = args[0].device
|
|
427
|
-
return self._to_copy(the_tensor, dtype, device)
|
|
428
|
-
device = kwargs.get('device')
|
|
429
|
-
dtype = kwargs.get('dtype')
|
|
430
|
-
# args like pin_memory etc that we will ignore
|
|
431
|
-
args = list(filter(lambda x: not isinstance(x, bool), args))
|
|
432
|
-
if len(args) >= 2:
|
|
433
|
-
device, dtype, *_ = args
|
|
434
|
-
elif len(args) == 1 and isinstance(args[0], torch.dtype):
|
|
435
|
-
dtype = args[0]
|
|
436
|
-
elif len(args) == 1:
|
|
437
|
-
device = args[0]
|
|
504
|
+
def _torch_Tensor_to(self, args, kwargs):
|
|
505
|
+
the_tensor = args[0]
|
|
506
|
+
args = args[1:]
|
|
507
|
+
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
|
|
508
|
+
dtype = args[0].dtype
|
|
509
|
+
device = args[0].device
|
|
438
510
|
return self._to_copy(the_tensor, dtype, device)
|
|
511
|
+
device = kwargs.get("device")
|
|
512
|
+
dtype = kwargs.get("dtype")
|
|
513
|
+
# args like pin_memory etc that we will ignore
|
|
514
|
+
args = list(filter(lambda x: not isinstance(x, bool), args))
|
|
515
|
+
if len(args) >= 2:
|
|
516
|
+
device, dtype, *_ = args
|
|
517
|
+
elif len(args) == 1 and isinstance(args[0], torch.dtype):
|
|
518
|
+
dtype = args[0]
|
|
519
|
+
elif len(args) == 1:
|
|
520
|
+
device = args[0]
|
|
521
|
+
return self._to_copy(the_tensor, dtype, device)
|
|
522
|
+
|
|
523
|
+
def dispatch(self, func, types, args, kwargs):
|
|
524
|
+
kwargs = kwargs or {}
|
|
525
|
+
if func in TENSOR_CONSTRUCTORS:
|
|
526
|
+
return self._handle_tensor_constructor(func, args, kwargs)
|
|
527
|
+
if func in (
|
|
528
|
+
torch.Tensor.to,
|
|
529
|
+
torch.ops.aten.lift_fresh.default,
|
|
530
|
+
torch.ops.aten._to_copy,
|
|
531
|
+
torch.ops.aten._to_copy.default,
|
|
532
|
+
):
|
|
533
|
+
return self._torch_Tensor_to(args, kwargs)
|
|
534
|
+
|
|
535
|
+
# If the func doesn't act on Tensor, and is not a tensor constructor,
|
|
536
|
+
# We should skip and let torch handle it.
|
|
537
|
+
|
|
538
|
+
tensor_args = [
|
|
539
|
+
t for t in torch_pytree.tree_flatten(args)[0]
|
|
540
|
+
if isinstance(t, torch.Tensor)
|
|
541
|
+
]
|
|
542
|
+
|
|
543
|
+
def is_not_torchax_tensor(x):
|
|
544
|
+
return not isinstance(x, Tensor) and not isinstance(x, View)
|
|
545
|
+
|
|
546
|
+
if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
|
|
547
|
+
res = func(*args, **kwargs)
|
|
548
|
+
return res
|
|
439
549
|
|
|
550
|
+
with jax.named_scope(_name_of_func(func)):
|
|
551
|
+
op = self._get_op_or_decomp(func)
|
|
440
552
|
|
|
441
|
-
|
|
553
|
+
old_args, old_kwargs = args, kwargs
|
|
554
|
+
with self._dispatch_mode:
|
|
555
|
+
args, kwargs = torch_pytree.tree_map_only(
|
|
556
|
+
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
557
|
+
torch.distributed._functional_collectives.wait_tensor,
|
|
558
|
+
(args, kwargs),
|
|
559
|
+
)
|
|
442
560
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default):
|
|
447
|
-
return self._torch_Tensor_to(args, kwargs)
|
|
561
|
+
try:
|
|
562
|
+
if not op.is_view_op:
|
|
563
|
+
args, kwargs = self.v2t_iso((args, kwargs))
|
|
448
564
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
565
|
+
with self:
|
|
566
|
+
if self.param.autocast_dtype is not None:
|
|
567
|
+
autocast_policy = amp.autocast_policy.get(func)
|
|
568
|
+
if autocast_policy is not None:
|
|
569
|
+
args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
|
|
570
|
+
self.param.autocast_dtype)
|
|
455
571
|
|
|
456
|
-
|
|
457
|
-
|
|
572
|
+
if op.is_jax_function:
|
|
573
|
+
args, kwargs = self.t2j_iso((args, kwargs))
|
|
574
|
+
except AssertionError:
|
|
575
|
+
if self.config.debug_mixed_tensor:
|
|
576
|
+
breakpoint()
|
|
577
|
+
else:
|
|
578
|
+
raise
|
|
458
579
|
|
|
459
|
-
|
|
460
|
-
|
|
580
|
+
if op.needs_env:
|
|
581
|
+
kwargs["env"] = self
|
|
461
582
|
|
|
462
|
-
|
|
463
|
-
|
|
583
|
+
if op.is_jax_function:
|
|
584
|
+
res = op.func(*args, **kwargs)
|
|
585
|
+
else:
|
|
586
|
+
# enable dispatch mode because this op could be a composite autograd op
|
|
587
|
+
# meaning, it will decompose in C++
|
|
588
|
+
with self._dispatch_mode:
|
|
589
|
+
res = op.func(*args, **kwargs)
|
|
464
590
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
f'Operator with name {_name_of_func(func)} has no lowering')
|
|
591
|
+
if op.is_jax_function:
|
|
592
|
+
res = self.j2t_iso(res)
|
|
468
593
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
472
|
-
torch.distributed._functional_collectives.wait_tensor,
|
|
473
|
-
(args, kwargs))
|
|
474
|
-
try:
|
|
475
|
-
if op.is_jax_function:
|
|
476
|
-
args, kwargs = self.t2j_iso((args, kwargs))
|
|
477
|
-
except AssertionError:
|
|
478
|
-
if self.config.debug_mixed_tensor:
|
|
479
|
-
import pdb; pdb.set_trace()
|
|
480
|
-
else:
|
|
481
|
-
raise
|
|
594
|
+
if self.config.force_materialize_views and isinstance(res, View):
|
|
595
|
+
res = res.torch()
|
|
482
596
|
|
|
597
|
+
if self.config.debug_accuracy_for_each_op:
|
|
598
|
+
debug_accuracy(func, old_args, old_kwargs, res)
|
|
599
|
+
return res
|
|
483
600
|
|
|
484
|
-
|
|
485
|
-
|
|
601
|
+
def enable_torch_modes(self):
|
|
602
|
+
self._dispatch_mode.__enter__()
|
|
603
|
+
self._function_mode.__enter__()
|
|
604
|
+
self.enabled = True
|
|
486
605
|
|
|
487
|
-
|
|
488
|
-
|
|
606
|
+
def disable_torch_modes(self, *exc):
|
|
607
|
+
if not exc:
|
|
608
|
+
exc = (None, None, None)
|
|
609
|
+
self._function_mode.__exit__(*exc)
|
|
610
|
+
self._dispatch_mode.__exit__(*exc)
|
|
611
|
+
self.enabled = False
|
|
489
612
|
|
|
490
|
-
|
|
491
|
-
|
|
613
|
+
def __enter__(self):
|
|
614
|
+
self.enable_torch_modes()
|
|
615
|
+
return self
|
|
492
616
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
return res
|
|
617
|
+
def __exit__(self, *exc):
|
|
618
|
+
self.disable_torch_modes(*exc)
|
|
496
619
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
self
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
def disable_torch_modes(self, *exc):
|
|
503
|
-
if not exc:
|
|
504
|
-
exc = (None, None, None)
|
|
505
|
-
self._function_mode.__exit__(*exc)
|
|
506
|
-
self._dispatch_mode.__exit__(*exc)
|
|
507
|
-
self.enabled = False
|
|
508
|
-
|
|
509
|
-
def __enter__(self):
|
|
510
|
-
self.enable_torch_modes()
|
|
511
|
-
self._manually_entered = True
|
|
512
|
-
return self
|
|
513
|
-
|
|
514
|
-
def __exit__(self, *exc):
|
|
515
|
-
self._manually_entered = False
|
|
516
|
-
self.disable_torch_modes(*exc)
|
|
517
|
-
|
|
518
|
-
def _move_one_value(self, val):
|
|
519
|
-
if isinstance(val, torch.nn.Module):
|
|
520
|
-
with self:
|
|
521
|
-
return val.to('jax')
|
|
522
|
-
if isinstance(val, Tensor):
|
|
523
|
-
return val
|
|
524
|
-
if isinstance(val, torch.Tensor):
|
|
525
|
-
return Tensor(t2j(val), self)
|
|
620
|
+
def _move_one_value(self, val):
|
|
621
|
+
if isinstance(val, torch.nn.Module):
|
|
622
|
+
with self:
|
|
623
|
+
return val.to("jax")
|
|
624
|
+
if isinstance(val, Tensor):
|
|
526
625
|
return val
|
|
626
|
+
if isinstance(val, torch.Tensor):
|
|
627
|
+
return Tensor(self.t2j_copy(val), self)
|
|
628
|
+
return val
|
|
527
629
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
630
|
+
def to_xla(self, torchvalues):
|
|
631
|
+
# tensors are torch.Tensors (not XLATensor)
|
|
632
|
+
res = torch_pytree.tree_map(self._move_one_value, torchvalues)
|
|
633
|
+
return res
|
|
634
|
+
|
|
635
|
+
def t2j_iso(self, torchtensors):
|
|
636
|
+
"""Convert torchax Tensor to jax array.
|
|
637
|
+
|
|
638
|
+
This function will not copy, will just unwrap the inner jax array out.
|
|
639
|
+
Note: iso is short for "isomorphic"
|
|
640
|
+
"""
|
|
641
|
+
|
|
642
|
+
def to_jax(x):
|
|
643
|
+
if self.config.allow_mixed_math_with_scalar_tensor and not isinstance(
|
|
644
|
+
x, Tensor):
|
|
645
|
+
if x.squeeze().ndim == 0:
|
|
646
|
+
return x.item()
|
|
647
|
+
if isinstance(
|
|
648
|
+
x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
649
|
+
x = x.wait()
|
|
650
|
+
assert isinstance(x, Tensor) or isinstance(x, View), (
|
|
651
|
+
f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
|
|
652
|
+
)
|
|
653
|
+
return x.jax()
|
|
654
|
+
|
|
655
|
+
res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
|
|
656
|
+
return res
|
|
534
657
|
|
|
535
|
-
|
|
536
|
-
def to_jax(x):
|
|
537
|
-
if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
538
|
-
x = x.wait()
|
|
539
|
-
assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor'
|
|
540
|
-
return x.jax()
|
|
541
|
-
return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
|
|
658
|
+
def v2t_iso(self, views):
|
|
542
659
|
|
|
543
|
-
def
|
|
544
|
-
|
|
545
|
-
|
|
660
|
+
def to_tensor(x):
|
|
661
|
+
if isinstance(x, View):
|
|
662
|
+
return x.torch()
|
|
663
|
+
return x
|
|
546
664
|
|
|
547
|
-
|
|
548
|
-
|
|
665
|
+
res = torch_pytree.tree_map_only(View, to_tensor, views)
|
|
666
|
+
return res
|
|
549
667
|
|
|
550
|
-
|
|
551
|
-
|
|
668
|
+
def j2t_iso(self, jaxarray):
|
|
669
|
+
"""Convert jax array to torchax Tensor.
|
|
670
|
+
|
|
671
|
+
This function will not copy, will just wrap the jax array with a torchax Tensor
|
|
672
|
+
Note: iso is short for "isomorphic"
|
|
673
|
+
"""
|
|
674
|
+
return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
|
|
675
|
+
jaxarray)
|
|
676
|
+
|
|
677
|
+
def j2t_copy(self, args):
|
|
678
|
+
"""Convert torch.Tensor in cpu to a jax array
|
|
679
|
+
|
|
680
|
+
This might involves copying the data (depending if dlpack is enabled)
|
|
681
|
+
"""
|
|
682
|
+
return torch_pytree.tree_map_only(
|
|
683
|
+
jax.Array,
|
|
684
|
+
lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
|
|
685
|
+
args)
|
|
686
|
+
|
|
687
|
+
def t2j_copy(self, args):
|
|
688
|
+
"""Convert jax array to torch.Tensor in cpu.
|
|
689
|
+
|
|
690
|
+
This might involves copying the data (depending if dlpack is enabled)
|
|
691
|
+
"""
|
|
692
|
+
return torch_pytree.tree_map_only(
|
|
693
|
+
torch.Tensor,
|
|
694
|
+
lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
|
|
695
|
+
args)
|
|
696
|
+
|
|
697
|
+
def override_op_definition(self, op_to_override, op_impl):
|
|
698
|
+
self._ops[op_to_override] = ops_registry.Operator(
|
|
552
699
|
op_to_override,
|
|
553
700
|
op_impl,
|
|
554
701
|
is_jax_function=False,
|
|
555
702
|
is_user_defined=True,
|
|
556
|
-
needs_env=False
|
|
557
|
-
|
|
703
|
+
needs_env=False,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
@contextlib.contextmanager
|
|
707
|
+
def override_property(self, **kwargs):
|
|
708
|
+
new_prop = self.param.override(**kwargs)
|
|
709
|
+
self._property.content.append(new_prop)
|
|
710
|
+
yield
|
|
711
|
+
self._property.content.pop()
|