torchax 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,97 @@
1
+ from jax import dlpack as jaxdl
2
+ import jax.numpy as jnp
3
+ import numpy
4
+ import torch
5
+ import torch.func
6
+ import torch.utils.dlpack as torchdl
7
+ import torch.utils._mode_utils as mode_utils
8
+
9
+
10
+ def t2j(t):
11
+ is_bool = False
12
+ if t.dtype == torch.bool:
13
+ is_bool = True
14
+ t = t.to(torch.int8)
15
+
16
+ t = t.to_dense()
17
+
18
+ if not t.is_contiguous():
19
+ t = t.contiguous()
20
+
21
+ try:
22
+ res = jaxdl.from_dlpack(t)
23
+ except Exception:
24
+ # https://github.com/google/jax/issues/7657
25
+ # https://github.com/google/jax/issues/17784
26
+ if t.dtype == torch.bfloat16:
27
+ nparray = (t.cpu().detach().to(torch.float32).numpy()
28
+ ) # numpy don't support bfloat16
29
+ else:
30
+ nparray = t.cpu().detach().numpy()
31
+ res = jnp.asarray(nparray)
32
+ if t.dtype == torch.bfloat16:
33
+ res = res.astype(jnp.bfloat16)
34
+
35
+ if is_bool:
36
+ res = res.astype(jnp.bool_)
37
+ return res
38
+
39
+
40
+ def j2t(x):
41
+ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
42
+ try:
43
+ dl = jaxdl.to_dlpack(x)
44
+ res = torchdl.from_dlpack(dl)
45
+ except Exception:
46
+ res = torch.from_numpy(numpy.asarray(x))
47
+ if x.dtype == jnp.bool_:
48
+ res = res.to(torch.bool)
49
+ return res
50
+
51
+ TORCH_DTYPE_TO_JAX = {
52
+ # NO_MAPPING : jnp.float0.dtype (signless scalar int),
53
+ torch.bool : jnp.bool_.dtype,
54
+ # NO_MAPPING : jnp.int4.dtype,
55
+ torch.int8 : jnp.int8.dtype,
56
+ torch.int16 : jnp.int16.dtype,
57
+ torch.int32 : jnp.int32.dtype,
58
+ torch.int64 : jnp.int64.dtype,
59
+ torch.long : jnp.int64.dtype,
60
+ # NO_MAPPING : jnp.uint4
61
+ torch.uint8 : jnp.uint8.dtype,
62
+ torch.uint16 : jnp.uint16.dtype,
63
+ torch.uint32 : jnp.uint32.dtype,
64
+ torch.uint64 : jnp.uint64.dtype,
65
+ # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
66
+ torch.float8_e4m3fn : jnp.float8_e4m3fn.dtype,
67
+ # NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
68
+ torch.float8_e5m2 : jnp.float8_e5m2.dtype,
69
+ # NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
70
+ torch.bfloat16 : jnp.bfloat16.dtype,
71
+ torch.half : jnp.float16.dtype,
72
+ torch.float16 : jnp.float16.dtype,
73
+ torch.float32 : jnp.float32.dtype,
74
+ torch.float64 : jnp.float64.dtype,
75
+ torch.double : jnp.double.dtype,
76
+ torch.complex64 : jnp.complex64.dtype,
77
+ torch.complex128 : jnp.complex128.dtype,
78
+ None : None,
79
+ }
80
+
81
+ JAX_DTYPE_TO_TORCH = {
82
+ value: key for key, value in TORCH_DTYPE_TO_JAX.items()
83
+ }
84
+ # Add imprecise mappings for some JAX dtypes which don't have torch analogues
85
+ JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
86
+ JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
87
+
88
+ def t2j_dtype(dtype):
89
+ if dtype not in TORCH_DTYPE_TO_JAX:
90
+ raise RuntimeError(f'Attempting to convert unknown type: {dtype} to jax type,')
91
+ return TORCH_DTYPE_TO_JAX[dtype]
92
+
93
+
94
+ def j2t_dtype(dtype):
95
+ if dtype not in JAX_DTYPE_TO_TORCH:
96
+ raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,')
97
+ return JAX_DTYPE_TO_TORCH[dtype]
torchax/ops/op_base.py ADDED
@@ -0,0 +1,104 @@
1
+ import functools
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ import torch
6
+ from torchax.ops import mappings
7
+ from torchax import types
8
+ import sys
9
+
10
+ from typing import Callable, Optional, ParamSpec, Concatenate
11
+
12
+
13
+ class InplaceOp:
14
+
15
+ def __init__(self, functional_op, replace=False, position_to_mutate=0):
16
+ self.functional = functional_op
17
+ self.replace = replace
18
+ self.position_to_mutate = position_to_mutate
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ to_mutate = args[0]
22
+ if self.replace:
23
+ to_mutate._elem = self.functional(*args, **kwargs)._elem
24
+ else:
25
+ to_mutate.copy_(self.functional(*args, **kwargs))
26
+ return to_mutate
27
+
28
+
29
+ class OutVariant:
30
+
31
+ def __call__(self, *args, **kwargs):
32
+ to_mutate = kwargs['out']
33
+ del kwargs['out']
34
+ to_mutate._elem = self.functional(*args, **kwargs)._elem
35
+ return to_mutate
36
+
37
+
38
+
39
+ P = ParamSpec('P')
40
+ def convert_dtype(use_default_dtype: bool = True):
41
+ """Converts `dtype` kwarg of function from torch to JAX.
42
+
43
+ Args:
44
+ use_default_dtype: Whether to use torch default dtype if none is provided.
45
+
46
+ Returns:
47
+ A decorator that wraps a JAX implementation of a torch function.
48
+ """
49
+
50
+ def decorator(func: types.TorchCallable):
51
+ @functools.wraps(func)
52
+ def wrapper(*args: P.args,
53
+ dtype: Optional[torch.dtype] = None,
54
+ **kwargs: P.kwargs):
55
+ if not dtype and use_default_dtype:
56
+ dtype = torch.get_default_dtype()
57
+ if isinstance(dtype, torch.dtype):
58
+ jax_dtype = mappings.t2j_dtype(dtype)
59
+ else:
60
+ jax_dtype = dtype
61
+
62
+ return func(*args, dtype=jax_dtype, **kwargs)
63
+
64
+ return wrapper
65
+
66
+ return decorator
67
+
68
+
69
+ def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[jnp.dtype]):
70
+ """Optionally converts scalar constant's dtype using `numpy`
71
+
72
+ Use in cases where you require a constant and can't handle a traced array.
73
+ """
74
+ if val and dtype:
75
+ if isinstance(val, jax.Array):
76
+ return maybe_convert_constant_dtype(val.item(), dtype)
77
+
78
+ return np.array(val, dtype)
79
+
80
+ return val
81
+
82
+
83
+ def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
84
+ """If the first argument is an int array, promote it to float32."""
85
+ @functools.wraps(f)
86
+ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
87
+ if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
88
+ x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
89
+
90
+ return f(x, *args, **kwargs)
91
+
92
+ return wrapper
93
+
94
+
95
+ def foreach_loop(
96
+ seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0
97
+ ):
98
+ """Run `fn` for each element of 1D array `seq`.
99
+
100
+ Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
101
+ assert len(seq.shape) == 1
102
+ return jax.lax.fori_loop(
103
+ 0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val
104
+ )
@@ -0,0 +1,50 @@
1
+ import dataclasses
2
+ import logging
3
+ from torchax.types import JaxCallable, TorchCallable
4
+
5
+ from typing import Union, Dict
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class Operator:
10
+ torch_op: TorchCallable
11
+ func: Union[TorchCallable, JaxCallable]
12
+ is_jax_function: bool
13
+ is_user_defined: bool
14
+ needs_env: bool
15
+
16
+
17
+ all_aten_ops: Dict[TorchCallable, Operator] = {}
18
+ all_torch_functions: Dict[TorchCallable, Operator] = {}
19
+
20
+
21
+ def register_torch_dispatch_op(
22
+ aten_op, impl_callable,
23
+ is_jax_function=True,
24
+ is_user_defined=False,
25
+ needs_env=False,
26
+ ):
27
+ op = Operator(
28
+ aten_op, impl_callable,
29
+ is_jax_function=is_jax_function,
30
+ is_user_defined=is_user_defined,
31
+ needs_env=needs_env)
32
+ if aten_op in all_aten_ops:
33
+ logging.warning(f'Duplicate op registration for {aten_op}')
34
+ all_aten_ops[aten_op] = op
35
+ return impl_callable
36
+
37
+
38
+ def register_torch_function_op(
39
+ torch_func, impl_callable,
40
+ is_jax_function=True,
41
+ is_user_defined=False,
42
+ needs_env=False,
43
+ ):
44
+ op = Operator(
45
+ torch_func, impl_callable,
46
+ is_jax_function=is_jax_function,
47
+ is_user_defined=is_user_defined,
48
+ needs_env=needs_env)
49
+ all_torch_functions[torch_func] = op
50
+ return impl_callable