returnn 1.20250508.93313__py3-none-any.whl → 1.20250513.145447__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (67) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/datasets/basic.py +24 -25
  4. returnn/datasets/cached.py +4 -3
  5. returnn/datasets/distrib_files.py +1 -2
  6. returnn/datasets/generating.py +20 -20
  7. returnn/datasets/hdf.py +9 -9
  8. returnn/datasets/lm.py +25 -13
  9. returnn/datasets/meta.py +39 -38
  10. returnn/datasets/normalization_data.py +1 -1
  11. returnn/datasets/postprocessing.py +20 -13
  12. returnn/datasets/sprint.py +8 -7
  13. returnn/datasets/util/strings.py +0 -1
  14. returnn/datasets/util/vocabulary.py +3 -3
  15. returnn/extern/graph_editor/subgraph.py +1 -2
  16. returnn/extern/graph_editor/transform.py +1 -2
  17. returnn/extern/graph_editor/util.py +1 -2
  18. returnn/frontend/_backend.py +4 -3
  19. returnn/frontend/_utils.py +1 -1
  20. returnn/frontend/audio/mel.py +0 -1
  21. returnn/frontend/const.py +3 -3
  22. returnn/frontend/device.py +0 -1
  23. returnn/frontend/dropout.py +1 -1
  24. returnn/frontend/encoder/e_branchformer.py +1 -1
  25. returnn/frontend/loop.py +3 -3
  26. returnn/frontend/loss.py +0 -1
  27. returnn/frontend/matmul.py +0 -1
  28. returnn/frontend/run_ctx.py +9 -9
  29. returnn/frontend/signal.py +0 -1
  30. returnn/frontend/types.py +2 -4
  31. returnn/native_op.py +13 -0
  32. returnn/sprint/cache.py +2 -4
  33. returnn/sprint/interface.py +3 -4
  34. returnn/tensor/_dim_extra.py +9 -9
  35. returnn/tensor/_tensor_extra.py +20 -19
  36. returnn/tensor/_tensor_op_overloads.py +0 -1
  37. returnn/tensor/tensor.py +1 -1
  38. returnn/tensor/tensor_dict.py +9 -9
  39. returnn/tf/engine.py +60 -65
  40. returnn/tf/frontend_layers/_backend.py +3 -3
  41. returnn/tf/frontend_layers/cond.py +6 -6
  42. returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
  43. returnn/tf/frontend_layers/layer.py +12 -12
  44. returnn/tf/frontend_layers/loop.py +3 -3
  45. returnn/tf/frontend_layers/make_layer.py +0 -1
  46. returnn/tf/layers/base.py +56 -49
  47. returnn/tf/layers/basic.py +60 -65
  48. returnn/tf/layers/rec.py +74 -74
  49. returnn/tf/native_op.py +1 -3
  50. returnn/tf/network.py +60 -57
  51. returnn/tf/updater.py +3 -3
  52. returnn/tf/util/basic.py +24 -23
  53. returnn/torch/data/extern_data.py +4 -5
  54. returnn/torch/data/pipeline.py +3 -4
  55. returnn/torch/engine.py +16 -16
  56. returnn/torch/frontend/_backend.py +15 -15
  57. returnn/torch/frontend/bridge.py +3 -3
  58. returnn/torch/updater.py +8 -9
  59. returnn/torch/util/debug_inf_nan.py +0 -2
  60. returnn/torch/util/exception_helper.py +1 -1
  61. returnn/torch/util/scaled_gradient.py +0 -1
  62. returnn/util/basic.py +1 -2
  63. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/METADATA +1 -1
  64. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/RECORD +67 -67
  65. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/LICENSE +0 -0
  66. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/WHEEL +0 -0
  67. {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ Many canonical basic layers.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Optional, Union, Sequence, List, Tuple, Dict
7
+ from typing import Callable, Optional, Union, Sequence, List, Tuple, Dict
8
8
  import typing
9
9
  import tensorflow as tf
10
10
  import contextlib
@@ -126,7 +126,7 @@ def concat_sources(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpec
126
126
  data.placeholder = tf.concat(
127
127
  axis=data.feature_dim_axis, values=[layer_data.placeholder for layer_data in layers_data]
128
128
  )
129
- axes_split_info = [None] * data.batch_ndim # type: typing.List[typing.Optional[typing.List[int]]]
129
+ axes_split_info: List[Optional[List[int]]] = [None] * data.batch_ndim
130
130
  axes_split_info[data.feature_dim_axis] = [layer_data.dim for layer_data in layers_data]
131
131
  tf_util.set_param_axes_split_info(data.placeholder, axes_split_info)
132
132
  # Note: We will loose this info for any further op (e.g. dropout, activation, etc). Should be better...
@@ -294,7 +294,7 @@ class _ConcatInputLayer(LayerBase):
294
294
  elif mask == "dropout":
295
295
  assert dropout > 0
296
296
  self.dropout = dropout
297
- self.input_data = None # type: typing.Optional[Data]
297
+ self.input_data: Optional[Data] = None
298
298
  if self.sources:
299
299
  self.input_data = concat_sources_with_opt_dropout(
300
300
  self.sources,
@@ -509,9 +509,7 @@ class ConcatLayer(LayerBase):
509
509
  assert sources
510
510
  sources, axes = zip(*sources) # unzip
511
511
  axes_int = [layer.output.get_axis_from_description(axis) for (layer, axis) in zip(sources, axes)]
512
- concat_dim_tags = [
513
- layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int)
514
- ] # type: typing.List[Dim]
512
+ concat_dim_tags: List[Dim] = [layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int)]
515
513
  if any(tag.dimension is None for tag in concat_dim_tags):
516
514
  dimension = None
517
515
  else:
@@ -707,8 +705,8 @@ class SelectSearchSourcesLayer(InternalLayer):
707
705
  self.output = src.output.copy_as_batch_major()
708
706
  self.rec_vars_outputs = src.rec_vars_outputs.copy()
709
707
  src_search_choices = src.get_search_choices()
710
- self.transform_func = None # type: typing.Optional[typing.Callable[[tf.Tensor],tf.Tensor]]
711
- self.search_choices_seq = None # type: typing.Optional[typing.List[SearchChoices]]
708
+ self.transform_func: Optional[Callable[[tf.Tensor], tf.Tensor]] = None
709
+ self.search_choices_seq: Optional[List[SearchChoices]] = None
712
710
  if not search_choices:
713
711
  assert not src_search_choices
714
712
  assert not self.output.beam
@@ -726,13 +724,7 @@ class SelectSearchSourcesLayer(InternalLayer):
726
724
  assert src_search_choices in search_choices_seq, self.network.debug_search_choices(
727
725
  self.search_choices_layer
728
726
  ) or (
729
- (
730
- "%s: No common search base:\n"
731
- "from layer %s\n"
732
- "search choices %s,\n"
733
- "to layer %s\n"
734
- "search choices\n%s."
735
- )
727
+ "%s: No common search base:\nfrom layer %s\nsearch choices %s,\nto layer %s\nsearch choices\n%s."
736
728
  % (self, src, src_search_choices, self.search_choices_layer, pformat(search_choices_seq))
737
729
  )
738
730
  search_choices_seq = search_choices_seq[: search_choices_seq.index(src_search_choices)]
@@ -4436,12 +4428,13 @@ class MergeDimsLayer(_ConcatInputLayer):
4436
4428
  :rtype: list[int]
4437
4429
  """
4438
4430
  if keep_order:
4439
- assert isinstance(axes, (tuple, list, typing.Sequence)) and not isinstance(
4440
- axes, str
4441
- ), "%s: axes %r must be a list or tuple, to have a well defined order in input %s" % (
4442
- name,
4443
- axes,
4444
- input_data,
4431
+ assert isinstance(axes, (tuple, list, typing.Sequence)) and not isinstance(axes, str), (
4432
+ "%s: axes %r must be a list or tuple, to have a well defined order in input %s"
4433
+ % (
4434
+ name,
4435
+ axes,
4436
+ input_data,
4437
+ )
4445
4438
  )
4446
4439
  axes_ = []
4447
4440
  for axis in axes:
@@ -5562,11 +5555,12 @@ class RepeatLayer(_ConcatInputLayer):
5562
5555
  repetitions_data = repetitions_data.copy_add_dim_by_tag(axis_dim_tag, unbroadcast=True)
5563
5556
  repetitions_axis = repetitions_data.get_axis_from_description(axis, allow_int=False)
5564
5557
  assert repetitions_data.ndim == 1, "Repetitions %r must only have at most one non-batch axis" % repetitions
5565
- assert (
5566
- repetitions_data.batch_shape[repetitions_axis] == self.input_data.batch_shape[input_axis]
5567
- ), "Axis mismatch between input (%i) and repetitions (%i)" % (
5568
- self.input_data.batch_shape[input_axis],
5569
- repetitions_data.batch_shape[repetitions_axis],
5558
+ assert repetitions_data.batch_shape[repetitions_axis] == self.input_data.batch_shape[input_axis], (
5559
+ "Axis mismatch between input (%i) and repetitions (%i)"
5560
+ % (
5561
+ self.input_data.batch_shape[input_axis],
5562
+ repetitions_data.batch_shape[repetitions_axis],
5563
+ )
5570
5564
  )
5571
5565
 
5572
5566
  assert self.output.have_batch_axis() == (
@@ -6267,9 +6261,9 @@ class ConvLayer(_ConcatInputLayer):
6267
6261
  from returnn.util import BehaviorVersion
6268
6262
 
6269
6263
  padding = padding.upper() if isinstance(padding, str) else padding
6270
- assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance(
6271
- padding, (int, tuple, list)
6272
- ), f"{self}: got unsupported padding {padding}"
6264
+ assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance(padding, (int, tuple, list)), (
6265
+ f"{self}: got unsupported padding {padding}"
6266
+ )
6273
6267
  assert "out_type" not in kwargs, "don't set out_type explicitly for this layer"
6274
6268
  assert len(filter_size) in (1, 2, 3), "only 1D conv, 2D conv or 3D conv supported"
6275
6269
  super(ConvLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs)
@@ -6285,9 +6279,9 @@ class ConvLayer(_ConcatInputLayer):
6285
6279
  assert len(dilation_rate) == len(filter_size)
6286
6280
  assert not self.input_data.sparse
6287
6281
  assert self.input_data.have_batch_axis()
6288
- assert (
6289
- self.input_data.have_feature_axis()
6290
- ), "this should be our single input feature dim now. otherwise use input_add_feature_dim"
6282
+ assert self.input_data.have_feature_axis(), (
6283
+ "this should be our single input feature dim now. otherwise use input_add_feature_dim"
6284
+ )
6291
6285
  input_data, num_batch_dims = self.transform_input(
6292
6286
  self.input_data,
6293
6287
  network=self.network,
@@ -7117,9 +7111,9 @@ class PoolLayer(_ConcatInputLayer):
7117
7111
  super(PoolLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs)
7118
7112
  assert not self.input_data.sparse
7119
7113
  assert self.input_data.have_batch_axis()
7120
- assert (
7121
- self.input_data.have_feature_axis()
7122
- ), "this should be our single input feature dim now. otherwise use input_add_feature_dim"
7114
+ assert self.input_data.have_feature_axis(), (
7115
+ "this should be our single input feature dim now. otherwise use input_add_feature_dim"
7116
+ )
7123
7117
  if in_dim and out_dim:
7124
7118
  assert in_dim == out_dim
7125
7119
  elif in_dim:
@@ -7381,9 +7375,9 @@ class TransposedConvLayer(_ConcatInputLayer):
7381
7375
  out_dim # noqa # via get_out_data_from_opts
7382
7376
  assert not self.input_data.sparse
7383
7377
  assert self.input_data.have_batch_axis()
7384
- assert (
7385
- self.input_data.have_feature_axis()
7386
- ), "this should be our single input feature dim now. otherwise use input_add_feature_dim"
7378
+ assert self.input_data.have_feature_axis(), (
7379
+ "this should be our single input feature dim now. otherwise use input_add_feature_dim"
7380
+ )
7387
7381
  input_data, num_batch_dims = ConvLayer.transform_input(
7388
7382
  self.input_data,
7389
7383
  network=self.network,
@@ -7404,14 +7398,15 @@ class TransposedConvLayer(_ConcatInputLayer):
7404
7398
  remove_padding = [remove_padding] * len(spatial_axes)
7405
7399
  if not isinstance(output_padding, (list, tuple)):
7406
7400
  output_padding = [output_padding] * len(spatial_axes)
7407
- assert (
7408
- len(spatial_axes) == len(filter_size) == len(strides) == len(remove_padding) == len(output_padding)
7409
- ), "%s: expected %i-D transposed-conv for input %r but got filter %r and strides %r" % (
7410
- self,
7411
- len(spatial_axes),
7412
- input_data,
7413
- filter_size,
7414
- strides,
7401
+ assert len(spatial_axes) == len(filter_size) == len(strides) == len(remove_padding) == len(output_padding), (
7402
+ "%s: expected %i-D transposed-conv for input %r but got filter %r and strides %r"
7403
+ % (
7404
+ self,
7405
+ len(spatial_axes),
7406
+ input_data,
7407
+ filter_size,
7408
+ strides,
7409
+ )
7415
7410
  )
7416
7411
  assert len(spatial_axes) in [1, 2], "%s: %i-D not yet implemented..." % (self, len(spatial_axes))
7417
7412
  x = input_data.placeholder
@@ -8775,9 +8770,9 @@ class DotLayer(LayerBase):
8775
8770
  red1,
8776
8771
  red2,
8777
8772
  )
8778
- assert len(a_reduce_axes) == len(
8779
- b_reduce_axes
8780
- ), "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (self, self.sources, red1, red2)
8773
+ assert len(a_reduce_axes) == len(b_reduce_axes), (
8774
+ "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (self, self.sources, red1, red2)
8775
+ )
8781
8776
  if (
8782
8777
  (BehaviorVersion.get() >= 3 and (var1 is NotSpecified or var2 is NotSpecified))
8783
8778
  or var1 == "auto"
@@ -9150,9 +9145,9 @@ class DotLayer(LayerBase):
9150
9145
  raise Exception(
9151
9146
  "%s %r: " % (cls.__name__, name) + "%s not found in sources %r" % (red_axis_desc, sources)
9152
9147
  )
9153
- assert len(a_reduce_axes) == len(
9154
- b_reduce_axes
9155
- ), "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (name, sources, red1, red2)
9148
+ assert len(a_reduce_axes) == len(b_reduce_axes), (
9149
+ "%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (name, sources, red1, red2)
9150
+ )
9156
9151
  if (
9157
9152
  (BehaviorVersion.get() >= 3 and (var1 is NotSpecified or var2 is NotSpecified))
9158
9153
  or var1 == "auto"
@@ -10178,9 +10173,9 @@ class CondLayer(LayerBase):
10178
10173
  self.condition_desc = condition
10179
10174
  self.condition_layer = self._make_layer("condition", self.condition_desc)
10180
10175
  self.true_layer_desc = true_layer
10181
- self.true_layer = None # type: typing.Optional[LayerBase]
10176
+ self.true_layer: Optional[LayerBase] = None
10182
10177
  self.false_layer_desc = false_layer
10183
- self.false_layer = None # type: typing.Optional[LayerBase]
10178
+ self.false_layer: Optional[LayerBase] = None
10184
10179
  assert self.condition_layer.output.batch_ndim == 0 and self.condition_layer.output.dtype == "bool"
10185
10180
  self._extra_out_templates = {k: v[0] for k, v in _extra_out.items()}
10186
10181
  x, extra_out, sizes = tf_util.cond(
@@ -12070,7 +12065,7 @@ class HDFDumpLayer(LayerBase):
12070
12065
  for (key, output) in extra.items()
12071
12066
  }
12072
12067
  extra = {key: output.copy_as_batch_spatial_major() for (key, output) in extra.items()}
12073
- self.extra = extra # type: typing.Dict[str,Data]
12068
+ self.extra: Dict[str, Data] = extra
12074
12069
  self.dump_whole_batches = dump_whole_batches
12075
12070
  self.num_seqs_written = 0
12076
12071
  ndim = data.ndim
@@ -12454,9 +12449,9 @@ class BinaryCrossEntropyLoss(Loss):
12454
12449
 
12455
12450
  def _check_init(self):
12456
12451
  assert self.target is not None
12457
- assert (
12458
- self.target.batch_ndim == self.output.batch_ndim
12459
- ), "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
12452
+ assert self.target.batch_ndim == self.output.batch_ndim, (
12453
+ "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
12454
+ )
12460
12455
 
12461
12456
  def get_value(self):
12462
12457
  """
@@ -13020,7 +13015,7 @@ class ExpectedLoss(Loss):
13020
13015
  self.divide_beam_size = divide_beam_size
13021
13016
  self.subtract_average_loss = subtract_average_loss
13022
13017
  self.loss_correction_grad_only = loss_correction_grad_only
13023
- self.search_choices = None # type: typing.Optional[SearchChoices]
13018
+ self.search_choices: Optional[SearchChoices] = None
13024
13019
 
13025
13020
  @classmethod
13026
13021
  def transform_config_dict(cls, d, network, get_layer):
@@ -13120,9 +13115,9 @@ class DeepClusteringLoss(Loss):
13120
13115
  Does some checks on self.target and self.output, e.g. if the dense shapes matches.
13121
13116
  You can overwrite this if those checks don't make sense for your derived loss class.
13122
13117
  """
13123
- assert (
13124
- self.target.ndim_dense == self.output.ndim_dense
13125
- ), "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
13118
+ assert self.target.ndim_dense == self.output.ndim_dense, (
13119
+ "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
13120
+ )
13126
13121
  expected_output_dim = self._embedding_dimension * (self.target.shape[1] // self._nr_of_sources)
13127
13122
  assert expected_output_dim == self.output.dim, "Expected output dim is %i but the output has dim %r. " % (
13128
13123
  expected_output_dim,
@@ -13822,9 +13817,9 @@ class SamplingBasedLoss(Loss):
13822
13817
  else:
13823
13818
  loss_fn = tf.nn.sampled_softmax_loss
13824
13819
 
13825
- assert (
13826
- self.layer.params["W"].shape[0] == self.target.dim
13827
- ), "Expect weight matrix of shape [num_classes, dim]"
13820
+ assert self.layer.params["W"].shape[0] == self.target.dim, (
13821
+ "Expect weight matrix of shape [num_classes, dim]"
13822
+ )
13828
13823
  out = loss_fn(
13829
13824
  weights=self.layer.params["W"].read_value(), # (num_classes,D).
13830
13825
  biases=self.layer.params["b"].read_value(), # (num_classes).
returnn/tf/layers/rec.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
 
7
7
  import contextlib
8
8
  import typing
9
+ from typing import Dict, Optional, Tuple, Union
9
10
  import tensorflow as tf
10
11
  import returnn.tf.compat as tf_compat
11
12
 
@@ -1037,7 +1038,8 @@ class RecLayer(_ConcatInputLayer):
1037
1038
  scope=tf_compat.v1.get_variable_scope(),
1038
1039
  )
1039
1040
  elif rnn_contrib and isinstance(
1040
- cell, (rnn_contrib.FusedRNNCell, rnn_contrib.LSTMBlockWrapper) # noqa # e.g. LSTMBlockFusedCell
1041
+ cell,
1042
+ (rnn_contrib.FusedRNNCell, rnn_contrib.LSTMBlockWrapper), # noqa # e.g. LSTMBlockFusedCell
1041
1043
  ):
1042
1044
  # Will get (time,batch,ydim).
1043
1045
  assert self._max_seq_len is None
@@ -1280,9 +1282,9 @@ class RecLayer(_ConcatInputLayer):
1280
1282
  :param str|int|None key:
1281
1283
  :rtype: tf.Tensor
1282
1284
  """
1283
- assert (
1284
- self._last_hidden_state is not None
1285
- ), "last-hidden-state not implemented/supported for this layer-type. try another unit. see the code."
1285
+ assert self._last_hidden_state is not None, (
1286
+ "last-hidden-state not implemented/supported for this layer-type. try another unit. see the code."
1287
+ )
1286
1288
  return RnnCellLayer.get_state_by_key(self._last_hidden_state, key=key)
1287
1289
 
1288
1290
  @classmethod
@@ -1431,9 +1433,7 @@ class _SubnetworkRecCell:
1431
1433
  )
1432
1434
  self._last_frames = {} # type: typing.Dict[str,Data]
1433
1435
  self._initial_outputs = None # type: typing.Optional[typing.Dict[str,tf.Tensor]]
1434
- self._initial_extra_outputs = (
1435
- None
1436
- ) # type: typing.Optional[typing.Dict[str,typing.Dict[str,typing.Union[tf.Tensor,typing.Tuple[tf.Tensor,...]]]]] # nopep8
1436
+ self._initial_extra_outputs: Optional[Dict[str, Dict[str, Union[tf.Tensor, Tuple[tf.Tensor, ...]]]]] = None
1437
1437
 
1438
1438
  # input_layers_moved_out, output_layers_moved_out and layers_in_loop include (used) sub-layers as separate
1439
1439
  # entries, this way in- and outputting them to the loop via TensorArrays will be handled just as for normal
@@ -1608,14 +1608,9 @@ class _SubnetworkRecCell:
1608
1608
  while parent and parent.parent:
1609
1609
  parent_names.insert(0, parent.parent_name or "?")
1610
1610
  parent = parent.parent
1611
- return (
1612
- "<RecLayer construct template GetLayer>("
1613
- "allow_uninitialized_template %r, "
1614
- "parents %r)"
1615
- % (
1616
- lself.allow_uninitialized_template,
1617
- " <- ".join(parent_names) or None,
1618
- )
1611
+ return "<RecLayer construct template GetLayer>(allow_uninitialized_template %r, parents %r)" % (
1612
+ lself.allow_uninitialized_template,
1613
+ " <- ".join(parent_names) or None,
1619
1614
  )
1620
1615
 
1621
1616
  def _add_uninitialized_count(self):
@@ -2141,16 +2136,17 @@ class _SubnetworkRecCell:
2141
2136
  layer = self.input_layers_net.layers[layer_name]
2142
2137
  assert isinstance(layer, LayerBase)
2143
2138
  if layer_name not in inputs_moved_out_tas:
2144
- assert not layer.output.mark_same_time(
2145
- self._time_dim_tags
2146
- ), "%s does not expect to have matching time dim to %s" % (layer, self.parent_rec_layer)
2147
- assert (
2148
- name != "output" and not prev
2149
- ), "Time dim does not match: RecLayer %s (%r) vs sub layer %s (%r)." % (
2150
- self.parent_rec_layer,
2151
- self.parent_rec_layer.output.get_time_dim_tag(),
2152
- layer,
2153
- layer.output.get_time_dim_tag(),
2139
+ assert not layer.output.mark_same_time(self._time_dim_tags), (
2140
+ "%s does not expect to have matching time dim to %s" % (layer, self.parent_rec_layer)
2141
+ )
2142
+ assert name != "output" and not prev, (
2143
+ "Time dim does not match: RecLayer %s (%r) vs sub layer %s (%r)."
2144
+ % (
2145
+ self.parent_rec_layer,
2146
+ self.parent_rec_layer.output.get_time_dim_tag(),
2147
+ layer,
2148
+ layer.output.get_time_dim_tag(),
2149
+ )
2154
2150
  )
2155
2151
  return layer
2156
2152
  output = layer.output.copy_template_excluding_time_dim().copy_template_set_ctx(self.net.control_flow_ctx)
@@ -2376,9 +2372,9 @@ class _SubnetworkRecCell:
2376
2372
  assert output_template.output.dim == self.parent_rec_layer.output.dim
2377
2373
  assert self.parent_rec_layer.output.time_dim_axis == 0
2378
2374
  assert not output_template.output.has_axis(self.time_dim_tag)
2379
- assert (
2380
- output_template.output.batch_shape == self.parent_rec_layer.output.batch_shape[1:]
2381
- ), "see RecLayer.get_out_data_from_opts()"
2375
+ assert output_template.output.batch_shape == self.parent_rec_layer.output.batch_shape[1:], (
2376
+ "see RecLayer.get_out_data_from_opts()"
2377
+ )
2382
2378
 
2383
2379
  def get_init_loop_vars(self):
2384
2380
  """
@@ -3014,9 +3010,9 @@ class _SubnetworkRecCell:
3014
3010
  needed_outputs.add("end")
3015
3011
  assert tf.as_dtype(end_template.output.dtype) is tf.bool
3016
3012
  else:
3017
- assert (
3018
- have_known_seq_len
3019
- ), "You need to have an 'end' layer in your rec subnet if the generated seq len is unknown."
3013
+ assert have_known_seq_len, (
3014
+ "You need to have an 'end' layer in your rec subnet if the generated seq len is unknown."
3015
+ )
3020
3016
 
3021
3017
  # noinspection PyProtectedMember
3022
3018
  if self.parent_rec_layer._optimize_move_layers_out:
@@ -3358,11 +3354,12 @@ class _SubnetworkRecCell:
3358
3354
  from .basic import SelectSearchSourcesLayer
3359
3355
 
3360
3356
  prev_end_layer = choices.translate_to_this_search_beam(prev_end_layer)
3361
- assert isinstance(
3362
- prev_end_layer, SelectSearchSourcesLayer
3363
- ), "unexpected search choices: cur end %r, prev end %r" % (
3364
- choices,
3365
- prev_end_layer.get_search_choices(),
3357
+ assert isinstance(prev_end_layer, SelectSearchSourcesLayer), (
3358
+ "unexpected search choices: cur end %r, prev end %r"
3359
+ % (
3360
+ choices,
3361
+ prev_end_layer.get_search_choices(),
3362
+ )
3366
3363
  )
3367
3364
  prev_end_flag = prev_end_layer.output.placeholder
3368
3365
  with tf.name_scope("dyn_seq_len"):
@@ -3475,14 +3472,15 @@ class _SubnetworkRecCell:
3475
3472
  assert fixed_seq_len is not None
3476
3473
  seq_len = fixed_seq_len
3477
3474
  if output_beam:
3478
- assert (
3479
- not input_beam or input_beam == output_beam
3480
- ), "%s: input beam %r, output beam %r, sources %r, target %r" % (
3481
- self.parent_rec_layer,
3482
- input_beam,
3483
- output_beam,
3484
- self.parent_rec_layer.sources,
3485
- self.parent_rec_layer.target,
3475
+ assert not input_beam or input_beam == output_beam, (
3476
+ "%s: input beam %r, output beam %r, sources %r, target %r"
3477
+ % (
3478
+ self.parent_rec_layer,
3479
+ input_beam,
3480
+ output_beam,
3481
+ self.parent_rec_layer.sources,
3482
+ self.parent_rec_layer.target,
3483
+ )
3486
3484
  )
3487
3485
  assert output_template.output.batch.beam == output_beam
3488
3486
  time_dim_tag = time_dim_tag.get_for_batch_ctx(
@@ -3791,9 +3789,9 @@ class _SubnetworkRecCell:
3791
3789
  if end_layer_choice.name.startswith("prev:"):
3792
3790
  # Logic from maybe_transform. It would be translated to the current beam.
3793
3791
  end_layer_choice = self.net.layers[end_layer_choice.name[len("prev:") :]]
3794
- assert (
3795
- end_layer_choice in choice_seq_in_frame
3796
- ), "End layer must not have a beam independent from output layer '{}'.".format(layer_name)
3792
+ assert end_layer_choice in choice_seq_in_frame, (
3793
+ "End layer must not have a beam independent from output layer '{}'.".format(layer_name)
3794
+ )
3797
3795
 
3798
3796
  end_layer_choice_index = choice_seq_in_frame.index(end_layer_choice)
3799
3797
  choices_seq_until_end_layer = choice_seq_in_frame[:end_layer_choice_index]
@@ -5856,12 +5854,13 @@ class RecUnstackLayer(LayerBase):
5856
5854
  if out_dim.is_dim_known(): # usually the case except at template construction
5857
5855
  assert out_dim != rec_time_dim # rec_time_dim is unknown, so it cannot be the same
5858
5856
  if out_dim != rec_time_dim:
5859
- assert (
5860
- declare_rec_time
5861
- ), "%s %r: must either set known axis on rec %s or enable declare_rec_time" % (
5862
- cls.__name__,
5863
- name,
5864
- rec_time_dim,
5857
+ assert declare_rec_time, (
5858
+ "%s %r: must either set known axis on rec %s or enable declare_rec_time"
5859
+ % (
5860
+ cls.__name__,
5861
+ name,
5862
+ rec_time_dim,
5863
+ )
5865
5864
  )
5866
5865
  rec_time_dim.declare_same_as(out_dim)
5867
5866
  out.mark_same_time(out_dim, must_match=True)
@@ -6132,12 +6131,13 @@ class ChoiceLayer(BaseChoiceLayer):
6132
6131
  base_beam_in = tf.shape(scores_base)[1] # 1 in first frame, then beam_in
6133
6132
  scores_beam_in = tf.shape(scores_in)[0] // net_batch_dim
6134
6133
  beam_in = self.sources[0].output.beam.beam_size
6135
- assert (
6136
- beam_in == base_search_choices.beam_size
6137
- ), "%r: source %r beam-size unexpected from base choice %r" % (
6138
- self,
6139
- self.sources[0],
6140
- base_search_choices,
6134
+ assert beam_in == base_search_choices.beam_size, (
6135
+ "%r: source %r beam-size unexpected from base choice %r"
6136
+ % (
6137
+ self,
6138
+ self.sources[0],
6139
+ base_search_choices,
6140
+ )
6141
6141
  )
6142
6142
  # About incoming beam size:
6143
6143
  # base_beam_in - 1 in first frame, then beam_in
@@ -7510,9 +7510,9 @@ class GenericAttentionLayer(AttentionBaseLayer):
7510
7510
  base_rem_axes = base.get_axes(exclude_batch=True, exclude_time=True)
7511
7511
  base_rem_axes.remove(base.feature_dim_axis)
7512
7512
  weights_rem_axes = weights.get_axes(exclude_batch=True)
7513
- assert (
7514
- weights.time_dim_axis is not None
7515
- ), f"{exception_prefix}: base {base}, weights {weights}, need time_dim_axis in weights"
7513
+ assert weights.time_dim_axis is not None, (
7514
+ f"{exception_prefix}: base {base}, weights {weights}, need time_dim_axis in weights"
7515
+ )
7516
7516
  weights_axis_to_reduce = cls._weights_time_axis_to_reduce(weights=weights, base=base)
7517
7517
  assert weights.batch_shape[weights_axis_to_reduce] == base.batch_shape[base.time_dim_axis]
7518
7518
  weights_rem_axes.remove(weights_axis_to_reduce)
@@ -9088,13 +9088,13 @@ class MaskedComputationLayer(LayerBase):
9088
9088
  new_size, new_time, idxs = None, None, None
9089
9089
  if mask:
9090
9090
  if self.network.is_inside_rec_layer():
9091
- assert (
9092
- mask.output.shape == () and mask.output.dtype == "bool"
9093
- ), "%s: invalid mask %s (inside rec loop)" % (self, mask)
9091
+ assert mask.output.shape == () and mask.output.dtype == "bool", (
9092
+ "%s: invalid mask %s (inside rec loop)" % (self, mask)
9093
+ )
9094
9094
  else:
9095
- assert (
9096
- mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool"
9097
- ), "%s: invalid mask %s (outside rec loop)" % (self, mask)
9095
+ assert mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool", (
9096
+ "%s: invalid mask %s (outside rec loop)" % (self, mask)
9097
+ )
9098
9098
  assert in_spatial_dim and out_spatial_dim
9099
9099
  mask_data = mask.output.copy_as_time_major()
9100
9100
  mask_t = where_bc(mask_data.placeholder, mask_data.get_sequence_mask(), tf.convert_to_tensor(False))
@@ -9785,9 +9785,9 @@ class UnmaskLayer(LayerBase):
9785
9785
  with same_control_flow_ctx(src_layer.output.placeholder):
9786
9786
  src = src_layer.output.copy_as_bt_or_tb_major()
9787
9787
  mask_out = self.mask.output
9788
- assert (
9789
- mask_out.shape == () and mask_out.batch_shape == (None,) and mask_out.dtype == "bool"
9790
- ), "%s: invalid mask %s (inside rec loop)" % (self, self.mask)
9788
+ assert mask_out.shape == () and mask_out.batch_shape == (None,) and mask_out.dtype == "bool", (
9789
+ "%s: invalid mask %s (inside rec loop)" % (self, self.mask)
9790
+ )
9791
9791
  prev_t = self._rec_previous_layer.rec_vars_outputs["t"] # [B]
9792
9792
  t = prev_t + tf.cast(mask_out.placeholder, tf.int32) # [B]
9793
9793
  self.rec_vars_outputs["t"] = t
@@ -11192,9 +11192,9 @@ class RelativePositionalEncodingLayer(_ConcatInputLayer):
11192
11192
  and is_axis_from_description_recurrent(key_value_spatial_dim, network=self.network, data=self.input_data)
11193
11193
  ):
11194
11194
  length = self.network.get_rec_step_index() + 1
11195
- assert (
11196
- key_value_spatial_dim_.dimension is None
11197
- ), f"{self}: unexpected kv spatial dim {key_value_spatial_dim_}"
11195
+ assert key_value_spatial_dim_.dimension is None, (
11196
+ f"{self}: unexpected kv spatial dim {key_value_spatial_dim_}"
11197
+ )
11198
11198
  assert key_value_spatial_dim_.dyn_size_ext is not None
11199
11199
  # See CumConcatLayer for similar logic
11200
11200
  if key_value_spatial_dim_.dyn_size_ext.placeholder is None:
returnn/tf/native_op.py CHANGED
@@ -283,9 +283,7 @@ class OpMaker:
283
283
  // otherwise it will trigger an assertion.
284
284
  if (IsRefType(context->input_dtype({in_idx})))
285
285
  context->forward_ref_input_to_ref_output({in_idx}, {out_idx});
286
- """.format(
287
- in_idx=in_idx, out_idx=out_idx
288
- )
286
+ """.format(in_idx=in_idx, out_idx=out_idx)
289
287
  code_set_io = ""
290
288
  for in_idx, v in enumerate(in_info):
291
289
  ndim = len(v["shape"])