torchax 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +38 -0
- torchax/__init__.py +124 -0
- torchax/config.py +19 -0
- torchax/decompositions.py +308 -0
- torchax/device_module.py +20 -0
- torchax/distributed.py +246 -0
- torchax/environment.py +2 -0
- torchax/export.py +236 -0
- torchax/interop.py +209 -0
- torchax/ops/__init__.py +10 -0
- torchax/ops/jaten.py +5212 -0
- torchax/ops/jax_reimplement.py +169 -0
- torchax/ops/jc10d.py +51 -0
- torchax/ops/jlibrary.py +73 -0
- torchax/ops/jtorch.py +427 -0
- torchax/ops/jtorchvision_nms.py +245 -0
- torchax/ops/mappings.py +97 -0
- torchax/ops/op_base.py +104 -0
- torchax/ops/ops_registry.py +50 -0
- torchax/tensor.py +557 -0
- torchax/tf_integration.py +119 -0
- torchax/train.py +120 -0
- torchax/types.py +12 -0
- torchax-0.0.4.dist-info/METADATA +341 -0
- torchax-0.0.4.dist-info/RECORD +27 -0
- torchax-0.0.4.dist-info/WHEEL +4 -0
- torchax-0.0.4.dist-info/licenses/LICENSE +28 -0
torchax/ops/mappings.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from jax import dlpack as jaxdl
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy
|
|
4
|
+
import torch
|
|
5
|
+
import torch.func
|
|
6
|
+
import torch.utils.dlpack as torchdl
|
|
7
|
+
import torch.utils._mode_utils as mode_utils
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def t2j(t):
|
|
11
|
+
is_bool = False
|
|
12
|
+
if t.dtype == torch.bool:
|
|
13
|
+
is_bool = True
|
|
14
|
+
t = t.to(torch.int8)
|
|
15
|
+
|
|
16
|
+
t = t.to_dense()
|
|
17
|
+
|
|
18
|
+
if not t.is_contiguous():
|
|
19
|
+
t = t.contiguous()
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
res = jaxdl.from_dlpack(t)
|
|
23
|
+
except Exception:
|
|
24
|
+
# https://github.com/google/jax/issues/7657
|
|
25
|
+
# https://github.com/google/jax/issues/17784
|
|
26
|
+
if t.dtype == torch.bfloat16:
|
|
27
|
+
nparray = (t.cpu().detach().to(torch.float32).numpy()
|
|
28
|
+
) # numpy don't support bfloat16
|
|
29
|
+
else:
|
|
30
|
+
nparray = t.cpu().detach().numpy()
|
|
31
|
+
res = jnp.asarray(nparray)
|
|
32
|
+
if t.dtype == torch.bfloat16:
|
|
33
|
+
res = res.astype(jnp.bfloat16)
|
|
34
|
+
|
|
35
|
+
if is_bool:
|
|
36
|
+
res = res.astype(jnp.bool_)
|
|
37
|
+
return res
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def j2t(x):
|
|
41
|
+
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
42
|
+
try:
|
|
43
|
+
dl = jaxdl.to_dlpack(x)
|
|
44
|
+
res = torchdl.from_dlpack(dl)
|
|
45
|
+
except Exception:
|
|
46
|
+
res = torch.from_numpy(numpy.asarray(x))
|
|
47
|
+
if x.dtype == jnp.bool_:
|
|
48
|
+
res = res.to(torch.bool)
|
|
49
|
+
return res
|
|
50
|
+
|
|
51
|
+
TORCH_DTYPE_TO_JAX = {
|
|
52
|
+
# NO_MAPPING : jnp.float0.dtype (signless scalar int),
|
|
53
|
+
torch.bool : jnp.bool_.dtype,
|
|
54
|
+
# NO_MAPPING : jnp.int4.dtype,
|
|
55
|
+
torch.int8 : jnp.int8.dtype,
|
|
56
|
+
torch.int16 : jnp.int16.dtype,
|
|
57
|
+
torch.int32 : jnp.int32.dtype,
|
|
58
|
+
torch.int64 : jnp.int64.dtype,
|
|
59
|
+
torch.long : jnp.int64.dtype,
|
|
60
|
+
# NO_MAPPING : jnp.uint4
|
|
61
|
+
torch.uint8 : jnp.uint8.dtype,
|
|
62
|
+
torch.uint16 : jnp.uint16.dtype,
|
|
63
|
+
torch.uint32 : jnp.uint32.dtype,
|
|
64
|
+
torch.uint64 : jnp.uint64.dtype,
|
|
65
|
+
# NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
|
|
66
|
+
torch.float8_e4m3fn : jnp.float8_e4m3fn.dtype,
|
|
67
|
+
# NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
|
|
68
|
+
torch.float8_e5m2 : jnp.float8_e5m2.dtype,
|
|
69
|
+
# NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
|
|
70
|
+
torch.bfloat16 : jnp.bfloat16.dtype,
|
|
71
|
+
torch.half : jnp.float16.dtype,
|
|
72
|
+
torch.float16 : jnp.float16.dtype,
|
|
73
|
+
torch.float32 : jnp.float32.dtype,
|
|
74
|
+
torch.float64 : jnp.float64.dtype,
|
|
75
|
+
torch.double : jnp.double.dtype,
|
|
76
|
+
torch.complex64 : jnp.complex64.dtype,
|
|
77
|
+
torch.complex128 : jnp.complex128.dtype,
|
|
78
|
+
None : None,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
JAX_DTYPE_TO_TORCH = {
|
|
82
|
+
value: key for key, value in TORCH_DTYPE_TO_JAX.items()
|
|
83
|
+
}
|
|
84
|
+
# Add imprecise mappings for some JAX dtypes which don't have torch analogues
|
|
85
|
+
JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
|
|
86
|
+
JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
|
|
87
|
+
|
|
88
|
+
def t2j_dtype(dtype):
|
|
89
|
+
if dtype not in TORCH_DTYPE_TO_JAX:
|
|
90
|
+
raise RuntimeError(f'Attempting to convert unknown type: {dtype} to jax type,')
|
|
91
|
+
return TORCH_DTYPE_TO_JAX[dtype]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def j2t_dtype(dtype):
|
|
95
|
+
if dtype not in JAX_DTYPE_TO_TORCH:
|
|
96
|
+
raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,')
|
|
97
|
+
return JAX_DTYPE_TO_TORCH[dtype]
|
torchax/ops/op_base.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from torchax.ops import mappings
|
|
7
|
+
from torchax import types
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
from typing import Callable, Optional, ParamSpec, Concatenate
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InplaceOp:
|
|
14
|
+
|
|
15
|
+
def __init__(self, functional_op, replace=False, position_to_mutate=0):
|
|
16
|
+
self.functional = functional_op
|
|
17
|
+
self.replace = replace
|
|
18
|
+
self.position_to_mutate = position_to_mutate
|
|
19
|
+
|
|
20
|
+
def __call__(self, *args, **kwargs):
|
|
21
|
+
to_mutate = args[0]
|
|
22
|
+
if self.replace:
|
|
23
|
+
to_mutate._elem = self.functional(*args, **kwargs)._elem
|
|
24
|
+
else:
|
|
25
|
+
to_mutate.copy_(self.functional(*args, **kwargs))
|
|
26
|
+
return to_mutate
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OutVariant:
|
|
30
|
+
|
|
31
|
+
def __call__(self, *args, **kwargs):
|
|
32
|
+
to_mutate = kwargs['out']
|
|
33
|
+
del kwargs['out']
|
|
34
|
+
to_mutate._elem = self.functional(*args, **kwargs)._elem
|
|
35
|
+
return to_mutate
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
P = ParamSpec('P')
|
|
40
|
+
def convert_dtype(use_default_dtype: bool = True):
|
|
41
|
+
"""Converts `dtype` kwarg of function from torch to JAX.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
use_default_dtype: Whether to use torch default dtype if none is provided.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A decorator that wraps a JAX implementation of a torch function.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def decorator(func: types.TorchCallable):
|
|
51
|
+
@functools.wraps(func)
|
|
52
|
+
def wrapper(*args: P.args,
|
|
53
|
+
dtype: Optional[torch.dtype] = None,
|
|
54
|
+
**kwargs: P.kwargs):
|
|
55
|
+
if not dtype and use_default_dtype:
|
|
56
|
+
dtype = torch.get_default_dtype()
|
|
57
|
+
if isinstance(dtype, torch.dtype):
|
|
58
|
+
jax_dtype = mappings.t2j_dtype(dtype)
|
|
59
|
+
else:
|
|
60
|
+
jax_dtype = dtype
|
|
61
|
+
|
|
62
|
+
return func(*args, dtype=jax_dtype, **kwargs)
|
|
63
|
+
|
|
64
|
+
return wrapper
|
|
65
|
+
|
|
66
|
+
return decorator
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[jnp.dtype]):
|
|
70
|
+
"""Optionally converts scalar constant's dtype using `numpy`
|
|
71
|
+
|
|
72
|
+
Use in cases where you require a constant and can't handle a traced array.
|
|
73
|
+
"""
|
|
74
|
+
if val and dtype:
|
|
75
|
+
if isinstance(val, jax.Array):
|
|
76
|
+
return maybe_convert_constant_dtype(val.item(), dtype)
|
|
77
|
+
|
|
78
|
+
return np.array(val, dtype)
|
|
79
|
+
|
|
80
|
+
return val
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
|
|
84
|
+
"""If the first argument is an int array, promote it to float32."""
|
|
85
|
+
@functools.wraps(f)
|
|
86
|
+
def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
|
|
87
|
+
if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
|
|
88
|
+
x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
89
|
+
|
|
90
|
+
return f(x, *args, **kwargs)
|
|
91
|
+
|
|
92
|
+
return wrapper
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def foreach_loop(
|
|
96
|
+
seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0
|
|
97
|
+
):
|
|
98
|
+
"""Run `fn` for each element of 1D array `seq`.
|
|
99
|
+
|
|
100
|
+
Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
|
|
101
|
+
assert len(seq.shape) == 1
|
|
102
|
+
return jax.lax.fori_loop(
|
|
103
|
+
0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val
|
|
104
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import logging
|
|
3
|
+
from torchax.types import JaxCallable, TorchCallable
|
|
4
|
+
|
|
5
|
+
from typing import Union, Dict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass
|
|
9
|
+
class Operator:
|
|
10
|
+
torch_op: TorchCallable
|
|
11
|
+
func: Union[TorchCallable, JaxCallable]
|
|
12
|
+
is_jax_function: bool
|
|
13
|
+
is_user_defined: bool
|
|
14
|
+
needs_env: bool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
all_aten_ops: Dict[TorchCallable, Operator] = {}
|
|
18
|
+
all_torch_functions: Dict[TorchCallable, Operator] = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def register_torch_dispatch_op(
|
|
22
|
+
aten_op, impl_callable,
|
|
23
|
+
is_jax_function=True,
|
|
24
|
+
is_user_defined=False,
|
|
25
|
+
needs_env=False,
|
|
26
|
+
):
|
|
27
|
+
op = Operator(
|
|
28
|
+
aten_op, impl_callable,
|
|
29
|
+
is_jax_function=is_jax_function,
|
|
30
|
+
is_user_defined=is_user_defined,
|
|
31
|
+
needs_env=needs_env)
|
|
32
|
+
if aten_op in all_aten_ops:
|
|
33
|
+
logging.warning(f'Duplicate op registration for {aten_op}')
|
|
34
|
+
all_aten_ops[aten_op] = op
|
|
35
|
+
return impl_callable
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def register_torch_function_op(
|
|
39
|
+
torch_func, impl_callable,
|
|
40
|
+
is_jax_function=True,
|
|
41
|
+
is_user_defined=False,
|
|
42
|
+
needs_env=False,
|
|
43
|
+
):
|
|
44
|
+
op = Operator(
|
|
45
|
+
torch_func, impl_callable,
|
|
46
|
+
is_jax_function=is_jax_function,
|
|
47
|
+
is_user_defined=is_user_defined,
|
|
48
|
+
needs_env=needs_env)
|
|
49
|
+
all_torch_functions[torch_func] = op
|
|
50
|
+
return impl_callable
|