torchax 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/__init__.py CHANGED
@@ -8,7 +8,7 @@ from torch.utils import _pytree as pytree
8
8
  from torchax import tensor
9
9
  from contextlib import contextmanager
10
10
 
11
- __version__ = "0.0.6"
11
+ __version__ = "0.0.7"
12
12
  VERSION = __version__
13
13
 
14
14
  __all__ = [
torchax/interop.py CHANGED
@@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax):
237
237
  the PyTorch autograd framework by saving the residuals into the context object.
238
238
  """
239
239
 
240
+ # NOTE(qihqi): This function cannot be inlined from the callsite
241
+ # Becuase if it does, then it won't hit the compilation cache for
242
+ # call_jax. Call jax uses functions' id as key.
243
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
244
+ # wrapped pure function, preventing cache collisions between different pure modules.
245
+ def _jax_forward(fn, other, tree_def, tensors):
246
+ """JAX function to compute output and vjp function.
247
+
248
+ primals should be a tuple (args, kwargs).
249
+ """
250
+ import jax
251
+ from jax.tree_util import tree_flatten, tree_unflatten
252
+
253
+ def fn_wrapper(*tensors):
254
+ # Reconstruct the original args and kwargs
255
+ flat_inputs = util.merge(tensors, other)
256
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
257
+ return fn(*args, **kwargs)
258
+
259
+ return jax.vjp(fn_wrapper, *tensors)
260
+
261
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
262
+ """JAX function to compute input gradients.
263
+
264
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
265
+ """
266
+ from jax.tree_util import tree_unflatten
267
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
268
+ return fun_vjp(grad_out)
269
+
240
270
  @wraps(fn)
241
271
  def inner(*args, **kwargs):
242
272
  from jax.tree_util import tree_flatten
@@ -290,36 +320,6 @@ def j2t_autograd(fn, call_jax=call_jax):
290
320
  return inner
291
321
 
292
322
 
293
- # NOTE(qihqi): This function cannot be inlined from the callsite
294
- # Becuase if it does, then it won't hit the compilation cache for
295
- # call_jax. Call jax uses functions' id as key.
296
- def _jax_forward(fn, other, tree_def, tensors):
297
- """JAX function to compute output and vjp function.
298
-
299
- primals should be a tuple (args, kwargs).
300
- """
301
- import jax
302
- from jax.tree_util import tree_flatten, tree_unflatten
303
-
304
- def fn_wrapper(*tensors):
305
- # Reconstruct the original args and kwargs
306
- flat_inputs = util.merge(tensors, other)
307
- args, kwargs = tree_unflatten(tree_def, flat_inputs)
308
- return fn(*args, **kwargs)
309
-
310
- return jax.vjp(fn_wrapper, *tensors)
311
-
312
-
313
- def _jax_backward(vjp_spec, saved_tensors, grad_out):
314
- """JAX function to compute input gradients.
315
-
316
- Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
317
- """
318
- from jax.tree_util import tree_unflatten
319
- fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
320
- return fun_vjp(grad_out)
321
-
322
-
323
323
  fori_loop = torch_view(jax.lax.fori_loop)
324
324
 
325
325
 
torchax/ops/mappings.py CHANGED
@@ -6,6 +6,14 @@ import torch.func
6
6
  import torch.utils.dlpack as torchdl
7
7
  import torch.utils._mode_utils as mode_utils
8
8
 
9
+ NUMPY_UNSUPPORTED_DTYPES = {
10
+ torch.bfloat16: jnp.bfloat16,
11
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
12
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
13
+ torch.float8_e5m2: jnp.float8_e5m2,
14
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
15
+ }
16
+
9
17
 
10
18
  def t2j(t, use_dlpack=True):
11
19
  is_bool = False
@@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
28
36
  if res is None:
29
37
  # https://github.com/google/jax/issues/7657
30
38
  # https://github.com/google/jax/issues/17784
31
- if t.dtype == torch.bfloat16:
39
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
32
40
  nparray = (t.cpu().detach().to(torch.float32).numpy()
33
- ) # numpy don't support bfloat16
41
+ ) # handle dtypes not supported by numpy
34
42
  else:
35
43
  nparray = t.cpu().detach().numpy()
36
44
  res = jnp.asarray(nparray)
37
- if t.dtype == torch.bfloat16:
38
- res = res.astype(jnp.bfloat16)
45
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
46
+ res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
39
47
 
40
48
  if is_bool:
41
49
  res = res.astype(jnp.bool_)
@@ -0,0 +1,47 @@
1
+
2
+ from torchax import config
3
+ from torchax.ops import mappings
4
+ import jax.numpy as jnp
5
+ import torch
6
+
7
+ def maybe_cast(result, torch_op):
8
+ """Casts the result to the torch op's return dtype if the config is set."""
9
+ if not config.DEFAULTS.internal_respect_torch_return_dtypes:
10
+ return result
11
+
12
+ if not hasattr(torch_op, '_schema'):
13
+ return result
14
+
15
+ schema = torch_op._schema
16
+ if not schema.returns:
17
+ return result
18
+
19
+ # TODO: Handle multiple return values
20
+ if len(schema.returns) > 1:
21
+ return result
22
+
23
+ return_type = schema.returns[0].type
24
+ if str(return_type) == 'Tensor':
25
+ # This is not quite right, we need to get the dtype of the tensor
26
+ # For now, let's assume we can get it from the first input argument
27
+ if not schema.arguments:
28
+ return result
29
+
30
+ input_type = schema.arguments[0].type
31
+ if str(input_type) != 'Tensor':
32
+ return result
33
+
34
+ # This is a hack, we need a better way to determine the return dtype
35
+ # For now, let's assume the return type is the same as the first input
36
+ # This is not always true, e.g. for comparison ops.
37
+ return result
38
+
39
+ try:
40
+ torch_dtype = getattr(torch, str(return_type))
41
+ jax_dtype = mappings.t2j_dtype(torch_dtype)
42
+ if isinstance(result, jnp.ndarray):
43
+ return result.astype(jax_dtype)
44
+ else:
45
+ return jax_dtype(result)
46
+ except (AttributeError, TypeError):
47
+ return result
torchax/tensor.py CHANGED
@@ -469,12 +469,12 @@ class Environment(contextlib.ContextDecorator):
469
469
  arr = self.t2j_copy(the_tensor)
470
470
  res = Tensor(arr, self, the_tensor.requires_grad)
471
471
 
472
- if new_dtype is not None and new_dtype != the_tensor.dtype:
473
- if isinstance(the_tensor, Tensor):
472
+ if new_dtype is not None and new_dtype != res.dtype:
473
+ if isinstance(res, Tensor):
474
474
  res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
475
475
  else:
476
476
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
477
- return the_tensor.to(device=new_device, dtype=new_dtype)
477
+ return res.to(device=new_device, dtype=new_dtype)
478
478
  return res
479
479
 
480
480
  def get_and_rotate_prng_key(self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
@@ -1,5 +1,5 @@
1
1
  torchax/CONTRIBUTING.md,sha256=VOL0us6kS-uc4yE6IlSm6SDHYHnx-gw-0upFnP0VkSQ,1369
2
- torchax/__init__.py,sha256=c98iIGugRTbEVcsx8eWnbAjsC4mpcDrK23ZQqiMycLg,3157
2
+ torchax/__init__.py,sha256=fVp0Hgq6-FwGzj7Gt9yH0qwzAzZ3Z7TZdSyLMHc-nrY,3157
3
3
  torchax/amp.py,sha256=-k8t4lrCsJLKHEhI6J0aHE3MAPEL-4DP6wCKtMwo1AM,11791
4
4
  torchax/config.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
5
5
  torchax/configuration.py,sha256=O9yF96AShWb02hcwkT5ToPTt_hpOo3dMJNO30A7dmac,922
@@ -8,9 +8,9 @@ torchax/device_module.py,sha256=7fkdPwXG0qCBTmvDYHp0fvv4xK0W9avV_Ua3MeMzczE,349
8
8
  torchax/environment.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
9
9
  torchax/export.py,sha256=xU-UbrQBvQWUy-GM2FfeIHymlEdmYDYcPymjlcXM23w,8969
10
10
  torchax/flax.py,sha256=2Tg8inGskgAfByPxJQh4ItZHHAb-960gYq156bSO8V4,1280
11
- torchax/interop.py,sha256=7HvJwtxdodcCrMyJzs-Wr47hkHuoh6CWb2-YKoBwqV0,11076
11
+ torchax/interop.py,sha256=5r3ZRUQAJj9n-7NGBxbP-N87-K-8GoYftULq1r2CDxE,11285
12
12
  torchax/mesh_util.py,sha256=Ab4ic2eHWmQ3Mw3jpERvi-TKLIcDvQQoC6tuIZ9ig7Q,9314
13
- torchax/tensor.py,sha256=XjAp7khpQNhoVsSMzDj-V8l4DFT9jBaL4NVCi88a6K0,20893
13
+ torchax/tensor.py,sha256=vU-RR6LArrQlO62fTNQQ4RFLRyKJ3Oa9GXsbmq4K8rI,20872
14
14
  torchax/tf_integration.py,sha256=d_h4vSJm7N9rJXpUPNCDOiUz3J1-UPo3KU8D9Wi4nnc,4074
15
15
  torchax/train.py,sha256=rtvj6HkdnG9fc3VWYPNwHuxGlUxFJkUXJWED8azgtok,3855
16
16
  torchax/types.py,sha256=j4ERjkgDgwhgi9zrwwbbiv4HMDlrJ1IEMUCmP_BIJ9M,388
@@ -24,10 +24,11 @@ torchax/ops/jimage.py,sha256=P0lAauYX_au_xjIHDsG7H6jO7Jf54_VCAjzZuIZdhO0,3182
24
24
  torchax/ops/jlibrary.py,sha256=YfYUQbf5dKiMtEHUMfdgHTeLuNvvSTJ-l8s7wQNIvO0,2930
25
25
  torchax/ops/jtorch.py,sha256=wR4ZdDscxqG4VpxjcLGzgdUKmipa3fp7S0mK3DcD--A,17161
26
26
  torchax/ops/jtorchvision_nms.py,sha256=HSnhwU0gFaHucT7EvrEruJdnWkAWTw4T35GY525ohO8,8903
27
- torchax/ops/mappings.py,sha256=AESERtXJ6i_Hm0ycwEw7z5OJnHu-7QteWlSs-mlUPE4,3492
27
+ torchax/ops/mappings.py,sha256=H-2jlG9ODuV9VzCFqZEC-djTrbcYXmw4fAVwn5Yilc4,3787
28
28
  torchax/ops/op_base.py,sha256=MLKFxMojIXgz4lkTE6k-8F-ddve-9vEiXkzj3P-YJPs,3739
29
29
  torchax/ops/ops_registry.py,sha256=qADpG1up0JOThoybiOQoRDWtAe5TOkHlqcj1bSHjtGY,1594
30
- torchax-0.0.6.dist-info/METADATA,sha256=uB9hoyxdfrAD14pHy0U8Gh1uCHbYwok-oEW12pEa6qs,10753
31
- torchax-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
32
- torchax-0.0.6.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
33
- torchax-0.0.6.dist-info/RECORD,,
30
+ torchax/ops/type_casting.py,sha256=gNz3mbA9XtRhkHcx-qpF1bFzsnsila-jkCE9BPQD9GI,1391
31
+ torchax-0.0.7.dist-info/METADATA,sha256=_F_gU0Ea6epTCngRXcBeur4oH8NgOvgq78DBhjt6zEo,10753
32
+ torchax-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
33
+ torchax-0.0.7.dist-info/licenses/LICENSE,sha256=ZHyir3-ltOerFLt9JH1bjf7lIxIWipFmqeMnB_8z_aU,1498
34
+ torchax-0.0.7.dist-info/RECORD,,