returnn 1.20250304.113330__py3-none-any.whl → 1.20250305.150759__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/loop.py +19 -5
- returnn/frontend/nested.py +98 -2
- returnn/frontend/rec.py +3 -0
- {returnn-1.20250304.113330.dist-info → returnn-1.20250305.150759.dist-info}/METADATA +1 -1
- {returnn-1.20250304.113330.dist-info → returnn-1.20250305.150759.dist-info}/RECORD +10 -10
- {returnn-1.20250304.113330.dist-info → returnn-1.20250305.150759.dist-info}/LICENSE +0 -0
- {returnn-1.20250304.113330.dist-info → returnn-1.20250305.150759.dist-info}/WHEEL +0 -0
- {returnn-1.20250304.113330.dist-info → returnn-1.20250305.150759.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250305.150759'
|
|
2
|
+
long_version = '1.20250305.150759+git.fd21be2'
|
returnn/frontend/loop.py
CHANGED
|
@@ -128,6 +128,16 @@ def scan(
|
|
|
128
128
|
like selecting the right beam entries.
|
|
129
129
|
:return: outputs ys, final state, and the new spatial_dim
|
|
130
130
|
"""
|
|
131
|
+
device = None
|
|
132
|
+
if initial is not None:
|
|
133
|
+
vs = [v for v in tree.flatten(initial) if isinstance(v, Tensor) and v.device not in (None, "cpu")]
|
|
134
|
+
if vs:
|
|
135
|
+
device = vs[0].device
|
|
136
|
+
if device is None and xs is not None:
|
|
137
|
+
vs = [v for v in tree.flatten(xs) if isinstance(v, Tensor) and v.device not in (None, "cpu")]
|
|
138
|
+
if vs:
|
|
139
|
+
device = vs[0].device
|
|
140
|
+
|
|
131
141
|
if spatial_dim is None or not spatial_dim.is_dim_known():
|
|
132
142
|
assert cond is not None, f"scan: spatial_dim {spatial_dim} is None/unknown, need to provide `cond`"
|
|
133
143
|
assert cond_dims is not None, f"scan: spatial_dim {spatial_dim} is None/unknown, need to provide `cond_dims`"
|
|
@@ -145,12 +155,13 @@ def scan(
|
|
|
145
155
|
def _body(_s: Tuple[Tensor, Tensor, Tensor, S, Y]) -> Tuple[Tensor, Tensor, Tensor, S, Y]:
|
|
146
156
|
i, seq_len_, prev_cond, s, ys_ = _s
|
|
147
157
|
seq_len_ = seq_len_ + rf.cast(prev_cond, dtype=seq_len_.dtype)
|
|
148
|
-
y,
|
|
158
|
+
y, s_ = body(None, s)
|
|
149
159
|
tree.assert_same_structure(ys_, y)
|
|
150
160
|
ys_ = tree.map_structure(lambda ys__, y_: ys__.push_back(y_) if ys__ is not None else None, ys_, y)
|
|
151
|
-
c = cond(None,
|
|
161
|
+
c = cond(None, s_)
|
|
152
162
|
c = rf.logical_and(c, prev_cond)
|
|
153
|
-
|
|
163
|
+
s_ = rf.nested.mask_nested(s_, mask=c, mask_value=s, allow_dim_extension=False)
|
|
164
|
+
return i + 1, seq_len_, c, s_, ys_
|
|
154
165
|
|
|
155
166
|
if cond_before_body:
|
|
156
167
|
initial_cond = cond(None, initial)
|
|
@@ -187,10 +198,13 @@ def scan(
|
|
|
187
198
|
|
|
188
199
|
def _body(_s: Tuple[Tensor, S, Y]) -> Tuple[Tensor, S, Y]:
|
|
189
200
|
i, s, ys_ = _s
|
|
190
|
-
y,
|
|
201
|
+
y, s_ = body(tree.map_structure(lambda x: x[i], xs), s)
|
|
191
202
|
tree.assert_same_structure(ys_, y)
|
|
192
203
|
ys_ = tree.map_structure(lambda ys__, y_: ys__.push_back(y_) if ys__ is not None else None, ys_, y)
|
|
193
|
-
|
|
204
|
+
s_ = rf.nested.mask_nested(
|
|
205
|
+
s_, mask=i < spatial_dim.get_size_tensor(device=device), mask_value=s, allow_dim_extension=False
|
|
206
|
+
)
|
|
207
|
+
return i + 1, s_, ys_
|
|
194
208
|
|
|
195
209
|
_, final_s, ys = while_loop(
|
|
196
210
|
_cond,
|
returnn/frontend/nested.py
CHANGED
|
@@ -3,7 +3,7 @@ Some utility functions on nested structures.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import TypeVar, Optional, Sequence, Tuple, Dict
|
|
6
|
+
from typing import TypeVar, Optional, Union, Sequence, Tuple, Dict
|
|
7
7
|
import functools
|
|
8
8
|
import re
|
|
9
9
|
import tree
|
|
@@ -11,12 +11,108 @@ from returnn.tensor import Tensor, Dim
|
|
|
11
11
|
import returnn.frontend as rf
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
__all__ = ["gather_nested", "masked_select_nested", "masked_scatter_nested"]
|
|
14
|
+
__all__ = ["mask_nested", "gather_nested", "masked_select_nested", "masked_scatter_nested"]
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
T = TypeVar("T")
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
def mask_nested(
|
|
21
|
+
s: T,
|
|
22
|
+
*,
|
|
23
|
+
mask: Tensor,
|
|
24
|
+
mask_value: Union[T, Tensor, float, None],
|
|
25
|
+
dim_map: Optional[Dict[Dim, Dim]] = None,
|
|
26
|
+
allow_dim_extension: bool = True,
|
|
27
|
+
) -> T:
|
|
28
|
+
"""
|
|
29
|
+
Applies where(mask, s, mask_value) for nested structures.
|
|
30
|
+
|
|
31
|
+
:param s:
|
|
32
|
+
:param mask:
|
|
33
|
+
:param mask_value:
|
|
34
|
+
:param dim_map:
|
|
35
|
+
:param allow_dim_extension:
|
|
36
|
+
:return: s with masked values
|
|
37
|
+
"""
|
|
38
|
+
if dim_map is None:
|
|
39
|
+
dim_map = {}
|
|
40
|
+
partial_kwargs = dict(mask=mask, dim_map=dim_map, allow_dim_extension=allow_dim_extension)
|
|
41
|
+
structures = [s]
|
|
42
|
+
if type(s) is type(mask_value): # mask_value also same nested structure?
|
|
43
|
+
tree.assert_same_structure(s, mask_value)
|
|
44
|
+
structures.append(mask_value)
|
|
45
|
+
else:
|
|
46
|
+
partial_kwargs["mask_value"] = mask_value
|
|
47
|
+
tree.map_structure(functools.partial(_mask_prepare_dims, **partial_kwargs), *structures)
|
|
48
|
+
return tree.map_structure(functools.partial(_mask, **partial_kwargs), *structures)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _mask_prepare_dims(
|
|
52
|
+
s: T, mask_value: Union[T, Tensor, float, None], *, mask: Tensor, dim_map: Dict[Dim, Dim], allow_dim_extension: bool
|
|
53
|
+
) -> T:
|
|
54
|
+
if isinstance(s, Dim):
|
|
55
|
+
if mask_value is None:
|
|
56
|
+
return s # not sure if always correct...
|
|
57
|
+
assert isinstance(mask_value, Dim)
|
|
58
|
+
if s == mask_value:
|
|
59
|
+
return s
|
|
60
|
+
if not allow_dim_extension:
|
|
61
|
+
dim_size_dims = set()
|
|
62
|
+
if s.dyn_size_ext is not None:
|
|
63
|
+
dim_size_dims.update(s.dyn_size_ext.dims_set)
|
|
64
|
+
if mask_value.dyn_size_ext is not None:
|
|
65
|
+
dim_size_dims.update(mask_value.dyn_size_ext.dims_set)
|
|
66
|
+
if not mask.dims_set.issubset(dim_size_dims):
|
|
67
|
+
assert not mask.dims_set.intersection(dim_size_dims) # not sure...
|
|
68
|
+
return s
|
|
69
|
+
new_dyn_size = _mask(
|
|
70
|
+
s.get_size_tensor(),
|
|
71
|
+
mask=mask,
|
|
72
|
+
mask_value=mask_value.get_size_tensor(),
|
|
73
|
+
dim_map=dim_map,
|
|
74
|
+
allow_dim_extension=allow_dim_extension,
|
|
75
|
+
)
|
|
76
|
+
new_dim = Dim(new_dyn_size, name=_extend_dim_name(s.name))
|
|
77
|
+
dim_map[s] = dim_map[mask_value] = new_dim
|
|
78
|
+
return new_dim
|
|
79
|
+
return s
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _mask(
|
|
83
|
+
s: T, mask_value: Union[T, Tensor, float, None], *, mask: Tensor, dim_map: Dict[Dim, Dim], allow_dim_extension: bool
|
|
84
|
+
) -> T:
|
|
85
|
+
if s is None:
|
|
86
|
+
return s
|
|
87
|
+
if isinstance(s, Tensor):
|
|
88
|
+
if dim_map:
|
|
89
|
+
for d in s.dims:
|
|
90
|
+
if d in dim_map:
|
|
91
|
+
s = rf.replace_dim_v2(s, in_dim=d, out_dim=dim_map[d])
|
|
92
|
+
if isinstance(mask_value, Tensor):
|
|
93
|
+
for d in mask_value.dims:
|
|
94
|
+
if d in dim_map:
|
|
95
|
+
mask_value = rf.replace_dim_v2(mask_value, in_dim=d, out_dim=dim_map[d])
|
|
96
|
+
if not allow_dim_extension and isinstance(mask_value, Tensor):
|
|
97
|
+
if not s.dims_set.issuperset(mask_value.dims_set):
|
|
98
|
+
return s
|
|
99
|
+
if not allow_dim_extension or mask_value is None or (isinstance(mask_value, (int, float)) and mask_value == 0):
|
|
100
|
+
if mask.dims_set.issubset(s.dims_set):
|
|
101
|
+
return rf.where(mask, s, mask_value)
|
|
102
|
+
assert not mask.dims_set.intersection(s.dims_set) # not sure...
|
|
103
|
+
return s
|
|
104
|
+
assert isinstance(mask_value, (int, float, Tensor))
|
|
105
|
+
return rf.where(mask, s, mask_value, allow_broadcast_all_sources=True)
|
|
106
|
+
if isinstance(s, Dim):
|
|
107
|
+
if mask_value is None:
|
|
108
|
+
return s
|
|
109
|
+
assert isinstance(mask_value, Dim)
|
|
110
|
+
if s == mask_value:
|
|
111
|
+
return s
|
|
112
|
+
return dim_map.get(s, s)
|
|
113
|
+
raise TypeError(f"_mask: unexpected {s!r} type {type(s).__name__}")
|
|
114
|
+
|
|
115
|
+
|
|
20
116
|
def gather_nested(s: T, *, indices: Tensor, dim_map: Optional[Dict[Dim, Dim]] = None) -> T:
|
|
21
117
|
"""
|
|
22
118
|
This is like :func:`gather`, but for nested structures.
|
returnn/frontend/rec.py
CHANGED
|
@@ -70,6 +70,9 @@ class LSTM(rf.Module):
|
|
|
70
70
|
out_dim=self.out_dim,
|
|
71
71
|
)
|
|
72
72
|
new_state = LstmState(h=new_state_h, c=new_state_c)
|
|
73
|
+
result.feature_dim = self.out_dim
|
|
74
|
+
new_state.h.feature_dim = self.out_dim
|
|
75
|
+
new_state.c.feature_dim = self.out_dim
|
|
73
76
|
|
|
74
77
|
return result, new_state
|
|
75
78
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=WGNUtpero6ia7jcFsuoDC8KQaPT077SI434pgHtd6y0,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=_BUxCpQI2oODCG2knN-6EoX8imvgddkHg0F7kXRyik0,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -99,19 +99,19 @@ returnn/frontend/hooks.py,sha256=jYPbsb4gy5HORRZvKTEJbLcoJri5hOt5ADbhnTCytQo,550
|
|
|
99
99
|
returnn/frontend/init.py,sha256=bVB7bpghaY8DI_HL0mkB_9z95onWnIX2zlW4hlMYnRw,7494
|
|
100
100
|
returnn/frontend/label_smoothing.py,sha256=lxmaowNr61sCMzMewqHhu1r0CcklYfhLXlFnBu8DeAU,5676
|
|
101
101
|
returnn/frontend/linear.py,sha256=xRUjnkD3MTWDezSaYATBYJQ2fa1RhKMNrTuhC54hhVs,2252
|
|
102
|
-
returnn/frontend/loop.py,sha256=
|
|
102
|
+
returnn/frontend/loop.py,sha256=vNfq-9U6Jb0ZW2zM9QmbTaXZBkBFQ5MzC9V0aT1ATCU,16158
|
|
103
103
|
returnn/frontend/loss.py,sha256=r_qRAiFIXgXTnFuLoAhE5jsyAZYMCucRzf55XfWbkC8,7441
|
|
104
104
|
returnn/frontend/math_.py,sha256=KlJxdIib8ENlid7cc4lcwHv5e21tzTjTEV8VgEDAijo,16984
|
|
105
105
|
returnn/frontend/matmul.py,sha256=3QaGiZtSs9PriT40T7Vc3KnYKPgYSN4tCZytYeq9qMA,1945
|
|
106
106
|
returnn/frontend/module.py,sha256=219rh5mE0CD0-NdxXLsKyhv3BNtOI9jSyiI1Rb8MOyU,10700
|
|
107
|
-
returnn/frontend/nested.py,sha256=
|
|
107
|
+
returnn/frontend/nested.py,sha256=oDiqnyTML7ZtCxrufU4ypG0fOZ_WsZPIvfdSn1Phs6M,14698
|
|
108
108
|
returnn/frontend/normalization.py,sha256=QIjXYg0C8BD2g_1lAkVO4Cara729uHC_bsQh99VsWeI,14061
|
|
109
109
|
returnn/frontend/parameter.py,sha256=w6SN-uv87OyeWBt90_3UBbK0h6sftSOCxkqXPg76caY,10375
|
|
110
110
|
returnn/frontend/parametrizations.py,sha256=hVbOlgm1pQAmZnAnNxq8Tk23rykr_iy3-6R1H6CwlMA,2798
|
|
111
111
|
returnn/frontend/parametrize.py,sha256=VhgTEP7ehON950Q4bkCy8rvg9641moEKAXn0XzomK6E,7216
|
|
112
112
|
returnn/frontend/piecewise_linear.py,sha256=TdL6wzop8P1dcIZwkEbJFvSUZSI1cbhS3XKzlWQkEVI,1964
|
|
113
113
|
returnn/frontend/rand.py,sha256=Levgf5VtOOBKDSgz0869Jf3VW4BWxYZuRXsa_fOxNI4,12969
|
|
114
|
-
returnn/frontend/rec.py,sha256=
|
|
114
|
+
returnn/frontend/rec.py,sha256=la-VXR_hzvwNzpAgn4Okl-yDx3F4gOW-81EKm-jAAlg,7999
|
|
115
115
|
returnn/frontend/reduce.py,sha256=-Zt-OH6Zbtb9uR6YEzurCyrowH-anIXvuga6Pla2V70,10220
|
|
116
116
|
returnn/frontend/run_ctx.py,sha256=ItcZwuFItkZjYWrg715L1Za2Xg7__MQCrRCAwBeTUxA,21411
|
|
117
117
|
returnn/frontend/signal.py,sha256=XgOBL1iy-cJgulePH5HRPAwp2cScy60q4RItr7xzvGc,4412
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250305.150759.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250305.150759.dist-info/METADATA,sha256=WGNUtpero6ia7jcFsuoDC8KQaPT077SI434pgHtd6y0,5215
|
|
258
|
+
returnn-1.20250305.150759.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
259
|
+
returnn-1.20250305.150759.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250305.150759.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|