torchax 0.0.10.dev20251118__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +43 -0
- torchax/__init__.py +149 -0
- torchax/amp.py +218 -0
- torchax/checkpoint.py +85 -0
- torchax/config.py +44 -0
- torchax/decompositions.py +796 -0
- torchax/device_module.py +47 -0
- torchax/export.py +258 -0
- torchax/flax.py +55 -0
- torchax/interop.py +369 -0
- torchax/mesh_util.py +236 -0
- torchax/ops/__init__.py +25 -0
- torchax/ops/jaten.py +5753 -0
- torchax/ops/jax_reimplement.py +211 -0
- torchax/ops/jc10d.py +64 -0
- torchax/ops/jimage.py +122 -0
- torchax/ops/jlibrary.py +94 -0
- torchax/ops/jtorch.py +608 -0
- torchax/ops/jtorchvision_nms.py +268 -0
- torchax/ops/mappings.py +139 -0
- torchax/ops/op_base.py +137 -0
- torchax/ops/ops_registry.py +74 -0
- torchax/tensor.py +732 -0
- torchax/train.py +130 -0
- torchax/types.py +27 -0
- torchax/util.py +104 -0
- torchax/view.py +399 -0
- torchax-0.0.10.dev20251118.dist-info/METADATA +507 -0
- torchax-0.0.10.dev20251118.dist-info/RECORD +31 -0
- torchax-0.0.10.dev20251118.dist-info/WHEEL +4 -0
- torchax-0.0.10.dev20251118.dist-info/licenses/LICENSE +201 -0
torchax/tensor.py
ADDED
|
@@ -0,0 +1,732 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import contextlib
|
|
16
|
+
import itertools
|
|
17
|
+
import logging
|
|
18
|
+
import sys
|
|
19
|
+
import threading
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import jax
|
|
23
|
+
import jax.numpy as jnp
|
|
24
|
+
import numpy
|
|
25
|
+
import torch
|
|
26
|
+
import torch.distributed._functional_collectives
|
|
27
|
+
import torch.func
|
|
28
|
+
import torch.utils._mode_utils as mode_utils
|
|
29
|
+
import torch.utils._python_dispatch as torch_dispatch
|
|
30
|
+
import torch.utils._pytree as torch_pytree
|
|
31
|
+
|
|
32
|
+
from torchax import amp, config
|
|
33
|
+
from torchax.ops import mappings, ops_registry
|
|
34
|
+
from torchax.view import View
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OperatorNotFound(Exception):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@contextlib.contextmanager
|
|
44
|
+
def log_nested(env, message):
|
|
45
|
+
if env.config.debug_print_each_op:
|
|
46
|
+
print((" " * log_nested.level) + message, file=sys.stderr)
|
|
47
|
+
log_nested.level += 1
|
|
48
|
+
yield
|
|
49
|
+
log_nested.level -= 1
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
log_nested.level = 0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Tensor(torch.Tensor):
|
|
56
|
+
@staticmethod
|
|
57
|
+
def __new__(cls, elem, env, requires_grad=False):
|
|
58
|
+
dtype = mappings.j2t_dtype(elem.dtype)
|
|
59
|
+
shape = list(elem.shape)
|
|
60
|
+
for i, s in enumerate(shape):
|
|
61
|
+
if not isinstance(s, int):
|
|
62
|
+
shape[i] = 1
|
|
63
|
+
if dtype is None:
|
|
64
|
+
dtype = torch.float32
|
|
65
|
+
# dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1)
|
|
66
|
+
if not (dtype.is_floating_point or dtype.is_complex):
|
|
67
|
+
requires_grad = False
|
|
68
|
+
|
|
69
|
+
return torch.Tensor._make_wrapper_subclass(
|
|
70
|
+
cls,
|
|
71
|
+
shape,
|
|
72
|
+
dtype=dtype,
|
|
73
|
+
device="meta",
|
|
74
|
+
requires_grad=requires_grad,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self._elem = elem
|
|
80
|
+
self._env = env
|
|
81
|
+
|
|
82
|
+
def __str__(self):
|
|
83
|
+
return f"Tensor({str(type(self._elem))} {str(self._elem)})"
|
|
84
|
+
|
|
85
|
+
__repr__ = __str__
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def shape(self):
|
|
89
|
+
return torch.Size(self._elem.shape)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def ndim(self):
|
|
93
|
+
return len(self._elem.shape)
|
|
94
|
+
|
|
95
|
+
def flatten(self, start_dim=0, end_dim=-1):
|
|
96
|
+
if end_dim == -1:
|
|
97
|
+
end_dim = self.ndim
|
|
98
|
+
new_shape = self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :]
|
|
99
|
+
new_elem = jnp.reshape(self._elem, new_shape)
|
|
100
|
+
return Tensor(new_elem, self._env)
|
|
101
|
+
# return torch.reshape(self, new_shape)
|
|
102
|
+
|
|
103
|
+
def __setitem__(self, key, val):
|
|
104
|
+
key, val = self._env.t2j_iso((key, val))
|
|
105
|
+
self._elem = self._elem.at[key].set(val)
|
|
106
|
+
|
|
107
|
+
def type_as(self, other):
|
|
108
|
+
self._elem = self._elem.astype(other._elem.dtype)
|
|
109
|
+
return self
|
|
110
|
+
|
|
111
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
115
|
+
# TODO(hanq): figure out why is dispatch mode not sufficient
|
|
116
|
+
if func == torch.ops._c10d_functional.wait_tensor.default:
|
|
117
|
+
return args[0]._env.dispatch(func, types, args, kwargs)
|
|
118
|
+
if func == torch.ops.prim.device.default:
|
|
119
|
+
return torch.device("privateuseone", 0)
|
|
120
|
+
raise AssertionError(
|
|
121
|
+
"torchax Tensors can only do math within the torchax environment."
|
|
122
|
+
"Please wrap your code with `with torchax.default_env()` or "
|
|
123
|
+
"call torchax.enable_globally() before."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def detach(self):
|
|
127
|
+
return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
|
|
128
|
+
|
|
129
|
+
def numpy(self) -> numpy.ndarray:
|
|
130
|
+
import numpy as np
|
|
131
|
+
|
|
132
|
+
return np.array(self._elem)
|
|
133
|
+
|
|
134
|
+
def jax(self) -> jax.Array:
|
|
135
|
+
return self._elem
|
|
136
|
+
|
|
137
|
+
def torch(self) -> torch.Tensor:
|
|
138
|
+
return self._env.j2t_copy(self.jax())
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def dtype(self):
|
|
142
|
+
return mappings.j2t_dtype(self._elem.dtype)
|
|
143
|
+
|
|
144
|
+
def dim(self):
|
|
145
|
+
return self.ndim
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def device(self):
|
|
149
|
+
return torch.device("jax:0")
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def jax_device(self):
|
|
153
|
+
return self._elem.device
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def data(self):
|
|
157
|
+
logger.warning("In-place to .data modifications still results a copy on TPU")
|
|
158
|
+
return self
|
|
159
|
+
|
|
160
|
+
@data.setter
|
|
161
|
+
def data(self, other):
|
|
162
|
+
if isinstance(other, Tensor):
|
|
163
|
+
self._elem = other._elem
|
|
164
|
+
|
|
165
|
+
def apply_jax(self, jax_function, *args, **kwargs):
|
|
166
|
+
# Call a jax function on _elem
|
|
167
|
+
res = jax_function(self._elem, *args, **kwargs)
|
|
168
|
+
return self._env.j2t_iso(res)
|
|
169
|
+
|
|
170
|
+
def apply_jax_(self, jax_function, *args, **kwargs):
|
|
171
|
+
self._elem = jax_function(self._elem, *args, **kwargs)
|
|
172
|
+
return self
|
|
173
|
+
|
|
174
|
+
def tolist(self):
|
|
175
|
+
return self._elem.tolist()
|
|
176
|
+
|
|
177
|
+
def shard_(self, sharding):
|
|
178
|
+
self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def debug_accuracy(func, args, kwargs, current_output):
|
|
182
|
+
args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
|
|
183
|
+
torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
187
|
+
if "device" in kwargs_torch:
|
|
188
|
+
kwargs_torch["device"] = "cpu" # do the torch native for comparison
|
|
189
|
+
expected_out = func(*args_torch, **kwargs_torch)
|
|
190
|
+
|
|
191
|
+
flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
|
|
192
|
+
flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
|
|
193
|
+
|
|
194
|
+
for ex, real in zip(flattened_expected_out, flattened_current_out, strict=False):
|
|
195
|
+
if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
|
|
196
|
+
ex = ex.to(real.dtype)
|
|
197
|
+
try:
|
|
198
|
+
if isinstance(ex, torch.Tensor) and not torch.allclose(
|
|
199
|
+
ex, real, atol=1e-3, equal_nan=True
|
|
200
|
+
):
|
|
201
|
+
import pdb
|
|
202
|
+
|
|
203
|
+
pdb.set_trace()
|
|
204
|
+
except Exception:
|
|
205
|
+
import pdb
|
|
206
|
+
|
|
207
|
+
pdb.set_trace()
|
|
208
|
+
|
|
209
|
+
return True
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
|
|
213
|
+
def _display(a):
|
|
214
|
+
if isinstance(a, torch.Tensor):
|
|
215
|
+
return f"Tensor of {type(a)}: {a.dtype}{a.shape}"
|
|
216
|
+
elif isinstance(a, jax.Array):
|
|
217
|
+
return f"Jax Array of {type(a)}: {a.dtype}{a.shape}"
|
|
218
|
+
else:
|
|
219
|
+
return str(a)
|
|
220
|
+
|
|
221
|
+
kwargs = kwargs or {}
|
|
222
|
+
title = "DISPATCH" if is_dispatch else "FUNCTION"
|
|
223
|
+
args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else ""
|
|
224
|
+
kwargs_msg = (
|
|
225
|
+
"kwargs: " + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items())
|
|
226
|
+
if log_args
|
|
227
|
+
else ""
|
|
228
|
+
)
|
|
229
|
+
return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}"
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class XLAFunctionMode(torch.overrides.TorchFunctionMode):
|
|
233
|
+
"""Context manager that dispatches torch function calls to JAX."""
|
|
234
|
+
|
|
235
|
+
def __init__(self, env):
|
|
236
|
+
self.env = env
|
|
237
|
+
|
|
238
|
+
def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
239
|
+
message = f"FUNCTION: {_name_of_func(func)}"
|
|
240
|
+
if self.env.config.debug_print_each_op_operands:
|
|
241
|
+
message = message + "f"
|
|
242
|
+
message = _make_debug_msg(
|
|
243
|
+
False, self.env.config.debug_print_each_op_operands, func, args, kwargs
|
|
244
|
+
)
|
|
245
|
+
with log_nested(self.env, message):
|
|
246
|
+
try:
|
|
247
|
+
return self.env.dispatch(func, types, args, kwargs)
|
|
248
|
+
except OperatorNotFound:
|
|
249
|
+
pass
|
|
250
|
+
if _name_of_func(func) in ("rot90"): # skip rot90 with k%4==0 due to no change
|
|
251
|
+
if len(args) >= 2 and isinstance(args[1], int):
|
|
252
|
+
if (args[1]) % 4 == 0:
|
|
253
|
+
return args[0]
|
|
254
|
+
return func(*args, **(kwargs or {}))
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class XLADispatchMode(torch_dispatch.TorchDispatchMode):
|
|
258
|
+
def __init__(self, env):
|
|
259
|
+
self.env = env
|
|
260
|
+
|
|
261
|
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
262
|
+
message = _make_debug_msg(
|
|
263
|
+
True, self.env.config.debug_print_each_op_operands, func, args, kwargs
|
|
264
|
+
)
|
|
265
|
+
with log_nested(self.env, message):
|
|
266
|
+
if isinstance(func, torch._ops.OpOverloadPacket):
|
|
267
|
+
with self:
|
|
268
|
+
return func(*args, **kwargs)
|
|
269
|
+
# Only functions under these namespaces will be intercepted
|
|
270
|
+
if func.namespace not in (
|
|
271
|
+
"aten",
|
|
272
|
+
"_c10d_functional",
|
|
273
|
+
"torchvision",
|
|
274
|
+
"xla",
|
|
275
|
+
):
|
|
276
|
+
return func(*args, **kwargs)
|
|
277
|
+
return self.env.dispatch(func, types, args, kwargs)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _name_of_func(func):
|
|
281
|
+
if hasattr(func, "name"):
|
|
282
|
+
return func.name()
|
|
283
|
+
return func.__name__
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# Constructors that don't take other tensor as input
|
|
287
|
+
TENSOR_CONSTRUCTORS = {
|
|
288
|
+
torch.ones,
|
|
289
|
+
torch.zeros,
|
|
290
|
+
torch.empty,
|
|
291
|
+
torch.empty_strided,
|
|
292
|
+
torch.tensor,
|
|
293
|
+
torch.arange,
|
|
294
|
+
torch.eye,
|
|
295
|
+
torch.randn,
|
|
296
|
+
torch.rand,
|
|
297
|
+
torch.randint,
|
|
298
|
+
torch.full,
|
|
299
|
+
torch.as_tensor,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
# TODO(wen): use existing types, either from torch or jax
|
|
303
|
+
SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"]
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class RuntimeProperty:
|
|
307
|
+
mesh: Any
|
|
308
|
+
prng: Any
|
|
309
|
+
autocast_dtype: Any
|
|
310
|
+
|
|
311
|
+
def __init__(self, mesh, prng, autocast_dtype):
|
|
312
|
+
self.mesh = mesh
|
|
313
|
+
self.prng = prng
|
|
314
|
+
self.autocast_dtype = autocast_dtype
|
|
315
|
+
|
|
316
|
+
def override(self, **kwargs):
|
|
317
|
+
return OverrideProperty(self, kwargs)
|
|
318
|
+
|
|
319
|
+
def get_and_rotate_prng_key(self):
|
|
320
|
+
old_key = self.prng
|
|
321
|
+
new_prng_key, next_key = jax.random.split(old_key)
|
|
322
|
+
self.prng = new_prng_key
|
|
323
|
+
return next_key
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class OverrideProperty(RuntimeProperty):
|
|
327
|
+
def __init__(self, parent, override):
|
|
328
|
+
self.parent = parent
|
|
329
|
+
self._override = dict(override)
|
|
330
|
+
|
|
331
|
+
def __getattr__(self, name):
|
|
332
|
+
if name in self._override:
|
|
333
|
+
return self._override[name]
|
|
334
|
+
return getattr(self.parent, name)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class Environment(contextlib.ContextDecorator):
|
|
338
|
+
"""This class holds a set of configurations and "globals" needed
|
|
339
|
+
|
|
340
|
+
for executing torch program using jax.
|
|
341
|
+
Things included so far:
|
|
342
|
+
|
|
343
|
+
op registry
|
|
344
|
+
PRNGKey
|
|
345
|
+
Configs
|
|
346
|
+
|
|
347
|
+
Also helper functions to manipulate those.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def __init__(self, configuration=None):
|
|
351
|
+
self._function_mode = XLAFunctionMode(self)
|
|
352
|
+
self._dispatch_mode = XLADispatchMode(self)
|
|
353
|
+
|
|
354
|
+
# name is torch callable
|
|
355
|
+
self._ops = {}
|
|
356
|
+
self._decomps = {}
|
|
357
|
+
|
|
358
|
+
self.load_ops()
|
|
359
|
+
|
|
360
|
+
_mesh = None
|
|
361
|
+
self.config = configuration or config.Configuration()
|
|
362
|
+
|
|
363
|
+
self.enabled = False
|
|
364
|
+
|
|
365
|
+
autocast_dtype = None
|
|
366
|
+
|
|
367
|
+
_prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
|
|
368
|
+
self._property = threading.local()
|
|
369
|
+
self._initial_content = RuntimeProperty(
|
|
370
|
+
mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def param(self):
|
|
375
|
+
if not hasattr(self._property, "content"):
|
|
376
|
+
self._property.content = [self._initial_content]
|
|
377
|
+
return self._property.content[-1]
|
|
378
|
+
|
|
379
|
+
def manual_seed(self, key):
|
|
380
|
+
if isinstance(key, torch.Tensor):
|
|
381
|
+
assert key.ndim == 0, "manual seed can only take scalars"
|
|
382
|
+
assert not key.dtype.is_floating_point, "manual seed can only be integers"
|
|
383
|
+
|
|
384
|
+
if isinstance(key, Tensor):
|
|
385
|
+
key = key._elem
|
|
386
|
+
else:
|
|
387
|
+
key = key.item()
|
|
388
|
+
jax_key = jax.random.PRNGKey(key)
|
|
389
|
+
new_prop = self.param.override(prng=jax_key)
|
|
390
|
+
self._property.content.append(new_prop)
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def prng_key(self):
|
|
394
|
+
return self.param.prng
|
|
395
|
+
|
|
396
|
+
def _should_use_torchax_tensor(self, device):
|
|
397
|
+
if device is None:
|
|
398
|
+
device = torch.get_default_device()
|
|
399
|
+
|
|
400
|
+
if isinstance(device, torch.device):
|
|
401
|
+
device = device.type
|
|
402
|
+
|
|
403
|
+
if ":" in device:
|
|
404
|
+
device = device.split(":")[0]
|
|
405
|
+
|
|
406
|
+
match device:
|
|
407
|
+
case "cpu":
|
|
408
|
+
return False
|
|
409
|
+
case "cuda":
|
|
410
|
+
return self.config.treat_cuda_as_jax_device
|
|
411
|
+
case "jax":
|
|
412
|
+
return True
|
|
413
|
+
case "privateuseone":
|
|
414
|
+
return True
|
|
415
|
+
case "meta":
|
|
416
|
+
return self.enabled
|
|
417
|
+
return False
|
|
418
|
+
|
|
419
|
+
def load_ops(self):
|
|
420
|
+
from torchax.ops import jaten, jc10d, jtorch, jtorchvision_nms # noqa: F401
|
|
421
|
+
|
|
422
|
+
for k, v in itertools.chain(
|
|
423
|
+
ops_registry.all_aten_ops.items(), ops_registry.all_torch_functions.items()
|
|
424
|
+
):
|
|
425
|
+
if v.is_jax_function:
|
|
426
|
+
self._ops[k] = v
|
|
427
|
+
else:
|
|
428
|
+
self._decomps[k] = v
|
|
429
|
+
|
|
430
|
+
from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION
|
|
431
|
+
|
|
432
|
+
for k, v in DECOMPOSITIONS.items():
|
|
433
|
+
if k not in self._decomps:
|
|
434
|
+
self._decomps[k] = ops_registry.Operator(
|
|
435
|
+
k,
|
|
436
|
+
v,
|
|
437
|
+
is_jax_function=False,
|
|
438
|
+
is_user_defined=False,
|
|
439
|
+
needs_env=False,
|
|
440
|
+
is_view_op=k in MUTABLE_DECOMPOSITION,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
def _get_op_or_decomp(self, func):
|
|
444
|
+
def _get_from_dict(op_dict, op):
|
|
445
|
+
op = op_dict.get(func)
|
|
446
|
+
if op is None and isinstance(func, torch._ops.OpOverloadPacket):
|
|
447
|
+
op = op_dict.get(func.default)
|
|
448
|
+
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
449
|
+
op = op_dict.get(func.overloadpacket)
|
|
450
|
+
return op
|
|
451
|
+
|
|
452
|
+
op = _get_from_dict(self._ops, func)
|
|
453
|
+
|
|
454
|
+
if op is None:
|
|
455
|
+
# fallback to decompose
|
|
456
|
+
op = _get_from_dict(self._decomps, func)
|
|
457
|
+
|
|
458
|
+
if op is None:
|
|
459
|
+
raise OperatorNotFound(
|
|
460
|
+
f"Operator with name {_name_of_func(func)} has no lowering"
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
return op
|
|
464
|
+
|
|
465
|
+
def _is_same_device(self, the_tensor, new_device):
|
|
466
|
+
if new_device is None:
|
|
467
|
+
return True
|
|
468
|
+
if new_device == "meta" and the_tensor.device.type == "jax":
|
|
469
|
+
return True
|
|
470
|
+
if the_tensor.device.type != new_device:
|
|
471
|
+
if the_tensor.device.type == "cuda":
|
|
472
|
+
return self.config.treat_cuda_as_jax_device
|
|
473
|
+
return False
|
|
474
|
+
return True
|
|
475
|
+
|
|
476
|
+
def _to_copy(self, the_tensor, new_dtype, new_device):
|
|
477
|
+
if isinstance(the_tensor, View):
|
|
478
|
+
the_tensor = the_tensor.torch()
|
|
479
|
+
if isinstance(new_device, torch.device):
|
|
480
|
+
new_device = new_device.type
|
|
481
|
+
res = the_tensor
|
|
482
|
+
if not self._is_same_device(the_tensor, new_device):
|
|
483
|
+
if isinstance(the_tensor, Tensor):
|
|
484
|
+
torch_tensor = self.j2t_copy(the_tensor._elem)
|
|
485
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
486
|
+
return torch_tensor.to(device=new_device, dtype=new_dtype)
|
|
487
|
+
else:
|
|
488
|
+
arr = self.t2j_copy(the_tensor)
|
|
489
|
+
res = Tensor(arr, self, the_tensor.requires_grad)
|
|
490
|
+
|
|
491
|
+
if new_dtype is not None and new_dtype != res.dtype:
|
|
492
|
+
if isinstance(res, Tensor):
|
|
493
|
+
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
|
|
494
|
+
else:
|
|
495
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
496
|
+
return res.to(device=new_device, dtype=new_dtype)
|
|
497
|
+
return res
|
|
498
|
+
|
|
499
|
+
def get_and_rotate_prng_key(self, generator: torch.Generator | None = None):
|
|
500
|
+
if generator is not None:
|
|
501
|
+
return jax.random.PRNGKey(generator.initial_seed() % (2**63))
|
|
502
|
+
return self.param.get_and_rotate_prng_key()
|
|
503
|
+
|
|
504
|
+
def _handle_tensor_constructor(self, func, args, kwargs):
|
|
505
|
+
device = kwargs.get("device")
|
|
506
|
+
if self._should_use_torchax_tensor(device):
|
|
507
|
+
# don't set default device, let caller set it
|
|
508
|
+
requires_grad = kwargs.get("requires_grad", False)
|
|
509
|
+
op = self._get_op_or_decomp(func)
|
|
510
|
+
if op.needs_env:
|
|
511
|
+
kwargs["env"] = self
|
|
512
|
+
if op.is_jax_function:
|
|
513
|
+
(args, kwargs) = self.t2j_iso((args, kwargs))
|
|
514
|
+
res = op.func(*args, **kwargs)
|
|
515
|
+
if isinstance(res, jax.Array):
|
|
516
|
+
res = Tensor(res, self, requires_grad)
|
|
517
|
+
return res
|
|
518
|
+
else:
|
|
519
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
520
|
+
return func(*args, **kwargs)
|
|
521
|
+
|
|
522
|
+
def _torch_Tensor_to(self, args, kwargs):
|
|
523
|
+
the_tensor = args[0]
|
|
524
|
+
args = args[1:]
|
|
525
|
+
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
|
|
526
|
+
dtype = args[0].dtype
|
|
527
|
+
device = args[0].device
|
|
528
|
+
return self._to_copy(the_tensor, dtype, device)
|
|
529
|
+
device = kwargs.get("device")
|
|
530
|
+
dtype = kwargs.get("dtype")
|
|
531
|
+
# args like pin_memory etc that we will ignore
|
|
532
|
+
args = list(filter(lambda x: not isinstance(x, bool), args))
|
|
533
|
+
if len(args) >= 2:
|
|
534
|
+
device, dtype, *_ = args
|
|
535
|
+
elif len(args) == 1 and isinstance(args[0], torch.dtype):
|
|
536
|
+
dtype = args[0]
|
|
537
|
+
elif len(args) == 1:
|
|
538
|
+
device = args[0]
|
|
539
|
+
return self._to_copy(the_tensor, dtype, device)
|
|
540
|
+
|
|
541
|
+
def dispatch(self, func, types, args, kwargs):
|
|
542
|
+
kwargs = kwargs or {}
|
|
543
|
+
if func in TENSOR_CONSTRUCTORS:
|
|
544
|
+
return self._handle_tensor_constructor(func, args, kwargs)
|
|
545
|
+
if func in (
|
|
546
|
+
torch.Tensor.to,
|
|
547
|
+
torch.ops.aten.lift_fresh.default,
|
|
548
|
+
torch.ops.aten._to_copy,
|
|
549
|
+
torch.ops.aten._to_copy.default,
|
|
550
|
+
):
|
|
551
|
+
return self._torch_Tensor_to(args, kwargs)
|
|
552
|
+
|
|
553
|
+
# If the func doesn't act on Tensor, and is not a tensor constructor,
|
|
554
|
+
# We should skip and let torch handle it.
|
|
555
|
+
|
|
556
|
+
tensor_args = [
|
|
557
|
+
t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)
|
|
558
|
+
]
|
|
559
|
+
|
|
560
|
+
def is_not_torchax_tensor(x):
|
|
561
|
+
return not isinstance(x, Tensor) and not isinstance(x, View)
|
|
562
|
+
|
|
563
|
+
if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args):
|
|
564
|
+
res = func(*args, **kwargs)
|
|
565
|
+
return res
|
|
566
|
+
|
|
567
|
+
with jax.named_scope(_name_of_func(func)):
|
|
568
|
+
op = self._get_op_or_decomp(func)
|
|
569
|
+
|
|
570
|
+
old_args, old_kwargs = args, kwargs
|
|
571
|
+
with self._dispatch_mode:
|
|
572
|
+
args, kwargs = torch_pytree.tree_map_only(
|
|
573
|
+
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
574
|
+
torch.distributed._functional_collectives.wait_tensor,
|
|
575
|
+
(args, kwargs),
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
try:
|
|
579
|
+
if not op.is_view_op:
|
|
580
|
+
args, kwargs = self.v2t_iso((args, kwargs))
|
|
581
|
+
|
|
582
|
+
with self:
|
|
583
|
+
if self.param.autocast_dtype is not None:
|
|
584
|
+
autocast_policy = amp.autocast_policy.get(func)
|
|
585
|
+
if autocast_policy is not None:
|
|
586
|
+
args, kwargs = amp.execute_policy(
|
|
587
|
+
autocast_policy, args, kwargs, self.param.autocast_dtype
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
if op.is_jax_function:
|
|
591
|
+
args, kwargs = self.t2j_iso((args, kwargs))
|
|
592
|
+
except AssertionError:
|
|
593
|
+
if self.config.debug_mixed_tensor:
|
|
594
|
+
breakpoint()
|
|
595
|
+
else:
|
|
596
|
+
raise
|
|
597
|
+
|
|
598
|
+
if op.needs_env:
|
|
599
|
+
kwargs["env"] = self
|
|
600
|
+
|
|
601
|
+
if op.is_jax_function:
|
|
602
|
+
res = op.func(*args, **kwargs)
|
|
603
|
+
else:
|
|
604
|
+
# enable dispatch mode because this op could be a composite autograd op
|
|
605
|
+
# meaning, it will decompose in C++
|
|
606
|
+
with self._dispatch_mode:
|
|
607
|
+
res = op.func(*args, **kwargs)
|
|
608
|
+
|
|
609
|
+
if op.is_jax_function:
|
|
610
|
+
res = self.j2t_iso(res)
|
|
611
|
+
|
|
612
|
+
if self.config.force_materialize_views and isinstance(res, View):
|
|
613
|
+
res = res.torch()
|
|
614
|
+
|
|
615
|
+
if self.config.debug_accuracy_for_each_op:
|
|
616
|
+
debug_accuracy(func, old_args, old_kwargs, res)
|
|
617
|
+
return res
|
|
618
|
+
|
|
619
|
+
def enable_torch_modes(self):
|
|
620
|
+
self._dispatch_mode.__enter__()
|
|
621
|
+
self._function_mode.__enter__()
|
|
622
|
+
self.enabled = True
|
|
623
|
+
|
|
624
|
+
def disable_torch_modes(self, *exc):
|
|
625
|
+
if not exc:
|
|
626
|
+
exc = (None, None, None)
|
|
627
|
+
self._function_mode.__exit__(*exc)
|
|
628
|
+
self._dispatch_mode.__exit__(*exc)
|
|
629
|
+
self.enabled = False
|
|
630
|
+
|
|
631
|
+
def __enter__(self):
|
|
632
|
+
self.enable_torch_modes()
|
|
633
|
+
return self
|
|
634
|
+
|
|
635
|
+
def __exit__(self, *exc):
|
|
636
|
+
self.disable_torch_modes(*exc)
|
|
637
|
+
|
|
638
|
+
def _move_one_value(self, val):
|
|
639
|
+
if isinstance(val, torch.nn.Module):
|
|
640
|
+
with self:
|
|
641
|
+
return val.to("jax")
|
|
642
|
+
if isinstance(val, Tensor):
|
|
643
|
+
return val
|
|
644
|
+
if isinstance(val, torch.Tensor):
|
|
645
|
+
return Tensor(self.t2j_copy(val), self)
|
|
646
|
+
return val
|
|
647
|
+
|
|
648
|
+
def to_xla(self, torchvalues):
|
|
649
|
+
# tensors are torch.Tensors (not XLATensor)
|
|
650
|
+
res = torch_pytree.tree_map(self._move_one_value, torchvalues)
|
|
651
|
+
return res
|
|
652
|
+
|
|
653
|
+
def t2j_iso(self, torchtensors):
|
|
654
|
+
"""Convert torchax Tensor to jax array.
|
|
655
|
+
|
|
656
|
+
This function will not copy, will just unwrap the inner jax array out.
|
|
657
|
+
Note: iso is short for "isomorphic"
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
def to_jax(x):
|
|
661
|
+
if (
|
|
662
|
+
self.config.allow_mixed_math_with_scalar_tensor
|
|
663
|
+
and not isinstance(x, Tensor)
|
|
664
|
+
and not isinstance(x, View)
|
|
665
|
+
):
|
|
666
|
+
if x.squeeze().ndim == 0:
|
|
667
|
+
return x.item()
|
|
668
|
+
if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
669
|
+
x = x.wait()
|
|
670
|
+
assert isinstance(x, Tensor) or isinstance(x, View), (
|
|
671
|
+
f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor"
|
|
672
|
+
)
|
|
673
|
+
return x.jax()
|
|
674
|
+
|
|
675
|
+
res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
|
|
676
|
+
return res
|
|
677
|
+
|
|
678
|
+
def v2t_iso(self, views):
|
|
679
|
+
def to_tensor(x):
|
|
680
|
+
if isinstance(x, View):
|
|
681
|
+
return x.torch()
|
|
682
|
+
return x
|
|
683
|
+
|
|
684
|
+
res = torch_pytree.tree_map_only(View, to_tensor, views)
|
|
685
|
+
return res
|
|
686
|
+
|
|
687
|
+
def j2t_iso(self, jaxarray):
|
|
688
|
+
"""Convert jax array to torchax Tensor.
|
|
689
|
+
|
|
690
|
+
This function will not copy, will just wrap the jax array with a torchax Tensor
|
|
691
|
+
Note: iso is short for "isomorphic"
|
|
692
|
+
"""
|
|
693
|
+
return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), jaxarray)
|
|
694
|
+
|
|
695
|
+
def j2t_copy(self, args):
|
|
696
|
+
"""Convert torch.Tensor in cpu to a jax array
|
|
697
|
+
|
|
698
|
+
This might involves copying the data (depending if dlpack is enabled)
|
|
699
|
+
"""
|
|
700
|
+
return torch_pytree.tree_map_only(
|
|
701
|
+
jax.Array,
|
|
702
|
+
lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion),
|
|
703
|
+
args,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
def t2j_copy(self, args):
|
|
707
|
+
"""Convert jax array to torch.Tensor in cpu.
|
|
708
|
+
|
|
709
|
+
This might involves copying the data (depending if dlpack is enabled)
|
|
710
|
+
"""
|
|
711
|
+
return torch_pytree.tree_map_only(
|
|
712
|
+
torch.Tensor,
|
|
713
|
+
lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion),
|
|
714
|
+
args,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def override_op_definition(self, op_to_override, op_impl, is_view_op=False):
|
|
718
|
+
self._ops[op_to_override] = ops_registry.Operator(
|
|
719
|
+
op_to_override,
|
|
720
|
+
op_impl,
|
|
721
|
+
is_jax_function=False,
|
|
722
|
+
is_user_defined=True,
|
|
723
|
+
needs_env=False,
|
|
724
|
+
is_view_op=is_view_op,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
@contextlib.contextmanager
|
|
728
|
+
def override_property(self, **kwargs):
|
|
729
|
+
new_prop = self.param.override(**kwargs)
|
|
730
|
+
self._property.content.append(new_prop)
|
|
731
|
+
yield
|
|
732
|
+
self._property.content.pop()
|