torchax 0.0.6__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

Files changed (116) hide show
  1. {torchax-0.0.6 → torchax-0.0.7}/PKG-INFO +1 -1
  2. {torchax-0.0.6 → torchax-0.0.7}/test/llama/llama_model.py +1 -1
  3. torchax-0.0.7/test/test_misc.py +47 -0
  4. {torchax-0.0.6 → torchax-0.0.7}/torchax/__init__.py +1 -1
  5. {torchax-0.0.6 → torchax-0.0.7}/torchax/interop.py +30 -30
  6. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/mappings.py +12 -4
  7. torchax-0.0.7/torchax/ops/type_casting.py +47 -0
  8. {torchax-0.0.6 → torchax-0.0.7}/torchax/tensor.py +3 -3
  9. torchax-0.0.6/test/test_misc.py +0 -22
  10. {torchax-0.0.6 → torchax-0.0.7}/.gitignore +0 -0
  11. {torchax-0.0.6 → torchax-0.0.7}/=2.3.0 +0 -0
  12. {torchax-0.0.6 → torchax-0.0.7}/LICENSE +0 -0
  13. {torchax-0.0.6 → torchax-0.0.7}/README.md +0 -0
  14. {torchax-0.0.6 → torchax-0.0.7}/build_nightly.sh +0 -0
  15. {torchax-0.0.6 → torchax-0.0.7}/dev-requirements.txt +0 -0
  16. {torchax-0.0.6 → torchax-0.0.7}/docs/api_iterations.md +0 -0
  17. {torchax-0.0.6 → torchax-0.0.7}/docs/dispatch.png +0 -0
  18. {torchax-0.0.6 → torchax-0.0.7}/docs/fixing_op_info_test.md +0 -0
  19. {torchax-0.0.6 → torchax-0.0.7}/docs/how_it_works.md +0 -0
  20. {torchax-0.0.6 → torchax-0.0.7}/docs/ops_registry.md +0 -0
  21. {torchax-0.0.6 → torchax-0.0.7}/docs/support_a_new_model.md +0 -0
  22. {torchax-0.0.6 → torchax-0.0.7}/docs/torch_dispatch/README.md +0 -0
  23. {torchax-0.0.6 → torchax-0.0.7}/docs/torch_dispatch/example.py +0 -0
  24. {torchax-0.0.6 → torchax-0.0.7}/docs/torch_dispatch/run_env.py +0 -0
  25. {torchax-0.0.6 → torchax-0.0.7}/docs/torch_xla2_dynamo.md +0 -0
  26. {torchax-0.0.6 → torchax-0.0.7}/docs/understand_jax_jit/jax_grad.py +0 -0
  27. {torchax-0.0.6 → torchax-0.0.7}/docs/understand_jax_jit/jax_jit.py +0 -0
  28. {torchax-0.0.6 → torchax-0.0.7}/docs/understand_jax_jit/torch_module.py +0 -0
  29. {torchax-0.0.6 → torchax-0.0.7}/examples/README.md +0 -0
  30. {torchax-0.0.6 → torchax-0.0.7}/examples/__init__.py +0 -0
  31. {torchax-0.0.6 → torchax-0.0.7}/examples/_diffusion.py +0 -0
  32. {torchax-0.0.6 → torchax-0.0.7}/examples/_grad_of_attention.py +0 -0
  33. {torchax-0.0.6 → torchax-0.0.7}/examples/basic_training.py +0 -0
  34. {torchax-0.0.6 → torchax-0.0.7}/examples/basic_training_jax.py +0 -0
  35. {torchax-0.0.6 → torchax-0.0.7}/examples/eager_mode.py +0 -0
  36. {torchax-0.0.6 → torchax-0.0.7}/examples/lightning_training.py +0 -0
  37. {torchax-0.0.6 → torchax-0.0.7}/examples/requirements.txt +0 -0
  38. {torchax-0.0.6 → torchax-0.0.7}/examples/torchbench_models/BERT_pytorch.py +0 -0
  39. {torchax-0.0.6 → torchax-0.0.7}/examples/train_gpt/requirements.txt +0 -0
  40. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama/README.md +0 -0
  41. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama/__init__.py +0 -0
  42. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama/model.py +0 -0
  43. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama/train_llama_lightning.py +0 -0
  44. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama/utils.py +0 -0
  45. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama_torchtitan/Dockerfile +0 -0
  46. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama_torchtitan/README.md +0 -0
  47. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama_torchtitan/__init__.py +0 -0
  48. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama_torchtitan/helper.py +0 -0
  49. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama_torchtitan/splash_attn.py +0 -0
  50. {torchax-0.0.6 → torchax-0.0.7}/examples/train_llama_torchtitan/train_llama.py +0 -0
  51. {torchax-0.0.6 → torchax-0.0.7}/format.sh +0 -0
  52. {torchax-0.0.6 → torchax-0.0.7}/pyproject.toml +0 -0
  53. {torchax-0.0.6 → torchax-0.0.7}/repro1.py +0 -0
  54. {torchax-0.0.6 → torchax-0.0.7}/temp +0 -0
  55. {torchax-0.0.6 → torchax-0.0.7}/test/BUILD +0 -0
  56. {torchax-0.0.6 → torchax-0.0.7}/test/__init__.py +0 -0
  57. {torchax-0.0.6 → torchax-0.0.7}/test/base_test_util.py +0 -0
  58. {torchax-0.0.6 → torchax-0.0.7}/test/gemma/__init__.py +0 -0
  59. {torchax-0.0.6 → torchax-0.0.7}/test/gemma/config.py +0 -0
  60. {torchax-0.0.6 → torchax-0.0.7}/test/gemma/model.py +0 -0
  61. {torchax-0.0.6 → torchax-0.0.7}/test/gemma/test_gemma.py +0 -0
  62. {torchax-0.0.6 → torchax-0.0.7}/test/gemma/tokenizer.py +0 -0
  63. {torchax-0.0.6 → torchax-0.0.7}/test/llama/BUILD +0 -0
  64. {torchax-0.0.6 → torchax-0.0.7}/test/llama/__init__.py +0 -0
  65. {torchax-0.0.6 → torchax-0.0.7}/test/llama/model_exportable.py +0 -0
  66. {torchax-0.0.6 → torchax-0.0.7}/test/llama/test_llama.py +0 -0
  67. {torchax-0.0.6 → torchax-0.0.7}/test/moe/__init__.py +0 -0
  68. {torchax-0.0.6 → torchax-0.0.7}/test/moe/model.py +0 -0
  69. {torchax-0.0.6 → torchax-0.0.7}/test/moe/moe_test.py +0 -0
  70. {torchax-0.0.6 → torchax-0.0.7}/test/test_amp.py +0 -0
  71. {torchax-0.0.6 → torchax-0.0.7}/test/test_base.py +0 -0
  72. {torchax-0.0.6 → torchax-0.0.7}/test/test_context.py +0 -0
  73. {torchax-0.0.6 → torchax-0.0.7}/test/test_conv.py +0 -0
  74. {torchax-0.0.6 → torchax-0.0.7}/test/test_core_aten_ops.py +0 -0
  75. {torchax-0.0.6 → torchax-0.0.7}/test/test_exports.py +0 -0
  76. {torchax-0.0.6 → torchax-0.0.7}/test/test_flax.py +0 -0
  77. {torchax-0.0.6 → torchax-0.0.7}/test/test_functions.py +0 -0
  78. {torchax-0.0.6 → torchax-0.0.7}/test/test_image.py +0 -0
  79. {torchax-0.0.6 → torchax-0.0.7}/test/test_interop.py +0 -0
  80. {torchax-0.0.6 → torchax-0.0.7}/test/test_jittable_module.py +0 -0
  81. {torchax-0.0.6 → torchax-0.0.7}/test/test_libraries.py +0 -0
  82. {torchax-0.0.6 → torchax-0.0.7}/test/test_mutations.py +0 -0
  83. {torchax-0.0.6 → torchax-0.0.7}/test/test_ops.py +0 -0
  84. {torchax-0.0.6 → torchax-0.0.7}/test/test_symbolic_shapes.py +0 -0
  85. {torchax-0.0.6 → torchax-0.0.7}/test/test_tf_integration.py +0 -0
  86. {torchax-0.0.6 → torchax-0.0.7}/test/test_train.py +0 -0
  87. {torchax-0.0.6 → torchax-0.0.7}/test/test_unbounded_dynamism.py +0 -0
  88. {torchax-0.0.6 → torchax-0.0.7}/test/test_util.py +0 -0
  89. {torchax-0.0.6 → torchax-0.0.7}/test/test_view.py +0 -0
  90. {torchax-0.0.6 → torchax-0.0.7}/test-requirements.txt +0 -0
  91. {torchax-0.0.6 → torchax-0.0.7}/test_dist/test_mesh_util.py +0 -0
  92. {torchax-0.0.6 → torchax-0.0.7}/torchax/CONTRIBUTING.md +0 -0
  93. {torchax-0.0.6 → torchax-0.0.7}/torchax/amp.py +0 -0
  94. {torchax-0.0.6 → torchax-0.0.7}/torchax/config.py +0 -0
  95. {torchax-0.0.6 → torchax-0.0.7}/torchax/configuration.py +0 -0
  96. {torchax-0.0.6 → torchax-0.0.7}/torchax/decompositions.py +0 -0
  97. {torchax-0.0.6 → torchax-0.0.7}/torchax/device_module.py +0 -0
  98. {torchax-0.0.6 → torchax-0.0.7}/torchax/environment.py +0 -0
  99. {torchax-0.0.6 → torchax-0.0.7}/torchax/export.py +0 -0
  100. {torchax-0.0.6 → torchax-0.0.7}/torchax/flax.py +0 -0
  101. {torchax-0.0.6 → torchax-0.0.7}/torchax/mesh_util.py +0 -0
  102. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/__init__.py +0 -0
  103. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jaten.py +0 -0
  104. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jax_reimplement.py +0 -0
  105. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jc10d.py +0 -0
  106. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jimage.py +0 -0
  107. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jlibrary.py +0 -0
  108. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jtorch.py +0 -0
  109. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/jtorchvision_nms.py +0 -0
  110. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/op_base.py +0 -0
  111. {torchax-0.0.6 → torchax-0.0.7}/torchax/ops/ops_registry.py +0 -0
  112. {torchax-0.0.6 → torchax-0.0.7}/torchax/tf_integration.py +0 -0
  113. {torchax-0.0.6 → torchax-0.0.7}/torchax/train.py +0 -0
  114. {torchax-0.0.6 → torchax-0.0.7}/torchax/types.py +0 -0
  115. {torchax-0.0.6 → torchax-0.0.7}/torchax/util.py +0 -0
  116. {torchax-0.0.6 → torchax-0.0.7}/torchax/view.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchax
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: torchax is a library for running Jax and PyTorch together
5
5
  Project-URL: Homepage, https://github.com/pytorch/xla/tree/master/torchax
6
6
  Author-email: Han Qi <qihan.dev@gmail.com>, Pytorch/XLA team <pytorchxla-dev@google.com>
@@ -4,7 +4,7 @@
4
4
  # This source code is licensed under the license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
- # This file is copied from https://github.com/pytorch-labs/gpt-fast
7
+ # This file is copied from https://github.com/meta-pytorch/gpt-fast
8
8
  # This is used for unit test purposes
9
9
  from dataclasses import dataclass
10
10
  import math
@@ -0,0 +1,47 @@
1
+ """If you don't know which file a test should go, and don't want to make a new file
2
+ for a small test. PUt it here
3
+ """
4
+ import torch
5
+ import unittest
6
+ import torchax
7
+ import jax
8
+ import jax.numpy as jnp
9
+
10
+
11
+ class MiscTest(unittest.TestCase):
12
+
13
+ def test_extract_jax_kwargs(self):
14
+
15
+ class M(torch.nn.Module):
16
+
17
+ def forward(self, a, b):
18
+ return torch.sin(a) + torch.cos(b)
19
+
20
+ weights, func = torchax.extract_jax(M())
21
+ res = func(
22
+ weights,
23
+ args=(),
24
+ kwargs={
25
+ 'a': jnp.array([1, 2, 3]),
26
+ 'b': jnp.array([3, 4, 5])
27
+ })
28
+ self.assertTrue(
29
+ jnp.allclose(
30
+ res,
31
+ jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5]))))
32
+
33
+ def test_to_device(self):
34
+ env = torchax.default_env()
35
+ env.config.debug_print_each_op = True
36
+ with env:
37
+ step1 = torch.ones(
38
+ 100,
39
+ 100,
40
+ )
41
+ step2 = torch.triu(step1, diagonal=1)
42
+ step3 = step2.to(dtype=torch.bool, device='jax')
43
+ self.assertEqual(step3.device.type, 'jax')
44
+
45
+
46
+ if __name__ == '__main__':
47
+ unittest.main()
@@ -8,7 +8,7 @@ from torch.utils import _pytree as pytree
8
8
  from torchax import tensor
9
9
  from contextlib import contextmanager
10
10
 
11
- __version__ = "0.0.6"
11
+ __version__ = "0.0.7"
12
12
  VERSION = __version__
13
13
 
14
14
  __all__ = [
@@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax):
237
237
  the PyTorch autograd framework by saving the residuals into the context object.
238
238
  """
239
239
 
240
+ # NOTE(qihqi): This function cannot be inlined from the callsite
241
+ # Becuase if it does, then it won't hit the compilation cache for
242
+ # call_jax. Call jax uses functions' id as key.
243
+ # It is nested inside j2t_autograd to ensure it gets a unique ID for each
244
+ # wrapped pure function, preventing cache collisions between different pure modules.
245
+ def _jax_forward(fn, other, tree_def, tensors):
246
+ """JAX function to compute output and vjp function.
247
+
248
+ primals should be a tuple (args, kwargs).
249
+ """
250
+ import jax
251
+ from jax.tree_util import tree_flatten, tree_unflatten
252
+
253
+ def fn_wrapper(*tensors):
254
+ # Reconstruct the original args and kwargs
255
+ flat_inputs = util.merge(tensors, other)
256
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
257
+ return fn(*args, **kwargs)
258
+
259
+ return jax.vjp(fn_wrapper, *tensors)
260
+
261
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
262
+ """JAX function to compute input gradients.
263
+
264
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
265
+ """
266
+ from jax.tree_util import tree_unflatten
267
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
268
+ return fun_vjp(grad_out)
269
+
240
270
  @wraps(fn)
241
271
  def inner(*args, **kwargs):
242
272
  from jax.tree_util import tree_flatten
@@ -290,36 +320,6 @@ def j2t_autograd(fn, call_jax=call_jax):
290
320
  return inner
291
321
 
292
322
 
293
- # NOTE(qihqi): This function cannot be inlined from the callsite
294
- # Becuase if it does, then it won't hit the compilation cache for
295
- # call_jax. Call jax uses functions' id as key.
296
- def _jax_forward(fn, other, tree_def, tensors):
297
- """JAX function to compute output and vjp function.
298
-
299
- primals should be a tuple (args, kwargs).
300
- """
301
- import jax
302
- from jax.tree_util import tree_flatten, tree_unflatten
303
-
304
- def fn_wrapper(*tensors):
305
- # Reconstruct the original args and kwargs
306
- flat_inputs = util.merge(tensors, other)
307
- args, kwargs = tree_unflatten(tree_def, flat_inputs)
308
- return fn(*args, **kwargs)
309
-
310
- return jax.vjp(fn_wrapper, *tensors)
311
-
312
-
313
- def _jax_backward(vjp_spec, saved_tensors, grad_out):
314
- """JAX function to compute input gradients.
315
-
316
- Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
317
- """
318
- from jax.tree_util import tree_unflatten
319
- fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
320
- return fun_vjp(grad_out)
321
-
322
-
323
323
  fori_loop = torch_view(jax.lax.fori_loop)
324
324
 
325
325
 
@@ -6,6 +6,14 @@ import torch.func
6
6
  import torch.utils.dlpack as torchdl
7
7
  import torch.utils._mode_utils as mode_utils
8
8
 
9
+ NUMPY_UNSUPPORTED_DTYPES = {
10
+ torch.bfloat16: jnp.bfloat16,
11
+ torch.float8_e4m3fn: jnp.float8_e4m3fn,
12
+ torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
13
+ torch.float8_e5m2: jnp.float8_e5m2,
14
+ torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
15
+ }
16
+
9
17
 
10
18
  def t2j(t, use_dlpack=True):
11
19
  is_bool = False
@@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True):
28
36
  if res is None:
29
37
  # https://github.com/google/jax/issues/7657
30
38
  # https://github.com/google/jax/issues/17784
31
- if t.dtype == torch.bfloat16:
39
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
32
40
  nparray = (t.cpu().detach().to(torch.float32).numpy()
33
- ) # numpy don't support bfloat16
41
+ ) # handle dtypes not supported by numpy
34
42
  else:
35
43
  nparray = t.cpu().detach().numpy()
36
44
  res = jnp.asarray(nparray)
37
- if t.dtype == torch.bfloat16:
38
- res = res.astype(jnp.bfloat16)
45
+ if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
46
+ res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype])
39
47
 
40
48
  if is_bool:
41
49
  res = res.astype(jnp.bool_)
@@ -0,0 +1,47 @@
1
+
2
+ from torchax import config
3
+ from torchax.ops import mappings
4
+ import jax.numpy as jnp
5
+ import torch
6
+
7
+ def maybe_cast(result, torch_op):
8
+ """Casts the result to the torch op's return dtype if the config is set."""
9
+ if not config.DEFAULTS.internal_respect_torch_return_dtypes:
10
+ return result
11
+
12
+ if not hasattr(torch_op, '_schema'):
13
+ return result
14
+
15
+ schema = torch_op._schema
16
+ if not schema.returns:
17
+ return result
18
+
19
+ # TODO: Handle multiple return values
20
+ if len(schema.returns) > 1:
21
+ return result
22
+
23
+ return_type = schema.returns[0].type
24
+ if str(return_type) == 'Tensor':
25
+ # This is not quite right, we need to get the dtype of the tensor
26
+ # For now, let's assume we can get it from the first input argument
27
+ if not schema.arguments:
28
+ return result
29
+
30
+ input_type = schema.arguments[0].type
31
+ if str(input_type) != 'Tensor':
32
+ return result
33
+
34
+ # This is a hack, we need a better way to determine the return dtype
35
+ # For now, let's assume the return type is the same as the first input
36
+ # This is not always true, e.g. for comparison ops.
37
+ return result
38
+
39
+ try:
40
+ torch_dtype = getattr(torch, str(return_type))
41
+ jax_dtype = mappings.t2j_dtype(torch_dtype)
42
+ if isinstance(result, jnp.ndarray):
43
+ return result.astype(jax_dtype)
44
+ else:
45
+ return jax_dtype(result)
46
+ except (AttributeError, TypeError):
47
+ return result
@@ -469,12 +469,12 @@ class Environment(contextlib.ContextDecorator):
469
469
  arr = self.t2j_copy(the_tensor)
470
470
  res = Tensor(arr, self, the_tensor.requires_grad)
471
471
 
472
- if new_dtype is not None and new_dtype != the_tensor.dtype:
473
- if isinstance(the_tensor, Tensor):
472
+ if new_dtype is not None and new_dtype != res.dtype:
473
+ if isinstance(res, Tensor):
474
474
  res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
475
475
  else:
476
476
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
477
- return the_tensor.to(device=new_device, dtype=new_dtype)
477
+ return res.to(device=new_device, dtype=new_dtype)
478
478
  return res
479
479
 
480
480
  def get_and_rotate_prng_key(self,
@@ -1,22 +0,0 @@
1
- import unittest
2
- import torch
3
- import torchax
4
-
5
-
6
- class MiscTest(unittest.TestCase):
7
-
8
- @classmethod
9
- def setUpClass(cls):
10
- torchax.enable_globally()
11
-
12
- def test_mixed_tensor_math_with_scalar(self):
13
- a = torch.tensor(2)
14
- b = torch.ones((2, 2), device='jax')
15
- c = a * b
16
- self.assertTrue(
17
- torch.allclose(c.cpu(),
18
- torch.tensor([[2, 2], [2, 2]], dtype=torch.float32)))
19
-
20
-
21
- if __name__ == '__main__':
22
- unittest.main()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes