returnn 1.20250304.10039__py3-none-any.whl → 1.20250304.113330__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250304.10039
3
+ Version: 1.20250304.113330
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250304.010039'
2
- long_version = '1.20250304.010039+git.3e53d74'
1
+ version = '1.20250304.113330'
2
+ long_version = '1.20250304.113330+git.acf09da'
@@ -1223,7 +1223,7 @@ class Backend(Generic[T]):
1223
1223
  out_spatial_dims: Optional[Sequence[Dim]] = None,
1224
1224
  filter: Tensor,
1225
1225
  filter_size: Sequence[Dim], # to have the order well-defined
1226
- padding: str,
1226
+ padding: Union[str, int, Sequence[int]],
1227
1227
  strides: Optional[Union[int, Sequence[int]]] = None,
1228
1228
  dilation_rate: Optional[Union[int, Sequence[int]]] = None,
1229
1229
  groups: Optional[int] = None,
@@ -1258,7 +1258,7 @@ class Backend(Generic[T]):
1258
1258
  *,
1259
1259
  mode: str,
1260
1260
  pool_size: Sequence[int],
1261
- padding: str = "valid",
1261
+ padding: Union[str, int, Sequence[int]] = "valid",
1262
1262
  dilation_rate: Union[Sequence[int], int] = 1,
1263
1263
  strides: Sequence[int],
1264
1264
  in_spatial_dims: Sequence[Dim],
returnn/frontend/conv.py CHANGED
@@ -181,15 +181,46 @@ def conv(
181
181
  in_spatial_dims: Sequence[Dim],
182
182
  out_spatial_dims: Optional[Sequence[Dim]] = None,
183
183
  filter: Tensor,
184
- filter_size: Sequence[Dim], # to have the order well-defined
185
- padding: str,
184
+ filter_size: Sequence[Dim],
185
+ padding: Union[str, int, Sequence[int]],
186
186
  strides: Optional[Union[int, Sequence[int]]] = None,
187
187
  dilation_rate: Optional[Union[int, Sequence[int]]] = None,
188
188
  groups: Optional[int] = None,
189
189
  bias: Optional[Tensor] = None,
190
190
  use_mask: Optional[bool] = None,
191
191
  ) -> Tuple[Tensor, Sequence[Dim]]:
192
- """convolution"""
192
+ """
193
+ Generic N-D convolution.
194
+
195
+ :param source:
196
+ :param in_dim: input channels
197
+ :param out_dim: output channels
198
+ :param in_spatial_dims: On what dimensions to operate on.
199
+ The number of specified dims (1, 2 or 3) specifies whether this is 1D, 2D or 3D convolution.
200
+ The order is consistent with the order of the ``filter_size``, ``strides``, etc.
201
+ :param out_spatial_dims:
202
+ :param filter:
203
+ :param filter_size: defines the order of dims in ``filter``
204
+ such that it matches the order of ``in_spatial_dims``.
205
+ :param padding: "valid" or "same" or int. "valid" is like padding=0.
206
+ padding="same" will pad such that the output has the same spatial dimensions as the input
207
+ (in case of stride=1), or otherwise ceildiv(input, stride).
208
+ The specific padding in padding="same" with stride>1 has changed with behavior version >=24
209
+ (or global config option ``rf_use_consistent_same_padding``)
210
+ and is now consistent independent of dimension size.
211
+ See :func:`_consistent_same_padding` for more details.
212
+ :param strides: the default (if it is None) is 1
213
+ :param dilation_rate:
214
+ :param groups:
215
+ :param bias:
216
+ :param use_mask: Whether to mask the input tensor based on seq lengths
217
+ such that the padding in the padded tensor is ignored
218
+ (it will mask with 0).
219
+ With behavior version >=23, this is enabled by default,
220
+ or configured with global config option ``rf_use_mask``.
221
+ (Also see :func:`use_mask_default`).
222
+ :return: out, out_spatial_dims
223
+ """
193
224
  if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
194
225
  if use_mask is None:
195
226
  use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
@@ -198,6 +229,10 @@ def conv(
198
229
  for in_spatial_dim in in_spatial_dims:
199
230
  if in_spatial_dim not in source.dims:
200
231
  raise ValueError(f"conv: source {source} does not have spatial dim {in_spatial_dim}")
232
+ if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
233
+ source, in_spatial_dims, padding = _consistent_same_padding(
234
+ source, in_spatial_dims=in_spatial_dims, filter_size=filter_size, dilation_rate=dilation_rate, pad_value=0
235
+ )
201
236
  # noinspection PyProtectedMember
202
237
  out, out_spatial_dims = source._raw_backend.conv(
203
238
  source,
@@ -359,6 +394,9 @@ def transposed_conv(
359
394
  use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
360
395
  if use_mask:
361
396
  source = source.copy_masked(0, dims=in_spatial_dims)
397
+ if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
398
+ # I don't really know what this should mean here... Investigate this further...
399
+ raise NotImplementedError("consistent same padding not implemented for transposed conv")
362
400
  # noinspection PyProtectedMember
363
401
  out, out_spatial_dims = source._raw_backend.transposed_conv(
364
402
  source=source,
@@ -409,7 +447,7 @@ def pool(
409
447
  nd: Optional[int] = None,
410
448
  mode: str,
411
449
  pool_size: Union[Sequence[int], int],
412
- padding: str = "valid",
450
+ padding: Union[str, int, Sequence[int]] = "valid",
413
451
  dilation_rate: Union[Sequence[int], int] = 1,
414
452
  strides: Optional[Union[Sequence[int], int]] = None,
415
453
  in_spatial_dims: Union[Sequence[Dim], Dim],
@@ -417,19 +455,29 @@ def pool(
417
455
  use_mask: Optional[bool] = None,
418
456
  ) -> Tuple[Tensor, Sequence[Dim]]:
419
457
  """
420
- A generic N-D pooling layer.
421
- This would usually be done after a convolution for down-sampling.
458
+ Generic N-D pooling.
422
459
 
423
460
  :param source:
424
461
  :param nd:
425
462
  :param mode: "max" or "avg"
426
463
  :param pool_size: shape of the window of each reduce
427
- :param padding: "valid" or "same"
464
+ :param padding: "valid" or "same" or int. "valid" is like padding=0.
465
+ padding="same" will pad such that the output has the same spatial dimensions as the input
466
+ (in case of stride=1), or otherwise ceildiv(input, stride).
467
+ The specific padding in padding="same" with stride>1 has changed with behavior version >=24
468
+ (or global config option ``rf_use_consistent_same_padding``)
469
+ and is now consistent independent of dimension size.
470
+ See :func:`_consistent_same_padding` for more details.
428
471
  :param dilation_rate:
429
- :param strides: in contrast to tf.nn.pool, the default (if it is None) will be set to pool_size
472
+ :param strides: the default (if it is None) will be set to pool_size (in contrast to :func:`conv`)
430
473
  :param in_spatial_dims:
431
474
  :param out_spatial_dims:
432
- :param use_mask:
475
+ :param use_mask: Whether to mask the input tensor based on seq lengths
476
+ such that the padding in the padded tensor is ignored
477
+ (for max-pooling, it will mask with -inf, for avg-pooling with 0).
478
+ With behavior version >=23, this is enabled by default,
479
+ or configured with global config option ``rf_use_mask``.
480
+ (Also see :func:`use_mask_default`).
433
481
  :return: out, out_spatial_dims
434
482
  """
435
483
  if isinstance(in_spatial_dims, Dim):
@@ -451,8 +499,7 @@ def pool(
451
499
  strides = pool_size
452
500
  elif isinstance(strides, int):
453
501
  strides = [strides] * nd
454
- assert isinstance(strides, (list, tuple))
455
- assert len(strides) == nd
502
+ assert isinstance(strides, (list, tuple)) and len(strides) == nd and all(isinstance(s, int) for s in strides)
456
503
 
457
504
  if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
458
505
  if use_mask is None:
@@ -462,6 +509,15 @@ def pool(
462
509
  else:
463
510
  use_mask = False
464
511
 
512
+ if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
513
+ source, in_spatial_dims, padding = _consistent_same_padding(
514
+ source,
515
+ in_spatial_dims=in_spatial_dims,
516
+ filter_size=pool_size,
517
+ dilation_rate=dilation_rate,
518
+ pad_value={"max": float("-inf"), "avg": 0}[mode],
519
+ )
520
+
465
521
  # noinspection PyProtectedMember
466
522
  out, out_spatial_dims = source._raw_backend.pool(
467
523
  source=source,
@@ -642,7 +698,7 @@ def make_conv_out_spatial_dims(
642
698
  in_spatial_dims: Sequence[Dim],
643
699
  *,
644
700
  filter_size: Union[Sequence[Union[int, Dim]], int, Dim],
645
- padding: str,
701
+ padding: Union[str, int, Sequence[int]],
646
702
  strides: Union[Sequence[int], int] = 1,
647
703
  dilation_rate: Union[Sequence[int], int] = 1,
648
704
  description_prefix: Optional[str] = None,
@@ -658,11 +714,15 @@ def make_conv_out_spatial_dims(
658
714
  if isinstance(dilation_rate, int):
659
715
  dilation_rate = [dilation_rate] * nd
660
716
  assert nd == len(in_spatial_dims) == len(filter_size) == len(strides) == len(dilation_rate)
661
- assert padding.lower() in ("valid", "same")
717
+ if isinstance(padding, (int, str)):
718
+ padding = [padding] * nd
719
+ padding = [p.lower() if isinstance(p, str) else p for p in padding]
662
720
  out_spatial_dims = []
663
721
  for i in range(nd):
664
722
  in_spatial_dim = in_spatial_dims[i]
665
- if filter_size[i] == strides[i] == 1 or (strides[i] == 1 and padding.lower() == "same"):
723
+ if (filter_size[i] == strides[i] == 1 and padding[i] in ("valid", "same", 0)) or (
724
+ strides[i] == 1 and padding[i] == "same"
725
+ ):
666
726
  out_spatial_dims.append(in_spatial_dim)
667
727
  else:
668
728
  out_spatial_dim = _calc_out_dim(
@@ -670,7 +730,7 @@ def make_conv_out_spatial_dims(
670
730
  filter_size=filter_size[i],
671
731
  stride=strides[i],
672
732
  dilation_rate=dilation_rate[i],
673
- padding=padding,
733
+ padding=padding[i],
674
734
  )
675
735
  assert isinstance(out_spatial_dim, Dim)
676
736
  if description_prefix and out_spatial_dim != in_spatial_dim:
@@ -681,7 +741,7 @@ def make_conv_out_spatial_dims(
681
741
  filter_size=filter_size[i],
682
742
  stride=strides[i],
683
743
  dilation_rate=dilation_rate[i],
684
- padding=padding,
744
+ padding=padding[i],
685
745
  )
686
746
  out_spatial_dims.append(out_spatial_dim)
687
747
  return out_spatial_dims
@@ -695,7 +755,7 @@ def _calc_out_dim(in_dim, filter_size, stride, padding, dilation_rate=1):
695
755
  :param int filter_size: e.g. 2, for the corresponding axis
696
756
  :param int stride: e.g. 1, for the corresponding axis
697
757
  :param int dilation_rate: e.g. 1
698
- :param str padding: "valid" or "same"
758
+ :param str|int padding: "valid" or "same" or int
699
759
  :return: the output dimension
700
760
  :rtype: T
701
761
  """
@@ -712,13 +772,16 @@ def _calc_out_dim(in_dim, filter_size, stride, padding, dilation_rate=1):
712
772
  return rf.ceil_divide(a, b)
713
773
  return -(-a // b)
714
774
 
715
- padding = padding.upper()
775
+ padding = padding.lower() if isinstance(padding, str) else padding
716
776
  # See tf.compat.v1.nn.convolution() documentation for more.
717
- if padding == "SAME":
777
+ if padding == "same":
718
778
  if isinstance(in_dim, Dim):
719
779
  return in_dim.ceildiv_right(stride)
720
780
  return ceildiv(in_dim, stride)
721
- elif padding == "VALID":
781
+ elif padding == "valid" or isinstance(padding, int):
782
+ if isinstance(padding, int) and padding != 0:
783
+ assert padding > 0
784
+ in_dim = padding + in_dim + padding
722
785
  if isinstance(in_dim, Dim):
723
786
  filter_left_dilated = (filter_size - 1) * dilation_rate // 2
724
787
  filter_right_dilated = (filter_size - 1) * dilation_rate - filter_left_dilated
@@ -726,4 +789,95 @@ def _calc_out_dim(in_dim, filter_size, stride, padding, dilation_rate=1):
726
789
  return valid_part.ceildiv_right(stride)
727
790
  return ceildiv(in_dim - (filter_size - 1) * dilation_rate, stride)
728
791
  else:
729
- raise Exception("invalid padding %r" % padding)
792
+ raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__})")
793
+
794
+
795
+ def _should_use_consistent_same_padding() -> bool:
796
+ """
797
+ :return: whether to use the new consistent same padding with :func:`_consistent_same_padding`.
798
+
799
+ This is only needed for the case when we have striding and padding="same".
800
+ See :func:`_consistent_same_padding` for more details.
801
+
802
+ Check the global RETURNN config for the ``rf_use_consistent_same_padding``
803
+ on how we should handle the ``padding="same"`` case for convolution/pooling when there is striding.
804
+ If that is not specified, with behavior version >=24, we will use the new consistent same padding,
805
+ with behavior version <=23, we will not use it.
806
+
807
+ See issue `#1693 <https://github.com/rwth-i6/returnn/issues/1693>`__.
808
+ """
809
+ from returnn.config import get_global_config
810
+
811
+ config = get_global_config(raise_exception=False)
812
+ config_value = None
813
+ if config:
814
+ if "rf_use_consistent_same_padding" in config.typed_dict:
815
+ config_value = config.typed_dict["rf_use_consistent_same_padding"]
816
+ assert config_value is None or isinstance(config_value, bool)
817
+ elif "rf_use_consistent_same_padding" in config.dict:
818
+ config_value = config.bool("rf_use_consistent_same_padding", None)
819
+ if config_value is not None:
820
+ return config_value
821
+
822
+ from returnn.util.basic import BehaviorVersion
823
+
824
+ return BehaviorVersion.get() >= 24
825
+
826
+
827
+ def _consistent_same_padding(
828
+ source: Tensor,
829
+ *,
830
+ in_spatial_dims: Sequence[Dim],
831
+ filter_size: Optional[Union[int, Dim, Sequence[int], Sequence[Dim]]],
832
+ dilation_rate: Optional[Union[int, Sequence[int]]] = None,
833
+ pad_value: Union[int, float],
834
+ ) -> Tuple[Tensor, Sequence[Dim], Union[int, Sequence[int]]]:
835
+ """
836
+ In case of striding and padding="same", the standard padding that we do (following TensorFlow)
837
+ depends on the current dimension size.
838
+ It adds padding left and right such that the first and last window
839
+ will have the same amount of padding (+-1).
840
+ With stride=1, this is the standard (filter_size-1)/2 left and right padding,
841
+ but with stride>1, this is not the case anymore.
842
+ (See also the explanation and calculation of padding in :func:`returnn.torch.frontend._backend.TorchBackend.conv`.)
843
+ However, the problem with this behavior is with batching:
844
+ The padding now depends on the longest sequence in the batch,
845
+ and thus is arbitrary for any of the other sequences.
846
+
847
+ The new consistent same padding adds padding independent of the current dimension size (largest seq in batch).
848
+ We just do the same as with stride=1, i.e. (filter_size-1)/2 left and right padding.
849
+
850
+ :return: source or padded source, in_spatial_dims or new in_spatial_dims, new padding on top of the output
851
+ """
852
+ filter_size = _make_sequence(filter_size or 1, nd=len(in_spatial_dims))
853
+ dilation_rate = _make_sequence(dilation_rate or 1, nd=len(in_spatial_dims))
854
+ filter_size_ints = [s.dimension if isinstance(s, Dim) else s for s in filter_size]
855
+ if all(s % 2 == 1 for s in filter_size_ints):
856
+ # In this case, we can pass padding as integer to the backend, so that it adds the same padding left/right.
857
+ return source, in_spatial_dims, [(s // 2) * d for s, d in zip(filter_size_ints, dilation_rate)]
858
+ # Need to use the custom padding here.
859
+ paddings = []
860
+ for s, d in zip(filter_size, dilation_rate):
861
+ pad_left = (s - 1) * d // 2
862
+ pad_right = (s - 1) * d - pad_left
863
+ paddings.append((pad_left, pad_right))
864
+ # We expect that masking was already done before (or we don't care about it), thus handle_dynamic_dims=False.
865
+ source, in_spatial_dims = rf.pad(
866
+ source, axes=in_spatial_dims, padding=paddings, value=pad_value, handle_dynamic_dims=False
867
+ )
868
+ return source, in_spatial_dims, 0
869
+
870
+
871
+ def _make_sequence(value: Union[int, Sequence[int]], *, nd: int) -> Sequence[int]:
872
+ if isinstance(value, int):
873
+ return [value] * nd
874
+ assert len(value) == nd
875
+ return value
876
+
877
+
878
+ def _any_is_non_default(single_or_seq: Optional[Union[int, Sequence[int]]], *, default: int) -> bool:
879
+ if single_or_seq is None:
880
+ return False
881
+ if isinstance(single_or_seq, int):
882
+ return single_or_seq != default
883
+ return any(i != default for i in single_or_seq)
@@ -998,7 +998,7 @@ class ReturnnLayersBackend(Backend[Layer]):
998
998
  out_spatial_dims: Optional[Sequence[Dim]] = None,
999
999
  filter: Tensor,
1000
1000
  filter_size: Sequence[Dim], # to have the order well-defined
1001
- padding: str,
1001
+ padding: Union[str, int, Sequence[int]],
1002
1002
  strides: Optional[Union[int, Sequence[int]]] = None,
1003
1003
  dilation_rate: Optional[Union[int, Sequence[int]]] = None,
1004
1004
  groups: Optional[int] = None,
@@ -1088,7 +1088,7 @@ class ReturnnLayersBackend(Backend[Layer]):
1088
1088
  *,
1089
1089
  mode: str,
1090
1090
  pool_size: Sequence[int],
1091
- padding: str = "valid",
1091
+ padding: Union[str, int, Sequence[int]] = "valid",
1092
1092
  dilation_rate: Union[Sequence[int], int] = 1,
1093
1093
  strides: Sequence[int],
1094
1094
  in_spatial_dims: Sequence[Dim],
@@ -4184,7 +4184,9 @@ class PadLayer(_ConcatInputLayer):
4184
4184
  self,
4185
4185
  *,
4186
4186
  axes: Union[Dim, str, Sequence[Union[Dim, str]]],
4187
- padding: Union[int, Tuple[int, int], Sequence[Tuple[int, int]]],
4187
+ padding: Union[
4188
+ int, Dim, Tuple[Union[int, Dim], Union[int, Dim]], Sequence[Tuple[Union[int, Dim], Union[int, Dim]]]
4189
+ ],
4188
4190
  out_dims: Optional[Union[Dim, Sequence[Dim]]] = None,
4189
4191
  handle_dynamic_dims: Optional[bool] = None,
4190
4192
  value: Union[int, float] = 0,
@@ -4211,7 +4213,10 @@ class PadLayer(_ConcatInputLayer):
4211
4213
  padding = self._transform_padding(padding=padding, axes=axes)
4212
4214
  paddings = [(0, 0)] * len(range(self.input_data.batch_ndim))
4213
4215
  for i, a in enumerate(axes):
4214
- paddings[a] = padding[i]
4216
+ pad_left, pad_right = padding[i]
4217
+ pad_left = pad_left.dimension if isinstance(pad_left, Dim) else pad_left
4218
+ pad_right = pad_right.dimension if isinstance(pad_right, Dim) else pad_right
4219
+ paddings[a] = (pad_left, pad_right)
4215
4220
  mode = mode.lower()
4216
4221
  if handle_dynamic_dims is None:
4217
4222
  handle_dynamic_dims = self._handle_dynamic_dims_default(
@@ -4219,7 +4224,7 @@ class PadLayer(_ConcatInputLayer):
4219
4224
  padding=padding,
4220
4225
  mode=mode,
4221
4226
  )
4222
- if all(sum(p) == 0 for p in padding):
4227
+ if all(left == right == 0 for left, right in paddings):
4223
4228
  self.output.placeholder = self.input_data.placeholder
4224
4229
  elif mode == "replication":
4225
4230
  self.output.placeholder = tf_util.pad_replicate(self.input_data.placeholder, axes, padding)
@@ -4227,7 +4232,7 @@ class PadLayer(_ConcatInputLayer):
4227
4232
  self.output.placeholder = tf.pad(
4228
4233
  self.input_data.placeholder, paddings=paddings, mode=mode, constant_values=value
4229
4234
  )
4230
- if all(right == 0 for left, right in padding) and mode != "circular":
4235
+ if all(right == 0 for left, right in paddings) and mode != "circular":
4231
4236
  pass # no masking needed
4232
4237
  else:
4233
4238
  import returnn.frontend as rf
@@ -4257,9 +4262,9 @@ class PadLayer(_ConcatInputLayer):
4257
4262
  @classmethod
4258
4263
  def _transform_padding(cls, padding, axes):
4259
4264
  """
4260
- :param list[(int,int)]|(int,int)|int padding:
4265
+ :param Sequence[(int|Dim,int|Dim)]|(int|Dim,int|Dim)|int|Dim padding:
4261
4266
  :param list[int] axes:
4262
- :rtype: list[(int,int)]
4267
+ :rtype: Sequence[(int|Dim,int|Dim)]
4263
4268
  """
4264
4269
  if isinstance(padding, (list, tuple)):
4265
4270
  if isinstance(padding[0], (list, tuple)):
@@ -4316,9 +4321,9 @@ class PadLayer(_ConcatInputLayer):
4316
4321
  """
4317
4322
  :param str name:
4318
4323
  :param list[LayerBase] sources:
4319
- :param Dim|str|list[Dim|str] axes:
4320
- :param list[(int,int)]|(int,int)|int padding:
4321
- :param Dim|list[Dim]|None out_dims:
4324
+ :param Dim|str|Sequence[Dim|str] axes:
4325
+ :param Sequence[(int|Dim,int|Dim)]|(int|Dim,int|Dim)|int|Dim padding:
4326
+ :param Dim|Sequence[Dim]|None out_dims:
4322
4327
  :rtype: Data
4323
4328
  """
4324
4329
  from ..util.data import Dim
@@ -6223,7 +6228,7 @@ class ConvLayer(_ConcatInputLayer):
6223
6228
  for 1D/2D/3D conv.
6224
6229
  The input data ndim must match, or you can add dimensions via input_expand_dims or input_add_feature_dim.
6225
6230
  It will automatically swap the batch-dim to the first axis of the input data.
6226
- :param str padding: "same", "valid" or "same_static".
6231
+ :param str|int|Sequence[int] padding: "same", "valid" or "same_static".
6227
6232
  "same_static" is calculated differently depending on whether an axis is static or dynamic.
6228
6233
  For static axes, "same_static" padding is the same as "same" padding,
6229
6234
  i.e. filter_size - 1 - (T + strides - 1) % strides.
@@ -6261,8 +6266,10 @@ class ConvLayer(_ConcatInputLayer):
6261
6266
  """
6262
6267
  from returnn.util import BehaviorVersion
6263
6268
 
6264
- padding = padding.upper()
6265
- assert padding in ["SAME", "VALID", "SAME_STATIC"], "no other padding supported at the moment"
6269
+ padding = padding.upper() if isinstance(padding, str) else padding
6270
+ assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance(
6271
+ padding, (int, tuple, list)
6272
+ ), f"{self}: got unsupported padding {padding}"
6266
6273
  assert "out_type" not in kwargs, "don't set out_type explicitly for this layer"
6267
6274
  assert len(filter_size) in (1, 2, 3), "only 1D conv, 2D conv or 3D conv supported"
6268
6275
  super(ConvLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs)
@@ -6398,6 +6405,17 @@ class ConvLayer(_ConcatInputLayer):
6398
6405
  out_batch_feature_major=out_batch_feature_major,
6399
6406
  )
6400
6407
  padding = "VALID" # input is now already "same" padded, therefore use "valid" padding from here
6408
+ elif isinstance(padding, int) and padding == 0:
6409
+ x = input_data.placeholder
6410
+ padding = "VALID"
6411
+ elif isinstance(padding, (int, list, tuple)):
6412
+ x = self.get_input_placeholder_with_int_padding(
6413
+ input_data=input_data,
6414
+ num_batch_dims=num_batch_dims,
6415
+ out_batch_feature_major=out_batch_feature_major,
6416
+ padding=padding,
6417
+ )
6418
+ padding = "VALID"
6401
6419
  else:
6402
6420
  x = input_data.placeholder
6403
6421
 
@@ -6539,7 +6557,7 @@ class ConvLayer(_ConcatInputLayer):
6539
6557
  :param Sequence[int|Dim] filter_size:
6540
6558
  :param Sequence[int] strides:
6541
6559
  :param Sequence[int] dilation_rate:
6542
- :param str padding:
6560
+ :param str|int|Sequence[int] padding:
6543
6561
  """
6544
6562
  if output.feature_dim_axis == num_batch_dims:
6545
6563
  out_spatial_dims_ = output.dim_tags[num_batch_dims + 1 :]
@@ -6558,7 +6576,7 @@ class ConvLayer(_ConcatInputLayer):
6558
6576
  filter_size=filter_size[i],
6559
6577
  stride=strides[i],
6560
6578
  dilation_rate=dilation_rate[i],
6561
- padding=padding,
6579
+ padding=padding if isinstance(padding, (str, int)) else padding[i],
6562
6580
  )
6563
6581
  assert isinstance(out_tag_calc, Dim)
6564
6582
  out_tag_calc.declare_same_as(out_tag)
@@ -6717,7 +6735,7 @@ class ConvLayer(_ConcatInputLayer):
6717
6735
  """
6718
6736
  Returns the placeholder of input_data with same_static padding applied to it.
6719
6737
 
6720
- :param input_data:
6738
+ :param input_data: [Batch..., Spatial..., Feature] or [Batch..., Feature, Spatial...]
6721
6739
  :param num_batch_dims:
6722
6740
  :param filter_size:
6723
6741
  :param strides:
@@ -6757,6 +6775,44 @@ class ConvLayer(_ConcatInputLayer):
6757
6775
  x = tf.pad(input_data.placeholder, paddings)
6758
6776
  return x
6759
6777
 
6778
+ @classmethod
6779
+ def get_input_placeholder_with_int_padding(
6780
+ cls,
6781
+ input_data: Data,
6782
+ *,
6783
+ num_batch_dims: int,
6784
+ out_batch_feature_major: bool,
6785
+ padding: Union[int, Sequence[int]],
6786
+ pad_value: float = 0.0,
6787
+ ) -> tf.Tensor:
6788
+ """
6789
+ Returns the placeholder of input_data with same_static padding applied to it.
6790
+
6791
+ :param input_data: [Batch..., Spatial..., Feature] or [Batch..., Feature, Spatial...]
6792
+ :param num_batch_dims:
6793
+ :param out_batch_feature_major:
6794
+ :param padding:
6795
+ :param pad_value:
6796
+ """
6797
+ num_spatial_dims = input_data.batch_ndim - num_batch_dims - 1
6798
+ if isinstance(padding, int):
6799
+ padding = [padding] * num_spatial_dims
6800
+ paddings = [[0, 0] for _ in range(input_data.batch_ndim)]
6801
+ for axis, dim in enumerate(input_data.dims):
6802
+ if axis < num_batch_dims:
6803
+ continue
6804
+ if axis == num_batch_dims and out_batch_feature_major:
6805
+ # input_data has dimensions [batch] * num_batch_dims + [channels] + [spatial] * num_spatial_dims
6806
+ continue
6807
+ if axis >= num_batch_dims + num_spatial_dims and not out_batch_feature_major:
6808
+ # input_data has dimensions [batch] * num_batch_dims + [spatial] * num_spatial_dims + [channels]
6809
+ break
6810
+
6811
+ padding_ = padding[axis - num_batch_dims - out_batch_feature_major]
6812
+ paddings[axis] = [padding_, padding_]
6813
+ x = tf.pad(input_data.placeholder, paddings, constant_values=pad_value)
6814
+ return x
6815
+
6760
6816
  @classmethod
6761
6817
  def calc_out_dim(cls, in_dim, filter_size, stride, padding, dilation_rate=1):
6762
6818
  """
@@ -6764,7 +6820,7 @@ class ConvLayer(_ConcatInputLayer):
6764
6820
  :param int|Dim filter_size: e.g. 2, for the corresponding axis
6765
6821
  :param int stride: e.g. 1, for the corresponding axis
6766
6822
  :param int dilation_rate: e.g. 1
6767
- :param str padding: "valid" or "same"
6823
+ :param str|int padding: "valid" or "same"
6768
6824
  :return: the output dimension
6769
6825
  :rtype: T
6770
6826
  """
@@ -6779,13 +6835,16 @@ class ConvLayer(_ConcatInputLayer):
6779
6835
  return a
6780
6836
  return -(-a // b)
6781
6837
 
6782
- padding = padding.upper()
6838
+ padding = padding.upper() if isinstance(padding, str) else padding
6783
6839
  # See tf.compat.v1.nn.convolution() documentation for more.
6784
6840
  if padding == "SAME":
6785
6841
  if isinstance(in_dim, Dim):
6786
6842
  return in_dim.ceildiv_right(stride)
6787
6843
  return ceildiv(in_dim, stride)
6788
- elif padding == "VALID":
6844
+ elif padding == "VALID" or isinstance(padding, int):
6845
+ if isinstance(padding, int) and padding != 0:
6846
+ assert padding > 0
6847
+ in_dim = padding + in_dim + padding
6789
6848
  if isinstance(in_dim, Dim):
6790
6849
  filter_left_dilated = (filter_size - 1) * dilation_rate // 2
6791
6850
  filter_right_dilated = (filter_size - 1) * dilation_rate - filter_left_dilated
@@ -6826,7 +6885,7 @@ class ConvLayer(_ConcatInputLayer):
6826
6885
  :param Sequence[LayerBase] sources:
6827
6886
  :param returnn.tf.network.TFNetwork network:
6828
6887
  :param Sequence[int|Dim] filter_size:
6829
- :param str padding:
6888
+ :param str|int|Sequence[int] padding:
6830
6889
  :param int|Sequence[int] strides:
6831
6890
  :param int|Sequence[int] dilation_rate:
6832
6891
  :param int input_expand_dims: number of dynamic dims to add to the input
@@ -6839,6 +6898,7 @@ class ConvLayer(_ConcatInputLayer):
6839
6898
  :param Sequence[Dim]|None out_spatial_dims:
6840
6899
  :param int input_expand_dims: number of spatial dims to add to the input
6841
6900
  :param bool|NotSpecified auto_use_channel_first:
6901
+ :rtype: Data
6842
6902
  """
6843
6903
  from returnn.util import BehaviorVersion
6844
6904
 
@@ -6857,7 +6917,8 @@ class ConvLayer(_ConcatInputLayer):
6857
6917
  assert len(dilation_rate) == len(filter_size)
6858
6918
  if in_spatial_dims:
6859
6919
  assert len(in_spatial_dims) == len(filter_size)
6860
- padding = padding.upper()
6920
+ if isinstance(padding, str):
6921
+ padding = padding.upper()
6861
6922
  # Be relaxed about incorrect input data. Throw errors later. This can also work during template construction.
6862
6923
  if not input_data.have_batch_axis():
6863
6924
  input_data = input_data.copy_add_batch_dim(batch_dim_axis=0)
@@ -6889,7 +6950,11 @@ class ConvLayer(_ConcatInputLayer):
6889
6950
  for i in range(len(filter_size)):
6890
6951
  old_tag = old_spatial_dim_tags[i] if i < len(old_spatial_dim_tags) else None
6891
6952
  filter_size_ = filter_size[i].dimension if isinstance(filter_size[i], Dim) else filter_size[i]
6892
- if old_tag and (filter_size_ == strides[i] == 1 or (strides[i] == 1 and padding == "SAME")):
6953
+ padding_ = padding if isinstance(padding, (str, int)) else padding[i]
6954
+ if old_tag and (
6955
+ (filter_size_ == strides[i] == 1 and padding_ in ("SAME", "VALID", 0))
6956
+ or (strides[i] == 1 and padding_ == "SAME")
6957
+ ):
6893
6958
  dim_tags.append(old_tag) # identity in this axis
6894
6959
  continue
6895
6960
  new_dim = None
@@ -6899,7 +6964,7 @@ class ConvLayer(_ConcatInputLayer):
6899
6964
  filter_size=filter_size[i],
6900
6965
  stride=strides[i],
6901
6966
  dilation_rate=dilation_rate[i],
6902
- padding=padding,
6967
+ padding=padding_,
6903
6968
  )
6904
6969
  dim_tags.append(
6905
6970
  Dim(
@@ -7009,8 +7074,8 @@ class PoolLayer(_ConcatInputLayer):
7009
7074
  ):
7010
7075
  """
7011
7076
  :param str mode: "max" or "avg"
7012
- :param tuple[int] pool_size: shape of the window of each reduce
7013
- :param str padding: "same", "valid" or "same_static".
7077
+ :param Sequence[int] pool_size: shape of the window of each reduce
7078
+ :param str|int|Sequence[int] padding: "same", "valid" or "same_static".
7014
7079
  "same_static" is calculated differently depending on whether an axis is static or dynamic.
7015
7080
  For static axes, "same_static" padding is the same as "same" padding,
7016
7081
  i.e. filter_size - 1 - (T + strides - 1) % strides.
@@ -7018,13 +7083,13 @@ class PoolLayer(_ConcatInputLayer):
7018
7083
  filter_size - 1, i.e. it is independent of the length T of the axis and the striding.
7019
7084
  For dynamic axes, to avoid skipping any frames on the right,
7020
7085
  we set left_padding = (filter_size - strides) // 2.
7021
- :param tuple[int]|int dilation_rate:
7022
- :param tuple[int]|int|None strides: in contrast to tf.nn.pool, the default (if it is None)
7086
+ :param Sequence[int]|int dilation_rate:
7087
+ :param Sequence[int]|int|None strides: in contrast to tf.nn.pool, the default (if it is None)
7023
7088
  will be set to pool_size
7024
7089
  :param Dim|None in_dim:
7025
- :param list[Dim|str]|None in_spatial_dims:
7090
+ :param Sequence[Dim|str]|None in_spatial_dims:
7026
7091
  :param Dim|None out_dim:
7027
- :param list[Dim]|None out_spatial_dims:
7092
+ :param Sequence[Dim]|None out_spatial_dims:
7028
7093
  :param bool|NotSpecified use_channel_first: if set, will transform input to NCHW format
7029
7094
  :param bool use_time_mask:
7030
7095
  """
@@ -7032,8 +7097,15 @@ class PoolLayer(_ConcatInputLayer):
7032
7097
  assert "out_type" not in kwargs
7033
7098
  mode = mode.upper()
7034
7099
  assert mode in ["MAX", "AVG"]
7035
- padding = padding.upper()
7036
- assert padding in ["VALID", "SAME", "SAME_STATIC"]
7100
+ if isinstance(padding, str):
7101
+ padding = padding.upper()
7102
+ assert padding in ["VALID", "SAME", "SAME_STATIC"]
7103
+ elif isinstance(padding, int) or (
7104
+ isinstance(padding, (list, tuple)) and all(isinstance(p, int) for p in padding)
7105
+ ):
7106
+ pass
7107
+ else:
7108
+ raise TypeError(f"invalid type ({type(padding).__name__}) for padding: {padding}")
7037
7109
  if isinstance(dilation_rate, int):
7038
7110
  dilation_rate = [dilation_rate] * len(pool_size)
7039
7111
  assert len(dilation_rate) == len(pool_size)
@@ -7102,6 +7174,18 @@ class PoolLayer(_ConcatInputLayer):
7102
7174
  out_batch_feature_major=out_batch_feature_major,
7103
7175
  )
7104
7176
  padding = "VALID" # input is now already "same" padded, therefore use "valid" padding from here
7177
+ elif isinstance(padding, int) and padding == 0:
7178
+ x = input_data.placeholder
7179
+ padding = "VALID"
7180
+ elif isinstance(padding, (int, list, tuple)):
7181
+ x = ConvLayer.get_input_placeholder_with_int_padding(
7182
+ input_data=input_data,
7183
+ num_batch_dims=num_batch_dims,
7184
+ out_batch_feature_major=out_batch_feature_major,
7185
+ padding=padding,
7186
+ pad_value={"MAX": float("-inf"), "AVG": 0}[mode],
7187
+ )
7188
+ padding = "VALID"
7105
7189
  else:
7106
7190
  x = input_data.placeholder
7107
7191
 
@@ -7145,14 +7229,14 @@ class PoolLayer(_ConcatInputLayer):
7145
7229
  :param str name:
7146
7230
  :param list[LayerBase] sources:
7147
7231
  :param returnn.tf.network.TFNetwork network:
7148
- :param tuple[int]|list[int] pool_size:
7149
- :param tuple[int]|list[int]|int strides:
7150
- :param int|tuple[int]|list[int] dilation_rate:
7151
- :param str padding:
7232
+ :param Sequence[int] pool_size:
7233
+ :param Sequence[int]|int strides:
7234
+ :param int|Sequence[int] dilation_rate:
7235
+ :param str|int|Sequence[int] padding:
7152
7236
  :param Dim|None in_dim:
7153
- :param list[Dim|str]|None in_spatial_dims:
7237
+ :param Sequence[Dim|str]|None in_spatial_dims:
7154
7238
  :param Dim|None out_dim:
7155
- :param list[Dim]|None out_spatial_dims:
7239
+ :param Sequence[Dim]|None out_spatial_dims:
7156
7240
  :param bool|NotSpecified use_channel_first:
7157
7241
  :rtype: Data
7158
7242
  """
returnn/tf/util/basic.py CHANGED
@@ -3733,13 +3733,14 @@ def single_strided_slice(x, axis, begin=None, end=None, step=None):
3733
3733
  def pad_replicate(x, axes, padding):
3734
3734
  """
3735
3735
  :param tf.Tensor x:
3736
- :param list[int] axes:
3737
- :param list[(int,int)] padding:
3736
+ :param Sequence[int] axes:
3737
+ :param Sequence[(int|Dim,int|Dim)] padding:
3738
3738
  :rtype: tf.Tensor
3739
3739
  """
3740
3740
  with tf.name_scope("pad_replicate"):
3741
3741
  assert len(padding) == 1, "Not implemented otherwise yet"
3742
3742
  assert len(axes) == 1, "Not implemented otherwise yet"
3743
+ assert isinstance(padding[0][0], int) and isinstance(padding[0][1], int) # not implemented otherwise yet
3743
3744
  pad_left = tf.gather(x, 0, axis=axes[0])
3744
3745
  pad_left = tf.expand_dims(pad_left, axis=axes[0])
3745
3746
  pad_left = tf.repeat(pad_left, padding[0][0], axis=axes[0])
returnn/torch/engine.py CHANGED
@@ -3,7 +3,7 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Any, Union, Callable, Dict, Set, Tuple
6
+ from typing import Optional, Any, Union, Callable, Dict, Set
7
7
  from contextlib import nullcontext, ExitStack, contextmanager
8
8
 
9
9
  import gc
@@ -365,8 +365,6 @@ class Engine(EngineBase):
365
365
  zero_grad_next_step = True
366
366
  cur_count_grad_accum = 0
367
367
  extern_data = None
368
- num_seqs = None
369
- last_seq_idx = 0
370
368
 
371
369
  total_data_size_packed = NumbersDict()
372
370
  total_data_size_padded = NumbersDict()
@@ -400,20 +398,8 @@ class Engine(EngineBase):
400
398
  )
401
399
 
402
400
  complete_frac = float(extern_data_raw["complete_frac"])
403
- num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
404
- report_prefix=report_prefix,
405
- extern_data_raw=extern_data_raw,
406
- step_idx=step_idx,
407
- prev_num_seqs=num_seqs,
408
- prev_last_seq_idx=last_seq_idx,
409
- )
410
- epoch_continuous = (
411
- self.epoch - 1 + complete_frac
412
- if complete_frac >= 0.0
413
- else (self.epoch - 1 + (last_seq_idx + 1) / num_seqs)
414
- if num_seqs is not None
415
- else None
416
- )
401
+ epoch_continuous = self.epoch - 1 + complete_frac if complete_frac >= 0.0 else None
402
+ num_seqs = int(extern_data_raw["num_seqs"])
417
403
 
418
404
  # clear the gradients when every gradient accumulation loop starts
419
405
  if zero_grad_next_step:
@@ -490,7 +476,7 @@ class Engine(EngineBase):
490
476
  eval_info=dict(eval_info),
491
477
  step_duration=step_duration,
492
478
  start_elapsed=step_end_time - epoch_start_time,
493
- seq_idx=last_seq_idx,
479
+ complete_frac=complete_frac,
494
480
  num_seqs=num_seqs,
495
481
  batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
496
482
  log_memory_usage_device=self._device if self._log_memory_usage else None,
@@ -629,13 +615,18 @@ class Engine(EngineBase):
629
615
  accumulated_losses_dict = NumbersDict()
630
616
  accumulated_inv_norm_factors_dict = NumbersDict()
631
617
  step_idx = 0
618
+ eval_start_time = time.monotonic()
632
619
 
620
+ report_prefix = f"ep {self.epoch} {dataset_name} eval"
633
621
  with torch.no_grad():
634
622
  for extern_data_raw in data_loader:
635
623
  if self._torch_distributed_ctx and step_idx % 100 == 0:
636
624
  _has_data = torch.tensor([True], device="cpu", dtype=torch.int8)
637
625
  torch.distributed.broadcast(_has_data, src=0)
638
626
 
627
+ complete_frac = float(extern_data_raw["complete_frac"])
628
+ num_seqs = int(extern_data_raw["num_seqs"])
629
+
639
630
  extern_data = extern_data_util.raw_dict_to_extern_data(
640
631
  extern_data_raw,
641
632
  extern_data_template=self.extern_data,
@@ -644,6 +635,8 @@ class Engine(EngineBase):
644
635
  )
645
636
 
646
637
  self._run_step(extern_data, train_func=True)
638
+ step_end_time = time.monotonic()
639
+
647
640
  train_ctx = rf.get_run_ctx()
648
641
 
649
642
  losses_dict = NumbersDict(
@@ -664,9 +657,12 @@ class Engine(EngineBase):
664
657
  accumulated_inv_norm_factors_dict += inv_norm_factors_dict
665
658
  eval_info = self._maybe_extend_losses_info(losses_dict / inv_norm_factors_dict)
666
659
  _print_process(
667
- f"ep {self.epoch} {dataset_name} eval",
660
+ report_prefix,
668
661
  step=step_idx,
669
662
  eval_info=dict(eval_info),
663
+ complete_frac=complete_frac,
664
+ num_seqs=num_seqs,
665
+ start_elapsed=step_end_time - eval_start_time,
670
666
  log_memory_usage_device=self._device if self._log_memory_usage else None,
671
667
  )
672
668
  step_idx += 1
@@ -1290,8 +1286,6 @@ class Engine(EngineBase):
1290
1286
  new_dim.dyn_size_ext = _get_tensor_wo_batch_numpy(dim.dyn_size_ext)
1291
1287
  return new_dim
1292
1288
 
1293
- num_seqs = None
1294
- last_seq_idx = 0
1295
1289
  report_prefix = f"ep {self.epoch} {dataset.name} forward"
1296
1290
  with torch.no_grad():
1297
1291
  callback.init(model=self._orig_model)
@@ -1300,13 +1294,8 @@ class Engine(EngineBase):
1300
1294
  for extern_data_raw in data_loader:
1301
1295
  step_begin_time = time.monotonic()
1302
1296
 
1303
- num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
1304
- report_prefix=report_prefix,
1305
- extern_data_raw=extern_data_raw,
1306
- step_idx=step_idx,
1307
- prev_num_seqs=num_seqs,
1308
- prev_last_seq_idx=last_seq_idx,
1309
- )
1297
+ complete_frac = float(extern_data_raw["complete_frac"])
1298
+ num_seqs = int(extern_data_raw["num_seqs"])
1310
1299
 
1311
1300
  if self._forward_step_expected_outputs:
1312
1301
  # Also resets any dyn dims, which might have been set in the prev step.
@@ -1354,7 +1343,7 @@ class Engine(EngineBase):
1354
1343
  eval_info=None,
1355
1344
  step_duration=step_duration,
1356
1345
  start_elapsed=step_end_time - epoch_start_time,
1357
- seq_idx=last_seq_idx,
1346
+ complete_frac=complete_frac,
1358
1347
  num_seqs=num_seqs,
1359
1348
  batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
1360
1349
  log_memory_usage_device=self._device if self._log_memory_usage else None,
@@ -1442,7 +1431,7 @@ def _print_process(
1442
1431
  batch_size_info: Optional[Dict[str, Any]] = None,
1443
1432
  step_duration: Optional[float] = None,
1444
1433
  start_elapsed: Optional[float] = None,
1445
- seq_idx: Optional[int] = None,
1434
+ complete_frac: Optional[float] = None,
1446
1435
  num_seqs: Optional[int] = None,
1447
1436
  log_memory_usage_device: Optional[str] = None,
1448
1437
  ):
@@ -1455,11 +1444,14 @@ def _print_process(
1455
1444
  :param batch_size_info:
1456
1445
  :param step_duration: time elapsed for this step (secs)
1457
1446
  :param start_elapsed: time elapsed since epoch start (secs)
1458
- :param num_seqs: total number of sequences for this epoch
1447
+ :param complete_frac: how much of the current epoch is already consumed
1448
+ :param num_seqs: total number of seqs this epoch
1459
1449
  :param log_memory_usage_device: if given, will log memory usage (peak allocated memory)
1460
1450
  :return: nothing, will be printed to log
1461
1451
  """
1462
1452
  if log.verbose[5]: # report every minibatch
1453
+ if step == 0 and num_seqs is not None and num_seqs >= 0:
1454
+ print(f"{report_prefix} num_seqs: {num_seqs}", file=log.v5)
1463
1455
  info = [report_prefix, "step %i" % step]
1464
1456
  if eval_info: # Such as score.
1465
1457
  info += ["%s %s" % (k, _format_score_value(v)) for k, v in eval_info.items()]
@@ -1475,17 +1467,16 @@ def _print_process(
1475
1467
  info += ["%.3f sec/step" % step_duration]
1476
1468
  if start_elapsed is not None:
1477
1469
  info += ["elapsed %s" % hms(start_elapsed)]
1478
- if num_seqs is not None:
1479
- assert seq_idx is not None and start_elapsed is not None # unexpected combination...
1480
- complete = (seq_idx + 1) / num_seqs
1481
- assert 1 >= complete > 0, f"{step} step, {num_seqs} num_seqs"
1482
- total_time_estimated = start_elapsed / complete
1470
+ if complete_frac is not None:
1471
+ assert 1 >= complete_frac > 0, f"{step} step, {complete_frac} complete_frac"
1472
+ assert start_elapsed is not None
1473
+ total_time_estimated = start_elapsed / complete_frac
1483
1474
  remaining_estimated = total_time_estimated - start_elapsed
1484
1475
  info += [
1485
1476
  "exp. remaining %s" % hms(remaining_estimated),
1486
- "complete %.02f%%" % (complete * 100),
1477
+ "complete %.02f%%" % (complete_frac * 100),
1487
1478
  ]
1488
- if start_elapsed is not None and num_seqs is None:
1479
+ if start_elapsed is not None and complete_frac is None:
1489
1480
  info += ["(unk epoch len)"]
1490
1481
  print(", ".join(filter(None, info)), file=log.v5)
1491
1482
 
@@ -1634,27 +1625,3 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1634
1625
  p=p,
1635
1626
  ).item()
1636
1627
  )
1637
-
1638
-
1639
- def _get_num_seqs_last_seq_idx(
1640
- *,
1641
- report_prefix: str,
1642
- extern_data_raw: Dict[str, Any],
1643
- step_idx: int,
1644
- prev_num_seqs: Optional[int],
1645
- prev_last_seq_idx: int,
1646
- ) -> Tuple[Optional[int], int]:
1647
- num_seqs = prev_num_seqs
1648
- num_seqs_ = int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
1649
- # Note: The batches might have been shuffled,
1650
- # thus we cannot really assert that the seq_idx is always increasing.
1651
- last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), prev_last_seq_idx)
1652
- if step_idx == 0:
1653
- if num_seqs_ >= 0:
1654
- print(f"{report_prefix} num_seqs: {num_seqs_}", file=log.v5)
1655
- num_seqs = num_seqs_
1656
- elif num_seqs_ >= 0:
1657
- assert num_seqs_ == num_seqs
1658
- if num_seqs is not None:
1659
- assert last_seq_idx < num_seqs
1660
- return num_seqs, last_seq_idx
@@ -1879,7 +1879,7 @@ class TorchBackend(Backend[torch.Tensor]):
1879
1879
  out_spatial_dims: Optional[Sequence[Dim]] = None,
1880
1880
  filter: Tensor,
1881
1881
  filter_size: Sequence[Dim], # to have the order well-defined
1882
- padding: str,
1882
+ padding: Union[str, int, Sequence[int]],
1883
1883
  strides: Optional[Union[int, Sequence[int]]] = None,
1884
1884
  dilation_rate: Optional[Union[int, Sequence[int]]] = None,
1885
1885
  groups: Optional[int] = None,
@@ -2008,7 +2008,7 @@ class TorchBackend(Backend[torch.Tensor]):
2008
2008
  *,
2009
2009
  mode: str,
2010
2010
  pool_size: Sequence[int],
2011
- padding: str = "valid",
2011
+ padding: Union[str, int, Sequence[int]] = "valid",
2012
2012
  dilation_rate: Union[Sequence[int], int] = 1,
2013
2013
  strides: Sequence[int],
2014
2014
  in_spatial_dims: Sequence[Dim],
@@ -2035,19 +2035,22 @@ class TorchBackend(Backend[torch.Tensor]):
2035
2035
  [-1, batch_dims[-1].get_dim_value() if batch_dims else 1] + [d.get_dim_value() for d in in_spatial_dims],
2036
2036
  )
2037
2037
  assert isinstance(strides, (list, tuple)) and len(strides) == len(in_spatial_dims) == len(pool_size)
2038
- if padding.lower() == "same":
2038
+ if isinstance(padding, str) and padding.lower() == "same":
2039
2039
  # padding='same' is not quite the same as ceil_mode=True, so we explicitly pad here.
2040
2040
  padding = []
2041
2041
  for i, s in enumerate(pool_size):
2042
2042
  # See comment in conv.
2043
+ # I'm a bit unsure here... https://github.com/pytorch/pytorch/issues/148123
2043
2044
  pad = s - 1 - (src_raw.shape[2 + i] - 1) % strides[i]
2044
2045
  padding.append(pad // 2)
2045
2046
  ceil_mode = True
2046
- elif padding.lower() == "valid":
2047
+ elif isinstance(padding, str) and padding.lower() == "valid":
2047
2048
  padding = 0
2048
2049
  ceil_mode = False
2050
+ elif isinstance(padding, (int, tuple, list)):
2051
+ ceil_mode = False
2049
2052
  else:
2050
- raise ValueError(f"invalid padding {padding!r}")
2053
+ raise ValueError(f"invalid padding {padding!r} (type {type(padding).__name__}")
2051
2054
  func_name = f"{mode}_pool{len(in_spatial_dims)}d"
2052
2055
  func = getattr(torch.nn.functional, func_name) # e.g. torch.nn.functional.max_pool1d
2053
2056
  kwargs = {}
returnn/util/basic.py CHANGED
@@ -219,7 +219,7 @@ class BehaviorVersion:
219
219
  See :ref:`behavior_version`.
220
220
  """
221
221
 
222
- _latest_behavior_version = 23
222
+ _latest_behavior_version = 24
223
223
  _behavior_version = None # type: typing.Optional[int]
224
224
  _min_behavior_version = 0 # type: int
225
225
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250304.10039
3
+ Version: 1.20250304.113330
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=et7Z9NstTVvnWjiIMXhquw3eiMnMxYMfnEEbVc755xQ,5214
1
+ returnn/PKG-INFO,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=d4hd9PkngTUKLJT4Q6GLMhVg4nXyV3Pym04_IKcblgc,77
6
+ returnn/_setup_info_generated.py,sha256=94BElbYUGmjpsoY8BzvfW39RUTXw9Fy3UwlPoEjrkU8,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -75,7 +75,7 @@ returnn/extern/graph_editor/subgraph.py,sha256=R3uIFqWgiL7L5S4YATm9o9a3wfEa_mSb4
75
75
  returnn/extern/graph_editor/transform.py,sha256=d9fEgu0JC342q0g9niVxRWMKzkQQA9mrrajBGcU1o_s,29349
76
76
  returnn/extern/graph_editor/util.py,sha256=QMrQeQZ7lJwsrNQub9tof0h3quEaoHiGJaZmogQ7jXE,18707
77
77
  returnn/frontend/__init__.py,sha256=2aS7nbxXniIrBp2DODl0xN0f3IJ_dX4Bi9ZlR7W5_DE,1472
78
- returnn/frontend/_backend.py,sha256=TNkEdj9GKxJfSM1ZMQ_SdAQzn2TU7SQbG6JGdaWhUeI,50374
78
+ returnn/frontend/_backend.py,sha256=JNqQomHCN4-1VLq5o9VRbs_L8gSZkvOgjUmRYt8jx1o,50428
79
79
  returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,8403
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
@@ -88,7 +88,7 @@ returnn/frontend/cond.py,sha256=gh6wg0aSbAJQfKRv4BQAu-EfPWtWPLFjgc8IaPPFmwg,1023
88
88
  returnn/frontend/const.py,sha256=bL51HXxq858dWmrKd61k8tWBWIe67jVf9pj1wZcZZAo,3945
89
89
  returnn/frontend/container.py,sha256=wF3OlQN7WlOVmmdapUth_Unha3DVf6h1B7okBJAuJDA,8011
90
90
  returnn/frontend/control_flow_ctx.py,sha256=v17CsNwRnZYe8GdMtGJt2ftibfxMCGK1i0l-GX5ILu0,699
91
- returnn/frontend/conv.py,sha256=p4R6j40GCvVrw3kbQQJtfxY6tfIR8Rb3tIzwAtiLuec,23858
91
+ returnn/frontend/conv.py,sha256=Q0q90-uu9d6qV-v8_DlFGxpZtc6FjfXVpfkkXmv1Alk,31959
92
92
  returnn/frontend/device.py,sha256=K7Y1qoQcO4GIHgLkPLQWK-GVT8gKL8GwyQrmPo8LgBE,1438
93
93
  returnn/frontend/dims.py,sha256=aH5FQ_m0xMD6Rj-BUWGx8lB-HkCuwZfMBf6mZbGGW5E,12611
94
94
  returnn/frontend/dropout.py,sha256=rsx3p5b0NblBfXXSQZTQFJ8jUUS3fj4Qzc39iffBMCA,5006
@@ -177,7 +177,7 @@ returnn/tf/sprint.py,sha256=Yqjh0-6sCWHpdDPQCzHKx7TwQCOjJyjfd0KHtnYdd-8,5471
177
177
  returnn/tf/updater.py,sha256=St4Z5iBjlkWaB6CiS-K1VNc_iLaan2e6-mVMTTPldzk,72034
178
178
  returnn/tf/frontend_layers/README.md,sha256=P4vVl_EK-4jT55m40mq-K4Nr9yFY0tJR5fmDzTHSDFE,1096
179
179
  returnn/tf/frontend_layers/__init__.py,sha256=MGUn7rv6fOefbtkX-5pq6fC1T6Y5h0oh1uOPSEcv1_I,506
180
- returnn/tf/frontend_layers/_backend.py,sha256=U7rbRY9XgMkxxyWY2D8KG-KesSOEGLCxn-Gl6dgwmPc,47277
180
+ returnn/tf/frontend_layers/_backend.py,sha256=igo147YCTVdNuUBm2euEwjAhpH5yDHyQAf5T4jcCrLM,47331
181
181
  returnn/tf/frontend_layers/_utils.py,sha256=ijByaDOqPDod5mZC9EoTkt8PHBEODXHsWbkwDOF9XW4,4205
182
182
  returnn/tf/frontend_layers/cond.py,sha256=yQ2h5W0sgMZndJdrWv2EE9k9yIcspQ1U0HwBSh3hOKE,14830
183
183
  returnn/tf/frontend_layers/config_entry_points.py,sha256=t01RWOiaZohzuqPXX-MLV0P5yCOfE0dz-9dZ77_pK4c,5751
@@ -193,13 +193,13 @@ returnn/tf/frontend_low_level/__init__.py,sha256=34469k3KzMUIGowxReOZnbf6WdTjxY7
193
193
  returnn/tf/frontend_low_level/_backend.py,sha256=JwwRRIGnElqBC4bTImdB7w3U1u_SJESeZHYLmq86wog,24479
194
194
  returnn/tf/layers/__init__.py,sha256=Ngu-X84nWFgz7ndDu88DqoZ-5lUMMTQWH4g7N8pSoCg,72
195
195
  returnn/tf/layers/base.py,sha256=KcADpZUxqLkoFpQPMe_l9thRC7rpyBJIZCHITmnOd7M,153169
196
- returnn/tf/layers/basic.py,sha256=la0EwaHVzAbL6JOXs6QXnYQ74F3R16piYpT55VwVFT4,611063
196
+ returnn/tf/layers/basic.py,sha256=7eefkCNa8aqh96Hl2Tr8b6rqpE0cudgyyQCuQK-QNKU,615168
197
197
  returnn/tf/layers/rec.py,sha256=K9vvyDJeDApYQDKabz7PaOTGHeSTloInkecxKTbqeTU,548357
198
198
  returnn/tf/layers/segmental_model.py,sha256=wUyDZGr-eTVIIQWcsHLML0wtOxuWn_NFKOIrUKQcvoI,21515
199
199
  returnn/tf/layers/signal_processing.py,sha256=vRlkN7k7otk9_Qdv0qr_l6V0VT5Q6dO2MxwZWb2HH2M,52693
200
200
  returnn/tf/layers/variable.py,sha256=G1dIEoq0iQsXp-uOAUPTaBKHSOQfx7Sn-spD8MRv0HM,11446
201
201
  returnn/tf/util/__init__.py,sha256=mEg5jNVbQBLO2TGwO4Ff2F5qQN5_Zg4hAAQfX5taeec,92
202
- returnn/tf/util/basic.py,sha256=F1-3Huh4mdoLgBCYbh4z5rDFz2meWMwsGQc3B87wOXg,302811
202
+ returnn/tf/util/basic.py,sha256=8c0xEQNcsIvts2ydwZdUvqk4HsTJFyH_xYPQzaZbV6M,302941
203
203
  returnn/tf/util/data.py,sha256=AlSa0r_IaXtjKG1q1vxUybFazpjt4lUX8LYq0STJv-w,29471
204
204
  returnn/tf/util/gradient_checkpoint.py,sha256=_1NGAmNZ5NiGhFYVRWvBV5yejt-EZWbbvxNWHbESp5Q,7426
205
205
  returnn/tf/util/ken_lm.py,sha256=R60UAoywriuDIeQ2Hk3Vm_waf2Hxxc88ofzEw6X6Sd4,17313
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
207
207
  returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=skFyutdVztxgTEk3HHJ8S83qRWbNpkNT8Tj16Ic0_hE,6981
210
- returnn/torch/engine.py,sha256=sU9A96icaj65uaEkX4i4aUK3IrB2S19_Fb9_sueB_JE,77426
210
+ returnn/torch/engine.py,sha256=2FLLb2m4sWFwYOQGREDSxQCheCKd_osnFJCdLa_4TzE,76400
211
211
  returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
216
216
  returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
217
217
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
218
218
  returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
219
- returnn/torch/frontend/_backend.py,sha256=rFCoCnzZoBtHPg7mWpO3yJOJMVesuWuA3_6GGSKMc5k,101452
219
+ returnn/torch/frontend/_backend.py,sha256=SKxxpIM0rXEcZ92p-Um5thfC7vmDoZmda13SMAXVYL0,101771
220
220
  returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
221
221
  returnn/torch/frontend/bridge.py,sha256=Z2_UW8AagezC7zsXDc5PKcd8G9WwisV7j9SWGHU0m4U,7840
222
222
  returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
@@ -233,7 +233,7 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
233
233
  returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
234
234
  returnn/torch/util/scaled_gradient.py,sha256=3585VuNypBty-pW6r3BKK047H3MqZQSdMjXeYAb4cmU,3192
235
235
  returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
236
- returnn/util/basic.py,sha256=Iynt9ATEs_8DaZsX5z6weMyaO2xW9o3gaywq6X7mbEc,142380
236
+ returnn/util/basic.py,sha256=eLlzR-ARGWJoiyRb5-SH5v2zx1jgR_5vuQ5jwYO5Cww,142380
237
237
  returnn/util/better_exchook.py,sha256=MVMnuu6KoyqgvlMeQLQNTfdspcPR9MwigCXOpeTVqCI,62956
238
238
  returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
239
239
  returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250304.10039.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250304.10039.dist-info/METADATA,sha256=et7Z9NstTVvnWjiIMXhquw3eiMnMxYMfnEEbVc755xQ,5214
258
- returnn-1.20250304.10039.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
- returnn-1.20250304.10039.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250304.10039.dist-info/RECORD,,
256
+ returnn-1.20250304.113330.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250304.113330.dist-info/METADATA,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
258
+ returnn-1.20250304.113330.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
+ returnn-1.20250304.113330.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250304.113330.dist-info/RECORD,,