torchax 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +1 -1
- torchax/interop.py +30 -30
- torchax/ops/mappings.py +12 -4
- torchax/ops/type_casting.py +47 -0
- torchax/tensor.py +3 -3
- {torchax-0.0.6.dist-info → torchax-0.0.7.dist-info}/METADATA +1 -1
- {torchax-0.0.6.dist-info → torchax-0.0.7.dist-info}/RECORD +9 -8
- {torchax-0.0.6.dist-info → torchax-0.0.7.dist-info}/WHEEL +0 -0
- {torchax-0.0.6.dist-info → torchax-0.0.7.dist-info}/licenses/LICENSE +0 -0
torchax/__init__.py
CHANGED
torchax/interop.py
CHANGED
|
@@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
237
237
|
the PyTorch autograd framework by saving the residuals into the context object.
|
|
238
238
|
"""
|
|
239
239
|
|
|
240
|
+
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
241
|
+
# Becuase if it does, then it won't hit the compilation cache for
|
|
242
|
+
# call_jax. Call jax uses functions' id as key.
|
|
243
|
+
# It is nested inside j2t_autograd to ensure it gets a unique ID for each
|
|
244
|
+
# wrapped pure function, preventing cache collisions between different pure modules.
|
|
245
|
+
def _jax_forward(fn, other, tree_def, tensors):
|
|
246
|
+
"""JAX function to compute output and vjp function.
|
|
247
|
+
|
|
248
|
+
primals should be a tuple (args, kwargs).
|
|
249
|
+
"""
|
|
250
|
+
import jax
|
|
251
|
+
from jax.tree_util import tree_flatten, tree_unflatten
|
|
252
|
+
|
|
253
|
+
def fn_wrapper(*tensors):
|
|
254
|
+
# Reconstruct the original args and kwargs
|
|
255
|
+
flat_inputs = util.merge(tensors, other)
|
|
256
|
+
args, kwargs = tree_unflatten(tree_def, flat_inputs)
|
|
257
|
+
return fn(*args, **kwargs)
|
|
258
|
+
|
|
259
|
+
return jax.vjp(fn_wrapper, *tensors)
|
|
260
|
+
|
|
261
|
+
def _jax_backward(vjp_spec, saved_tensors, grad_out):
|
|
262
|
+
"""JAX function to compute input gradients.
|
|
263
|
+
|
|
264
|
+
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
265
|
+
"""
|
|
266
|
+
from jax.tree_util import tree_unflatten
|
|
267
|
+
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
268
|
+
return fun_vjp(grad_out)
|
|
269
|
+
|
|
240
270
|
@wraps(fn)
|
|
241
271
|
def inner(*args, **kwargs):
|
|
242
272
|
from jax.tree_util import tree_flatten
|
|
@@ -290,36 +320,6 @@ def j2t_autograd(fn, call_jax=call_jax):
|
|
|
290
320
|
return inner
|
|
291
321
|
|
|
292
322
|
|
|
293
|
-
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
294
|
-
# Becuase if it does, then it won't hit the compilation cache for
|
|
295
|
-
# call_jax. Call jax uses functions' id as key.
|
|
296
|
-
def _jax_forward(fn, other, tree_def, tensors):
|
|
297
|
-
"""JAX function to compute output and vjp function.
|
|
298
|
-
|
|
299
|
-
primals should be a tuple (args, kwargs).
|
|
300
|
-
"""
|
|
301
|
-
import jax
|
|
302
|
-
from jax.tree_util import tree_flatten, tree_unflatten
|
|
303
|
-
|
|
304
|
-
def fn_wrapper(*tensors):
|
|
305
|
-
# Reconstruct the original args and kwargs
|
|
306
|
-
flat_inputs = util.merge(tensors, other)
|
|
307
|
-
args, kwargs = tree_unflatten(tree_def, flat_inputs)
|
|
308
|
-
return fn(*args, **kwargs)
|
|
309
|
-
|
|
310
|
-
return jax.vjp(fn_wrapper, *tensors)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
def _jax_backward(vjp_spec, saved_tensors, grad_out):
|
|
314
|
-
"""JAX function to compute input gradients.
|
|
315
|
-
|
|
316
|
-
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
317
|
-
"""
|
|
318
|
-
from jax.tree_util import tree_unflatten
|
|
319
|
-
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
320
|
-
return fun_vjp(grad_out)
|
|
321
|
-
|
|
322
|
-
|
|
323
323
|
fori_loop = torch_view(jax.lax.fori_loop)
|
|
324
324
|
|
|
325
325
|
|
torchax/ops/mappings.py
CHANGED
|
@@ -6,6 +6,14 @@ import torch.func
|
|
|
6
6
|
import torch.utils.dlpack as torchdl
|
|
7
7
|
import torch.utils._mode_utils as mode_utils
|
|
8
8
|
|
|
9
|
+
NUMPY_UNSUPPORTED_DTYPES = {
|
|
10
|
+
torch.bfloat16: jnp.bfloat16,
|
|
11
|
+
torch.float8_e4m3fn: jnp.float8_e4m3fn,
|
|
12
|
+
torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
|
|
13
|
+
torch.float8_e5m2: jnp.float8_e5m2,
|
|
14
|
+
torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
|
|
15
|
+
}
|
|
16
|
+
|
|
9
17
|
|
|
10
18
|
def t2j(t, use_dlpack=True):
|
|
11
19
|
is_bool = False
|
|
@@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
|
|
|
28
36
|
if res is None:
|
|
29
37
|
# https://github.com/google/jax/issues/7657
|
|
30
38
|
# https://github.com/google/jax/issues/17784
|
|
31
|
-
if t.dtype
|
|
39
|
+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
32
40
|
nparray = (t.cpu().detach().to(torch.float32).numpy()
|
|
33
|
-
) #
|
|
41
|
+
) # handle dtypes not supported by numpy
|
|
34
42
|
else:
|
|
35
43
|
nparray = t.cpu().detach().numpy()
|
|
36
44
|
res = jnp.asarray(nparray)
|
|
37
|
-
if t.dtype
|
|
38
|
-
res = res.astype(
|
|
45
|
+
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
46
|
+
res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
|
|
39
47
|
|
|
40
48
|
if is_bool:
|
|
41
49
|
res = res.astype(jnp.bool_)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
|
|
2
|
+
from torchax import config
|
|
3
|
+
from torchax.ops import mappings
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
def maybe_cast(result, torch_op):
|
|
8
|
+
"""Casts the result to the torch op's return dtype if the config is set."""
|
|
9
|
+
if not config.DEFAULTS.internal_respect_torch_return_dtypes:
|
|
10
|
+
return result
|
|
11
|
+
|
|
12
|
+
if not hasattr(torch_op, '_schema'):
|
|
13
|
+
return result
|
|
14
|
+
|
|
15
|
+
schema = torch_op._schema
|
|
16
|
+
if not schema.returns:
|
|
17
|
+
return result
|
|
18
|
+
|
|
19
|
+
# TODO: Handle multiple return values
|
|
20
|
+
if len(schema.returns) > 1:
|
|
21
|
+
return result
|
|
22
|
+
|
|
23
|
+
return_type = schema.returns[0].type
|
|
24
|
+
if str(return_type) == 'Tensor':
|
|
25
|
+
# This is not quite right, we need to get the dtype of the tensor
|
|
26
|
+
# For now, let's assume we can get it from the first input argument
|
|
27
|
+
if not schema.arguments:
|
|
28
|
+
return result
|
|
29
|
+
|
|
30
|
+
input_type = schema.arguments[0].type
|
|
31
|
+
if str(input_type) != 'Tensor':
|
|
32
|
+
return result
|
|
33
|
+
|
|
34
|
+
# This is a hack, we need a better way to determine the return dtype
|
|
35
|
+
# For now, let's assume the return type is the same as the first input
|
|
36
|
+
# This is not always true, e.g. for comparison ops.
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
torch_dtype = getattr(torch, str(return_type))
|
|
41
|
+
jax_dtype = mappings.t2j_dtype(torch_dtype)
|
|
42
|
+
if isinstance(result, jnp.ndarray):
|
|
43
|
+
return result.astype(jax_dtype)
|
|
44
|
+
else:
|
|
45
|
+
return jax_dtype(result)
|
|
46
|
+
except (AttributeError, TypeError):
|
|
47
|
+
return result
|
torchax/tensor.py
CHANGED
|
@@ -469,12 +469,12 @@ class Environment(contextlib.ContextDecorator):
|
|
|
469
469
|
arr = self.t2j_copy(the_tensor)
|
|
470
470
|
res = Tensor(arr, self, the_tensor.requires_grad)
|
|
471
471
|
|
|
472
|
-
if new_dtype is not None and new_dtype !=
|
|
473
|
-
if isinstance(
|
|
472
|
+
if new_dtype is not None and new_dtype != res.dtype:
|
|
473
|
+
if isinstance(res, Tensor):
|
|
474
474
|
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
|
|
475
475
|
else:
|
|
476
476
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
477
|
-
return
|
|
477
|
+
return res.to(device=new_device, dtype=new_dtype)
|
|
478
478
|
return res
|
|
479
479
|
|
|
480
480
|
def get_and_rotate_prng_key(self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: torchax is a library for running Jax and PyTorch together
|
|
5
5
|
Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
|
|
6
6
|
Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
torchax/CONTRIBUTING.md,sha256=VOL0us6kS-uc4yE6IlSm6SDHYHnx-gw-0upFnP0VkSQ,1369
|
|
2
|
-
torchax/__init__.py,sha256=
|
|
2
|
+
torchax/__init__.py,sha256=fVp0Hgq6-FwGzj7Gt9yH0qwzAzZ3Z7TZdSyLMHc-nrY,3157
|
|
3
3
|
torchax/amp.py,sha256=-k8t4lrCsJLKHEhI6J0aHE3MAPEL-4DP6wCKtMwo1AM,11791
|
|
4
4
|
torchax/config.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
|
|
5
5
|
torchax/configuration.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
|
|
@@ -8,9 +8,9 @@ torchax/device_module.py,sha256=7fkdPwXG0qCBTmvDYHp0fvv4xK0W9avV_Ua3MeMzczE,349
|
|
|
8
8
|
torchax/environment.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
9
9
|
torchax/export.py,sha256=xU-UbrQBvQWUy-GM2FfeIHymlEdmYDYcPymjlcXM23w,8969
|
|
10
10
|
torchax/flax.py,sha256=2Tg8inGskgAfByPxJQh4ItZHHAb-960gYq156bSO8V4,1280
|
|
11
|
-
torchax/interop.py,sha256=
|
|
11
|
+
torchax/interop.py,sha256=5r3ZRUQAJj9n-7NGBxbP-N87-K-8GoYftULq1r2CDxE,11285
|
|
12
12
|
torchax/mesh_util.py,sha256=Ab4ic2eHWmQ3Mw3jpERvi-TKLIcDvQQoC6tuIZ9ig7Q,9314
|
|
13
|
-
torchax/tensor.py,sha256=
|
|
13
|
+
torchax/tensor.py,sha256=vU-RR6LArrQlO62fTNQQ4RFLRyKJ3Oa9GXsbmq4K8rI,20872
|
|
14
14
|
torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
|
|
15
15
|
torchax/train.py,sha256=rtvj6HkdnG9fc3VWYPNwHuxGlUxFJkUXJWED8azgtok,3855
|
|
16
16
|
torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
|
|
@@ -24,10 +24,11 @@ torchax/ops/jimage.py,sha256=P0lAauYX_au_xjIHDsG7H6jO7Jf54_VCAjzZuIZdhO0,3182
|
|
|
24
24
|
torchax/ops/jlibrary.py,sha256=YfYUQbf5dKiMtEHUMfdgHTeLuNvvSTJ-l8s7wQNIvO0,2930
|
|
25
25
|
torchax/ops/jtorch.py,sha256=wR4ZdDscxqG4VpxjcLGzgdUKmipa3fp7S0mK3DcD--A,17161
|
|
26
26
|
torchax/ops/jtorchvision_nms.py,sha256=HSnhwU0gFaHucT7EvrEruJdnWkAWTw4T35GY525ohO8,8903
|
|
27
|
-
torchax/ops/mappings.py,sha256=
|
|
27
|
+
torchax/ops/mappings.py,sha256=H-2jlG9ODuV9VzCFqZEC-djTrbcYXmw4fAVwn5Yilc4,3787
|
|
28
28
|
torchax/ops/op_base.py,sha256=MLKFxMojIXgz4lkTE6k-8F-ddve-9vEiXkzj3P-YJPs,3739
|
|
29
29
|
torchax/ops/ops_registry.py,sha256=qADpG1up0JOThoybiOQoRDWtAe5TOkHlqcj1bSHjtGY,1594
|
|
30
|
-
torchax
|
|
31
|
-
torchax-0.0.
|
|
32
|
-
torchax-0.0.
|
|
33
|
-
torchax-0.0.
|
|
30
|
+
torchax/ops/type_casting.py,sha256=gNz3mbA9XtRhkHcx-qpF1bFzsnsila-jkCE9BPQD9GI,1391
|
|
31
|
+
torchax-0.0.7.dist-info/METADATA,sha256=_F_gU0Ea6epTCngRXcBeur4oH8NgOvgq78DBhjt6zEo,10753
|
|
32
|
+
torchax-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
33
|
+
torchax-0.0.7.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
|
|
34
|
+
torchax-0.0.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|