returnn 1.20250508.93313__py3-none-any.whl → 1.20250513.145447__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +24 -25
- returnn/datasets/cached.py +4 -3
- returnn/datasets/distrib_files.py +1 -2
- returnn/datasets/generating.py +20 -20
- returnn/datasets/hdf.py +9 -9
- returnn/datasets/lm.py +25 -13
- returnn/datasets/meta.py +39 -38
- returnn/datasets/normalization_data.py +1 -1
- returnn/datasets/postprocessing.py +20 -13
- returnn/datasets/sprint.py +8 -7
- returnn/datasets/util/strings.py +0 -1
- returnn/datasets/util/vocabulary.py +3 -3
- returnn/extern/graph_editor/subgraph.py +1 -2
- returnn/extern/graph_editor/transform.py +1 -2
- returnn/extern/graph_editor/util.py +1 -2
- returnn/frontend/_backend.py +4 -3
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/audio/mel.py +0 -1
- returnn/frontend/const.py +3 -3
- returnn/frontend/device.py +0 -1
- returnn/frontend/dropout.py +1 -1
- returnn/frontend/encoder/e_branchformer.py +1 -1
- returnn/frontend/loop.py +3 -3
- returnn/frontend/loss.py +0 -1
- returnn/frontend/matmul.py +0 -1
- returnn/frontend/run_ctx.py +9 -9
- returnn/frontend/signal.py +0 -1
- returnn/frontend/types.py +2 -4
- returnn/native_op.py +13 -0
- returnn/sprint/cache.py +2 -4
- returnn/sprint/interface.py +3 -4
- returnn/tensor/_dim_extra.py +9 -9
- returnn/tensor/_tensor_extra.py +20 -19
- returnn/tensor/_tensor_op_overloads.py +0 -1
- returnn/tensor/tensor.py +1 -1
- returnn/tensor/tensor_dict.py +9 -9
- returnn/tf/engine.py +60 -65
- returnn/tf/frontend_layers/_backend.py +3 -3
- returnn/tf/frontend_layers/cond.py +6 -6
- returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
- returnn/tf/frontend_layers/layer.py +12 -12
- returnn/tf/frontend_layers/loop.py +3 -3
- returnn/tf/frontend_layers/make_layer.py +0 -1
- returnn/tf/layers/base.py +56 -49
- returnn/tf/layers/basic.py +60 -65
- returnn/tf/layers/rec.py +74 -74
- returnn/tf/native_op.py +1 -3
- returnn/tf/network.py +60 -57
- returnn/tf/updater.py +3 -3
- returnn/tf/util/basic.py +24 -23
- returnn/torch/data/extern_data.py +4 -5
- returnn/torch/data/pipeline.py +3 -4
- returnn/torch/engine.py +16 -16
- returnn/torch/frontend/_backend.py +15 -15
- returnn/torch/frontend/bridge.py +3 -3
- returnn/torch/updater.py +8 -9
- returnn/torch/util/debug_inf_nan.py +0 -2
- returnn/torch/util/exception_helper.py +1 -1
- returnn/torch/util/scaled_gradient.py +0 -1
- returnn/util/basic.py +1 -2
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/METADATA +1 -1
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/RECORD +67 -67
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/LICENSE +0 -0
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/WHEEL +0 -0
- {returnn-1.20250508.93313.dist-info → returnn-1.20250513.145447.dist-info}/top_level.txt +0 -0
returnn/tf/engine.py
CHANGED
|
@@ -12,7 +12,7 @@ See :ref:`tech_overview` for an overview how it fits all together.
|
|
|
12
12
|
|
|
13
13
|
from __future__ import annotations
|
|
14
14
|
|
|
15
|
-
from typing import Optional
|
|
15
|
+
from typing import Callable, Dict, List, Optional, Union
|
|
16
16
|
import typing
|
|
17
17
|
import os
|
|
18
18
|
import sys
|
|
@@ -101,31 +101,29 @@ class Runner:
|
|
|
101
101
|
self.store_tf_profile = engine.config.bool("store_tf_profile", False)
|
|
102
102
|
self.store_metadata_mod_step = engine.config.int("store_metadata_mod_step", 0)
|
|
103
103
|
self.reset_updater_vars_mod_step = engine.config.int("reset_updater_vars_mod_step", 0)
|
|
104
|
-
assert not (
|
|
105
|
-
|
|
106
|
-
)
|
|
104
|
+
assert not (self.store_tf_profile and self.store_metadata_mod_step), (
|
|
105
|
+
"Cannot use store_tf_profile and store_metadata_mod_step at the same time"
|
|
106
|
+
)
|
|
107
107
|
self.finalized = False
|
|
108
108
|
self.cancel_flag = False
|
|
109
109
|
self.run_exception = None
|
|
110
110
|
self.num_steps = None
|
|
111
|
-
self.device_crash_batch
|
|
111
|
+
self.device_crash_batch: Optional[int] = None
|
|
112
112
|
self.start_time = None
|
|
113
113
|
self.elapsed = None
|
|
114
|
-
self.report_prefix
|
|
114
|
+
self.report_prefix: Optional[str] = None
|
|
115
115
|
self._results_accumulated = NumbersDict() # entries like "cost:output" or "loss"
|
|
116
116
|
self._inv_norm_accumulated = NumbersDict() # entries like "output"
|
|
117
117
|
self.num_frames_accumulated = NumbersDict() # for each data key (eg. "classes"), corresponding number of frames
|
|
118
|
-
self.results
|
|
119
|
-
self.score
|
|
120
|
-
self.error
|
|
121
|
-
self.stats =
|
|
122
|
-
{}
|
|
123
|
-
) # type: typing.Dict[str,typing.Union[float,numpy.ndarray,'Util.Stats']] # entries like "stats:..."
|
|
118
|
+
self.results: Dict[str, float] = {} # entries like "cost:output" or "loss"
|
|
119
|
+
self.score: Dict[str, float] = {} # entries like "cost:output"
|
|
120
|
+
self.error: Dict[str, float] = {} # entries like "error:output"
|
|
121
|
+
self.stats: Dict[str, Union[float, numpy.ndarray, "util.Stats"]] = {} # entries like "stats:..."
|
|
124
122
|
self.extra_fetches = extra_fetches
|
|
125
123
|
if extra_fetches is not None:
|
|
126
124
|
assert extra_fetches_callback
|
|
127
125
|
self.extra_fetches_callback = extra_fetches_callback
|
|
128
|
-
self._step_start_time
|
|
126
|
+
self._step_start_time: Optional[float] = None
|
|
129
127
|
self._horovod_last_param_sync_time = time.time() # we assume it is synced right now
|
|
130
128
|
self._horovod_stopped_runner = False
|
|
131
129
|
self._horovod_finish_all = False
|
|
@@ -133,9 +131,7 @@ class Runner:
|
|
|
133
131
|
self._horovod_finish_all = True
|
|
134
132
|
# With Horovod, during the main session.run, if reduce_type != grad or not training,
|
|
135
133
|
# the following tensors are enough to ensure that we are in sync.
|
|
136
|
-
self._horovod_collected_reduce_inputs = (
|
|
137
|
-
{}
|
|
138
|
-
) # type: typing.Dict[str,(tf.Tensor,tf.Tensor)] # name -> (input,output)
|
|
134
|
+
self._horovod_collected_reduce_inputs: Dict[str, (tf.Tensor, tf.Tensor)] = {} # name -> (input,output)
|
|
139
135
|
|
|
140
136
|
from returnn.util.basic import terminal_size
|
|
141
137
|
|
|
@@ -196,9 +192,9 @@ class Runner:
|
|
|
196
192
|
d["extra:%s" % k] = v
|
|
197
193
|
continue
|
|
198
194
|
assert isinstance(v, Data)
|
|
199
|
-
d[
|
|
200
|
-
|
|
201
|
-
|
|
195
|
+
d["extra:%s" % k] = (
|
|
196
|
+
v.placeholder
|
|
197
|
+
) # see _maybe_handle_extra_fetches, it will transform to batch-major there
|
|
202
198
|
for i, s in v.size_placeholder.items():
|
|
203
199
|
d["extra:%s:size_%i" % (k, i)] = s
|
|
204
200
|
|
|
@@ -732,9 +728,9 @@ class Runner:
|
|
|
732
728
|
run_options_.MergeFrom(run_options)
|
|
733
729
|
# We could use tfdbg.add_debug_tensor_watch here.
|
|
734
730
|
session_run_start_time = time.time()
|
|
735
|
-
fetches_results = sess.run(
|
|
731
|
+
fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run(
|
|
736
732
|
fetches_dict, feed_dict=feed_dict, options=run_options_, run_metadata=run_metadata
|
|
737
|
-
)
|
|
733
|
+
)
|
|
738
734
|
elapsed_time_tf += time.time() - session_run_start_time
|
|
739
735
|
writer.add_summary(fetches_results["summary"], step + step_offset)
|
|
740
736
|
writer.add_run_metadata(run_metadata, "step_{:04d}".format(step + step_offset))
|
|
@@ -746,13 +742,13 @@ class Runner:
|
|
|
746
742
|
session_run_start_time = time.time()
|
|
747
743
|
if self.store_tf_profile:
|
|
748
744
|
with tf.profiler.experimental.Trace(name=report_prefix, step_num=step + step_offset):
|
|
749
|
-
fetches_results = sess.run(
|
|
745
|
+
fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run(
|
|
750
746
|
fetches_dict, feed_dict=feed_dict, options=run_options
|
|
751
|
-
)
|
|
747
|
+
)
|
|
752
748
|
else:
|
|
753
|
-
fetches_results = sess.run(
|
|
749
|
+
fetches_results: Dict[str, Union[numpy.ndarray, str]] = sess.run(
|
|
754
750
|
fetches_dict, feed_dict=feed_dict, options=run_options
|
|
755
|
-
)
|
|
751
|
+
)
|
|
756
752
|
elapsed_time_tf += time.time() - session_run_start_time
|
|
757
753
|
if writer and "summary" in fetches_results:
|
|
758
754
|
writer.add_summary(fetches_results["summary"], step + step_offset)
|
|
@@ -891,27 +887,27 @@ class Engine(EngineBase):
|
|
|
891
887
|
BackendEngine.select_engine(default_fallback_engine=default_fallback_engine, config=self.config)
|
|
892
888
|
assert BackendEngine.is_tensorflow_selected()
|
|
893
889
|
self.orig_config = {} # see _maybe_update_config
|
|
894
|
-
self.custom_get_net_dict
|
|
890
|
+
self.custom_get_net_dict: Optional[Callable] = None
|
|
895
891
|
self._have_rf_get_model_func = False
|
|
896
892
|
self._check_devices()
|
|
897
|
-
self.tf_session
|
|
898
|
-
self.network
|
|
899
|
-
self.updater
|
|
893
|
+
self.tf_session: Optional[tf.compat.v1.Session] = None
|
|
894
|
+
self.network: Optional[TFNetwork] = None
|
|
895
|
+
self.updater: Optional[Updater] = None
|
|
900
896
|
self._checked_uninitialized_vars = False
|
|
901
897
|
self._merge_all_summaries = None
|
|
902
|
-
self.dataset_batches
|
|
903
|
-
self.dataset_provider
|
|
904
|
-
self.train_data
|
|
905
|
-
self.eval_datasets
|
|
906
|
-
self.start_epoch
|
|
907
|
-
self._num_trained_epochs = 0 #
|
|
908
|
-
self._num_net_reinit = 0
|
|
898
|
+
self.dataset_batches: Dict[str, BatchSetGenerator] = {}
|
|
899
|
+
self.dataset_provider: Optional[DatasetDataProvider] = None
|
|
900
|
+
self.train_data: Optional[Dataset] = None
|
|
901
|
+
self.eval_datasets: Dict[str, Dataset] = {}
|
|
902
|
+
self.start_epoch: Optional[int] = None
|
|
903
|
+
self._num_trained_epochs: int = 0 # just a counter
|
|
904
|
+
self._num_net_reinit: int = 0
|
|
909
905
|
self.use_dynamic_train_flag = False
|
|
910
906
|
self.use_search_flag = self.config.value("task", None) == "search"
|
|
911
907
|
self.use_eval_flag = self.config.value("task", None) != "forward"
|
|
912
|
-
self._const_cache
|
|
913
|
-
self.preload_from_files
|
|
914
|
-
self.max_seqs
|
|
908
|
+
self._const_cache: Dict[str, tf.Tensor] = {}
|
|
909
|
+
self.preload_from_files: Optional[Dict[str, Dict[str]]] = None
|
|
910
|
+
self.max_seqs: Optional[int] = None
|
|
915
911
|
|
|
916
912
|
def finalize(self, error_occurred=False):
|
|
917
913
|
"""
|
|
@@ -1140,7 +1136,7 @@ class Engine(EngineBase):
|
|
|
1140
1136
|
self.min_seq_length = config.typed_value("min_seq_length", None) or config.float("min_seq_length", 0)
|
|
1141
1137
|
self.inc_seq_length = config.float("inc_seq_length", 0)
|
|
1142
1138
|
if not self.max_seq_length:
|
|
1143
|
-
self.max_seq_length
|
|
1139
|
+
self.max_seq_length: Union[int, float, Dict[str, int], NumbersDict] = sys.maxsize
|
|
1144
1140
|
if isinstance(self.max_seq_length, dict):
|
|
1145
1141
|
self.max_seq_length = NumbersDict(self.max_seq_length)
|
|
1146
1142
|
assert isinstance(self.max_seq_length, (int, float, NumbersDict))
|
|
@@ -1630,7 +1626,7 @@ class Engine(EngineBase):
|
|
|
1630
1626
|
assert isinstance(self.start_epoch, int)
|
|
1631
1627
|
epoch = self.start_epoch # Epochs start at 1.
|
|
1632
1628
|
while epoch <= final_epoch:
|
|
1633
|
-
self.epoch = epoch
|
|
1629
|
+
self.epoch: int = epoch
|
|
1634
1630
|
if isinstance(self.max_seq_length, int) and self.max_seq_length != sys.maxsize:
|
|
1635
1631
|
if int(self.max_seq_length + self.inc_seq_length) != int(self.max_seq_length):
|
|
1636
1632
|
print("increasing sequence lengths to", int(self.max_seq_length + self.inc_seq_length), file=log.v3)
|
|
@@ -1878,9 +1874,9 @@ class Engine(EngineBase):
|
|
|
1878
1874
|
# We update the model params in-place.
|
|
1879
1875
|
# In training, we don't want that, because it should not use the validation data.
|
|
1880
1876
|
# We could reset it later when continuing the training, but it's not implemented.
|
|
1881
|
-
assert (
|
|
1882
|
-
self.config.value("task",
|
|
1883
|
-
)
|
|
1877
|
+
assert self.config.value("task", "train") != "train", (
|
|
1878
|
+
"task %r should be just 'eval' or so. training will break." % self.config.value("task", None)
|
|
1879
|
+
)
|
|
1884
1880
|
if not self.updater:
|
|
1885
1881
|
self.updater = Updater(
|
|
1886
1882
|
config=self.config, network=self.network, initial_learning_rate=self.initial_learning_rate
|
|
@@ -1928,11 +1924,12 @@ class Engine(EngineBase):
|
|
|
1928
1924
|
allowed_outputs = {"seq_tag", "seq_len", "score", "error", "pos_score", "pos_error"}
|
|
1929
1925
|
|
|
1930
1926
|
assert isinstance(output_per_seq_format, (tuple, list)), "provide output_per_seq_format"
|
|
1931
|
-
assert (
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1927
|
+
assert set(output_per_seq_format) - allowed_outputs == set(), (
|
|
1928
|
+
"Only %r are allowed in function eval_model as output_per_seq_format, but got: %r "
|
|
1929
|
+
% (
|
|
1930
|
+
allowed_outputs,
|
|
1931
|
+
output_per_seq_format,
|
|
1932
|
+
)
|
|
1936
1933
|
)
|
|
1937
1934
|
|
|
1938
1935
|
# always fetch seq_tag to map loss values to the corresponding line
|
|
@@ -1968,12 +1965,10 @@ class Engine(EngineBase):
|
|
|
1968
1965
|
if "pos_error" in output_per_seq_format:
|
|
1969
1966
|
extra_fetches["pos_error"] = loss_holder.get_error_value_per_pos()
|
|
1970
1967
|
|
|
1971
|
-
seq_idx_to_tag =
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
{}
|
|
1976
|
-
) # type: typing.Dict[str,typing.Dict[str,typing.Union[float,str,int]]] # seq_tag -> dict. Results of fetches will be written in this dict # nopep8
|
|
1968
|
+
seq_idx_to_tag: Dict[int, str] = {} # we need this in order to write the results in the correct order later
|
|
1969
|
+
results_per_seq: Dict[
|
|
1970
|
+
str, Dict[str, Union[float, str, int]]
|
|
1971
|
+
] = {} # seq_tag -> dict. Results of fetches will be written in this dict
|
|
1977
1972
|
|
|
1978
1973
|
# function to save the return values of each callback to the dict `results_per_seq`
|
|
1979
1974
|
# noinspection PyShadowingNames
|
|
@@ -2012,7 +2007,7 @@ class Engine(EngineBase):
|
|
|
2012
2007
|
|
|
2013
2008
|
if output_per_seq_file:
|
|
2014
2009
|
assert len(self.get_eval_datasets()) == 1, (
|
|
2015
|
-
"output per sequence is only supported for one dataset (dev or eval),
|
|
2010
|
+
"output per sequence is only supported for one dataset (dev or eval),provided datasets are %r"
|
|
2016
2011
|
) % list(self.get_eval_datasets().keys())
|
|
2017
2012
|
# try to sort dataset to minimize zero-padding
|
|
2018
2013
|
dataset = list(self.get_eval_datasets().values())[0]
|
|
@@ -2453,9 +2448,9 @@ class Engine(EngineBase):
|
|
|
2453
2448
|
)
|
|
2454
2449
|
|
|
2455
2450
|
max_seq_length = self.config.typed_value("max_seq_length", None) or self.config.float("max_seq_length", 0)
|
|
2456
|
-
assert (
|
|
2457
|
-
|
|
2458
|
-
)
|
|
2451
|
+
assert not max_seq_length, (
|
|
2452
|
+
"Set max_seq_length = 0 for search (i.e. no maximal length). We want to keep all source sentences."
|
|
2453
|
+
)
|
|
2459
2454
|
|
|
2460
2455
|
dataset.init_seq_order(epoch=self.epoch)
|
|
2461
2456
|
batches = dataset.generate_batches(
|
|
@@ -2552,8 +2547,8 @@ class Engine(EngineBase):
|
|
|
2552
2547
|
outputs[output_layer_idx] = bytearray(outputs[output_layer_idx]).decode("utf8")
|
|
2553
2548
|
|
|
2554
2549
|
# Create lists with serialized data. All of length num_output_layers.
|
|
2555
|
-
serialized_outputs
|
|
2556
|
-
serialized_targets
|
|
2550
|
+
serialized_outputs: List[Optional[Union[str, numpy.ndarray]]] = []
|
|
2551
|
+
serialized_targets: List[Optional[Union[str, numpy.ndarray]]] = []
|
|
2557
2552
|
# noinspection PyShadowingNames
|
|
2558
2553
|
for output_layer_idx in range(num_output_layers):
|
|
2559
2554
|
if output_layers[output_layer_idx].output.sparse:
|
|
@@ -2572,8 +2567,8 @@ class Engine(EngineBase):
|
|
|
2572
2567
|
]
|
|
2573
2568
|
else:
|
|
2574
2569
|
serialized_output = None
|
|
2575
|
-
assert not output_file,
|
|
2576
|
-
output_layer_names[output_layer_idx]
|
|
2570
|
+
assert not output_file, (
|
|
2571
|
+
"Unable to serialize sparse output of layer '%s'." % (output_layer_names[output_layer_idx])
|
|
2577
2572
|
)
|
|
2578
2573
|
else:
|
|
2579
2574
|
# Output dense layers as-is
|
|
@@ -2594,8 +2589,8 @@ class Engine(EngineBase):
|
|
|
2594
2589
|
]
|
|
2595
2590
|
else:
|
|
2596
2591
|
serialized_target = None
|
|
2597
|
-
assert not output_file,
|
|
2598
|
-
target_keys[output_layer_idx]
|
|
2592
|
+
assert not output_file, (
|
|
2593
|
+
"Unable to serialize sparse target '%s'." % (target_keys[output_layer_idx])
|
|
2599
2594
|
)
|
|
2600
2595
|
else:
|
|
2601
2596
|
serialized_target = targets[output_layer_idx]
|
|
@@ -510,9 +510,9 @@ class ReturnnLayersBackend(Backend[Layer]):
|
|
|
510
510
|
# We could also maybe move out all the dependencies.
|
|
511
511
|
# However, it's not clear whether this is always safe.
|
|
512
512
|
for dep in value.raw_tensor.get_tensor_dependencies():
|
|
513
|
-
assert (
|
|
514
|
-
dep
|
|
515
|
-
)
|
|
513
|
+
assert dep.parent.can_access_children_from_root, (
|
|
514
|
+
f"dep {dep} of moved value {value} is not accessible"
|
|
515
|
+
)
|
|
516
516
|
param.raw_tensor.layer_dict["init_by_layer"] = value
|
|
517
517
|
else:
|
|
518
518
|
param.raw_tensor.layer_dict.pop("init_by_layer", None)
|
|
@@ -181,9 +181,9 @@ class Cond(Generic[T]):
|
|
|
181
181
|
After this, self.result is available.
|
|
182
182
|
"""
|
|
183
183
|
assert self._entered, f"{self} you need to be in the context scope"
|
|
184
|
-
assert (
|
|
185
|
-
self
|
|
186
|
-
)
|
|
184
|
+
assert self._entered_state is False, (
|
|
185
|
+
f"{self} you need to be in the False branch, have assigned :func:`true` before"
|
|
186
|
+
)
|
|
187
187
|
assert not self._false_value_set
|
|
188
188
|
nest.assert_same_structure(self._true_value, false_value)
|
|
189
189
|
# This needs to match the true() setter logic.
|
|
@@ -198,9 +198,9 @@ class Cond(Generic[T]):
|
|
|
198
198
|
if false_v is None: # see above
|
|
199
199
|
false_v = rf.zeros((), dtype="int32") # dummy value
|
|
200
200
|
else:
|
|
201
|
-
assert isinstance(
|
|
202
|
-
|
|
203
|
-
)
|
|
201
|
+
assert isinstance(false_v, Tensor), (
|
|
202
|
+
f"unexpected {false_value!r}, only expects tensors, got {type(false_v)}"
|
|
203
|
+
)
|
|
204
204
|
assert true_v.raw_tensor.parent is self.true_branch_name_ctx
|
|
205
205
|
name = true_v.raw_tensor.name
|
|
206
206
|
assert name not in self.false_branch_name_ctx.children
|
|
@@ -1104,13 +1104,13 @@ class _NetDictBuilderCtx:
|
|
|
1104
1104
|
# If dyn_size_ext is not set yet, try to complete it.
|
|
1105
1105
|
if dim.dyn_size_ext is None:
|
|
1106
1106
|
dim.complete_dyn_size()
|
|
1107
|
-
assert (
|
|
1108
|
-
dim
|
|
1109
|
-
)
|
|
1107
|
+
assert dim.dyn_size_ext is not None, (
|
|
1108
|
+
f"{sub_name_ctx}: need {dim} to be defined to be able to know about implicit dims"
|
|
1109
|
+
)
|
|
1110
1110
|
dim_tags.extend(data_template.dim_tags_set_implicit_only_wrapped)
|
|
1111
|
-
assert len(dim_tags) == len(
|
|
1112
|
-
|
|
1113
|
-
)
|
|
1111
|
+
assert len(dim_tags) == len(set((d, d.match_priority if isinstance(d, Dim) else 0) for d in dim_tags)), (
|
|
1112
|
+
f"duplicate dims in {sub_name_ctx} {sub_name_ctx.tensor}"
|
|
1113
|
+
)
|
|
1114
1114
|
if len(dim_tags) == len(set(dim_tags)): # might not be unique without match_priority
|
|
1115
1115
|
# For some layer classes, the out_shape would be redundant.
|
|
1116
1116
|
if layer_dict["class"] not in {"constant", "variable", "random", "subnetwork", "transpose"}:
|
|
@@ -1135,9 +1135,9 @@ class _NetDictBuilderCtx:
|
|
|
1135
1135
|
|
|
1136
1136
|
sub_layer_abs_name_scope = self._expected_layer_abs_name_scope(sub_name_ctx)
|
|
1137
1137
|
if sub_name_ctx.layer_dict["class"] == "variable":
|
|
1138
|
-
assert (
|
|
1139
|
-
|
|
1140
|
-
)
|
|
1138
|
+
assert sub_layer_abs_name_scope, (
|
|
1139
|
+
f"VariableLayer {sub_name_ctx} must have a unique name in {self.root_module}"
|
|
1140
|
+
)
|
|
1141
1141
|
if sub_layer_abs_name_scope is not None:
|
|
1142
1142
|
if (
|
|
1143
1143
|
layer_abs_name_scope_default != sub_layer_abs_name_scope
|
|
@@ -1153,9 +1153,9 @@ class _NetDictBuilderCtx:
|
|
|
1153
1153
|
|
|
1154
1154
|
def _map_elem_resolve(obj: Any) -> Any:
|
|
1155
1155
|
if isinstance(obj, Tensor):
|
|
1156
|
-
assert isinstance(
|
|
1157
|
-
obj.raw_tensor, rfl.Layer
|
|
1158
|
-
)
|
|
1156
|
+
assert isinstance(obj.raw_tensor, rfl.Layer), (
|
|
1157
|
+
f"unexpected tensor {obj} with raw tensor type {type(obj.raw_tensor)}, expected rfl.Layer"
|
|
1158
|
+
)
|
|
1159
1159
|
obj: Tensor[rfl.Layer]
|
|
1160
1160
|
assert obj.raw_tensor.parent or net.name_ctx == obj.raw_tensor
|
|
1161
1161
|
return obj.raw_tensor.get_name_in_ctx(ctx=net.name_ctx)
|
|
@@ -415,9 +415,9 @@ class _LoopState:
|
|
|
415
415
|
tensor.raw_tensor.make_all_sub_networks_and_optimize()
|
|
416
416
|
|
|
417
417
|
layer_ctx_list = tensor.raw_tensor.get_abs_name_ctx_list()
|
|
418
|
-
assert (
|
|
419
|
-
|
|
420
|
-
)
|
|
418
|
+
assert self.loop.name_ctx in layer_ctx_list, (
|
|
419
|
+
f"Loop state {name_ctx} should get a value inside the loop but got {tensor}"
|
|
420
|
+
)
|
|
421
421
|
# We need some special logic for MaskedComputation but maybe also for others later.
|
|
422
422
|
# This is currently not nice, but I'm not sure about better solutions.
|
|
423
423
|
for i in range(layer_ctx_list.index(self.loop.name_ctx) + 1, len(layer_ctx_list) - 1):
|
returnn/tf/layers/base.py
CHANGED
|
@@ -4,8 +4,9 @@ This module contains the layer base class :class:`LayerBase`.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
from typing import Optional, Dict, List
|
|
7
|
+
from typing import Optional, Dict, List, Union
|
|
8
8
|
import typing
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
9
10
|
import contextlib
|
|
10
11
|
import numpy
|
|
11
12
|
import tensorflow as tf
|
|
@@ -17,6 +18,9 @@ from returnn.tf.util.data import Data, FeatureDim, Dim
|
|
|
17
18
|
from returnn.tf.util.basic import OutputWithActivation, CustomUpdate, reuse_name_scope
|
|
18
19
|
from returnn.log import log
|
|
19
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from tensorflow.python.training.saver import BaseSaverBuilder
|
|
23
|
+
|
|
20
24
|
|
|
21
25
|
class LayerBase:
|
|
22
26
|
"""
|
|
@@ -188,7 +192,7 @@ class LayerBase:
|
|
|
188
192
|
self.name = name
|
|
189
193
|
self.network = network
|
|
190
194
|
self._register_layer()
|
|
191
|
-
self.kwargs
|
|
195
|
+
self.kwargs: Optional[Dict[str]] = None # set via self.post_init
|
|
192
196
|
self.target = None
|
|
193
197
|
self.targets = None
|
|
194
198
|
if target:
|
|
@@ -219,12 +223,12 @@ class LayerBase:
|
|
|
219
223
|
"%s: out_dim handling not implemented correctly for this layer" % self
|
|
220
224
|
)
|
|
221
225
|
out_shape # noqa # not used here but in fixup_out_data
|
|
222
|
-
self.output_before_activation
|
|
223
|
-
self.output_loss
|
|
226
|
+
self.output_before_activation: Optional[OutputWithActivation] = None
|
|
227
|
+
self.output_loss: Optional[tf.Tensor] = None
|
|
224
228
|
if copy_output_loss_from_source_idx is not None:
|
|
225
229
|
self.output_loss = sources[copy_output_loss_from_source_idx].output_loss
|
|
226
|
-
self.rec_vars_outputs
|
|
227
|
-
self.search_choices
|
|
230
|
+
self.rec_vars_outputs: Dict[str, tf.Tensor] = {}
|
|
231
|
+
self.search_choices: Optional[SearchChoices] = None
|
|
228
232
|
self._src_common_search_choices = _src_common_search_choices
|
|
229
233
|
self._initial_output = initial_output
|
|
230
234
|
self.need_last = need_last
|
|
@@ -237,14 +241,14 @@ class LayerBase:
|
|
|
237
241
|
# Note that this check is somewhat incomplete
|
|
238
242
|
# (does not check multiple sources, see _ConcatInputLayer)
|
|
239
243
|
# and there is no guarantee that a specific layer really uses this correctly.
|
|
240
|
-
assert sources[0].output.have_dim_tag(
|
|
241
|
-
in_dim,
|
|
242
|
-
)
|
|
244
|
+
assert sources[0].output.have_dim_tag(in_dim, unique=True), (
|
|
245
|
+
"%s: in_dim %s not found or unique in input %s" % (self, in_dim, sources[0])
|
|
246
|
+
)
|
|
243
247
|
self.have_params = False
|
|
244
|
-
self.params
|
|
245
|
-
self.saveable_param_replace
|
|
246
|
-
|
|
247
|
-
|
|
248
|
+
self.params: Dict[str, tf.Variable] = {}
|
|
249
|
+
self.saveable_param_replace: Dict[
|
|
250
|
+
tf.Variable, Union["BaseSaverBuilder.SaveableObject", None]
|
|
251
|
+
] = {} # see get_saveable_params_dict()
|
|
248
252
|
self.reuse_params = reuse_params
|
|
249
253
|
self.name_scope = name_scope
|
|
250
254
|
self.param_device = param_device
|
|
@@ -264,7 +268,7 @@ class LayerBase:
|
|
|
264
268
|
self.control_dependencies_on_output = control_dependencies_on_output
|
|
265
269
|
self.register_as_extern_data = register_as_extern_data
|
|
266
270
|
# Stats will be collected by the engine.
|
|
267
|
-
self.stats
|
|
271
|
+
self.stats: Dict[str, tf.Tensor] = {}
|
|
268
272
|
self._set_prev_state(state)
|
|
269
273
|
|
|
270
274
|
def _set_prev_state(self, state):
|
|
@@ -516,9 +520,9 @@ class LayerBase:
|
|
|
516
520
|
# Special case: Input feature or sparse dim looks the same, so overtake it.
|
|
517
521
|
out_dim = sources_data.feature_dim_or_sparse_dim
|
|
518
522
|
if out_dim:
|
|
519
|
-
assert (
|
|
520
|
-
out_dim
|
|
521
|
-
)
|
|
523
|
+
assert out_dim.dimension == output.dim, (
|
|
524
|
+
f"Layer {name!r} out_dim {out_dim} does not match Data {output} via out_type {out_type}"
|
|
525
|
+
)
|
|
522
526
|
if output.sparse:
|
|
523
527
|
output.sparse_dim = out_dim
|
|
524
528
|
else:
|
|
@@ -850,9 +854,9 @@ class LayerBase:
|
|
|
850
854
|
loss_scale = d.pop("loss_scale", 1.0)
|
|
851
855
|
if loss_scale != 1.0:
|
|
852
856
|
if "scale" in loss_opts:
|
|
853
|
-
assert (
|
|
854
|
-
|
|
855
|
-
)
|
|
857
|
+
assert loss_opts["scale"] == loss_scale, (
|
|
858
|
+
"do not use loss_scale and loss with 'scale' option together"
|
|
859
|
+
)
|
|
856
860
|
loss_opts["scale"] = loss_scale
|
|
857
861
|
d["loss"] = cls._make_loss(
|
|
858
862
|
class_name=d.pop("loss", None), opts=loss_opts, network=network, get_layer=get_layer
|
|
@@ -2099,9 +2103,9 @@ class LayerBase:
|
|
|
2099
2103
|
src_output = src.output.copy()
|
|
2100
2104
|
if src_output.placeholder is not None:
|
|
2101
2105
|
zeroed_src_shape = tf_util.get_shape(src_output.placeholder)
|
|
2102
|
-
zeroed_src_shape = [
|
|
2106
|
+
zeroed_src_shape: List[Union[tf.Tensor, int]] = [
|
|
2103
2107
|
zeroed_src_shape[i] for i in range(src_output.batch_ndim)
|
|
2104
|
-
]
|
|
2108
|
+
]
|
|
2105
2109
|
else:
|
|
2106
2110
|
zeroed_src_shape = []
|
|
2107
2111
|
for i, d in enumerate(src_output.batch_shape):
|
|
@@ -2550,9 +2554,9 @@ class ReuseParams:
|
|
|
2550
2554
|
:rtype: tf.Variable|tf.Tensor
|
|
2551
2555
|
"""
|
|
2552
2556
|
if self.shape is not None:
|
|
2553
|
-
assert tuple(shape) == tuple(
|
|
2554
|
-
|
|
2555
|
-
)
|
|
2557
|
+
assert tuple(shape) == tuple(d.dimension for d in self.shape), (
|
|
2558
|
+
"%s: unexpected shape %r for param %r, expected %r" % (self, shape, name, self.shape)
|
|
2559
|
+
)
|
|
2556
2560
|
abs_scope_prefix = base_layer.get_absolute_name_scope_prefix()
|
|
2557
2561
|
assert not abs_scope_prefix or abs_scope_prefix.endswith("/")
|
|
2558
2562
|
assert name.startswith(abs_scope_prefix)
|
|
@@ -2609,10 +2613,10 @@ class SearchChoices:
|
|
|
2609
2613
|
assert beam_size is not None
|
|
2610
2614
|
self.owner = owner
|
|
2611
2615
|
self._done_src_layer = False
|
|
2612
|
-
self._src_layer
|
|
2613
|
-
self.src_beams
|
|
2616
|
+
self._src_layer: Optional[LayerBase] = None
|
|
2617
|
+
self.src_beams: Optional[tf.Tensor] = None # src beam index, (batch, beam)
|
|
2614
2618
|
self.beam_size = beam_size
|
|
2615
|
-
self.beam_scores
|
|
2619
|
+
self.beam_scores: Optional[tf.Tensor] = None # (batch, beam)
|
|
2616
2620
|
self.is_decided = is_decided
|
|
2617
2621
|
self.keep_raw = keep_raw
|
|
2618
2622
|
if not owner.output.beam:
|
|
@@ -2872,22 +2876,22 @@ class Loss:
|
|
|
2872
2876
|
"""
|
|
2873
2877
|
self.base_network = base_network
|
|
2874
2878
|
self.use_flatten_frames = use_flatten_frames
|
|
2875
|
-
self.layer
|
|
2879
|
+
self.layer: Optional[LayerBase] = None
|
|
2876
2880
|
# All are initialized in self.init().
|
|
2877
|
-
self.output
|
|
2878
|
-
self.output_with_activation
|
|
2879
|
-
self.output_seq_lens
|
|
2880
|
-
self.target
|
|
2881
|
-
self.target_seq_lens
|
|
2882
|
-
self.output_flat
|
|
2883
|
-
self.output_before_softmax_flat
|
|
2881
|
+
self.output: Optional[Data] = None
|
|
2882
|
+
self.output_with_activation: Optional[OutputWithActivation] = None
|
|
2883
|
+
self.output_seq_lens: Optional[tf.Tensor] = None
|
|
2884
|
+
self.target: Optional[Data] = None
|
|
2885
|
+
self.target_seq_lens: Optional[tf.Tensor] = None
|
|
2886
|
+
self.output_flat: Optional[tf.Tensor] = None
|
|
2887
|
+
self.output_before_softmax_flat: Optional[tf.Tensor] = None
|
|
2884
2888
|
if _check_output_before_softmax is not None:
|
|
2885
2889
|
self._check_output_before_softmax = _check_output_before_softmax
|
|
2886
|
-
self.target_flat
|
|
2890
|
+
self.target_flat: Optional[tf.Tensor] = None
|
|
2887
2891
|
# Maybe make configurable. For now, same as in our Theano behavior.
|
|
2888
2892
|
# The loss_norm_factor is used by Runner._normalize_loss both for normalization per epoch and per batch.
|
|
2889
2893
|
# It is e.g. set to 1/sum(target_seq_len), and logic of accumulation is handled in the Runner.
|
|
2890
|
-
self.loss_norm_factor
|
|
2894
|
+
self.loss_norm_factor: Optional[tf.Tensor] = None
|
|
2891
2895
|
self.use_normalized_loss = use_normalized_loss # for the optimizer, per batch
|
|
2892
2896
|
self.custom_norm_factor = custom_norm_factor
|
|
2893
2897
|
self.custom_inv_norm_factor = custom_inv_norm_factor
|
|
@@ -3132,18 +3136,21 @@ class Loss:
|
|
|
3132
3136
|
self.output,
|
|
3133
3137
|
self.target,
|
|
3134
3138
|
)
|
|
3135
|
-
assert (
|
|
3136
|
-
self.target
|
|
3137
|
-
)
|
|
3139
|
+
assert self.target.ndim_dense == self.output.ndim_dense, (
|
|
3140
|
+
"Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
|
|
3141
|
+
)
|
|
3138
3142
|
expected_output_dim = self.get_auto_output_layer_dim(self.target.feature_dim_or_sparse_dim)
|
|
3139
|
-
assert (
|
|
3140
|
-
|
|
3141
|
-
|
|
3142
|
-
|
|
3143
|
-
|
|
3144
|
-
|
|
3145
|
-
|
|
3146
|
-
|
|
3143
|
+
assert expected_output_dim.dimension == self.output.dim, (
|
|
3144
|
+
"Expected output dim is %r but the output has dim %r. "
|
|
3145
|
+
% (
|
|
3146
|
+
expected_output_dim,
|
|
3147
|
+
self.output.feature_dim_or_sparse_dim,
|
|
3148
|
+
)
|
|
3149
|
+
+ "Target: %s, output: %s"
|
|
3150
|
+
% (
|
|
3151
|
+
self.target,
|
|
3152
|
+
self.output,
|
|
3153
|
+
)
|
|
3147
3154
|
)
|
|
3148
3155
|
if self.base_network.get_config().bool("debug_runtime_sanity_checks", False):
|
|
3149
3156
|
with tf.name_scope("Loss_debug_runtime_sanity_checks"):
|