returnn 1.20250508.93313__py3-none-any.whl → 1.20250508.181644__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (67) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/datasets/basic.py +24 -25
  4. returnn/datasets/cached.py +4 -3
  5. returnn/datasets/distrib_files.py +1 -2
  6. returnn/datasets/generating.py +20 -20
  7. returnn/datasets/hdf.py +9 -9
  8. returnn/datasets/lm.py +25 -13
  9. returnn/datasets/meta.py +39 -38
  10. returnn/datasets/normalization_data.py +1 -1
  11. returnn/datasets/postprocessing.py +9 -9
  12. returnn/datasets/sprint.py +8 -7
  13. returnn/datasets/util/strings.py +0 -1
  14. returnn/datasets/util/vocabulary.py +3 -3
  15. returnn/extern/graph_editor/subgraph.py +1 -2
  16. returnn/extern/graph_editor/transform.py +1 -2
  17. returnn/extern/graph_editor/util.py +1 -2
  18. returnn/frontend/_backend.py +4 -3
  19. returnn/frontend/_utils.py +1 -1
  20. returnn/frontend/audio/mel.py +0 -1
  21. returnn/frontend/const.py +3 -3
  22. returnn/frontend/device.py +0 -1
  23. returnn/frontend/dropout.py +1 -1
  24. returnn/frontend/encoder/e_branchformer.py +1 -1
  25. returnn/frontend/loop.py +3 -3
  26. returnn/frontend/loss.py +0 -1
  27. returnn/frontend/matmul.py +0 -1
  28. returnn/frontend/run_ctx.py +9 -9
  29. returnn/frontend/signal.py +0 -1
  30. returnn/frontend/types.py +2 -4
  31. returnn/native_op.py +13 -0
  32. returnn/sprint/cache.py +2 -4
  33. returnn/sprint/interface.py +3 -4
  34. returnn/tensor/_dim_extra.py +9 -9
  35. returnn/tensor/_tensor_extra.py +20 -19
  36. returnn/tensor/_tensor_op_overloads.py +0 -1
  37. returnn/tensor/tensor.py +1 -1
  38. returnn/tensor/tensor_dict.py +9 -9
  39. returnn/tf/engine.py +60 -65
  40. returnn/tf/frontend_layers/_backend.py +3 -3
  41. returnn/tf/frontend_layers/cond.py +6 -6
  42. returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
  43. returnn/tf/frontend_layers/layer.py +12 -12
  44. returnn/tf/frontend_layers/loop.py +3 -3
  45. returnn/tf/frontend_layers/make_layer.py +0 -1
  46. returnn/tf/layers/base.py +56 -49
  47. returnn/tf/layers/basic.py +60 -65
  48. returnn/tf/layers/rec.py +74 -74
  49. returnn/tf/native_op.py +1 -3
  50. returnn/tf/network.py +60 -57
  51. returnn/tf/updater.py +3 -3
  52. returnn/tf/util/basic.py +24 -23
  53. returnn/torch/data/extern_data.py +4 -5
  54. returnn/torch/data/pipeline.py +3 -4
  55. returnn/torch/engine.py +16 -16
  56. returnn/torch/frontend/_backend.py +15 -15
  57. returnn/torch/frontend/bridge.py +3 -3
  58. returnn/torch/updater.py +8 -9
  59. returnn/torch/util/debug_inf_nan.py +0 -2
  60. returnn/torch/util/exception_helper.py +1 -1
  61. returnn/torch/util/scaled_gradient.py +0 -1
  62. returnn/util/basic.py +1 -2
  63. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/METADATA +1 -1
  64. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/RECORD +67 -67
  65. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/LICENSE +0 -0
  66. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/WHEEL +0 -0
  67. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/top_level.txt +0 -0
returnn/tf/network.py CHANGED
@@ -3,7 +3,9 @@ Defines the :class:`TFNetwork` and :class:`ExternData`.
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Any, Protocol, List, Tuple, Dict
6
+
7
+ from typing import Callable, List, Optional, Any, Protocol, Tuple, Dict, TYPE_CHECKING, Union
8
+
7
9
  import tensorflow as tf
8
10
  import sys
9
11
  import re
@@ -19,6 +21,11 @@ from returnn.tensor import Tensor, Dim, TensorDict
19
21
  from returnn.tf.util.data import Data
20
22
  from returnn.util import basic as util
21
23
 
24
+ if TYPE_CHECKING:
25
+ from returnn.config import Config
26
+ from returnn.tf.layers.base import SearchChoices
27
+ from returnn.tf.util.data import BatchInfo
28
+
22
29
 
23
30
  class DataNotFound(Exception):
24
31
  """
@@ -39,8 +46,8 @@ class ExternData(TensorDict):
39
46
  :param None|dict[str,dict[str]] data: optional init kwargs for Data
40
47
  """
41
48
  super().__init__()
42
- self._config = None # type: typing.Optional["returnn.config.Config"]
43
- self._batch_info = None # type: typing.Optional["returnn.tf.util.data.BatchInfo"]
49
+ self._config: typing.Optional[Config] = None
50
+ self._batch_info: typing.Optional[BatchInfo] = None
44
51
  self.default_input = default_input
45
52
  self.default_target = default_target
46
53
  self.extra_added_keys = set() # set[str]
@@ -369,8 +376,7 @@ def _extern_data_types_from_config(config):
369
376
  print("Warning: Using extern_data and will ignore num_inputs/num_outputs in config.", file=log.v2)
370
377
  else:
371
378
  log.print_deprecation_warning(
372
- "Using num_inputs/num_outputs instead of extern_data is deprecated"
373
- " and might be removed in future versions"
379
+ "Using num_inputs/num_outputs instead of extern_data is deprecated and might be removed in future versions"
374
380
  )
375
381
  num_inputs, num_outputs = _num_inputs_outputs_from_config(config)
376
382
  data_dims = num_outputs.copy()
@@ -502,7 +508,7 @@ class _NetworkConstructionStack:
502
508
  """
503
509
 
504
510
  def __init__(self):
505
- self.layers = [] # type: typing.List[str]
511
+ self.layers: typing.List[str] = []
506
512
  self.in_flat_construct_count = 0
507
513
 
508
514
  def append(self, layer_name):
@@ -645,33 +651,31 @@ class TFNetwork:
645
651
  self.extra_deps_in_extra = False
646
652
  self.extra_only_template = False
647
653
  self.is_root_in_ctx = not parent_net # default. might be overwritten
648
- self.extra_nets = {} # type: typing.Dict[str,TFNetwork]
649
- self.subnets = {} # type: typing.Dict[str,Subnetwork]
654
+ self.extra_nets: Dict[str, TFNetwork] = {}
655
+ self.subnets: Dict[str, Subnetwork] = {}
650
656
  self._selected_train_layers = None
651
657
  self._construction_stack = _NetworkConstructionStack()
652
658
  self.layers_desc: Dict[str, Dict[str, Any]] = {}
653
659
  self.layers: Dict[str, LayerBase] = {}
654
- self.losses_dict = {} # type: typing.Dict[str,LossHolder]
655
- self.total_loss = None # type: typing.Optional[tf.Tensor]
656
- self.total_constraints = None # type: typing.Optional[tf.Tensor]
657
- self.total_objective = None # type: typing.Optional[tf.Tensor]
658
- self._global_train_step = None # type: typing.Optional[tf.Tensor]
659
- self._global_train_step_var = None # type: typing.Optional[tf.Variable]
660
+ self.losses_dict: Dict[str, LossHolder] = {}
661
+ self.total_loss: Optional[tf.Tensor] = None
662
+ self.total_constraints: Optional[tf.Tensor] = None
663
+ self.total_objective: Optional[tf.Tensor] = None
664
+ self._global_train_step: Optional[tf.Tensor] = None
665
+ self._global_train_step_var: Optional[tf.Variable] = None
660
666
  self.epoch_step = None
661
- self.saver = None # type: typing.Optional[tf.compat.v1.train.Saver]
662
- self.extra_vars_to_save = [] # type: typing.List[tf.Variable]
667
+ self.saver: Optional[tf.compat.v1.train.Saver] = None
668
+ self.extra_vars_to_save: List[tf.Variable] = []
663
669
  self.recurrent = False
664
- self._assigner_cache = {} # type: typing.Dict[tf.Variable,tf_util.VariableAssigner]
670
+ self._assigner_cache: Dict[tf.Variable, tf_util.VariableAssigner] = {}
665
671
  self.concat_sources_dropout_cache: Dict[
666
672
  Tuple[Tuple[LayerBase, ...], Dim, float, Optional[Tuple[Optional[int], ...]]], Data
667
673
  ] = {}
668
- self._merge_all_summaries = None # type: typing.Optional[tf.Tensor]
669
- self._graph_reset_callbacks = [] # type: typing.List[typing.Callable]
670
- self._run_opts = {} # type: typing.Dict[str, typing.Any]
671
- self._run_finished_callbacks = [] # type: typing.List[typing.Callable]
672
- self._map_search_beam_to_search_choices = (
673
- {}
674
- ) # type: typing.Dict[tf_util.SearchBeam,"returnn.tf.layers.base.SearchChoices"]
674
+ self._merge_all_summaries: Optional[tf.Tensor] = None
675
+ self._graph_reset_callbacks: List[Callable] = []
676
+ self._run_opts: Dict[str, Any] = {}
677
+ self._run_finished_callbacks: List[Callable] = []
678
+ self._map_search_beam_to_search_choices: Dict[tf_util.SearchBeam, SearchChoices] = {}
675
679
 
676
680
  def __repr__(self):
677
681
  s = "TFNetwork %r" % self.name
@@ -1308,15 +1312,16 @@ class TFNetwork:
1308
1312
  layer.output.sanity_check()
1309
1313
  # The axes should not have moved now.
1310
1314
  output_special_axes = layer.output.get_special_axes_dict()
1311
- assert (
1312
- output_template_special_axes == output_special_axes
1313
- ), "%s %r: not equal: %r == %r, from data %r -> %r" % (
1314
- layer_class.__name__,
1315
- name,
1316
- output_template_special_axes,
1317
- output_special_axes,
1318
- output_template,
1319
- layer.output,
1315
+ assert output_template_special_axes == output_special_axes, (
1316
+ "%s %r: not equal: %r == %r, from data %r -> %r"
1317
+ % (
1318
+ layer_class.__name__,
1319
+ name,
1320
+ output_template_special_axes,
1321
+ output_special_axes,
1322
+ output_template,
1323
+ layer.output,
1324
+ )
1320
1325
  )
1321
1326
  except TypeError:
1322
1327
  help_on_type_error_wrong_args(cls=layer_class, kwargs=list(layer_desc.keys()))
@@ -1486,7 +1491,7 @@ class TFNetwork:
1486
1491
  else:
1487
1492
  total_loss = None
1488
1493
  total_constraints = None
1489
- losses_multi_dict = {} # type: typing.Dict[str,typing.List[typing.Tuple[typing.Optional[str],LossHolder]]]
1494
+ losses_multi_dict: Dict[str, List[Tuple[Optional[str], LossHolder]]] = {}
1490
1495
  # self.layers also include extra net layers and sub layers, see add_layer.
1491
1496
  for name, layer in sorted(self.layers.items()):
1492
1497
  assert isinstance(layer, LayerBase)
@@ -1869,14 +1874,15 @@ class TFNetwork:
1869
1874
 
1870
1875
  # All end points must be mapped now.
1871
1876
  for layer in end_points:
1872
- assert (
1873
- layer in mapped_layers
1874
- ), "end point %r not mapped.\n end points:\n%s\n mapped:\n%s\n blacklist:\n%s\n starting points:\n%s" % (
1875
- layer,
1876
- pformat(end_points),
1877
- pformat(mapped_layers),
1878
- pformat(blacklist),
1879
- pformat(starting_points),
1877
+ assert layer in mapped_layers, (
1878
+ "end point %r not mapped.\n end points:\n%s\n mapped:\n%s\n blacklist:\n%s\n starting points:\n%s"
1879
+ % (
1880
+ layer,
1881
+ pformat(end_points),
1882
+ pformat(mapped_layers),
1883
+ pformat(blacklist),
1884
+ pformat(starting_points),
1885
+ )
1880
1886
  )
1881
1887
  # Assign flatten_with_seq_len_mask cache to mapped layers.
1882
1888
  for layer, new_layer in mapped_layers.items():
@@ -2402,9 +2408,7 @@ class TFNetwork:
2402
2408
 
2403
2409
  Note that this excludes auxiliary params.
2404
2410
  """
2405
- layers = {
2406
- layer.get_absolute_name(): layer for layer in self.get_all_layers_deep()
2407
- } # type: typing.Dict[str,LayerBase]
2411
+ layers: Dict[str, LayerBase] = {layer.get_absolute_name(): layer for layer in self.get_all_layers_deep()}
2408
2412
  for layer_name, layer_values_dict in values_dict.items():
2409
2413
  if layer_values_dict:
2410
2414
  if ignore_non_existing and layer_name not in layers:
@@ -4091,9 +4095,9 @@ class LossHolder:
4091
4095
  self._error_value = self._layer._cond_only_on_eval_opt(self.loss.get_error, default_value=0.0)
4092
4096
  else:
4093
4097
  self._error_value = self.loss.get_error()
4094
- assert (
4095
- self._loss_value is not None or self._error_value is not None
4096
- ), "layer %r loss %r return None for loss and error" % (self._layer, self.loss)
4098
+ assert self._loss_value is not None or self._error_value is not None, (
4099
+ "layer %r loss %r return None for loss and error" % (self._layer, self.loss)
4100
+ )
4097
4101
  if self._norm_factor is None:
4098
4102
  self._norm_factor = self.loss.get_normalization_factor()
4099
4103
  loss_value = self._loss_value
@@ -4515,12 +4519,12 @@ class CustomCheckpointLoader:
4515
4519
  # All variables in the checkpoint:
4516
4520
  self.var_ckpt_names = set(self.reader.get_variable_to_shape_map()) # type: typing.Set[str]
4517
4521
  # All variables of the model to be loaded:
4518
- self.var_net_names = {
4522
+ self.var_net_names: Dict[str, Union[tf.Variable, Any]] = {
4519
4523
  self._get_param_name(v): v for v in self.saveable_params
4520
- } # type: typing.Dict[str,typing.Union[tf.Variable,typing.Any]]
4524
+ }
4521
4525
  # Model variables missing in the checkpoint:
4522
- self.missing_var_names = [] # type: typing.List[str]
4523
- self.missing_non_critical_var_names = [] # type: typing.List[str]
4526
+ self.missing_var_names: List[str] = []
4527
+ self.missing_non_critical_var_names: List[str] = []
4524
4528
  for name, v in sorted(self.var_net_names.items()):
4525
4529
  if name in self.var_ckpt_names:
4526
4530
  continue
@@ -4702,10 +4706,10 @@ class CustomCheckpointLoader:
4702
4706
  "rnn/lstm_cell/bias": "lstm_cell/bias",
4703
4707
  "rnn/lstm_cell/kernel": "lstm_cell/kernel",
4704
4708
  (
4705
- "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/" "cudnn_compatible_lstm_cell/bias"
4709
+ "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias"
4706
4710
  ): "lstm_fused_cell/bias",
4707
4711
  (
4708
- "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/" "cudnn_compatible_lstm_cell/kernel"
4712
+ "cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel"
4709
4713
  ): "lstm_fused_cell/kernel",
4710
4714
  }
4711
4715
 
@@ -4877,7 +4881,7 @@ class CustomCheckpointLoader:
4877
4881
  self.target = target
4878
4882
  self.keys = [target + "bias", target + "kernel"]
4879
4883
  self.prefix = prefix
4880
- self.data = None # type: typing.Optional[typing.Dict[str,numpy.ndarray]]
4884
+ self.data: typing.Optional[typing.Dict[str, numpy.ndarray]] = None
4881
4885
 
4882
4886
  # noinspection PyMethodParameters
4883
4887
  def _load(sself):
@@ -5140,8 +5144,7 @@ class CustomLoadParamFunc(Protocol):
5140
5144
 
5141
5145
  def __call__(
5142
5146
  self, *, name: str, shape: Tuple[int], reader: tf.compat.v1.train.NewCheckpointReader
5143
- ) -> Optional[numpy.ndarray]:
5144
- ...
5147
+ ) -> Optional[numpy.ndarray]: ...
5145
5148
 
5146
5149
 
5147
5150
  def set_custom_post_init(var, func):
returnn/tf/updater.py CHANGED
@@ -219,9 +219,9 @@ class Updater:
219
219
 
220
220
  learning_rate_function = self.config.typed_dict.get("dynamic_learning_rate")
221
221
  signature = inspect.signature(learning_rate_function)
222
- assert any(
223
- [arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]
224
- ), "please specify **kwargs in dynamic_learning_rate for future compatibility"
222
+ assert any([arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]), (
223
+ "please specify **kwargs in dynamic_learning_rate for future compatibility"
224
+ )
225
225
  if "epoch" in signature.parameters:
226
226
  raise NotImplementedError("TF updater: dynamic_learning_rate with epoch not supported currently")
227
227
  lr = learning_rate_function(
returnn/tf/util/basic.py CHANGED
@@ -1799,7 +1799,7 @@ def dropout(
1799
1799
  x = tf.convert_to_tensor(x, name="x")
1800
1800
  assert isinstance(x, tf.Tensor)
1801
1801
  if isinstance(keep_prob, (float, int)) and not 0 < keep_prob <= 1:
1802
- raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob)
1802
+ raise ValueError("keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob)
1803
1803
  # Do nothing if we know keep_prob == 1
1804
1804
  if isinstance(keep_prob, (float, int)) and keep_prob == 1:
1805
1805
  return x
@@ -2492,9 +2492,9 @@ def get_common_shape(values, ignore_axes=(), allow_broadcast_all_sources=NotSpec
2492
2492
  import numpy
2493
2493
 
2494
2494
  assert len(values) > 0
2495
- assert all(
2496
- [isinstance(value, (tf.Tensor, tf.Variable, float, int, numpy.number)) for value in values]
2497
- ), "types %r" % ([type(v) for v in values])
2495
+ assert all([isinstance(value, (tf.Tensor, tf.Variable, float, int, numpy.number)) for value in values]), (
2496
+ "types %r" % ([type(v) for v in values])
2497
+ )
2498
2498
  # Filter out scalars.
2499
2499
  values = [value for value in values if isinstance(value, (tf.Tensor, tf.Variable))]
2500
2500
  assert all([value.shape.ndims is not None for value in values]), "some unknown ndim"
@@ -2523,14 +2523,15 @@ def get_common_shape(values, ignore_axes=(), allow_broadcast_all_sources=NotSpec
2523
2523
  common_shape[axis] = static_dim
2524
2524
  else: # common_shape is int
2525
2525
  assert isinstance(common_shape[axis], int)
2526
- assert (
2527
- common_shape[axis] == static_dim
2528
- ), "non matching dim %r vs %r in axis %i, value %r of values %r" % (
2529
- common_shape[axis],
2530
- static_dim,
2531
- axis,
2532
- value,
2533
- values,
2526
+ assert common_shape[axis] == static_dim, (
2527
+ "non matching dim %r vs %r in axis %i, value %r of values %r"
2528
+ % (
2529
+ common_shape[axis],
2530
+ static_dim,
2531
+ axis,
2532
+ value,
2533
+ values,
2534
+ )
2534
2535
  )
2535
2536
  # Check validate_broadcast_all_sources
2536
2537
  need_broadcast = {id(value): False for value in values}
@@ -2576,9 +2577,9 @@ def unbroadcast_to_common_shape(value, common_shape, ignore_axes=(), allow_only_
2576
2577
  for axis in ignore_axes:
2577
2578
  assert 0 <= axis < ndim
2578
2579
  tile_multiples[axis] = 1
2579
- assert all(
2580
- [m is not None for m in tile_multiples]
2581
- ), "ignore_axes %r probably missing some axis for common shape %r" % (ignore_axes, common_shape)
2580
+ assert all([m is not None for m in tile_multiples]), (
2581
+ "ignore_axes %r probably missing some axis for common shape %r" % (ignore_axes, common_shape)
2582
+ )
2582
2583
  if all([isinstance(m, int) and m == 1 for m in tile_multiples]):
2583
2584
  # We have a no-op.
2584
2585
  return value
@@ -6611,7 +6612,6 @@ def find_unsupported_devices_in_graph(graph, dev_name, ignore=None):
6611
6612
 
6612
6613
 
6613
6614
  class _DeviceAttrMod:
6614
-
6615
6615
  _tf_mod = None
6616
6616
 
6617
6617
  @classmethod
@@ -7680,13 +7680,14 @@ class FetchHelper:
7680
7680
  _, info = copier(sgv, dst_graph=sgv.graph, dst_scope="", reuse_dst_scope=True)
7681
7681
  assert isinstance(info, graph_editor.TransformerInfo)
7682
7682
  target_op_transformed = info.transformed(target_op)
7683
- assert isinstance(
7684
- target_op_transformed, tf.Operation
7685
- ), "\ntarget_op\n%r,\nfetches\n%r,\nstop_at_ts\n%s,\nops\n%s" % (
7686
- target_op,
7687
- fetches,
7688
- pformat(stop_at_ts),
7689
- pformat(ops),
7683
+ assert isinstance(target_op_transformed, tf.Operation), (
7684
+ "\ntarget_op\n%r,\nfetches\n%r,\nstop_at_ts\n%s,\nops\n%s"
7685
+ % (
7686
+ target_op,
7687
+ fetches,
7688
+ pformat(stop_at_ts),
7689
+ pformat(ops),
7690
+ )
7690
7691
  )
7691
7692
  fetch_helpers = []
7692
7693
  for x in fetch_helper_tensors:
@@ -56,9 +56,9 @@ def raw_dict_to_extern_data(
56
56
  assert len(raw_tensor.shape) == data.batch_ndim, f"ndim mismatch for {k}: {raw_tensor.shape} vs {data}"
57
57
  for i, dim in enumerate(data.dims):
58
58
  if dim.dimension is not None:
59
- assert (
60
- dim.dimension == raw_tensor.shape[i]
61
- ), f"shape mismatch for {k}: {raw_tensor.shape} vs {data.batch_shape}"
59
+ assert dim.dimension == raw_tensor.shape[i], (
60
+ f"shape mismatch for {k}: {raw_tensor.shape} vs {data.batch_shape}"
61
+ )
62
62
  if isinstance(raw_tensor, torch.Tensor):
63
63
  if raw_tensor.dtype.is_floating_point and float_dtype:
64
64
  raw_tensor = raw_tensor.to(dtype=float_dtype)
@@ -81,8 +81,7 @@ def raw_dict_to_extern_data(
81
81
  and (data.dims[1].dyn_size_ext is None or data.dims[1].dyn_size_ext.raw_tensor is None)
82
82
  ):
83
83
  assert k + ":seq_len" in extern_data_raw, (
84
- f"extern_data {data}, dyn spatial dim, missing {k}:seq_len in raw dict, "
85
- f"check dataset or collate_batch"
84
+ f"extern_data {data}, dyn spatial dim, missing {k}:seq_len in raw dict, check dataset or collate_batch"
86
85
  )
87
86
  size = extern_data_raw[k + ":seq_len"]
88
87
  # Sequence lengths have to be on CPU for the later call to rnn.pack_padded_sequence
@@ -123,7 +123,6 @@ class ChunkingIterDataPipe(torch.utils.data.IterDataPipe):
123
123
  chunking_data_keys = list(self._chunk_size.keys())
124
124
 
125
125
  for data_dict in self._dataset:
126
-
127
126
  if not chunking_data_keys:
128
127
  chunking_data_keys = list(data_dict.keys()) # use all if not configured separately
129
128
  chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs", "epoch", "complete_frac"]
@@ -150,9 +149,9 @@ class ChunkingIterDataPipe(torch.utils.data.IterDataPipe):
150
149
  if num_chunks is None:
151
150
  num_chunks = len(chunks)
152
151
  else:
153
- assert num_chunks == len(
154
- chunks
155
- ), "Chunking resulted in different number of chunks for different data keys."
152
+ assert num_chunks == len(chunks), (
153
+ "Chunking resulted in different number of chunks for different data keys."
154
+ )
156
155
 
157
156
  data_chunks[data_key] = chunks
158
157
 
returnn/torch/engine.py CHANGED
@@ -66,22 +66,22 @@ class Engine(EngineBase):
66
66
  self.model_filename = self.config.value("model", None)
67
67
  self._mp_manager = torch.multiprocessing.Manager()
68
68
  self._epoch_mp_shared = self._mp_manager.Value("i", 0)
69
- self.train_dataset = None # type: Optional[Dataset]
69
+ self.train_dataset: Optional[Dataset] = None
70
70
  self.eval_datasets = {}
71
- self.extern_data = None # type: Optional[TensorDict]
72
- self._train_dataloader = None # type: Optional[DataLoader]
73
- self._eval_dataloaders = {} # type: Dict[str, DataLoader]
71
+ self.extern_data: Optional[TensorDict] = None
72
+ self._train_dataloader: Optional[DataLoader] = None
73
+ self._eval_dataloaders: Dict[str, DataLoader] = {}
74
74
 
75
- self._start_epoch = None # type: Optional[int]
76
- self._final_epoch = None # type: Optional[int]
77
- self._min_seq_length = config.typed_value("min_seq_length", None) or config.int(
75
+ self._start_epoch: Optional[int] = None
76
+ self._final_epoch: Optional[int] = None
77
+ self._min_seq_length: Union[int, float, Dict[str, int], NumbersDict] = config.typed_value(
78
78
  "min_seq_length", None
79
- ) # type: Union[int,float,Dict[str,int],NumbersDict]
80
- self._max_seq_length = config.typed_value("max_seq_length", None) or config.int(
79
+ ) or config.int("min_seq_length", None)
80
+ self._max_seq_length: Union[int, float, Dict[str, int], NumbersDict] = config.typed_value(
81
81
  "max_seq_length", None
82
- ) # type: Union[int,float,Dict[str,int],NumbersDict]
83
- self._orig_model = None # type: Optional[Union[rf.Module, torch.nn.Module]]
84
- self._pt_model = None # type: Optional[torch.nn.Module]
82
+ ) or config.int("max_seq_length", None)
83
+ self._orig_model: Optional[Union[rf.Module, torch.nn.Module]] = None
84
+ self._pt_model: Optional[torch.nn.Module] = None
85
85
  self._epoch_start_func: Optional[Callable] = self.config.typed_value("epoch_start")
86
86
  self._epoch_end_func: Optional[Callable] = self.config.typed_value("epoch_end")
87
87
  self._train_step_func: Optional[Callable] = None
@@ -95,15 +95,15 @@ class Engine(EngineBase):
95
95
  self._updater: Optional[Updater] = None
96
96
 
97
97
  self._use_autocast = False
98
- self._autocast_dtype = None # type: Optional[str]
99
- self._grad_scaler = None # type: Optional[amp.GradScaler]
98
+ self._autocast_dtype: Optional[str] = None
99
+ self._grad_scaler: Optional[amp.GradScaler] = None
100
100
 
101
101
  dev_ = get_device_from_config_opt(config.value("device", None))
102
102
  self._device = dev_.result
103
103
  print("Using device:", self._device, f"({dev_.reason or '?'})", file=log.v2)
104
104
 
105
- self._torch_distributed_ctx = None # type: Optional[DistributedContext]
106
- self._ddp_pt_model = None # type: Optional[DistributedDataParallel]
105
+ self._torch_distributed_ctx: Optional[DistributedContext] = None
106
+ self._ddp_pt_model: Optional[DistributedDataParallel] = None
107
107
 
108
108
  if config.typed_value("torch_distributed") is not None:
109
109
  self._torch_distributed_ctx = dist_get_ctx(config=config)
@@ -421,9 +421,9 @@ class TorchBackend(Backend[torch.Tensor]):
421
421
  else: # not allow_broadcast
422
422
  for source, dim in sources:
423
423
  templ_dims = other_dims[:axis] + [dim] + other_dims[axis:]
424
- assert set(templ_dims) == set(
425
- source.dims
426
- ), f"concat {source} {dim} not allowed with allow_broadcast=False"
424
+ assert set(templ_dims) == set(source.dims), (
425
+ f"concat {source} {dim} not allowed with allow_broadcast=False"
426
+ )
427
427
  source_ = source.copy_transpose(templ_dims)
428
428
  sources_raw.append(source_.raw_tensor)
429
429
  out = Tensor(
@@ -612,9 +612,9 @@ class TorchBackend(Backend[torch.Tensor]):
612
612
  assert axis in logits.dims, "Specified axis not present in logits."
613
613
 
614
614
  if axis == targets.sparse_dim:
615
- assert (
616
- logits.dims_set - {axis} == targets.dims_set
617
- ), "logits Dims and target Dims have to match (except for implicit sparse_dim)."
615
+ assert logits.dims_set - {axis} == targets.dims_set, (
616
+ "logits Dims and target Dims have to match (except for implicit sparse_dim)."
617
+ )
618
618
 
619
619
  logits_dim_order = list(targets.dims)
620
620
  if len(logits_dim_order) > 0:
@@ -629,9 +629,9 @@ class TorchBackend(Backend[torch.Tensor]):
629
629
  targets.raw_tensor = targets.raw_tensor.long()
630
630
 
631
631
  else:
632
- assert (
633
- not targets.sparse_dim
634
- ), "We expect that cross entropy would always be calculated along the sparse dim, if there is one."
632
+ assert not targets.sparse_dim, (
633
+ "We expect that cross entropy would always be calculated along the sparse dim, if there is one."
634
+ )
635
635
  assert logits.dims_set == targets.dims_set, "logits Dims and target Dims have to match."
636
636
  assert axis in targets.dims, "Specified axis not present in targets."
637
637
 
@@ -1348,12 +1348,12 @@ class TorchBackend(Backend[torch.Tensor]):
1348
1348
  a_dims = a.dims
1349
1349
  b_dims = b.dims
1350
1350
 
1351
- assert all(
1352
- dim in a_dims for dim in reduce
1353
- ), f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})"
1354
- assert all(
1355
- dim in b_dims for dim in reduce
1356
- ), f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})"
1351
+ assert all(dim in a_dims for dim in reduce), (
1352
+ f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})"
1353
+ )
1354
+ assert all(dim in b_dims for dim in reduce), (
1355
+ f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})"
1356
+ )
1357
1357
 
1358
1358
  if len(reduce) > 1:
1359
1359
  reduce = list(reduce)
@@ -178,9 +178,9 @@ class RFModuleAsPTModule(torch.nn.Module):
178
178
  rf_param = getattr(self._rf_module, name, None)
179
179
  if not isinstance(rf_param, rf.Parameter):
180
180
  return # just ignore
181
- assert isinstance(
182
- param, torch.nn.Parameter
183
- ), f"{self} register_parameter {name}: did not get a Parameter but {type(param).__name__}"
181
+ assert isinstance(param, torch.nn.Parameter), (
182
+ f"{self} register_parameter {name}: did not get a Parameter but {type(param).__name__}"
183
+ )
184
184
  rf_param.raw_tensor = param
185
185
 
186
186
  def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
returnn/torch/updater.py CHANGED
@@ -39,7 +39,7 @@ def _init_optimizer_classes_dict():
39
39
 
40
40
 
41
41
  def get_optimizer_class(
42
- class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]]
42
+ class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]],
43
43
  ) -> Type[torch.optim.Optimizer]:
44
44
  """
45
45
  :param class_name: Optimizer class, either as str (e.g. "adam"), as type (torch.optim.Adam) or callable.
@@ -121,9 +121,9 @@ class Updater:
121
121
  import inspect
122
122
 
123
123
  signature = inspect.signature(self.learning_rate_function)
124
- assert any(
125
- [arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]
126
- ), "please specify **kwargs in dynamic_learning_rate for future compatibility"
124
+ assert any([arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]), (
125
+ "please specify **kwargs in dynamic_learning_rate for future compatibility"
126
+ )
127
127
  if "network" in signature.parameters:
128
128
  raise ValueError("Torch updater: dynamic_learning_rate network is TF specific")
129
129
  else:
@@ -497,10 +497,9 @@ class Updater:
497
497
  # Split in parameter groups only if decouple_constraints is set and the optimizer accepts weight_decay.
498
498
  cls_init_kwargs = _get_class_init_kwargs(optim_class)
499
499
  if "weight_decay" not in cls_init_kwargs:
500
- assert (
501
- "weight_decay" not in optimizer_opts
502
- ), "weight_decay not accepted by the chosen optimizer. Accepted values: %s" % ", ".join(
503
- "%s" % optim_name for optim_name in cls_init_kwargs
500
+ assert "weight_decay" not in optimizer_opts, (
501
+ "weight_decay not accepted by the chosen optimizer. Accepted values: %s"
502
+ % ", ".join("%s" % optim_name for optim_name in cls_init_kwargs)
504
503
  )
505
504
  return network_params
506
505
 
@@ -564,7 +563,7 @@ class Updater:
564
563
 
565
564
 
566
565
  def _wrap_user_blacklist_wd_modules(
567
- mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]]
566
+ mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]],
568
567
  ) -> Tuple[type, ...]:
569
568
  assert isinstance(mods, (list, tuple)), f"invalid blacklist_weight_decay_modules {mods!r}"
570
569
  res = []
@@ -30,7 +30,6 @@ Also, there might be inf/nan values which are ok, expected, and not a problem
30
30
  So we don't stop on the first occurrence but just report all of them.
31
31
  """
32
32
 
33
-
34
33
  from __future__ import annotations
35
34
 
36
35
  import sys
@@ -90,7 +89,6 @@ def debug_inf_nan(
90
89
  print(f"Caught RuntimeError in backward: {exc}", file=file)
91
90
 
92
91
  else: # without grad
93
-
94
92
  with trace_ops:
95
93
  func()
96
94
 
@@ -79,7 +79,7 @@ def help_on_torch_exception(
79
79
 
80
80
 
81
81
  def _help_data_or_array(
82
- value: Union[torch.Tensor, np.ndarray, bool, object]
82
+ value: Union[torch.Tensor, np.ndarray, bool, object],
83
83
  ) -> Tuple[str, Tuple[Union[int, float], Union[int, float]]]:
84
84
  """
85
85
  :param value:
@@ -14,7 +14,6 @@ https://github.com/janfreyberg/pytorch-revgrad/blob/449fa763a76d/src/pytorch_rev
14
14
  https://github.com/tadeephuy/GradientReversal/blob/5d9857d63/gradient_reversal/functional.py
15
15
  """
16
16
 
17
-
18
17
  from __future__ import annotations
19
18
  from typing import Optional, Union
20
19
  import torch
returnn/util/basic.py CHANGED
@@ -705,7 +705,7 @@ def expand_env_vars(s: str) -> str:
705
705
  return delim
706
706
  if mo.group("invalid") is not None:
707
707
  i = mo.start("invalid")
708
- raise ValueError(f"Invalid placeholder in string: {s[i:i+2]!r}...")
708
+ raise ValueError(f"Invalid placeholder in string: {s[i : i + 2]!r}...")
709
709
  raise ValueError(f"Unrecognized named group in pattern {pattern}")
710
710
 
711
711
  return pattern_.sub(_convert, s)
@@ -1811,7 +1811,6 @@ def json_remove_comments(string, strip_space=True):
1811
1811
  index = 0
1812
1812
 
1813
1813
  for match in re.finditer(tokenizer, string):
1814
-
1815
1814
  if not (in_multi or in_single):
1816
1815
  tmp = string[index : match.start()]
1817
1816
  if not in_string and strip_space:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250508.93313
3
+ Version: 1.20250508.181644
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer