torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/ops/mappings.py CHANGED
@@ -7,7 +7,7 @@ import torch.utils.dlpack as torchdl
7
7
  import torch.utils._mode_utils as mode_utils
8
8
 
9
9
 
10
- def t2j(t):
10
+ def t2j(t, use_dlpack=True):
11
11
  is_bool = False
12
12
  if t.dtype == torch.bool:
13
13
  is_bool = True
@@ -18,9 +18,14 @@ def t2j(t):
18
18
  if not t.is_contiguous():
19
19
  t = t.contiguous()
20
20
 
21
- try:
22
- res = jaxdl.from_dlpack(t)
23
- except Exception:
21
+ res = None
22
+ if use_dlpack:
23
+ try:
24
+ res = jaxdl.from_dlpack(t)
25
+ except Exception:
26
+ pass
27
+
28
+ if res is None:
24
29
  # https://github.com/google/jax/issues/7657
25
30
  # https://github.com/google/jax/issues/17784
26
31
  if t.dtype == torch.bfloat16:
@@ -37,61 +42,98 @@ def t2j(t):
37
42
  return res
38
43
 
39
44
 
40
- def j2t(x):
45
+ def j2t(x, use_dlpack=True):
41
46
  with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
42
- try:
43
- dl = jaxdl.to_dlpack(x)
44
- res = torchdl.from_dlpack(dl)
45
- except Exception:
47
+ res = None
48
+ if use_dlpack:
49
+ try:
50
+ dl = jaxdl.to_dlpack(x)
51
+ res = torchdl.from_dlpack(dl)
52
+ except Exception:
53
+ res = None
54
+
55
+ orig_dtype = None
56
+ if res is None:
57
+ orig_dtype = None
58
+ if x.dtype == jnp.bfloat16.dtype:
59
+ orig_dtype = x.dtype
60
+ x = x.astype(jnp.float32.dtype)
46
61
  res = torch.from_numpy(numpy.asarray(x))
62
+
47
63
  if x.dtype == jnp.bool_:
48
64
  res = res.to(torch.bool)
65
+
66
+ if orig_dtype is not None:
67
+ res = res.to(j2t_dtype(orig_dtype))
49
68
  return res
50
69
 
70
+
51
71
  TORCH_DTYPE_TO_JAX = {
52
72
  # NO_MAPPING : jnp.float0.dtype (signless scalar int),
53
- torch.bool : jnp.bool_.dtype,
73
+ torch.bool:
74
+ jnp.bool_.dtype,
54
75
  # NO_MAPPING : jnp.int4.dtype,
55
- torch.int8 : jnp.int8.dtype,
56
- torch.int16 : jnp.int16.dtype,
57
- torch.int32 : jnp.int32.dtype,
58
- torch.int64 : jnp.int64.dtype,
59
- torch.long : jnp.int64.dtype,
76
+ torch.int8:
77
+ jnp.int8.dtype,
78
+ torch.int16:
79
+ jnp.int16.dtype,
80
+ torch.int32:
81
+ jnp.int32.dtype,
82
+ torch.int64:
83
+ jnp.int64.dtype,
84
+ torch.long:
85
+ jnp.int64.dtype,
60
86
  # NO_MAPPING : jnp.uint4
61
- torch.uint8 : jnp.uint8.dtype,
62
- torch.uint16 : jnp.uint16.dtype,
63
- torch.uint32 : jnp.uint32.dtype,
64
- torch.uint64 : jnp.uint64.dtype,
87
+ torch.uint8:
88
+ jnp.uint8.dtype,
89
+ torch.uint16:
90
+ jnp.uint16.dtype,
91
+ torch.uint32:
92
+ jnp.uint32.dtype,
93
+ torch.uint64:
94
+ jnp.uint64.dtype,
65
95
  # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
66
- torch.float8_e4m3fn : jnp.float8_e4m3fn.dtype,
96
+ torch.float8_e4m3fn:
97
+ jnp.float8_e4m3fn.dtype,
67
98
  # NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
68
- torch.float8_e5m2 : jnp.float8_e5m2.dtype,
99
+ torch.float8_e5m2:
100
+ jnp.float8_e5m2.dtype,
69
101
  # NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
70
- torch.bfloat16 : jnp.bfloat16.dtype,
71
- torch.half : jnp.float16.dtype,
72
- torch.float16 : jnp.float16.dtype,
73
- torch.float32 : jnp.float32.dtype,
74
- torch.float64 : jnp.float64.dtype,
75
- torch.double : jnp.double.dtype,
76
- torch.complex64 : jnp.complex64.dtype,
77
- torch.complex128 : jnp.complex128.dtype,
78
- None : None,
102
+ torch.bfloat16:
103
+ jnp.bfloat16.dtype,
104
+ torch.half:
105
+ jnp.float16.dtype,
106
+ torch.float16:
107
+ jnp.float16.dtype,
108
+ torch.float32:
109
+ jnp.float32.dtype,
110
+ torch.float64:
111
+ jnp.float64.dtype,
112
+ torch.double:
113
+ jnp.double.dtype,
114
+ torch.complex64:
115
+ jnp.complex64.dtype,
116
+ torch.complex128:
117
+ jnp.complex128.dtype,
118
+ None:
119
+ None,
79
120
  }
80
121
 
81
- JAX_DTYPE_TO_TORCH = {
82
- value: key for key, value in TORCH_DTYPE_TO_JAX.items()
83
- }
122
+ JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
84
123
  # Add imprecise mappings for some JAX dtypes which don't have torch analogues
85
124
  JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8
86
125
  JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8
87
126
 
127
+
88
128
  def t2j_dtype(dtype):
89
129
  if dtype not in TORCH_DTYPE_TO_JAX:
90
- raise RuntimeError(f'Attempting to convert unknown type: {dtype} to jax type,')
130
+ raise RuntimeError(
131
+ f'Attempting to convert unknown type: {dtype} to jax type,')
91
132
  return TORCH_DTYPE_TO_JAX[dtype]
92
133
 
93
134
 
94
135
  def j2t_dtype(dtype):
95
136
  if dtype not in JAX_DTYPE_TO_TORCH:
96
- raise RuntimeError(f'Attempting to convert unknown type: {dtype} to torch type,')
137
+ raise RuntimeError(
138
+ f'Attempting to convert unknown type: {dtype} to torch type,')
97
139
  return JAX_DTYPE_TO_TORCH[dtype]
torchax/ops/op_base.py CHANGED
@@ -4,6 +4,7 @@ import jax.numpy as jnp
4
4
  import numpy as np
5
5
  import torch
6
6
  from torchax.ops import mappings
7
+ from torchax.view import View
7
8
  from torchax import types
8
9
  import sys
9
10
 
@@ -12,31 +13,55 @@ from typing import Callable, Optional, ParamSpec, Concatenate
12
13
 
13
14
  class InplaceOp:
14
15
 
15
- def __init__(self, functional_op, replace=False, position_to_mutate=0):
16
- self.functional = functional_op
17
- self.replace = replace
18
- self.position_to_mutate = position_to_mutate
19
-
20
- def __call__(self, *args, **kwargs):
21
- to_mutate = args[0]
22
- if self.replace:
23
- to_mutate._elem = self.functional(*args, **kwargs)._elem
24
- else:
25
- to_mutate.copy_(self.functional(*args, **kwargs))
26
- return to_mutate
16
+ def __init__(self,
17
+ functional_op,
18
+ replace=False,
19
+ position_to_mutate=0,
20
+ is_jax_func=False):
21
+ self.functional = functional_op
22
+ self.replace = replace
23
+ self.position_to_mutate = position_to_mutate
24
+ self.is_jax_func = is_jax_func
25
+
26
+ def __call__(self, *args, **kwargs):
27
+ to_mutate = args[self.position_to_mutate]
28
+ view_value = to_mutate
29
+ if isinstance(to_mutate, View):
30
+ view_value = to_mutate.torch()
31
+ # Convert the target View to a Tensor, and
32
+ # leave the rest args as is. If other args are
33
+ # also View, they will be converted to tensors
34
+ # in the self.functional dispatch.
35
+ env = view_value._env
36
+ if self.is_jax_func:
37
+ view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs))
38
+ new_value_jax = self.functional(view_value, *args[1:], **kwargs)
39
+ new_value = env.j2t_iso(new_value_jax)
40
+ else:
41
+ new_value = self.functional(view_value, *args[1:], **kwargs)
42
+
43
+ if isinstance(to_mutate, View):
44
+ to_mutate.update(new_value)
45
+ else:
46
+ if self.replace:
47
+ to_mutate._elem = new_value._elem
48
+ else:
49
+ to_mutate.copy_(new_value)
50
+ return to_mutate
27
51
 
28
52
 
29
53
  class OutVariant:
30
54
 
31
- def __call__(self, *args, **kwargs):
32
- to_mutate = kwargs['out']
33
- del kwargs['out']
34
- to_mutate._elem = self.functional(*args, **kwargs)._elem
35
- return to_mutate
36
-
55
+ def __call__(self, *args, **kwargs):
56
+ to_mutate = kwargs['out']
57
+ del kwargs['out']
58
+ to_mutate._elem = self.functional(*args, **kwargs)._elem
59
+ return to_mutate
37
60
 
38
61
 
39
62
  P = ParamSpec('P')
63
+
64
+
40
65
  def convert_dtype(use_default_dtype: bool = True):
41
66
  """Converts `dtype` kwarg of function from torch to JAX.
42
67
 
@@ -48,6 +73,7 @@ def convert_dtype(use_default_dtype: bool = True):
48
73
  """
49
74
 
50
75
  def decorator(func: types.TorchCallable):
76
+
51
77
  @functools.wraps(func)
52
78
  def wrapper(*args: P.args,
53
79
  dtype: Optional[torch.dtype] = None,
@@ -66,7 +92,8 @@ def convert_dtype(use_default_dtype: bool = True):
66
92
  return decorator
67
93
 
68
94
 
69
- def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[jnp.dtype]):
95
+ def maybe_convert_constant_dtype(val: Optional[types.JaxValue],
96
+ dtype: Optional[jnp.dtype]):
70
97
  """Optionally converts scalar constant's dtype using `numpy`
71
98
 
72
99
  Use in cases where you require a constant and can't handle a traced array.
@@ -81,24 +108,24 @@ def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[
81
108
 
82
109
 
83
110
  def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
84
- """If the first argument is an int array, promote it to float32."""
85
- @functools.wraps(f)
86
- def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
87
- if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
88
- x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
111
+ """If the first argument is an int array, promote it to float32."""
112
+
113
+ @functools.wraps(f)
114
+ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs):
115
+ if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]:
116
+ x = x.astype(mappings.t2j_dtype(torch.get_default_dtype()))
89
117
 
90
- return f(x, *args, **kwargs)
118
+ return f(x, *args, **kwargs)
91
119
 
92
- return wrapper
120
+ return wrapper
93
121
 
94
122
 
95
- def foreach_loop(
96
- seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0
97
- ):
123
+ def foreach_loop(seq: jax.Array,
124
+ fn: Callable[[jax.Array, jax.Array], jax.Array],
125
+ init_val=0.0):
98
126
  """Run `fn` for each element of 1D array `seq`.
99
127
 
100
128
  Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
101
129
  assert len(seq.shape) == 1
102
- return jax.lax.fori_loop(
103
- 0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val
104
- )
130
+ return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]),
131
+ init_val)
@@ -7,44 +7,49 @@ from typing import Union, Dict
7
7
 
8
8
  @dataclasses.dataclass
9
9
  class Operator:
10
- torch_op: TorchCallable
11
- func: Union[TorchCallable, JaxCallable]
12
- is_jax_function: bool
13
- is_user_defined: bool
14
- needs_env: bool
10
+ torch_op: TorchCallable
11
+ func: Union[TorchCallable, JaxCallable]
12
+ is_jax_function: bool
13
+ is_user_defined: bool
14
+ needs_env: bool
15
+ is_view_op: bool
15
16
 
16
17
 
17
18
  all_aten_ops: Dict[TorchCallable, Operator] = {}
18
19
  all_torch_functions: Dict[TorchCallable, Operator] = {}
19
20
 
20
21
 
21
- def register_torch_dispatch_op(
22
- aten_op, impl_callable,
23
- is_jax_function=True,
24
- is_user_defined=False,
25
- needs_env=False,
26
- ):
27
- op = Operator(
28
- aten_op, impl_callable,
29
- is_jax_function=is_jax_function,
30
- is_user_defined=is_user_defined,
31
- needs_env=needs_env)
32
- if aten_op in all_aten_ops:
33
- logging.warning(f'Duplicate op registration for {aten_op}')
34
- all_aten_ops[aten_op] = op
35
- return impl_callable
36
-
37
-
38
- def register_torch_function_op(
39
- torch_func, impl_callable,
40
- is_jax_function=True,
41
- is_user_defined=False,
42
- needs_env=False,
43
- ):
44
- op = Operator(
45
- torch_func, impl_callable,
46
- is_jax_function=is_jax_function,
47
- is_user_defined=is_user_defined,
48
- needs_env=needs_env)
49
- all_torch_functions[torch_func] = op
50
- return impl_callable
22
+ def register_torch_dispatch_op(aten_op,
23
+ impl_callable,
24
+ is_jax_function=True,
25
+ is_user_defined=False,
26
+ needs_env=False,
27
+ is_view_op=False):
28
+ op = Operator(
29
+ aten_op,
30
+ impl_callable,
31
+ is_jax_function=is_jax_function,
32
+ is_user_defined=is_user_defined,
33
+ needs_env=needs_env,
34
+ is_view_op=is_view_op)
35
+ if aten_op in all_aten_ops:
36
+ logging.warning(f'Duplicate op registration for {aten_op}')
37
+ all_aten_ops[aten_op] = op
38
+ return impl_callable
39
+
40
+
41
+ def register_torch_function_op(torch_func,
42
+ impl_callable,
43
+ is_jax_function=True,
44
+ is_user_defined=False,
45
+ needs_env=False,
46
+ is_view_op=False):
47
+ op = Operator(
48
+ torch_func,
49
+ impl_callable,
50
+ is_jax_function=is_jax_function,
51
+ is_user_defined=is_user_defined,
52
+ needs_env=needs_env,
53
+ is_view_op=is_view_op)
54
+ all_torch_functions[torch_func] = op
55
+ return impl_callable