torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202617.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251116.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/licenses/LICENSE +0 -0
torchax/ops/jtorchvision_nms.py
CHANGED
|
@@ -16,12 +16,10 @@
|
|
|
16
16
|
Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
-
import
|
|
20
|
-
from typing import List, Union, Optional, Tuple
|
|
21
|
-
|
|
19
|
+
import jax.numpy as jnp
|
|
22
20
|
import torch
|
|
23
21
|
from jax import lax
|
|
24
|
-
|
|
22
|
+
|
|
25
23
|
from . import ops_registry
|
|
26
24
|
|
|
27
25
|
_NMS_TILE_SIZE = 256
|
|
@@ -38,9 +36,11 @@ def _bbox_overlap(boxes, gt_boxes):
|
|
|
38
36
|
iou: Intersection over union matrix of all input bounding boxes
|
|
39
37
|
"""
|
|
40
38
|
bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
|
|
41
|
-
|
|
39
|
+
ary=boxes, indices_or_sections=4, axis=2
|
|
40
|
+
)
|
|
42
41
|
gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
|
|
43
|
-
|
|
42
|
+
ary=gt_boxes, indices_or_sections=4, axis=2
|
|
43
|
+
)
|
|
44
44
|
|
|
45
45
|
# Calculates the intersection area.
|
|
46
46
|
i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
|
|
@@ -64,11 +64,16 @@ def _bbox_overlap(boxes, gt_boxes):
|
|
|
64
64
|
def _self_suppression(in_args):
|
|
65
65
|
iou, _, iou_sum = in_args
|
|
66
66
|
batch_size = iou.shape[0]
|
|
67
|
-
can_suppress_others = jnp.reshape(
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
67
|
+
can_suppress_others = jnp.reshape(jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(
|
|
68
|
+
iou.dtype
|
|
69
|
+
)
|
|
70
|
+
iou_suppressed = (
|
|
71
|
+
jnp.reshape(
|
|
72
|
+
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
|
|
73
|
+
[batch_size, -1, 1],
|
|
74
|
+
)
|
|
75
|
+
* iou
|
|
76
|
+
)
|
|
72
77
|
iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
|
|
73
78
|
return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
|
|
74
79
|
|
|
@@ -76,11 +81,14 @@ def _self_suppression(in_args):
|
|
|
76
81
|
def _cross_suppression(in_args):
|
|
77
82
|
boxes, box_slice, iou_threshold, inner_idx = in_args
|
|
78
83
|
batch_size = boxes.shape[0]
|
|
79
|
-
new_slice = lax.dynamic_slice(
|
|
80
|
-
|
|
84
|
+
new_slice = lax.dynamic_slice(
|
|
85
|
+
boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4]
|
|
86
|
+
)
|
|
81
87
|
iou = _bbox_overlap(new_slice, box_slice)
|
|
82
|
-
ret_slice =
|
|
83
|
-
|
|
88
|
+
ret_slice = (
|
|
89
|
+
jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2)
|
|
90
|
+
* box_slice
|
|
91
|
+
)
|
|
84
92
|
return boxes, ret_slice, iou_threshold, inner_idx + 1
|
|
85
93
|
|
|
86
94
|
|
|
@@ -101,38 +109,44 @@ def _suppression_loop_body(in_args):
|
|
|
101
109
|
batch_size = boxes.shape[0]
|
|
102
110
|
|
|
103
111
|
# Iterates over tiles that can possibly suppress the current tile.
|
|
104
|
-
box_slice = lax.dynamic_slice(
|
|
105
|
-
|
|
112
|
+
box_slice = lax.dynamic_slice(
|
|
113
|
+
boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4]
|
|
114
|
+
)
|
|
106
115
|
|
|
107
116
|
def _loop_cond(in_args):
|
|
108
117
|
_, _, _, inner_idx = in_args
|
|
109
118
|
return inner_idx < idx
|
|
110
119
|
|
|
111
|
-
_, box_slice, _, _ = lax.while_loop(
|
|
112
|
-
|
|
120
|
+
_, box_slice, _, _ = lax.while_loop(
|
|
121
|
+
_loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0)
|
|
122
|
+
)
|
|
113
123
|
|
|
114
124
|
# Iterates over the current tile to compute self-suppression.
|
|
115
125
|
iou = _bbox_overlap(box_slice, box_slice)
|
|
116
126
|
mask = jnp.expand_dims(
|
|
117
|
-
|
|
118
|
-
|
|
127
|
+
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1])
|
|
128
|
+
> jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]),
|
|
129
|
+
0,
|
|
130
|
+
)
|
|
119
131
|
iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
|
|
120
132
|
|
|
121
133
|
def _loop_cond2(in_args):
|
|
122
134
|
_, loop_condition, _ = in_args
|
|
123
135
|
return loop_condition
|
|
124
136
|
|
|
125
|
-
suppressed_iou, _, _ = lax.while_loop(
|
|
126
|
-
|
|
137
|
+
suppressed_iou, _, _ = lax.while_loop(
|
|
138
|
+
_loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2]))
|
|
139
|
+
)
|
|
127
140
|
suppressed_box = jnp.sum(suppressed_iou, 1) > 0
|
|
128
141
|
box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
|
|
129
142
|
|
|
130
143
|
# Uses box_slice to update the input boxes.
|
|
131
|
-
mask = jnp.reshape(
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
144
|
+
mask = jnp.reshape(
|
|
145
|
+
(jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), [1, -1, 1, 1]
|
|
146
|
+
)
|
|
147
|
+
boxes = jnp.tile(
|
|
148
|
+
jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1]
|
|
149
|
+
) * mask + jnp.reshape(boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
|
|
136
150
|
boxes = jnp.reshape(boxes, [batch_size, -1, 4])
|
|
137
151
|
|
|
138
152
|
# Updates output_size.
|
|
@@ -193,8 +207,7 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
193
207
|
"""
|
|
194
208
|
batch_size = boxes.shape[0]
|
|
195
209
|
num_boxes = boxes.shape[1]
|
|
196
|
-
pad = int(jnp.ceil(
|
|
197
|
-
float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
|
|
210
|
+
pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes
|
|
198
211
|
boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
|
|
199
212
|
scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
|
|
200
213
|
num_boxes += pad
|
|
@@ -202,29 +215,37 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
|
|
|
202
215
|
def _loop_cond(in_args):
|
|
203
216
|
unused_boxes, unused_threshold, output_size, idx = in_args
|
|
204
217
|
return jnp.logical_and(
|
|
205
|
-
|
|
206
|
-
|
|
218
|
+
jnp.min(output_size) < max_output_size, idx < num_boxes // _NMS_TILE_SIZE
|
|
219
|
+
)
|
|
207
220
|
|
|
208
221
|
selected_boxes, _, output_size, _ = lax.while_loop(
|
|
209
|
-
|
|
210
|
-
|
|
222
|
+
_loop_cond,
|
|
223
|
+
_suppression_loop_body,
|
|
224
|
+
(boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0),
|
|
225
|
+
)
|
|
211
226
|
idx = num_boxes - lax.top_k(
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
227
|
+
jnp.any(selected_boxes > 0, [2]).astype(jnp.int32)
|
|
228
|
+
* jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
|
|
229
|
+
max_output_size,
|
|
230
|
+
)[0].astype(jnp.int32)
|
|
215
231
|
idx = jnp.minimum(idx, num_boxes - 1)
|
|
216
232
|
idx = jnp.reshape(
|
|
217
|
-
|
|
233
|
+
idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1]
|
|
234
|
+
)
|
|
218
235
|
|
|
219
236
|
return idx
|
|
220
|
-
boxes = jnp.reshape(
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
237
|
+
boxes = jnp.reshape(
|
|
238
|
+
(jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4]
|
|
239
|
+
)
|
|
240
|
+
boxes = boxes * (
|
|
241
|
+
jnp.reshape(jnp.arange(max_output_size), [1, -1, 1])
|
|
242
|
+
< jnp.reshape(output_size, [-1, 1, 1])
|
|
243
|
+
).astype(boxes.dtype)
|
|
244
|
+
scores = jnp.reshape(jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size])
|
|
245
|
+
scores = scores * (
|
|
246
|
+
jnp.reshape(jnp.arange(max_output_size), [1, -1])
|
|
247
|
+
< jnp.reshape(output_size, [-1, 1])
|
|
248
|
+
).astype(scores.dtype)
|
|
228
249
|
return scores, boxes
|
|
229
250
|
|
|
230
251
|
|
|
@@ -235,14 +256,13 @@ def nms(boxes, scores, iou_threshold):
|
|
|
235
256
|
max_output_size = boxes.shape[0]
|
|
236
257
|
boxes = boxes.reshape((1, *boxes.shape))
|
|
237
258
|
scores = scores.reshape((1, *scores.shape))
|
|
238
|
-
res = non_max_suppression_padded(scores, boxes, max_output_size,
|
|
239
|
-
iou_threshold)
|
|
259
|
+
res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
|
|
240
260
|
return res
|
|
241
261
|
|
|
242
262
|
|
|
243
263
|
try:
|
|
244
264
|
import torch
|
|
245
|
-
|
|
265
|
+
|
|
246
266
|
ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
|
|
247
267
|
except Exception:
|
|
248
268
|
pass
|
torchax/ops/mappings.py
CHANGED
|
@@ -12,20 +12,20 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from jax import dlpack as jaxdl
|
|
16
15
|
import jax.numpy as jnp
|
|
17
16
|
import numpy
|
|
18
17
|
import torch
|
|
19
18
|
import torch.func
|
|
20
|
-
import torch.utils.dlpack as torchdl
|
|
21
19
|
import torch.utils._mode_utils as mode_utils
|
|
20
|
+
import torch.utils.dlpack as torchdl
|
|
21
|
+
from jax import dlpack as jaxdl
|
|
22
22
|
|
|
23
23
|
NUMPY_UNSUPPORTED_DTYPES = {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
24
|
+
torch.bfloat16: jnp.bfloat16,
|
|
25
|
+
torch.float8_e4m3fn: jnp.float8_e4m3fn,
|
|
26
|
+
torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz,
|
|
27
|
+
torch.float8_e5m2: jnp.float8_e5m2,
|
|
28
|
+
torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz,
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
|
|
@@ -51,8 +51,9 @@ def t2j(t, use_dlpack=True):
|
|
|
51
51
|
# https://github.com/google/jax/issues/7657
|
|
52
52
|
# https://github.com/google/jax/issues/17784
|
|
53
53
|
if t.dtype in NUMPY_UNSUPPORTED_DTYPES:
|
|
54
|
-
nparray = (
|
|
55
|
-
|
|
54
|
+
nparray = (
|
|
55
|
+
t.cpu().detach().to(torch.float32).numpy()
|
|
56
|
+
) # handle dtypes not supported by numpy
|
|
56
57
|
else:
|
|
57
58
|
nparray = t.cpu().detach().numpy()
|
|
58
59
|
res = jnp.asarray(nparray)
|
|
@@ -91,71 +92,49 @@ def j2t(x, use_dlpack=True):
|
|
|
91
92
|
|
|
92
93
|
|
|
93
94
|
TORCH_DTYPE_TO_JAX = {
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
jnp.
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
jnp.
|
|
111
|
-
|
|
112
|
-
jnp.
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
jnp.float8_e5m2.dtype,
|
|
123
|
-
# NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
|
|
124
|
-
torch.bfloat16:
|
|
125
|
-
jnp.bfloat16.dtype,
|
|
126
|
-
torch.half:
|
|
127
|
-
jnp.float16.dtype,
|
|
128
|
-
torch.float16:
|
|
129
|
-
jnp.float16.dtype,
|
|
130
|
-
torch.float32:
|
|
131
|
-
jnp.float32.dtype,
|
|
132
|
-
torch.float64:
|
|
133
|
-
jnp.float64.dtype,
|
|
134
|
-
torch.double:
|
|
135
|
-
jnp.double.dtype,
|
|
136
|
-
torch.complex64:
|
|
137
|
-
jnp.complex64.dtype,
|
|
138
|
-
torch.complex128:
|
|
139
|
-
jnp.complex128.dtype,
|
|
140
|
-
None:
|
|
141
|
-
None,
|
|
95
|
+
# NO_MAPPING : jnp.float0.dtype (signless scalar int),
|
|
96
|
+
torch.bool: jnp.bool_.dtype,
|
|
97
|
+
# NO_MAPPING : jnp.int4.dtype,
|
|
98
|
+
torch.int8: jnp.int8.dtype,
|
|
99
|
+
torch.int16: jnp.int16.dtype,
|
|
100
|
+
torch.int32: jnp.int32.dtype,
|
|
101
|
+
torch.int64: jnp.int64.dtype,
|
|
102
|
+
torch.long: jnp.int64.dtype,
|
|
103
|
+
# NO_MAPPING : jnp.uint4
|
|
104
|
+
torch.uint8: jnp.uint8.dtype,
|
|
105
|
+
torch.uint16: jnp.uint16.dtype,
|
|
106
|
+
torch.uint32: jnp.uint32.dtype,
|
|
107
|
+
torch.uint64: jnp.uint64.dtype,
|
|
108
|
+
torch.float4_e2m1fn_x2: jnp.float4_e2m1fn.dtype,
|
|
109
|
+
# NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype,
|
|
110
|
+
torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype,
|
|
111
|
+
# NO_MAPPING : jnp.float8_e4m3fnuz.dtype,
|
|
112
|
+
torch.float8_e5m2: jnp.float8_e5m2.dtype,
|
|
113
|
+
# NO_MAPPING : jnp.float8_e5m2fnuz.dtype,
|
|
114
|
+
torch.bfloat16: jnp.bfloat16.dtype,
|
|
115
|
+
torch.half: jnp.float16.dtype,
|
|
116
|
+
torch.float16: jnp.float16.dtype,
|
|
117
|
+
torch.float32: jnp.float32.dtype,
|
|
118
|
+
torch.float64: jnp.float64.dtype,
|
|
119
|
+
torch.double: jnp.double.dtype,
|
|
120
|
+
torch.complex64: jnp.complex64.dtype,
|
|
121
|
+
torch.complex128: jnp.complex128.dtype,
|
|
122
|
+
None: None,
|
|
142
123
|
}
|
|
143
124
|
|
|
144
125
|
JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
|
|
145
126
|
# Add imprecise mappings for some JAX dtypes which don't have torch analogues
|
|
146
|
-
JAX_DTYPE_TO_TORCH[jnp.dtype(
|
|
147
|
-
JAX_DTYPE_TO_TORCH[jnp.dtype(
|
|
127
|
+
JAX_DTYPE_TO_TORCH[jnp.dtype("int4")] = torch.int8
|
|
128
|
+
JAX_DTYPE_TO_TORCH[jnp.dtype("uint4")] = torch.uint8
|
|
148
129
|
|
|
149
130
|
|
|
150
131
|
def t2j_dtype(dtype):
|
|
151
132
|
if dtype not in TORCH_DTYPE_TO_JAX:
|
|
152
|
-
raise RuntimeError(
|
|
153
|
-
f'Attempting to convert unknown type: {dtype} to jax type,')
|
|
133
|
+
raise RuntimeError(f"Attempting to convert unknown type: {dtype} to jax type,")
|
|
154
134
|
return TORCH_DTYPE_TO_JAX[dtype]
|
|
155
135
|
|
|
156
136
|
|
|
157
137
|
def j2t_dtype(dtype):
|
|
158
138
|
if dtype not in JAX_DTYPE_TO_TORCH:
|
|
159
|
-
raise RuntimeError(
|
|
160
|
-
f'Attempting to convert unknown type: {dtype} to torch type,')
|
|
139
|
+
raise RuntimeError(f"Attempting to convert unknown type: {dtype} to torch type,")
|
|
161
140
|
return JAX_DTYPE_TO_TORCH[dtype]
|
torchax/ops/op_base.py
CHANGED
|
@@ -13,25 +13,23 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import functools
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
from typing import Concatenate, ParamSpec
|
|
18
|
+
|
|
16
19
|
import jax
|
|
17
20
|
import jax.numpy as jnp
|
|
18
21
|
import numpy as np
|
|
19
22
|
import torch
|
|
23
|
+
|
|
24
|
+
from torchax import types
|
|
20
25
|
from torchax.ops import mappings
|
|
21
26
|
from torchax.view import View
|
|
22
|
-
from torchax import types
|
|
23
|
-
import sys
|
|
24
|
-
|
|
25
|
-
from typing import Callable, Optional, ParamSpec, Concatenate
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class InplaceOp:
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
replace=False,
|
|
33
|
-
position_to_mutate=0,
|
|
34
|
-
is_jax_func=False):
|
|
30
|
+
def __init__(
|
|
31
|
+
self, functional_op, replace=False, position_to_mutate=0, is_jax_func=False
|
|
32
|
+
):
|
|
35
33
|
self.functional = functional_op
|
|
36
34
|
self.replace = replace
|
|
37
35
|
self.position_to_mutate = position_to_mutate
|
|
@@ -65,15 +63,14 @@ class InplaceOp:
|
|
|
65
63
|
|
|
66
64
|
|
|
67
65
|
class OutVariant:
|
|
68
|
-
|
|
69
66
|
def __call__(self, *args, **kwargs):
|
|
70
|
-
to_mutate = kwargs[
|
|
71
|
-
del kwargs[
|
|
67
|
+
to_mutate = kwargs["out"]
|
|
68
|
+
del kwargs["out"]
|
|
72
69
|
to_mutate._elem = self.functional(*args, **kwargs)._elem
|
|
73
70
|
return to_mutate
|
|
74
71
|
|
|
75
72
|
|
|
76
|
-
P = ParamSpec(
|
|
73
|
+
P = ParamSpec("P")
|
|
77
74
|
|
|
78
75
|
|
|
79
76
|
def convert_dtype(use_default_dtype: bool = True):
|
|
@@ -87,11 +84,8 @@ def convert_dtype(use_default_dtype: bool = True):
|
|
|
87
84
|
"""
|
|
88
85
|
|
|
89
86
|
def decorator(func: types.TorchCallable):
|
|
90
|
-
|
|
91
87
|
@functools.wraps(func)
|
|
92
|
-
def wrapper(*args: P.args,
|
|
93
|
-
dtype: Optional[torch.dtype] = None,
|
|
94
|
-
**kwargs: P.kwargs):
|
|
88
|
+
def wrapper(*args: P.args, dtype: torch.dtype | None = None, **kwargs: P.kwargs):
|
|
95
89
|
if not dtype and use_default_dtype:
|
|
96
90
|
dtype = torch.get_default_dtype()
|
|
97
91
|
if isinstance(dtype, torch.dtype):
|
|
@@ -106,8 +100,7 @@ def convert_dtype(use_default_dtype: bool = True):
|
|
|
106
100
|
return decorator
|
|
107
101
|
|
|
108
102
|
|
|
109
|
-
def maybe_convert_constant_dtype(val:
|
|
110
|
-
dtype: Optional[jnp.dtype]):
|
|
103
|
+
def maybe_convert_constant_dtype(val: types.JaxValue | None, dtype: jnp.dtype | None):
|
|
111
104
|
"""Optionally converts scalar constant's dtype using `numpy`
|
|
112
105
|
|
|
113
106
|
Use in cases where you require a constant and can't handle a traced array.
|
|
@@ -134,12 +127,11 @@ def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]):
|
|
|
134
127
|
return wrapper
|
|
135
128
|
|
|
136
129
|
|
|
137
|
-
def foreach_loop(
|
|
138
|
-
|
|
139
|
-
|
|
130
|
+
def foreach_loop(
|
|
131
|
+
seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0
|
|
132
|
+
):
|
|
140
133
|
"""Run `fn` for each element of 1D array `seq`.
|
|
141
134
|
|
|
142
135
|
Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`."""
|
|
143
136
|
assert len(seq.shape) == 1
|
|
144
|
-
return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]),
|
|
145
|
-
init_val)
|
|
137
|
+
return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val)
|
torchax/ops/ops_registry.py
CHANGED
|
@@ -14,56 +14,61 @@
|
|
|
14
14
|
|
|
15
15
|
import dataclasses
|
|
16
16
|
import logging
|
|
17
|
-
from torchax.types import JaxCallable, TorchCallable
|
|
18
17
|
|
|
19
|
-
from
|
|
18
|
+
from torchax.types import JaxCallable, TorchCallable
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
@dataclasses.dataclass
|
|
23
22
|
class Operator:
|
|
24
23
|
torch_op: TorchCallable
|
|
25
|
-
func:
|
|
24
|
+
func: TorchCallable | JaxCallable
|
|
26
25
|
is_jax_function: bool
|
|
27
26
|
is_user_defined: bool
|
|
28
27
|
needs_env: bool
|
|
29
28
|
is_view_op: bool
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
all_aten_ops:
|
|
33
|
-
all_torch_functions:
|
|
31
|
+
all_aten_ops: dict[TorchCallable, Operator] = {}
|
|
32
|
+
all_torch_functions: dict[TorchCallable, Operator] = {}
|
|
34
33
|
|
|
35
34
|
|
|
36
|
-
def register_torch_dispatch_op(
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
35
|
+
def register_torch_dispatch_op(
|
|
36
|
+
aten_op,
|
|
37
|
+
impl_callable,
|
|
38
|
+
is_jax_function=True,
|
|
39
|
+
is_user_defined=False,
|
|
40
|
+
needs_env=False,
|
|
41
|
+
is_view_op=False,
|
|
42
|
+
):
|
|
42
43
|
op = Operator(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
44
|
+
aten_op,
|
|
45
|
+
impl_callable,
|
|
46
|
+
is_jax_function=is_jax_function,
|
|
47
|
+
is_user_defined=is_user_defined,
|
|
48
|
+
needs_env=needs_env,
|
|
49
|
+
is_view_op=is_view_op,
|
|
50
|
+
)
|
|
49
51
|
if aten_op in all_aten_ops:
|
|
50
|
-
logging.warning(f
|
|
52
|
+
logging.warning(f"Duplicate op registration for {aten_op}")
|
|
51
53
|
all_aten_ops[aten_op] = op
|
|
52
54
|
return impl_callable
|
|
53
55
|
|
|
54
56
|
|
|
55
|
-
def register_torch_function_op(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
57
|
+
def register_torch_function_op(
|
|
58
|
+
torch_func,
|
|
59
|
+
impl_callable,
|
|
60
|
+
is_jax_function=True,
|
|
61
|
+
is_user_defined=False,
|
|
62
|
+
needs_env=False,
|
|
63
|
+
is_view_op=False,
|
|
64
|
+
):
|
|
61
65
|
op = Operator(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
torch_func,
|
|
67
|
+
impl_callable,
|
|
68
|
+
is_jax_function=is_jax_function,
|
|
69
|
+
is_user_defined=is_user_defined,
|
|
70
|
+
needs_env=needs_env,
|
|
71
|
+
is_view_op=is_view_op,
|
|
72
|
+
)
|
|
68
73
|
all_torch_functions[torch_func] = op
|
|
69
74
|
return impl_callable
|