torchax 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +38 -0
- torchax/__init__.py +124 -0
- torchax/config.py +19 -0
- torchax/decompositions.py +308 -0
- torchax/device_module.py +20 -0
- torchax/distributed.py +246 -0
- torchax/environment.py +2 -0
- torchax/export.py +236 -0
- torchax/interop.py +209 -0
- torchax/ops/__init__.py +10 -0
- torchax/ops/jaten.py +5212 -0
- torchax/ops/jax_reimplement.py +169 -0
- torchax/ops/jc10d.py +51 -0
- torchax/ops/jlibrary.py +73 -0
- torchax/ops/jtorch.py +427 -0
- torchax/ops/jtorchvision_nms.py +245 -0
- torchax/ops/mappings.py +97 -0
- torchax/ops/op_base.py +104 -0
- torchax/ops/ops_registry.py +50 -0
- torchax/tensor.py +557 -0
- torchax/tf_integration.py +119 -0
- torchax/train.py +120 -0
- torchax/types.py +12 -0
- torchax-0.0.4.dist-info/METADATA +341 -0
- torchax-0.0.4.dist-info/RECORD +27 -0
- torchax-0.0.4.dist-info/WHEEL +4 -0
- torchax-0.0.4.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from jax._src.numpy.util import promote_dtypes_inexact
|
|
4
|
+
import numpy as np
|
|
5
|
+
import jax
|
|
6
|
+
from jax import numpy as jnp
|
|
7
|
+
from jax._src.util import canonicalize_axis
|
|
8
|
+
from jax._src import core
|
|
9
|
+
from jax._src.image.scale import _kernels, ResizeMethod
|
|
10
|
+
from jax import lax
|
|
11
|
+
from typing import Callable
|
|
12
|
+
|
|
13
|
+
# TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106
|
|
14
|
+
# START ----------------- JAX code copied for fixing scale_and_translate -----------------------------
|
|
15
|
+
|
|
16
|
+
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
|
|
17
|
+
|
|
18
|
+
def compute_weight_mat(input_size: core.DimSize,
|
|
19
|
+
output_size: core.DimSize,
|
|
20
|
+
scale,
|
|
21
|
+
translation,
|
|
22
|
+
kernel: Callable,
|
|
23
|
+
antialias: bool):
|
|
24
|
+
dtype = jnp.result_type(scale, translation)
|
|
25
|
+
inv_scale = 1. / scale
|
|
26
|
+
# When downsampling the kernel should be scaled since we want to low pass
|
|
27
|
+
# filter and interpolate, but when upsampling it should not be since we only
|
|
28
|
+
# want to interpolate.
|
|
29
|
+
kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1.
|
|
30
|
+
sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale -
|
|
31
|
+
translation * inv_scale - 0.5)
|
|
32
|
+
x = (
|
|
33
|
+
jnp.abs(sample_f[jnp.newaxis, :] -
|
|
34
|
+
jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) /
|
|
35
|
+
kernel_scale)
|
|
36
|
+
weights = kernel(x)
|
|
37
|
+
|
|
38
|
+
total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
|
|
39
|
+
weights = jnp.where(
|
|
40
|
+
jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
|
|
41
|
+
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)),
|
|
42
|
+
0)
|
|
43
|
+
# Zero out weights where the sample location is completely outside the input
|
|
44
|
+
# range.
|
|
45
|
+
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
|
46
|
+
|
|
47
|
+
# (barney-s) -------------- returning weights without zeroing ---------------------
|
|
48
|
+
return weights
|
|
49
|
+
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
|
|
50
|
+
return jnp.where(
|
|
51
|
+
jnp.logical_and(sample_f >= -0.5,
|
|
52
|
+
sample_f <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
|
|
53
|
+
# (barney-s) -------------- END returning weights without zeroing ---------------------
|
|
54
|
+
|
|
55
|
+
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
|
|
56
|
+
|
|
57
|
+
def _scale_and_translate(x, output_shape: core.Shape,
|
|
58
|
+
spatial_dims: Sequence[int], scale, translation,
|
|
59
|
+
kernel, antialias: bool, precision):
|
|
60
|
+
input_shape = x.shape
|
|
61
|
+
assert len(input_shape) == len(output_shape)
|
|
62
|
+
assert len(spatial_dims) == len(scale)
|
|
63
|
+
assert len(spatial_dims) == len(translation)
|
|
64
|
+
if len(spatial_dims) == 0:
|
|
65
|
+
return x
|
|
66
|
+
contractions = []
|
|
67
|
+
in_indices = list(range(len(output_shape)))
|
|
68
|
+
out_indices = list(range(len(output_shape)))
|
|
69
|
+
for i, d in enumerate(spatial_dims):
|
|
70
|
+
d = canonicalize_axis(d, x.ndim)
|
|
71
|
+
m = input_shape[d]
|
|
72
|
+
n = output_shape[d]
|
|
73
|
+
w = compute_weight_mat(m, n, scale[i], translation[i],
|
|
74
|
+
kernel, antialias).astype(x.dtype)
|
|
75
|
+
contractions.append(w)
|
|
76
|
+
contractions.append([d, len(output_shape) + i])
|
|
77
|
+
out_indices[d] = len(output_shape) + i
|
|
78
|
+
contractions.append(out_indices)
|
|
79
|
+
return jnp.einsum(x, in_indices, *contractions, precision=precision)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172
|
|
83
|
+
|
|
84
|
+
# scale and translation here are scalar elements of an np.array, what is the
|
|
85
|
+
# correct type annotation?
|
|
86
|
+
def scale_and_translate(image, shape: core.Shape,
|
|
87
|
+
spatial_dims: Sequence[int],
|
|
88
|
+
scale, translation,
|
|
89
|
+
# (barney-s) use string
|
|
90
|
+
method: str, #(barney-s) | ResizeMethod,
|
|
91
|
+
antialias: bool = True,
|
|
92
|
+
precision=lax.Precision.HIGHEST):
|
|
93
|
+
"""Apply a scale and translation to an image.
|
|
94
|
+
|
|
95
|
+
Generates a new image of shape 'shape' by resampling from the input image
|
|
96
|
+
using the sampling method corresponding to method. For 2D images, this
|
|
97
|
+
operation transforms a location in the input images, (x, y), to a location
|
|
98
|
+
in the output image according to::
|
|
99
|
+
|
|
100
|
+
(x * scale[1] + translation[1], y * scale[0] + translation[0])
|
|
101
|
+
|
|
102
|
+
(Note the *inverse* warp is used to generate the sample locations.)
|
|
103
|
+
Assumes half-centered pixels, i.e the pixel at integer location ``row, col``
|
|
104
|
+
has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input
|
|
105
|
+
image dimensions.
|
|
106
|
+
|
|
107
|
+
If an output location(pixel) maps to an input sample location that is outside
|
|
108
|
+
the input boundaries then the value for the output location will be set to
|
|
109
|
+
zero.
|
|
110
|
+
|
|
111
|
+
The ``method`` argument expects one of the following resize methods:
|
|
112
|
+
|
|
113
|
+
``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``,
|
|
114
|
+
``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a
|
|
115
|
+
triangular filter when downsampling.
|
|
116
|
+
|
|
117
|
+
``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"``
|
|
118
|
+
`Cubic interpolation`_, using the Keys cubic kernel.
|
|
119
|
+
|
|
120
|
+
``ResizeMethod.LANCZOS3``, ``"lanczos3"``
|
|
121
|
+
`Lanczos resampling`_, using a kernel of radius 3.
|
|
122
|
+
|
|
123
|
+
``ResizeMethod.LANCZOS5``, ``"lanczos5"``
|
|
124
|
+
`Lanczos resampling`_, using a kernel of radius 5.
|
|
125
|
+
|
|
126
|
+
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
|
|
127
|
+
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
|
|
128
|
+
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
image: a JAX array.
|
|
132
|
+
shape: the output shape, as a sequence of integers with length equal to the
|
|
133
|
+
number of dimensions of `image`.
|
|
134
|
+
spatial_dims: A length K tuple specifying the spatial dimensions that the
|
|
135
|
+
passed scale and translation should be applied to.
|
|
136
|
+
scale: A [K] array with the same number of dimensions as image, containing
|
|
137
|
+
the scale to apply in each dimension.
|
|
138
|
+
translation: A [K] array with the same number of dimensions as image,
|
|
139
|
+
containing the translation to apply in each dimension.
|
|
140
|
+
method: the resizing method to use; either a ``ResizeMethod`` instance or a
|
|
141
|
+
string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC.
|
|
142
|
+
antialias: Should an antialiasing filter be used when downsampling? Defaults
|
|
143
|
+
to ``True``. Has no effect when upsampling.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
The scale and translated image.
|
|
147
|
+
"""
|
|
148
|
+
shape = core.canonicalize_shape(shape)
|
|
149
|
+
if len(shape) != image.ndim:
|
|
150
|
+
msg = ('shape must have length equal to the number of dimensions of x; '
|
|
151
|
+
f' {shape} vs {image.shape}')
|
|
152
|
+
raise ValueError(msg)
|
|
153
|
+
if isinstance(method, str):
|
|
154
|
+
method = ResizeMethod.from_string(method)
|
|
155
|
+
if method == ResizeMethod.NEAREST:
|
|
156
|
+
# Nearest neighbor is currently special-cased for straight resize, so skip
|
|
157
|
+
# for now.
|
|
158
|
+
raise ValueError('Nearest neighbor resampling is not currently supported '
|
|
159
|
+
'for scale_and_translate.')
|
|
160
|
+
assert isinstance(method, ResizeMethod)
|
|
161
|
+
|
|
162
|
+
kernel = _kernels[method]
|
|
163
|
+
image, = promote_dtypes_inexact(image)
|
|
164
|
+
scale, translation = promote_dtypes_inexact(scale, translation)
|
|
165
|
+
return _scale_and_translate(image, shape, spatial_dims, scale, translation,
|
|
166
|
+
kernel, antialias, precision)
|
|
167
|
+
|
|
168
|
+
# END ----------------- END JAX code copied for testing -----------------------------
|
|
169
|
+
|
torchax/ops/jc10d.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
from torchax.ops import ops_registry
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def op(*aten, **kwargs):
|
|
9
|
+
def inner(func):
|
|
10
|
+
for a in aten:
|
|
11
|
+
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
|
|
12
|
+
return func
|
|
13
|
+
|
|
14
|
+
return inner
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@op(torch.ops._c10d_functional.all_gather_into_tensor)
|
|
18
|
+
def _c10d_all_gather(input, group_size: int, group_name: str):
|
|
19
|
+
return jax.lax.all_gather(input, "torch_dist")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@op(torch.ops._c10d_functional.all_reduce)
|
|
23
|
+
def _c10d_all_reduce(self, reduceOp: str, group_name: str):
|
|
24
|
+
|
|
25
|
+
if reduceOp == "sum":
|
|
26
|
+
res = jax.lax.psum(self, axis_name="torch_dist")
|
|
27
|
+
elif reduceOp == "avg":
|
|
28
|
+
res = jax.lax.pmean(self, axis_name="torch_dist")
|
|
29
|
+
elif reduceOp == "min":
|
|
30
|
+
res = jax.lax.pmin(self, axis_name="torch_dist")
|
|
31
|
+
elif reduceOp == "max":
|
|
32
|
+
res = jax.lax.pmax(self, axis_name="torch_dist")
|
|
33
|
+
else:
|
|
34
|
+
raise RuntimeError(f"Reduce op {reduceOp} not implemented")
|
|
35
|
+
return res
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@op(torch.ops._c10d_functional.broadcast)
|
|
39
|
+
def _c10d_broadcast(self, src: int, group_name: str):
|
|
40
|
+
masked = jnp.where(
|
|
41
|
+
jax.lax.axis_index("torch_dist") == src,
|
|
42
|
+
self,
|
|
43
|
+
jnp.zeros_like(self),
|
|
44
|
+
)
|
|
45
|
+
return jax.lax.psum(masked, "torch_dist")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@op(torch.ops._c10d_functional.wait_tensor)
|
|
49
|
+
def _c10d_wait_tensor(tensor):
|
|
50
|
+
# Async tensor is aleady `wait`ed by dispatcher
|
|
51
|
+
return tensor
|
torchax/ops/jlibrary.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""The `jlibrary` module has functions which help to preserve torch.library ops
|
|
2
|
+
during export. This includes aten ops, and custom operations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torchax
|
|
8
|
+
from torchax.ops import jaten
|
|
9
|
+
import jax
|
|
10
|
+
import functools
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
|
|
14
|
+
"""Wrap a jaxpr in a jitted function with the proper composite name
|
|
15
|
+
TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op.
|
|
16
|
+
"""
|
|
17
|
+
def composite_impl(*args):
|
|
18
|
+
return jaxpr_impl(*args)
|
|
19
|
+
composite_impl.__name__ = composite_name
|
|
20
|
+
composite_impl.__qualname__ = composite_name
|
|
21
|
+
return jax.jit(composite_impl, **jit_args)
|
|
22
|
+
|
|
23
|
+
def register_jax_composite(composite_name, impl, *ops, **jit_args):
|
|
24
|
+
"""Register a composite using a JAX implementation.
|
|
25
|
+
composite_name - The name of the library op to use in the exported composite
|
|
26
|
+
impl - A JAX lowering for the library operation
|
|
27
|
+
*ops - Variadic torch.ops to lower using `impl`.
|
|
28
|
+
**jit_args - Additional parameters to forward to JAX jit.
|
|
29
|
+
|
|
30
|
+
This is used to register custom lowerings with an explicit jaxpr
|
|
31
|
+
implementation, such as preserving a specific aten op using a jaten impl.
|
|
32
|
+
|
|
33
|
+
For custom torch op registration with a decomposition written in torch,
|
|
34
|
+
use `register_torch_composite`.
|
|
35
|
+
|
|
36
|
+
For jit params and troubleshooting see:
|
|
37
|
+
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
|
|
38
|
+
"""
|
|
39
|
+
@jaten.op(*ops)
|
|
40
|
+
def _composite_impl(*args):
|
|
41
|
+
return _jit_composite_impl(composite_name, impl, **jit_args)(*args)
|
|
42
|
+
|
|
43
|
+
def register_torch_composite(composite_name, impl, *ops, **jit_args):
|
|
44
|
+
"""Register a torch decomposition as a composite.
|
|
45
|
+
This is useful for registerring custom torch op libraries as composite ops.
|
|
46
|
+
|
|
47
|
+
The `impl` can be the `@impl` used to define the torch custom library op.
|
|
48
|
+
This must be a function or module impl that provides the decompositions, and
|
|
49
|
+
not an instance of the custom op.
|
|
50
|
+
|
|
51
|
+
TODO: Better error handling, or can we make this an instance of the op as a param?
|
|
52
|
+
|
|
53
|
+
For jit params and troubleshooting see:
|
|
54
|
+
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@jaten.op(*ops)
|
|
58
|
+
def _composite_impl(*args):
|
|
59
|
+
class ImplWrapper(torch.nn.Module):
|
|
60
|
+
def __init__(self):
|
|
61
|
+
super().__init__()
|
|
62
|
+
|
|
63
|
+
def forward(self, *args):
|
|
64
|
+
return impl(*args)
|
|
65
|
+
|
|
66
|
+
# Note: avoid refactoring to share code with register_jaxpr_composite.
|
|
67
|
+
# The `extract_jax` call must live in the `@jaten.op` handler. If called
|
|
68
|
+
# outside of the handler, we would build the jaxpr representation of the
|
|
69
|
+
# module once during registration, potentially missing op registrations that
|
|
70
|
+
# come after. I.e. may miss nested abstractions if we build jaxpr AoT.
|
|
71
|
+
state, jfn = torchax.extract_jax(ImplWrapper())
|
|
72
|
+
jaxpr_impl = lambda *args: jfn(state, tuple([*args]))
|
|
73
|
+
return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args)
|