torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jax_reimplement.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
from collections.abc import Sequence
|
|
3
2
|
from jax._src.numpy.util import promote_dtypes_inexact
|
|
4
3
|
import numpy as np
|
|
@@ -15,12 +14,9 @@ from typing import Callable
|
|
|
15
14
|
|
|
16
15
|
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
scale,
|
|
21
|
-
translation,
|
|
22
|
-
kernel: Callable,
|
|
23
|
-
antialias: bool):
|
|
17
|
+
|
|
18
|
+
def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize,
|
|
19
|
+
scale, translation, kernel: Callable, antialias: bool):
|
|
24
20
|
dtype = jnp.result_type(scale, translation)
|
|
25
21
|
inv_scale = 1. / scale
|
|
26
22
|
# When downsampling the kernel should be scaled since we want to low pass
|
|
@@ -38,8 +34,8 @@ def compute_weight_mat(input_size: core.DimSize,
|
|
|
38
34
|
total_weight_sum = jnp.sum(weights, axis=0, keepdims=True)
|
|
39
35
|
weights = jnp.where(
|
|
40
36
|
jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps),
|
|
41
|
-
jnp.divide(weights, jnp.where(total_weight_sum != 0,
|
|
42
|
-
|
|
37
|
+
jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum,
|
|
38
|
+
1)), 0)
|
|
43
39
|
# Zero out weights where the sample location is completely outside the input
|
|
44
40
|
# range.
|
|
45
41
|
# Note sample_f has already had the 0.5 removed, hence the weird range below.
|
|
@@ -48,12 +44,14 @@ def compute_weight_mat(input_size: core.DimSize,
|
|
|
48
44
|
return weights
|
|
49
45
|
input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5
|
|
50
46
|
return jnp.where(
|
|
51
|
-
jnp.logical_and(sample_f >= -0.5,
|
|
52
|
-
|
|
47
|
+
jnp.logical_and(sample_f >= -0.5, sample_f
|
|
48
|
+
<= input_size_minus_0_5)[jnp.newaxis, :], weights, 0)
|
|
53
49
|
# (barney-s) -------------- END returning weights without zeroing ---------------------
|
|
54
50
|
|
|
51
|
+
|
|
55
52
|
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86
|
|
56
53
|
|
|
54
|
+
|
|
57
55
|
def _scale_and_translate(x, output_shape: core.Shape,
|
|
58
56
|
spatial_dims: Sequence[int], scale, translation,
|
|
59
57
|
kernel, antialias: bool, precision):
|
|
@@ -70,8 +68,8 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
|
|
70
68
|
d = canonicalize_axis(d, x.ndim)
|
|
71
69
|
m = input_shape[d]
|
|
72
70
|
n = output_shape[d]
|
|
73
|
-
w = compute_weight_mat(m, n, scale[i], translation[i],
|
|
74
|
-
|
|
71
|
+
w = compute_weight_mat(m, n, scale[i], translation[i], kernel,
|
|
72
|
+
antialias).astype(x.dtype)
|
|
75
73
|
contractions.append(w)
|
|
76
74
|
contractions.append([d, len(output_shape) + i])
|
|
77
75
|
out_indices[d] = len(output_shape) + i
|
|
@@ -81,15 +79,19 @@ def _scale_and_translate(x, output_shape: core.Shape,
|
|
|
81
79
|
|
|
82
80
|
# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172
|
|
83
81
|
|
|
82
|
+
|
|
84
83
|
# scale and translation here are scalar elements of an np.array, what is the
|
|
85
84
|
# correct type annotation?
|
|
86
|
-
def scale_and_translate(
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
85
|
+
def scale_and_translate(
|
|
86
|
+
image,
|
|
87
|
+
shape: core.Shape,
|
|
88
|
+
spatial_dims: Sequence[int],
|
|
89
|
+
scale,
|
|
90
|
+
translation,
|
|
91
|
+
# (barney-s) use string
|
|
92
|
+
method: str, #(barney-s) | ResizeMethod,
|
|
93
|
+
antialias: bool = True,
|
|
94
|
+
precision=lax.Precision.HIGHEST):
|
|
93
95
|
"""Apply a scale and translation to an image.
|
|
94
96
|
|
|
95
97
|
Generates a new image of shape 'shape' by resampling from the input image
|
|
@@ -165,5 +167,5 @@ def scale_and_translate(image, shape: core.Shape,
|
|
|
165
167
|
return _scale_and_translate(image, shape, spatial_dims, scale, translation,
|
|
166
168
|
kernel, antialias, precision)
|
|
167
169
|
|
|
168
|
-
# END ----------------- END JAX code copied for testing -----------------------------
|
|
169
170
|
|
|
171
|
+
# END ----------------- END JAX code copied for testing -----------------------------
|
torchax/ops/jc10d.py
CHANGED
|
@@ -6,6 +6,7 @@ from torchax.ops import ops_registry
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def op(*aten, **kwargs):
|
|
9
|
+
|
|
9
10
|
def inner(func):
|
|
10
11
|
for a in aten:
|
|
11
12
|
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
|
|
@@ -21,7 +22,7 @@ def _c10d_all_gather(input, group_size: int, group_name: str):
|
|
|
21
22
|
|
|
22
23
|
@op(torch.ops._c10d_functional.all_reduce)
|
|
23
24
|
def _c10d_all_reduce(self, reduceOp: str, group_name: str):
|
|
24
|
-
|
|
25
|
+
|
|
25
26
|
if reduceOp == "sum":
|
|
26
27
|
res = jax.lax.psum(self, axis_name="torch_dist")
|
|
27
28
|
elif reduceOp == "avg":
|
|
@@ -38,9 +39,9 @@ def _c10d_all_reduce(self, reduceOp: str, group_name: str):
|
|
|
38
39
|
@op(torch.ops._c10d_functional.broadcast)
|
|
39
40
|
def _c10d_broadcast(self, src: int, group_name: str):
|
|
40
41
|
masked = jnp.where(
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
42
|
+
jax.lax.axis_index("torch_dist") == src,
|
|
43
|
+
self,
|
|
44
|
+
jnp.zeros_like(self),
|
|
44
45
|
)
|
|
45
46
|
return jax.lax.psum(masked, "torch_dist")
|
|
46
47
|
|
torchax/ops/jimage.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def cubic_kernel(x, a=-0.75):
|
|
6
|
+
"""Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)"""
|
|
7
|
+
absx = jnp.abs(x)
|
|
8
|
+
x2 = absx * absx
|
|
9
|
+
x3 = x2 * absx
|
|
10
|
+
cond1 = (absx <= 1)
|
|
11
|
+
cond2 = (absx > 1) & (absx < 2)
|
|
12
|
+
f1 = (a + 2) * x3 - (a + 3) * x2 + 1
|
|
13
|
+
f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a
|
|
14
|
+
return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compute_contribs(in_size,
|
|
18
|
+
out_size,
|
|
19
|
+
scale,
|
|
20
|
+
support=2.0,
|
|
21
|
+
align_corners=False,
|
|
22
|
+
dtype=None):
|
|
23
|
+
if align_corners:
|
|
24
|
+
if out_size == 1:
|
|
25
|
+
in_coords = jnp.zeros((1,), dtype=dtype)
|
|
26
|
+
else:
|
|
27
|
+
in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype)
|
|
28
|
+
else:
|
|
29
|
+
out_coords = jnp.arange(out_size, dtype=dtype) + 0.5
|
|
30
|
+
in_coords = out_coords / scale - 0.5
|
|
31
|
+
|
|
32
|
+
left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1
|
|
33
|
+
idxs = left_idx[:, None] + jnp.arange(4)
|
|
34
|
+
|
|
35
|
+
dx = in_coords[:, None] - idxs
|
|
36
|
+
|
|
37
|
+
weights = cubic_kernel(dx)
|
|
38
|
+
|
|
39
|
+
weights = weights / jnp.sum(weights, axis=1, keepdims=True)
|
|
40
|
+
return idxs, weights
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def gather_weights(img, idxs, axis):
|
|
44
|
+
"""Safely gather with boundary handling"""
|
|
45
|
+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1)
|
|
46
|
+
return jnp.take(img, idxs, axis=axis)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def interpolate_along_axis_bchw(img, idxs, weights, axis):
|
|
50
|
+
"""
|
|
51
|
+
Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
|
|
52
|
+
idxs: (out_size, 4) int32 indices
|
|
53
|
+
weights: (out_size, 4) float32 weights
|
|
54
|
+
"""
|
|
55
|
+
assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)"
|
|
56
|
+
out_size = idxs.shape[0]
|
|
57
|
+
k = idxs.shape[1] # Typically 4 for cubic
|
|
58
|
+
|
|
59
|
+
# Clip to input bounds
|
|
60
|
+
idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4)
|
|
61
|
+
|
|
62
|
+
def gather_and_weight(i):
|
|
63
|
+
idx = idxs[i] # (4,)
|
|
64
|
+
w = weights[i] # (4,)
|
|
65
|
+
|
|
66
|
+
def gather_one(offset):
|
|
67
|
+
return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W)
|
|
68
|
+
|
|
69
|
+
gathered = jnp.stack([gather_one(o) for o in range(k)],
|
|
70
|
+
axis=0) # (4, B, C, H, W)
|
|
71
|
+
weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W)
|
|
72
|
+
return weighted
|
|
73
|
+
|
|
74
|
+
out = jax.vmap(gather_and_weight)(
|
|
75
|
+
jnp.arange(out_size)) # (out_size, B, C, H, W)
|
|
76
|
+
|
|
77
|
+
# Move the interpolated axis back into place
|
|
78
|
+
if axis == 2: # interpolated over H
|
|
79
|
+
return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W)
|
|
80
|
+
else: # axis == 3, interpolated over W
|
|
81
|
+
return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
|
|
85
|
+
h, w = img.shape[-2:]
|
|
86
|
+
if align_corners and out_h > 1:
|
|
87
|
+
scale_y = (h - 1) / (out_h - 1)
|
|
88
|
+
else:
|
|
89
|
+
scale_y = out_h / h
|
|
90
|
+
|
|
91
|
+
if align_corners and out_w > 1:
|
|
92
|
+
scale_x = (w - 1) / (out_w - 1)
|
|
93
|
+
else:
|
|
94
|
+
scale_x = out_w / w
|
|
95
|
+
|
|
96
|
+
idxs_y, weights_y = compute_contribs(
|
|
97
|
+
h,
|
|
98
|
+
out_h,
|
|
99
|
+
scale_y,
|
|
100
|
+
align_corners=align_corners,
|
|
101
|
+
dtype=img.dtype,
|
|
102
|
+
)
|
|
103
|
+
tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2)
|
|
104
|
+
|
|
105
|
+
idxs_x, weights_x = compute_contribs(
|
|
106
|
+
w,
|
|
107
|
+
out_w,
|
|
108
|
+
scale_x,
|
|
109
|
+
align_corners=align_corners,
|
|
110
|
+
dtype=img.dtype,
|
|
111
|
+
)
|
|
112
|
+
out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3)
|
|
113
|
+
return out
|
torchax/ops/jlibrary.py
CHANGED
|
@@ -14,19 +14,22 @@ def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args):
|
|
|
14
14
|
"""Wrap a jaxpr in a jitted function with the proper composite name
|
|
15
15
|
TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op.
|
|
16
16
|
"""
|
|
17
|
+
|
|
17
18
|
def composite_impl(*args):
|
|
18
19
|
return jaxpr_impl(*args)
|
|
20
|
+
|
|
19
21
|
composite_impl.__name__ = composite_name
|
|
20
22
|
composite_impl.__qualname__ = composite_name
|
|
21
23
|
return jax.jit(composite_impl, **jit_args)
|
|
22
24
|
|
|
25
|
+
|
|
23
26
|
def register_jax_composite(composite_name, impl, *ops, **jit_args):
|
|
24
27
|
"""Register a composite using a JAX implementation.
|
|
25
28
|
composite_name - The name of the library op to use in the exported composite
|
|
26
29
|
impl - A JAX lowering for the library operation
|
|
27
30
|
*ops - Variadic torch.ops to lower using `impl`.
|
|
28
31
|
**jit_args - Additional parameters to forward to JAX jit.
|
|
29
|
-
|
|
32
|
+
|
|
30
33
|
This is used to register custom lowerings with an explicit jaxpr
|
|
31
34
|
implementation, such as preserving a specific aten op using a jaten impl.
|
|
32
35
|
|
|
@@ -36,10 +39,12 @@ def register_jax_composite(composite_name, impl, *ops, **jit_args):
|
|
|
36
39
|
For jit params and troubleshooting see:
|
|
37
40
|
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
|
|
38
41
|
"""
|
|
42
|
+
|
|
39
43
|
@jaten.op(*ops)
|
|
40
44
|
def _composite_impl(*args):
|
|
41
45
|
return _jit_composite_impl(composite_name, impl, **jit_args)(*args)
|
|
42
46
|
|
|
47
|
+
|
|
43
48
|
def register_torch_composite(composite_name, impl, *ops, **jit_args):
|
|
44
49
|
"""Register a torch decomposition as a composite.
|
|
45
50
|
This is useful for registerring custom torch op libraries as composite ops.
|
|
@@ -53,10 +58,12 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args):
|
|
|
53
58
|
For jit params and troubleshooting see:
|
|
54
59
|
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
|
|
55
60
|
"""
|
|
56
|
-
|
|
61
|
+
|
|
57
62
|
@jaten.op(*ops)
|
|
58
63
|
def _composite_impl(*args):
|
|
64
|
+
|
|
59
65
|
class ImplWrapper(torch.nn.Module):
|
|
66
|
+
|
|
60
67
|
def __init__(self):
|
|
61
68
|
super().__init__()
|
|
62
69
|
|