returnn 1.20250828.2732__py3-none-any.whl → 1.20250829.151916__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/_cache.py +4 -2
- returnn/frontend/array_.py +76 -34
- returnn/frontend/conv.py +7 -4
- returnn/frontend/dims.py +26 -10
- returnn/frontend/hooks.py +3 -3
- returnn/frontend/normalization.py +1 -1
- returnn/frontend/signal.py +1 -1
- returnn/tensor/_dim_extra.py +34 -6
- returnn/util/basic.py +8 -6
- {returnn-1.20250828.2732.dist-info → returnn-1.20250829.151916.dist-info}/METADATA +1 -1
- {returnn-1.20250828.2732.dist-info → returnn-1.20250829.151916.dist-info}/RECORD +16 -16
- {returnn-1.20250828.2732.dist-info → returnn-1.20250829.151916.dist-info}/LICENSE +0 -0
- {returnn-1.20250828.2732.dist-info → returnn-1.20250829.151916.dist-info}/WHEEL +0 -0
- {returnn-1.20250828.2732.dist-info → returnn-1.20250829.151916.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250829.151916'
|
|
2
|
+
long_version = '1.20250829.151916+git.687fa49'
|
returnn/frontend/_cache.py
CHANGED
|
@@ -6,7 +6,7 @@ One use case example is :func:`sinusoidal_positional_encoding` and :func:`relati
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
|
-
from typing import Optional, Union, Any, Type, Callable, Tuple, Dict
|
|
9
|
+
from typing import Optional, Union, Any, Type, Callable, Tuple, Dict, List
|
|
10
10
|
from weakref import ref
|
|
11
11
|
import tree
|
|
12
12
|
from returnn.util.lru_cache import lru_cache
|
|
@@ -59,6 +59,8 @@ class Cache:
|
|
|
59
59
|
if isinstance(key_item_orig, DimWrapper):
|
|
60
60
|
assert isinstance(key_item, DimWrapper)
|
|
61
61
|
dim_orig = key_item_orig.dim_ref()
|
|
62
|
+
if dim_orig is None: # orig dim could be dead. but then it would not be used anyway
|
|
63
|
+
continue
|
|
62
64
|
dim = key_item.dim_ref()
|
|
63
65
|
assert isinstance(dim_orig, Dim) and isinstance(dim, Dim)
|
|
64
66
|
dim_map[dim_orig] = dim
|
|
@@ -103,7 +105,7 @@ def _transform_key(
|
|
|
103
105
|
key: Any, *, finalize_callback: Optional[Callable] = None, collected_dim_map: Optional[Dict[Dim, DimWrapper]] = None
|
|
104
106
|
) -> Tuple[Union[Type[Backend], ref[rf.RunCtx], _KeyItemType], ...]:
|
|
105
107
|
backend = _get_backend(key)
|
|
106
|
-
keys_flat = [backend]
|
|
108
|
+
keys_flat: List[Any] = [backend]
|
|
107
109
|
if not backend.executing_eagerly():
|
|
108
110
|
# See comment above: If graph-mode, the cached value becomes invalid
|
|
109
111
|
# when the current run ctx goes out of scope.
|
returnn/frontend/array_.py
CHANGED
|
@@ -188,22 +188,18 @@ def merge_dims(
|
|
|
188
188
|
return source, dims[0]
|
|
189
189
|
return rf.replace_dim(source, in_dim=dims[0], out_dim=out_dim)
|
|
190
190
|
if out_dim is None:
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
for d in dims[1:]:
|
|
194
|
-
reset_dyn_size |= d.need_masking() and out_dim.capacity != 1
|
|
195
|
-
out_dim = out_dim * d
|
|
196
|
-
if reset_dyn_size:
|
|
191
|
+
from returnn.util.basic import prod
|
|
192
|
+
|
|
193
|
+
if any(d.need_masking() for d in dims[1:]):
|
|
197
194
|
# The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
|
|
198
195
|
# This would then potentially discard some of the data in the tensor in subsequent operations,
|
|
199
196
|
# when masking is applied.
|
|
200
197
|
# Thus, discard the dynamic sizes, and just treat it as a flat dim with scalar dynamic size.
|
|
201
198
|
# https://github.com/rwth-i6/returnn/issues/1694
|
|
202
|
-
|
|
203
|
-
for d in dims
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
out_dim.dyn_size_ext = out_dim_size
|
|
199
|
+
# See also similar logic in :func:`concat`.
|
|
200
|
+
out_dim = Dim(prod(d.get_dim_value_tensor() for d in dims), name="merged")
|
|
201
|
+
else:
|
|
202
|
+
out_dim = prod(dims)
|
|
207
203
|
# noinspection PyProtectedMember
|
|
208
204
|
return source._raw_backend.merge_dims(source, dims=dims, out_dim=out_dim), out_dim
|
|
209
205
|
|
|
@@ -345,7 +341,9 @@ def window(
|
|
|
345
341
|
"""
|
|
346
342
|
if spatial_dim.need_masking():
|
|
347
343
|
if use_mask is None:
|
|
348
|
-
use_mask = rf.use_mask_default(
|
|
344
|
+
use_mask = rf.use_mask_default(
|
|
345
|
+
default=True, default_false_for_behavior_version_up_to=22, func_name="window"
|
|
346
|
+
)
|
|
349
347
|
if use_mask:
|
|
350
348
|
source = source.copy_masked(0, dims=[spatial_dim])
|
|
351
349
|
assert window_dim.dimension is not None
|
|
@@ -427,28 +425,39 @@ def concat(
|
|
|
427
425
|
dims = sources[0][0].dims_set - {sources[0][1]}
|
|
428
426
|
for src, dim in sources:
|
|
429
427
|
assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
|
|
428
|
+
need_handle_dynamic_dims = False
|
|
429
|
+
for src, dim in sources[:-1]:
|
|
430
|
+
if dim.need_masking():
|
|
431
|
+
need_handle_dynamic_dims = True
|
|
432
|
+
if handle_dynamic_dims is None:
|
|
433
|
+
handle_dynamic_dims = need_handle_dynamic_dims
|
|
430
434
|
if not out_dim:
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
)
|
|
447
|
-
|
|
448
|
-
mask_concat = sources[0][0]._raw_backend.concat(
|
|
449
|
-
*[(mask, dim) for (_, dim), mask in zip(sources, masks)], allow_broadcast=True, out_dim=out_dim
|
|
435
|
+
if handle_dynamic_dims or not need_handle_dynamic_dims:
|
|
436
|
+
out_dim = sum(d for _, d in sources)
|
|
437
|
+
else: # not handle_dynamic_dims but need_handle_dynamic_dims
|
|
438
|
+
# There are dynamic dims, but we don't want to handle them.
|
|
439
|
+
# So, summing the dims would be incorrect.
|
|
440
|
+
# Just add the dim values.
|
|
441
|
+
out_dim = Dim(sum(d.get_dim_value_tensor() for _, d in sources if d.dimension is not None), name="concat")
|
|
442
|
+
if handle_dynamic_dims:
|
|
443
|
+
out_non_masked_dim = Dim(sum(d.get_dim_value_tensor() for _, d in sources))
|
|
444
|
+
# noinspection PyProtectedMember
|
|
445
|
+
out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_non_masked_dim)
|
|
446
|
+
masks = []
|
|
447
|
+
for _, dim in sources:
|
|
448
|
+
masks.append(
|
|
449
|
+
dim.get_mask(dim_order=(dim,) + dim.dyn_size_ext.dims, device=out.device)
|
|
450
|
+
if dim.need_masking()
|
|
451
|
+
else rf.constant(True, dims=[dim], device=out.device)
|
|
450
452
|
)
|
|
451
|
-
|
|
453
|
+
# noinspection PyProtectedMember
|
|
454
|
+
mask_concat = sources[0][0]._raw_backend.concat(
|
|
455
|
+
*[(mask, dim) for (_, dim), mask in zip(sources, masks)], allow_broadcast=True, out_dim=out_non_masked_dim
|
|
456
|
+
)
|
|
457
|
+
out, _ = rf.masked_select(out, mask=mask_concat, dims=[out_non_masked_dim], out_dim=out_dim)
|
|
458
|
+
else:
|
|
459
|
+
# noinspection PyProtectedMember
|
|
460
|
+
out = sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim)
|
|
452
461
|
return out, out_dim
|
|
453
462
|
|
|
454
463
|
|
|
@@ -494,7 +503,12 @@ def pad(
|
|
|
494
503
|
if handle_dynamic_dims is None:
|
|
495
504
|
handle_dynamic_dims = _pad_handle_dynamic_dims_default(axes, padding, mode=mode)
|
|
496
505
|
if not out_dims:
|
|
497
|
-
out_dims = [
|
|
506
|
+
out_dims = [
|
|
507
|
+
(left + middle + right)
|
|
508
|
+
if handle_dynamic_dims or not _pad_need_dyn_dim_handling(middle, left, right, mode=mode)
|
|
509
|
+
else _pad_sum_dims_no_dyn_dim_handling(middle, left, right)
|
|
510
|
+
for middle, (left, right) in zip(axes, padding)
|
|
511
|
+
]
|
|
498
512
|
# noinspection PyProtectedMember
|
|
499
513
|
return (
|
|
500
514
|
source._raw_backend.pad(
|
|
@@ -560,6 +574,32 @@ def _pad_need_dyn_dim_handling(
|
|
|
560
574
|
return True
|
|
561
575
|
|
|
562
576
|
|
|
577
|
+
def _pad_sum_dims_no_dyn_dim_handling(
|
|
578
|
+
middle: Dim, left: Union[Dim, int, Tensor], right: Union[Dim, int, Tensor]
|
|
579
|
+
) -> Dim:
|
|
580
|
+
"""
|
|
581
|
+
This gets called when we need to handle dyn dims, but handle_dynamic_dims=False.
|
|
582
|
+
See also the same logic in :func:`concat`.
|
|
583
|
+
"""
|
|
584
|
+
if isinstance(left, Dim):
|
|
585
|
+
left = left.get_dim_value_tensor()
|
|
586
|
+
elif isinstance(left, int):
|
|
587
|
+
pass
|
|
588
|
+
elif isinstance(left, Tensor):
|
|
589
|
+
assert left.dims == () # scalar
|
|
590
|
+
else:
|
|
591
|
+
raise TypeError(f"invalid left pad {left}")
|
|
592
|
+
if isinstance(right, Dim):
|
|
593
|
+
right = right.get_dim_value_tensor()
|
|
594
|
+
elif isinstance(right, int):
|
|
595
|
+
pass
|
|
596
|
+
elif isinstance(right, Tensor):
|
|
597
|
+
assert right.dims == () # scalar
|
|
598
|
+
else:
|
|
599
|
+
raise TypeError(f"invalid right pad {right}")
|
|
600
|
+
return Dim(left + middle.get_dim_value_tensor() + right, name="pad")
|
|
601
|
+
|
|
602
|
+
|
|
563
603
|
def cum_concat_step(
|
|
564
604
|
source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Optional[Dim] = None
|
|
565
605
|
) -> Tuple[Tensor, Dim]:
|
|
@@ -867,7 +907,9 @@ def scatter(
|
|
|
867
907
|
indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
|
|
868
908
|
if any(dim.need_masking() for dim in indices_dim):
|
|
869
909
|
if use_mask is None:
|
|
870
|
-
use_mask = rf.use_mask_default(
|
|
910
|
+
use_mask = rf.use_mask_default(
|
|
911
|
+
default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
|
|
912
|
+
)
|
|
871
913
|
if use_mask:
|
|
872
914
|
source = source.copy_masked(fill_value, dims=indices_dim)
|
|
873
915
|
else:
|
returnn/frontend/conv.py
CHANGED
|
@@ -223,7 +223,7 @@ def conv(
|
|
|
223
223
|
"""
|
|
224
224
|
if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
|
|
225
225
|
if use_mask is None:
|
|
226
|
-
use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
|
|
226
|
+
use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22, func_name="conv")
|
|
227
227
|
if use_mask:
|
|
228
228
|
source = source.copy_masked(0, dims=in_spatial_dims)
|
|
229
229
|
for in_spatial_dim in in_spatial_dims:
|
|
@@ -391,7 +391,9 @@ def transposed_conv(
|
|
|
391
391
|
"""transposed conv"""
|
|
392
392
|
if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
|
|
393
393
|
if use_mask is None:
|
|
394
|
-
use_mask = rf.use_mask_default(
|
|
394
|
+
use_mask = rf.use_mask_default(
|
|
395
|
+
default=True, default_false_for_behavior_version_up_to=22, func_name="transposed_conv"
|
|
396
|
+
)
|
|
395
397
|
if use_mask:
|
|
396
398
|
source = source.copy_masked(0, dims=in_spatial_dims)
|
|
397
399
|
if padding == "same" and _any_is_non_default(strides, default=1) and _should_use_consistent_same_padding():
|
|
@@ -503,7 +505,7 @@ def pool(
|
|
|
503
505
|
|
|
504
506
|
if any(in_spatial_dim.need_masking() for in_spatial_dim in in_spatial_dims):
|
|
505
507
|
if use_mask is None:
|
|
506
|
-
use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
|
|
508
|
+
use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22, func_name="pool")
|
|
507
509
|
if use_mask:
|
|
508
510
|
source = source.copy_masked({"max": float("-inf"), "avg": 0}[mode], dims=in_spatial_dims)
|
|
509
511
|
else:
|
|
@@ -862,8 +864,9 @@ def _consistent_same_padding(
|
|
|
862
864
|
pad_right = (s - 1) * d - pad_left
|
|
863
865
|
paddings.append((pad_left, pad_right))
|
|
864
866
|
# We expect that masking was already done before (or we don't care about it), thus handle_dynamic_dims=False.
|
|
867
|
+
out_dims = [(left + middle + right) for middle, (left, right) in zip(in_spatial_dims, paddings)]
|
|
865
868
|
source, in_spatial_dims = rf.pad(
|
|
866
|
-
source, axes=in_spatial_dims, padding=paddings, value=pad_value, handle_dynamic_dims=False
|
|
869
|
+
source, axes=in_spatial_dims, padding=paddings, value=pad_value, handle_dynamic_dims=False, out_dims=out_dims
|
|
867
870
|
)
|
|
868
871
|
return source, in_spatial_dims, 0
|
|
869
872
|
|
returnn/frontend/dims.py
CHANGED
|
@@ -3,7 +3,7 @@ Utilities for dimension tags, dimensions, axes.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Optional, Union, TypeVar, Sequence, Tuple
|
|
6
|
+
from typing import TYPE_CHECKING, Optional, Union, TypeVar, Sequence, Tuple
|
|
7
7
|
from returnn.tensor import Tensor, Dim
|
|
8
8
|
import returnn.frontend as rf
|
|
9
9
|
from ._backend import get_backend_by_tensor, global_backend
|
|
@@ -25,6 +25,9 @@ __all__ = [
|
|
|
25
25
|
"use_mask_default",
|
|
26
26
|
]
|
|
27
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from returnn.config import Config
|
|
30
|
+
|
|
28
31
|
|
|
29
32
|
def range_over_dim(dim: Dim, *, dtype: Optional[str] = None, device: Optional[str] = None) -> Tensor[T]:
|
|
30
33
|
"""
|
|
@@ -309,7 +312,10 @@ def last_frame_position_of_dim(
|
|
|
309
312
|
|
|
310
313
|
|
|
311
314
|
def use_mask_default(
|
|
312
|
-
*,
|
|
315
|
+
*,
|
|
316
|
+
default: Optional[bool] = None,
|
|
317
|
+
default_false_for_behavior_version_up_to: Optional[int] = None,
|
|
318
|
+
func_name: Optional[str] = None,
|
|
313
319
|
) -> Optional[bool]:
|
|
314
320
|
"""
|
|
315
321
|
Check the global RETURNN config for the ``rf_use_mask``
|
|
@@ -324,20 +330,20 @@ def use_mask_default(
|
|
|
324
330
|
and if this is set, and the behavior version is less or equal,
|
|
325
331
|
then return False by default, i.e. do not use the mask by default, if it is not defined in the config.
|
|
326
332
|
This takes precedence over `default`.
|
|
333
|
+
:param func_name: if specified, also check
|
|
327
334
|
:return: what to use for the ``use_mask`` argument by default
|
|
328
335
|
"""
|
|
329
336
|
from returnn.config import get_global_config
|
|
330
337
|
|
|
331
338
|
config = get_global_config(raise_exception=False)
|
|
332
|
-
config_value = None
|
|
333
339
|
if config:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
config_value = config
|
|
339
|
-
|
|
340
|
-
|
|
340
|
+
config_value = _get_opt_bool_from_config(config, "rf_use_mask")
|
|
341
|
+
if config_value is not None:
|
|
342
|
+
return config_value
|
|
343
|
+
if func_name:
|
|
344
|
+
config_value = _get_opt_bool_from_config(config, f"rf_use_mask_{func_name}")
|
|
345
|
+
if config_value is not None:
|
|
346
|
+
return config_value
|
|
341
347
|
|
|
342
348
|
if default_false_for_behavior_version_up_to is not None:
|
|
343
349
|
from returnn.util.basic import BehaviorVersion
|
|
@@ -345,3 +351,13 @@ def use_mask_default(
|
|
|
345
351
|
if BehaviorVersion.get() <= default_false_for_behavior_version_up_to:
|
|
346
352
|
return False
|
|
347
353
|
return default
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _get_opt_bool_from_config(config: Config, key: str) -> Optional[bool]:
|
|
357
|
+
if key in config.typed_dict:
|
|
358
|
+
config_value = config.typed_dict[key]
|
|
359
|
+
assert config_value is None or isinstance(config_value, bool)
|
|
360
|
+
return config_value
|
|
361
|
+
elif key in config.dict:
|
|
362
|
+
return config.bool(key, None)
|
|
363
|
+
return None
|
returnn/frontend/hooks.py
CHANGED
|
@@ -16,7 +16,7 @@ T = TypeVar("T")
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def setup_post_hook_on_method(
|
|
19
|
-
obj:
|
|
19
|
+
obj: T,
|
|
20
20
|
attr: str,
|
|
21
21
|
hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
|
|
22
22
|
*,
|
|
@@ -40,7 +40,7 @@ class MethodWithHooks:
|
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
@classmethod
|
|
43
|
-
def get(cls, obj:
|
|
43
|
+
def get(cls, obj: T, attr: str) -> MethodWithHooks:
|
|
44
44
|
"""get existing or init new :class:`MethodWithHooks`"""
|
|
45
45
|
method = getattr(obj, attr)
|
|
46
46
|
if not isinstance(method, MethodWithHooks):
|
|
@@ -56,7 +56,7 @@ class MethodWithHooks:
|
|
|
56
56
|
method.setup()
|
|
57
57
|
return method
|
|
58
58
|
|
|
59
|
-
def __init__(self, obj:
|
|
59
|
+
def __init__(self, obj: T, attr: str):
|
|
60
60
|
"""
|
|
61
61
|
:param obj:
|
|
62
62
|
:param attr:
|
|
@@ -218,7 +218,7 @@ class BatchNorm(rf.Module):
|
|
|
218
218
|
|
|
219
219
|
if any(d.need_masking() for d in source.dims if d != self.in_dim):
|
|
220
220
|
if self.use_mask is None:
|
|
221
|
-
use_mask = rf.use_mask_default(default=True)
|
|
221
|
+
use_mask = rf.use_mask_default(default=True, func_name="BatchNorm")
|
|
222
222
|
else:
|
|
223
223
|
use_mask = self.use_mask
|
|
224
224
|
else:
|
returnn/frontend/signal.py
CHANGED
|
@@ -71,7 +71,7 @@ def stft(
|
|
|
71
71
|
"""
|
|
72
72
|
if in_spatial_dim.need_masking():
|
|
73
73
|
if use_mask is None:
|
|
74
|
-
use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22)
|
|
74
|
+
use_mask = rf.use_mask_default(default=True, default_false_for_behavior_version_up_to=22, func_name="stft")
|
|
75
75
|
if use_mask:
|
|
76
76
|
x = x.copy_masked(0, dims=[in_spatial_dim])
|
|
77
77
|
fft_length = fft_length or frame_length
|
returnn/tensor/_dim_extra.py
CHANGED
|
@@ -1264,7 +1264,6 @@ class _DimMixin:
|
|
|
1264
1264
|
raise TypeError(f"complete_dyn_size: _relu: unexpected type {type(a)}")
|
|
1265
1265
|
|
|
1266
1266
|
y: Optional[_t.Tensor] = None # resulting dyn size
|
|
1267
|
-
y_max_value: Optional[_t.Tensor] = None # resulting dyn size max value
|
|
1268
1267
|
inputs = list(op.inputs)
|
|
1269
1268
|
assert inputs
|
|
1270
1269
|
for x_dim in inputs:
|
|
@@ -1275,8 +1274,6 @@ class _DimMixin:
|
|
|
1275
1274
|
if x_dim.dyn_size_ext is None and x_dim.dimension is None:
|
|
1276
1275
|
return
|
|
1277
1276
|
y = _bin_op(y, x_dim.dimension if x_dim.dimension is not None else x_dim.dyn_size_ext)
|
|
1278
|
-
if not template_only and y.raw_tensor is not None:
|
|
1279
|
-
y_max_value = _bin_op(y_max_value, x_dim.get_dim_value_tensor())
|
|
1280
1277
|
assert y is not None, f"op {op}?"
|
|
1281
1278
|
if self.dyn_size_ext is not None:
|
|
1282
1279
|
assert self.dyn_size_ext.dim_tags == y.dim_tags
|
|
@@ -1286,9 +1283,14 @@ class _DimMixin:
|
|
|
1286
1283
|
else:
|
|
1287
1284
|
self.batch = y.batch
|
|
1288
1285
|
self.dyn_size_ext = y
|
|
1289
|
-
if not template_only and
|
|
1290
|
-
|
|
1291
|
-
|
|
1286
|
+
if not template_only and y.raw_tensor is not None:
|
|
1287
|
+
# Note: Earlier, we had this wrong.
|
|
1288
|
+
# It is not correct to replicate the same math (bin ops)
|
|
1289
|
+
# on the dim values (_dyn_size_max_value of each dim).
|
|
1290
|
+
# Consider sizes1=[2,3], sizes2=[5,4], and the op is "add".
|
|
1291
|
+
# Then the result sizes would be [7,7], thus its max is 7,
|
|
1292
|
+
# but max(sizes1)+max(sizes2)=3+5=8.
|
|
1293
|
+
self._dyn_size_max_value = rf.reduce_max(y, axis=y.dims) if y.dims else y
|
|
1292
1294
|
if tf and y.placeholder is not None:
|
|
1293
1295
|
self.set_tag_on_size_tensor(y.placeholder)
|
|
1294
1296
|
|
|
@@ -2080,6 +2082,8 @@ class _DimMixin:
|
|
|
2080
2082
|
:return: self + other. note that this is not commutative, i.e. different from other + self.
|
|
2081
2083
|
:rtype: Dim
|
|
2082
2084
|
"""
|
|
2085
|
+
if isinstance(other, int) and other == 0:
|
|
2086
|
+
return self
|
|
2083
2087
|
cache_key = ("add", other)
|
|
2084
2088
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2085
2089
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2098,6 +2102,8 @@ class _DimMixin:
|
|
|
2098
2102
|
:return: other + self
|
|
2099
2103
|
:rtype: Dim
|
|
2100
2104
|
"""
|
|
2105
|
+
if isinstance(other, int) and other == 0:
|
|
2106
|
+
return self
|
|
2101
2107
|
cache_key = ("add_left", other)
|
|
2102
2108
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2103
2109
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2115,6 +2121,8 @@ class _DimMixin:
|
|
|
2115
2121
|
:param Dim|int other:
|
|
2116
2122
|
:rtype: Dim
|
|
2117
2123
|
"""
|
|
2124
|
+
if isinstance(other, int) and other == 0:
|
|
2125
|
+
return self
|
|
2118
2126
|
return self.sub_right(other)
|
|
2119
2127
|
|
|
2120
2128
|
def sub_right(self: Dim, other):
|
|
@@ -2123,6 +2131,8 @@ class _DimMixin:
|
|
|
2123
2131
|
:return: self - other
|
|
2124
2132
|
:rtype: Dim
|
|
2125
2133
|
"""
|
|
2134
|
+
if isinstance(other, int) and other == 0:
|
|
2135
|
+
return self
|
|
2126
2136
|
cache_key = ("sub", other)
|
|
2127
2137
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2128
2138
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2141,6 +2151,8 @@ class _DimMixin:
|
|
|
2141
2151
|
:return: (-other) + self
|
|
2142
2152
|
:rtype: Dim
|
|
2143
2153
|
"""
|
|
2154
|
+
if isinstance(other, int) and other == 0:
|
|
2155
|
+
return self
|
|
2144
2156
|
cache_key = ("sub_left", other)
|
|
2145
2157
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2146
2158
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2158,6 +2170,8 @@ class _DimMixin:
|
|
|
2158
2170
|
:param Dim|int other:
|
|
2159
2171
|
:rtype: Dim
|
|
2160
2172
|
"""
|
|
2173
|
+
if isinstance(other, int) and other == 1:
|
|
2174
|
+
return self
|
|
2161
2175
|
cache_key = ("mul", other)
|
|
2162
2176
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2163
2177
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2175,6 +2189,8 @@ class _DimMixin:
|
|
|
2175
2189
|
:param Dim|int other:
|
|
2176
2190
|
:rtype: Dim
|
|
2177
2191
|
"""
|
|
2192
|
+
if isinstance(other, int) and other == 1:
|
|
2193
|
+
return self
|
|
2178
2194
|
cache_key = ("mul_left", other)
|
|
2179
2195
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2180
2196
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2192,6 +2208,8 @@ class _DimMixin:
|
|
|
2192
2208
|
:param Dim|int other:
|
|
2193
2209
|
:rtype: Dim
|
|
2194
2210
|
"""
|
|
2211
|
+
if isinstance(other, int) and other == 1:
|
|
2212
|
+
return self
|
|
2195
2213
|
cache_key = ("floordiv", other)
|
|
2196
2214
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2197
2215
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2209,6 +2227,8 @@ class _DimMixin:
|
|
|
2209
2227
|
:param Dim|int other:
|
|
2210
2228
|
:rtype: Dim
|
|
2211
2229
|
"""
|
|
2230
|
+
if isinstance(other, int) and other == 1:
|
|
2231
|
+
return self
|
|
2212
2232
|
return self.div_right(other)
|
|
2213
2233
|
|
|
2214
2234
|
def div_left(self: Dim, other):
|
|
@@ -2216,6 +2236,8 @@ class _DimMixin:
|
|
|
2216
2236
|
:param Dim|int other:
|
|
2217
2237
|
:rtype: Dim
|
|
2218
2238
|
"""
|
|
2239
|
+
if isinstance(other, int) and other == 1:
|
|
2240
|
+
return self
|
|
2219
2241
|
cache_key = ("truediv_left", other)
|
|
2220
2242
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2221
2243
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2233,6 +2255,8 @@ class _DimMixin:
|
|
|
2233
2255
|
:param Dim|int other:
|
|
2234
2256
|
:rtype: Dim
|
|
2235
2257
|
"""
|
|
2258
|
+
if isinstance(other, int) and other == 1:
|
|
2259
|
+
return self
|
|
2236
2260
|
cache_key = ("truediv", other)
|
|
2237
2261
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2238
2262
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2250,6 +2274,8 @@ class _DimMixin:
|
|
|
2250
2274
|
:param Dim|int other:
|
|
2251
2275
|
:rtype: Dim
|
|
2252
2276
|
"""
|
|
2277
|
+
if isinstance(other, int) and other == 1:
|
|
2278
|
+
return self
|
|
2253
2279
|
cache_key = ("ceildiv_left", other)
|
|
2254
2280
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2255
2281
|
cache_entry = cache.get(cache_key, None)
|
|
@@ -2267,6 +2293,8 @@ class _DimMixin:
|
|
|
2267
2293
|
:param Dim|int other:
|
|
2268
2294
|
:rtype: Dim
|
|
2269
2295
|
"""
|
|
2296
|
+
if isinstance(other, int) and other == 1:
|
|
2297
|
+
return self
|
|
2270
2298
|
cache_key = ("ceildiv", other)
|
|
2271
2299
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
2272
2300
|
cache_entry = cache.get(cache_key, None)
|
returnn/util/basic.py
CHANGED
|
@@ -1693,15 +1693,17 @@ def inplace_increment(x: numpy.ndarray, idx: numpy.ndarray, y: Union[numpy.ndarr
|
|
|
1693
1693
|
raise NotImplementedError("This feature was removed with dropped Theano support")
|
|
1694
1694
|
|
|
1695
1695
|
|
|
1696
|
-
def prod(ls):
|
|
1696
|
+
def prod(ls: Union[Iterable[T], numpy.ndarray]) -> Union[int, T, float]:
|
|
1697
1697
|
"""
|
|
1698
|
-
:param
|
|
1699
|
-
:
|
|
1698
|
+
:param ls:
|
|
1699
|
+
:return: ls[0] * ls[1] * ...
|
|
1700
1700
|
"""
|
|
1701
|
-
|
|
1701
|
+
it = iter(ls)
|
|
1702
|
+
try:
|
|
1703
|
+
x = next(it)
|
|
1704
|
+
except StopIteration:
|
|
1702
1705
|
return 1
|
|
1703
|
-
|
|
1704
|
-
for y in ls[1:]:
|
|
1706
|
+
for y in it:
|
|
1705
1707
|
x = x * y # *= doesn't work because x might be a tensor, and for e.g. torch.Tensor this op is in-place
|
|
1706
1708
|
return x
|
|
1707
1709
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=69VkJGV1pqtymng35zrww_x1nRV1ItxPtwRUITKi0OA,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=i5XkrJfU4e2pYDZFMN1NXjv0difTuwrUKEBqaiMDWr0,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -76,11 +76,11 @@ returnn/extern/graph_editor/transform.py,sha256=qMGSenpbAnGqdG6QP6iWjlm6_ccySYJa
|
|
|
76
76
|
returnn/extern/graph_editor/util.py,sha256=HfRbyQPmQ6_n5-O-096n0KeJtllQXFtaurpeJS_URZ0,18706
|
|
77
77
|
returnn/frontend/__init__.py,sha256=2aS7nbxXniIrBp2DODl0xN0f3IJ_dX4Bi9ZlR7W5_DE,1472
|
|
78
78
|
returnn/frontend/_backend.py,sha256=39l5MC1DaT0MPklMM8HXAW9nqisIIZQ9g2QSHOOtPQE,50741
|
|
79
|
-
returnn/frontend/_cache.py,sha256=
|
|
79
|
+
returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,8550
|
|
80
80
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=7uX5-Os2OyYUfC5soprIUx7rr-371yKf9DcckRKONXY,53855
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -88,14 +88,14 @@ returnn/frontend/cond.py,sha256=gh6wg0aSbAJQfKRv4BQAu-EfPWtWPLFjgc8IaPPFmwg,1023
|
|
|
88
88
|
returnn/frontend/const.py,sha256=A5fP9w6Akv56d89pPvdoZaXvC9ZTYcexepnS9O2clOc,3945
|
|
89
89
|
returnn/frontend/container.py,sha256=wF3OlQN7WlOVmmdapUth_Unha3DVf6h1B7okBJAuJDA,8011
|
|
90
90
|
returnn/frontend/control_flow_ctx.py,sha256=v17CsNwRnZYe8GdMtGJt2ftibfxMCGK1i0l-GX5ILu0,699
|
|
91
|
-
returnn/frontend/conv.py,sha256=
|
|
91
|
+
returnn/frontend/conv.py,sha256=4Mrq7MFc0f7SJ8g-wJEv4Lg3Stmju-fMwD09qKv6CuQ,32174
|
|
92
92
|
returnn/frontend/device.py,sha256=Sjara0EmFLhu9O55cN_p6OwU0NgdNCCQjyAuQhiWpGw,1437
|
|
93
|
-
returnn/frontend/dims.py,sha256=
|
|
93
|
+
returnn/frontend/dims.py,sha256=_HDU-Kxn3pApicFkm0F4Fs-ZAuF1gKXG8rroQHCFQQI,13073
|
|
94
94
|
returnn/frontend/dropout.py,sha256=TjqZCKDIOBeHr14-NCemOm9m3p84LxQuPH1DvRAYg88,5028
|
|
95
95
|
returnn/frontend/dtype.py,sha256=Ooc5BrcNrTp6XShuFEV9g5V6-niuy4ImP_Lt_Qgq3jE,1886
|
|
96
96
|
returnn/frontend/gradient.py,sha256=G-Qv4gKGHYEeB92Zwco9ao4qjd6umZPUzQC4J-fbYWo,4033
|
|
97
97
|
returnn/frontend/graph.py,sha256=PIv901WZ1rfTV0QGkyzBv6UxfWk9NsLGxdoJ5x9-8Xg,1818
|
|
98
|
-
returnn/frontend/hooks.py,sha256=
|
|
98
|
+
returnn/frontend/hooks.py,sha256=L7ITrlEQ6JUy8fEBE0SXg1dzFNkLrgb8gxZm88fxryU,5501
|
|
99
99
|
returnn/frontend/init.py,sha256=bVB7bpghaY8DI_HL0mkB_9z95onWnIX2zlW4hlMYnRw,7494
|
|
100
100
|
returnn/frontend/label_smoothing.py,sha256=lxmaowNr61sCMzMewqHhu1r0CcklYfhLXlFnBu8DeAU,5676
|
|
101
101
|
returnn/frontend/linear.py,sha256=xRUjnkD3MTWDezSaYATBYJQ2fa1RhKMNrTuhC54hhVs,2252
|
|
@@ -105,7 +105,7 @@ returnn/frontend/math_.py,sha256=A_RkZ5lH2uXMchfPIH3itraWtMNNCVckQHHpf7aIIZQ,172
|
|
|
105
105
|
returnn/frontend/matmul.py,sha256=xkueyxzSDz8MsYaWxPSjmV2Yy-tcaiOQDXbFt1IQM2A,1944
|
|
106
106
|
returnn/frontend/module.py,sha256=219rh5mE0CD0-NdxXLsKyhv3BNtOI9jSyiI1Rb8MOyU,10700
|
|
107
107
|
returnn/frontend/nested.py,sha256=P84u_cjoYdYRJ_0Cbt0vlKXxskmXTDfsnw_vFCCNKtU,15107
|
|
108
|
-
returnn/frontend/normalization.py,sha256
|
|
108
|
+
returnn/frontend/normalization.py,sha256=NrIIaZ3c2yf-WH2R9lPaL2TAq4IcNQc4OE5kFYdoihw,14139
|
|
109
109
|
returnn/frontend/parameter.py,sha256=zvrkhSYC1c_O9kVwgHvOtOnWNurl5J28lkS0i1LQpWU,10627
|
|
110
110
|
returnn/frontend/parametrizations.py,sha256=ptNgBw5IiPXVpB3QGse7AGAhdXp8X1rCqYUl2Mae8aI,2876
|
|
111
111
|
returnn/frontend/parametrize.py,sha256=VhgTEP7ehON950Q4bkCy8rvg9641moEKAXn0XzomK6E,7216
|
|
@@ -114,7 +114,7 @@ returnn/frontend/rand.py,sha256=2x7AHSYH_tZkzTk_q3t3GA_yYRNeKsVbJjw2InqSGDk,1354
|
|
|
114
114
|
returnn/frontend/rec.py,sha256=6YSsSG7fdtfvvg24vmexSg8R2aVCcKHBdGLh-Mgn9Co,8037
|
|
115
115
|
returnn/frontend/reduce.py,sha256=gRSvBJZNHa757IqBxGw4hu5eiO3pjie_ptEwUXHLSCs,10340
|
|
116
116
|
returnn/frontend/run_ctx.py,sha256=yyOMUCKTOe19C4z2Nfly4YCLBmQ9ihip6nGrkW-Y6qg,23789
|
|
117
|
-
returnn/frontend/signal.py,sha256=
|
|
117
|
+
returnn/frontend/signal.py,sha256=iBRO2ywpJOjIUfVveJaqX4NT59013VCoE49IHkVn6p8,4429
|
|
118
118
|
returnn/frontend/state.py,sha256=EePdrx6PtWL4mJ2XZmGlh5dl4nq6G9wZpqP4hdDEzfY,2935
|
|
119
119
|
returnn/frontend/stepwise_scheduler.py,sha256=fMOTR7npGCDXrXDmSQ4VwmudoHEbY3Yr-QGyjFdQJSc,927
|
|
120
120
|
returnn/frontend/tensor_array.py,sha256=Ej7CHtvpY0yBROlAk5vFe3CTXh-iAuqu9qcXS3Qxt2I,4328
|
|
@@ -154,7 +154,7 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
|
|
|
154
154
|
returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
|
|
155
155
|
returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
|
|
156
156
|
returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
|
|
157
|
-
returnn/tensor/_dim_extra.py,sha256=
|
|
157
|
+
returnn/tensor/_dim_extra.py,sha256=VN7Smn1Q0Y0DO7GSPM-aJUhp_jy5pzSMJbPkCk6JnqY,123448
|
|
158
158
|
returnn/tensor/_tensor_extra.py,sha256=gbSl6HMtn8WFYloanew_RaNNwx3eCpnKv3UfCkntJiQ,164923
|
|
159
159
|
returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
|
|
160
160
|
returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
|
|
@@ -233,7 +233,7 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
|
|
|
233
233
|
returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
|
|
234
234
|
returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
|
|
235
235
|
returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
|
|
236
|
-
returnn/util/basic.py,sha256=
|
|
236
|
+
returnn/util/basic.py,sha256=S2ABKcP0pf2UexuMXDNHGcfAu7GDSD2mr6OIByM152M,143168
|
|
237
237
|
returnn/util/better_exchook.py,sha256=39yvRecluDgYhViwSkaQ8crJ_cBWI63KeEGuK4RKe5w,70843
|
|
238
238
|
returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
|
|
239
239
|
returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250829.151916.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250829.151916.dist-info/METADATA,sha256=69VkJGV1pqtymng35zrww_x1nRV1ItxPtwRUITKi0OA,5215
|
|
258
|
+
returnn-1.20250829.151916.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250829.151916.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250829.151916.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|