returnn 1.20250430.145858__py3-none-any.whl → 1.20250508.181644__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +24 -25
- returnn/datasets/cached.py +4 -3
- returnn/datasets/distrib_files.py +1 -2
- returnn/datasets/generating.py +20 -20
- returnn/datasets/hdf.py +9 -9
- returnn/datasets/lm.py +25 -13
- returnn/datasets/meta.py +39 -38
- returnn/datasets/normalization_data.py +1 -1
- returnn/datasets/postprocessing.py +9 -9
- returnn/datasets/sprint.py +8 -7
- returnn/datasets/util/strings.py +0 -1
- returnn/datasets/util/vocabulary.py +3 -3
- returnn/extern/graph_editor/subgraph.py +1 -2
- returnn/extern/graph_editor/transform.py +1 -2
- returnn/extern/graph_editor/util.py +1 -2
- returnn/frontend/_backend.py +4 -3
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/audio/mel.py +0 -1
- returnn/frontend/const.py +3 -3
- returnn/frontend/device.py +0 -1
- returnn/frontend/dropout.py +1 -1
- returnn/frontend/encoder/e_branchformer.py +1 -1
- returnn/frontend/loop.py +3 -3
- returnn/frontend/loss.py +0 -1
- returnn/frontend/matmul.py +0 -1
- returnn/frontend/run_ctx.py +9 -9
- returnn/frontend/signal.py +0 -1
- returnn/frontend/types.py +2 -4
- returnn/native_op.py +13 -0
- returnn/sprint/cache.py +2 -4
- returnn/sprint/interface.py +3 -4
- returnn/tensor/_dim_extra.py +9 -9
- returnn/tensor/_tensor_extra.py +20 -19
- returnn/tensor/_tensor_op_overloads.py +0 -1
- returnn/tensor/tensor.py +1 -1
- returnn/tensor/tensor_dict.py +9 -9
- returnn/tf/engine.py +60 -65
- returnn/tf/frontend_layers/_backend.py +3 -3
- returnn/tf/frontend_layers/cond.py +6 -6
- returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
- returnn/tf/frontend_layers/layer.py +12 -12
- returnn/tf/frontend_layers/loop.py +3 -3
- returnn/tf/frontend_layers/make_layer.py +0 -1
- returnn/tf/layers/base.py +56 -49
- returnn/tf/layers/basic.py +60 -65
- returnn/tf/layers/rec.py +74 -74
- returnn/tf/native_op.py +1 -3
- returnn/tf/network.py +60 -57
- returnn/tf/updater.py +3 -3
- returnn/tf/util/basic.py +24 -23
- returnn/torch/data/extern_data.py +4 -5
- returnn/torch/data/pipeline.py +3 -4
- returnn/torch/engine.py +16 -16
- returnn/torch/frontend/_backend.py +15 -15
- returnn/torch/frontend/bridge.py +3 -3
- returnn/torch/updater.py +8 -9
- returnn/torch/util/debug_inf_nan.py +0 -2
- returnn/torch/util/exception_helper.py +1 -1
- returnn/torch/util/scaled_gradient.py +0 -1
- returnn/util/basic.py +1 -2
- returnn/util/better_exchook.py +14 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/METADATA +1 -1
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/RECORD +68 -68
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/LICENSE +0 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/WHEEL +0 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/top_level.txt +0 -0
returnn/tf/layers/basic.py
CHANGED
|
@@ -4,7 +4,7 @@ Many canonical basic layers.
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
from typing import Optional, Union, Sequence, List, Tuple, Dict
|
|
7
|
+
from typing import Callable, Optional, Union, Sequence, List, Tuple, Dict
|
|
8
8
|
import typing
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
import contextlib
|
|
@@ -126,7 +126,7 @@ def concat_sources(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpec
|
|
|
126
126
|
data.placeholder = tf.concat(
|
|
127
127
|
axis=data.feature_dim_axis, values=[layer_data.placeholder for layer_data in layers_data]
|
|
128
128
|
)
|
|
129
|
-
axes_split_info = [None] * data.batch_ndim
|
|
129
|
+
axes_split_info: List[Optional[List[int]]] = [None] * data.batch_ndim
|
|
130
130
|
axes_split_info[data.feature_dim_axis] = [layer_data.dim for layer_data in layers_data]
|
|
131
131
|
tf_util.set_param_axes_split_info(data.placeholder, axes_split_info)
|
|
132
132
|
# Note: We will loose this info for any further op (e.g. dropout, activation, etc). Should be better...
|
|
@@ -294,7 +294,7 @@ class _ConcatInputLayer(LayerBase):
|
|
|
294
294
|
elif mask == "dropout":
|
|
295
295
|
assert dropout > 0
|
|
296
296
|
self.dropout = dropout
|
|
297
|
-
self.input_data
|
|
297
|
+
self.input_data: Optional[Data] = None
|
|
298
298
|
if self.sources:
|
|
299
299
|
self.input_data = concat_sources_with_opt_dropout(
|
|
300
300
|
self.sources,
|
|
@@ -509,9 +509,7 @@ class ConcatLayer(LayerBase):
|
|
|
509
509
|
assert sources
|
|
510
510
|
sources, axes = zip(*sources) # unzip
|
|
511
511
|
axes_int = [layer.output.get_axis_from_description(axis) for (layer, axis) in zip(sources, axes)]
|
|
512
|
-
concat_dim_tags = [
|
|
513
|
-
layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int)
|
|
514
|
-
] # type: typing.List[Dim]
|
|
512
|
+
concat_dim_tags: List[Dim] = [layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int)]
|
|
515
513
|
if any(tag.dimension is None for tag in concat_dim_tags):
|
|
516
514
|
dimension = None
|
|
517
515
|
else:
|
|
@@ -707,8 +705,8 @@ class SelectSearchSourcesLayer(InternalLayer):
|
|
|
707
705
|
self.output = src.output.copy_as_batch_major()
|
|
708
706
|
self.rec_vars_outputs = src.rec_vars_outputs.copy()
|
|
709
707
|
src_search_choices = src.get_search_choices()
|
|
710
|
-
self.transform_func
|
|
711
|
-
self.search_choices_seq
|
|
708
|
+
self.transform_func: Optional[Callable[[tf.Tensor], tf.Tensor]] = None
|
|
709
|
+
self.search_choices_seq: Optional[List[SearchChoices]] = None
|
|
712
710
|
if not search_choices:
|
|
713
711
|
assert not src_search_choices
|
|
714
712
|
assert not self.output.beam
|
|
@@ -726,13 +724,7 @@ class SelectSearchSourcesLayer(InternalLayer):
|
|
|
726
724
|
assert src_search_choices in search_choices_seq, self.network.debug_search_choices(
|
|
727
725
|
self.search_choices_layer
|
|
728
726
|
) or (
|
|
729
|
-
|
|
730
|
-
"%s: No common search base:\n"
|
|
731
|
-
"from layer %s\n"
|
|
732
|
-
"search choices %s,\n"
|
|
733
|
-
"to layer %s\n"
|
|
734
|
-
"search choices\n%s."
|
|
735
|
-
)
|
|
727
|
+
"%s: No common search base:\nfrom layer %s\nsearch choices %s,\nto layer %s\nsearch choices\n%s."
|
|
736
728
|
% (self, src, src_search_choices, self.search_choices_layer, pformat(search_choices_seq))
|
|
737
729
|
)
|
|
738
730
|
search_choices_seq = search_choices_seq[: search_choices_seq.index(src_search_choices)]
|
|
@@ -4436,12 +4428,13 @@ class MergeDimsLayer(_ConcatInputLayer):
|
|
|
4436
4428
|
:rtype: list[int]
|
|
4437
4429
|
"""
|
|
4438
4430
|
if keep_order:
|
|
4439
|
-
assert isinstance(axes, (tuple, list, typing.Sequence)) and not isinstance(
|
|
4440
|
-
axes,
|
|
4441
|
-
|
|
4442
|
-
|
|
4443
|
-
|
|
4444
|
-
|
|
4431
|
+
assert isinstance(axes, (tuple, list, typing.Sequence)) and not isinstance(axes, str), (
|
|
4432
|
+
"%s: axes %r must be a list or tuple, to have a well defined order in input %s"
|
|
4433
|
+
% (
|
|
4434
|
+
name,
|
|
4435
|
+
axes,
|
|
4436
|
+
input_data,
|
|
4437
|
+
)
|
|
4445
4438
|
)
|
|
4446
4439
|
axes_ = []
|
|
4447
4440
|
for axis in axes:
|
|
@@ -5562,11 +5555,12 @@ class RepeatLayer(_ConcatInputLayer):
|
|
|
5562
5555
|
repetitions_data = repetitions_data.copy_add_dim_by_tag(axis_dim_tag, unbroadcast=True)
|
|
5563
5556
|
repetitions_axis = repetitions_data.get_axis_from_description(axis, allow_int=False)
|
|
5564
5557
|
assert repetitions_data.ndim == 1, "Repetitions %r must only have at most one non-batch axis" % repetitions
|
|
5565
|
-
assert (
|
|
5566
|
-
|
|
5567
|
-
|
|
5568
|
-
|
|
5569
|
-
|
|
5558
|
+
assert repetitions_data.batch_shape[repetitions_axis] == self.input_data.batch_shape[input_axis], (
|
|
5559
|
+
"Axis mismatch between input (%i) and repetitions (%i)"
|
|
5560
|
+
% (
|
|
5561
|
+
self.input_data.batch_shape[input_axis],
|
|
5562
|
+
repetitions_data.batch_shape[repetitions_axis],
|
|
5563
|
+
)
|
|
5570
5564
|
)
|
|
5571
5565
|
|
|
5572
5566
|
assert self.output.have_batch_axis() == (
|
|
@@ -6267,9 +6261,9 @@ class ConvLayer(_ConcatInputLayer):
|
|
|
6267
6261
|
from returnn.util import BehaviorVersion
|
|
6268
6262
|
|
|
6269
6263
|
padding = padding.upper() if isinstance(padding, str) else padding
|
|
6270
|
-
assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance(
|
|
6271
|
-
|
|
6272
|
-
)
|
|
6264
|
+
assert padding in ["SAME", "VALID", "SAME_STATIC"] or isinstance(padding, (int, tuple, list)), (
|
|
6265
|
+
f"{self}: got unsupported padding {padding}"
|
|
6266
|
+
)
|
|
6273
6267
|
assert "out_type" not in kwargs, "don't set out_type explicitly for this layer"
|
|
6274
6268
|
assert len(filter_size) in (1, 2, 3), "only 1D conv, 2D conv or 3D conv supported"
|
|
6275
6269
|
super(ConvLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs)
|
|
@@ -6285,9 +6279,9 @@ class ConvLayer(_ConcatInputLayer):
|
|
|
6285
6279
|
assert len(dilation_rate) == len(filter_size)
|
|
6286
6280
|
assert not self.input_data.sparse
|
|
6287
6281
|
assert self.input_data.have_batch_axis()
|
|
6288
|
-
assert (
|
|
6289
|
-
|
|
6290
|
-
)
|
|
6282
|
+
assert self.input_data.have_feature_axis(), (
|
|
6283
|
+
"this should be our single input feature dim now. otherwise use input_add_feature_dim"
|
|
6284
|
+
)
|
|
6291
6285
|
input_data, num_batch_dims = self.transform_input(
|
|
6292
6286
|
self.input_data,
|
|
6293
6287
|
network=self.network,
|
|
@@ -7117,9 +7111,9 @@ class PoolLayer(_ConcatInputLayer):
|
|
|
7117
7111
|
super(PoolLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs)
|
|
7118
7112
|
assert not self.input_data.sparse
|
|
7119
7113
|
assert self.input_data.have_batch_axis()
|
|
7120
|
-
assert (
|
|
7121
|
-
|
|
7122
|
-
)
|
|
7114
|
+
assert self.input_data.have_feature_axis(), (
|
|
7115
|
+
"this should be our single input feature dim now. otherwise use input_add_feature_dim"
|
|
7116
|
+
)
|
|
7123
7117
|
if in_dim and out_dim:
|
|
7124
7118
|
assert in_dim == out_dim
|
|
7125
7119
|
elif in_dim:
|
|
@@ -7381,9 +7375,9 @@ class TransposedConvLayer(_ConcatInputLayer):
|
|
|
7381
7375
|
out_dim # noqa # via get_out_data_from_opts
|
|
7382
7376
|
assert not self.input_data.sparse
|
|
7383
7377
|
assert self.input_data.have_batch_axis()
|
|
7384
|
-
assert (
|
|
7385
|
-
|
|
7386
|
-
)
|
|
7378
|
+
assert self.input_data.have_feature_axis(), (
|
|
7379
|
+
"this should be our single input feature dim now. otherwise use input_add_feature_dim"
|
|
7380
|
+
)
|
|
7387
7381
|
input_data, num_batch_dims = ConvLayer.transform_input(
|
|
7388
7382
|
self.input_data,
|
|
7389
7383
|
network=self.network,
|
|
@@ -7404,14 +7398,15 @@ class TransposedConvLayer(_ConcatInputLayer):
|
|
|
7404
7398
|
remove_padding = [remove_padding] * len(spatial_axes)
|
|
7405
7399
|
if not isinstance(output_padding, (list, tuple)):
|
|
7406
7400
|
output_padding = [output_padding] * len(spatial_axes)
|
|
7407
|
-
assert (
|
|
7408
|
-
|
|
7409
|
-
|
|
7410
|
-
|
|
7411
|
-
|
|
7412
|
-
|
|
7413
|
-
|
|
7414
|
-
|
|
7401
|
+
assert len(spatial_axes) == len(filter_size) == len(strides) == len(remove_padding) == len(output_padding), (
|
|
7402
|
+
"%s: expected %i-D transposed-conv for input %r but got filter %r and strides %r"
|
|
7403
|
+
% (
|
|
7404
|
+
self,
|
|
7405
|
+
len(spatial_axes),
|
|
7406
|
+
input_data,
|
|
7407
|
+
filter_size,
|
|
7408
|
+
strides,
|
|
7409
|
+
)
|
|
7415
7410
|
)
|
|
7416
7411
|
assert len(spatial_axes) in [1, 2], "%s: %i-D not yet implemented..." % (self, len(spatial_axes))
|
|
7417
7412
|
x = input_data.placeholder
|
|
@@ -8775,9 +8770,9 @@ class DotLayer(LayerBase):
|
|
|
8775
8770
|
red1,
|
|
8776
8771
|
red2,
|
|
8777
8772
|
)
|
|
8778
|
-
assert len(a_reduce_axes) == len(
|
|
8779
|
-
|
|
8780
|
-
)
|
|
8773
|
+
assert len(a_reduce_axes) == len(b_reduce_axes), (
|
|
8774
|
+
"%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (self, self.sources, red1, red2)
|
|
8775
|
+
)
|
|
8781
8776
|
if (
|
|
8782
8777
|
(BehaviorVersion.get() >= 3 and (var1 is NotSpecified or var2 is NotSpecified))
|
|
8783
8778
|
or var1 == "auto"
|
|
@@ -9150,9 +9145,9 @@ class DotLayer(LayerBase):
|
|
|
9150
9145
|
raise Exception(
|
|
9151
9146
|
"%s %r: " % (cls.__name__, name) + "%s not found in sources %r" % (red_axis_desc, sources)
|
|
9152
9147
|
)
|
|
9153
|
-
assert len(a_reduce_axes) == len(
|
|
9154
|
-
|
|
9155
|
-
)
|
|
9148
|
+
assert len(a_reduce_axes) == len(b_reduce_axes), (
|
|
9149
|
+
"%s: sources %r, red1 %r, red2 %r, reduce axes must match in count" % (name, sources, red1, red2)
|
|
9150
|
+
)
|
|
9156
9151
|
if (
|
|
9157
9152
|
(BehaviorVersion.get() >= 3 and (var1 is NotSpecified or var2 is NotSpecified))
|
|
9158
9153
|
or var1 == "auto"
|
|
@@ -10178,9 +10173,9 @@ class CondLayer(LayerBase):
|
|
|
10178
10173
|
self.condition_desc = condition
|
|
10179
10174
|
self.condition_layer = self._make_layer("condition", self.condition_desc)
|
|
10180
10175
|
self.true_layer_desc = true_layer
|
|
10181
|
-
self.true_layer
|
|
10176
|
+
self.true_layer: Optional[LayerBase] = None
|
|
10182
10177
|
self.false_layer_desc = false_layer
|
|
10183
|
-
self.false_layer
|
|
10178
|
+
self.false_layer: Optional[LayerBase] = None
|
|
10184
10179
|
assert self.condition_layer.output.batch_ndim == 0 and self.condition_layer.output.dtype == "bool"
|
|
10185
10180
|
self._extra_out_templates = {k: v[0] for k, v in _extra_out.items()}
|
|
10186
10181
|
x, extra_out, sizes = tf_util.cond(
|
|
@@ -12070,7 +12065,7 @@ class HDFDumpLayer(LayerBase):
|
|
|
12070
12065
|
for (key, output) in extra.items()
|
|
12071
12066
|
}
|
|
12072
12067
|
extra = {key: output.copy_as_batch_spatial_major() for (key, output) in extra.items()}
|
|
12073
|
-
self.extra
|
|
12068
|
+
self.extra: Dict[str, Data] = extra
|
|
12074
12069
|
self.dump_whole_batches = dump_whole_batches
|
|
12075
12070
|
self.num_seqs_written = 0
|
|
12076
12071
|
ndim = data.ndim
|
|
@@ -12454,9 +12449,9 @@ class BinaryCrossEntropyLoss(Loss):
|
|
|
12454
12449
|
|
|
12455
12450
|
def _check_init(self):
|
|
12456
12451
|
assert self.target is not None
|
|
12457
|
-
assert (
|
|
12458
|
-
self.target
|
|
12459
|
-
)
|
|
12452
|
+
assert self.target.batch_ndim == self.output.batch_ndim, (
|
|
12453
|
+
"Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
|
|
12454
|
+
)
|
|
12460
12455
|
|
|
12461
12456
|
def get_value(self):
|
|
12462
12457
|
"""
|
|
@@ -13020,7 +13015,7 @@ class ExpectedLoss(Loss):
|
|
|
13020
13015
|
self.divide_beam_size = divide_beam_size
|
|
13021
13016
|
self.subtract_average_loss = subtract_average_loss
|
|
13022
13017
|
self.loss_correction_grad_only = loss_correction_grad_only
|
|
13023
|
-
self.search_choices
|
|
13018
|
+
self.search_choices: Optional[SearchChoices] = None
|
|
13024
13019
|
|
|
13025
13020
|
@classmethod
|
|
13026
13021
|
def transform_config_dict(cls, d, network, get_layer):
|
|
@@ -13120,9 +13115,9 @@ class DeepClusteringLoss(Loss):
|
|
|
13120
13115
|
Does some checks on self.target and self.output, e.g. if the dense shapes matches.
|
|
13121
13116
|
You can overwrite this if those checks don't make sense for your derived loss class.
|
|
13122
13117
|
"""
|
|
13123
|
-
assert (
|
|
13124
|
-
self.target
|
|
13125
|
-
)
|
|
13118
|
+
assert self.target.ndim_dense == self.output.ndim_dense, (
|
|
13119
|
+
"Number of dimensions mismatch. Target: %s, output: %s" % (self.target, self.output)
|
|
13120
|
+
)
|
|
13126
13121
|
expected_output_dim = self._embedding_dimension * (self.target.shape[1] // self._nr_of_sources)
|
|
13127
13122
|
assert expected_output_dim == self.output.dim, "Expected output dim is %i but the output has dim %r. " % (
|
|
13128
13123
|
expected_output_dim,
|
|
@@ -13822,9 +13817,9 @@ class SamplingBasedLoss(Loss):
|
|
|
13822
13817
|
else:
|
|
13823
13818
|
loss_fn = tf.nn.sampled_softmax_loss
|
|
13824
13819
|
|
|
13825
|
-
assert (
|
|
13826
|
-
|
|
13827
|
-
)
|
|
13820
|
+
assert self.layer.params["W"].shape[0] == self.target.dim, (
|
|
13821
|
+
"Expect weight matrix of shape [num_classes, dim]"
|
|
13822
|
+
)
|
|
13828
13823
|
out = loss_fn(
|
|
13829
13824
|
weights=self.layer.params["W"].read_value(), # (num_classes,D).
|
|
13830
13825
|
biases=self.layer.params["b"].read_value(), # (num_classes).
|
returnn/tf/layers/rec.py
CHANGED
|
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
|
|
7
7
|
import contextlib
|
|
8
8
|
import typing
|
|
9
|
+
from typing import Dict, Optional, Tuple, Union
|
|
9
10
|
import tensorflow as tf
|
|
10
11
|
import returnn.tf.compat as tf_compat
|
|
11
12
|
|
|
@@ -1037,7 +1038,8 @@ class RecLayer(_ConcatInputLayer):
|
|
|
1037
1038
|
scope=tf_compat.v1.get_variable_scope(),
|
|
1038
1039
|
)
|
|
1039
1040
|
elif rnn_contrib and isinstance(
|
|
1040
|
-
cell,
|
|
1041
|
+
cell,
|
|
1042
|
+
(rnn_contrib.FusedRNNCell, rnn_contrib.LSTMBlockWrapper), # noqa # e.g. LSTMBlockFusedCell
|
|
1041
1043
|
):
|
|
1042
1044
|
# Will get (time,batch,ydim).
|
|
1043
1045
|
assert self._max_seq_len is None
|
|
@@ -1280,9 +1282,9 @@ class RecLayer(_ConcatInputLayer):
|
|
|
1280
1282
|
:param str|int|None key:
|
|
1281
1283
|
:rtype: tf.Tensor
|
|
1282
1284
|
"""
|
|
1283
|
-
assert (
|
|
1284
|
-
|
|
1285
|
-
)
|
|
1285
|
+
assert self._last_hidden_state is not None, (
|
|
1286
|
+
"last-hidden-state not implemented/supported for this layer-type. try another unit. see the code."
|
|
1287
|
+
)
|
|
1286
1288
|
return RnnCellLayer.get_state_by_key(self._last_hidden_state, key=key)
|
|
1287
1289
|
|
|
1288
1290
|
@classmethod
|
|
@@ -1431,9 +1433,7 @@ class _SubnetworkRecCell:
|
|
|
1431
1433
|
)
|
|
1432
1434
|
self._last_frames = {} # type: typing.Dict[str,Data]
|
|
1433
1435
|
self._initial_outputs = None # type: typing.Optional[typing.Dict[str,tf.Tensor]]
|
|
1434
|
-
self._initial_extra_outputs =
|
|
1435
|
-
None
|
|
1436
|
-
) # type: typing.Optional[typing.Dict[str,typing.Dict[str,typing.Union[tf.Tensor,typing.Tuple[tf.Tensor,...]]]]] # nopep8
|
|
1436
|
+
self._initial_extra_outputs: Optional[Dict[str, Dict[str, Union[tf.Tensor, Tuple[tf.Tensor, ...]]]]] = None
|
|
1437
1437
|
|
|
1438
1438
|
# input_layers_moved_out, output_layers_moved_out and layers_in_loop include (used) sub-layers as separate
|
|
1439
1439
|
# entries, this way in- and outputting them to the loop via TensorArrays will be handled just as for normal
|
|
@@ -1608,14 +1608,9 @@ class _SubnetworkRecCell:
|
|
|
1608
1608
|
while parent and parent.parent:
|
|
1609
1609
|
parent_names.insert(0, parent.parent_name or "?")
|
|
1610
1610
|
parent = parent.parent
|
|
1611
|
-
return (
|
|
1612
|
-
|
|
1613
|
-
"
|
|
1614
|
-
"parents %r)"
|
|
1615
|
-
% (
|
|
1616
|
-
lself.allow_uninitialized_template,
|
|
1617
|
-
" <- ".join(parent_names) or None,
|
|
1618
|
-
)
|
|
1611
|
+
return "<RecLayer construct template GetLayer>(allow_uninitialized_template %r, parents %r)" % (
|
|
1612
|
+
lself.allow_uninitialized_template,
|
|
1613
|
+
" <- ".join(parent_names) or None,
|
|
1619
1614
|
)
|
|
1620
1615
|
|
|
1621
1616
|
def _add_uninitialized_count(self):
|
|
@@ -2141,16 +2136,17 @@ class _SubnetworkRecCell:
|
|
|
2141
2136
|
layer = self.input_layers_net.layers[layer_name]
|
|
2142
2137
|
assert isinstance(layer, LayerBase)
|
|
2143
2138
|
if layer_name not in inputs_moved_out_tas:
|
|
2144
|
-
assert not layer.output.mark_same_time(
|
|
2145
|
-
self.
|
|
2146
|
-
)
|
|
2147
|
-
assert (
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
|
|
2151
|
-
|
|
2152
|
-
|
|
2153
|
-
|
|
2139
|
+
assert not layer.output.mark_same_time(self._time_dim_tags), (
|
|
2140
|
+
"%s does not expect to have matching time dim to %s" % (layer, self.parent_rec_layer)
|
|
2141
|
+
)
|
|
2142
|
+
assert name != "output" and not prev, (
|
|
2143
|
+
"Time dim does not match: RecLayer %s (%r) vs sub layer %s (%r)."
|
|
2144
|
+
% (
|
|
2145
|
+
self.parent_rec_layer,
|
|
2146
|
+
self.parent_rec_layer.output.get_time_dim_tag(),
|
|
2147
|
+
layer,
|
|
2148
|
+
layer.output.get_time_dim_tag(),
|
|
2149
|
+
)
|
|
2154
2150
|
)
|
|
2155
2151
|
return layer
|
|
2156
2152
|
output = layer.output.copy_template_excluding_time_dim().copy_template_set_ctx(self.net.control_flow_ctx)
|
|
@@ -2376,9 +2372,9 @@ class _SubnetworkRecCell:
|
|
|
2376
2372
|
assert output_template.output.dim == self.parent_rec_layer.output.dim
|
|
2377
2373
|
assert self.parent_rec_layer.output.time_dim_axis == 0
|
|
2378
2374
|
assert not output_template.output.has_axis(self.time_dim_tag)
|
|
2379
|
-
assert (
|
|
2380
|
-
|
|
2381
|
-
)
|
|
2375
|
+
assert output_template.output.batch_shape == self.parent_rec_layer.output.batch_shape[1:], (
|
|
2376
|
+
"see RecLayer.get_out_data_from_opts()"
|
|
2377
|
+
)
|
|
2382
2378
|
|
|
2383
2379
|
def get_init_loop_vars(self):
|
|
2384
2380
|
"""
|
|
@@ -3014,9 +3010,9 @@ class _SubnetworkRecCell:
|
|
|
3014
3010
|
needed_outputs.add("end")
|
|
3015
3011
|
assert tf.as_dtype(end_template.output.dtype) is tf.bool
|
|
3016
3012
|
else:
|
|
3017
|
-
assert (
|
|
3018
|
-
|
|
3019
|
-
)
|
|
3013
|
+
assert have_known_seq_len, (
|
|
3014
|
+
"You need to have an 'end' layer in your rec subnet if the generated seq len is unknown."
|
|
3015
|
+
)
|
|
3020
3016
|
|
|
3021
3017
|
# noinspection PyProtectedMember
|
|
3022
3018
|
if self.parent_rec_layer._optimize_move_layers_out:
|
|
@@ -3358,11 +3354,12 @@ class _SubnetworkRecCell:
|
|
|
3358
3354
|
from .basic import SelectSearchSourcesLayer
|
|
3359
3355
|
|
|
3360
3356
|
prev_end_layer = choices.translate_to_this_search_beam(prev_end_layer)
|
|
3361
|
-
assert isinstance(
|
|
3362
|
-
|
|
3363
|
-
|
|
3364
|
-
|
|
3365
|
-
|
|
3357
|
+
assert isinstance(prev_end_layer, SelectSearchSourcesLayer), (
|
|
3358
|
+
"unexpected search choices: cur end %r, prev end %r"
|
|
3359
|
+
% (
|
|
3360
|
+
choices,
|
|
3361
|
+
prev_end_layer.get_search_choices(),
|
|
3362
|
+
)
|
|
3366
3363
|
)
|
|
3367
3364
|
prev_end_flag = prev_end_layer.output.placeholder
|
|
3368
3365
|
with tf.name_scope("dyn_seq_len"):
|
|
@@ -3475,14 +3472,15 @@ class _SubnetworkRecCell:
|
|
|
3475
3472
|
assert fixed_seq_len is not None
|
|
3476
3473
|
seq_len = fixed_seq_len
|
|
3477
3474
|
if output_beam:
|
|
3478
|
-
assert (
|
|
3479
|
-
|
|
3480
|
-
|
|
3481
|
-
|
|
3482
|
-
|
|
3483
|
-
|
|
3484
|
-
|
|
3485
|
-
|
|
3475
|
+
assert not input_beam or input_beam == output_beam, (
|
|
3476
|
+
"%s: input beam %r, output beam %r, sources %r, target %r"
|
|
3477
|
+
% (
|
|
3478
|
+
self.parent_rec_layer,
|
|
3479
|
+
input_beam,
|
|
3480
|
+
output_beam,
|
|
3481
|
+
self.parent_rec_layer.sources,
|
|
3482
|
+
self.parent_rec_layer.target,
|
|
3483
|
+
)
|
|
3486
3484
|
)
|
|
3487
3485
|
assert output_template.output.batch.beam == output_beam
|
|
3488
3486
|
time_dim_tag = time_dim_tag.get_for_batch_ctx(
|
|
@@ -3791,9 +3789,9 @@ class _SubnetworkRecCell:
|
|
|
3791
3789
|
if end_layer_choice.name.startswith("prev:"):
|
|
3792
3790
|
# Logic from maybe_transform. It would be translated to the current beam.
|
|
3793
3791
|
end_layer_choice = self.net.layers[end_layer_choice.name[len("prev:") :]]
|
|
3794
|
-
assert (
|
|
3795
|
-
|
|
3796
|
-
)
|
|
3792
|
+
assert end_layer_choice in choice_seq_in_frame, (
|
|
3793
|
+
"End layer must not have a beam independent from output layer '{}'.".format(layer_name)
|
|
3794
|
+
)
|
|
3797
3795
|
|
|
3798
3796
|
end_layer_choice_index = choice_seq_in_frame.index(end_layer_choice)
|
|
3799
3797
|
choices_seq_until_end_layer = choice_seq_in_frame[:end_layer_choice_index]
|
|
@@ -5856,12 +5854,13 @@ class RecUnstackLayer(LayerBase):
|
|
|
5856
5854
|
if out_dim.is_dim_known(): # usually the case except at template construction
|
|
5857
5855
|
assert out_dim != rec_time_dim # rec_time_dim is unknown, so it cannot be the same
|
|
5858
5856
|
if out_dim != rec_time_dim:
|
|
5859
|
-
assert (
|
|
5860
|
-
declare_rec_time
|
|
5861
|
-
|
|
5862
|
-
|
|
5863
|
-
|
|
5864
|
-
|
|
5857
|
+
assert declare_rec_time, (
|
|
5858
|
+
"%s %r: must either set known axis on rec %s or enable declare_rec_time"
|
|
5859
|
+
% (
|
|
5860
|
+
cls.__name__,
|
|
5861
|
+
name,
|
|
5862
|
+
rec_time_dim,
|
|
5863
|
+
)
|
|
5865
5864
|
)
|
|
5866
5865
|
rec_time_dim.declare_same_as(out_dim)
|
|
5867
5866
|
out.mark_same_time(out_dim, must_match=True)
|
|
@@ -6132,12 +6131,13 @@ class ChoiceLayer(BaseChoiceLayer):
|
|
|
6132
6131
|
base_beam_in = tf.shape(scores_base)[1] # 1 in first frame, then beam_in
|
|
6133
6132
|
scores_beam_in = tf.shape(scores_in)[0] // net_batch_dim
|
|
6134
6133
|
beam_in = self.sources[0].output.beam.beam_size
|
|
6135
|
-
assert (
|
|
6136
|
-
|
|
6137
|
-
|
|
6138
|
-
|
|
6139
|
-
|
|
6140
|
-
|
|
6134
|
+
assert beam_in == base_search_choices.beam_size, (
|
|
6135
|
+
"%r: source %r beam-size unexpected from base choice %r"
|
|
6136
|
+
% (
|
|
6137
|
+
self,
|
|
6138
|
+
self.sources[0],
|
|
6139
|
+
base_search_choices,
|
|
6140
|
+
)
|
|
6141
6141
|
)
|
|
6142
6142
|
# About incoming beam size:
|
|
6143
6143
|
# base_beam_in - 1 in first frame, then beam_in
|
|
@@ -7510,9 +7510,9 @@ class GenericAttentionLayer(AttentionBaseLayer):
|
|
|
7510
7510
|
base_rem_axes = base.get_axes(exclude_batch=True, exclude_time=True)
|
|
7511
7511
|
base_rem_axes.remove(base.feature_dim_axis)
|
|
7512
7512
|
weights_rem_axes = weights.get_axes(exclude_batch=True)
|
|
7513
|
-
assert (
|
|
7514
|
-
weights
|
|
7515
|
-
)
|
|
7513
|
+
assert weights.time_dim_axis is not None, (
|
|
7514
|
+
f"{exception_prefix}: base {base}, weights {weights}, need time_dim_axis in weights"
|
|
7515
|
+
)
|
|
7516
7516
|
weights_axis_to_reduce = cls._weights_time_axis_to_reduce(weights=weights, base=base)
|
|
7517
7517
|
assert weights.batch_shape[weights_axis_to_reduce] == base.batch_shape[base.time_dim_axis]
|
|
7518
7518
|
weights_rem_axes.remove(weights_axis_to_reduce)
|
|
@@ -9088,13 +9088,13 @@ class MaskedComputationLayer(LayerBase):
|
|
|
9088
9088
|
new_size, new_time, idxs = None, None, None
|
|
9089
9089
|
if mask:
|
|
9090
9090
|
if self.network.is_inside_rec_layer():
|
|
9091
|
-
assert (
|
|
9092
|
-
mask
|
|
9093
|
-
)
|
|
9091
|
+
assert mask.output.shape == () and mask.output.dtype == "bool", (
|
|
9092
|
+
"%s: invalid mask %s (inside rec loop)" % (self, mask)
|
|
9093
|
+
)
|
|
9094
9094
|
else:
|
|
9095
|
-
assert (
|
|
9096
|
-
mask
|
|
9097
|
-
)
|
|
9095
|
+
assert mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool", (
|
|
9096
|
+
"%s: invalid mask %s (outside rec loop)" % (self, mask)
|
|
9097
|
+
)
|
|
9098
9098
|
assert in_spatial_dim and out_spatial_dim
|
|
9099
9099
|
mask_data = mask.output.copy_as_time_major()
|
|
9100
9100
|
mask_t = where_bc(mask_data.placeholder, mask_data.get_sequence_mask(), tf.convert_to_tensor(False))
|
|
@@ -9785,9 +9785,9 @@ class UnmaskLayer(LayerBase):
|
|
|
9785
9785
|
with same_control_flow_ctx(src_layer.output.placeholder):
|
|
9786
9786
|
src = src_layer.output.copy_as_bt_or_tb_major()
|
|
9787
9787
|
mask_out = self.mask.output
|
|
9788
|
-
assert (
|
|
9789
|
-
|
|
9790
|
-
)
|
|
9788
|
+
assert mask_out.shape == () and mask_out.batch_shape == (None,) and mask_out.dtype == "bool", (
|
|
9789
|
+
"%s: invalid mask %s (inside rec loop)" % (self, self.mask)
|
|
9790
|
+
)
|
|
9791
9791
|
prev_t = self._rec_previous_layer.rec_vars_outputs["t"] # [B]
|
|
9792
9792
|
t = prev_t + tf.cast(mask_out.placeholder, tf.int32) # [B]
|
|
9793
9793
|
self.rec_vars_outputs["t"] = t
|
|
@@ -11192,9 +11192,9 @@ class RelativePositionalEncodingLayer(_ConcatInputLayer):
|
|
|
11192
11192
|
and is_axis_from_description_recurrent(key_value_spatial_dim, network=self.network, data=self.input_data)
|
|
11193
11193
|
):
|
|
11194
11194
|
length = self.network.get_rec_step_index() + 1
|
|
11195
|
-
assert (
|
|
11196
|
-
|
|
11197
|
-
)
|
|
11195
|
+
assert key_value_spatial_dim_.dimension is None, (
|
|
11196
|
+
f"{self}: unexpected kv spatial dim {key_value_spatial_dim_}"
|
|
11197
|
+
)
|
|
11198
11198
|
assert key_value_spatial_dim_.dyn_size_ext is not None
|
|
11199
11199
|
# See CumConcatLayer for similar logic
|
|
11200
11200
|
if key_value_spatial_dim_.dyn_size_ext.placeholder is None:
|
returnn/tf/native_op.py
CHANGED
|
@@ -283,9 +283,7 @@ class OpMaker:
|
|
|
283
283
|
// otherwise it will trigger an assertion.
|
|
284
284
|
if (IsRefType(context->input_dtype({in_idx})))
|
|
285
285
|
context->forward_ref_input_to_ref_output({in_idx}, {out_idx});
|
|
286
|
-
""".format(
|
|
287
|
-
in_idx=in_idx, out_idx=out_idx
|
|
288
|
-
)
|
|
286
|
+
""".format(in_idx=in_idx, out_idx=out_idx)
|
|
289
287
|
code_set_io = ""
|
|
290
288
|
for in_idx, v in enumerate(in_info):
|
|
291
289
|
ndim = len(v["shape"])
|