torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +57 -19
- torchax/amp.py +333 -0
- torchax/config.py +19 -12
- torchax/decompositions.py +663 -195
- torchax/device_module.py +7 -1
- torchax/distributed.py +55 -60
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +275 -141
- torchax/mesh_util.py +211 -0
- torchax/ops/jaten.py +1718 -1294
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +219 -78
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +417 -275
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/METADATA +111 -145
- torchax-0.0.5.dist-info/RECORD +32 -0
- torchax/environment.py +0 -2
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/licenses/LICENSE +0 -0
torchax/tensor.py
CHANGED
|
@@ -5,15 +5,18 @@ from typing import Optional, Any
|
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import numpy
|
|
8
|
+
import itertools
|
|
8
9
|
import torch
|
|
9
10
|
import torch.distributed._functional_collectives
|
|
10
11
|
import torch.func
|
|
11
12
|
import torch.utils._mode_utils as mode_utils
|
|
12
13
|
import torch.utils._python_dispatch as torch_dispatch
|
|
13
14
|
import torch.utils._pytree as torch_pytree
|
|
14
|
-
|
|
15
|
+
from torchax.view import View
|
|
15
16
|
from torchax import config
|
|
16
17
|
from torchax.ops import mappings, ops_registry
|
|
18
|
+
from torchax import amp
|
|
19
|
+
from jax.experimental import mutable_array
|
|
17
20
|
|
|
18
21
|
logger = logging.getLogger(__name__)
|
|
19
22
|
|
|
@@ -30,32 +33,15 @@ def unwrap(torchtensors):
|
|
|
30
33
|
return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
|
|
31
34
|
|
|
32
35
|
|
|
33
|
-
def t2j(t):
|
|
34
|
-
if isinstance(t, Tensor):
|
|
35
|
-
return t._elem
|
|
36
|
-
return mappings.t2j(t)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def j2t(x):
|
|
40
|
-
return mappings.j2t(x)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def t2j_dtype(dtype):
|
|
44
|
-
return mappings.t2j_dtype(dtype)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def j2t_dtype(dtype):
|
|
48
|
-
return mappings.j2t_dtype(dtype)
|
|
49
|
-
|
|
50
|
-
|
|
51
36
|
@contextlib.contextmanager
|
|
52
37
|
def log_nested(env, message):
|
|
53
38
|
if env.config.debug_print_each_op:
|
|
54
|
-
print((
|
|
39
|
+
print((" " * log_nested.level) + message, file=sys.stderr)
|
|
55
40
|
log_nested.level += 1
|
|
56
41
|
yield
|
|
57
42
|
log_nested.level -= 1
|
|
58
43
|
|
|
44
|
+
|
|
59
45
|
log_nested.level = 0
|
|
60
46
|
|
|
61
47
|
|
|
@@ -63,7 +49,7 @@ class Tensor(torch.Tensor):
|
|
|
63
49
|
|
|
64
50
|
@staticmethod
|
|
65
51
|
def __new__(cls, elem, env):
|
|
66
|
-
dtype = j2t_dtype(elem.dtype)
|
|
52
|
+
dtype = mappings.j2t_dtype(elem.dtype)
|
|
67
53
|
shape = list(elem.shape)
|
|
68
54
|
for i, s in enumerate(shape):
|
|
69
55
|
if not isinstance(s, int):
|
|
@@ -74,11 +60,11 @@ class Tensor(torch.Tensor):
|
|
|
74
60
|
cls,
|
|
75
61
|
shape,
|
|
76
62
|
dtype=dtype,
|
|
77
|
-
device=
|
|
63
|
+
device="meta",
|
|
78
64
|
requires_grad=False,
|
|
79
65
|
)
|
|
80
66
|
|
|
81
|
-
def __init__(self, elem: jax.Array, env:
|
|
67
|
+
def __init__(self, elem: jax.Array, env: "Environment"):
|
|
82
68
|
super().__init__()
|
|
83
69
|
self._elem = elem
|
|
84
70
|
self._env = env
|
|
@@ -93,7 +79,7 @@ class Tensor(torch.Tensor):
|
|
|
93
79
|
|
|
94
80
|
@property
|
|
95
81
|
def shape(self):
|
|
96
|
-
return self._elem.shape
|
|
82
|
+
return torch.Size(self._elem.shape)
|
|
97
83
|
|
|
98
84
|
@property
|
|
99
85
|
def ndim(self):
|
|
@@ -120,14 +106,13 @@ class Tensor(torch.Tensor):
|
|
|
120
106
|
|
|
121
107
|
@classmethod
|
|
122
108
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
return func(*args, **(kwargs or {}))
|
|
109
|
+
# TODO(hanq): figure out why is dispatch mode not sufficient
|
|
110
|
+
if func == torch.ops._c10d_functional.wait_tensor.default:
|
|
111
|
+
return args[0]._env.dispatch(func, types, args, kwargs)
|
|
112
|
+
raise AssertionError(
|
|
113
|
+
'torchax Tensors can only do math within the torchax environment.'
|
|
114
|
+
'Please wrap your code with `with torchax.default_env()` or '
|
|
115
|
+
'call torchax.enable_globally() before.')
|
|
131
116
|
|
|
132
117
|
def detach(self):
|
|
133
118
|
return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
|
|
@@ -141,18 +126,18 @@ class Tensor(torch.Tensor):
|
|
|
141
126
|
return self._elem
|
|
142
127
|
|
|
143
128
|
def torch(self) -> torch.Tensor:
|
|
144
|
-
return
|
|
129
|
+
return self._env.j2t_copy(self.jax())
|
|
145
130
|
|
|
146
131
|
@property
|
|
147
132
|
def dtype(self):
|
|
148
|
-
return j2t_dtype(self._elem.dtype)
|
|
133
|
+
return mappings.j2t_dtype(self._elem.dtype)
|
|
149
134
|
|
|
150
135
|
def dim(self):
|
|
151
136
|
return self.ndim
|
|
152
137
|
|
|
153
138
|
@property
|
|
154
139
|
def device(self):
|
|
155
|
-
return torch.device(
|
|
140
|
+
return torch.device("jax:0")
|
|
156
141
|
|
|
157
142
|
@property
|
|
158
143
|
def jax_device(self):
|
|
@@ -160,7 +145,8 @@ class Tensor(torch.Tensor):
|
|
|
160
145
|
|
|
161
146
|
@property
|
|
162
147
|
def data(self):
|
|
163
|
-
logger.
|
|
148
|
+
logger.warning(
|
|
149
|
+
"In-place to .data modifications still results a copy on TPU")
|
|
164
150
|
return self
|
|
165
151
|
|
|
166
152
|
@data.setter
|
|
@@ -182,15 +168,15 @@ class Tensor(torch.Tensor):
|
|
|
182
168
|
|
|
183
169
|
def shard_(self, sharding):
|
|
184
170
|
self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
|
|
185
|
-
|
|
171
|
+
|
|
186
172
|
|
|
187
173
|
def debug_accuracy(func, args, kwargs, current_output):
|
|
188
174
|
args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
|
|
189
|
-
torch.Tensor, lambda x:
|
|
175
|
+
torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output))
|
|
190
176
|
|
|
191
177
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
192
|
-
if
|
|
193
|
-
kwargs_torch[
|
|
178
|
+
if "device" in kwargs_torch:
|
|
179
|
+
kwargs_torch["device"] = "cpu" # do the torch native for comparison
|
|
194
180
|
expected_out = func(*args_torch, **kwargs_torch)
|
|
195
181
|
|
|
196
182
|
flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
|
|
@@ -200,8 +186,8 @@ def debug_accuracy(func, args, kwargs, current_output):
|
|
|
200
186
|
if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
|
|
201
187
|
ex = ex.to(real.dtype)
|
|
202
188
|
try:
|
|
203
|
-
if
|
|
204
|
-
|
|
189
|
+
if isinstance(ex, torch.Tensor) and not torch.allclose(
|
|
190
|
+
ex, real, atol=1e-3, equal_nan=True):
|
|
205
191
|
import pdb
|
|
206
192
|
|
|
207
193
|
pdb.set_trace()
|
|
@@ -212,46 +198,52 @@ def debug_accuracy(func, args, kwargs, current_output):
|
|
|
212
198
|
|
|
213
199
|
return True
|
|
214
200
|
|
|
201
|
+
|
|
215
202
|
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
|
|
203
|
+
|
|
216
204
|
def _display(a):
|
|
217
205
|
if isinstance(a, torch.Tensor):
|
|
218
|
-
return f
|
|
206
|
+
return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
|
|
219
207
|
elif isinstance(a, jax.Array):
|
|
220
|
-
return f
|
|
208
|
+
return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
|
|
221
209
|
else:
|
|
222
210
|
return str(a)
|
|
223
211
|
|
|
224
212
|
kwargs = kwargs or {}
|
|
225
|
-
title =
|
|
226
|
-
args_msg =
|
|
227
|
-
kwargs_msg =
|
|
228
|
-
|
|
213
|
+
title = "DISPATCH" if is_dispatch else "FUNCTION"
|
|
214
|
+
args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
|
|
215
|
+
kwargs_msg = ("kwargs: " +
|
|
216
|
+
",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
|
|
217
|
+
if log_args else "")
|
|
218
|
+
return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
|
|
229
219
|
|
|
230
220
|
|
|
231
221
|
class XLAFunctionMode(torch.overrides.TorchFunctionMode):
|
|
232
222
|
"""Context manager that dispatches torch function calls to JAX."""
|
|
233
223
|
|
|
234
224
|
def __init__(self, env):
|
|
235
|
-
|
|
225
|
+
self.env = env
|
|
236
226
|
|
|
237
227
|
def __torch_function__(self,
|
|
238
228
|
func,
|
|
239
229
|
types,
|
|
240
230
|
args=(),
|
|
241
231
|
kwargs=None) -> torch.Tensor:
|
|
242
|
-
message = f
|
|
232
|
+
message = f"FUNCTION: {_name_of_func(func)}"
|
|
243
233
|
if self.env.config.debug_print_each_op_operands:
|
|
244
|
-
message = message +
|
|
245
|
-
message = _make_debug_msg(False,
|
|
234
|
+
message = message + "f"
|
|
235
|
+
message = _make_debug_msg(False,
|
|
236
|
+
self.env.config.debug_print_each_op_operands,
|
|
246
237
|
func, args, kwargs)
|
|
247
238
|
with log_nested(self.env, message):
|
|
248
239
|
try:
|
|
249
240
|
return self.env.dispatch(func, types, args, kwargs)
|
|
250
241
|
except OperatorNotFound:
|
|
251
242
|
pass
|
|
252
|
-
if _name_of_func(func) in (
|
|
243
|
+
if _name_of_func(func) in (
|
|
244
|
+
"rot90"): # skip rot90 with k%4==0 due to no change
|
|
253
245
|
if len(args) >= 2 and type(args[1]) == int:
|
|
254
|
-
if (
|
|
246
|
+
if (args[1]) % 4 == 0:
|
|
255
247
|
return args[0]
|
|
256
248
|
return func(*args, **(kwargs or {}))
|
|
257
249
|
|
|
@@ -262,296 +254,446 @@ class XLADispatchMode(torch_dispatch.TorchDispatchMode):
|
|
|
262
254
|
self.env = env
|
|
263
255
|
|
|
264
256
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
265
|
-
message = _make_debug_msg(True,
|
|
257
|
+
message = _make_debug_msg(True,
|
|
258
|
+
self.env.config.debug_print_each_op_operands,
|
|
266
259
|
func, args, kwargs)
|
|
267
260
|
with log_nested(self.env, message):
|
|
268
261
|
if isinstance(func, torch._ops.OpOverloadPacket):
|
|
269
262
|
with self:
|
|
270
263
|
return func(*args, **kwargs)
|
|
271
|
-
|
|
264
|
+
# Only functions under these namespaces will be intercepted
|
|
265
|
+
if func.namespace not in (
|
|
266
|
+
"aten",
|
|
267
|
+
"_c10d_functional",
|
|
268
|
+
"torchvision",
|
|
269
|
+
"xla",
|
|
270
|
+
):
|
|
272
271
|
return func(*args, **kwargs)
|
|
273
272
|
return self.env.dispatch(func, types, args, kwargs)
|
|
274
273
|
|
|
274
|
+
|
|
275
275
|
def _name_of_func(func):
|
|
276
|
-
if hasattr(func,
|
|
276
|
+
if hasattr(func, "name"):
|
|
277
277
|
return func.name()
|
|
278
278
|
return func.__name__
|
|
279
279
|
|
|
280
280
|
|
|
281
281
|
# Constructors that don't take other tensor as input
|
|
282
282
|
TENSOR_CONSTRUCTORS = {
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
283
|
+
torch.ones,
|
|
284
|
+
torch.zeros,
|
|
285
|
+
torch.empty,
|
|
286
|
+
torch.empty_strided,
|
|
287
|
+
torch.tensor,
|
|
288
|
+
torch.arange,
|
|
289
|
+
torch.eye,
|
|
290
|
+
torch.randn,
|
|
291
|
+
torch.rand,
|
|
292
|
+
torch.randint,
|
|
293
|
+
torch.full,
|
|
294
|
+
torch.as_tensor,
|
|
295
295
|
}
|
|
296
296
|
|
|
297
|
+
# TODO(wen): use existing types, either from torch or jax
|
|
298
|
+
SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
|
|
297
299
|
|
|
298
|
-
class Environment(contextlib.ContextDecorator):
|
|
299
|
-
"""This class holds a set of configurations and "globals" needed
|
|
300
|
-
|
|
301
|
-
for executing torch program using jax.
|
|
302
|
-
Things included so far:
|
|
303
|
-
|
|
304
|
-
op registry
|
|
305
|
-
PRNGKey
|
|
306
|
-
Configs
|
|
307
300
|
|
|
308
|
-
|
|
309
|
-
|
|
301
|
+
class Environment(contextlib.ContextDecorator):
|
|
302
|
+
"""This class holds a set of configurations and "globals" needed
|
|
310
303
|
|
|
311
|
-
|
|
304
|
+
for executing torch program using jax.
|
|
305
|
+
Things included so far:
|
|
312
306
|
|
|
307
|
+
op registry
|
|
308
|
+
PRNGKey
|
|
309
|
+
Configs
|
|
313
310
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
self._dispatch_mode = XLADispatchMode(self)
|
|
311
|
+
Also helper functions to manipulate those.
|
|
312
|
+
"""
|
|
317
313
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
314
|
+
def __init__(self, configuration=None):
|
|
315
|
+
self._function_mode = XLAFunctionMode(self)
|
|
316
|
+
self._dispatch_mode = XLADispatchMode(self)
|
|
321
317
|
|
|
322
|
-
|
|
323
|
-
|
|
318
|
+
# name is torch callable
|
|
319
|
+
self._ops = {}
|
|
320
|
+
self._decomps = {}
|
|
324
321
|
|
|
325
|
-
|
|
326
|
-
self.enabled = False
|
|
327
|
-
self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
|
|
322
|
+
self.load_ops()
|
|
328
323
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
device = torch.get_default_device()
|
|
324
|
+
self._mesh = None
|
|
325
|
+
self.config = configuration or config.Configuration()
|
|
332
326
|
|
|
333
|
-
|
|
334
|
-
|
|
327
|
+
self._manually_entered = False
|
|
328
|
+
self.enabled = False
|
|
335
329
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
330
|
+
self._prng_key = mutable_array(
|
|
331
|
+
jax.random.key(torch.initial_seed() % (1 << 63)))
|
|
332
|
+
self.autocast_dtype = None
|
|
333
|
+
self._target_device = jax.local_devices()[0].platform
|
|
339
334
|
|
|
340
|
-
|
|
341
|
-
|
|
335
|
+
@property
|
|
336
|
+
def target_device(self):
|
|
337
|
+
return self._target_device
|
|
342
338
|
|
|
343
|
-
|
|
344
|
-
|
|
339
|
+
@target_device.setter
|
|
340
|
+
def target_device(self, device: str):
|
|
341
|
+
self._target_device = device.lower()
|
|
345
342
|
|
|
346
|
-
|
|
347
|
-
|
|
343
|
+
def manual_seed(self, key):
|
|
344
|
+
self._prng_key = mutable_array(jax.random.key(key))
|
|
348
345
|
|
|
346
|
+
@property
|
|
347
|
+
def prng_key(self):
|
|
348
|
+
return self._prng_key[...]
|
|
349
|
+
|
|
350
|
+
def get_as_jax_device(self, device: Any):
|
|
351
|
+
if device is None:
|
|
352
|
+
device = torch.get_default_device()
|
|
353
|
+
|
|
354
|
+
if isinstance(device, torch.device):
|
|
355
|
+
device = str(device)
|
|
356
|
+
|
|
357
|
+
if not self.config.use_torch_native_for_cpu_tensor and device.startswith(
|
|
358
|
+
"cpu"):
|
|
359
|
+
return jax.devices("cpu")[0]
|
|
360
|
+
|
|
361
|
+
if self.config.treat_cuda_as_jax_device and device.startswith("cuda"):
|
|
362
|
+
return jax.local_devices()[0]
|
|
363
|
+
|
|
364
|
+
if device.startswith("xla"):
|
|
365
|
+
return jax.local_devices()[0]
|
|
366
|
+
|
|
367
|
+
# TODO (wen): jax is NOT a device type,
|
|
368
|
+
# once we can register more than one backend, revisit
|
|
369
|
+
if device.startswith("jax"):
|
|
370
|
+
match self.target_device:
|
|
371
|
+
case "cpu":
|
|
372
|
+
return jax.devices("cpu")[0]
|
|
373
|
+
case "tpu":
|
|
374
|
+
return jax.devices("tpu")[0]
|
|
375
|
+
case _:
|
|
376
|
+
raise AttributeError(
|
|
377
|
+
f"Cannot handle env.target_device {self.target_device}")
|
|
378
|
+
|
|
379
|
+
return None # fallback to torch
|
|
380
|
+
|
|
381
|
+
def load_ops(self):
|
|
382
|
+
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
|
|
383
|
+
|
|
384
|
+
for k, v in itertools.chain(ops_registry.all_aten_ops.items(),
|
|
385
|
+
ops_registry.all_torch_functions.items()):
|
|
386
|
+
if v.is_jax_function:
|
|
387
|
+
self._ops[k] = v
|
|
388
|
+
else:
|
|
389
|
+
self._decomps[k] = v
|
|
349
390
|
|
|
350
|
-
|
|
351
|
-
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
|
|
352
|
-
self._ops.update(ops_registry.all_aten_ops)
|
|
353
|
-
self._ops.update(ops_registry.all_torch_functions)
|
|
391
|
+
from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
|
|
354
392
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
for k, v in decomps.items():
|
|
359
|
-
if k not in self._ops:
|
|
360
|
-
self._ops[k] = ops_registry.Operator(
|
|
393
|
+
for k, v in DECOMPOSITIONS.items():
|
|
394
|
+
if k not in self._decomps:
|
|
395
|
+
self._decomps[k] = ops_registry.Operator(
|
|
361
396
|
k,
|
|
362
397
|
v,
|
|
363
398
|
is_jax_function=False,
|
|
364
399
|
is_user_defined=False,
|
|
365
|
-
needs_env=False
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
400
|
+
needs_env=False,
|
|
401
|
+
is_view_op=k in MUTABLE_DECOMPOSITION,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def _get_op_or_decomp(self, func):
|
|
405
|
+
|
|
406
|
+
def _get_from_dict(op_dict, op):
|
|
407
|
+
op = op_dict.get(func)
|
|
408
|
+
if op is None and isinstance(func, torch._ops.OpOverloadPacket):
|
|
409
|
+
op = op_dict.get(func.default)
|
|
410
|
+
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
411
|
+
op = op_dict.get(func.overloadpacket)
|
|
412
|
+
return op
|
|
413
|
+
|
|
414
|
+
op = _get_from_dict(self._ops, func)
|
|
415
|
+
|
|
416
|
+
if op is None:
|
|
417
|
+
# fallback to decompose
|
|
418
|
+
op = _get_from_dict(self._decomps, func)
|
|
419
|
+
|
|
420
|
+
if op is None:
|
|
421
|
+
raise OperatorNotFound(
|
|
422
|
+
f"Operator with name {_name_of_func(func)} has no lowering")
|
|
423
|
+
|
|
424
|
+
return op
|
|
425
|
+
|
|
426
|
+
def _to_copy(self, the_tensor, new_dtype, new_device):
|
|
427
|
+
if isinstance(the_tensor, View):
|
|
428
|
+
the_tensor = the_tensor.torch()
|
|
429
|
+
|
|
430
|
+
if isinstance(the_tensor, Tensor):
|
|
431
|
+
|
|
432
|
+
arr = the_tensor.jax()
|
|
433
|
+
|
|
434
|
+
if new_dtype is not None and new_dtype != arr.dtype:
|
|
435
|
+
arr = arr.astype(mappings.t2j_dtype(new_dtype))
|
|
436
|
+
|
|
437
|
+
if new_device is not None:
|
|
438
|
+
match str(new_device).lower():
|
|
439
|
+
case "cpu":
|
|
377
440
|
# converting to a non-jax device: let torch native handle it
|
|
378
|
-
torch_tensor =
|
|
441
|
+
torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor,
|
|
442
|
+
Tensor) else arr
|
|
379
443
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
380
444
|
return torch_tensor.to(new_device)
|
|
445
|
+
case "jax":
|
|
446
|
+
# move torchax.tensor / jax tensor between devices
|
|
447
|
+
# I don't know ifgit this will work after the model is jitted
|
|
448
|
+
if self.target_device != the_tensor.jax_device.platform:
|
|
449
|
+
arr = jax.device_put(the_tensor.jax(),
|
|
450
|
+
jax.devices(self.target_device)[0])
|
|
451
|
+
return Tensor(arr, self)
|
|
452
|
+
case _:
|
|
453
|
+
logging.error(f"torchax.Tenosr cannot handle device {new_device}")
|
|
454
|
+
|
|
455
|
+
else:
|
|
456
|
+
if new_dtype is not None and new_dtype != the_tensor.dtype:
|
|
457
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
458
|
+
the_tensor = the_tensor.to(new_dtype)
|
|
459
|
+
|
|
460
|
+
if new_device is None: ## device is None means don't change device
|
|
461
|
+
return the_tensor
|
|
462
|
+
|
|
463
|
+
jax_device = self.get_as_jax_device(new_device)
|
|
464
|
+
if jax_device:
|
|
465
|
+
arr = self.t2j_copy(the_tensor)
|
|
466
|
+
arr = jax.device_put(arr, jax_device)
|
|
381
467
|
else:
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
the_tensor = the_tensor.to(new_dtype)
|
|
385
|
-
jax_device = self.get_as_jax_device(new_device)
|
|
386
|
-
if jax_device:
|
|
387
|
-
arr = t2j(the_tensor)
|
|
388
|
-
arr = jax.device_put(arr, jax_device)
|
|
389
|
-
else:
|
|
390
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
391
|
-
return the_tensor.to(new_device)
|
|
468
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
469
|
+
return the_tensor.to(new_device)
|
|
392
470
|
|
|
393
|
-
|
|
394
|
-
|
|
471
|
+
return Tensor(arr, self)
|
|
395
472
|
|
|
396
|
-
|
|
397
|
-
|
|
473
|
+
def get_and_rotate_prng_key(self,
|
|
474
|
+
generator: Optional[torch.Generator] = None):
|
|
475
|
+
if generator is not None:
|
|
398
476
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
477
|
+
self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63))
|
|
478
|
+
old_key = self._prng_key[...]
|
|
479
|
+
new_prng_key, next_key = jax.random.split(old_key)
|
|
480
|
+
self._prng_key[...] = new_prng_key
|
|
481
|
+
return next_key
|
|
482
|
+
|
|
483
|
+
def _handle_tensor_constructor(self, func, args, kwargs):
|
|
484
|
+
device = kwargs.get("device")
|
|
485
|
+
jax_device = self.get_as_jax_device(device)
|
|
486
|
+
# TODO(qihqi) figure out better ways for device propagation
|
|
487
|
+
if not self._manually_entered and jax_device is None:
|
|
488
|
+
# let torch handle it
|
|
489
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
490
|
+
return func(*args, **kwargs)
|
|
491
|
+
with jax.default_device(jax_device):
|
|
492
|
+
requires_grad = kwargs.get("requires_grad", False)
|
|
493
|
+
op = self._get_op_or_decomp(func)
|
|
494
|
+
res = op.func(*args, **kwargs)
|
|
495
|
+
if isinstance(res, jax.Array):
|
|
496
|
+
res = Tensor(res, self)
|
|
497
|
+
if requires_grad:
|
|
498
|
+
res.requires_grad = True
|
|
499
|
+
return res
|
|
403
500
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
411
|
-
return func(*args, **kwargs)
|
|
412
|
-
with jax.default_device(jax_device):
|
|
413
|
-
op = self._ops.get(func)
|
|
414
|
-
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
415
|
-
op = self._ops.get(func.overloadpacket)
|
|
416
|
-
res = op.func(*args, **kwargs)
|
|
417
|
-
if isinstance(res, jax.Array):
|
|
418
|
-
res = Tensor(res, self)
|
|
419
|
-
return res
|
|
420
|
-
|
|
421
|
-
def _torch_Tensor_to(self, args, kwargs):
|
|
422
|
-
the_tensor = args[0]
|
|
423
|
-
args = args[1:]
|
|
424
|
-
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
|
|
425
|
-
dtype = args[0].dtype
|
|
426
|
-
device = args[0].device
|
|
427
|
-
return self._to_copy(the_tensor, dtype, device)
|
|
428
|
-
device = kwargs.get('device')
|
|
429
|
-
dtype = kwargs.get('dtype')
|
|
430
|
-
# args like pin_memory etc that we will ignore
|
|
431
|
-
args = list(filter(lambda x: not isinstance(x, bool), args))
|
|
432
|
-
if len(args) >= 2:
|
|
433
|
-
device, dtype, *_ = args
|
|
434
|
-
elif len(args) == 1 and isinstance(args[0], torch.dtype):
|
|
435
|
-
dtype = args[0]
|
|
436
|
-
elif len(args) == 1:
|
|
437
|
-
device = args[0]
|
|
501
|
+
def _torch_Tensor_to(self, args, kwargs):
|
|
502
|
+
the_tensor = args[0]
|
|
503
|
+
args = args[1:]
|
|
504
|
+
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
|
|
505
|
+
dtype = args[0].dtype
|
|
506
|
+
device = args[0].device
|
|
438
507
|
return self._to_copy(the_tensor, dtype, device)
|
|
508
|
+
device = kwargs.get("device")
|
|
509
|
+
dtype = kwargs.get("dtype")
|
|
510
|
+
# args like pin_memory etc that we will ignore
|
|
511
|
+
args = list(filter(lambda x: not isinstance(x, bool), args))
|
|
512
|
+
if len(args) >= 2:
|
|
513
|
+
device, dtype, *_ = args
|
|
514
|
+
elif len(args) == 1 and isinstance(args[0], torch.dtype):
|
|
515
|
+
dtype = args[0]
|
|
516
|
+
elif len(args) == 1:
|
|
517
|
+
device = args[0]
|
|
518
|
+
return self._to_copy(the_tensor, dtype, device)
|
|
519
|
+
|
|
520
|
+
def dispatch(self, func, types, args, kwargs):
|
|
521
|
+
kwargs = kwargs or {}
|
|
522
|
+
if func in TENSOR_CONSTRUCTORS:
|
|
523
|
+
return self._handle_tensor_constructor(func, args, kwargs)
|
|
524
|
+
if func in (
|
|
525
|
+
torch.Tensor.to,
|
|
526
|
+
torch.ops.aten.lift_fresh.default,
|
|
527
|
+
torch.ops.aten._to_copy,
|
|
528
|
+
torch.ops.aten._to_copy.default,
|
|
529
|
+
):
|
|
530
|
+
return self._torch_Tensor_to(args, kwargs)
|
|
531
|
+
|
|
532
|
+
# If the func doesn't act on Tensor, and is not a tensor constructor,
|
|
533
|
+
# We should skip and let torch handle it.
|
|
534
|
+
|
|
535
|
+
tensor_args = [
|
|
536
|
+
t for t in torch_pytree.tree_flatten(args)[0]
|
|
537
|
+
if isinstance(t, torch.Tensor)
|
|
538
|
+
]
|
|
539
|
+
|
|
540
|
+
def is_not_torchax_tensor(x):
|
|
541
|
+
return not isinstance(x, Tensor) and not isinstance(x, View)
|
|
542
|
+
|
|
543
|
+
if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
|
|
544
|
+
res = func(*args, **kwargs)
|
|
545
|
+
return res
|
|
439
546
|
|
|
547
|
+
with jax.named_scope(_name_of_func(func)):
|
|
548
|
+
op = self._get_op_or_decomp(func)
|
|
440
549
|
|
|
441
|
-
|
|
550
|
+
old_args, old_kwargs = args, kwargs
|
|
551
|
+
with self._dispatch_mode:
|
|
552
|
+
args, kwargs = torch_pytree.tree_map_only(
|
|
553
|
+
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
554
|
+
torch.distributed._functional_collectives.wait_tensor,
|
|
555
|
+
(args, kwargs),
|
|
556
|
+
)
|
|
442
557
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default):
|
|
447
|
-
return self._torch_Tensor_to(args, kwargs)
|
|
558
|
+
try:
|
|
559
|
+
if not op.is_view_op:
|
|
560
|
+
args, kwargs = self.v2t_iso((args, kwargs))
|
|
448
561
|
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
562
|
+
with self:
|
|
563
|
+
if self.autocast_dtype is not None:
|
|
564
|
+
autocast_policy = amp.autocast_policy.get(func)
|
|
565
|
+
if autocast_policy is not None:
|
|
566
|
+
args, kwargs = amp.execute_policy(autocast_policy, args, kwargs,
|
|
567
|
+
self.autocast_dtype)
|
|
455
568
|
|
|
456
|
-
|
|
457
|
-
|
|
569
|
+
if op.is_jax_function:
|
|
570
|
+
args, kwargs = self.t2j_iso((args, kwargs))
|
|
571
|
+
except AssertionError:
|
|
572
|
+
if self.config.debug_mixed_tensor:
|
|
573
|
+
breakpoint()
|
|
574
|
+
else:
|
|
575
|
+
raise
|
|
458
576
|
|
|
459
|
-
|
|
460
|
-
|
|
577
|
+
if op.needs_env:
|
|
578
|
+
kwargs["env"] = self
|
|
461
579
|
|
|
462
|
-
|
|
463
|
-
|
|
580
|
+
if op.is_jax_function:
|
|
581
|
+
res = op.func(*args, **kwargs)
|
|
582
|
+
else:
|
|
583
|
+
# enable dispatch mode because this op could be a composite autograd op
|
|
584
|
+
# meaning, it will decompose in C++
|
|
585
|
+
with self._dispatch_mode:
|
|
586
|
+
res = op.func(*args, **kwargs)
|
|
464
587
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
f'Operator with name {_name_of_func(func)} has no lowering')
|
|
588
|
+
if op.is_jax_function:
|
|
589
|
+
res = self.j2t_iso(res)
|
|
468
590
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
472
|
-
torch.distributed._functional_collectives.wait_tensor,
|
|
473
|
-
(args, kwargs))
|
|
474
|
-
try:
|
|
475
|
-
if op.is_jax_function:
|
|
476
|
-
args, kwargs = self.t2j_iso((args, kwargs))
|
|
477
|
-
except AssertionError:
|
|
478
|
-
if self.config.debug_mixed_tensor:
|
|
479
|
-
import pdb; pdb.set_trace()
|
|
480
|
-
else:
|
|
481
|
-
raise
|
|
591
|
+
if self.config.force_materialize_views and isinstance(res, View):
|
|
592
|
+
res = res.torch()
|
|
482
593
|
|
|
594
|
+
if self.config.debug_accuracy_for_each_op:
|
|
595
|
+
debug_accuracy(func, old_args, old_kwargs, res)
|
|
596
|
+
return res
|
|
483
597
|
|
|
484
|
-
|
|
485
|
-
|
|
598
|
+
def enable_torch_modes(self):
|
|
599
|
+
self._dispatch_mode.__enter__()
|
|
600
|
+
self._function_mode.__enter__()
|
|
601
|
+
self.enabled = True
|
|
602
|
+
|
|
603
|
+
def disable_torch_modes(self, *exc):
|
|
604
|
+
if not exc:
|
|
605
|
+
exc = (None, None, None)
|
|
606
|
+
self._function_mode.__exit__(*exc)
|
|
607
|
+
self._dispatch_mode.__exit__(*exc)
|
|
608
|
+
self.enabled = False
|
|
609
|
+
|
|
610
|
+
def __enter__(self):
|
|
611
|
+
self.enable_torch_modes()
|
|
612
|
+
self._manually_entered = True
|
|
613
|
+
return self
|
|
486
614
|
|
|
487
|
-
|
|
488
|
-
|
|
615
|
+
def __exit__(self, *exc):
|
|
616
|
+
self._manually_entered = False
|
|
617
|
+
self.disable_torch_modes(*exc)
|
|
489
618
|
|
|
490
|
-
|
|
491
|
-
|
|
619
|
+
def _move_one_value(self, val):
|
|
620
|
+
if isinstance(val, torch.nn.Module):
|
|
621
|
+
with self:
|
|
622
|
+
return val.to("jax")
|
|
623
|
+
if isinstance(val, Tensor):
|
|
624
|
+
return val
|
|
625
|
+
if isinstance(val, torch.Tensor):
|
|
626
|
+
return Tensor(self.t2j_copy(val), self)
|
|
627
|
+
return val
|
|
492
628
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
629
|
+
def to_xla(self, torchvalues):
|
|
630
|
+
# tensors are torch.Tensors (not XLATensor)
|
|
631
|
+
res = torch_pytree.tree_map(self._move_one_value, torchvalues)
|
|
632
|
+
return res
|
|
496
633
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
self._function_mode.__enter__()
|
|
500
|
-
self.enabled = True
|
|
634
|
+
def t2j_iso(self, torchtensors):
|
|
635
|
+
"""Convert torchax Tensor to jax array.
|
|
501
636
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
self._function_mode.__exit__(*exc)
|
|
506
|
-
self._dispatch_mode.__exit__(*exc)
|
|
507
|
-
self.enabled = False
|
|
508
|
-
|
|
509
|
-
def __enter__(self):
|
|
510
|
-
self.enable_torch_modes()
|
|
511
|
-
self._manually_entered = True
|
|
512
|
-
return self
|
|
513
|
-
|
|
514
|
-
def __exit__(self, *exc):
|
|
515
|
-
self._manually_entered = False
|
|
516
|
-
self.disable_torch_modes(*exc)
|
|
517
|
-
|
|
518
|
-
def _move_one_value(self, val):
|
|
519
|
-
if isinstance(val, torch.nn.Module):
|
|
520
|
-
with self:
|
|
521
|
-
return val.to('jax')
|
|
522
|
-
if isinstance(val, Tensor):
|
|
523
|
-
return val
|
|
524
|
-
if isinstance(val, torch.Tensor):
|
|
525
|
-
return Tensor(t2j(val), self)
|
|
526
|
-
return val
|
|
637
|
+
This function will not copy, will just unwrap the inner jax array out.
|
|
638
|
+
Note: iso is short for "isomorphic"
|
|
639
|
+
"""
|
|
527
640
|
|
|
528
|
-
def
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
641
|
+
def to_jax(x):
|
|
642
|
+
if isinstance(
|
|
643
|
+
x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
644
|
+
x = x.wait()
|
|
645
|
+
assert isinstance(x, Tensor) or isinstance(x, View), (
|
|
646
|
+
f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
|
|
647
|
+
)
|
|
648
|
+
return x.jax()
|
|
649
|
+
|
|
650
|
+
res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
|
|
651
|
+
return res
|
|
534
652
|
|
|
535
|
-
|
|
536
|
-
def to_jax(x):
|
|
537
|
-
if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
538
|
-
x = x.wait()
|
|
539
|
-
assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor'
|
|
540
|
-
return x.jax()
|
|
541
|
-
return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
|
|
653
|
+
def v2t_iso(self, views):
|
|
542
654
|
|
|
543
|
-
def
|
|
544
|
-
|
|
545
|
-
|
|
655
|
+
def to_tensor(x):
|
|
656
|
+
if isinstance(x, View):
|
|
657
|
+
return x.torch()
|
|
658
|
+
return x
|
|
546
659
|
|
|
547
|
-
|
|
548
|
-
|
|
660
|
+
res = torch_pytree.tree_map_only(View, to_tensor, views)
|
|
661
|
+
return res
|
|
549
662
|
|
|
550
|
-
|
|
551
|
-
|
|
663
|
+
def j2t_iso(self, jaxarray):
|
|
664
|
+
"""Convert jax array to torchax Tensor.
|
|
665
|
+
|
|
666
|
+
This function will not copy, will just wrap the jax array with a torchax Tensor
|
|
667
|
+
Note: iso is short for "isomorphic"
|
|
668
|
+
"""
|
|
669
|
+
return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self),
|
|
670
|
+
jaxarray)
|
|
671
|
+
|
|
672
|
+
def j2t_copy(self, args):
|
|
673
|
+
"""Convert torch.Tensor in cpu to a jax array
|
|
674
|
+
|
|
675
|
+
This might involves copying the data (depending if dlpack is enabled)
|
|
676
|
+
"""
|
|
677
|
+
return torch_pytree.tree_map_only(
|
|
678
|
+
jax.Array,
|
|
679
|
+
lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
|
|
680
|
+
args)
|
|
681
|
+
|
|
682
|
+
def t2j_copy(self, args):
|
|
683
|
+
"""Convert jax array to torch.Tensor in cpu.
|
|
684
|
+
|
|
685
|
+
This might involves copying the data (depending if dlpack is enabled)
|
|
686
|
+
"""
|
|
687
|
+
return torch_pytree.tree_map_only(
|
|
688
|
+
torch.Tensor,
|
|
689
|
+
lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
|
|
690
|
+
args)
|
|
691
|
+
|
|
692
|
+
def override_op_definition(self, op_to_override, op_impl):
|
|
693
|
+
self._ops[op_to_override] = ops_registry.Operator(
|
|
552
694
|
op_to_override,
|
|
553
695
|
op_impl,
|
|
554
696
|
is_jax_function=False,
|
|
555
697
|
is_user_defined=True,
|
|
556
|
-
needs_env=False
|
|
557
|
-
|
|
698
|
+
needs_env=False,
|
|
699
|
+
)
|