returnn 1.20250508.93313__py3-none-any.whl → 1.20250508.181644__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

Files changed (67) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/datasets/basic.py +24 -25
  4. returnn/datasets/cached.py +4 -3
  5. returnn/datasets/distrib_files.py +1 -2
  6. returnn/datasets/generating.py +20 -20
  7. returnn/datasets/hdf.py +9 -9
  8. returnn/datasets/lm.py +25 -13
  9. returnn/datasets/meta.py +39 -38
  10. returnn/datasets/normalization_data.py +1 -1
  11. returnn/datasets/postprocessing.py +9 -9
  12. returnn/datasets/sprint.py +8 -7
  13. returnn/datasets/util/strings.py +0 -1
  14. returnn/datasets/util/vocabulary.py +3 -3
  15. returnn/extern/graph_editor/subgraph.py +1 -2
  16. returnn/extern/graph_editor/transform.py +1 -2
  17. returnn/extern/graph_editor/util.py +1 -2
  18. returnn/frontend/_backend.py +4 -3
  19. returnn/frontend/_utils.py +1 -1
  20. returnn/frontend/audio/mel.py +0 -1
  21. returnn/frontend/const.py +3 -3
  22. returnn/frontend/device.py +0 -1
  23. returnn/frontend/dropout.py +1 -1
  24. returnn/frontend/encoder/e_branchformer.py +1 -1
  25. returnn/frontend/loop.py +3 -3
  26. returnn/frontend/loss.py +0 -1
  27. returnn/frontend/matmul.py +0 -1
  28. returnn/frontend/run_ctx.py +9 -9
  29. returnn/frontend/signal.py +0 -1
  30. returnn/frontend/types.py +2 -4
  31. returnn/native_op.py +13 -0
  32. returnn/sprint/cache.py +2 -4
  33. returnn/sprint/interface.py +3 -4
  34. returnn/tensor/_dim_extra.py +9 -9
  35. returnn/tensor/_tensor_extra.py +20 -19
  36. returnn/tensor/_tensor_op_overloads.py +0 -1
  37. returnn/tensor/tensor.py +1 -1
  38. returnn/tensor/tensor_dict.py +9 -9
  39. returnn/tf/engine.py +60 -65
  40. returnn/tf/frontend_layers/_backend.py +3 -3
  41. returnn/tf/frontend_layers/cond.py +6 -6
  42. returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
  43. returnn/tf/frontend_layers/layer.py +12 -12
  44. returnn/tf/frontend_layers/loop.py +3 -3
  45. returnn/tf/frontend_layers/make_layer.py +0 -1
  46. returnn/tf/layers/base.py +56 -49
  47. returnn/tf/layers/basic.py +60 -65
  48. returnn/tf/layers/rec.py +74 -74
  49. returnn/tf/native_op.py +1 -3
  50. returnn/tf/network.py +60 -57
  51. returnn/tf/updater.py +3 -3
  52. returnn/tf/util/basic.py +24 -23
  53. returnn/torch/data/extern_data.py +4 -5
  54. returnn/torch/data/pipeline.py +3 -4
  55. returnn/torch/engine.py +16 -16
  56. returnn/torch/frontend/_backend.py +15 -15
  57. returnn/torch/frontend/bridge.py +3 -3
  58. returnn/torch/updater.py +8 -9
  59. returnn/torch/util/debug_inf_nan.py +0 -2
  60. returnn/torch/util/exception_helper.py +1 -1
  61. returnn/torch/util/scaled_gradient.py +0 -1
  62. returnn/util/basic.py +1 -2
  63. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/METADATA +1 -1
  64. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/RECORD +67 -67
  65. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/LICENSE +0 -0
  66. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/WHEEL +0 -0
  67. {returnn-1.20250508.93313.dist-info → returnn-1.20250508.181644.dist-info}/top_level.txt +0 -0
returnn/tf/engine.py CHANGED
@@ -12,7 +12,7 @@ See :ref:`tech_overview` for an overview how it fits all together.
12
12
 
13
13
  from __future__ import annotations
14
14
 
15
- from typing import Optional
15
+ from typing import Callable, Dict, List, Optional, Union
16
16
  import typing
17
17
  import os
18
18
  import sys
@@ -101,31 +101,29 @@ class Runner:
101
101
  self.store_tf_profile = engine.config.bool("store_tf_profile", False)
102
102
  self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0)
103
103
  self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0)
104
- assert not (
105
- self.store_tf_profile and self.store_metadata_mod_step
106
- ), "Cannot use store_tf_profile and store_metadata_mod_step at the same time"
104
+ assert not (self.store_tf_profile and self.store_metadata_mod_step), (
105
+ "Cannot use store_tf_profile and store_metadata_mod_step at the same time"
106
+ )
107
107
  self.finalized = False
108
108
  self.cancel_flag = False
109
109
  self.run_exception = None
110
110
  self.num_steps = None
111
- self.device_crash_batch = None # type: typing.Optional[int]
111
+ self.device_crash_batch: Optional[int] = None
112
112
  self.start_time = None
113
113
  self.elapsed = None
114
- self.report_prefix = None # type: typing.Optional[str]
114
+ self.report_prefix: Optional[str] = None
115
115
  self._results_accumulated = NumbersDict() # entries like "cost:output" or "loss"
116
116
  self._inv_norm_accumulated = NumbersDict() # entries like "output"
117
117
  self.num_frames_accumulated = NumbersDict() # for each data key (eg. "classes"), corresponding number of frames
118
- self.results = {} # type: typing.Dict[str,float] # entries like "cost:output" or "loss"
119
- self.score = {} # type: typing.Dict[str,float] # entries like "cost:output"
120
- self.error = {} # type: typing.Dict[str,float] # entries like "error:output"
121
- self.stats = (
122
- {}
123
- ) # type: typing.Dict[str,typing.Union[float,numpy.ndarray,'Util.Stats']] # entries like "stats:..."
118
+ self.results: Dict[str, float] = {} # entries like "cost:output" or "loss"
119
+ self.score: Dict[str, float] = {} # entries like "cost:output"
120
+ self.error: Dict[str, float] = {} # entries like "error:output"
121
+ self.stats: Dict[str, Union[float, numpy.ndarray, "util.Stats"]] = {} # entries like "stats:..."
124
122
  self.extra_fetches = extra_fetches
125
123
  if extra_fetches is not None:
126
124
  assert extra_fetches_callback
127
125
  self.extra_fetches_callback = extra_fetches_callback
128
- self._step_start_time = None # type: typing.Optional[float]
126
+ self._step_start_time: Optional[float] = None
129
127
  self._horovod_last_param_sync_time = time.time() # we assume it is synced right now
130
128
  self._horovod_stopped_runner = False
131
129
  self._horovod_finish_all = False
@@ -133,9 +131,7 @@ class Runner:
133
131
  self._horovod_finish_all = True
134
132
  # With Horovod, during the main session.run, if reduce_type != grad or not training,
135
133
  # the following tensors are enough to ensure that we are in sync.
136
- self._horovod_collected_reduce_inputs = (
137
- {}
138
- ) # type: typing.Dict[str,(tf.Tensor,tf.Tensor)] # name -> (input,output)
134
+ self._horovod_collected_reduce_inputs: Dict[str, (tf.Tensor, tf.Tensor)] = {} # name -> (input,output)
139
135
 
140
136
  from returnn.util.basic import terminal_size
141
137
 
@@ -196,9 +192,9 @@ class Runner:
196
192
  d["extra:%s" % k] = v
197
193
  continue
198
194
  assert isinstance(v, Data)
199
- d[
200
- "extra:%s" % k
201
- ] = v.placeholder # see _maybe_handle_extra_fetches, it will transform to batch-major there
195
+ d["extra:%s" % k] = (
196
+ v.placeholder
197
+ ) # see _maybe_handle_extra_fetches, it will transform to batch-major there
202
198
  for i, s in v.size_placeholder.items():
203
199
  d["extra:%s:size_%i" % (k, i)] = s
204
200
 
@@ -732,9 +728,9 @@ class Runner:
732
728
  run_options_.MergeFrom(run_options)
733
729
  # We could use tfdbg.add_debug_tensor_watch here.
734
730
  session_run_start_time = time.time()
735
- fetches_results = sess.run(
731
+ fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run(
736
732
  fetches_dict, feed_dict=feed_dict, options=run_options_, run_metadata=run_metadata
737
- ) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]]
733
+ )
738
734
  elapsed_time_tf += time.time() - session_run_start_time
739
735
  writer.add_summary(fetches_results["summary"], step + step_offset)
740
736
  writer.add_run_metadata(run_metadata, "step_{:04d}".format(step + step_offset))
@@ -746,13 +742,13 @@ class Runner:
746
742
  session_run_start_time = time.time()
747
743
  if self.store_tf_profile:
748
744
  with tf.profiler.experimental.Trace(name=report_prefix, step_num=step + step_offset):
749
- fetches_results = sess.run(
745
+ fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run(
750
746
  fetches_dict, feed_dict=feed_dict, options=run_options
751
- ) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]]
747
+ )
752
748
  else:
753
- fetches_results = sess.run(
749
+ fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run(
754
750
  fetches_dict, feed_dict=feed_dict, options=run_options
755
- ) # type: typing.Dict[str,typing.Union[numpy.ndarray,str]]
751
+ )
756
752
  elapsed_time_tf += time.time() - session_run_start_time
757
753
  if writer and "summary" in fetches_results:
758
754
  writer.add_summary(fetches_results["summary"], step + step_offset)
@@ -891,27 +887,27 @@ class Engine(EngineBase):
891
887
  BackendEngine.select_engine(default_fallback_engine=default_fallback_engine, config=self.config)
892
888
  assert BackendEngine.is_tensorflow_selected()
893
889
  self.orig_config = {} # see _maybe_update_config
894
- self.custom_get_net_dict = None # type: typing.Optional[typing.Callable]
890
+ self.custom_get_net_dict: Optional[Callable] = None
895
891
  self._have_rf_get_model_func = False
896
892
  self._check_devices()
897
- self.tf_session = None # type: typing.Optional[tf.compat.v1.Session]
898
- self.network = None # type: typing.Optional[TFNetwork]
899
- self.updater = None # type: typing.Optional[Updater]
893
+ self.tf_session: Optional[tf.compat.v1.Session] = None
894
+ self.network: Optional[TFNetwork] = None
895
+ self.updater: Optional[Updater] = None
900
896
  self._checked_uninitialized_vars = False
901
897
  self._merge_all_summaries = None
902
- self.dataset_batches = {} # type: typing.Dict[str,BatchSetGenerator]
903
- self.dataset_provider = None # type: typing.Optional[DatasetDataProvider]
904
- self.train_data = None # type: typing.Optional[Dataset]
905
- self.eval_datasets = {} # type: typing.Dict[str,Dataset]
906
- self.start_epoch = None # type: typing.Optional[int]
907
- self._num_trained_epochs = 0 # type: int # just a counter
908
- self._num_net_reinit = 0 # type: int
898
+ self.dataset_batches: Dict[str, BatchSetGenerator] = {}
899
+ self.dataset_provider: Optional[DatasetDataProvider] = None
900
+ self.train_data: Optional[Dataset] = None
901
+ self.eval_datasets: Dict[str, Dataset] = {}
902
+ self.start_epoch: Optional[int] = None
903
+ self._num_trained_epochs: int = 0 # just a counter
904
+ self._num_net_reinit: int = 0
909
905
  self.use_dynamic_train_flag = False
910
906
  self.use_search_flag = self.config.value("task", None) == "search"
911
907
  self.use_eval_flag = self.config.value("task", None) != "forward"
912
- self._const_cache = {} # type: typing.Dict[str,tf.Tensor]
913
- self.preload_from_files = None # type: typing.Optional[typing.Dict[str,typing.Dict[str]]]
914
- self.max_seqs = None # type: typing.Optional[int]
908
+ self._const_cache: Dict[str, tf.Tensor] = {}
909
+ self.preload_from_files: Optional[Dict[str, Dict[str]]] = None
910
+ self.max_seqs: Optional[int] = None
915
911
 
916
912
  def finalize(self, error_occurred=False):
917
913
  """
@@ -1140,7 +1136,7 @@ class Engine(EngineBase):
1140
1136
  self.min_seq_length = config.typed_value("min_seq_length", None) or config.float("min_seq_length", 0)
1141
1137
  self.inc_seq_length = config.float("inc_seq_length", 0)
1142
1138
  if not self.max_seq_length:
1143
- self.max_seq_length = sys.maxsize # type: typing.Union[int,float,typing.Dict[str,int],NumbersDict]
1139
+ self.max_seq_length: Union[int, float, Dict[str, int], NumbersDict] = sys.maxsize
1144
1140
  if isinstance(self.max_seq_length, dict):
1145
1141
  self.max_seq_length = NumbersDict(self.max_seq_length)
1146
1142
  assert isinstance(self.max_seq_length, (int, float, NumbersDict))
@@ -1630,7 +1626,7 @@ class Engine(EngineBase):
1630
1626
  assert isinstance(self.start_epoch, int)
1631
1627
  epoch = self.start_epoch # Epochs start at 1.
1632
1628
  while epoch <= final_epoch:
1633
- self.epoch = epoch # type: int
1629
+ self.epoch: int = epoch
1634
1630
  if isinstance(self.max_seq_length, int) and self.max_seq_length != sys.maxsize:
1635
1631
  if int(self.max_seq_length + self.inc_seq_length) != int(self.max_seq_length):
1636
1632
  print("increasing sequence lengths to", int(self.max_seq_length + self.inc_seq_length), file=log.v3)
@@ -1878,9 +1874,9 @@ class Engine(EngineBase):
1878
1874
  # We update the model params in-place.
1879
1875
  # In training, we don't want that, because it should not use the validation data.
1880
1876
  # We could reset it later when continuing the training, but it's not implemented.
1881
- assert (
1882
- self.config.value("task", "train") != "train"
1883
- ), "task %r should be just 'eval' or so. training will break." % self.config.value("task", None)
1877
+ assert self.config.value("task", "train") != "train", (
1878
+ "task %r should be just 'eval' or so. training will break." % self.config.value("task", None)
1879
+ )
1884
1880
  if not self.updater:
1885
1881
  self.updater = Updater(
1886
1882
  config=self.config, network=self.network, initial_learning_rate=self.initial_learning_rate
@@ -1928,11 +1924,12 @@ class Engine(EngineBase):
1928
1924
  allowed_outputs = {"seq_tag", "seq_len", "score", "error", "pos_score", "pos_error"}
1929
1925
 
1930
1926
  assert isinstance(output_per_seq_format, (tuple, list)), "provide output_per_seq_format"
1931
- assert (
1932
- set(output_per_seq_format) - allowed_outputs == set()
1933
- ), "Only %r are allowed in function eval_model as output_per_seq_format, but got: %r " % (
1934
- allowed_outputs,
1935
- output_per_seq_format,
1927
+ assert set(output_per_seq_format) - allowed_outputs == set(), (
1928
+ "Only %r are allowed in function eval_model as output_per_seq_format, but got: %r "
1929
+ % (
1930
+ allowed_outputs,
1931
+ output_per_seq_format,
1932
+ )
1936
1933
  )
1937
1934
 
1938
1935
  # always fetch seq_tag to map loss values to the corresponding line
@@ -1968,12 +1965,10 @@ class Engine(EngineBase):
1968
1965
  if "pos_error" in output_per_seq_format:
1969
1966
  extra_fetches["pos_error"] = loss_holder.get_error_value_per_pos()
1970
1967
 
1971
- seq_idx_to_tag = (
1972
- {}
1973
- ) # type: typing.Dict[int,str] # we need this in order to write the results in the correct order later # nopep8
1974
- results_per_seq = (
1975
- {}
1976
- ) # type: typing.Dict[str,typing.Dict[str,typing.Union[float,str,int]]] # seq_tag -> dict. Results of fetches will be written in this dict # nopep8
1968
+ seq_idx_to_tag: Dict[int, str] = {} # we need this in order to write the results in the correct order later
1969
+ results_per_seq: Dict[
1970
+ str, Dict[str, Union[float, str, int]]
1971
+ ] = {} # seq_tag -> dict. Results of fetches will be written in this dict
1977
1972
 
1978
1973
  # function to save the return values of each callback to the dict `results_per_seq`
1979
1974
  # noinspection PyShadowingNames
@@ -2012,7 +2007,7 @@ class Engine(EngineBase):
2012
2007
 
2013
2008
  if output_per_seq_file:
2014
2009
  assert len(self.get_eval_datasets()) == 1, (
2015
- "output per sequence is only supported for one dataset (dev or eval)," "provided datasets are %r"
2010
+ "output per sequence is only supported for one dataset (dev or eval),provided datasets are %r"
2016
2011
  ) % list(self.get_eval_datasets().keys())
2017
2012
  # try to sort dataset to minimize zero-padding
2018
2013
  dataset = list(self.get_eval_datasets().values())[0]
@@ -2453,9 +2448,9 @@ class Engine(EngineBase):
2453
2448
  )
2454
2449
 
2455
2450
  max_seq_length = self.config.typed_value("max_seq_length", None) or self.config.float("max_seq_length", 0)
2456
- assert (
2457
- not max_seq_length
2458
- ), "Set max_seq_length = 0 for search (i.e. no maximal length). We want to keep all source sentences."
2451
+ assert not max_seq_length, (
2452
+ "Set max_seq_length = 0 for search (i.e. no maximal length). We want to keep all source sentences."
2453
+ )
2459
2454
 
2460
2455
  dataset.init_seq_order(epoch=self.epoch)
2461
2456
  batches = dataset.generate_batches(
@@ -2552,8 +2547,8 @@ class Engine(EngineBase):
2552
2547
  outputs[output_layer_idx] = bytearray(outputs[output_layer_idx]).decode("utf8")
2553
2548
 
2554
2549
  # Create lists with serialized data. All of length num_output_layers.
2555
- serialized_outputs = [] # type: typing.List[typing.Optional[typing.Union[str,numpy.ndarray]]]
2556
- serialized_targets = [] # type: typing.List[typing.Optional[typing.Union[str,numpy.ndarray]]]
2550
+ serialized_outputs: List[Optional[Union[str, numpy.ndarray]]] = []
2551
+ serialized_targets: List[Optional[Union[str, numpy.ndarray]]] = []
2557
2552
  # noinspection PyShadowingNames
2558
2553
  for output_layer_idx in range(num_output_layers):
2559
2554
  if output_layers[output_layer_idx].output.sparse:
@@ -2572,8 +2567,8 @@ class Engine(EngineBase):
2572
2567
  ]
2573
2568
  else:
2574
2569
  serialized_output = None
2575
- assert not output_file, "Unable to serialize sparse output of layer '%s'." % (
2576
- output_layer_names[output_layer_idx]
2570
+ assert not output_file, (
2571
+ "Unable to serialize sparse output of layer '%s'." % (output_layer_names[output_layer_idx])
2577
2572
  )
2578
2573
  else:
2579
2574
  # Output dense layers as-is
@@ -2594,8 +2589,8 @@ class Engine(EngineBase):
2594
2589
  ]
2595
2590
  else:
2596
2591
  serialized_target = None
2597
- assert not output_file, "Unable to serialize sparse target '%s'." % (
2598
- target_keys[output_layer_idx]
2592
+ assert not output_file, (
2593
+ "Unable to serialize sparse target '%s'." % (target_keys[output_layer_idx])
2599
2594
  )
2600
2595
  else:
2601
2596
  serialized_target = targets[output_layer_idx]
@@ -510,9 +510,9 @@ class ReturnnLayersBackend(Backend[Layer]):
510
510
  # We could also maybe move out all the dependencies.
511
511
  # However, it's not clear whether this is always safe.
512
512
  for dep in value.raw_tensor.get_tensor_dependencies():
513
- assert (
514
- dep.parent.can_access_children_from_root
515
- ), f"dep {dep} of moved value {value} is not accessible"
513
+ assert dep.parent.can_access_children_from_root, (
514
+ f"dep {dep} of moved value {value} is not accessible"
515
+ )
516
516
  param.raw_tensor.layer_dict["init_by_layer"] = value
517
517
  else:
518
518
  param.raw_tensor.layer_dict.pop("init_by_layer", None)
@@ -181,9 +181,9 @@ class Cond(Generic[T]):
181
181
  After this, self.result is available.
182
182
  """
183
183
  assert self._entered, f"{self} you need to be in the context scope"
184
- assert (
185
- self._entered_state is False
186
- ), f"{self} you need to be in the False branch, have assigned :func:`true` before"
184
+ assert self._entered_state is False, (
185
+ f"{self} you need to be in the False branch, have assigned :func:`true` before"
186
+ )
187
187
  assert not self._false_value_set
188
188
  nest.assert_same_structure(self._true_value, false_value)
189
189
  # This needs to match the true() setter logic.
@@ -198,9 +198,9 @@ class Cond(Generic[T]):
198
198
  if false_v is None: # see above
199
199
  false_v = rf.zeros((), dtype="int32") # dummy value
200
200
  else:
201
- assert isinstance(
202
- false_v, Tensor
203
- ), f"unexpected {false_value!r}, only expects tensors, got {type(false_v)}"
201
+ assert isinstance(false_v, Tensor), (
202
+ f"unexpected {false_value!r}, only expects tensors, got {type(false_v)}"
203
+ )
204
204
  assert true_v.raw_tensor.parent is self.true_branch_name_ctx
205
205
  name = true_v.raw_tensor.name
206
206
  assert name not in self.false_branch_name_ctx.children
@@ -2,7 +2,6 @@
2
2
  Debug eager mode
3
3
  """
4
4
 
5
-
6
5
  _debug_eager_mode_enabled = False
7
6
 
8
7
 
@@ -1104,13 +1104,13 @@ class _NetDictBuilderCtx:
1104
1104
  # If dyn_size_ext is not set yet, try to complete it.
1105
1105
  if dim.dyn_size_ext is None:
1106
1106
  dim.complete_dyn_size()
1107
- assert (
1108
- dim.dyn_size_ext is not None
1109
- ), f"{sub_name_ctx}: need {dim} to be defined to be able to know about implicit dims"
1107
+ assert dim.dyn_size_ext is not None, (
1108
+ f"{sub_name_ctx}: need {dim} to be defined to be able to know about implicit dims"
1109
+ )
1110
1110
  dim_tags.extend(data_template.dim_tags_set_implicit_only_wrapped)
1111
- assert len(dim_tags) == len(
1112
- set((d, d.match_priority if isinstance(d, Dim) else 0) for d in dim_tags)
1113
- ), f"duplicate dims in {sub_name_ctx} {sub_name_ctx.tensor}"
1111
+ assert len(dim_tags) == len(set((d, d.match_priority if isinstance(d, Dim) else 0) for d in dim_tags)), (
1112
+ f"duplicate dims in {sub_name_ctx} {sub_name_ctx.tensor}"
1113
+ )
1114
1114
  if len(dim_tags) == len(set(dim_tags)): # might not be unique without match_priority
1115
1115
  # For some layer classes, the out_shape would be redundant.
1116
1116
  if layer_dict["class"] not in {"constant", "variable", "random", "subnetwork", "transpose"}:
@@ -1135,9 +1135,9 @@ class _NetDictBuilderCtx:
1135
1135
 
1136
1136
  sub_layer_abs_name_scope = self._expected_layer_abs_name_scope(sub_name_ctx)
1137
1137
  if sub_name_ctx.layer_dict["class"] == "variable":
1138
- assert (
1139
- sub_layer_abs_name_scope
1140
- ), f"VariableLayer {sub_name_ctx} must have a unique name in {self.root_module}"
1138
+ assert sub_layer_abs_name_scope, (
1139
+ f"VariableLayer {sub_name_ctx} must have a unique name in {self.root_module}"
1140
+ )
1141
1141
  if sub_layer_abs_name_scope is not None:
1142
1142
  if (
1143
1143
  layer_abs_name_scope_default != sub_layer_abs_name_scope
@@ -1153,9 +1153,9 @@ class _NetDictBuilderCtx:
1153
1153
 
1154
1154
  def _map_elem_resolve(obj: Any) -> Any:
1155
1155
  if isinstance(obj, Tensor):
1156
- assert isinstance(
1157
- obj.raw_tensor, rfl.Layer
1158
- ), f"unexpected tensor {obj} with raw tensor type {type(obj.raw_tensor)}, expected rfl.Layer"
1156
+ assert isinstance(obj.raw_tensor, rfl.Layer), (
1157
+ f"unexpected tensor {obj} with raw tensor type {type(obj.raw_tensor)}, expected rfl.Layer"
1158
+ )
1159
1159
  obj: Tensor[rfl.Layer]
1160
1160
  assert obj.raw_tensor.parent or net.name_ctx == obj.raw_tensor
1161
1161
  return obj.raw_tensor.get_name_in_ctx(ctx=net.name_ctx)
@@ -415,9 +415,9 @@ class _LoopState:
415
415
  tensor.raw_tensor.make_all_sub_networks_and_optimize()
416
416
 
417
417
  layer_ctx_list = tensor.raw_tensor.get_abs_name_ctx_list()
418
- assert (
419
- self.loop.name_ctx in layer_ctx_list
420
- ), f"Loop state {name_ctx} should get a value inside the loop but got {tensor}"
418
+ assert self.loop.name_ctx in layer_ctx_list, (
419
+ f"Loop state {name_ctx} should get a value inside the loop but got {tensor}"
420
+ )
421
421
  # We need some special logic for MaskedComputation but maybe also for others later.
422
422
  # This is currently not nice, but I'm not sure about better solutions.
423
423
  for i in range(layer_ctx_list.index(self.loop.name_ctx) + 1, len(layer_ctx_list) - 1):
@@ -74,7 +74,6 @@ def make_layer(
74
74
  raise TypeError(f"{layer}: unexpected type {type(value)} in layer_dict: {layer_dict}")
75
75
 
76
76
  try:
77
-
78
77
  if out is not None:
79
78
  assert isinstance(out, Tensor)
80
79
  elif predefined_out_data is not None:
returnn/tf/layers/base.py CHANGED
@@ -4,8 +4,9 @@ This module contains the layer base class :class:`LayerBase`.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Optional, Dict, List
7
+ from typing import Optional, Dict, List, Union
8
8
  import typing
9
+ from typing import TYPE_CHECKING
9
10
  import contextlib
10
11
  import numpy
11
12
  import tensorflow as tf
@@ -17,6 +18,9 @@ from returnn.tf.util.data import Data, FeatureDim, Dim
17
18
  from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope
18
19
  from returnn.log import log
19
20
 
21
+ if TYPE_CHECKING:
22
+ from tensorflow.python.training.saver import BaseSaverBuilder
23
+
20
24
 
21
25
  class LayerBase:
22
26
  """
@@ -188,7 +192,7 @@ class LayerBase:
188
192
  self.name = name
189
193
  self.network = network
190
194
  self._register_layer()
191
- self.kwargs = None # type: typing.Optional[typing.Dict[str]] # set via self.post_init
195
+ self.kwargs: Optional[Dict[str]] = None # set via self.post_init
192
196
  self.target = None
193
197
  self.targets = None
194
198
  if target:
@@ -219,12 +223,12 @@ class LayerBase:
219
223
  "%s: out_dim handling not implemented correctly for this layer" % self
220
224
  )
221
225
  out_shape # noqa # not used here but in fixup_out_data
222
- self.output_before_activation = None # type: typing.Optional[OutputWithActivation]
223
- self.output_loss = None # type: typing.Optional[tf.Tensor]
226
+ self.output_before_activation: Optional[OutputWithActivation] = None
227
+ self.output_loss: Optional[tf.Tensor] = None
224
228
  if copy_output_loss_from_source_idx is not None:
225
229
  self.output_loss = sources[copy_output_loss_from_source_idx].output_loss
226
- self.rec_vars_outputs = {} # type: typing.Dict[str,tf.Tensor]
227
- self.search_choices = None # type: typing.Optional[SearchChoices]
230
+ self.rec_vars_outputs: Dict[str, tf.Tensor] = {}
231
+ self.search_choices: Optional[SearchChoices] = None
228
232
  self._src_common_search_choices = _src_common_search_choices
229
233
  self._initial_output = initial_output
230
234
  self.need_last = need_last
@@ -237,14 +241,14 @@ class LayerBase:
237
241
  # Note that this check is somewhat incomplete
238
242
  # (does not check multiple sources, see _ConcatInputLayer)
239
243
  # and there is no guarantee that a specific layer really uses this correctly.
240
- assert sources[0].output.have_dim_tag(
241
- in_dim, unique=True
242
- ), "%s: in_dim %s not found or unique in input %s" % (self, in_dim, sources[0])
244
+ assert sources[0].output.have_dim_tag(in_dim, unique=True), (
245
+ "%s: in_dim %s not found or unique in input %s" % (self, in_dim, sources[0])
246
+ )
243
247
  self.have_params = False
244
- self.params = {} # type: typing.Dict[str,tf.Variable]
245
- self.saveable_param_replace = (
246
- {}
247
- ) # type: typing.Dict[tf.Variable,typing.Union['tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject',None]] # see get_saveable_params_dict() # nopep8
248
+ self.params: Dict[str, tf.Variable] = {}
249
+ self.saveable_param_replace: Dict[
250
+ tf.Variable, Union["BaseSaverBuilder.SaveableObject", None]
251
+ ] = {} # see get_saveable_params_dict()
248
252
  self.reuse_params = reuse_params
249
253
  self.name_scope = name_scope
250
254
  self.param_device = param_device
@@ -264,7 +268,7 @@ class LayerBase:
264
268
  self.control_dependencies_on_output = control_dependencies_on_output
265
269
  self.register_as_extern_data = register_as_extern_data
266
270
  # Stats will be collected by the engine.
267
- self.stats = {} # type: typing.Dict[str,tf.Tensor]
271
+ self.stats: Dict[str, tf.Tensor] = {}
268
272
  self._set_prev_state(state)
269
273
 
270
274
  def _set_prev_state(self, state):
@@ -516,9 +520,9 @@ class LayerBase:
516
520
  # Special case: Input feature or sparse dim looks the same, so overtake it.
517
521
  out_dim = sources_data.feature_dim_or_sparse_dim
518
522
  if out_dim:
519
- assert (
520
- out_dim.dimension == output.dim
521
- ), f"Layer {name!r} out_dim {out_dim} does not match Data {output} via out_type {out_type}"
523
+ assert out_dim.dimension == output.dim, (
524
+ f"Layer {name!r} out_dim {out_dim} does not match Data {output} via out_type {out_type}"
525
+ )
522
526
  if output.sparse:
523
527
  output.sparse_dim = out_dim
524
528
  else:
@@ -850,9 +854,9 @@ class LayerBase:
850
854
  loss_scale = d.pop("loss_scale", 1.0)
851
855
  if loss_scale != 1.0:
852
856
  if "scale" in loss_opts:
853
- assert (
854
- loss_opts["scale"] == loss_scale
855
- ), "do not use loss_scale and loss with 'scale' option together"
857
+ assert loss_opts["scale"] == loss_scale, (
858
+ "do not use loss_scale and loss with 'scale' option together"
859
+ )
856
860
  loss_opts["scale"] = loss_scale
857
861
  d["loss"] = cls._make_loss(
858
862
  class_name=d.pop("loss", None), opts=loss_opts, network=network, get_layer=get_layer
@@ -2099,9 +2103,9 @@ class LayerBase:
2099
2103
  src_output = src.output.copy()
2100
2104
  if src_output.placeholder is not None:
2101
2105
  zeroed_src_shape = tf_util.get_shape(src_output.placeholder)
2102
- zeroed_src_shape = [
2106
+ zeroed_src_shape: List[Union[tf.Tensor, int]] = [
2103
2107
  zeroed_src_shape[i] for i in range(src_output.batch_ndim)
2104
- ] # type: typing.List[typing.Union[tf.Tensor,int]]
2108
+ ]
2105
2109
  else:
2106
2110
  zeroed_src_shape = []
2107
2111
  for i, d in enumerate(src_output.batch_shape):
@@ -2550,9 +2554,9 @@ class ReuseParams:
2550
2554
  :rtype: tf.Variable|tf.Tensor
2551
2555
  """
2552
2556
  if self.shape is not None:
2553
- assert tuple(shape) == tuple(
2554
- d.dimension for d in self.shape
2555
- ), "%s: unexpected shape %r for param %r, expected %r" % (self, shape, name, self.shape)
2557
+ assert tuple(shape) == tuple(d.dimension for d in self.shape), (
2558
+ "%s: unexpected shape %r for param %r, expected %r" % (self, shape, name, self.shape)
2559
+ )
2556
2560
  abs_scope_prefix = base_layer.get_absolute_name_scope_prefix()
2557
2561
  assert not abs_scope_prefix or abs_scope_prefix.endswith("/")
2558
2562
  assert name.startswith(abs_scope_prefix)
@@ -2609,10 +2613,10 @@ class SearchChoices:
2609
2613
  assert beam_size is not None
2610
2614
  self.owner = owner
2611
2615
  self._done_src_layer = False
2612
- self._src_layer = None # type: typing.Optional[LayerBase]
2613
- self.src_beams = None # type: typing.Optional[tf.Tensor] # src beam index, (batch, beam)
2616
+ self._src_layer: Optional[LayerBase] = None
2617
+ self.src_beams: Optional[tf.Tensor] = None # src beam index, (batch, beam)
2614
2618
  self.beam_size = beam_size
2615
- self.beam_scores = None # type: typing.Optional[tf.Tensor] # (batch, beam)
2619
+ self.beam_scores: Optional[tf.Tensor] = None # (batch, beam)
2616
2620
  self.is_decided = is_decided
2617
2621
  self.keep_raw = keep_raw
2618
2622
  if not owner.output.beam:
@@ -2872,22 +2876,22 @@ class Loss:
2872
2876
  """
2873
2877
  self.base_network = base_network
2874
2878
  self.use_flatten_frames = use_flatten_frames
2875
- self.layer = None # type: typing.Optional[LayerBase]
2879
+ self.layer: Optional[LayerBase] = None
2876
2880
  # All are initialized in self.init().
2877
- self.output = None # type: typing.Optional[Data]
2878
- self.output_with_activation = None # type: typing.Optional[OutputWithActivation]
2879
- self.output_seq_lens = None # type: typing.Optional[tf.Tensor]
2880
- self.target = None # type: typing.Optional[Data]
2881
- self.target_seq_lens = None # type: typing.Optional[tf.Tensor]
2882
- self.output_flat = None # type: typing.Optional[tf.Tensor]
2883
- self.output_before_softmax_flat = None # type: typing.Optional[tf.Tensor]
2881
+ self.output: Optional[Data] = None
2882
+ self.output_with_activation: Optional[OutputWithActivation] = None
2883
+ self.output_seq_lens: Optional[tf.Tensor] = None
2884
+ self.target: Optional[Data] = None
2885
+ self.target_seq_lens: Optional[tf.Tensor] = None
2886
+ self.output_flat: Optional[tf.Tensor] = None
2887
+ self.output_before_softmax_flat: Optional[tf.Tensor] = None
2884
2888
  if _check_output_before_softmax is not None:
2885
2889
  self._check_output_before_softmax = _check_output_before_softmax
2886
- self.target_flat = None # type: typing.Optional[tf.Tensor]
2890
+ self.target_flat: Optional[tf.Tensor] = None
2887
2891
  # Maybe make configurable. For now, same as in our Theano behavior.
2888
2892
  # The loss_norm_factor is used by Runner._normalize_loss both for normalization per epoch and per batch.
2889
2893
  # It is e.g. set to 1/sum(target_seq_len), and logic of accumulation is handled in the Runner.
2890
- self.loss_norm_factor = None # type: typing.Optional[tf.Tensor]
2894
+ self.loss_norm_factor: Optional[tf.Tensor] = None
2891
2895
  self.use_normalized_loss = use_normalized_loss # for the optimizer, per batch
2892
2896
  self.custom_norm_factor = custom_norm_factor
2893
2897
  self.custom_inv_norm_factor = custom_inv_norm_factor
@@ -3132,18 +3136,21 @@ class Loss:
3132
3136
  self.output,
3133
3137
  self.target,
3134
3138
  )
3135
- assert (
3136
- self.target.ndim_dense == self.output.ndim_dense
3137
- ), "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
3139
+ assert self.target.ndim_dense == self.output.ndim_dense, (
3140
+ "Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
3141
+ )
3138
3142
  expected_output_dim = self.get_auto_output_layer_dim(self.target.feature_dim_or_sparse_dim)
3139
- assert (
3140
- expected_output_dim.dimension == self.output.dim
3141
- ), "Expected output dim is %r but the output has dim %r. " % (
3142
- expected_output_dim,
3143
- self.output.feature_dim_or_sparse_dim,
3144
- ) + "Target: %s, output: %s" % (
3145
- self.target,
3146
- self.output,
3143
+ assert expected_output_dim.dimension == self.output.dim, (
3144
+ "Expected output dim is %r but the output has dim %r. "
3145
+ % (
3146
+ expected_output_dim,
3147
+ self.output.feature_dim_or_sparse_dim,
3148
+ )
3149
+ + "Target: %s, output: %s"
3150
+ % (
3151
+ self.target,
3152
+ self.output,
3153
+ )
3147
3154
  )
3148
3155
  if self.base_network.get_config().bool("debug_runtime_sanity_checks", False):
3149
3156
  with tf.name_scope("Loss_debug_runtime_sanity_checks"):