returnn 1.20250828.2732__py3-none-any.whl → 1.20250829.151916__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250828.2732
3
+ Version: 1.20250829.151916
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250828.002732'
2
- long_version = '1.20250828.002732+git.06c221e'
1
+ version = '1.20250829.151916'
2
+ long_version = '1.20250829.151916+git.687fa49'
@@ -6,7 +6,7 @@ One use case example is :func:`sinusoidal_positional_encoding` and :func:`relati
6
6
  """
7
7
 
8
8
  from __future__ import annotations
9
- from typing import Optional, Union, Any, Type, Callable, Tuple, Dict
9
+ from typing import Optional, Union, Any, Type, Callable, Tuple, Dict, List
10
10
  from weakref import ref
11
11
  import tree
12
12
  from returnn.util.lru_cache import lru_cache
@@ -59,6 +59,8 @@ class Cache:
59
59
  if isinstance(key_item_orig, DimWrapper):
60
60
  assert isinstance(key_item, DimWrapper)
61
61
  dim_orig = key_item_orig.dim_ref()
62
+ if dim_orig is None: # orig dim could be dead. but then it would not be used anyway
63
+ continue
62
64
  dim = key_item.dim_ref()
63
65
  assert isinstance(dim_orig, Dim) and isinstance(dim, Dim)
64
66
  dim_map[dim_orig] = dim
@@ -103,7 +105,7 @@ def _transform_key(
103
105
  key: Any, *, finalize_callback: Optional[Callable] = None, collected_dim_map: Optional[Dict[Dim, DimWrapper]] = None
104
106
  ) -> Tuple[Union[Type[Backend], ref[rf.RunCtx], _KeyItemType], ...]:
105
107
  backend = _get_backend(key)
106
- keys_flat = [backend]
108
+ keys_flat: List[Any] = [backend]
107
109
  if not backend.executing_eagerly():
108
110
  # See comment above: If graph-mode, the cached value becomes invalid
109
111
  # when the current run ctx goes out of scope.
@@ -188,22 +188,18 @@ def merge_dims(
188
188
  return source, dims[0]
189
189
  return rf.replace_dim(source, in_dim=dims[0], out_dim=out_dim)
190
190
  if out_dim is None:
191
- out_dim = dims[0]
192
- reset_dyn_size = False
193
- for d in dims[1:]:
194
- reset_dyn_size |= d.need_masking() and out_dim.capacity != 1
195
- out_dim = out_dim * d
196
- if reset_dyn_size:
191
+ from returnn.util.basic import prod
192
+
193
+ if any(d.need_masking() for d in dims[1:]):
197
194
  # The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
198
195
  # This would then potentially discard some of the data in the tensor in subsequent operations,
199
196
  # when masking is applied.
200
197
  # Thus, discard the dynamic sizes, and just treat it as a flat dim with scalar dynamic size.
201
198
  # https://github.com/rwth-i6/returnn/issues/1694
202
- out_dim_size = dims[0].get_dim_value_tensor()
203
- for d in dims[1:]:
204
- out_dim_size *= d.get_dim_value_tensor()
205
- assert isinstance(out_dim_size, Tensor) and out_dim_size.dims == () # scalar
206
- out_dim.dyn_size_ext = out_dim_size
199
+ # See also similar logic in :func:`concat`.
200
+ out_dim = Dim(prod(d.get_dim_value_tensor() for d in dims), name="merged")
201
+ else:
202
+ out_dim = prod(dims)
207
203
  # noinspection PyProtectedMember
208
204
  return source._raw_backend.merge_dims(source, dims=dims, out_dim=out_dim), out_dim
209
205
 
@@ -345,7 +341,9 @@ def window(
345
341
  """
346
342
  if spatial_dim.need_masking():
347
343
  if use_mask is None:
348
- use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
344
+ use_mask = rf.use_mask_default(
345
+ default=True, default_false_for_behavior_version_up_to=22, func_name="window"
346
+ )
349
347
  if use_mask:
350
348
  source = source.copy_masked(0, dims=[spatial_dim])
351
349
  assert window_dim.dimension is not None
@@ -427,28 +425,39 @@ def concat(
427
425
  dims = sources[0][0].dims_set - {sources[0][1]}
428
426
  for src, dim in sources:
429
427
  assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
428
+ need_handle_dynamic_dims = False
429
+ for src, dim in sources[:-1]:
430
+ if dim.need_masking():
431
+ need_handle_dynamic_dims = True
432
+ if handle_dynamic_dims is None:
433
+ handle_dynamic_dims = need_handle_dynamic_dims
430
434
  if not out_dim:
431
- out_dim = sum(d for _, d in sources)
432
- # noinspection PyProtectedMember
433
- out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim)
434
- if handle_dynamic_dims is None or handle_dynamic_dims:
435
- need_to_handle = False
436
- for src, dim in sources[:-1]:
437
- if dim.need_masking():
438
- need_to_handle = True
439
- if need_to_handle:
440
- masks = []
441
- for _, dim in sources:
442
- masks.append(
443
- dim.get_mask(dim_order=(dim,) + dim.dyn_size_ext.dims, device=out.device)
444
- if dim.need_masking()
445
- else rf.constant(True, dims=[dim], device=out.device)
446
- )
447
- # noinspection PyProtectedMember
448
- mask_concat = sources[0][0]._raw_backend.concat(
449
- *[(mask, dim) for (_, dim), mask in zip(sources, masks)], allow_broadcast=True, out_dim=out_dim
435
+ if handle_dynamic_dims or not need_handle_dynamic_dims:
436
+ out_dim = sum(d for _, d in sources)
437
+ else: # not handle_dynamic_dims but need_handle_dynamic_dims
438
+ # There are dynamic dims, but we don't want to handle them.
439
+ # So, summing the dims would be incorrect.
440
+ # Just add the dim values.
441
+ out_dim = Dim(sum(d.get_dim_value_tensor() for _, d in sources if d.dimension is not None), name="concat")
442
+ if handle_dynamic_dims:
443
+ out_non_masked_dim = Dim(sum(d.get_dim_value_tensor() for _, d in sources))
444
+ # noinspection PyProtectedMember
445
+ out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_non_masked_dim)
446
+ masks = []
447
+ for _, dim in sources:
448
+ masks.append(
449
+ dim.get_mask(dim_order=(dim,) + dim.dyn_size_ext.dims, device=out.device)
450
+ if dim.need_masking()
451
+ else rf.constant(True, dims=[dim], device=out.device)
450
452
  )
451
- out, out_dim = rf.masked_select(out, mask=mask_concat, dims=[out_dim])
453
+ # noinspection PyProtectedMember
454
+ mask_concat = sources[0][0]._raw_backend.concat(
455
+ *[(mask, dim) for (_, dim), mask in zip(sources, masks)], allow_broadcast=True, out_dim=out_non_masked_dim
456
+ )
457
+ out, _ = rf.masked_select(out, mask=mask_concat, dims=[out_non_masked_dim], out_dim=out_dim)
458
+ else:
459
+ # noinspection PyProtectedMember
460
+ out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim)
452
461
  return out, out_dim
453
462
 
454
463
 
@@ -494,7 +503,12 @@ def pad(
494
503
  if handle_dynamic_dims is None:
495
504
  handle_dynamic_dims = _pad_handle_dynamic_dims_default(axes, padding, mode=mode)
496
505
  if not out_dims:
497
- out_dims = [left + middle + right for middle, (left, right) in zip(axes, padding)]
506
+ out_dims = [
507
+ (left + middle + right)
508
+ if handle_dynamic_dims or not _pad_need_dyn_dim_handling(middle, left, right, mode=mode)
509
+ else _pad_sum_dims_no_dyn_dim_handling(middle, left, right)
510
+ for middle, (left, right) in zip(axes, padding)
511
+ ]
498
512
  # noinspection PyProtectedMember
499
513
  return (
500
514
  source._raw_backend.pad(
@@ -560,6 +574,32 @@ def _pad_need_dyn_dim_handling(
560
574
  return True
561
575
 
562
576
 
577
+ def _pad_sum_dims_no_dyn_dim_handling(
578
+ middle: Dim, left: Union[Dim, int, Tensor], right: Union[Dim, int, Tensor]
579
+ ) -> Dim:
580
+ """
581
+ This gets called when we need to handle dyn dims, but handle_dynamic_dims=False.
582
+ See also the same logic in :func:`concat`.
583
+ """
584
+ if isinstance(left, Dim):
585
+ left = left.get_dim_value_tensor()
586
+ elif isinstance(left, int):
587
+ pass
588
+ elif isinstance(left, Tensor):
589
+ assert left.dims == () # scalar
590
+ else:
591
+ raise TypeError(f"invalid left pad {left}")
592
+ if isinstance(right, Dim):
593
+ right = right.get_dim_value_tensor()
594
+ elif isinstance(right, int):
595
+ pass
596
+ elif isinstance(right, Tensor):
597
+ assert right.dims == () # scalar
598
+ else:
599
+ raise TypeError(f"invalid right pad {right}")
600
+ return Dim(left + middle.get_dim_value_tensor() + right, name="pad")
601
+
602
+
563
603
  def cum_concat_step(
564
604
  source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Optional[Dim] = None
565
605
  ) -> Tuple[Tensor, Dim]:
@@ -867,7 +907,9 @@ def scatter(
867
907
  indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
868
908
  if any(dim.need_masking() for dim in indices_dim):
869
909
  if use_mask is None:
870
- use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
910
+ use_mask = rf.use_mask_default(
911
+ default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
912
+ )
871
913
  if use_mask:
872
914
  source = source.copy_masked(fill_value, dims=indices_dim)
873
915
  else:
returnn/frontend/conv.py CHANGED
@@ -223,7 +223,7 @@ def conv(
223
223
  """
224
224
  if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
225
225
  if use_mask is None:
226
- use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
226
+ use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22, func_name="conv")
227
227
  if use_mask:
228
228
  source = source.copy_masked(0, dims=in_spatial_dims)
229
229
  for in_spatial_dim in in_spatial_dims:
@@ -391,7 +391,9 @@ def transposed_conv(
391
391
  """transposed conv"""
392
392
  if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
393
393
  if use_mask is None:
394
- use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
394
+ use_mask = rf.use_mask_default(
395
+ default=True, default_false_for_behavior_version_up_to=22, func_name="transposed_conv"
396
+ )
395
397
  if use_mask:
396
398
  source = source.copy_masked(0, dims=in_spatial_dims)
397
399
  if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
@@ -503,7 +505,7 @@ def pool(
503
505
 
504
506
  if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
505
507
  if use_mask is None:
506
- use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
508
+ use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22, func_name="pool")
507
509
  if use_mask:
508
510
  source = source.copy_masked({"max": float("-inf"), "avg": 0}[mode], dims=in_spatial_dims)
509
511
  else:
@@ -862,8 +864,9 @@ def _consistent_same_padding(
862
864
  pad_right = (s - 1) * d - pad_left
863
865
  paddings.append((pad_left, pad_right))
864
866
  # We expect that masking was already done before (or we don't care about it), thus handle_dynamic_dims=False.
867
+ out_dims = [(left + middle + right) for middle, (left, right) in zip(in_spatial_dims, paddings)]
865
868
  source, in_spatial_dims = rf.pad(
866
- source, axes=in_spatial_dims, padding=paddings, value=pad_value, handle_dynamic_dims=False
869
+ source, axes=in_spatial_dims, padding=paddings, value=pad_value, handle_dynamic_dims=False, out_dims=out_dims
867
870
  )
868
871
  return source, in_spatial_dims, 0
869
872
 
returnn/frontend/dims.py CHANGED
@@ -3,7 +3,7 @@ Utilities for dimension tags, dimensions, axes.
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Union, TypeVar, Sequence, Tuple
6
+ from typing import TYPE_CHECKING, Optional, Union, TypeVar, Sequence, Tuple
7
7
  from returnn.tensor import Tensor, Dim
8
8
  import returnn.frontend as rf
9
9
  from ._backend import get_backend_by_tensor, global_backend
@@ -25,6 +25,9 @@ __all__ = [
25
25
  "use_mask_default",
26
26
  ]
27
27
 
28
+ if TYPE_CHECKING:
29
+ from returnn.config import Config
30
+
28
31
 
29
32
  def range_over_dim(dim: Dim, *, dtype: Optional[str] = None, device: Optional[str] = None) -> Tensor[T]:
30
33
  """
@@ -309,7 +312,10 @@ def last_frame_position_of_dim(
309
312
 
310
313
 
311
314
  def use_mask_default(
312
- *, default: Optional[bool] = None, default_false_for_behavior_version_up_to: Optional[int] = None
315
+ *,
316
+ default: Optional[bool] = None,
317
+ default_false_for_behavior_version_up_to: Optional[int] = None,
318
+ func_name: Optional[str] = None,
313
319
  ) -> Optional[bool]:
314
320
  """
315
321
  Check the global RETURNN config for the ``rf_use_mask``
@@ -324,20 +330,20 @@ def use_mask_default(
324
330
  and if this is set, and the behavior version is less or equal,
325
331
  then return False by default, i.e. do not use the mask by default, if it is not defined in the config.
326
332
  This takes precedence over `default`.
333
+ :param func_name: if specified, also check
327
334
  :return: what to use for the ``use_mask`` argument by default
328
335
  """
329
336
  from returnn.config import get_global_config
330
337
 
331
338
  config = get_global_config(raise_exception=False)
332
- config_value = None
333
339
  if config:
334
- if "rf_use_mask" in config.typed_dict:
335
- config_value = config.typed_dict["rf_use_mask"]
336
- assert config_value is None or isinstance(config_value, bool)
337
- elif "rf_use_mask" in config.dict:
338
- config_value = config.bool("rf_use_mask", None)
339
- if config_value is not None:
340
- return config_value
340
+ config_value = _get_opt_bool_from_config(config, "rf_use_mask")
341
+ if config_value is not None:
342
+ return config_value
343
+ if func_name:
344
+ config_value = _get_opt_bool_from_config(config, f"rf_use_mask_{func_name}")
345
+ if config_value is not None:
346
+ return config_value
341
347
 
342
348
  if default_false_for_behavior_version_up_to is not None:
343
349
  from returnn.util.basic import BehaviorVersion
@@ -345,3 +351,13 @@ def use_mask_default(
345
351
  if BehaviorVersion.get() <= default_false_for_behavior_version_up_to:
346
352
  return False
347
353
  return default
354
+
355
+
356
+ def _get_opt_bool_from_config(config: Config, key: str) -> Optional[bool]:
357
+ if key in config.typed_dict:
358
+ config_value = config.typed_dict[key]
359
+ assert config_value is None or isinstance(config_value, bool)
360
+ return config_value
361
+ elif key in config.dict:
362
+ return config.bool(key, None)
363
+ return None
returnn/frontend/hooks.py CHANGED
@@ -16,7 +16,7 @@ T = TypeVar("T")
16
16
 
17
17
 
18
18
  def setup_post_hook_on_method(
19
- obj: Any,
19
+ obj: T,
20
20
  attr: str,
21
21
  hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
22
22
  *,
@@ -40,7 +40,7 @@ class MethodWithHooks:
40
40
  """
41
41
 
42
42
  @classmethod
43
- def get(cls, obj: Any, attr: str) -> MethodWithHooks:
43
+ def get(cls, obj: T, attr: str) -> MethodWithHooks:
44
44
  """get existing or init new :class:`MethodWithHooks`"""
45
45
  method = getattr(obj, attr)
46
46
  if not isinstance(method, MethodWithHooks):
@@ -56,7 +56,7 @@ class MethodWithHooks:
56
56
  method.setup()
57
57
  return method
58
58
 
59
- def __init__(self, obj: Any, attr: str):
59
+ def __init__(self, obj: T, attr: str):
60
60
  """
61
61
  :param obj:
62
62
  :param attr:
@@ -218,7 +218,7 @@ class BatchNorm(rf.Module):
218
218
 
219
219
  if any(d.need_masking() for d in source.dims if d != self.in_dim):
220
220
  if self.use_mask is None:
221
- use_mask = rf.use_mask_default(default=True)
221
+ use_mask = rf.use_mask_default(default=True, func_name="BatchNorm")
222
222
  else:
223
223
  use_mask = self.use_mask
224
224
  else:
@@ -71,7 +71,7 @@ def stft(
71
71
  """
72
72
  if in_spatial_dim.need_masking():
73
73
  if use_mask is None:
74
- use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
74
+ use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22, func_name="stft")
75
75
  if use_mask:
76
76
  x = x.copy_masked(0, dims=[in_spatial_dim])
77
77
  fft_length = fft_length or frame_length
@@ -1264,7 +1264,6 @@ class _DimMixin:
1264
1264
  raise TypeError(f"complete_dyn_size: _relu: unexpected type {type(a)}")
1265
1265
 
1266
1266
  y: Optional[_t.Tensor] = None # resulting dyn size
1267
- y_max_value: Optional[_t.Tensor] = None # resulting dyn size max value
1268
1267
  inputs = list(op.inputs)
1269
1268
  assert inputs
1270
1269
  for x_dim in inputs:
@@ -1275,8 +1274,6 @@ class _DimMixin:
1275
1274
  if x_dim.dyn_size_ext is None and x_dim.dimension is None:
1276
1275
  return
1277
1276
  y = _bin_op(y, x_dim.dimension if x_dim.dimension is not None else x_dim.dyn_size_ext)
1278
- if not template_only and y.raw_tensor is not None:
1279
- y_max_value = _bin_op(y_max_value, x_dim.get_dim_value_tensor())
1280
1277
  assert y is not None, f"op {op}?"
1281
1278
  if self.dyn_size_ext is not None:
1282
1279
  assert self.dyn_size_ext.dim_tags == y.dim_tags
@@ -1286,9 +1283,14 @@ class _DimMixin:
1286
1283
  else:
1287
1284
  self.batch = y.batch
1288
1285
  self.dyn_size_ext = y
1289
- if not template_only and y_max_value is not None:
1290
- assert y_max_value is not None and y_max_value.raw_tensor is not None
1291
- self._dyn_size_max_value = y_max_value
1286
+ if not template_only and y.raw_tensor is not None:
1287
+ # Note: Earlier, we had this wrong.
1288
+ # It is not correct to replicate the same math (bin ops)
1289
+ # on the dim values (_dyn_size_max_value of each dim).
1290
+ # Consider sizes1=[2,3], sizes2=[5,4], and the op is "add".
1291
+ # Then the result sizes would be [7,7], thus its max is 7,
1292
+ # but max(sizes1)+max(sizes2)=3+5=8.
1293
+ self._dyn_size_max_value = rf.reduce_max(y, axis=y.dims) if y.dims else y
1292
1294
  if tf and y.placeholder is not None:
1293
1295
  self.set_tag_on_size_tensor(y.placeholder)
1294
1296
 
@@ -2080,6 +2082,8 @@ class _DimMixin:
2080
2082
  :return: self + other. note that this is not commutative, i.e. different from other + self.
2081
2083
  :rtype: Dim
2082
2084
  """
2085
+ if isinstance(other, int) and other == 0:
2086
+ return self
2083
2087
  cache_key = ("add", other)
2084
2088
  cache = self.get_same_base()._make_extra().cache_dim_math
2085
2089
  cache_entry = cache.get(cache_key, None)
@@ -2098,6 +2102,8 @@ class _DimMixin:
2098
2102
  :return: other + self
2099
2103
  :rtype: Dim
2100
2104
  """
2105
+ if isinstance(other, int) and other == 0:
2106
+ return self
2101
2107
  cache_key = ("add_left", other)
2102
2108
  cache = self.get_same_base()._make_extra().cache_dim_math
2103
2109
  cache_entry = cache.get(cache_key, None)
@@ -2115,6 +2121,8 @@ class _DimMixin:
2115
2121
  :param Dim|int other:
2116
2122
  :rtype: Dim
2117
2123
  """
2124
+ if isinstance(other, int) and other == 0:
2125
+ return self
2118
2126
  return self.sub_right(other)
2119
2127
 
2120
2128
  def sub_right(self: Dim, other):
@@ -2123,6 +2131,8 @@ class _DimMixin:
2123
2131
  :return: self - other
2124
2132
  :rtype: Dim
2125
2133
  """
2134
+ if isinstance(other, int) and other == 0:
2135
+ return self
2126
2136
  cache_key = ("sub", other)
2127
2137
  cache = self.get_same_base()._make_extra().cache_dim_math
2128
2138
  cache_entry = cache.get(cache_key, None)
@@ -2141,6 +2151,8 @@ class _DimMixin:
2141
2151
  :return: (-other) + self
2142
2152
  :rtype: Dim
2143
2153
  """
2154
+ if isinstance(other, int) and other == 0:
2155
+ return self
2144
2156
  cache_key = ("sub_left", other)
2145
2157
  cache = self.get_same_base()._make_extra().cache_dim_math
2146
2158
  cache_entry = cache.get(cache_key, None)
@@ -2158,6 +2170,8 @@ class _DimMixin:
2158
2170
  :param Dim|int other:
2159
2171
  :rtype: Dim
2160
2172
  """
2173
+ if isinstance(other, int) and other == 1:
2174
+ return self
2161
2175
  cache_key = ("mul", other)
2162
2176
  cache = self.get_same_base()._make_extra().cache_dim_math
2163
2177
  cache_entry = cache.get(cache_key, None)
@@ -2175,6 +2189,8 @@ class _DimMixin:
2175
2189
  :param Dim|int other:
2176
2190
  :rtype: Dim
2177
2191
  """
2192
+ if isinstance(other, int) and other == 1:
2193
+ return self
2178
2194
  cache_key = ("mul_left", other)
2179
2195
  cache = self.get_same_base()._make_extra().cache_dim_math
2180
2196
  cache_entry = cache.get(cache_key, None)
@@ -2192,6 +2208,8 @@ class _DimMixin:
2192
2208
  :param Dim|int other:
2193
2209
  :rtype: Dim
2194
2210
  """
2211
+ if isinstance(other, int) and other == 1:
2212
+ return self
2195
2213
  cache_key = ("floordiv", other)
2196
2214
  cache = self.get_same_base()._make_extra().cache_dim_math
2197
2215
  cache_entry = cache.get(cache_key, None)
@@ -2209,6 +2227,8 @@ class _DimMixin:
2209
2227
  :param Dim|int other:
2210
2228
  :rtype: Dim
2211
2229
  """
2230
+ if isinstance(other, int) and other == 1:
2231
+ return self
2212
2232
  return self.div_right(other)
2213
2233
 
2214
2234
  def div_left(self: Dim, other):
@@ -2216,6 +2236,8 @@ class _DimMixin:
2216
2236
  :param Dim|int other:
2217
2237
  :rtype: Dim
2218
2238
  """
2239
+ if isinstance(other, int) and other == 1:
2240
+ return self
2219
2241
  cache_key = ("truediv_left", other)
2220
2242
  cache = self.get_same_base()._make_extra().cache_dim_math
2221
2243
  cache_entry = cache.get(cache_key, None)
@@ -2233,6 +2255,8 @@ class _DimMixin:
2233
2255
  :param Dim|int other:
2234
2256
  :rtype: Dim
2235
2257
  """
2258
+ if isinstance(other, int) and other == 1:
2259
+ return self
2236
2260
  cache_key = ("truediv", other)
2237
2261
  cache = self.get_same_base()._make_extra().cache_dim_math
2238
2262
  cache_entry = cache.get(cache_key, None)
@@ -2250,6 +2274,8 @@ class _DimMixin:
2250
2274
  :param Dim|int other:
2251
2275
  :rtype: Dim
2252
2276
  """
2277
+ if isinstance(other, int) and other == 1:
2278
+ return self
2253
2279
  cache_key = ("ceildiv_left", other)
2254
2280
  cache = self.get_same_base()._make_extra().cache_dim_math
2255
2281
  cache_entry = cache.get(cache_key, None)
@@ -2267,6 +2293,8 @@ class _DimMixin:
2267
2293
  :param Dim|int other:
2268
2294
  :rtype: Dim
2269
2295
  """
2296
+ if isinstance(other, int) and other == 1:
2297
+ return self
2270
2298
  cache_key = ("ceildiv", other)
2271
2299
  cache = self.get_same_base()._make_extra().cache_dim_math
2272
2300
  cache_entry = cache.get(cache_key, None)
returnn/util/basic.py CHANGED
@@ -1693,15 +1693,17 @@ def inplace_increment(x: numpy.ndarray, idx: numpy.ndarray, y: Union[numpy.ndarr
1693
1693
  raise NotImplementedError("This feature was removed with dropped Theano support")
1694
1694
 
1695
1695
 
1696
- def prod(ls):
1696
+ def prod(ls: Union[Iterable[T], numpy.ndarray]) -> Union[int, T, float]:
1697
1697
  """
1698
- :param list[T]|tuple[T]|numpy.ndarray ls:
1699
- :rtype: T|int|float
1698
+ :param ls:
1699
+ :return: ls[0] * ls[1] * ...
1700
1700
  """
1701
- if len(ls) == 0:
1701
+ it = iter(ls)
1702
+ try:
1703
+ x = next(it)
1704
+ except StopIteration:
1702
1705
  return 1
1703
- x = ls[0]
1704
- for y in ls[1:]:
1706
+ for y in it:
1705
1707
  x = x * y # *= doesn't work because x might be a tensor, and for e.g. torch.Tensor this op is in-place
1706
1708
  return x
1707
1709
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250828.2732
3
+ Version: 1.20250829.151916
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=40gq368ieKKYuTEYmkvPOiakDG6N6_RPtPd5pBGCoQ0,5213
1
+ returnn/PKG-INFO,sha256=69VkJGV1pqtymng35zrww_x1nRV1ItxPtwRUITKi0OA,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=b0sPsBSCJObrNcEG_PgvBTHyCMvourjvAfvgeL7uEAk,77
6
+ returnn/_setup_info_generated.py,sha256=i5XkrJfU4e2pYDZFMN1NXjv0difTuwrUKEBqaiMDWr0,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -76,11 +76,11 @@ returnn/extern/graph_editor/transform.py,sha256=qMGSenpbAnGqdG6QP6iWjlm6_ccySYJa
76
76
  returnn/extern/graph_editor/util.py,sha256=HfRbyQPmQ6_n5-O-096n0KeJtllQXFtaurpeJS_URZ0,18706
77
77
  returnn/frontend/__init__.py,sha256=2aS7nbxXniIrBp2DODl0xN0f3IJ_dX4Bi9ZlR7W5_DE,1472
78
78
  returnn/frontend/_backend.py,sha256=39l5MC1DaT0MPklMM8HXAW9nqisIIZQ9g2QSHOOtPQE,50741
79
- returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,8403
79
+ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,8550
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
83
- returnn/frontend/array_.py,sha256=5aCU-BCBH035QZlpYSRGtVXhxz78tteZ75e57FxCIRw,52182
83
+ returnn/frontend/array_.py,sha256=7uX5-Os2OyYUfC5soprIUx7rr-371yKf9DcckRKONXY,53855
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -88,14 +88,14 @@ returnn/frontend/cond.py,sha256=gh6wg0aSbAJQfKRv4BQAu-EfPWtWPLFjgc8IaPPFmwg,1023
88
88
  returnn/frontend/const.py,sha256=A5fP9w6Akv56d89pPvdoZaXvC9ZTYcexepnS9O2clOc,3945
89
89
  returnn/frontend/container.py,sha256=wF3OlQN7WlOVmmdapUth_Unha3DVf6h1B7okBJAuJDA,8011
90
90
  returnn/frontend/control_flow_ctx.py,sha256=v17CsNwRnZYe8GdMtGJt2ftibfxMCGK1i0l-GX5ILu0,699
91
- returnn/frontend/conv.py,sha256=Q0q90-uu9d6qV-v8_DlFGxpZtc6FjfXVpfkkXmv1Alk,31959
91
+ returnn/frontend/conv.py,sha256=4Mrq7MFc0f7SJ8g-wJEv4Lg3Stmju-fMwD09qKv6CuQ,32174
92
92
  returnn/frontend/device.py,sha256=Sjara0EmFLhu9O55cN_p6OwU0NgdNCCQjyAuQhiWpGw,1437
93
- returnn/frontend/dims.py,sha256=aH5FQ_m0xMD6Rj-BUWGx8lB-HkCuwZfMBf6mZbGGW5E,12611
93
+ returnn/frontend/dims.py,sha256=_HDU-Kxn3pApicFkm0F4Fs-ZAuF1gKXG8rroQHCFQQI,13073
94
94
  returnn/frontend/dropout.py,sha256=TjqZCKDIOBeHr14-NCemOm9m3p84LxQuPH1DvRAYg88,5028
95
95
  returnn/frontend/dtype.py,sha256=Ooc5BrcNrTp6XShuFEV9g5V6-niuy4ImP_Lt_Qgq3jE,1886
96
96
  returnn/frontend/gradient.py,sha256=G-Qv4gKGHYEeB92Zwco9ao4qjd6umZPUzQC4J-fbYWo,4033
97
97
  returnn/frontend/graph.py,sha256=PIv901WZ1rfTV0QGkyzBv6UxfWk9NsLGxdoJ5x9-8Xg,1818
98
- returnn/frontend/hooks.py,sha256=jYPbsb4gy5HORRZvKTEJbLcoJri5hOt5ADbhnTCytQo,5507
98
+ returnn/frontend/hooks.py,sha256=L7ITrlEQ6JUy8fEBE0SXg1dzFNkLrgb8gxZm88fxryU,5501
99
99
  returnn/frontend/init.py,sha256=bVB7bpghaY8DI_HL0mkB_9z95onWnIX2zlW4hlMYnRw,7494
100
100
  returnn/frontend/label_smoothing.py,sha256=lxmaowNr61sCMzMewqHhu1r0CcklYfhLXlFnBu8DeAU,5676
101
101
  returnn/frontend/linear.py,sha256=xRUjnkD3MTWDezSaYATBYJQ2fa1RhKMNrTuhC54hhVs,2252
@@ -105,7 +105,7 @@ returnn/frontend/math_.py,sha256=A_RkZ5lH2uXMchfPIH3itraWtMNNCVckQHHpf7aIIZQ,172
105
105
  returnn/frontend/matmul.py,sha256=xkueyxzSDz8MsYaWxPSjmV2Yy-tcaiOQDXbFt1IQM2A,1944
106
106
  returnn/frontend/module.py,sha256=219rh5mE0CD0-NdxXLsKyhv3BNtOI9jSyiI1Rb8MOyU,10700
107
107
  returnn/frontend/nested.py,sha256=P84u_cjoYdYRJ_0Cbt0vlKXxskmXTDfsnw_vFCCNKtU,15107
108
- returnn/frontend/normalization.py,sha256=-lYJ9IWcheOQu1gXJehSOA76qgVtxd1C07Jqps6Qg1o,14116
108
+ returnn/frontend/normalization.py,sha256=NrIIaZ3c2yf-WH2R9lPaL2TAq4IcNQc4OE5kFYdoihw,14139
109
109
  returnn/frontend/parameter.py,sha256=zvrkhSYC1c_O9kVwgHvOtOnWNurl5J28lkS0i1LQpWU,10627
110
110
  returnn/frontend/parametrizations.py,sha256=ptNgBw5IiPXVpB3QGse7AGAhdXp8X1rCqYUl2Mae8aI,2876
111
111
  returnn/frontend/parametrize.py,sha256=VhgTEP7ehON950Q4bkCy8rvg9641moEKAXn0XzomK6E,7216
@@ -114,7 +114,7 @@ returnn/frontend/rand.py,sha256=2x7AHSYH_tZkzTk_q3t3GA_yYRNeKsVbJjw2InqSGDk,1354
114
114
  returnn/frontend/rec.py,sha256=6YSsSG7fdtfvvg24vmexSg8R2aVCcKHBdGLh-Mgn9Co,8037
115
115
  returnn/frontend/reduce.py,sha256=gRSvBJZNHa757IqBxGw4hu5eiO3pjie_ptEwUXHLSCs,10340
116
116
  returnn/frontend/run_ctx.py,sha256=yyOMUCKTOe19C4z2Nfly4YCLBmQ9ihip6nGrkW-Y6qg,23789
117
- returnn/frontend/signal.py,sha256=hfDipDhO0n9nXhGy7txwYUNbvg28NqkFq9p0Jq46f9c,4411
117
+ returnn/frontend/signal.py,sha256=iBRO2ywpJOjIUfVveJaqX4NT59013VCoE49IHkVn6p8,4429
118
118
  returnn/frontend/state.py,sha256=EePdrx6PtWL4mJ2XZmGlh5dl4nq6G9wZpqP4hdDEzfY,2935
119
119
  returnn/frontend/stepwise_scheduler.py,sha256=fMOTR7npGCDXrXDmSQ4VwmudoHEbY3Yr-QGyjFdQJSc,927
120
120
  returnn/frontend/tensor_array.py,sha256=Ej7CHtvpY0yBROlAk5vFe3CTXh-iAuqu9qcXS3Qxt2I,4328
@@ -154,7 +154,7 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
154
154
  returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
155
155
  returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
156
156
  returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
157
- returnn/tensor/_dim_extra.py,sha256=oxJgPxrYQvew8MrFcYo5YjbKFC7Dd2yR1kcGWAf0afg,122380
157
+ returnn/tensor/_dim_extra.py,sha256=VN7Smn1Q0Y0DO7GSPM-aJUhp_jy5pzSMJbPkCk6JnqY,123448
158
158
  returnn/tensor/_tensor_extra.py,sha256=gbSl6HMtn8WFYloanew_RaNNwx3eCpnKv3UfCkntJiQ,164923
159
159
  returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
160
160
  returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
@@ -233,7 +233,7 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
233
233
  returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
234
234
  returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
235
235
  returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
236
- returnn/util/basic.py,sha256=UjHujX9pSu_dOgTxozWD0ujj5eSpyj_zD5vFU6bfyms,143096
236
+ returnn/util/basic.py,sha256=S2ABKcP0pf2UexuMXDNHGcfAu7GDSD2mr6OIByM152M,143168
237
237
  returnn/util/better_exchook.py,sha256=39yvRecluDgYhViwSkaQ8crJ_cBWI63KeEGuK4RKe5w,70843
238
238
  returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
239
239
  returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250828.2732.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250828.2732.dist-info/METADATA,sha256=40gq368ieKKYuTEYmkvPOiakDG6N6_RPtPd5pBGCoQ0,5213
258
- returnn-1.20250828.2732.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250828.2732.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250828.2732.dist-info/RECORD,,
256
+ returnn-1.20250829.151916.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250829.151916.dist-info/METADATA,sha256=69VkJGV1pqtymng35zrww_x1nRV1ItxPtwRUITKi0OA,5215
258
+ returnn-1.20250829.151916.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250829.151916.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250829.151916.dist-info/RECORD,,