torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +57 -19
- torchax/amp.py +333 -0
- torchax/config.py +19 -12
- torchax/decompositions.py +663 -195
- torchax/device_module.py +7 -1
- torchax/distributed.py +55 -60
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +275 -141
- torchax/mesh_util.py +211 -0
- torchax/ops/jaten.py +1718 -1294
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +219 -78
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +417 -275
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/METADATA +111 -145
- torchax-0.0.5.dist-info/RECORD +32 -0
- torchax/environment.py +0 -2
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/licenses/LICENSE +0 -0
torchax/ops/mappings.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch.utils.dlpack as torchdl
|
|
|
7
7
|
import torch.utils._mode_utils as mode_utils
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def t2j(t):
|
|
10
|
+
def t2j(t, use_dlpack=True):
|
|
11
11
|
is_bool = False
|
|
12
12
|
if t.dtype == torch.bool:
|
|
13
13
|
is_bool = True
|
|
@@ -18,9 +18,14 @@ def t2j(t):
|
|
|
18
18
|
if not t.is_contiguous():
|
|
19
19
|
t = t.contiguous()
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
res = None
|
|
22
|
+
if use_dlpack:
|
|
23
|
+
try:
|
|
24
|
+
res = jaxdl.from_dlpack(t)
|
|
25
|
+
except Exception:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
if res is None:
|
|
24
29
|
# https://github.com/google/jax/issues/7657
|
|
25
30
|
# https://github.com/google/jax/issues/17784
|
|
26
31
|
if t.dtype == torch.bfloat16:
|
|
@@ -37,61 +42,98 @@ def t2j(t):
|
|
|
37
42
|
return res
|
|
38
43
|
|
|
39
44
|
|
|
40
|
-
def j2t(x):
|
|
45
|
+
def j2t(x, use_dlpack=True):
|
|
41
46
|
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
47
|
+
res = None
|
|
48
|
+
if use_dlpack:
|
|
49
|
+
try:
|
|
50
|
+
dl = jaxdl.to_dlpack(x)
|
|
51
|
+
res = torchdl.from_dlpack(dl)
|
|
52
|
+
except Exception:
|
|
53
|
+
res = None
|
|
54
|
+
|
|
55
|
+
orig_dtype = None
|
|
56
|
+
if res is None:
|
|
57
|
+
orig_dtype = None
|
|
58
|
+
if x.dtype == jnp.bfloat16.dtype:
|
|
59
|
+
orig_dtype = x.dtype
|
|
60
|
+
x = x.astype(jnp.float32.dtype)
|
|
46
61
|
res = torch.from_numpy(numpy.asarray(x))
|
|
62
|
+
|
|
47
63
|
if x.dtype == jnp.bool_:
|
|
48
64
|
res = res.to(torch.bool)
|
|
65
|
+
|
|
66
|
+
if orig_dtype is not None:
|
|
67
|
+
res = res.to(j2t_dtype(orig_dtype))
|
|
49
68
|
return res
|
|
50
69
|
|
|
70
|
+
|
|
51
71
|
TORCH_DTYPE_TO_JAX = {
|
|
52
72
|
# NO_MAPPING : jnp.float0.dtype (signless scalar int),
|
|
53
|
-
torch.bool
|
|
73
|
+
torch.bool:
|
|
74
|
+
jnp.bool_.dtype,
|
|
54
75
|
# NO_MAPPING : jnp.int4.dtype,
|
|
55
|
-
torch.int8
|
|
56
|
-
|
|
57
|
-
torch.
|
|
58
|
-
|
|
59
|
-
torch.
|
|
76
|
+
torch.int8:
|
|
77
|
+
jnp.int8.dtype,
|
|
78
|
+
torch.int16:
|
|
79
|
+
jnp.int16.dtype,
|
|
80
|
+
torch.int32:
|
|
81
|
+
jnp.int32.dtype,
|
|
82
|
+
torch.int64:
|
|
83
|
+
jnp.int64.dtype,
|
|
84
|
+
torch.long:
|
|
85
|
+
jnp.int64.dtype,
|
|
60
86
|
# NO_MAPPING : jnp.uint4
|
|
61
|
-
torch.uint8
|
|
62
|
-
|
|
63
|
-
torch.
|
|
64
|
-
|
|
87
|
+
torch.uint8:
|
|
88
|
+
jnp.uint8.dtype,
|
|
89
|
+
torch.uint16:
|
|
90
|
+
jnp.uint16.dtype,
|
|
91
|
+
torch.uint32:
|
|
92
|
+
jnp.uint32.dtype,
|
|
93
|
+
torch.uint64:
|
|
94
|
+
jnp.uint64.dtype,
|
|
65
95
|
# NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
|
|
66
|
-
torch.float8_e4m3fn
|
|
96
|
+
torch.float8_e4m3fn:
|
|
97
|
+
jnp.float8_e4m3fn.dtype,
|
|
67
98
|
# NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
|
|
68
|
-
torch.float8_e5m2
|
|
99
|
+
torch.float8_e5m2:
|
|
100
|
+
jnp.float8_e5m2.dtype,
|
|
69
101
|
# NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
|
|
70
|
-
torch.bfloat16
|
|
71
|
-
|
|
72
|
-
torch.
|
|
73
|
-
|
|
74
|
-
torch.
|
|
75
|
-
|
|
76
|
-
torch.
|
|
77
|
-
|
|
78
|
-
|
|
102
|
+
torch.bfloat16:
|
|
103
|
+
jnp.bfloat16.dtype,
|
|
104
|
+
torch.half:
|
|
105
|
+
jnp.float16.dtype,
|
|
106
|
+
torch.float16:
|
|
107
|
+
jnp.float16.dtype,
|
|
108
|
+
torch.float32:
|
|
109
|
+
jnp.float32.dtype,
|
|
110
|
+
torch.float64:
|
|
111
|
+
jnp.float64.dtype,
|
|
112
|
+
torch.double:
|
|
113
|
+
jnp.double.dtype,
|
|
114
|
+
torch.complex64:
|
|
115
|
+
jnp.complex64.dtype,
|
|
116
|
+
torch.complex128:
|
|
117
|
+
jnp.complex128.dtype,
|
|
118
|
+
None:
|
|
119
|
+
None,
|
|
79
120
|
}
|
|
80
121
|
|
|
81
|
-
JAX_DTYPE_TO_TORCH = {
|
|
82
|
-
value: key for key, value in TORCH_DTYPE_TO_JAX.items()
|
|
83
|
-
}
|
|
122
|
+
JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
|
|
84
123
|
# Add imprecise mappings for some JAX dtypes which don't have torch analogues
|
|
85
124
|
JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
|
|
86
125
|
JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
|
|
87
126
|
|
|
127
|
+
|
|
88
128
|
def t2j_dtype(dtype):
|
|
89
129
|
if dtype not in TORCH_DTYPE_TO_JAX:
|
|
90
|
-
raise RuntimeError(
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
f'Attempting to convert unknown type: {dtype} to jax type,')
|
|
91
132
|
return TORCH_DTYPE_TO_JAX[dtype]
|
|
92
133
|
|
|
93
134
|
|
|
94
135
|
def j2t_dtype(dtype):
|
|
95
136
|
if dtype not in JAX_DTYPE_TO_TORCH:
|
|
96
|
-
raise RuntimeError(
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
f'Attempting to convert unknown type: {dtype} to torch type,')
|
|
97
139
|
return JAX_DTYPE_TO_TORCH[dtype]
|
torchax/ops/op_base.py
CHANGED
|
@@ -4,6 +4,7 @@ import jax.numpy as jnp
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
from torchax.ops import mappings
|
|
7
|
+
from torchax.view import View
|
|
7
8
|
from torchax import types
|
|
8
9
|
import sys
|
|
9
10
|
|
|
@@ -12,31 +13,55 @@ from typing import Callable, Optional, ParamSpec, Concatenate
|
|
|
12
13
|
|
|
13
14
|
class InplaceOp:
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
16
|
+
def __init__(self,
|
|
17
|
+
functional_op,
|
|
18
|
+
replace=False,
|
|
19
|
+
position_to_mutate=0,
|
|
20
|
+
is_jax_func=False):
|
|
21
|
+
self.functional = functional_op
|
|
22
|
+
self.replace = replace
|
|
23
|
+
self.position_to_mutate = position_to_mutate
|
|
24
|
+
self.is_jax_func = is_jax_func
|
|
25
|
+
|
|
26
|
+
def __call__(self, *args, **kwargs):
|
|
27
|
+
to_mutate = args[self.position_to_mutate]
|
|
28
|
+
view_value = to_mutate
|
|
29
|
+
if isinstance(to_mutate, View):
|
|
30
|
+
view_value = to_mutate.torch()
|
|
31
|
+
# Convert the target View to a Tensor, and
|
|
32
|
+
# leave the rest args as is. If other args are
|
|
33
|
+
# also View, they will be converted to tensors
|
|
34
|
+
# in the self.functional dispatch.
|
|
35
|
+
env = view_value._env
|
|
36
|
+
if self.is_jax_func:
|
|
37
|
+
view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs))
|
|
38
|
+
new_value_jax = self.functional(view_value, *args[1:], **kwargs)
|
|
39
|
+
new_value = env.j2t_iso(new_value_jax)
|
|
40
|
+
else:
|
|
41
|
+
new_value = self.functional(view_value, *args[1:], **kwargs)
|
|
42
|
+
|
|
43
|
+
if isinstance(to_mutate, View):
|
|
44
|
+
to_mutate.update(new_value)
|
|
45
|
+
else:
|
|
46
|
+
if self.replace:
|
|
47
|
+
to_mutate._elem = new_value._elem
|
|
48
|
+
else:
|
|
49
|
+
to_mutate.copy_(new_value)
|
|
50
|
+
return to_mutate
|
|
27
51
|
|
|
28
52
|
|
|
29
53
|
class OutVariant:
|
|
30
54
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
55
|
+
def __call__(self, *args, **kwargs):
|
|
56
|
+
to_mutate = kwargs['out']
|
|
57
|
+
del kwargs['out']
|
|
58
|
+
to_mutate._elem = self.functional(*args, **kwargs)._elem
|
|
59
|
+
return to_mutate
|
|
37
60
|
|
|
38
61
|
|
|
39
62
|
P = ParamSpec('P')
|
|
63
|
+
|
|
64
|
+
|
|
40
65
|
def convert_dtype(use_default_dtype: bool = True):
|
|
41
66
|
"""Converts `dtype` kwarg of function from torch to JAX.
|
|
42
67
|
|
|
@@ -48,6 +73,7 @@ def convert_dtype(use_default_dtype: bool = True):
|
|
|
48
73
|
"""
|
|
49
74
|
|
|
50
75
|
def decorator(func: types.TorchCallable):
|
|
76
|
+
|
|
51
77
|
@functools.wraps(func)
|
|
52
78
|
def wrapper(*args: P.args,
|
|
53
79
|
dtype: Optional[torch.dtype] = None,
|
|
@@ -66,7 +92,8 @@ def convert_dtype(use_default_dtype: bool = True):
|
|
|
66
92
|
return decorator
|
|
67
93
|
|
|
68
94
|
|
|
69
|
-
def maybe_convert_constant_dtype(val: Optional[types.JaxValue],
|
|
95
|
+
def maybe_convert_constant_dtype(val: Optional[types.JaxValue],
|
|
96
|
+
dtype: Optional[jnp.dtype]):
|
|
70
97
|
"""Optionally converts scalar constant's dtype using `numpy`
|
|
71
98
|
|
|
72
99
|
Use in cases where you require a constant and can't handle a traced array.
|
|
@@ -81,24 +108,24 @@ def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[
|
|
|
81
108
|
|
|
82
109
|
|
|
83
110
|
def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
111
|
+
"""If the first argument is an int array, promote it to float32."""
|
|
112
|
+
|
|
113
|
+
@functools.wraps(f)
|
|
114
|
+
def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
|
|
115
|
+
if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
|
|
116
|
+
x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
|
|
89
117
|
|
|
90
|
-
|
|
118
|
+
return f(x, *args, **kwargs)
|
|
91
119
|
|
|
92
|
-
|
|
120
|
+
return wrapper
|
|
93
121
|
|
|
94
122
|
|
|
95
|
-
def foreach_loop(
|
|
96
|
-
|
|
97
|
-
):
|
|
123
|
+
def foreach_loop(seq: jax.Array,
|
|
124
|
+
fn: Callable[[jax.Array, jax.Array], jax.Array],
|
|
125
|
+
init_val=0.0):
|
|
98
126
|
"""Run `fn` for each element of 1D array `seq`.
|
|
99
127
|
|
|
100
128
|
Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
|
|
101
129
|
assert len(seq.shape) == 1
|
|
102
|
-
return jax.lax.fori_loop(
|
|
103
|
-
|
|
104
|
-
)
|
|
130
|
+
return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]),
|
|
131
|
+
init_val)
|
torchax/ops/ops_registry.py
CHANGED
|
@@ -7,44 +7,49 @@ from typing import Union, Dict
|
|
|
7
7
|
|
|
8
8
|
@dataclasses.dataclass
|
|
9
9
|
class Operator:
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
10
|
+
torch_op: TorchCallable
|
|
11
|
+
func: Union[TorchCallable, JaxCallable]
|
|
12
|
+
is_jax_function: bool
|
|
13
|
+
is_user_defined: bool
|
|
14
|
+
needs_env: bool
|
|
15
|
+
is_view_op: bool
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
all_aten_ops: Dict[TorchCallable, Operator] = {}
|
|
18
19
|
all_torch_functions: Dict[TorchCallable, Operator] = {}
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
def register_torch_dispatch_op(
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
):
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
22
|
+
def register_torch_dispatch_op(aten_op,
|
|
23
|
+
impl_callable,
|
|
24
|
+
is_jax_function=True,
|
|
25
|
+
is_user_defined=False,
|
|
26
|
+
needs_env=False,
|
|
27
|
+
is_view_op=False):
|
|
28
|
+
op = Operator(
|
|
29
|
+
aten_op,
|
|
30
|
+
impl_callable,
|
|
31
|
+
is_jax_function=is_jax_function,
|
|
32
|
+
is_user_defined=is_user_defined,
|
|
33
|
+
needs_env=needs_env,
|
|
34
|
+
is_view_op=is_view_op)
|
|
35
|
+
if aten_op in all_aten_ops:
|
|
36
|
+
logging.warning(f'Duplicate op registration for {aten_op}')
|
|
37
|
+
all_aten_ops[aten_op] = op
|
|
38
|
+
return impl_callable
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def register_torch_function_op(torch_func,
|
|
42
|
+
impl_callable,
|
|
43
|
+
is_jax_function=True,
|
|
44
|
+
is_user_defined=False,
|
|
45
|
+
needs_env=False,
|
|
46
|
+
is_view_op=False):
|
|
47
|
+
op = Operator(
|
|
48
|
+
torch_func,
|
|
49
|
+
impl_callable,
|
|
50
|
+
is_jax_function=is_jax_function,
|
|
51
|
+
is_user_defined=is_user_defined,
|
|
52
|
+
needs_env=needs_env,
|
|
53
|
+
is_view_op=is_view_op)
|
|
54
|
+
all_torch_functions[torch_func] = op
|
|
55
|
+
return impl_callable
|