torchax 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +38 -0
- torchax/__init__.py +124 -0
- torchax/config.py +19 -0
- torchax/decompositions.py +308 -0
- torchax/device_module.py +20 -0
- torchax/distributed.py +246 -0
- torchax/environment.py +2 -0
- torchax/export.py +236 -0
- torchax/interop.py +209 -0
- torchax/ops/__init__.py +10 -0
- torchax/ops/jaten.py +5212 -0
- torchax/ops/jax_reimplement.py +169 -0
- torchax/ops/jc10d.py +51 -0
- torchax/ops/jlibrary.py +73 -0
- torchax/ops/jtorch.py +427 -0
- torchax/ops/jtorchvision_nms.py +245 -0
- torchax/ops/mappings.py +97 -0
- torchax/ops/op_base.py +104 -0
- torchax/ops/ops_registry.py +50 -0
- torchax/tensor.py +557 -0
- torchax/tf_integration.py +119 -0
- torchax/train.py +120 -0
- torchax/types.py +12 -0
- torchax-0.0.4.dist-info/METADATA +341 -0
- torchax-0.0.4.dist-info/RECORD +27 -0
- torchax-0.0.4.dist-info/WHEEL +4 -0
- torchax-0.0.4.dist-info/licenses/LICENSE +28 -0
torchax/tensor.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
import contextlib
|
|
4
|
+
from typing import Optional, Any
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy
|
|
8
|
+
import torch
|
|
9
|
+
import torch.distributed._functional_collectives
|
|
10
|
+
import torch.func
|
|
11
|
+
import torch.utils._mode_utils as mode_utils
|
|
12
|
+
import torch.utils._python_dispatch as torch_dispatch
|
|
13
|
+
import torch.utils._pytree as torch_pytree
|
|
14
|
+
|
|
15
|
+
from torchax import config
|
|
16
|
+
from torchax.ops import mappings, ops_registry
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OperatorNotFound(Exception):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def wrap(jaxarray):
|
|
26
|
+
return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def unwrap(torchtensors):
|
|
30
|
+
return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def t2j(t):
|
|
34
|
+
if isinstance(t, Tensor):
|
|
35
|
+
return t._elem
|
|
36
|
+
return mappings.t2j(t)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def j2t(x):
|
|
40
|
+
return mappings.j2t(x)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def t2j_dtype(dtype):
|
|
44
|
+
return mappings.t2j_dtype(dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def j2t_dtype(dtype):
|
|
48
|
+
return mappings.j2t_dtype(dtype)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@contextlib.contextmanager
|
|
52
|
+
def log_nested(env, message):
|
|
53
|
+
if env.config.debug_print_each_op:
|
|
54
|
+
print((' ' * log_nested.level) + message, file=sys.stderr)
|
|
55
|
+
log_nested.level += 1
|
|
56
|
+
yield
|
|
57
|
+
log_nested.level -= 1
|
|
58
|
+
|
|
59
|
+
log_nested.level = 0
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Tensor(torch.Tensor):
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def __new__(cls, elem, env):
|
|
66
|
+
dtype = j2t_dtype(elem.dtype)
|
|
67
|
+
shape = list(elem.shape)
|
|
68
|
+
for i, s in enumerate(shape):
|
|
69
|
+
if not isinstance(s, int):
|
|
70
|
+
shape[i] = 1
|
|
71
|
+
if dtype is None:
|
|
72
|
+
dtype = torch.float32
|
|
73
|
+
return torch.Tensor._make_wrapper_subclass(
|
|
74
|
+
cls,
|
|
75
|
+
shape,
|
|
76
|
+
dtype=dtype,
|
|
77
|
+
device='meta',
|
|
78
|
+
requires_grad=False,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def __init__(self, elem: jax.Array, env: 'Environment'):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self._elem = elem
|
|
84
|
+
self._env = env
|
|
85
|
+
|
|
86
|
+
def __str__(self):
|
|
87
|
+
return "Tensor({} {})".format(str(type(self._elem)), str(self._elem))
|
|
88
|
+
|
|
89
|
+
__repr__ = __str__
|
|
90
|
+
|
|
91
|
+
def __jax_array__(self):
|
|
92
|
+
return self._elem
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def shape(self):
|
|
96
|
+
return self._elem.shape
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def ndim(self):
|
|
100
|
+
return len(self._elem.shape)
|
|
101
|
+
|
|
102
|
+
def flatten(self, start_dim=0, end_dim=-1):
|
|
103
|
+
if end_dim == -1:
|
|
104
|
+
end_dim = self.ndim
|
|
105
|
+
new_shape = (
|
|
106
|
+
self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:])
|
|
107
|
+
new_elem = jnp.reshape(self._elem, new_shape)
|
|
108
|
+
return Tensor(new_elem, self._env)
|
|
109
|
+
# return torch.reshape(self, new_shape)
|
|
110
|
+
|
|
111
|
+
def __setitem__(self, key, val):
|
|
112
|
+
key, val = self._env.t2j_iso((key, val))
|
|
113
|
+
self._elem = self._elem.at[key].set(val)
|
|
114
|
+
|
|
115
|
+
def type_as(self, other):
|
|
116
|
+
self._elem = self._elem.astype(other._elem.dtype)
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
123
|
+
env = None
|
|
124
|
+
for arg in torch_pytree.arg_tree_leaves(*args, **kwargs):
|
|
125
|
+
if isinstance(arg, Tensor):
|
|
126
|
+
env = arg._env
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
with env:
|
|
130
|
+
return func(*args, **(kwargs or {}))
|
|
131
|
+
|
|
132
|
+
def detach(self):
|
|
133
|
+
return Tensor(jax.lax.stop_gradient(self.jax()), self._env)
|
|
134
|
+
|
|
135
|
+
def numpy(self) -> numpy.ndarray:
|
|
136
|
+
import numpy as np
|
|
137
|
+
|
|
138
|
+
return np.array(self._elem)
|
|
139
|
+
|
|
140
|
+
def jax(self) -> jax.Array:
|
|
141
|
+
return self._elem
|
|
142
|
+
|
|
143
|
+
def torch(self) -> torch.Tensor:
|
|
144
|
+
return j2t(self.jax())
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def dtype(self):
|
|
148
|
+
return j2t_dtype(self._elem.dtype)
|
|
149
|
+
|
|
150
|
+
def dim(self):
|
|
151
|
+
return self.ndim
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def device(self):
|
|
155
|
+
return torch.device('jax:0')
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def jax_device(self):
|
|
159
|
+
return self._elem.device
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def data(self):
|
|
163
|
+
logger.warn("In-place to .data modifications still results a copy on TPU")
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
@data.setter
|
|
167
|
+
def data(self, other):
|
|
168
|
+
if isinstance(other, Tensor):
|
|
169
|
+
self._elem = other._elem
|
|
170
|
+
|
|
171
|
+
def apply_jax(self, jax_function, *args, **kwargs):
|
|
172
|
+
# Call a jax function on _elem
|
|
173
|
+
res = jax_function(self._elem, *args, **kwargs)
|
|
174
|
+
return self._env.j2t_iso(res)
|
|
175
|
+
|
|
176
|
+
def apply_jax_(self, jax_function, *args, **kwargs):
|
|
177
|
+
self._elem = jax_function(self._elem, *args, **kwargs)
|
|
178
|
+
return self
|
|
179
|
+
|
|
180
|
+
def tolist(self):
|
|
181
|
+
return self._elem.tolist()
|
|
182
|
+
|
|
183
|
+
def shard_(self, sharding):
|
|
184
|
+
self.apply_jax_(jax.lax.with_sharding_constraint, sharding)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def debug_accuracy(func, args, kwargs, current_output):
|
|
188
|
+
args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only(
|
|
189
|
+
torch.Tensor, lambda x: j2t(x._elem), (args, kwargs, current_output))
|
|
190
|
+
|
|
191
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
192
|
+
if 'device' in kwargs_torch:
|
|
193
|
+
kwargs_torch['device'] = 'cpu' # do the torch native for comparison
|
|
194
|
+
expected_out = func(*args_torch, **kwargs_torch)
|
|
195
|
+
|
|
196
|
+
flattened_current_out, _ = torch_pytree.tree_flatten(out_torch)
|
|
197
|
+
flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out)
|
|
198
|
+
|
|
199
|
+
for ex, real in zip(flattened_expected_out, flattened_current_out):
|
|
200
|
+
if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype:
|
|
201
|
+
ex = ex.to(real.dtype)
|
|
202
|
+
try:
|
|
203
|
+
if (isinstance(ex, torch.Tensor) and
|
|
204
|
+
not torch.allclose(ex, real, atol=1e-3, equal_nan=True)):
|
|
205
|
+
import pdb
|
|
206
|
+
|
|
207
|
+
pdb.set_trace()
|
|
208
|
+
except:
|
|
209
|
+
import pdb
|
|
210
|
+
|
|
211
|
+
pdb.set_trace()
|
|
212
|
+
|
|
213
|
+
return True
|
|
214
|
+
|
|
215
|
+
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
|
|
216
|
+
def _display(a):
|
|
217
|
+
if isinstance(a, torch.Tensor):
|
|
218
|
+
return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
|
|
219
|
+
elif isinstance(a, jax.Array):
|
|
220
|
+
return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
|
|
221
|
+
else:
|
|
222
|
+
return str(a)
|
|
223
|
+
|
|
224
|
+
kwargs = kwargs or {}
|
|
225
|
+
title = 'DISPATCH' if is_dispatch else 'FUNCTION'
|
|
226
|
+
args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
|
|
227
|
+
kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
|
|
228
|
+
return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class XLAFunctionMode(torch.overrides.TorchFunctionMode):
|
|
232
|
+
"""Context manager that dispatches torch function calls to JAX."""
|
|
233
|
+
|
|
234
|
+
def __init__(self, env):
|
|
235
|
+
self.env = env
|
|
236
|
+
|
|
237
|
+
def __torch_function__(self,
|
|
238
|
+
func,
|
|
239
|
+
types,
|
|
240
|
+
args=(),
|
|
241
|
+
kwargs=None) -> torch.Tensor:
|
|
242
|
+
message = f'FUNCTION: {_name_of_func(func)}'
|
|
243
|
+
if self.env.config.debug_print_each_op_operands:
|
|
244
|
+
message = message + 'f'
|
|
245
|
+
message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
|
|
246
|
+
func, args, kwargs)
|
|
247
|
+
with log_nested(self.env, message):
|
|
248
|
+
try:
|
|
249
|
+
return self.env.dispatch(func, types, args, kwargs)
|
|
250
|
+
except OperatorNotFound:
|
|
251
|
+
pass
|
|
252
|
+
if _name_of_func(func) in ('rot90'): # skip rot90 with k%4==0 due to no change
|
|
253
|
+
if len(args) >= 2 and type(args[1]) == int:
|
|
254
|
+
if ((args[1])%4 == 0):
|
|
255
|
+
return args[0]
|
|
256
|
+
return func(*args, **(kwargs or {}))
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class XLADispatchMode(torch_dispatch.TorchDispatchMode):
|
|
260
|
+
|
|
261
|
+
def __init__(self, env):
|
|
262
|
+
self.env = env
|
|
263
|
+
|
|
264
|
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
265
|
+
message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
|
|
266
|
+
func, args, kwargs)
|
|
267
|
+
with log_nested(self.env, message):
|
|
268
|
+
if isinstance(func, torch._ops.OpOverloadPacket):
|
|
269
|
+
with self:
|
|
270
|
+
return func(*args, **kwargs)
|
|
271
|
+
if func.namespace not in ('aten', '_c10d_functional', 'torchvision'):
|
|
272
|
+
return func(*args, **kwargs)
|
|
273
|
+
return self.env.dispatch(func, types, args, kwargs)
|
|
274
|
+
|
|
275
|
+
def _name_of_func(func):
|
|
276
|
+
if hasattr(func, 'name'):
|
|
277
|
+
return func.name()
|
|
278
|
+
return func.__name__
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# Constructors that don't take other tensor as input
|
|
282
|
+
TENSOR_CONSTRUCTORS = {
|
|
283
|
+
torch.ones,
|
|
284
|
+
torch.zeros,
|
|
285
|
+
torch.empty,
|
|
286
|
+
torch.empty_strided,
|
|
287
|
+
torch.tensor,
|
|
288
|
+
torch.arange,
|
|
289
|
+
torch.eye,
|
|
290
|
+
torch.randn,
|
|
291
|
+
torch.rand,
|
|
292
|
+
torch.randint,
|
|
293
|
+
torch.full,
|
|
294
|
+
torch.as_tensor,
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class Environment(contextlib.ContextDecorator):
|
|
299
|
+
"""This class holds a set of configurations and "globals" needed
|
|
300
|
+
|
|
301
|
+
for executing torch program using jax.
|
|
302
|
+
Things included so far:
|
|
303
|
+
|
|
304
|
+
op registry
|
|
305
|
+
PRNGKey
|
|
306
|
+
Configs
|
|
307
|
+
|
|
308
|
+
Also helper functions to manipulate those.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
_prng_key: jax.random.PRNGKey
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def __init__(self, configuration=None):
|
|
315
|
+
self._function_mode = XLAFunctionMode(self)
|
|
316
|
+
self._dispatch_mode = XLADispatchMode(self)
|
|
317
|
+
|
|
318
|
+
# name is torch callable
|
|
319
|
+
self._ops = {}
|
|
320
|
+
self.load_ops()
|
|
321
|
+
|
|
322
|
+
self._mesh = None
|
|
323
|
+
self.config = configuration or config.Configuration()
|
|
324
|
+
|
|
325
|
+
self._manually_entered = False
|
|
326
|
+
self.enabled = False
|
|
327
|
+
self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
|
|
328
|
+
|
|
329
|
+
def get_as_jax_device(self, device: Any):
|
|
330
|
+
if device is None:
|
|
331
|
+
device = torch.get_default_device()
|
|
332
|
+
|
|
333
|
+
if isinstance(device, torch.device):
|
|
334
|
+
device = str(device)
|
|
335
|
+
|
|
336
|
+
if (not self.config.use_torch_native_for_cpu_tensor and
|
|
337
|
+
device.startswith('cpu')):
|
|
338
|
+
return jax.devices('cpu')[0]
|
|
339
|
+
|
|
340
|
+
if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
|
|
341
|
+
return jax.local_devices()[0]
|
|
342
|
+
|
|
343
|
+
if device.startswith('jax'):
|
|
344
|
+
return jax.local_devices()[0]
|
|
345
|
+
|
|
346
|
+
return None # fallback to torch
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def load_ops(self):
|
|
351
|
+
from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms
|
|
352
|
+
self._ops.update(ops_registry.all_aten_ops)
|
|
353
|
+
self._ops.update(ops_registry.all_torch_functions)
|
|
354
|
+
|
|
355
|
+
decomps = torch._decomp.core_aten_decompositions()
|
|
356
|
+
from torchax.decompositions import EXTRA_DECOMP
|
|
357
|
+
decomps.update(EXTRA_DECOMP)
|
|
358
|
+
for k, v in decomps.items():
|
|
359
|
+
if k not in self._ops:
|
|
360
|
+
self._ops[k] = ops_registry.Operator(
|
|
361
|
+
k,
|
|
362
|
+
v,
|
|
363
|
+
is_jax_function=False,
|
|
364
|
+
is_user_defined=False,
|
|
365
|
+
needs_env=False
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
def _to_copy(self, the_tensor, new_dtype, new_device):
|
|
369
|
+
if isinstance(the_tensor, Tensor):
|
|
370
|
+
arr = the_tensor.jax()
|
|
371
|
+
if new_dtype is not None and new_dtype != arr.dtype:
|
|
372
|
+
arr = arr.astype(mappings.t2j_dtype(new_dtype))
|
|
373
|
+
if new_device is not None:
|
|
374
|
+
# convert xla tensor to other device
|
|
375
|
+
# only supported is CPU
|
|
376
|
+
if str(new_device).startswith('cpu'):
|
|
377
|
+
# converting to a non-jax device: let torch native handle it
|
|
378
|
+
torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
|
|
379
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
380
|
+
return torch_tensor.to(new_device)
|
|
381
|
+
else:
|
|
382
|
+
if new_dtype is not None and new_dtype != the_tensor.dtype:
|
|
383
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
384
|
+
the_tensor = the_tensor.to(new_dtype)
|
|
385
|
+
jax_device = self.get_as_jax_device(new_device)
|
|
386
|
+
if jax_device:
|
|
387
|
+
arr = t2j(the_tensor)
|
|
388
|
+
arr = jax.device_put(arr, jax_device)
|
|
389
|
+
else:
|
|
390
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
391
|
+
return the_tensor.to(new_device)
|
|
392
|
+
|
|
393
|
+
return Tensor(arr, self)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
|
|
397
|
+
# Always use the default `randint` to get the next seed
|
|
398
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
399
|
+
next_key = torch.randint(
|
|
400
|
+
0, 2**32, (), dtype=torch.uint32, generator=generator).numpy()
|
|
401
|
+
|
|
402
|
+
return jax.random.key(next_key)
|
|
403
|
+
|
|
404
|
+
def _handle_tensor_constructor(self, func, args, kwargs):
|
|
405
|
+
device = kwargs.get('device')
|
|
406
|
+
jax_device = self.get_as_jax_device(device)
|
|
407
|
+
# TODO(qihqi) figure out better ways for device propagation
|
|
408
|
+
if not self._manually_entered and jax_device is None:
|
|
409
|
+
# let torch handle it
|
|
410
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
411
|
+
return func(*args, **kwargs)
|
|
412
|
+
with jax.default_device(jax_device):
|
|
413
|
+
op = self._ops.get(func)
|
|
414
|
+
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
415
|
+
op = self._ops.get(func.overloadpacket)
|
|
416
|
+
res = op.func(*args, **kwargs)
|
|
417
|
+
if isinstance(res, jax.Array):
|
|
418
|
+
res = Tensor(res, self)
|
|
419
|
+
return res
|
|
420
|
+
|
|
421
|
+
def _torch_Tensor_to(self, args, kwargs):
|
|
422
|
+
the_tensor = args[0]
|
|
423
|
+
args = args[1:]
|
|
424
|
+
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
|
|
425
|
+
dtype = args[0].dtype
|
|
426
|
+
device = args[0].device
|
|
427
|
+
return self._to_copy(the_tensor, dtype, device)
|
|
428
|
+
device = kwargs.get('device')
|
|
429
|
+
dtype = kwargs.get('dtype')
|
|
430
|
+
# args like pin_memory etc that we will ignore
|
|
431
|
+
args = list(filter(lambda x: not isinstance(x, bool), args))
|
|
432
|
+
if len(args) >= 2:
|
|
433
|
+
device, dtype, *_ = args
|
|
434
|
+
elif len(args) == 1 and isinstance(args[0], torch.dtype):
|
|
435
|
+
dtype = args[0]
|
|
436
|
+
elif len(args) == 1:
|
|
437
|
+
device = args[0]
|
|
438
|
+
return self._to_copy(the_tensor, dtype, device)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def dispatch(self, func, types, args, kwargs):
|
|
442
|
+
|
|
443
|
+
kwargs = kwargs or {}
|
|
444
|
+
if func in TENSOR_CONSTRUCTORS:
|
|
445
|
+
return self._handle_tensor_constructor(func, args, kwargs)
|
|
446
|
+
if func in (torch.Tensor.to, torch.ops.aten.lift_fresh.default ,torch.ops.aten._to_copy, torch.ops.aten._to_copy.default):
|
|
447
|
+
return self._torch_Tensor_to(args, kwargs)
|
|
448
|
+
|
|
449
|
+
# If the func doesn't act on Tensor, and is not a tensor constructor,
|
|
450
|
+
# We should skip and let torch handle it.
|
|
451
|
+
|
|
452
|
+
tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)]
|
|
453
|
+
if tensor_args and all(not isinstance(t, Tensor) for t in tensor_args):
|
|
454
|
+
return func(*args, **kwargs)
|
|
455
|
+
|
|
456
|
+
with jax.named_scope(_name_of_func(func)):
|
|
457
|
+
op = self._ops.get(func)
|
|
458
|
+
|
|
459
|
+
if op is None and isinstance(func, torch._ops.OpOverloadPacket):
|
|
460
|
+
op = self._ops.get(func.default)
|
|
461
|
+
|
|
462
|
+
if op is None and isinstance(func, torch._ops.OpOverload):
|
|
463
|
+
op = self._ops.get(func.overloadpacket)
|
|
464
|
+
|
|
465
|
+
if op is None:
|
|
466
|
+
raise OperatorNotFound(
|
|
467
|
+
f'Operator with name {_name_of_func(func)} has no lowering')
|
|
468
|
+
|
|
469
|
+
old_args, old_kwargs = args, kwargs
|
|
470
|
+
args, kwargs = torch_pytree.tree_map_only(
|
|
471
|
+
torch.distributed._functional_collectives.AsyncCollectiveTensor,
|
|
472
|
+
torch.distributed._functional_collectives.wait_tensor,
|
|
473
|
+
(args, kwargs))
|
|
474
|
+
try:
|
|
475
|
+
if op.is_jax_function:
|
|
476
|
+
args, kwargs = self.t2j_iso((args, kwargs))
|
|
477
|
+
except AssertionError:
|
|
478
|
+
if self.config.debug_mixed_tensor:
|
|
479
|
+
import pdb; pdb.set_trace()
|
|
480
|
+
else:
|
|
481
|
+
raise
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
if op.needs_env:
|
|
485
|
+
kwargs['env'] = self
|
|
486
|
+
|
|
487
|
+
with self:
|
|
488
|
+
res = op.func(*args, **kwargs)
|
|
489
|
+
|
|
490
|
+
if op.is_jax_function:
|
|
491
|
+
res = self.j2t_iso(res)
|
|
492
|
+
|
|
493
|
+
if self.config.debug_accuracy_for_each_op:
|
|
494
|
+
debug_accuracy(func, old_args, old_kwargs, res)
|
|
495
|
+
return res
|
|
496
|
+
|
|
497
|
+
def enable_torch_modes(self):
|
|
498
|
+
self._dispatch_mode.__enter__()
|
|
499
|
+
self._function_mode.__enter__()
|
|
500
|
+
self.enabled = True
|
|
501
|
+
|
|
502
|
+
def disable_torch_modes(self, *exc):
|
|
503
|
+
if not exc:
|
|
504
|
+
exc = (None, None, None)
|
|
505
|
+
self._function_mode.__exit__(*exc)
|
|
506
|
+
self._dispatch_mode.__exit__(*exc)
|
|
507
|
+
self.enabled = False
|
|
508
|
+
|
|
509
|
+
def __enter__(self):
|
|
510
|
+
self.enable_torch_modes()
|
|
511
|
+
self._manually_entered = True
|
|
512
|
+
return self
|
|
513
|
+
|
|
514
|
+
def __exit__(self, *exc):
|
|
515
|
+
self._manually_entered = False
|
|
516
|
+
self.disable_torch_modes(*exc)
|
|
517
|
+
|
|
518
|
+
def _move_one_value(self, val):
|
|
519
|
+
if isinstance(val, torch.nn.Module):
|
|
520
|
+
with self:
|
|
521
|
+
return val.to('jax')
|
|
522
|
+
if isinstance(val, Tensor):
|
|
523
|
+
return val
|
|
524
|
+
if isinstance(val, torch.Tensor):
|
|
525
|
+
return Tensor(t2j(val), self)
|
|
526
|
+
return val
|
|
527
|
+
|
|
528
|
+
def to_xla(self, torchvalues):
|
|
529
|
+
# tensors are torch.Tensors (not XLATensor)
|
|
530
|
+
res = torch_pytree.tree_map(
|
|
531
|
+
self._move_one_value,
|
|
532
|
+
torchvalues)
|
|
533
|
+
return res
|
|
534
|
+
|
|
535
|
+
def t2j_iso(self, torchtensors):
|
|
536
|
+
def to_jax(x):
|
|
537
|
+
if isinstance(x, torch.distributed._functional_collectives.AsyncCollectiveTensor):
|
|
538
|
+
x = x.wait()
|
|
539
|
+
assert isinstance(x, Tensor), f'Expect a Tensor but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor'
|
|
540
|
+
return x.jax()
|
|
541
|
+
return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)
|
|
542
|
+
|
|
543
|
+
def j2t_iso(self, jaxarray):
|
|
544
|
+
return torch_pytree.tree_map_only(
|
|
545
|
+
jnp.ndarray, lambda x: Tensor(x, self), jaxarray)
|
|
546
|
+
|
|
547
|
+
def j2t_copy(self, args):
|
|
548
|
+
pass
|
|
549
|
+
|
|
550
|
+
def override_op_definition(self, op_to_override, op_impl):
|
|
551
|
+
self._ops[op_to_override] = ops_registry.Operator(
|
|
552
|
+
op_to_override,
|
|
553
|
+
op_impl,
|
|
554
|
+
is_jax_function=False,
|
|
555
|
+
is_user_defined=True,
|
|
556
|
+
needs_env=False
|
|
557
|
+
)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# pylint: disable
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Tuple
|
|
4
|
+
|
|
5
|
+
from jax.experimental import jax2tf
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
import torch
|
|
8
|
+
from torchax import export
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def exported_program_to_tf_function(ep, enable_xla=True):
|
|
12
|
+
weights, jax_program = export.exported_program_to_jax(ep)
|
|
13
|
+
wrapped = lambda *args: jax_program(weights, (args,))
|
|
14
|
+
avals = export.extract_avals(ep)
|
|
15
|
+
input_signature = [
|
|
16
|
+
tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}")
|
|
17
|
+
for i, t in enumerate(avals)
|
|
18
|
+
]
|
|
19
|
+
tf_f = tf.function(
|
|
20
|
+
jax2tf.convert(
|
|
21
|
+
wrapped,
|
|
22
|
+
with_gradient=False,
|
|
23
|
+
enable_xla=enable_xla,
|
|
24
|
+
),
|
|
25
|
+
autograph=False,
|
|
26
|
+
input_signature=input_signature,
|
|
27
|
+
)
|
|
28
|
+
return tf_f
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def exported_program_to_tf_module(ep: torch.export.ExportedProgram,
|
|
32
|
+
enable_xla=True) -> tf.Module:
|
|
33
|
+
tfm = tf.Module()
|
|
34
|
+
tfm.f = exported_program_to_tf_function(ep, enable_xla)
|
|
35
|
+
return tfm
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def save_exported_program_as_tf_saved_model(
|
|
39
|
+
ep: torch.export.ExportedProgram,
|
|
40
|
+
saved_model_dir: os.PathLike,
|
|
41
|
+
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
|
42
|
+
function_alias: str = "",
|
|
43
|
+
enable_xla=True,
|
|
44
|
+
):
|
|
45
|
+
"""This function will export and save a pytorch ExportedProgram to tf.saved_model format.
|
|
46
|
+
|
|
47
|
+
The resulting tf.saved_model can be used inference using tf.serving model
|
|
48
|
+
server
|
|
49
|
+
or further convert to tflite flatbuffer for on-device serving.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
torch_model: torch.nn.Module - model to export and save
|
|
53
|
+
args: Tuple[Any] - a set of args to trace the model with, i.e.
|
|
54
|
+
torch_model(*args) must run
|
|
55
|
+
saved_model_dir: os.PathLike - location to an empty directory to store the
|
|
56
|
+
saved_model
|
|
57
|
+
serving_key: str - serving key tag, this is used by tf.serving to know
|
|
58
|
+
which function to run.
|
|
59
|
+
function_alias: str - passed through saved_model.save, used to tag a
|
|
60
|
+
function for inference converter or other tools.
|
|
61
|
+
"""
|
|
62
|
+
tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla)
|
|
63
|
+
signatures = {
|
|
64
|
+
serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature)
|
|
65
|
+
}
|
|
66
|
+
save_options = tf.saved_model.SaveOptions(function_aliases={
|
|
67
|
+
function_alias: tfm.f,
|
|
68
|
+
})
|
|
69
|
+
tf.saved_model.save(
|
|
70
|
+
tfm,
|
|
71
|
+
saved_model_dir,
|
|
72
|
+
signatures=signatures,
|
|
73
|
+
options=save_options,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def save_torch_module_as_tf_saved_model(
|
|
78
|
+
torch_model: torch.nn.Module,
|
|
79
|
+
args: Tuple[Any],
|
|
80
|
+
saved_model_dir: os.PathLike,
|
|
81
|
+
serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
|
82
|
+
function_alias: str = "",
|
|
83
|
+
enable_xla=True,
|
|
84
|
+
):
|
|
85
|
+
"""This function will export and save a pytorch nn.Module to tf.saved_model format.
|
|
86
|
+
|
|
87
|
+
The resulting tf.saved_model can be used inference using tf.serving model
|
|
88
|
+
server
|
|
89
|
+
or further convert to tflite flatbuffer for on-device serving.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
torch_model: torch.nn.Module - model to export and save
|
|
93
|
+
args: Tuple[Any] - a set of args to trace the model with, i.e.
|
|
94
|
+
torch_model(*args) must run
|
|
95
|
+
saved_model_dir: os.PathLike - location to an empty directory to store the
|
|
96
|
+
saved_model
|
|
97
|
+
serving_key: str - serving key tag, this is used by tf.serving to know
|
|
98
|
+
which function to run.
|
|
99
|
+
function_alias: str - passed through saved_model.save, used to tag a
|
|
100
|
+
function for inference converter or other tools.
|
|
101
|
+
"""
|
|
102
|
+
ep = torch.export.export(torch_model, args)
|
|
103
|
+
save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key,
|
|
104
|
+
function_alias, enable_xla)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram):
|
|
108
|
+
tfm = exported_program_to_tf_module(ep)
|
|
109
|
+
tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature)
|
|
110
|
+
converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
|
111
|
+
[tf_concrete_func], tfm)
|
|
112
|
+
tflite_model = converter.convert()
|
|
113
|
+
return tflite_model
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module,
|
|
117
|
+
args: Tuple[Any]):
|
|
118
|
+
ep = torch.export.export(torch_model, args)
|
|
119
|
+
return exported_program_to_tflite_flatbuffer(ep)
|