returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +130 -42
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +48 -2
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +222 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +182 -172
- returnn/native_op.py +36 -31
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +8 -5
- returnn/tf/frontend_layers/_backend.py +7 -3
- returnn/tf/layers/basic.py +27 -40
- returnn/tf/native_op.py +27 -63
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +22 -197
- returnn/torch/engine.py +157 -6
- returnn/torch/frontend/_backend.py +280 -29
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +12 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/frontend/loss.py
CHANGED
|
@@ -3,11 +3,20 @@ Loss functions
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
from typing import Optional, Tuple
|
|
6
7
|
from returnn.tensor import Tensor, Dim
|
|
7
8
|
import returnn.frontend as rf
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"cross_entropy",
|
|
13
|
+
"ctc_loss",
|
|
14
|
+
"ctc_best_path",
|
|
15
|
+
"ctc_greedy_decode",
|
|
16
|
+
"ctc_durations_from_path",
|
|
17
|
+
"ctc_no_label_loop_blank_durations_from_path",
|
|
18
|
+
"edit_distance",
|
|
19
|
+
]
|
|
11
20
|
|
|
12
21
|
|
|
13
22
|
def cross_entropy(
|
|
@@ -63,6 +72,8 @@ def ctc_loss(
|
|
|
63
72
|
targets_spatial_dim: Dim,
|
|
64
73
|
blank_index: int,
|
|
65
74
|
max_approx: bool = False,
|
|
75
|
+
use_native_op: Optional[bool] = None,
|
|
76
|
+
label_loop: bool = True,
|
|
66
77
|
) -> Tensor:
|
|
67
78
|
"""
|
|
68
79
|
Calculates the CTC loss.
|
|
@@ -79,6 +90,8 @@ def ctc_loss(
|
|
|
79
90
|
:param targets_spatial_dim: spatial dim of targets
|
|
80
91
|
:param blank_index: vocab index of the blank symbol
|
|
81
92
|
:param max_approx: if True, use max instead of sum over alignments (max approx, Viterbi)
|
|
93
|
+
:param use_native_op: whether to use our native op
|
|
94
|
+
:param label_loop:
|
|
82
95
|
:return: loss shape [B...]
|
|
83
96
|
"""
|
|
84
97
|
# noinspection PyProtectedMember
|
|
@@ -90,9 +103,210 @@ def ctc_loss(
|
|
|
90
103
|
targets_spatial_dim=targets_spatial_dim,
|
|
91
104
|
blank_index=blank_index,
|
|
92
105
|
max_approx=max_approx,
|
|
106
|
+
use_native_op=use_native_op,
|
|
107
|
+
label_loop=label_loop,
|
|
93
108
|
)
|
|
94
109
|
|
|
95
110
|
|
|
111
|
+
def ctc_best_path(
|
|
112
|
+
*,
|
|
113
|
+
logits: Tensor,
|
|
114
|
+
logits_normalized: bool = False,
|
|
115
|
+
targets: Tensor,
|
|
116
|
+
input_spatial_dim: Dim,
|
|
117
|
+
targets_spatial_dim: Dim,
|
|
118
|
+
blank_index: int,
|
|
119
|
+
label_loop: bool = True,
|
|
120
|
+
) -> Tensor:
|
|
121
|
+
"""
|
|
122
|
+
Calculates the CTC best path.
|
|
123
|
+
|
|
124
|
+
:param logits: (before softmax). shape [B...,input_spatial,C]
|
|
125
|
+
:param logits_normalized: whether the logits are already normalized (e.g. via log-softmax)
|
|
126
|
+
:param targets: sparse. shape [B...,targets_spatial] -> C
|
|
127
|
+
:param input_spatial_dim: spatial dim of input logits
|
|
128
|
+
:param targets_spatial_dim: spatial dim of targets
|
|
129
|
+
:param blank_index: vocab index of the blank symbol
|
|
130
|
+
:param label_loop: whether label loops are allowed (standard for CTC). False is like RNA topology.
|
|
131
|
+
:return: best path, shape [B...,targets_spatial] -> C
|
|
132
|
+
"""
|
|
133
|
+
# noinspection PyProtectedMember
|
|
134
|
+
return logits._raw_backend.ctc_best_path(
|
|
135
|
+
logits=logits,
|
|
136
|
+
logits_normalized=logits_normalized,
|
|
137
|
+
targets=targets,
|
|
138
|
+
input_spatial_dim=input_spatial_dim,
|
|
139
|
+
targets_spatial_dim=targets_spatial_dim,
|
|
140
|
+
blank_index=blank_index,
|
|
141
|
+
label_loop=label_loop,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def ctc_greedy_decode(
|
|
146
|
+
logits: Tensor,
|
|
147
|
+
*,
|
|
148
|
+
in_spatial_dim: Dim,
|
|
149
|
+
blank_index: int,
|
|
150
|
+
out_spatial_dim: Optional[Dim] = None,
|
|
151
|
+
target_dim: Optional[Dim] = None,
|
|
152
|
+
wb_target_dim: Optional[Dim] = None,
|
|
153
|
+
) -> Tuple[Tensor, Dim]:
|
|
154
|
+
"""
|
|
155
|
+
Greedy CTC decode.
|
|
156
|
+
|
|
157
|
+
:return: (labels, out_spatial_dim)
|
|
158
|
+
"""
|
|
159
|
+
if wb_target_dim is None:
|
|
160
|
+
assert logits.feature_dim
|
|
161
|
+
wb_target_dim = logits.feature_dim
|
|
162
|
+
|
|
163
|
+
labels = rf.reduce_argmax(logits, axis=wb_target_dim)
|
|
164
|
+
labels = rf.cast(labels, "int32")
|
|
165
|
+
|
|
166
|
+
labels_shifted = rf.shift_right(labels, axis=in_spatial_dim, pad_value=blank_index)
|
|
167
|
+
mask_repeat = labels != labels_shifted
|
|
168
|
+
labels, out_spatial_dim = rf.masked_select(
|
|
169
|
+
labels,
|
|
170
|
+
mask=(labels != blank_index) & mask_repeat,
|
|
171
|
+
dims=[in_spatial_dim],
|
|
172
|
+
out_dim=out_spatial_dim,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if target_dim:
|
|
176
|
+
# Set correct sparse_dim. Only currently implemented if blank comes after.
|
|
177
|
+
assert target_dim.dimension == blank_index
|
|
178
|
+
labels.sparse_dim = target_dim
|
|
179
|
+
|
|
180
|
+
return labels, out_spatial_dim
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def ctc_durations_from_path(
|
|
184
|
+
*,
|
|
185
|
+
path: Tensor,
|
|
186
|
+
path_spatial_dim: Dim,
|
|
187
|
+
blank_index: int,
|
|
188
|
+
targets_spatial_dim: Optional[Dim] = None,
|
|
189
|
+
out_spatial_dim: Optional[Dim] = None,
|
|
190
|
+
) -> Tuple[Tensor, Dim]:
|
|
191
|
+
"""
|
|
192
|
+
Given a CTC path (alignment), compute the durations of each label + blanks.
|
|
193
|
+
Specifically, assuming that we have N labels in the target sequence,
|
|
194
|
+
there are N labels and N+1 blank durations,
|
|
195
|
+
(one before the first label, one after the last label, and one between each pair of labels),
|
|
196
|
+
resulting in a total of 2N+1 durations.
|
|
197
|
+
The returned durations tensor will have shape [B,...,T'] where T' = 2 * N + 1,
|
|
198
|
+
corresponding to durations for state sequence [blank_0, label_1, blank_1, label_2, ..., label_N, blank_N].
|
|
199
|
+
|
|
200
|
+
:param path: CTC path (alignment), shape [B...,path_spatial_dim] -> label indices (including blanks)
|
|
201
|
+
:param path_spatial_dim: spatial dim of path
|
|
202
|
+
:param blank_index: index of the blank label
|
|
203
|
+
:param targets_spatial_dim: if given, asserts that the computed number of labels matches this size
|
|
204
|
+
:param out_spatial_dim: if given, asserts that the output spatial dim size matches 2 * target_spatial_dim + 1
|
|
205
|
+
:return: (durations, out_spatial_dim).
|
|
206
|
+
durations shape [B...,out_spatial_dim] where out_spatial_dim = 2 * N + 1,
|
|
207
|
+
where N is the number of labels in the target sequence.
|
|
208
|
+
"""
|
|
209
|
+
# example path: [_ _ a a b _ _ c c c _]
|
|
210
|
+
path_shifted = rf.shift_right(path, axis=path_spatial_dim, pad_value=blank_index)
|
|
211
|
+
# path_shifted: [_ _ _ a a b _ _ c c c]
|
|
212
|
+
new_label_mask = rf.logical_and(path != blank_index, path != path_shifted)
|
|
213
|
+
new_label_mask = new_label_mask.copy_masked(False, dims=[path_spatial_dim])
|
|
214
|
+
num_labels = rf.reduce_sum(rf.cast(new_label_mask, "int32"), axis=path_spatial_dim)
|
|
215
|
+
if targets_spatial_dim is not None:
|
|
216
|
+
rf.assert_(
|
|
217
|
+
targets_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels,
|
|
218
|
+
"target_spatial_dim size does not match number of labels in path",
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
targets_spatial_dim = Dim(
|
|
222
|
+
rf.copy_to_device(num_labels, rf.get_default_dim_size_device()), name="target_spatial"
|
|
223
|
+
)
|
|
224
|
+
# new_label_mask: [0 0 1 0 1 0 0 1 0 0 0]
|
|
225
|
+
blank_idx = rf.cumsum(rf.cast(new_label_mask, "int32"), spatial_dim=path_spatial_dim)
|
|
226
|
+
# label_idx = blank_idx - 1
|
|
227
|
+
# label_idx: [-1 -1 0 0 1 1 1 2 2 2 2]
|
|
228
|
+
# blank_idx: [0 0 1 1 2 2 2 3 3 3 3]
|
|
229
|
+
blank_idx_x2 = blank_idx * 2
|
|
230
|
+
# blank_idx_x2: [0 0 2 2 4 4 4 6 6 6 6]
|
|
231
|
+
state_idx = blank_idx_x2 + rf.where(path == blank_index, 0, -1)
|
|
232
|
+
# state_idx: [0 0 1 1 3 4 4 5 5 5 6]
|
|
233
|
+
if out_spatial_dim is not None:
|
|
234
|
+
rf.assert_(
|
|
235
|
+
out_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels * 2 + 1,
|
|
236
|
+
"out_spatial_dim size does not match 2 * target_spatial_dim + 1",
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
out_spatial_dim = targets_spatial_dim * 2 + 1
|
|
240
|
+
out = rf.scatter(rf.ones_like(state_idx), indices=state_idx, indices_dim=path_spatial_dim, out_dim=out_spatial_dim)
|
|
241
|
+
# out state seq: [ _ a _ b _ c _ ]
|
|
242
|
+
# out: [ 2 2 0 1 2 3 1 ]
|
|
243
|
+
return out, out_spatial_dim
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def ctc_no_label_loop_blank_durations_from_path(
|
|
247
|
+
*,
|
|
248
|
+
path: Tensor,
|
|
249
|
+
path_spatial_dim: Dim,
|
|
250
|
+
blank_index: int,
|
|
251
|
+
targets_spatial_dim: Optional[Dim] = None,
|
|
252
|
+
out_spatial_dim: Optional[Dim] = None,
|
|
253
|
+
) -> Tuple[Tensor, Dim]:
|
|
254
|
+
"""
|
|
255
|
+
Given a CTC-without-label-loop (``label_loop=False`` in :func:`ctc_best_path`) (RNA) path (alignment),
|
|
256
|
+
compute the durations of all the blanks.
|
|
257
|
+
Specifically, assuming that we have N labels in the target sequence,
|
|
258
|
+
there are N+1 blank durations
|
|
259
|
+
(one before the first label, one after the last label, and one between each pair of labels).
|
|
260
|
+
|
|
261
|
+
:param path: CTC path (alignment), shape [B...,path_spatial_dim] -> label indices (including blanks)
|
|
262
|
+
:param path_spatial_dim: spatial dim of path
|
|
263
|
+
:param blank_index: index of the blank label
|
|
264
|
+
:param targets_spatial_dim: if given, asserts that the computed number of labels matches this size
|
|
265
|
+
:param out_spatial_dim: if given, asserts that the output spatial dim size matches target_spatial_dim + 1
|
|
266
|
+
:return: (durations, out_spatial_dim),
|
|
267
|
+
durations is for the blank labels,
|
|
268
|
+
durations shape [B...,out_spatial_dim] where out_spatial_dim = N + 1,
|
|
269
|
+
where N is the number of labels in the target sequence.
|
|
270
|
+
"""
|
|
271
|
+
# example path: [_ _ _ a b _ _ c _]
|
|
272
|
+
new_label_mask = path != blank_index
|
|
273
|
+
new_label_mask = new_label_mask.copy_masked(False, dims=[path_spatial_dim])
|
|
274
|
+
num_labels = rf.reduce_sum(rf.cast(new_label_mask, "int32"), axis=path_spatial_dim)
|
|
275
|
+
if targets_spatial_dim is not None:
|
|
276
|
+
rf.assert_(
|
|
277
|
+
targets_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels,
|
|
278
|
+
"target_spatial_dim size does not match number of labels in path",
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
targets_spatial_dim = Dim(
|
|
282
|
+
rf.copy_to_device(num_labels, rf.get_default_dim_size_device()), name="target_spatial"
|
|
283
|
+
)
|
|
284
|
+
# new_label_mask: [0 0 0 1 1 0 0 1 0]
|
|
285
|
+
blank_idx = rf.cumsum(rf.cast(new_label_mask, "int32"), spatial_dim=path_spatial_dim)
|
|
286
|
+
# blank_idx: [0 0 0 1 2 2 2 3 3]
|
|
287
|
+
blank_idx = rf.where(
|
|
288
|
+
(path == blank_index) & rf.sequence_mask(path_spatial_dim, device=path.device),
|
|
289
|
+
blank_idx,
|
|
290
|
+
rf.reduce_max(num_labels, axis=num_labels.dims) + 1,
|
|
291
|
+
)
|
|
292
|
+
# blank_idx: [0 0 0 4 4 2 2 4 3]
|
|
293
|
+
if out_spatial_dim is not None:
|
|
294
|
+
rf.assert_(
|
|
295
|
+
out_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels + 1,
|
|
296
|
+
"out_spatial_dim size does not match 2 * target_spatial_dim + 1",
|
|
297
|
+
)
|
|
298
|
+
else:
|
|
299
|
+
out_spatial_dim = targets_spatial_dim + 1
|
|
300
|
+
out_spatial_dim_ext = out_spatial_dim + 1 # for the extra label index used above
|
|
301
|
+
out = rf.scatter(
|
|
302
|
+
rf.ones_like(blank_idx), indices=blank_idx, indices_dim=path_spatial_dim, out_dim=out_spatial_dim_ext
|
|
303
|
+
)
|
|
304
|
+
out, _ = rf.slice(out, axis=out_spatial_dim_ext, size=out_spatial_dim)
|
|
305
|
+
# out state seq: [ _ a _ b _ c _ ]
|
|
306
|
+
# out: [ 3 0 2 1 ]
|
|
307
|
+
return out, out_spatial_dim
|
|
308
|
+
|
|
309
|
+
|
|
96
310
|
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = "int32") -> Tensor:
|
|
97
311
|
"""
|
|
98
312
|
:param a: [B,Ta]
|
|
@@ -102,13 +316,18 @@ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim,
|
|
|
102
316
|
:param dtype:
|
|
103
317
|
:return: [B]
|
|
104
318
|
"""
|
|
105
|
-
|
|
319
|
+
# noinspection PyProtectedMember
|
|
320
|
+
backend = a._raw_backend
|
|
321
|
+
if backend.have_edit_distance():
|
|
322
|
+
return backend.edit_distance(a, a_spatial_dim, b, b_spatial_dim)
|
|
323
|
+
|
|
324
|
+
from numpy import iinfo
|
|
106
325
|
|
|
107
326
|
# The axis permutation is just an efficiency optimization.
|
|
108
327
|
a = a.copy_transpose([a_spatial_dim] + a.remaining_dims(a_spatial_dim))
|
|
109
328
|
b = b.copy_transpose([b_spatial_dim] + b.remaining_dims(b_spatial_dim))
|
|
110
329
|
dev = a.device
|
|
111
|
-
max_dist_err =
|
|
330
|
+
max_dist_err = iinfo(dtype).max
|
|
112
331
|
n_a_max_len = a_spatial_dim.get_dim_value()
|
|
113
332
|
n_b_max_len = b_spatial_dim.get_dim_value()
|
|
114
333
|
if int(n_a_max_len) < int(n_b_max_len):
|
returnn/frontend/math_.py
CHANGED
|
@@ -3,7 +3,6 @@ Math ops
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
import typing
|
|
7
6
|
from typing import Optional, Sequence, Union, Tuple, overload
|
|
8
7
|
import numpy
|
|
9
8
|
from returnn.tensor import Tensor, Dim
|
|
@@ -77,7 +76,7 @@ __all__ = [
|
|
|
77
76
|
]
|
|
78
77
|
|
|
79
78
|
|
|
80
|
-
@
|
|
79
|
+
@overload
|
|
81
80
|
def compare(
|
|
82
81
|
a: Tensor,
|
|
83
82
|
kind: str,
|
|
@@ -86,7 +85,19 @@ def compare(
|
|
|
86
85
|
allow_broadcast_all_sources: Optional[bool] = None,
|
|
87
86
|
dim_order: Optional[Sequence[Dim]] = None,
|
|
88
87
|
) -> Tensor:
|
|
89
|
-
"""compare
|
|
88
|
+
"""compare"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@overload
|
|
92
|
+
def compare(
|
|
93
|
+
a: Union[Tensor, _RawTensorTypes],
|
|
94
|
+
kind: str,
|
|
95
|
+
b: Union[Tensor, _RawTensorTypes],
|
|
96
|
+
*,
|
|
97
|
+
allow_broadcast_all_sources: Optional[bool] = None,
|
|
98
|
+
dim_order: Optional[Sequence[Dim]] = None,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
"""compare"""
|
|
90
101
|
|
|
91
102
|
|
|
92
103
|
_CompareMap = {
|
|
@@ -138,7 +149,7 @@ def compare_bc(
|
|
|
138
149
|
return compare(a, kind, b, allow_broadcast_all_sources=True, dim_order=dim_order)
|
|
139
150
|
|
|
140
151
|
|
|
141
|
-
@
|
|
152
|
+
@overload
|
|
142
153
|
def combine(
|
|
143
154
|
a: Tensor,
|
|
144
155
|
kind: str,
|
|
@@ -147,7 +158,19 @@ def combine(
|
|
|
147
158
|
allow_broadcast_all_sources: Optional[bool] = None,
|
|
148
159
|
dim_order: Optional[Sequence[Dim]] = None,
|
|
149
160
|
) -> Tensor:
|
|
150
|
-
"""combine
|
|
161
|
+
"""combine"""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@overload
|
|
165
|
+
def combine(
|
|
166
|
+
a: Union[Tensor, _RawTensorTypes],
|
|
167
|
+
kind: str,
|
|
168
|
+
b: Union[Tensor, _RawTensorTypes],
|
|
169
|
+
*,
|
|
170
|
+
allow_broadcast_all_sources: Optional[bool] = None,
|
|
171
|
+
dim_order: Optional[Sequence[Dim]] = None,
|
|
172
|
+
) -> Union[Tensor, _RawTensorTypes]:
|
|
173
|
+
"""combine"""
|
|
151
174
|
|
|
152
175
|
|
|
153
176
|
_CombineMap = {
|
|
@@ -332,7 +355,12 @@ def logical_not(a: Tensor) -> Tensor:
|
|
|
332
355
|
|
|
333
356
|
@overload
|
|
334
357
|
def opt_logical_or(a: bool, b: bool) -> bool:
|
|
335
|
-
"""logical or"""
|
|
358
|
+
"""opt logical or"""
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@overload
|
|
362
|
+
def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
363
|
+
"""opt logical or"""
|
|
336
364
|
|
|
337
365
|
|
|
338
366
|
def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
@@ -350,7 +378,12 @@ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tens
|
|
|
350
378
|
|
|
351
379
|
@overload
|
|
352
380
|
def opt_logical_and(a: bool, b: bool) -> bool:
|
|
353
|
-
"""logical and"""
|
|
381
|
+
"""opt logical and"""
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
@overload
|
|
385
|
+
def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
386
|
+
"""opt logical and"""
|
|
354
387
|
|
|
355
388
|
|
|
356
389
|
def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
|
|
@@ -416,16 +449,23 @@ def minimum(a: Tensor, b: Union[Tensor, _RawTensorTypes], *other_tensors) -> Ten
|
|
|
416
449
|
|
|
417
450
|
def clip_by_value(
|
|
418
451
|
x: Tensor,
|
|
419
|
-
clip_value_min: Union[Tensor, _RawTensorTypes],
|
|
420
|
-
clip_value_max: Union[Tensor, _RawTensorTypes],
|
|
452
|
+
clip_value_min: Union[None, Tensor, _RawTensorTypes] = None,
|
|
453
|
+
clip_value_max: Union[None, Tensor, _RawTensorTypes] = None,
|
|
421
454
|
*,
|
|
422
455
|
allow_broadcast_all_sources: bool = False,
|
|
423
456
|
) -> Tensor:
|
|
424
457
|
"""clip by value"""
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
x
|
|
428
|
-
|
|
458
|
+
if clip_value_min is not None and clip_value_max is not None:
|
|
459
|
+
# noinspection PyProtectedMember
|
|
460
|
+
return x._raw_backend.clip_by_value(
|
|
461
|
+
x, clip_value_min, clip_value_max, allow_broadcast_all_sources=allow_broadcast_all_sources
|
|
462
|
+
)
|
|
463
|
+
elif clip_value_min is not None and clip_value_max is None:
|
|
464
|
+
return maximum(x, clip_value_min)
|
|
465
|
+
elif clip_value_min is None and clip_value_max is not None:
|
|
466
|
+
return minimum(x, clip_value_max)
|
|
467
|
+
else:
|
|
468
|
+
return x
|
|
429
469
|
|
|
430
470
|
|
|
431
471
|
def identity(x: Tensor) -> Tensor:
|
|
@@ -541,7 +581,7 @@ def floor(a: Tensor) -> Tensor:
|
|
|
541
581
|
|
|
542
582
|
# noinspection PyShadowingBuiltins
|
|
543
583
|
def round(a: Tensor) -> Tensor:
|
|
544
|
-
"""round"""
|
|
584
|
+
"""round. the result dtype is same as input dtype, still float"""
|
|
545
585
|
# noinspection PyProtectedMember
|
|
546
586
|
return a._raw_backend.activation(a, "round")
|
|
547
587
|
|