returnn 1.20250430.145858__py3-none-any.whl → 1.20250508.181644__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +24 -25
- returnn/datasets/cached.py +4 -3
- returnn/datasets/distrib_files.py +1 -2
- returnn/datasets/generating.py +20 -20
- returnn/datasets/hdf.py +9 -9
- returnn/datasets/lm.py +25 -13
- returnn/datasets/meta.py +39 -38
- returnn/datasets/normalization_data.py +1 -1
- returnn/datasets/postprocessing.py +9 -9
- returnn/datasets/sprint.py +8 -7
- returnn/datasets/util/strings.py +0 -1
- returnn/datasets/util/vocabulary.py +3 -3
- returnn/extern/graph_editor/subgraph.py +1 -2
- returnn/extern/graph_editor/transform.py +1 -2
- returnn/extern/graph_editor/util.py +1 -2
- returnn/frontend/_backend.py +4 -3
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/audio/mel.py +0 -1
- returnn/frontend/const.py +3 -3
- returnn/frontend/device.py +0 -1
- returnn/frontend/dropout.py +1 -1
- returnn/frontend/encoder/e_branchformer.py +1 -1
- returnn/frontend/loop.py +3 -3
- returnn/frontend/loss.py +0 -1
- returnn/frontend/matmul.py +0 -1
- returnn/frontend/run_ctx.py +9 -9
- returnn/frontend/signal.py +0 -1
- returnn/frontend/types.py +2 -4
- returnn/native_op.py +13 -0
- returnn/sprint/cache.py +2 -4
- returnn/sprint/interface.py +3 -4
- returnn/tensor/_dim_extra.py +9 -9
- returnn/tensor/_tensor_extra.py +20 -19
- returnn/tensor/_tensor_op_overloads.py +0 -1
- returnn/tensor/tensor.py +1 -1
- returnn/tensor/tensor_dict.py +9 -9
- returnn/tf/engine.py +60 -65
- returnn/tf/frontend_layers/_backend.py +3 -3
- returnn/tf/frontend_layers/cond.py +6 -6
- returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
- returnn/tf/frontend_layers/layer.py +12 -12
- returnn/tf/frontend_layers/loop.py +3 -3
- returnn/tf/frontend_layers/make_layer.py +0 -1
- returnn/tf/layers/base.py +56 -49
- returnn/tf/layers/basic.py +60 -65
- returnn/tf/layers/rec.py +74 -74
- returnn/tf/native_op.py +1 -3
- returnn/tf/network.py +60 -57
- returnn/tf/updater.py +3 -3
- returnn/tf/util/basic.py +24 -23
- returnn/torch/data/extern_data.py +4 -5
- returnn/torch/data/pipeline.py +3 -4
- returnn/torch/engine.py +16 -16
- returnn/torch/frontend/_backend.py +15 -15
- returnn/torch/frontend/bridge.py +3 -3
- returnn/torch/updater.py +8 -9
- returnn/torch/util/debug_inf_nan.py +0 -2
- returnn/torch/util/exception_helper.py +1 -1
- returnn/torch/util/scaled_gradient.py +0 -1
- returnn/util/basic.py +1 -2
- returnn/util/better_exchook.py +14 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/METADATA +1 -1
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/RECORD +68 -68
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/LICENSE +0 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/WHEEL +0 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/top_level.txt +0 -0
returnn/tf/network.py
CHANGED
|
@@ -3,7 +3,9 @@ Defines the :class:`TFNetwork` and :class:`ExternData`.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
|
|
6
|
+
|
|
7
|
+
from typing import Callable, List, Optional, Any, Protocol, Tuple, Dict, TYPE_CHECKING, Union
|
|
8
|
+
|
|
7
9
|
import tensorflow as tf
|
|
8
10
|
import sys
|
|
9
11
|
import re
|
|
@@ -19,6 +21,11 @@ from returnn.tensor import Tensor, Dim, TensorDict
|
|
|
19
21
|
from returnn.tf.util.data import Data
|
|
20
22
|
from returnn.util import basic as util
|
|
21
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from returnn.config import Config
|
|
26
|
+
from returnn.tf.layers.base import SearchChoices
|
|
27
|
+
from returnn.tf.util.data import BatchInfo
|
|
28
|
+
|
|
22
29
|
|
|
23
30
|
class DataNotFound(Exception):
|
|
24
31
|
"""
|
|
@@ -39,8 +46,8 @@ class ExternData(TensorDict):
|
|
|
39
46
|
:param None|dict[str,dict[str]] data: optional init kwargs for Data
|
|
40
47
|
"""
|
|
41
48
|
super().__init__()
|
|
42
|
-
self._config
|
|
43
|
-
self._batch_info
|
|
49
|
+
self._config: typing.Optional[Config] = None
|
|
50
|
+
self._batch_info: typing.Optional[BatchInfo] = None
|
|
44
51
|
self.default_input = default_input
|
|
45
52
|
self.default_target = default_target
|
|
46
53
|
self.extra_added_keys = set() # set[str]
|
|
@@ -369,8 +376,7 @@ def _extern_data_types_from_config(config):
|
|
|
369
376
|
print("Warning: Using extern_data and will ignore num_inputs/num_outputs in config.", file=log.v2)
|
|
370
377
|
else:
|
|
371
378
|
log.print_deprecation_warning(
|
|
372
|
-
"Using num_inputs/num_outputs instead of extern_data is deprecated"
|
|
373
|
-
" and might be removed in future versions"
|
|
379
|
+
"Using num_inputs/num_outputs instead of extern_data is deprecated and might be removed in future versions"
|
|
374
380
|
)
|
|
375
381
|
num_inputs, num_outputs = _num_inputs_outputs_from_config(config)
|
|
376
382
|
data_dims = num_outputs.copy()
|
|
@@ -502,7 +508,7 @@ class _NetworkConstructionStack:
|
|
|
502
508
|
"""
|
|
503
509
|
|
|
504
510
|
def __init__(self):
|
|
505
|
-
self.layers
|
|
511
|
+
self.layers: typing.List[str] = []
|
|
506
512
|
self.in_flat_construct_count = 0
|
|
507
513
|
|
|
508
514
|
def append(self, layer_name):
|
|
@@ -645,33 +651,31 @@ class TFNetwork:
|
|
|
645
651
|
self.extra_deps_in_extra = False
|
|
646
652
|
self.extra_only_template = False
|
|
647
653
|
self.is_root_in_ctx = not parent_net # default. might be overwritten
|
|
648
|
-
self.extra_nets
|
|
649
|
-
self.subnets
|
|
654
|
+
self.extra_nets: Dict[str, TFNetwork] = {}
|
|
655
|
+
self.subnets: Dict[str, Subnetwork] = {}
|
|
650
656
|
self._selected_train_layers = None
|
|
651
657
|
self._construction_stack = _NetworkConstructionStack()
|
|
652
658
|
self.layers_desc: Dict[str, Dict[str, Any]] = {}
|
|
653
659
|
self.layers: Dict[str, LayerBase] = {}
|
|
654
|
-
self.losses_dict
|
|
655
|
-
self.total_loss
|
|
656
|
-
self.total_constraints
|
|
657
|
-
self.total_objective
|
|
658
|
-
self._global_train_step
|
|
659
|
-
self._global_train_step_var
|
|
660
|
+
self.losses_dict: Dict[str, LossHolder] = {}
|
|
661
|
+
self.total_loss: Optional[tf.Tensor] = None
|
|
662
|
+
self.total_constraints: Optional[tf.Tensor] = None
|
|
663
|
+
self.total_objective: Optional[tf.Tensor] = None
|
|
664
|
+
self._global_train_step: Optional[tf.Tensor] = None
|
|
665
|
+
self._global_train_step_var: Optional[tf.Variable] = None
|
|
660
666
|
self.epoch_step = None
|
|
661
|
-
self.saver
|
|
662
|
-
self.extra_vars_to_save
|
|
667
|
+
self.saver: Optional[tf.compat.v1.train.Saver] = None
|
|
668
|
+
self.extra_vars_to_save: List[tf.Variable] = []
|
|
663
669
|
self.recurrent = False
|
|
664
|
-
self._assigner_cache
|
|
670
|
+
self._assigner_cache: Dict[tf.Variable, tf_util.VariableAssigner] = {}
|
|
665
671
|
self.concat_sources_dropout_cache: Dict[
|
|
666
672
|
Tuple[Tuple[LayerBase, ...], Dim, float, Optional[Tuple[Optional[int], ...]]], Data
|
|
667
673
|
] = {}
|
|
668
|
-
self._merge_all_summaries
|
|
669
|
-
self._graph_reset_callbacks
|
|
670
|
-
self._run_opts
|
|
671
|
-
self._run_finished_callbacks
|
|
672
|
-
self._map_search_beam_to_search_choices =
|
|
673
|
-
{}
|
|
674
|
-
) # type: typing.Dict[tf_util.SearchBeam,"returnn.tf.layers.base.SearchChoices"]
|
|
674
|
+
self._merge_all_summaries: Optional[tf.Tensor] = None
|
|
675
|
+
self._graph_reset_callbacks: List[Callable] = []
|
|
676
|
+
self._run_opts: Dict[str, Any] = {}
|
|
677
|
+
self._run_finished_callbacks: List[Callable] = []
|
|
678
|
+
self._map_search_beam_to_search_choices: Dict[tf_util.SearchBeam, SearchChoices] = {}
|
|
675
679
|
|
|
676
680
|
def __repr__(self):
|
|
677
681
|
s = "TFNetwork %r" % self.name
|
|
@@ -1308,15 +1312,16 @@ class TFNetwork:
|
|
|
1308
1312
|
layer.output.sanity_check()
|
|
1309
1313
|
# The axes should not have moved now.
|
|
1310
1314
|
output_special_axes = layer.output.get_special_axes_dict()
|
|
1311
|
-
assert (
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1315
|
+
assert output_template_special_axes == output_special_axes, (
|
|
1316
|
+
"%s %r: not equal: %r == %r, from data %r -> %r"
|
|
1317
|
+
% (
|
|
1318
|
+
layer_class.__name__,
|
|
1319
|
+
name,
|
|
1320
|
+
output_template_special_axes,
|
|
1321
|
+
output_special_axes,
|
|
1322
|
+
output_template,
|
|
1323
|
+
layer.output,
|
|
1324
|
+
)
|
|
1320
1325
|
)
|
|
1321
1326
|
except TypeError:
|
|
1322
1327
|
help_on_type_error_wrong_args(cls=layer_class, kwargs=list(layer_desc.keys()))
|
|
@@ -1486,7 +1491,7 @@ class TFNetwork:
|
|
|
1486
1491
|
else:
|
|
1487
1492
|
total_loss = None
|
|
1488
1493
|
total_constraints = None
|
|
1489
|
-
losses_multi_dict
|
|
1494
|
+
losses_multi_dict: Dict[str, List[Tuple[Optional[str], LossHolder]]] = {}
|
|
1490
1495
|
# self.layers also include extra net layers and sub layers, see add_layer.
|
|
1491
1496
|
for name, layer in sorted(self.layers.items()):
|
|
1492
1497
|
assert isinstance(layer, LayerBase)
|
|
@@ -1869,14 +1874,15 @@ class TFNetwork:
|
|
|
1869
1874
|
|
|
1870
1875
|
# All end points must be mapped now.
|
|
1871
1876
|
for layer in end_points:
|
|
1872
|
-
assert (
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1877
|
+
assert layer in mapped_layers, (
|
|
1878
|
+
"end point %r not mapped.\n end points:\n%s\n mapped:\n%s\n blacklist:\n%s\n starting points:\n%s"
|
|
1879
|
+
% (
|
|
1880
|
+
layer,
|
|
1881
|
+
pformat(end_points),
|
|
1882
|
+
pformat(mapped_layers),
|
|
1883
|
+
pformat(blacklist),
|
|
1884
|
+
pformat(starting_points),
|
|
1885
|
+
)
|
|
1880
1886
|
)
|
|
1881
1887
|
# Assign flatten_with_seq_len_mask cache to mapped layers.
|
|
1882
1888
|
for layer, new_layer in mapped_layers.items():
|
|
@@ -2402,9 +2408,7 @@ class TFNetwork:
|
|
|
2402
2408
|
|
|
2403
2409
|
Note that this excludes auxiliary params.
|
|
2404
2410
|
"""
|
|
2405
|
-
layers = {
|
|
2406
|
-
layer.get_absolute_name(): layer for layer in self.get_all_layers_deep()
|
|
2407
|
-
} # type: typing.Dict[str,LayerBase]
|
|
2411
|
+
layers: Dict[str, LayerBase] = {layer.get_absolute_name(): layer for layer in self.get_all_layers_deep()}
|
|
2408
2412
|
for layer_name, layer_values_dict in values_dict.items():
|
|
2409
2413
|
if layer_values_dict:
|
|
2410
2414
|
if ignore_non_existing and layer_name not in layers:
|
|
@@ -4091,9 +4095,9 @@ class LossHolder:
|
|
|
4091
4095
|
self._error_value = self._layer._cond_only_on_eval_opt(self.loss.get_error, default_value=0.0)
|
|
4092
4096
|
else:
|
|
4093
4097
|
self._error_value = self.loss.get_error()
|
|
4094
|
-
assert (
|
|
4095
|
-
|
|
4096
|
-
)
|
|
4098
|
+
assert self._loss_value is not None or self._error_value is not None, (
|
|
4099
|
+
"layer %r loss %r return None for loss and error" % (self._layer, self.loss)
|
|
4100
|
+
)
|
|
4097
4101
|
if self._norm_factor is None:
|
|
4098
4102
|
self._norm_factor = self.loss.get_normalization_factor()
|
|
4099
4103
|
loss_value = self._loss_value
|
|
@@ -4515,12 +4519,12 @@ class CustomCheckpointLoader:
|
|
|
4515
4519
|
# All variables in the checkpoint:
|
|
4516
4520
|
self.var_ckpt_names = set(self.reader.get_variable_to_shape_map()) # type: typing.Set[str]
|
|
4517
4521
|
# All variables of the model to be loaded:
|
|
4518
|
-
self.var_net_names = {
|
|
4522
|
+
self.var_net_names: Dict[str, Union[tf.Variable, Any]] = {
|
|
4519
4523
|
self._get_param_name(v): v for v in self.saveable_params
|
|
4520
|
-
}
|
|
4524
|
+
}
|
|
4521
4525
|
# Model variables missing in the checkpoint:
|
|
4522
|
-
self.missing_var_names
|
|
4523
|
-
self.missing_non_critical_var_names
|
|
4526
|
+
self.missing_var_names: List[str] = []
|
|
4527
|
+
self.missing_non_critical_var_names: List[str] = []
|
|
4524
4528
|
for name, v in sorted(self.var_net_names.items()):
|
|
4525
4529
|
if name in self.var_ckpt_names:
|
|
4526
4530
|
continue
|
|
@@ -4702,10 +4706,10 @@ class CustomCheckpointLoader:
|
|
|
4702
4706
|
"rnn/lstm_cell/bias": "lstm_cell/bias",
|
|
4703
4707
|
"rnn/lstm_cell/kernel": "lstm_cell/kernel",
|
|
4704
4708
|
(
|
|
4705
|
-
"cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/
|
|
4709
|
+
"cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias"
|
|
4706
4710
|
): "lstm_fused_cell/bias",
|
|
4707
4711
|
(
|
|
4708
|
-
"cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/
|
|
4712
|
+
"cudnn/params_canonical/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel"
|
|
4709
4713
|
): "lstm_fused_cell/kernel",
|
|
4710
4714
|
}
|
|
4711
4715
|
|
|
@@ -4877,7 +4881,7 @@ class CustomCheckpointLoader:
|
|
|
4877
4881
|
self.target = target
|
|
4878
4882
|
self.keys = [target + "bias", target + "kernel"]
|
|
4879
4883
|
self.prefix = prefix
|
|
4880
|
-
self.data
|
|
4884
|
+
self.data: typing.Optional[typing.Dict[str, numpy.ndarray]] = None
|
|
4881
4885
|
|
|
4882
4886
|
# noinspection PyMethodParameters
|
|
4883
4887
|
def _load(sself):
|
|
@@ -5140,8 +5144,7 @@ class CustomLoadParamFunc(Protocol):
|
|
|
5140
5144
|
|
|
5141
5145
|
def __call__(
|
|
5142
5146
|
self, *, name: str, shape: Tuple[int], reader: tf.compat.v1.train.NewCheckpointReader
|
|
5143
|
-
) -> Optional[numpy.ndarray]:
|
|
5144
|
-
...
|
|
5147
|
+
) -> Optional[numpy.ndarray]: ...
|
|
5145
5148
|
|
|
5146
5149
|
|
|
5147
5150
|
def set_custom_post_init(var, func):
|
returnn/tf/updater.py
CHANGED
|
@@ -219,9 +219,9 @@ class Updater:
|
|
|
219
219
|
|
|
220
220
|
learning_rate_function = self.config.typed_dict.get("dynamic_learning_rate")
|
|
221
221
|
signature = inspect.signature(learning_rate_function)
|
|
222
|
-
assert any(
|
|
223
|
-
|
|
224
|
-
)
|
|
222
|
+
assert any([arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]), (
|
|
223
|
+
"please specify **kwargs in dynamic_learning_rate for future compatibility"
|
|
224
|
+
)
|
|
225
225
|
if "epoch" in signature.parameters:
|
|
226
226
|
raise NotImplementedError("TF updater: dynamic_learning_rate with epoch not supported currently")
|
|
227
227
|
lr = learning_rate_function(
|
returnn/tf/util/basic.py
CHANGED
|
@@ -1799,7 +1799,7 @@ def dropout(
|
|
|
1799
1799
|
x = tf.convert_to_tensor(x, name="x")
|
|
1800
1800
|
assert isinstance(x, tf.Tensor)
|
|
1801
1801
|
if isinstance(keep_prob, (float, int)) and not 0 < keep_prob <= 1:
|
|
1802
|
-
raise ValueError("keep_prob must be a scalar tensor or a float in the
|
|
1802
|
+
raise ValueError("keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob)
|
|
1803
1803
|
# Do nothing if we know keep_prob == 1
|
|
1804
1804
|
if isinstance(keep_prob, (float, int)) and keep_prob == 1:
|
|
1805
1805
|
return x
|
|
@@ -2492,9 +2492,9 @@ def get_common_shape(values, ignore_axes=(), allow_broadcast_all_sources=NotSpec
|
|
|
2492
2492
|
import numpy
|
|
2493
2493
|
|
|
2494
2494
|
assert len(values) > 0
|
|
2495
|
-
assert all(
|
|
2496
|
-
|
|
2497
|
-
)
|
|
2495
|
+
assert all([isinstance(value, (tf.Tensor, tf.Variable, float, int, numpy.number)) for value in values]), (
|
|
2496
|
+
"types %r" % ([type(v) for v in values])
|
|
2497
|
+
)
|
|
2498
2498
|
# Filter out scalars.
|
|
2499
2499
|
values = [value for value in values if isinstance(value, (tf.Tensor, tf.Variable))]
|
|
2500
2500
|
assert all([value.shape.ndims is not None for value in values]), "some unknown ndim"
|
|
@@ -2523,14 +2523,15 @@ def get_common_shape(values, ignore_axes=(), allow_broadcast_all_sources=NotSpec
|
|
|
2523
2523
|
common_shape[axis] = static_dim
|
|
2524
2524
|
else: # common_shape is int
|
|
2525
2525
|
assert isinstance(common_shape[axis], int)
|
|
2526
|
-
assert (
|
|
2527
|
-
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2526
|
+
assert common_shape[axis] == static_dim, (
|
|
2527
|
+
"non matching dim %r vs %r in axis %i, value %r of values %r"
|
|
2528
|
+
% (
|
|
2529
|
+
common_shape[axis],
|
|
2530
|
+
static_dim,
|
|
2531
|
+
axis,
|
|
2532
|
+
value,
|
|
2533
|
+
values,
|
|
2534
|
+
)
|
|
2534
2535
|
)
|
|
2535
2536
|
# Check validate_broadcast_all_sources
|
|
2536
2537
|
need_broadcast = {id(value): False for value in values}
|
|
@@ -2576,9 +2577,9 @@ def unbroadcast_to_common_shape(value, common_shape, ignore_axes=(), allow_only_
|
|
|
2576
2577
|
for axis in ignore_axes:
|
|
2577
2578
|
assert 0 <= axis < ndim
|
|
2578
2579
|
tile_multiples[axis] = 1
|
|
2579
|
-
assert all(
|
|
2580
|
-
|
|
2581
|
-
)
|
|
2580
|
+
assert all([m is not None for m in tile_multiples]), (
|
|
2581
|
+
"ignore_axes %r probably missing some axis for common shape %r" % (ignore_axes, common_shape)
|
|
2582
|
+
)
|
|
2582
2583
|
if all([isinstance(m, int) and m == 1 for m in tile_multiples]):
|
|
2583
2584
|
# We have a no-op.
|
|
2584
2585
|
return value
|
|
@@ -6611,7 +6612,6 @@ def find_unsupported_devices_in_graph(graph, dev_name, ignore=None):
|
|
|
6611
6612
|
|
|
6612
6613
|
|
|
6613
6614
|
class _DeviceAttrMod:
|
|
6614
|
-
|
|
6615
6615
|
_tf_mod = None
|
|
6616
6616
|
|
|
6617
6617
|
@classmethod
|
|
@@ -7680,13 +7680,14 @@ class FetchHelper:
|
|
|
7680
7680
|
_, info = copier(sgv, dst_graph=sgv.graph, dst_scope="", reuse_dst_scope=True)
|
|
7681
7681
|
assert isinstance(info, graph_editor.TransformerInfo)
|
|
7682
7682
|
target_op_transformed = info.transformed(target_op)
|
|
7683
|
-
assert isinstance(
|
|
7684
|
-
|
|
7685
|
-
|
|
7686
|
-
|
|
7687
|
-
|
|
7688
|
-
|
|
7689
|
-
|
|
7683
|
+
assert isinstance(target_op_transformed, tf.Operation), (
|
|
7684
|
+
"\ntarget_op\n%r,\nfetches\n%r,\nstop_at_ts\n%s,\nops\n%s"
|
|
7685
|
+
% (
|
|
7686
|
+
target_op,
|
|
7687
|
+
fetches,
|
|
7688
|
+
pformat(stop_at_ts),
|
|
7689
|
+
pformat(ops),
|
|
7690
|
+
)
|
|
7690
7691
|
)
|
|
7691
7692
|
fetch_helpers = []
|
|
7692
7693
|
for x in fetch_helper_tensors:
|
|
@@ -56,9 +56,9 @@ def raw_dict_to_extern_data(
|
|
|
56
56
|
assert len(raw_tensor.shape) == data.batch_ndim, f"ndim mismatch for {k}: {raw_tensor.shape} vs {data}"
|
|
57
57
|
for i, dim in enumerate(data.dims):
|
|
58
58
|
if dim.dimension is not None:
|
|
59
|
-
assert (
|
|
60
|
-
|
|
61
|
-
)
|
|
59
|
+
assert dim.dimension == raw_tensor.shape[i], (
|
|
60
|
+
f"shape mismatch for {k}: {raw_tensor.shape} vs {data.batch_shape}"
|
|
61
|
+
)
|
|
62
62
|
if isinstance(raw_tensor, torch.Tensor):
|
|
63
63
|
if raw_tensor.dtype.is_floating_point and float_dtype:
|
|
64
64
|
raw_tensor = raw_tensor.to(dtype=float_dtype)
|
|
@@ -81,8 +81,7 @@ def raw_dict_to_extern_data(
|
|
|
81
81
|
and (data.dims[1].dyn_size_ext is None or data.dims[1].dyn_size_ext.raw_tensor is None)
|
|
82
82
|
):
|
|
83
83
|
assert k + ":seq_len" in extern_data_raw, (
|
|
84
|
-
f"extern_data {data}, dyn spatial dim, missing {k}:seq_len in raw dict, "
|
|
85
|
-
f"check dataset or collate_batch"
|
|
84
|
+
f"extern_data {data}, dyn spatial dim, missing {k}:seq_len in raw dict, check dataset or collate_batch"
|
|
86
85
|
)
|
|
87
86
|
size = extern_data_raw[k + ":seq_len"]
|
|
88
87
|
# Sequence lengths have to be on CPU for the later call to rnn.pack_padded_sequence
|
returnn/torch/data/pipeline.py
CHANGED
|
@@ -123,7 +123,6 @@ class ChunkingIterDataPipe(torch.utils.data.IterDataPipe):
|
|
|
123
123
|
chunking_data_keys = list(self._chunk_size.keys())
|
|
124
124
|
|
|
125
125
|
for data_dict in self._dataset:
|
|
126
|
-
|
|
127
126
|
if not chunking_data_keys:
|
|
128
127
|
chunking_data_keys = list(data_dict.keys()) # use all if not configured separately
|
|
129
128
|
chunking_data_key_black_list = ["seq_tag", "seq_idx", "num_seqs", "epoch", "complete_frac"]
|
|
@@ -150,9 +149,9 @@ class ChunkingIterDataPipe(torch.utils.data.IterDataPipe):
|
|
|
150
149
|
if num_chunks is None:
|
|
151
150
|
num_chunks = len(chunks)
|
|
152
151
|
else:
|
|
153
|
-
assert num_chunks == len(
|
|
154
|
-
chunks
|
|
155
|
-
)
|
|
152
|
+
assert num_chunks == len(chunks), (
|
|
153
|
+
"Chunking resulted in different number of chunks for different data keys."
|
|
154
|
+
)
|
|
156
155
|
|
|
157
156
|
data_chunks[data_key] = chunks
|
|
158
157
|
|
returnn/torch/engine.py
CHANGED
|
@@ -66,22 +66,22 @@ class Engine(EngineBase):
|
|
|
66
66
|
self.model_filename = self.config.value("model", None)
|
|
67
67
|
self._mp_manager = torch.multiprocessing.Manager()
|
|
68
68
|
self._epoch_mp_shared = self._mp_manager.Value("i", 0)
|
|
69
|
-
self.train_dataset
|
|
69
|
+
self.train_dataset: Optional[Dataset] = None
|
|
70
70
|
self.eval_datasets = {}
|
|
71
|
-
self.extern_data
|
|
72
|
-
self._train_dataloader
|
|
73
|
-
self._eval_dataloaders
|
|
71
|
+
self.extern_data: Optional[TensorDict] = None
|
|
72
|
+
self._train_dataloader: Optional[DataLoader] = None
|
|
73
|
+
self._eval_dataloaders: Dict[str, DataLoader] = {}
|
|
74
74
|
|
|
75
|
-
self._start_epoch
|
|
76
|
-
self._final_epoch
|
|
77
|
-
self._min_seq_length
|
|
75
|
+
self._start_epoch: Optional[int] = None
|
|
76
|
+
self._final_epoch: Optional[int] = None
|
|
77
|
+
self._min_seq_length: Union[int, float, Dict[str, int], NumbersDict] = config.typed_value(
|
|
78
78
|
"min_seq_length", None
|
|
79
|
-
)
|
|
80
|
-
self._max_seq_length
|
|
79
|
+
) or config.int("min_seq_length", None)
|
|
80
|
+
self._max_seq_length: Union[int, float, Dict[str, int], NumbersDict] = config.typed_value(
|
|
81
81
|
"max_seq_length", None
|
|
82
|
-
)
|
|
83
|
-
self._orig_model
|
|
84
|
-
self._pt_model
|
|
82
|
+
) or config.int("max_seq_length", None)
|
|
83
|
+
self._orig_model: Optional[Union[rf.Module, torch.nn.Module]] = None
|
|
84
|
+
self._pt_model: Optional[torch.nn.Module] = None
|
|
85
85
|
self._epoch_start_func: Optional[Callable] = self.config.typed_value("epoch_start")
|
|
86
86
|
self._epoch_end_func: Optional[Callable] = self.config.typed_value("epoch_end")
|
|
87
87
|
self._train_step_func: Optional[Callable] = None
|
|
@@ -95,15 +95,15 @@ class Engine(EngineBase):
|
|
|
95
95
|
self._updater: Optional[Updater] = None
|
|
96
96
|
|
|
97
97
|
self._use_autocast = False
|
|
98
|
-
self._autocast_dtype
|
|
99
|
-
self._grad_scaler
|
|
98
|
+
self._autocast_dtype: Optional[str] = None
|
|
99
|
+
self._grad_scaler: Optional[amp.GradScaler] = None
|
|
100
100
|
|
|
101
101
|
dev_ = get_device_from_config_opt(config.value("device", None))
|
|
102
102
|
self._device = dev_.result
|
|
103
103
|
print("Using device:", self._device, f"({dev_.reason or '?'})", file=log.v2)
|
|
104
104
|
|
|
105
|
-
self._torch_distributed_ctx
|
|
106
|
-
self._ddp_pt_model
|
|
105
|
+
self._torch_distributed_ctx: Optional[DistributedContext] = None
|
|
106
|
+
self._ddp_pt_model: Optional[DistributedDataParallel] = None
|
|
107
107
|
|
|
108
108
|
if config.typed_value("torch_distributed") is not None:
|
|
109
109
|
self._torch_distributed_ctx = dist_get_ctx(config=config)
|
|
@@ -421,9 +421,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
421
421
|
else: # not allow_broadcast
|
|
422
422
|
for source, dim in sources:
|
|
423
423
|
templ_dims = other_dims[:axis] + [dim] + other_dims[axis:]
|
|
424
|
-
assert set(templ_dims) == set(
|
|
425
|
-
source
|
|
426
|
-
)
|
|
424
|
+
assert set(templ_dims) == set(source.dims), (
|
|
425
|
+
f"concat {source} {dim} not allowed with allow_broadcast=False"
|
|
426
|
+
)
|
|
427
427
|
source_ = source.copy_transpose(templ_dims)
|
|
428
428
|
sources_raw.append(source_.raw_tensor)
|
|
429
429
|
out = Tensor(
|
|
@@ -612,9 +612,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
612
612
|
assert axis in logits.dims, "Specified axis not present in logits."
|
|
613
613
|
|
|
614
614
|
if axis == targets.sparse_dim:
|
|
615
|
-
assert (
|
|
616
|
-
logits
|
|
617
|
-
)
|
|
615
|
+
assert logits.dims_set - {axis} == targets.dims_set, (
|
|
616
|
+
"logits Dims and target Dims have to match (except for implicit sparse_dim)."
|
|
617
|
+
)
|
|
618
618
|
|
|
619
619
|
logits_dim_order = list(targets.dims)
|
|
620
620
|
if len(logits_dim_order) > 0:
|
|
@@ -629,9 +629,9 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
629
629
|
targets.raw_tensor = targets.raw_tensor.long()
|
|
630
630
|
|
|
631
631
|
else:
|
|
632
|
-
assert (
|
|
633
|
-
|
|
634
|
-
)
|
|
632
|
+
assert not targets.sparse_dim, (
|
|
633
|
+
"We expect that cross entropy would always be calculated along the sparse dim, if there is one."
|
|
634
|
+
)
|
|
635
635
|
assert logits.dims_set == targets.dims_set, "logits Dims and target Dims have to match."
|
|
636
636
|
assert axis in targets.dims, "Specified axis not present in targets."
|
|
637
637
|
|
|
@@ -1348,12 +1348,12 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1348
1348
|
a_dims = a.dims
|
|
1349
1349
|
b_dims = b.dims
|
|
1350
1350
|
|
|
1351
|
-
assert all(
|
|
1352
|
-
|
|
1353
|
-
)
|
|
1354
|
-
assert all(
|
|
1355
|
-
|
|
1356
|
-
)
|
|
1351
|
+
assert all(dim in a_dims for dim in reduce), (
|
|
1352
|
+
f"'a' does not have the specified reduce dim(s) {reduce} (a dims: {a_dims})"
|
|
1353
|
+
)
|
|
1354
|
+
assert all(dim in b_dims for dim in reduce), (
|
|
1355
|
+
f"'b' does not have the specified reduce dim(s) {reduce} (b dims: {b_dims})"
|
|
1356
|
+
)
|
|
1357
1357
|
|
|
1358
1358
|
if len(reduce) > 1:
|
|
1359
1359
|
reduce = list(reduce)
|
returnn/torch/frontend/bridge.py
CHANGED
|
@@ -178,9 +178,9 @@ class RFModuleAsPTModule(torch.nn.Module):
|
|
|
178
178
|
rf_param = getattr(self._rf_module, name, None)
|
|
179
179
|
if not isinstance(rf_param, rf.Parameter):
|
|
180
180
|
return # just ignore
|
|
181
|
-
assert isinstance(
|
|
182
|
-
param
|
|
183
|
-
)
|
|
181
|
+
assert isinstance(param, torch.nn.Parameter), (
|
|
182
|
+
f"{self} register_parameter {name}: did not get a Parameter but {type(param).__name__}"
|
|
183
|
+
)
|
|
184
184
|
rf_param.raw_tensor = param
|
|
185
185
|
|
|
186
186
|
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
|
returnn/torch/updater.py
CHANGED
|
@@ -39,7 +39,7 @@ def _init_optimizer_classes_dict():
|
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def get_optimizer_class(
|
|
42
|
-
class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]]
|
|
42
|
+
class_name: Union[str, Type[torch.optim.Optimizer], Callable[[], Type[torch.optim.Optimizer]]],
|
|
43
43
|
) -> Type[torch.optim.Optimizer]:
|
|
44
44
|
"""
|
|
45
45
|
:param class_name: Optimizer class, either as str (e.g. "adam"), as type (torch.optim.Adam) or callable.
|
|
@@ -121,9 +121,9 @@ class Updater:
|
|
|
121
121
|
import inspect
|
|
122
122
|
|
|
123
123
|
signature = inspect.signature(self.learning_rate_function)
|
|
124
|
-
assert any(
|
|
125
|
-
|
|
126
|
-
)
|
|
124
|
+
assert any([arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]), (
|
|
125
|
+
"please specify **kwargs in dynamic_learning_rate for future compatibility"
|
|
126
|
+
)
|
|
127
127
|
if "network" in signature.parameters:
|
|
128
128
|
raise ValueError("Torch updater: dynamic_learning_rate network is TF specific")
|
|
129
129
|
else:
|
|
@@ -497,10 +497,9 @@ class Updater:
|
|
|
497
497
|
# Split in parameter groups only if decouple_constraints is set and the optimizer accepts weight_decay.
|
|
498
498
|
cls_init_kwargs = _get_class_init_kwargs(optim_class)
|
|
499
499
|
if "weight_decay" not in cls_init_kwargs:
|
|
500
|
-
assert (
|
|
501
|
-
"weight_decay
|
|
502
|
-
|
|
503
|
-
"%s" % optim_name for optim_name in cls_init_kwargs
|
|
500
|
+
assert "weight_decay" not in optimizer_opts, (
|
|
501
|
+
"weight_decay not accepted by the chosen optimizer. Accepted values: %s"
|
|
502
|
+
% ", ".join("%s" % optim_name for optim_name in cls_init_kwargs)
|
|
504
503
|
)
|
|
505
504
|
return network_params
|
|
506
505
|
|
|
@@ -564,7 +563,7 @@ class Updater:
|
|
|
564
563
|
|
|
565
564
|
|
|
566
565
|
def _wrap_user_blacklist_wd_modules(
|
|
567
|
-
mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]]
|
|
566
|
+
mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]],
|
|
568
567
|
) -> Tuple[type, ...]:
|
|
569
568
|
assert isinstance(mods, (list, tuple)), f"invalid blacklist_weight_decay_modules {mods!r}"
|
|
570
569
|
res = []
|
|
@@ -30,7 +30,6 @@ Also, there might be inf/nan values which are ok, expected, and not a problem
|
|
|
30
30
|
So we don't stop on the first occurrence but just report all of them.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
-
|
|
34
33
|
from __future__ import annotations
|
|
35
34
|
|
|
36
35
|
import sys
|
|
@@ -90,7 +89,6 @@ def debug_inf_nan(
|
|
|
90
89
|
print(f"Caught RuntimeError in backward: {exc}", file=file)
|
|
91
90
|
|
|
92
91
|
else: # without grad
|
|
93
|
-
|
|
94
92
|
with trace_ops:
|
|
95
93
|
func()
|
|
96
94
|
|
|
@@ -79,7 +79,7 @@ def help_on_torch_exception(
|
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
def _help_data_or_array(
|
|
82
|
-
value: Union[torch.Tensor, np.ndarray, bool, object]
|
|
82
|
+
value: Union[torch.Tensor, np.ndarray, bool, object],
|
|
83
83
|
) -> Tuple[str, Tuple[Union[int, float], Union[int, float]]]:
|
|
84
84
|
"""
|
|
85
85
|
:param value:
|
|
@@ -14,7 +14,6 @@ https://github.com/janfreyberg/pytorch-revgrad/blob/449fa763a76d/src/pytorch_rev
|
|
|
14
14
|
https://github.com/tadeephuy/GradientReversal/blob/5d9857d63/gradient_reversal/functional.py
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
from __future__ import annotations
|
|
19
18
|
from typing import Optional, Union
|
|
20
19
|
import torch
|
returnn/util/basic.py
CHANGED
|
@@ -705,7 +705,7 @@ def expand_env_vars(s: str) -> str:
|
|
|
705
705
|
return delim
|
|
706
706
|
if mo.group("invalid") is not None:
|
|
707
707
|
i = mo.start("invalid")
|
|
708
|
-
raise ValueError(f"Invalid placeholder in string: {s[i:i+2]!r}...")
|
|
708
|
+
raise ValueError(f"Invalid placeholder in string: {s[i : i + 2]!r}...")
|
|
709
709
|
raise ValueError(f"Unrecognized named group in pattern {pattern}")
|
|
710
710
|
|
|
711
711
|
return pattern_.sub(_convert, s)
|
|
@@ -1811,7 +1811,6 @@ def json_remove_comments(string, strip_space=True):
|
|
|
1811
1811
|
index = 0
|
|
1812
1812
|
|
|
1813
1813
|
for match in re.finditer(tokenizer, string):
|
|
1814
|
-
|
|
1815
1814
|
if not (in_multi or in_single):
|
|
1816
1815
|
tmp = string[index : match.start()]
|
|
1817
1816
|
if not in_string and strip_space:
|