returnn 1.20250216.155246__py3-none-any.whl → 1.20250220.200053__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/_backend.py +7 -9
- returnn/frontend/_native/module.cpp +3 -2
- returnn/frontend/_native/tensor_ops.cpp +8 -0
- returnn/frontend/_numpy_backend.py +2 -1
- returnn/frontend/array_.py +73 -1
- returnn/frontend/dims.py +31 -0
- returnn/frontend/rand.py +30 -0
- returnn/tf/frontend_layers/_backend.py +3 -0
- returnn/tf/frontend_low_level/_backend.py +5 -10
- returnn/torch/frontend/_backend.py +21 -10
- {returnn-1.20250216.155246.dist-info → returnn-1.20250220.200053.dist-info}/METADATA +1 -1
- {returnn-1.20250216.155246.dist-info → returnn-1.20250220.200053.dist-info}/RECORD +17 -17
- {returnn-1.20250216.155246.dist-info → returnn-1.20250220.200053.dist-info}/LICENSE +0 -0
- {returnn-1.20250216.155246.dist-info → returnn-1.20250220.200053.dist-info}/WHEEL +0 -0
- {returnn-1.20250216.155246.dist-info → returnn-1.20250220.200053.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250220.200053'
|
|
2
|
+
long_version = '1.20250220.200053+git.bb5c0aa'
|
returnn/frontend/_backend.py
CHANGED
|
@@ -784,18 +784,11 @@ class Backend(Generic[T]):
|
|
|
784
784
|
dims: Sequence[Dim],
|
|
785
785
|
dtype: str,
|
|
786
786
|
sparse_dim: Optional[Dim] = None,
|
|
787
|
+
feature_dim: Optional[Dim] = None,
|
|
787
788
|
device: Optional[str] = None,
|
|
788
789
|
name: Optional[str] = None,
|
|
789
790
|
) -> Tensor[T]:
|
|
790
|
-
"""
|
|
791
|
-
:param value: tensor, or scalar raw tensor or some other scalar value
|
|
792
|
-
:param dims:
|
|
793
|
-
:param dtype:
|
|
794
|
-
:param sparse_dim:
|
|
795
|
-
:param device:
|
|
796
|
-
:param name:
|
|
797
|
-
:return: tensor
|
|
798
|
-
"""
|
|
791
|
+
"""convert (raw/any) tensor to tensor"""
|
|
799
792
|
raise NotImplementedError
|
|
800
793
|
|
|
801
794
|
@staticmethod
|
|
@@ -956,6 +949,11 @@ class Backend(Generic[T]):
|
|
|
956
949
|
"""where"""
|
|
957
950
|
raise NotImplementedError
|
|
958
951
|
|
|
952
|
+
@staticmethod
|
|
953
|
+
def sort(source: Tensor, *, axis: Dim, descending: bool, stable: bool) -> Tuple[Tensor, Tensor, Dim]:
|
|
954
|
+
"""sort. return values and indices"""
|
|
955
|
+
raise NotImplementedError
|
|
956
|
+
|
|
959
957
|
@staticmethod
|
|
960
958
|
def search_sorted(
|
|
961
959
|
sorted_seq: Tensor, values: Tensor, *, axis: Dim, side: str = "left", out_dtype: str = "int32"
|
|
@@ -312,8 +312,9 @@ bool PyModuleState::_cachedOpInitTorch() {
|
|
|
312
312
|
AddOp(TOp_FloorDiv, "floor_divide");
|
|
313
313
|
AddOp(TOp_Mod, "remainder");
|
|
314
314
|
AddOp(TOp_Pow, "pow");
|
|
315
|
-
|
|
316
|
-
AddOp(
|
|
315
|
+
// Use clamp_min/clamp_max instead of maximum/minimum because the former allow number arguments.
|
|
316
|
+
AddOp(TOp_Maximum, "clamp_min");
|
|
317
|
+
AddOp(TOp_Minimum, "clamp_max");
|
|
317
318
|
AddOpAlt(TOp_SquaredDifference, "squared_difference");
|
|
318
319
|
AddOp(TOp_And, "logical_and");
|
|
319
320
|
AddOp(TOp_Or, "logical_or");
|
|
@@ -1368,6 +1368,14 @@ static PyObject* compareOrCombineViaCached(
|
|
|
1368
1368
|
case TOp_FloorDiv:
|
|
1369
1369
|
case TOp_Mod:
|
|
1370
1370
|
case TOp_Pow:
|
|
1371
|
+
case TOp_Maximum:
|
|
1372
|
+
case TOp_Minimum:
|
|
1373
|
+
case TOp_Eq:
|
|
1374
|
+
case TOp_Ne:
|
|
1375
|
+
case TOp_Lt:
|
|
1376
|
+
case TOp_Le:
|
|
1377
|
+
case TOp_Gt:
|
|
1378
|
+
case TOp_Ge:
|
|
1371
1379
|
needConvertToTensor = false;
|
|
1372
1380
|
default:
|
|
1373
1381
|
break;
|
|
@@ -83,6 +83,7 @@ class NumpyBackend(Backend[numpy.ndarray]):
|
|
|
83
83
|
dims: Sequence[Dim],
|
|
84
84
|
dtype: str,
|
|
85
85
|
sparse_dim: Optional[Dim] = None,
|
|
86
|
+
feature_dim: Optional[Dim] = None,
|
|
86
87
|
device: Optional[str] = None,
|
|
87
88
|
name: Optional[str] = None,
|
|
88
89
|
) -> Tensor[numpy.ndarray]:
|
|
@@ -95,7 +96,7 @@ class NumpyBackend(Backend[numpy.ndarray]):
|
|
|
95
96
|
name = name or "const"
|
|
96
97
|
value = numpy.array(value, dtype=NumpyBackend.as_dtype_raw(dtype))
|
|
97
98
|
assert isinstance(value, numpy.ndarray)
|
|
98
|
-
return Tensor(name, dims=dims, dtype=dtype, sparse_dim=sparse_dim, raw_tensor=value)
|
|
99
|
+
return Tensor(name, dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim, raw_tensor=value)
|
|
99
100
|
|
|
100
101
|
@staticmethod
|
|
101
102
|
def expand_dims_raw(raw_tensor: numpy.ndarray, axis: int) -> numpy.ndarray:
|
returnn/frontend/array_.py
CHANGED
|
@@ -45,9 +45,12 @@ __all__ = [
|
|
|
45
45
|
"shift_left",
|
|
46
46
|
"reverse_sequence",
|
|
47
47
|
"where",
|
|
48
|
+
"sort",
|
|
48
49
|
"search_sorted",
|
|
49
50
|
"sparse_to_dense",
|
|
50
51
|
"one_hot",
|
|
52
|
+
"top_k_mask",
|
|
53
|
+
"top_p_mask",
|
|
51
54
|
]
|
|
52
55
|
|
|
53
56
|
|
|
@@ -57,6 +60,7 @@ def convert_to_tensor(
|
|
|
57
60
|
dims: Sequence[Dim] = None,
|
|
58
61
|
dtype: Optional[str] = None,
|
|
59
62
|
sparse_dim: Optional[Dim] = None,
|
|
63
|
+
feature_dim: Optional[Dim] = None,
|
|
60
64
|
shape: Sequence[Dim] = None,
|
|
61
65
|
device: Optional[str] = None,
|
|
62
66
|
keep_scalar_on_cpu: bool = False,
|
|
@@ -68,6 +72,7 @@ def convert_to_tensor(
|
|
|
68
72
|
:param dims:
|
|
69
73
|
:param dtype:
|
|
70
74
|
:param sparse_dim:
|
|
75
|
+
:param feature_dim:
|
|
71
76
|
:param shape: alias for dims, for some older code
|
|
72
77
|
:param name:
|
|
73
78
|
:param device:
|
|
@@ -121,7 +126,7 @@ def convert_to_tensor(
|
|
|
121
126
|
if dtype is None:
|
|
122
127
|
dtype = value_backend.get_dtype_name_raw(value)
|
|
123
128
|
return _backend.convert_to_tensor(
|
|
124
|
-
value=value, dims=dims, dtype=dtype, sparse_dim=sparse_dim, device=device, name=name
|
|
129
|
+
value=value, dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim, device=device, name=name
|
|
125
130
|
)
|
|
126
131
|
|
|
127
132
|
|
|
@@ -996,6 +1001,27 @@ def where(
|
|
|
996
1001
|
return cond._raw_backend.where(cond, true_, false_, allow_broadcast_all_sources=allow_broadcast_all_sources)
|
|
997
1002
|
|
|
998
1003
|
|
|
1004
|
+
def sort(source: Tensor, *, axis: Dim, descending: bool = False, stable: bool = True) -> Tuple[Tensor, Tensor, Dim]:
|
|
1005
|
+
"""
|
|
1006
|
+
Sorts the source tensor along the given axis.
|
|
1007
|
+
|
|
1008
|
+
See also :func:`top_k`.
|
|
1009
|
+
:func:`top_k` with ``k=axis.get_size_tensor()`` is equivalent to this function.
|
|
1010
|
+
|
|
1011
|
+
:param source: {other_dims..., axis}
|
|
1012
|
+
:param axis: The axis to sort along.
|
|
1013
|
+
:param descending: If True, sort in descending order, otherwise in ascending order.
|
|
1014
|
+
:param stable: If True, use a stable sorting algorithm (not reordering equal elements).
|
|
1015
|
+
Note that many frameworks (Torch, TensorFlow) have ``stable=False`` by default.
|
|
1016
|
+
``stable=False`` can be faster.
|
|
1017
|
+
:return: sorted tensor, indices tensor, out_dim. both tensors have the shape {other_dims..., out_dim},
|
|
1018
|
+
i.e. ``axis`` replaced by ``out_dim``.
|
|
1019
|
+
indices tensor has sparse_dim set to ``axis``.
|
|
1020
|
+
"""
|
|
1021
|
+
# noinspection PyProtectedMember
|
|
1022
|
+
return source._raw_backend.sort(source, axis=axis, descending=descending, stable=stable)
|
|
1023
|
+
|
|
1024
|
+
|
|
999
1025
|
def search_sorted(
|
|
1000
1026
|
sorted_seq: Tensor, values: Tensor, *, axis: Dim, side: str = "left", out_dtype: str = "int32"
|
|
1001
1027
|
) -> Tensor:
|
|
@@ -1044,3 +1070,49 @@ def one_hot(source: Tensor) -> Tensor:
|
|
|
1044
1070
|
and much more efficiently than they would be with dense tensors.
|
|
1045
1071
|
"""
|
|
1046
1072
|
return sparse_to_dense(source, label_value=1.0, other_value=0.0)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def top_k_mask(values: Tensor, *, axis: Dim, k: Union[int, Tensor]) -> Tensor:
|
|
1076
|
+
"""
|
|
1077
|
+
Top-k filtering.
|
|
1078
|
+
|
|
1079
|
+
:param values: {other_dims..., axis}
|
|
1080
|
+
:param axis:
|
|
1081
|
+
:param k: the number of top values to keep
|
|
1082
|
+
:return: mask {other_dims..., axis} of the top-k values
|
|
1083
|
+
"""
|
|
1084
|
+
_, indices, k_dim = rf.top_k(values, axis=axis, k=k)
|
|
1085
|
+
mask = rf.scatter(rf.full(dims=indices.dims, fill_value=True), indices=indices, indices_dim=k_dim, fill_value=False)
|
|
1086
|
+
return mask
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def top_p_mask(
|
|
1090
|
+
probs: Tensor,
|
|
1091
|
+
*,
|
|
1092
|
+
axis: Dim,
|
|
1093
|
+
p: Union[float, Tensor],
|
|
1094
|
+
one_more: bool = True,
|
|
1095
|
+
) -> Tensor:
|
|
1096
|
+
"""
|
|
1097
|
+
Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
|
|
1098
|
+
|
|
1099
|
+
:param probs: {probs_dims..., axis}
|
|
1100
|
+
:param axis:
|
|
1101
|
+
:param p: the probability mass to keep
|
|
1102
|
+
:param one_more: if True (default), keep also the first token above the threshold.
|
|
1103
|
+
(It's enabled by default to follow the behavior of the original implementation.)
|
|
1104
|
+
:return: mask {probs_dims..., axis} of the top-p tokens.
|
|
1105
|
+
``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
|
|
1106
|
+
"""
|
|
1107
|
+
assert 0.0 <= p <= 1.0
|
|
1108
|
+
if isinstance(p, Tensor):
|
|
1109
|
+
assert axis not in p.dims
|
|
1110
|
+
# https://github.com/ari-holtzman/degen/blob/master/gen.py
|
|
1111
|
+
sorted_probs, sorted_indices, sorted_dim = rf.sort(probs, axis=axis, descending=True)
|
|
1112
|
+
cum_probs = rf.cumsum(sorted_probs, spatial_dim=sorted_dim)
|
|
1113
|
+
mask = cum_probs <= p # {probs_dims..., sorted_dim}
|
|
1114
|
+
if one_more:
|
|
1115
|
+
# keep also the first token above the threshold
|
|
1116
|
+
mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
|
|
1117
|
+
mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
|
|
1118
|
+
return mask
|
returnn/frontend/dims.py
CHANGED
|
@@ -14,6 +14,7 @@ __all__ = [
|
|
|
14
14
|
"range_over_dim",
|
|
15
15
|
"range_over_dim_strided",
|
|
16
16
|
"range_over_merged_dims",
|
|
17
|
+
"linspace_over_dim",
|
|
17
18
|
"replace_dim",
|
|
18
19
|
"replace_dim_v2",
|
|
19
20
|
"set_sparse_dim",
|
|
@@ -81,6 +82,36 @@ def range_over_merged_dims(
|
|
|
81
82
|
return indices
|
|
82
83
|
|
|
83
84
|
|
|
85
|
+
def linspace_over_dim(
|
|
86
|
+
dim: Dim,
|
|
87
|
+
start: Union[float, Tensor] = 0.0,
|
|
88
|
+
end: Union[float, Tensor] = 1.0,
|
|
89
|
+
*,
|
|
90
|
+
dtype: Optional[str] = None,
|
|
91
|
+
device: Optional[str] = None,
|
|
92
|
+
) -> Tensor:
|
|
93
|
+
"""
|
|
94
|
+
Linearly spaced values over a dim.
|
|
95
|
+
|
|
96
|
+
:param dim: dim to range over
|
|
97
|
+
:param start: start value
|
|
98
|
+
:param end: end value
|
|
99
|
+
:param dtype: dtype of the output tensor
|
|
100
|
+
:param device: device of the output tensor
|
|
101
|
+
:return: tensor with shape [dim] containing linearly spaced values between start and end
|
|
102
|
+
"""
|
|
103
|
+
if dtype is None:
|
|
104
|
+
dtype = rf.get_default_float_dtype()
|
|
105
|
+
indices = rf.range_over_dim(dim, dtype=dtype, device=device)
|
|
106
|
+
linspace = indices / rf.cast(rf.maximum(dim.get_size_tensor(device=indices.device), 1), dtype=indices.dtype)
|
|
107
|
+
space_len = end - start
|
|
108
|
+
if not isinstance(space_len, (int, float)) or space_len != 1:
|
|
109
|
+
linspace *= space_len
|
|
110
|
+
if not isinstance(start, (int, float)) or start != 0:
|
|
111
|
+
linspace += start
|
|
112
|
+
return linspace
|
|
113
|
+
|
|
114
|
+
|
|
84
115
|
def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Optional[Dim] = None) -> Tuple[Tensor, Dim]:
|
|
85
116
|
"""
|
|
86
117
|
Also see: :func:`replace_dim_v2`, :func:`rf.merge_dims`, :func:`rf.split_dims`.
|
returnn/frontend/rand.py
CHANGED
|
@@ -64,6 +64,7 @@ __all__ = [
|
|
|
64
64
|
"random_uniform",
|
|
65
65
|
"random_normal",
|
|
66
66
|
"random_truncated_normal",
|
|
67
|
+
"random_choice_without_replacement",
|
|
67
68
|
]
|
|
68
69
|
|
|
69
70
|
|
|
@@ -349,3 +350,32 @@ def random_truncated_normal(
|
|
|
349
350
|
static=static,
|
|
350
351
|
out=out,
|
|
351
352
|
)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def random_choice_without_replacement(
|
|
356
|
+
*,
|
|
357
|
+
log_probs: Tensor,
|
|
358
|
+
axis: Union[Dim, Sequence[Dim]],
|
|
359
|
+
num_samples_dim: Dim,
|
|
360
|
+
noise_scale: Union[float, Tensor] = 1.0,
|
|
361
|
+
) -> Union[Tensor, Sequence[Tensor]]:
|
|
362
|
+
"""
|
|
363
|
+
Randomly sample without replacement.
|
|
364
|
+
|
|
365
|
+
:param log_probs: {log_probs_dims..., axis}
|
|
366
|
+
:param axis: same as in :func:`top_k`
|
|
367
|
+
:param num_samples_dim: how many samples to draw
|
|
368
|
+
:param noise_scale: scale the noise. with scale=0, you get :func:`top_k`.
|
|
369
|
+
:return: random indices shape {log_probs_dims..., num_samples_dim} -> axis.
|
|
370
|
+
if axis was a sequence, will return a sequence of tensors.
|
|
371
|
+
"""
|
|
372
|
+
# https://github.com/tensorflow/tensorflow/issues/9260
|
|
373
|
+
# https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
|
|
374
|
+
scores_random_sample = -rf.log(
|
|
375
|
+
-rf.log(random_uniform(log_probs.dims, dtype=log_probs.dtype, device=log_probs.device))
|
|
376
|
+
)
|
|
377
|
+
if not isinstance(noise_scale, (int, float)) or noise_scale != 1.0:
|
|
378
|
+
scores_random_sample *= noise_scale
|
|
379
|
+
scores = log_probs + scores_random_sample
|
|
380
|
+
_, indices, _ = rf.top_k(scores, k_dim=num_samples_dim, axis=axis)
|
|
381
|
+
return indices
|
|
@@ -559,6 +559,7 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
559
559
|
dims: Sequence[Dim],
|
|
560
560
|
dtype: str,
|
|
561
561
|
sparse_dim: Optional[Dim] = None,
|
|
562
|
+
feature_dim: Optional[Dim] = None,
|
|
562
563
|
device: Optional[str] = None,
|
|
563
564
|
name: Optional[str] = None,
|
|
564
565
|
) -> Tensor[Layer]:
|
|
@@ -568,6 +569,8 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
568
569
|
kwargs = {}
|
|
569
570
|
if sparse_dim:
|
|
570
571
|
kwargs["sparse_dim"] = sparse_dim
|
|
572
|
+
if feature_dim:
|
|
573
|
+
kwargs["feature_dim"] = feature_dim
|
|
571
574
|
dim_deps = _dims.get_dim_deps(dims)
|
|
572
575
|
if dim_deps:
|
|
573
576
|
kwargs["shape_deps"] = dim_deps
|
|
@@ -411,24 +411,19 @@ class TFBackend(Backend[tf.Tensor]):
|
|
|
411
411
|
dims: Sequence[Dim],
|
|
412
412
|
dtype: str,
|
|
413
413
|
sparse_dim: Optional[Dim] = None,
|
|
414
|
+
feature_dim: Optional[Dim] = None,
|
|
414
415
|
device: Optional[str] = None,
|
|
415
416
|
name: Optional[str] = None,
|
|
416
417
|
) -> _TT:
|
|
417
|
-
"""
|
|
418
|
-
:param value:
|
|
419
|
-
:param dims:
|
|
420
|
-
:param dtype:
|
|
421
|
-
:param sparse_dim:
|
|
422
|
-
:param device:
|
|
423
|
-
:param name:
|
|
424
|
-
:return: tensor
|
|
425
|
-
"""
|
|
418
|
+
"""convert to tensor"""
|
|
426
419
|
if isinstance(value, Tensor):
|
|
427
420
|
return value
|
|
428
421
|
with tf.control_dependencies(None):
|
|
429
422
|
value = tf.convert_to_tensor(value, dtype=dtype)
|
|
430
423
|
assert isinstance(value, tf.Tensor)
|
|
431
|
-
return Tensor(
|
|
424
|
+
return Tensor(
|
|
425
|
+
name or "const", raw_tensor=value, dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim
|
|
426
|
+
)
|
|
432
427
|
|
|
433
428
|
@staticmethod
|
|
434
429
|
def range_over_dim(dim: Dim, *, dtype: Optional[str] = None, device: Optional[str] = None) -> _TT:
|
|
@@ -895,18 +895,11 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
895
895
|
dims: Sequence[Dim],
|
|
896
896
|
dtype: str,
|
|
897
897
|
sparse_dim: Optional[Dim] = None,
|
|
898
|
+
feature_dim: Optional[Dim] = None,
|
|
898
899
|
device: Optional[str] = None,
|
|
899
900
|
name: Optional[str] = None,
|
|
900
901
|
) -> Tensor[torch.Tensor]:
|
|
901
|
-
"""
|
|
902
|
-
:param value:
|
|
903
|
-
:param dims:
|
|
904
|
-
:param dtype:
|
|
905
|
-
:param sparse_dim:
|
|
906
|
-
:param device:
|
|
907
|
-
:param name:
|
|
908
|
-
:return: tensor
|
|
909
|
-
"""
|
|
902
|
+
"""convert to tensor"""
|
|
910
903
|
if isinstance(value, Tensor):
|
|
911
904
|
return value
|
|
912
905
|
if isinstance(value, torch.Tensor):
|
|
@@ -926,7 +919,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
926
919
|
device=device or rf.get_default_device(),
|
|
927
920
|
)
|
|
928
921
|
assert isinstance(value, torch.Tensor)
|
|
929
|
-
return Tensor(name, dims=dims, dtype=dtype, sparse_dim=sparse_dim, raw_tensor=value)
|
|
922
|
+
return Tensor(name, dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim, raw_tensor=value)
|
|
930
923
|
|
|
931
924
|
@staticmethod
|
|
932
925
|
def full(
|
|
@@ -1223,6 +1216,21 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1223
1216
|
out.raw_tensor = torch.where(cond_bc_raw, true_bc_raw, false_bc_raw)
|
|
1224
1217
|
return out
|
|
1225
1218
|
|
|
1219
|
+
@staticmethod
|
|
1220
|
+
def sort(source: Tensor, *, axis: Dim, descending: bool, stable: bool) -> Tuple[Tensor, Tensor, Dim]:
|
|
1221
|
+
"""sort. return values and indices"""
|
|
1222
|
+
axis_int = source.get_axis_from_description(axis, allow_int=False)
|
|
1223
|
+
# Move to last axis. Should be more efficient.
|
|
1224
|
+
source = source.copy_move_axis(axis_int, -1)
|
|
1225
|
+
axis_int = source.batch_ndim - 1
|
|
1226
|
+
values_raw, indices_raw = torch.sort(source.raw_tensor, dim=axis_int, descending=descending, stable=stable)
|
|
1227
|
+
out_dims = list(source.dims)
|
|
1228
|
+
out_dim = axis.copy(same_as_self=False, description=f"{axis.description}:sorted")
|
|
1229
|
+
out_dims[axis_int] = out_dim
|
|
1230
|
+
values = rf.convert_to_tensor(values_raw, dims=out_dims, feature_dim={axis: out_dim}.get(source.feature_dim))
|
|
1231
|
+
indices = rf.convert_to_tensor(indices_raw, dims=out_dims, sparse_dim=axis)
|
|
1232
|
+
return values, indices, out_dim
|
|
1233
|
+
|
|
1226
1234
|
@staticmethod
|
|
1227
1235
|
def search_sorted(
|
|
1228
1236
|
sorted_seq: Tensor, values: Tensor, *, axis: Dim, side: str = "left", out_dtype: str = "int32"
|
|
@@ -1566,6 +1574,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1566
1574
|
return values, indices_out, k_dim
|
|
1567
1575
|
assert isinstance(axis, Dim)
|
|
1568
1576
|
axis_int = source.get_axis_from_description(axis, allow_int=False)
|
|
1577
|
+
# Move to last axis. Should be more efficient.
|
|
1578
|
+
source = source.copy_move_axis(axis_int, -1)
|
|
1579
|
+
axis_int = source.batch_ndim - 1
|
|
1569
1580
|
values_raw, indices_raw = torch.topk(
|
|
1570
1581
|
source.raw_tensor, k=k_dim.get_dim_value(), dim=axis_int, largest=True, sorted=sorted
|
|
1571
1582
|
)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=U58QGiF-75H5Ac8V3JUwKdPkzP3TPwuPkhfzHhpa7Vc,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=OojdMrmzo4naqIdlDTwnSiMHtnmVuqlosY9_dqmm20c,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -75,12 +75,12 @@ returnn/extern/graph_editor/subgraph.py,sha256=R3uIFqWgiL7L5S4YATm9o9a3wfEa_mSb4
|
|
|
75
75
|
returnn/extern/graph_editor/transform.py,sha256=d9fEgu0JC342q0g9niVxRWMKzkQQA9mrrajBGcU1o_s,29349
|
|
76
76
|
returnn/extern/graph_editor/util.py,sha256=QMrQeQZ7lJwsrNQub9tof0h3quEaoHiGJaZmogQ7jXE,18707
|
|
77
77
|
returnn/frontend/__init__.py,sha256=2aS7nbxXniIrBp2DODl0xN0f3IJ_dX4Bi9ZlR7W5_DE,1472
|
|
78
|
-
returnn/frontend/_backend.py,sha256=
|
|
78
|
+
returnn/frontend/_backend.py,sha256=W3J3ZSOxonX6wk-wY2dX_aokXHpm1VQ1V0qSjllQxUM,50165
|
|
79
79
|
returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,8403
|
|
80
|
-
returnn/frontend/_numpy_backend.py,sha256=
|
|
80
|
+
returnn/frontend/_numpy_backend.py,sha256=2oCtG0YCWL_89v4cD_jDj8em1O_Fp-_YWl5EblGi_yo,7858
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=4A3MSRM0i86J77550uR_AjcBEPu6nymLUZ9Xd1V3Fkc,12073
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=UHTQmb_cFsjVStAELcCqMkCbQNQiBiwN4gQZu6CloIA,44126
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -90,7 +90,7 @@ returnn/frontend/container.py,sha256=wF3OlQN7WlOVmmdapUth_Unha3DVf6h1B7okBJAuJDA
|
|
|
90
90
|
returnn/frontend/control_flow_ctx.py,sha256=v17CsNwRnZYe8GdMtGJt2ftibfxMCGK1i0l-GX5ILu0,699
|
|
91
91
|
returnn/frontend/conv.py,sha256=51LZovcRzITDLXvPcJs_MFsGEY_MFvO_MFF9D-jZstA,22481
|
|
92
92
|
returnn/frontend/device.py,sha256=K7Y1qoQcO4GIHgLkPLQWK-GVT8gKL8GwyQrmPo8LgBE,1438
|
|
93
|
-
returnn/frontend/dims.py,sha256=
|
|
93
|
+
returnn/frontend/dims.py,sha256=hKA7IQRB0DbohN1ngNw31W44BsyjdHCtYAccxOcumzQ,10872
|
|
94
94
|
returnn/frontend/dropout.py,sha256=rsx3p5b0NblBfXXSQZTQFJ8jUUS3fj4Qzc39iffBMCA,5006
|
|
95
95
|
returnn/frontend/dtype.py,sha256=Ooc5BrcNrTp6XShuFEV9g5V6-niuy4ImP_Lt_Qgq3jE,1886
|
|
96
96
|
returnn/frontend/gradient.py,sha256=dOUvLqN-vxsvjKQfpfIvEYlx4TlpHkOk-p9hsB680iA,3376
|
|
@@ -110,7 +110,7 @@ returnn/frontend/parameter.py,sha256=w6SN-uv87OyeWBt90_3UBbK0h6sftSOCxkqXPg76caY
|
|
|
110
110
|
returnn/frontend/parametrizations.py,sha256=hVbOlgm1pQAmZnAnNxq8Tk23rykr_iy3-6R1H6CwlMA,2798
|
|
111
111
|
returnn/frontend/parametrize.py,sha256=VhgTEP7ehON950Q4bkCy8rvg9641moEKAXn0XzomK6E,7216
|
|
112
112
|
returnn/frontend/piecewise_linear.py,sha256=TdL6wzop8P1dcIZwkEbJFvSUZSI1cbhS3XKzlWQkEVI,1964
|
|
113
|
-
returnn/frontend/rand.py,sha256=
|
|
113
|
+
returnn/frontend/rand.py,sha256=Levgf5VtOOBKDSgz0869Jf3VW4BWxYZuRXsa_fOxNI4,12969
|
|
114
114
|
returnn/frontend/rec.py,sha256=4m20LvsPJ75pRYykVrup6Csj_D7duG-dW28SaJh-sq8,7863
|
|
115
115
|
returnn/frontend/reduce.py,sha256=-Zt-OH6Zbtb9uR6YEzurCyrowH-anIXvuga6Pla2V70,10220
|
|
116
116
|
returnn/frontend/run_ctx.py,sha256=ItcZwuFItkZjYWrg715L1Za2Xg7__MQCrRCAwBeTUxA,21411
|
|
@@ -122,10 +122,10 @@ returnn/frontend/types.py,sha256=gpevnXZSlF_BgA76duIkkzN-ed_MflhSlOnHj1xJnAs,111
|
|
|
122
122
|
returnn/frontend/_native/__init__.py,sha256=fVjazAujt0rdICXZL-GgW1sjFeL1HB4NPuy2m5rmMsc,6480
|
|
123
123
|
returnn/frontend/_native/backend.cpp,sha256=MeHczHypwj_ncntOxRqanK8SqGyV9Eq1X0cpMWb_WII,4768
|
|
124
124
|
returnn/frontend/_native/backend.hpp,sha256=Wq80dcEzXfRNxGOXFnIgHllkiv1rDi3KpHK-xxJsSDI,791
|
|
125
|
-
returnn/frontend/_native/module.cpp,sha256=
|
|
125
|
+
returnn/frontend/_native/module.cpp,sha256=lS1Oypo3n6oCu6cxKAmqpNjSvQN9aMZIOeMec96FWYU,15626
|
|
126
126
|
returnn/frontend/_native/module.hpp,sha256=uf4HPSTrFP2brGR_x9G5N1ZlZ-ok5GakMbNo4LbqxUg,6670
|
|
127
127
|
returnn/frontend/_native/py_utils.hpp,sha256=vcxKGmOyDRuwsmmSEjoaCJyKMy1BNYoGlso2pZu7VoE,3139
|
|
128
|
-
returnn/frontend/_native/tensor_ops.cpp,sha256=
|
|
128
|
+
returnn/frontend/_native/tensor_ops.cpp,sha256=bYtwwn_NeJfAEHWYPEJlkoLDKt9baZ3RA8av7gtz2qc,70246
|
|
129
129
|
returnn/frontend/_native/tensor_ops.hpp,sha256=dDqvUejRNHjItnmOP5aHyAQbAmXmXoDVXSe3tveEU8A,3732
|
|
130
130
|
returnn/frontend/audio/__init__.py,sha256=8mahwucBje8qHKw0bOvoySlvvD0rFKxviSvcAHSjiJY,67
|
|
131
131
|
returnn/frontend/audio/mel.py,sha256=VZdxf2mTLzLOXsLRzCvaad712Zf0c2iwdthrzeVfgxk,7885
|
|
@@ -177,7 +177,7 @@ returnn/tf/sprint.py,sha256=Yqjh0-6sCWHpdDPQCzHKx7TwQCOjJyjfd0KHtnYdd-8,5471
|
|
|
177
177
|
returnn/tf/updater.py,sha256=St4Z5iBjlkWaB6CiS-K1VNc_iLaan2e6-mVMTTPldzk,72034
|
|
178
178
|
returnn/tf/frontend_layers/README.md,sha256=P4vVl_EK-4jT55m40mq-K4Nr9yFY0tJR5fmDzTHSDFE,1096
|
|
179
179
|
returnn/tf/frontend_layers/__init__.py,sha256=MGUn7rv6fOefbtkX-5pq6fC1T6Y5h0oh1uOPSEcv1_I,506
|
|
180
|
-
returnn/tf/frontend_layers/_backend.py,sha256=
|
|
180
|
+
returnn/tf/frontend_layers/_backend.py,sha256=8lWE6LxxdNx8FnFvp2Pnk-UqJ8oymxXwx7s9HTEgDug,47443
|
|
181
181
|
returnn/tf/frontend_layers/_utils.py,sha256=ijByaDOqPDod5mZC9EoTkt8PHBEODXHsWbkwDOF9XW4,4205
|
|
182
182
|
returnn/tf/frontend_layers/cond.py,sha256=yQ2h5W0sgMZndJdrWv2EE9k9yIcspQ1U0HwBSh3hOKE,14830
|
|
183
183
|
returnn/tf/frontend_layers/config_entry_points.py,sha256=t01RWOiaZohzuqPXX-MLV0P5yCOfE0dz-9dZ77_pK4c,5751
|
|
@@ -190,7 +190,7 @@ returnn/tf/frontend_layers/masked_computation.py,sha256=I_TW0Qm4Yl_wPZ6TkuK7a-wB
|
|
|
190
190
|
returnn/tf/frontend_layers/parameter_assign.py,sha256=B_7kgobRyFtExiuSy2MsVGpAR36-jdG-xKABGc6EUGM,5103
|
|
191
191
|
returnn/tf/frontend_layers/prev_tensor_ref.py,sha256=EqTAanOgYAhl8o2fMylN52mfReH9heAQFdzn9CwqAX4,2282
|
|
192
192
|
returnn/tf/frontend_low_level/__init__.py,sha256=34469k3KzMUIGowxReOZnbf6WdTjxY73Gp1a4WqDN1M,62
|
|
193
|
-
returnn/tf/frontend_low_level/_backend.py,sha256=
|
|
193
|
+
returnn/tf/frontend_low_level/_backend.py,sha256=JwwRRIGnElqBC4bTImdB7w3U1u_SJESeZHYLmq86wog,24479
|
|
194
194
|
returnn/tf/layers/__init__.py,sha256=Ngu-X84nWFgz7ndDu88DqoZ-5lUMMTQWH4g7N8pSoCg,72
|
|
195
195
|
returnn/tf/layers/base.py,sha256=KcADpZUxqLkoFpQPMe_l9thRC7rpyBJIZCHITmnOd7M,153169
|
|
196
196
|
returnn/tf/layers/basic.py,sha256=IVQ_6PkM-uuBN_vVg-VeGM74bb1pc6TjJhKf92pPS1I,610870
|
|
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
|
|
|
216
216
|
returnn/torch/data/returnn_dataset_wrapper.py,sha256=1Bw82-Ge_8m_DSDXZNqQ3zGDic2HQlp6jysELL0NVK0,7369
|
|
217
217
|
returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
|
|
218
218
|
returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
|
|
219
|
-
returnn/torch/frontend/_backend.py,sha256=
|
|
219
|
+
returnn/torch/frontend/_backend.py,sha256=mjR6Ilt2zlnIO4_CpVPCLQ0XVJa_QmW3HsZtR2KT8yk,101110
|
|
220
220
|
returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
|
|
221
221
|
returnn/torch/frontend/bridge.py,sha256=Z2_UW8AagezC7zsXDc5PKcd8G9WwisV7j9SWGHU0m4U,7840
|
|
222
222
|
returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250220.200053.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250220.200053.dist-info/METADATA,sha256=U58QGiF-75H5Ac8V3JUwKdPkzP3TPwuPkhfzHhpa7Vc,5215
|
|
258
|
+
returnn-1.20250220.200053.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
259
|
+
returnn-1.20250220.200053.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250220.200053.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|