returnn 1.20250304.113330__py3-none-any.whl → 1.20250305.150759__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250304.113330
3
+ Version: 1.20250305.150759
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250304.113330'
2
- long_version = '1.20250304.113330+git.acf09da'
1
+ version = '1.20250305.150759'
2
+ long_version = '1.20250305.150759+git.fd21be2'
returnn/frontend/loop.py CHANGED
@@ -128,6 +128,16 @@ def scan(
128
128
  like selecting the right beam entries.
129
129
  :return: outputs ys, final state, and the new spatial_dim
130
130
  """
131
+ device = None
132
+ if initial is not None:
133
+ vs = [v for v in tree.flatten(initial) if isinstance(v, Tensor) and v.device not in (None, "cpu")]
134
+ if vs:
135
+ device = vs[0].device
136
+ if device is None and xs is not None:
137
+ vs = [v for v in tree.flatten(xs) if isinstance(v, Tensor) and v.device not in (None, "cpu")]
138
+ if vs:
139
+ device = vs[0].device
140
+
131
141
  if spatial_dim is None or not spatial_dim.is_dim_known():
132
142
  assert cond is not None, f"scan: spatial_dim {spatial_dim} is None/unknown, need to provide `cond`"
133
143
  assert cond_dims is not None, f"scan: spatial_dim {spatial_dim} is None/unknown, need to provide `cond_dims`"
@@ -145,12 +155,13 @@ def scan(
145
155
  def _body(_s: Tuple[Tensor, Tensor, Tensor, S, Y]) -> Tuple[Tensor, Tensor, Tensor, S, Y]:
146
156
  i, seq_len_, prev_cond, s, ys_ = _s
147
157
  seq_len_ = seq_len_ + rf.cast(prev_cond, dtype=seq_len_.dtype)
148
- y, s = body(None, s)
158
+ y, s_ = body(None, s)
149
159
  tree.assert_same_structure(ys_, y)
150
160
  ys_ = tree.map_structure(lambda ys__, y_: ys__.push_back(y_) if ys__ is not None else None, ys_, y)
151
- c = cond(None, s)
161
+ c = cond(None, s_)
152
162
  c = rf.logical_and(c, prev_cond)
153
- return i + 1, seq_len_, c, s, ys_
163
+ s_ = rf.nested.mask_nested(s_, mask=c, mask_value=s, allow_dim_extension=False)
164
+ return i + 1, seq_len_, c, s_, ys_
154
165
 
155
166
  if cond_before_body:
156
167
  initial_cond = cond(None, initial)
@@ -187,10 +198,13 @@ def scan(
187
198
 
188
199
  def _body(_s: Tuple[Tensor, S, Y]) -> Tuple[Tensor, S, Y]:
189
200
  i, s, ys_ = _s
190
- y, s = body(tree.map_structure(lambda x: x[i], xs), s)
201
+ y, s_ = body(tree.map_structure(lambda x: x[i], xs), s)
191
202
  tree.assert_same_structure(ys_, y)
192
203
  ys_ = tree.map_structure(lambda ys__, y_: ys__.push_back(y_) if ys__ is not None else None, ys_, y)
193
- return i + 1, s, ys_
204
+ s_ = rf.nested.mask_nested(
205
+ s_, mask=i < spatial_dim.get_size_tensor(device=device), mask_value=s, allow_dim_extension=False
206
+ )
207
+ return i + 1, s_, ys_
194
208
 
195
209
  _, final_s, ys = while_loop(
196
210
  _cond,
@@ -3,7 +3,7 @@ Some utility functions on nested structures.
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import TypeVar, Optional, Sequence, Tuple, Dict
6
+ from typing import TypeVar, Optional, Union, Sequence, Tuple, Dict
7
7
  import functools
8
8
  import re
9
9
  import tree
@@ -11,12 +11,108 @@ from returnn.tensor import Tensor, Dim
11
11
  import returnn.frontend as rf
12
12
 
13
13
 
14
- __all__ = ["gather_nested", "masked_select_nested", "masked_scatter_nested"]
14
+ __all__ = ["mask_nested", "gather_nested", "masked_select_nested", "masked_scatter_nested"]
15
15
 
16
16
 
17
17
  T = TypeVar("T")
18
18
 
19
19
 
20
+ def mask_nested(
21
+ s: T,
22
+ *,
23
+ mask: Tensor,
24
+ mask_value: Union[T, Tensor, float, None],
25
+ dim_map: Optional[Dict[Dim, Dim]] = None,
26
+ allow_dim_extension: bool = True,
27
+ ) -> T:
28
+ """
29
+ Applies where(mask, s, mask_value) for nested structures.
30
+
31
+ :param s:
32
+ :param mask:
33
+ :param mask_value:
34
+ :param dim_map:
35
+ :param allow_dim_extension:
36
+ :return: s with masked values
37
+ """
38
+ if dim_map is None:
39
+ dim_map = {}
40
+ partial_kwargs = dict(mask=mask, dim_map=dim_map, allow_dim_extension=allow_dim_extension)
41
+ structures = [s]
42
+ if type(s) is type(mask_value): # mask_value also same nested structure?
43
+ tree.assert_same_structure(s, mask_value)
44
+ structures.append(mask_value)
45
+ else:
46
+ partial_kwargs["mask_value"] = mask_value
47
+ tree.map_structure(functools.partial(_mask_prepare_dims, **partial_kwargs), *structures)
48
+ return tree.map_structure(functools.partial(_mask, **partial_kwargs), *structures)
49
+
50
+
51
+ def _mask_prepare_dims(
52
+ s: T, mask_value: Union[T, Tensor, float, None], *, mask: Tensor, dim_map: Dict[Dim, Dim], allow_dim_extension: bool
53
+ ) -> T:
54
+ if isinstance(s, Dim):
55
+ if mask_value is None:
56
+ return s # not sure if always correct...
57
+ assert isinstance(mask_value, Dim)
58
+ if s == mask_value:
59
+ return s
60
+ if not allow_dim_extension:
61
+ dim_size_dims = set()
62
+ if s.dyn_size_ext is not None:
63
+ dim_size_dims.update(s.dyn_size_ext.dims_set)
64
+ if mask_value.dyn_size_ext is not None:
65
+ dim_size_dims.update(mask_value.dyn_size_ext.dims_set)
66
+ if not mask.dims_set.issubset(dim_size_dims):
67
+ assert not mask.dims_set.intersection(dim_size_dims) # not sure...
68
+ return s
69
+ new_dyn_size = _mask(
70
+ s.get_size_tensor(),
71
+ mask=mask,
72
+ mask_value=mask_value.get_size_tensor(),
73
+ dim_map=dim_map,
74
+ allow_dim_extension=allow_dim_extension,
75
+ )
76
+ new_dim = Dim(new_dyn_size, name=_extend_dim_name(s.name))
77
+ dim_map[s] = dim_map[mask_value] = new_dim
78
+ return new_dim
79
+ return s
80
+
81
+
82
+ def _mask(
83
+ s: T, mask_value: Union[T, Tensor, float, None], *, mask: Tensor, dim_map: Dict[Dim, Dim], allow_dim_extension: bool
84
+ ) -> T:
85
+ if s is None:
86
+ return s
87
+ if isinstance(s, Tensor):
88
+ if dim_map:
89
+ for d in s.dims:
90
+ if d in dim_map:
91
+ s = rf.replace_dim_v2(s, in_dim=d, out_dim=dim_map[d])
92
+ if isinstance(mask_value, Tensor):
93
+ for d in mask_value.dims:
94
+ if d in dim_map:
95
+ mask_value = rf.replace_dim_v2(mask_value, in_dim=d, out_dim=dim_map[d])
96
+ if not allow_dim_extension and isinstance(mask_value, Tensor):
97
+ if not s.dims_set.issuperset(mask_value.dims_set):
98
+ return s
99
+ if not allow_dim_extension or mask_value is None or (isinstance(mask_value, (int, float)) and mask_value == 0):
100
+ if mask.dims_set.issubset(s.dims_set):
101
+ return rf.where(mask, s, mask_value)
102
+ assert not mask.dims_set.intersection(s.dims_set) # not sure...
103
+ return s
104
+ assert isinstance(mask_value, (int, float, Tensor))
105
+ return rf.where(mask, s, mask_value, allow_broadcast_all_sources=True)
106
+ if isinstance(s, Dim):
107
+ if mask_value is None:
108
+ return s
109
+ assert isinstance(mask_value, Dim)
110
+ if s == mask_value:
111
+ return s
112
+ return dim_map.get(s, s)
113
+ raise TypeError(f"_mask: unexpected {s!r} type {type(s).__name__}")
114
+
115
+
20
116
  def gather_nested(s: T, *, indices: Tensor, dim_map: Optional[Dict[Dim, Dim]] = None) -> T:
21
117
  """
22
118
  This is like :func:`gather`, but for nested structures.
returnn/frontend/rec.py CHANGED
@@ -70,6 +70,9 @@ class LSTM(rf.Module):
70
70
  out_dim=self.out_dim,
71
71
  )
72
72
  new_state = LstmState(h=new_state_h, c=new_state_c)
73
+ result.feature_dim = self.out_dim
74
+ new_state.h.feature_dim = self.out_dim
75
+ new_state.c.feature_dim = self.out_dim
73
76
 
74
77
  return result, new_state
75
78
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250304.113330
3
+ Version: 1.20250305.150759
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
1
+ returnn/PKG-INFO,sha256=WGNUtpero6ia7jcFsuoDC8KQaPT077SI434pgHtd6y0,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=94BElbYUGmjpsoY8BzvfW39RUTXw9Fy3UwlPoEjrkU8,77
6
+ returnn/_setup_info_generated.py,sha256=_BUxCpQI2oODCG2knN-6EoX8imvgddkHg0F7kXRyik0,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -99,19 +99,19 @@ returnn/frontend/hooks.py,sha256=jYPbsb4gy5HORRZvKTEJbLcoJri5hOt5ADbhnTCytQo,550
99
99
  returnn/frontend/init.py,sha256=bVB7bpghaY8DI_HL0mkB_9z95onWnIX2zlW4hlMYnRw,7494
100
100
  returnn/frontend/label_smoothing.py,sha256=lxmaowNr61sCMzMewqHhu1r0CcklYfhLXlFnBu8DeAU,5676
101
101
  returnn/frontend/linear.py,sha256=xRUjnkD3MTWDezSaYATBYJQ2fa1RhKMNrTuhC54hhVs,2252
102
- returnn/frontend/loop.py,sha256=anIdKjrS0PdHwcqRR5NG9OTPpoydeEGd0QawlG-AF_k,15498
102
+ returnn/frontend/loop.py,sha256=vNfq-9U6Jb0ZW2zM9QmbTaXZBkBFQ5MzC9V0aT1ATCU,16158
103
103
  returnn/frontend/loss.py,sha256=r_qRAiFIXgXTnFuLoAhE5jsyAZYMCucRzf55XfWbkC8,7441
104
104
  returnn/frontend/math_.py,sha256=KlJxdIib8ENlid7cc4lcwHv5e21tzTjTEV8VgEDAijo,16984
105
105
  returnn/frontend/matmul.py,sha256=3QaGiZtSs9PriT40T7Vc3KnYKPgYSN4tCZytYeq9qMA,1945
106
106
  returnn/frontend/module.py,sha256=219rh5mE0CD0-NdxXLsKyhv3BNtOI9jSyiI1Rb8MOyU,10700
107
- returnn/frontend/nested.py,sha256=CT3C0wXkeWGjJcAoF6yebsRXuN8-YpjO2eqgdl1-vaE,11005
107
+ returnn/frontend/nested.py,sha256=oDiqnyTML7ZtCxrufU4ypG0fOZ_WsZPIvfdSn1Phs6M,14698
108
108
  returnn/frontend/normalization.py,sha256=QIjXYg0C8BD2g_1lAkVO4Cara729uHC_bsQh99VsWeI,14061
109
109
  returnn/frontend/parameter.py,sha256=w6SN-uv87OyeWBt90_3UBbK0h6sftSOCxkqXPg76caY,10375
110
110
  returnn/frontend/parametrizations.py,sha256=hVbOlgm1pQAmZnAnNxq8Tk23rykr_iy3-6R1H6CwlMA,2798
111
111
  returnn/frontend/parametrize.py,sha256=VhgTEP7ehON950Q4bkCy8rvg9641moEKAXn0XzomK6E,7216
112
112
  returnn/frontend/piecewise_linear.py,sha256=TdL6wzop8P1dcIZwkEbJFvSUZSI1cbhS3XKzlWQkEVI,1964
113
113
  returnn/frontend/rand.py,sha256=Levgf5VtOOBKDSgz0869Jf3VW4BWxYZuRXsa_fOxNI4,12969
114
- returnn/frontend/rec.py,sha256=4m20LvsPJ75pRYykVrup6Csj_D7duG-dW28SaJh-sq8,7863
114
+ returnn/frontend/rec.py,sha256=la-VXR_hzvwNzpAgn4Okl-yDx3F4gOW-81EKm-jAAlg,7999
115
115
  returnn/frontend/reduce.py,sha256=-Zt-OH6Zbtb9uR6YEzurCyrowH-anIXvuga6Pla2V70,10220
116
116
  returnn/frontend/run_ctx.py,sha256=ItcZwuFItkZjYWrg715L1Za2Xg7__MQCrRCAwBeTUxA,21411
117
117
  returnn/frontend/signal.py,sha256=XgOBL1iy-cJgulePH5HRPAwp2cScy60q4RItr7xzvGc,4412
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250304.113330.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250304.113330.dist-info/METADATA,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
258
- returnn-1.20250304.113330.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
- returnn-1.20250304.113330.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250304.113330.dist-info/RECORD,,
256
+ returnn-1.20250305.150759.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250305.150759.dist-info/METADATA,sha256=WGNUtpero6ia7jcFsuoDC8KQaPT077SI434pgHtd6y0,5215
258
+ returnn-1.20250305.150759.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
+ returnn-1.20250305.150759.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250305.150759.dist-info/RECORD,,