returnn 1.20250430.145858__py3-none-any.whl → 1.20250508.181644__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +24 -25
- returnn/datasets/cached.py +4 -3
- returnn/datasets/distrib_files.py +1 -2
- returnn/datasets/generating.py +20 -20
- returnn/datasets/hdf.py +9 -9
- returnn/datasets/lm.py +25 -13
- returnn/datasets/meta.py +39 -38
- returnn/datasets/normalization_data.py +1 -1
- returnn/datasets/postprocessing.py +9 -9
- returnn/datasets/sprint.py +8 -7
- returnn/datasets/util/strings.py +0 -1
- returnn/datasets/util/vocabulary.py +3 -3
- returnn/extern/graph_editor/subgraph.py +1 -2
- returnn/extern/graph_editor/transform.py +1 -2
- returnn/extern/graph_editor/util.py +1 -2
- returnn/frontend/_backend.py +4 -3
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/audio/mel.py +0 -1
- returnn/frontend/const.py +3 -3
- returnn/frontend/device.py +0 -1
- returnn/frontend/dropout.py +1 -1
- returnn/frontend/encoder/e_branchformer.py +1 -1
- returnn/frontend/loop.py +3 -3
- returnn/frontend/loss.py +0 -1
- returnn/frontend/matmul.py +0 -1
- returnn/frontend/run_ctx.py +9 -9
- returnn/frontend/signal.py +0 -1
- returnn/frontend/types.py +2 -4
- returnn/native_op.py +13 -0
- returnn/sprint/cache.py +2 -4
- returnn/sprint/interface.py +3 -4
- returnn/tensor/_dim_extra.py +9 -9
- returnn/tensor/_tensor_extra.py +20 -19
- returnn/tensor/_tensor_op_overloads.py +0 -1
- returnn/tensor/tensor.py +1 -1
- returnn/tensor/tensor_dict.py +9 -9
- returnn/tf/engine.py +60 -65
- returnn/tf/frontend_layers/_backend.py +3 -3
- returnn/tf/frontend_layers/cond.py +6 -6
- returnn/tf/frontend_layers/debug_eager_mode.py +0 -1
- returnn/tf/frontend_layers/layer.py +12 -12
- returnn/tf/frontend_layers/loop.py +3 -3
- returnn/tf/frontend_layers/make_layer.py +0 -1
- returnn/tf/layers/base.py +56 -49
- returnn/tf/layers/basic.py +60 -65
- returnn/tf/layers/rec.py +74 -74
- returnn/tf/native_op.py +1 -3
- returnn/tf/network.py +60 -57
- returnn/tf/updater.py +3 -3
- returnn/tf/util/basic.py +24 -23
- returnn/torch/data/extern_data.py +4 -5
- returnn/torch/data/pipeline.py +3 -4
- returnn/torch/engine.py +16 -16
- returnn/torch/frontend/_backend.py +15 -15
- returnn/torch/frontend/bridge.py +3 -3
- returnn/torch/updater.py +8 -9
- returnn/torch/util/debug_inf_nan.py +0 -2
- returnn/torch/util/exception_helper.py +1 -1
- returnn/torch/util/scaled_gradient.py +0 -1
- returnn/util/basic.py +1 -2
- returnn/util/better_exchook.py +14 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/METADATA +1 -1
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/RECORD +68 -68
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/LICENSE +0 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/WHEEL +0 -0
- {returnn-1.20250430.145858.dist-info → returnn-1.20250508.181644.dist-info}/top_level.txt +0 -0
returnn/frontend/const.py
CHANGED
|
@@ -54,9 +54,9 @@ def full(
|
|
|
54
54
|
"Use rf.convert_to_tensor to convert an arbitrary array to a tensor."
|
|
55
55
|
)
|
|
56
56
|
if isinstance(fill_value, Tensor):
|
|
57
|
-
assert (
|
|
58
|
-
fill_value.dims
|
|
59
|
-
)
|
|
57
|
+
assert fill_value.dims == (), (
|
|
58
|
+
f"full/fill/constant: expect scalar fill_value, got tensor with shape {fill_value.dims}."
|
|
59
|
+
)
|
|
60
60
|
return global_backend.full(
|
|
61
61
|
dims, fill_value, dtype=dtype, device=device, sparse_dim=sparse_dim, feature_dim=feature_dim
|
|
62
62
|
)
|
returnn/frontend/device.py
CHANGED
returnn/frontend/dropout.py
CHANGED
|
@@ -50,7 +50,7 @@ def dropout(
|
|
|
50
50
|
raise ValueError(f"dropout axis {axis} not in source {source}")
|
|
51
51
|
|
|
52
52
|
if isinstance(keep_prob, (float, int)) and not 0 < keep_prob <= 1:
|
|
53
|
-
raise ValueError("keep_prob must be a scalar tensor or a float in the
|
|
53
|
+
raise ValueError("keep_prob must be a scalar tensor or a float in the range (0, 1], got %g" % keep_prob)
|
|
54
54
|
|
|
55
55
|
# Do nothing if we know keep_prob == 1
|
|
56
56
|
if isinstance(keep_prob, (float, int)) and keep_prob == 1:
|
|
@@ -268,7 +268,7 @@ class Merge(rf.Module):
|
|
|
268
268
|
|
|
269
269
|
|
|
270
270
|
def _make_activation(
|
|
271
|
-
activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module]
|
|
271
|
+
activation: Union[Callable[[Tensor], Tensor], Dict[str, Any], rf.Module],
|
|
272
272
|
) -> Union[Callable[[Tensor], Tensor], rf.Module]:
|
|
273
273
|
if isinstance(activation, dict):
|
|
274
274
|
activation = rf.build_from_dict(activation)
|
returnn/frontend/loop.py
CHANGED
|
@@ -273,9 +273,9 @@ def _check_matching_loop_var_templates(loop_var_templates: S, loop_vars: S):
|
|
|
273
273
|
x._push_back_delayed_check()
|
|
274
274
|
|
|
275
275
|
else: # other cases: just check same type
|
|
276
|
-
assert type(template) is type(
|
|
277
|
-
x
|
|
278
|
-
)
|
|
276
|
+
assert type(template) is type(x), (
|
|
277
|
+
f"loop var {path} template type {type(template)} does not match var type {type(x)}"
|
|
278
|
+
)
|
|
279
279
|
assert not isinstance(x, Tensor), f"loop var {path} is a Tensor but should not be"
|
|
280
280
|
|
|
281
281
|
tree.map_structure_with_path(_check, loop_var_templates, loop_vars)
|
returnn/frontend/loss.py
CHANGED
|
@@ -137,7 +137,6 @@ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim,
|
|
|
137
137
|
# We are going diagonal over (Ta+1) and (Tb+1). (Similar as RETURNN native EditDistanceOp.)
|
|
138
138
|
# You need to draw the grid on paper to understand all the index math...
|
|
139
139
|
for u in range(1, n_a_max_len + n_b_max_len + 1):
|
|
140
|
-
|
|
141
140
|
prev2_dist, _ = rf.slice(
|
|
142
141
|
buffer, axis=buffer_dim, start=buffer_offsets[u % 3], size=b_spatial_dim1, out_dim=b_spatial_dim1
|
|
143
142
|
) # [Tb+1,B]
|
returnn/frontend/matmul.py
CHANGED
returnn/frontend/run_ctx.py
CHANGED
|
@@ -306,19 +306,19 @@ class RunCtx:
|
|
|
306
306
|
assert self.stage == "forward_step"
|
|
307
307
|
|
|
308
308
|
if self.expected_outputs is not None:
|
|
309
|
-
assert (
|
|
310
|
-
name
|
|
311
|
-
)
|
|
309
|
+
assert name in self.expected_outputs.data, (
|
|
310
|
+
f"mark_as_output: unexpected output {name!r}, we expect outputs: {self.expected_outputs}"
|
|
311
|
+
)
|
|
312
312
|
expected_output = self.expected_outputs.data[name] if self.expected_outputs else None
|
|
313
|
-
assert dims is None or (
|
|
314
|
-
|
|
315
|
-
)
|
|
313
|
+
assert dims is None or (isinstance(dims, (list, tuple)) and all(isinstance(dim, Dim) for dim in dims)), (
|
|
314
|
+
f"dims should be a tuple of Dims, got {dims}"
|
|
315
|
+
)
|
|
316
316
|
if dims is None and expected_output is not None:
|
|
317
317
|
dims = expected_output.dims
|
|
318
318
|
if dims is not None and expected_output is not None:
|
|
319
|
-
assert expected_output.dims == tuple(
|
|
320
|
-
dims
|
|
321
|
-
)
|
|
319
|
+
assert expected_output.dims == tuple(dims), (
|
|
320
|
+
f"mark_as_output: {name!r} dims mismatch from expected output, given {dims}, expected {expected_output}"
|
|
321
|
+
)
|
|
322
322
|
|
|
323
323
|
if not isinstance(tensor, Tensor):
|
|
324
324
|
assert isinstance(tensor, _backend.global_backend.RawTensorType)
|
returnn/frontend/signal.py
CHANGED
returnn/frontend/types.py
CHANGED
|
@@ -19,15 +19,13 @@ ItemKeyType = Union[RawTensorTypes, Tensor, slice, Sequence[Union[RawTensorTypes
|
|
|
19
19
|
class GetModelFunc(Protocol):
|
|
20
20
|
"""get model func"""
|
|
21
21
|
|
|
22
|
-
def __call__(self, *, epoch: int, step: int) -> rf.Module:
|
|
23
|
-
...
|
|
22
|
+
def __call__(self, *, epoch: int, step: int) -> rf.Module: ...
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class StepFunc(Protocol):
|
|
27
26
|
"""step func"""
|
|
28
27
|
|
|
29
|
-
def __call__(self, *, model: rf.Module, extern_data: TensorDict) -> None:
|
|
30
|
-
...
|
|
28
|
+
def __call__(self, *, model: rf.Module, extern_data: TensorDict) -> None: ...
|
|
31
29
|
|
|
32
30
|
|
|
33
31
|
def get_raw_tensor_type() -> Type:
|
returnn/native_op.py
CHANGED
|
@@ -291,6 +291,7 @@ class LstmGenericBase(NativeOpGenBase):
|
|
|
291
291
|
:param H: gates and cell state. 3d (time,batch,dim*4)
|
|
292
292
|
:param d: final cell state. 2d (batch,dim)
|
|
293
293
|
"""
|
|
294
|
+
|
|
294
295
|
in_info = (
|
|
295
296
|
{
|
|
296
297
|
"name": "Z",
|
|
@@ -542,6 +543,7 @@ class LstmLowMem(NativeOpGenBase):
|
|
|
542
543
|
:param C: cell states. 3d (time,batch,dim). gradient ignored!
|
|
543
544
|
:param d: final cell state. 2d (batch,dim)
|
|
544
545
|
"""
|
|
546
|
+
|
|
545
547
|
in_info = (
|
|
546
548
|
{"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True},
|
|
547
549
|
{"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True},
|
|
@@ -994,6 +996,7 @@ class NativeLstm2(NativeOpGenBase):
|
|
|
994
996
|
:param H: cell-in + gates. 3d (time,batch,dim*4). gradient ignored!
|
|
995
997
|
:param d: final cell state. 2d (batch,dim)
|
|
996
998
|
"""
|
|
999
|
+
|
|
997
1000
|
in_info = (
|
|
998
1001
|
{"name": "X", "ndim": 3, "shape": (None, None, None), "need_contiguous": True},
|
|
999
1002
|
{"name": "W", "ndim": 2, "shape": (None, None), "need_contiguous": True},
|
|
@@ -1423,6 +1426,7 @@ class TwoDLSTM(NativeOpGenBase):
|
|
|
1423
1426
|
:param H: gates and cell state. 4d (timeS,timeT,batch,dim*5) ?
|
|
1424
1427
|
:param d: final cell state. 3d (timeT,batch,dim)
|
|
1425
1428
|
"""
|
|
1429
|
+
|
|
1426
1430
|
in_info = (
|
|
1427
1431
|
{
|
|
1428
1432
|
"name": "X",
|
|
@@ -3198,6 +3202,7 @@ class FastBaumWelchOp(NativeOpGenBase):
|
|
|
3198
3202
|
outputs:
|
|
3199
3203
|
:param output: Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores
|
|
3200
3204
|
"""
|
|
3205
|
+
|
|
3201
3206
|
in_info = (
|
|
3202
3207
|
{
|
|
3203
3208
|
"name": "am_scores",
|
|
@@ -3620,6 +3625,7 @@ class MultiEndFastBaumWelchOp(NativeOpGenBase):
|
|
|
3620
3625
|
outputs:
|
|
3621
3626
|
:param output: Baum-Welch alignment, scores in -log space. 3d (time,batch,dim), like am_scores
|
|
3622
3627
|
"""
|
|
3628
|
+
|
|
3623
3629
|
in_info = (
|
|
3624
3630
|
{
|
|
3625
3631
|
"name": "am_scores",
|
|
@@ -4497,6 +4503,7 @@ class FastViterbiOp(NativeOpGenBase):
|
|
|
4497
4503
|
:param output: Viterbi (hard) alignment, scores in +log space. 2d (time,batch)
|
|
4498
4504
|
:param scores: (batch,)
|
|
4499
4505
|
"""
|
|
4506
|
+
|
|
4500
4507
|
in_info = (
|
|
4501
4508
|
{
|
|
4502
4509
|
"name": "am_scores",
|
|
@@ -4865,6 +4872,7 @@ class GetCtcFsaFastBwOp(NativeOpGenBase):
|
|
|
4865
4872
|
`num_edges` should be `n_batch * (5 * (n_time - 1) + 10)`
|
|
4866
4873
|
(see construction in kernel why that number).
|
|
4867
4874
|
"""
|
|
4875
|
+
|
|
4868
4876
|
in_info = (
|
|
4869
4877
|
{
|
|
4870
4878
|
"name": "targets",
|
|
@@ -5229,6 +5237,7 @@ class EditDistanceOp(NativeOpGenBase):
|
|
|
5229
5237
|
outputs:
|
|
5230
5238
|
:param output: 1d (batch,), int32, unnormalized edit distance
|
|
5231
5239
|
"""
|
|
5240
|
+
|
|
5232
5241
|
in_info = (
|
|
5233
5242
|
{
|
|
5234
5243
|
"name": "a",
|
|
@@ -5414,6 +5423,7 @@ class OptimalCompletionEditDistanceOp(NativeOpGenBase):
|
|
|
5414
5423
|
outputs:
|
|
5415
5424
|
:param output: 1d (batch,), int32, unnormalized edit distance
|
|
5416
5425
|
"""
|
|
5426
|
+
|
|
5417
5427
|
in_info = (
|
|
5418
5428
|
{
|
|
5419
5429
|
"name": "a",
|
|
@@ -5610,6 +5620,7 @@ class OptimalCompletionEditDistancePerSuccessorOp(NativeOpGenBase):
|
|
|
5610
5620
|
outputs:
|
|
5611
5621
|
:param output: 2d (batch,num_labels), int32, unnormalized edit distance
|
|
5612
5622
|
"""
|
|
5623
|
+
|
|
5613
5624
|
in_info = (
|
|
5614
5625
|
{
|
|
5615
5626
|
"name": "a",
|
|
@@ -5880,6 +5891,7 @@ class NextEditDistanceRowOp(NativeOpGenBase):
|
|
|
5880
5891
|
outputs:
|
|
5881
5892
|
:param output: 2d (batch,b_time + 1), int32, next (unnormalized) edit distance row
|
|
5882
5893
|
"""
|
|
5894
|
+
|
|
5883
5895
|
in_info = (
|
|
5884
5896
|
{
|
|
5885
5897
|
"name": "last_row",
|
|
@@ -6039,6 +6051,7 @@ class NextEditDistanceReduceOp(NativeOpGenBase):
|
|
|
6039
6051
|
outputs:
|
|
6040
6052
|
:param output: 2d (batch,n_labels), int32, next (unnormalized) (maybe optional) edit distance
|
|
6041
6053
|
"""
|
|
6054
|
+
|
|
6042
6055
|
in_info = (
|
|
6043
6056
|
{
|
|
6044
6057
|
"name": "last_row",
|
returnn/sprint/cache.py
CHANGED
|
@@ -7,7 +7,7 @@ This module is about reading (maybe later also writing) the Sprint archive forma
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
|
-
from typing import List
|
|
10
|
+
from typing import List, Optional, Tuple
|
|
11
11
|
import sys
|
|
12
12
|
import os
|
|
13
13
|
import typing
|
|
@@ -904,9 +904,7 @@ class MixtureSet:
|
|
|
904
904
|
self.densities[n, 1] = cov_idx
|
|
905
905
|
|
|
906
906
|
self.num_mixtures = self.read_u32()
|
|
907
|
-
self.mixtures = [
|
|
908
|
-
None
|
|
909
|
-
] * self.num_mixtures # type: typing.List[typing.Optional[typing.Tuple[typing.List[int],typing.List[float]]]] # nopep8
|
|
907
|
+
self.mixtures: List[Optional[Tuple[List[int], List[float]]]] = [None] * self.num_mixtures
|
|
910
908
|
for n in range(self.num_mixtures):
|
|
911
909
|
num_densities = self.read_u32()
|
|
912
910
|
dns_idx = []
|
returnn/sprint/interface.py
CHANGED
|
@@ -820,9 +820,9 @@ def _prepare_forwarding():
|
|
|
820
820
|
assert engine
|
|
821
821
|
assert config
|
|
822
822
|
# Should already be set via setTargetMode().
|
|
823
|
-
assert config.list("extract") == [
|
|
824
|
-
"posteriors"
|
|
825
|
-
|
|
823
|
+
assert config.list("extract") == ["posteriors"], (
|
|
824
|
+
"You need to have extract = posteriors in your RETURNN config. You have: %s" % config.list("extract")
|
|
825
|
+
)
|
|
826
826
|
|
|
827
827
|
# Load network.
|
|
828
828
|
engine.init_network_from_config(config)
|
|
@@ -870,7 +870,6 @@ def _train(segment_name, features, targets=None):
|
|
|
870
870
|
# The CRNN train thread started via start() will do the actual training.
|
|
871
871
|
|
|
872
872
|
if TargetMode == "criterion-by-sprint":
|
|
873
|
-
|
|
874
873
|
# TODO...
|
|
875
874
|
make_criterion_class()
|
|
876
875
|
|
returnn/tensor/_dim_extra.py
CHANGED
|
@@ -1067,13 +1067,14 @@ class _DimMixin:
|
|
|
1067
1067
|
)
|
|
1068
1068
|
)
|
|
1069
1069
|
if batch and getattr(x, "_RETURNN_dyn_size_beam", None):
|
|
1070
|
-
assert batch.beam == getattr(
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1070
|
+
assert batch.beam == getattr(x, "_RETURNN_dyn_size_beam"), (
|
|
1071
|
+
"%s: dyn size %s has unexpected batch %s, expected %s"
|
|
1072
|
+
% (
|
|
1073
|
+
self,
|
|
1074
|
+
x,
|
|
1075
|
+
batch,
|
|
1076
|
+
getattr(x, "_RETURNN_dyn_size_beam"),
|
|
1077
|
+
)
|
|
1077
1078
|
)
|
|
1078
1079
|
if self.batch and batch:
|
|
1079
1080
|
assert self.batch == batch
|
|
@@ -1359,8 +1360,7 @@ class _DimMixin:
|
|
|
1359
1360
|
# Only auto-generated dim tags are allowed to be treated as broadcastable.
|
|
1360
1361
|
# This was another suggestion from here: https://github.com/rwth-i6/returnn/issues/666
|
|
1361
1362
|
# It was not implemented like this because the auto_generated flag was only introduced later.
|
|
1362
|
-
(self.dimension == 1 and self.auto_generated)
|
|
1363
|
-
or (other.dimension == 1 and other.auto_generated)
|
|
1363
|
+
(self.dimension == 1 and self.auto_generated) or (other.dimension == 1 and other.auto_generated)
|
|
1364
1364
|
):
|
|
1365
1365
|
pass # pass on
|
|
1366
1366
|
else:
|
returnn/tensor/_tensor_extra.py
CHANGED
|
@@ -335,9 +335,9 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
335
335
|
if tag.dyn_size_ext.placeholder is None:
|
|
336
336
|
tag.complete_dyn_size()
|
|
337
337
|
if self.placeholder is not None:
|
|
338
|
-
assert (
|
|
339
|
-
|
|
340
|
-
)
|
|
338
|
+
assert tag.dyn_size_ext.placeholder is not None, (
|
|
339
|
+
"%s sanity_check: dynamic dim %s value unknown" % (self, tag)
|
|
340
|
+
)
|
|
341
341
|
assert tag.is_dim_known()
|
|
342
342
|
|
|
343
343
|
def get_runtime_sanity_check_op(self: Tensor):
|
|
@@ -2494,8 +2494,7 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
2494
2494
|
if res_tag.match_priority > tag.match_priority:
|
|
2495
2495
|
continue
|
|
2496
2496
|
raise Exception(
|
|
2497
|
-
f"{self}: get_axis_from_description({axis}) not unique."
|
|
2498
|
-
f" use match_priority to resolve ambiguity"
|
|
2497
|
+
f"{self}: get_axis_from_description({axis}) not unique. use match_priority to resolve ambiguity"
|
|
2499
2498
|
)
|
|
2500
2499
|
if res_idx is None:
|
|
2501
2500
|
raise Exception(f"{self}: get_axis_from_description({axis}) not found")
|
|
@@ -2646,12 +2645,13 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
2646
2645
|
return self.batch_shape[self.time_dim_axis_excluding_batch] is None
|
|
2647
2646
|
if self.time_dim_axis_excluding_batch in self.size_placeholder:
|
|
2648
2647
|
return True
|
|
2649
|
-
assert isinstance(
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2648
|
+
assert isinstance(self.shape[self.time_dim_axis_excluding_batch], int), (
|
|
2649
|
+
"%s: dynamic time axis dim (None) (axis %i) but size_placeholder %r misses information"
|
|
2650
|
+
% (
|
|
2651
|
+
self,
|
|
2652
|
+
self.time_dim_axis,
|
|
2653
|
+
self.size_placeholder,
|
|
2654
|
+
)
|
|
2655
2655
|
)
|
|
2656
2656
|
return False
|
|
2657
2657
|
|
|
@@ -3307,14 +3307,15 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
3307
3307
|
if self_axis not in taken_self_axes
|
|
3308
3308
|
]
|
|
3309
3309
|
if opt == "unknown_spatial_matches":
|
|
3310
|
-
assert (
|
|
3311
|
-
|
|
3312
|
-
|
|
3313
|
-
|
|
3314
|
-
|
|
3315
|
-
|
|
3316
|
-
|
|
3317
|
-
|
|
3310
|
+
assert len(matching) <= 1, (
|
|
3311
|
+
"cannot match axes %s from %s to %s, failed at other %s, not unique after %s"
|
|
3312
|
+
% (
|
|
3313
|
+
other_axes,
|
|
3314
|
+
other,
|
|
3315
|
+
self,
|
|
3316
|
+
other_axis,
|
|
3317
|
+
opt,
|
|
3318
|
+
)
|
|
3318
3319
|
)
|
|
3319
3320
|
if matching:
|
|
3320
3321
|
break
|
returnn/tensor/tensor.py
CHANGED
|
@@ -187,7 +187,7 @@ class Tensor(_TensorMixin, _TensorOpOverloadsMixin, Generic[RawTensorType]):
|
|
|
187
187
|
if not backend.executing_eagerly():
|
|
188
188
|
backend.set_known_shape_raw(value, self.batch_shape)
|
|
189
189
|
assert backend.get_dtype_name_raw(value) == self.dtype, (
|
|
190
|
-
f"{self} dtype {self.dtype} does not match
|
|
190
|
+
f"{self} dtype {self.dtype} does not match raw tensor dtype {backend.get_dtype_name_raw(value)}"
|
|
191
191
|
)
|
|
192
192
|
self._raw_tensor = value
|
|
193
193
|
|
returnn/tensor/tensor_dict.py
CHANGED
|
@@ -91,9 +91,9 @@ class TensorDict:
|
|
|
91
91
|
out = {}
|
|
92
92
|
for key, value in self.data.items():
|
|
93
93
|
assert key not in out
|
|
94
|
-
assert isinstance(
|
|
95
|
-
value.raw_tensor, expected_value_type
|
|
96
|
-
)
|
|
94
|
+
assert isinstance(value.raw_tensor, expected_value_type), (
|
|
95
|
+
f"key {key} {value}: unexpected {type(value.raw_tensor)}, expected {expected_value_type}"
|
|
96
|
+
)
|
|
97
97
|
out[key] = value.raw_tensor
|
|
98
98
|
for i, dim in enumerate(value.dims):
|
|
99
99
|
if exclude_duplicate_dims and dim in visited_dims:
|
|
@@ -103,9 +103,9 @@ class TensorDict:
|
|
|
103
103
|
if dim.is_batch_dim() and (dim.dyn_size_ext is None or dim.dyn_size_ext.raw_tensor is None):
|
|
104
104
|
if include_scalar_dyn_sizes:
|
|
105
105
|
dim_value = dim.get_dim_value()
|
|
106
|
-
assert isinstance(
|
|
107
|
-
dim_value, expected_value_type
|
|
108
|
-
)
|
|
106
|
+
assert isinstance(dim_value, expected_value_type), (
|
|
107
|
+
f"key {key_} {dim}: unexpected {type(dim_value)}, expected {expected_value_type}"
|
|
108
|
+
)
|
|
109
109
|
out[key_] = dim_value
|
|
110
110
|
elif dim.dyn_size_ext is not None:
|
|
111
111
|
if include_scalar_dyn_sizes or dim.dyn_size_ext.dims:
|
|
@@ -116,9 +116,9 @@ class TensorDict:
|
|
|
116
116
|
out[key_] = dim.dyn_size_ext.raw_tensor
|
|
117
117
|
elif dim.size is not None:
|
|
118
118
|
if include_scalar_dyn_sizes and include_const_sizes:
|
|
119
|
-
assert isinstance(
|
|
120
|
-
dim.size, expected_value_type
|
|
121
|
-
)
|
|
119
|
+
assert isinstance(dim.size, expected_value_type), (
|
|
120
|
+
f"key {key_} {dim}: unexpected {type(dim.size)}, expected {expected_value_type}"
|
|
121
|
+
)
|
|
122
122
|
out[key_] = dim.size
|
|
123
123
|
else:
|
|
124
124
|
raise Exception(f"cannot handle dim: {dim}")
|