returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +1 -1
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +110 -42
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +6 -5
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +19 -0
- returnn/frontend/loss.py +183 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +104 -174
- returnn/native_op.py +36 -31
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +1 -1
- returnn/tf/frontend_layers/_backend.py +3 -1
- returnn/tf/layers/basic.py +13 -2
- returnn/tf/native_op.py +16 -5
- returnn/tf/util/basic.py +7 -201
- returnn/torch/engine.py +120 -3
- returnn/torch/frontend/_backend.py +166 -22
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +3 -1
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +1 -0
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/__old_mod_loader__.py
CHANGED
|
@@ -17,7 +17,7 @@ This is supported as well.
|
|
|
17
17
|
import sys
|
|
18
18
|
import os
|
|
19
19
|
import types
|
|
20
|
-
import
|
|
20
|
+
from typing import Any, Dict
|
|
21
21
|
import importlib
|
|
22
22
|
|
|
23
23
|
old_to_new_mod_mapping = {
|
|
@@ -122,7 +122,7 @@ class _LazyLoader(types.ModuleType):
|
|
|
122
122
|
fn = "%s/%s/__init__.py" % (_base_dir, full_mod_name.replace(".", "/"))
|
|
123
123
|
assert os.path.exists(fn), "_LazyLoader: mod %r not found in %r" % (full_mod_name, _base_dir)
|
|
124
124
|
self.__file__ = fn
|
|
125
|
-
self._lazy_mod_config = dict(full_mod_name=full_mod_name, **kwargs)
|
|
125
|
+
self._lazy_mod_config: Dict[str, Any] = dict(full_mod_name=full_mod_name, **kwargs)
|
|
126
126
|
|
|
127
127
|
def _load(self):
|
|
128
128
|
full_mod_name = self.__name__
|
|
@@ -172,6 +172,30 @@ class _LazyLoader(types.ModuleType):
|
|
|
172
172
|
return super(_LazyLoader, self).__getattribute__(item)
|
|
173
173
|
|
|
174
174
|
def __getattr__(self, item):
|
|
175
|
+
if item == "torch":
|
|
176
|
+
# torch.compile Dynamo hashing can trigger this, when it uses pickle to serialize some object state,
|
|
177
|
+
# which iterates through sys.modules and does getattr on each module.
|
|
178
|
+
# In this case, it searches for torch.
|
|
179
|
+
# File ".../torch/_inductor/codecache.py", line 607 in dumps
|
|
180
|
+
# File ".../torch/_inductor/codecache.py", line 622 in get_hash
|
|
181
|
+
# File ".../torch/_inductor/codecache.py", line 961 in compiled_fx_graph_hash
|
|
182
|
+
# ...
|
|
183
|
+
# Unfortunately, Pickler.dump is native code, so we cannot easily check whether that is the parent frame.
|
|
184
|
+
# The C stacktrace looks like:
|
|
185
|
+
# ...
|
|
186
|
+
# 7 Python 0x0000000102e7d504 call_attribute + 80
|
|
187
|
+
# 8 Python 0x0000000102e7d400 _Py_slot_tp_getattr_hook + 576
|
|
188
|
+
# 9 Python 0x0000000102e507a0 PyObject_GetOptionalAttr + 248
|
|
189
|
+
# 10 _pickle.cpython-313-darwin.so 0x0000000102d24fb4 get_deep_attribute + 104
|
|
190
|
+
# 11 _pickle.cpython-313-darwin.so 0x0000000102d250b8 _checkmodule + 88
|
|
191
|
+
# 12 _pickle.cpython-313-darwin.so 0x0000000102d22588 save_global + 3024
|
|
192
|
+
# 13 _pickle.cpython-313-darwin.so 0x0000000102d1eddc save + 3424
|
|
193
|
+
# ...
|
|
194
|
+
# Right now, we just check for `item == "torch"` as a heuristic,
|
|
195
|
+
# which should never exist for any of the old-style wrapped modules here.
|
|
196
|
+
# We could maybe also check sys._getframe(1).f_code or so and add some other heuristics...
|
|
197
|
+
raise AttributeError(f"module {self.__name__} has no attribute {item} (lazy loading skipped)")
|
|
198
|
+
|
|
175
199
|
module = self._load()
|
|
176
200
|
return getattr(module, item)
|
|
177
201
|
|
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20260119.015400'
|
|
2
|
+
long_version = '1.20260119.015400+git.5c6a8c0'
|
returnn/datasets/lm.py
CHANGED
|
@@ -86,6 +86,7 @@ class LmDataset(CachedDataset2):
|
|
|
86
86
|
delayed_seq_data_start_symbol="[START]",
|
|
87
87
|
dtype: Optional[str] = None,
|
|
88
88
|
tag_prefix: Optional[str] = None,
|
|
89
|
+
_debug_limit_line_count: Optional[int] = None,
|
|
89
90
|
**kwargs,
|
|
90
91
|
):
|
|
91
92
|
"""
|
|
@@ -138,6 +139,8 @@ class LmDataset(CachedDataset2):
|
|
|
138
139
|
delayed_seq_data_start_symbol + original_sequence[:-1].
|
|
139
140
|
:param str delayed_seq_data_start_symbol: used for add_delayed_seq_data.
|
|
140
141
|
:param dtype: explicit dtype. if not given, automatically determined based on the number of labels.
|
|
142
|
+
:param tag_prefix: prefix for sequence tags. by default "line-".
|
|
143
|
+
:param _debug_limit_line_count:
|
|
141
144
|
"""
|
|
142
145
|
super(LmDataset, self).__init__(**kwargs)
|
|
143
146
|
|
|
@@ -316,6 +319,10 @@ class LmDataset(CachedDataset2):
|
|
|
316
319
|
self.num_skipped = 0
|
|
317
320
|
self.num_unknown = 0
|
|
318
321
|
|
|
322
|
+
if _debug_limit_line_count is None:
|
|
323
|
+
_debug_limit_line_count = _get_debug_limit_line_count()
|
|
324
|
+
self._debug_limit_line_count = _debug_limit_line_count
|
|
325
|
+
|
|
319
326
|
def _lazy_init(self):
|
|
320
327
|
if self._orths_offsets_and_lens is not None:
|
|
321
328
|
return
|
|
@@ -340,6 +347,9 @@ class LmDataset(CachedDataset2):
|
|
|
340
347
|
lens_per_corpus_file = []
|
|
341
348
|
start_time = time.time()
|
|
342
349
|
last_print_time = start_time
|
|
350
|
+
debug_limit_line_count = self._debug_limit_line_count
|
|
351
|
+
debug_limit_est_total = 0
|
|
352
|
+
debug_limit_hit = False
|
|
343
353
|
|
|
344
354
|
def _init_tmp_file():
|
|
345
355
|
nonlocal tmp_file, tmp_file_orth_files_index
|
|
@@ -368,13 +378,16 @@ class LmDataset(CachedDataset2):
|
|
|
368
378
|
|
|
369
379
|
if time.time() - last_print_time > 10:
|
|
370
380
|
print(
|
|
371
|
-
f" ... loaded {len(
|
|
381
|
+
f" ... loaded {len(orths)} sequences,"
|
|
372
382
|
f" {human_bytes_size(total_bytes_read)},"
|
|
373
383
|
f" after {hms(time.time() - start_time)}",
|
|
374
384
|
file=log.v4,
|
|
375
385
|
)
|
|
376
386
|
last_print_time = time.time()
|
|
377
387
|
|
|
388
|
+
if debug_limit_line_count is not None and len(orths) - prev_orth_len >= debug_limit_line_count:
|
|
389
|
+
raise _ReachedDebugLimitLineCount()
|
|
390
|
+
|
|
378
391
|
# If a list of files is provided, concatenate all.
|
|
379
392
|
if isinstance(corpus_file, str):
|
|
380
393
|
corpus_file = [corpus_file]
|
|
@@ -383,37 +396,46 @@ class LmDataset(CachedDataset2):
|
|
|
383
396
|
for file_name in corpus_file:
|
|
384
397
|
if self._use_cache_manager:
|
|
385
398
|
file_name = cf(file_name)
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
#
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
if _is_bliss(file_name):
|
|
402
|
+
_init_tmp_file()
|
|
403
|
+
_iter_bliss(filename=file_name, callback=_tmp_file_add_line, decode=False)
|
|
404
|
+
elif file_name.endswith(".gz"):
|
|
405
|
+
_init_tmp_file()
|
|
406
|
+
_iter_txt(
|
|
407
|
+
filename=file_name,
|
|
408
|
+
callback=_tmp_file_add_line,
|
|
409
|
+
skip_empty_lines=self._skip_empty_lines,
|
|
410
|
+
decode=False,
|
|
411
|
+
)
|
|
412
|
+
else: # Raw txt file
|
|
413
|
+
# Directly mmap the file.
|
|
414
|
+
# We just need to scan once through it to find line offsets.
|
|
415
|
+
file = open(file_name, "rb")
|
|
416
|
+
file_mmap = mmap.mmap(file.fileno(), 0, flags=mmap.MAP_PRIVATE)
|
|
417
|
+
file_index = len(self._orth_files)
|
|
418
|
+
self._orth_files.append(file)
|
|
419
|
+
self._orth_mmaps.append(file_mmap)
|
|
420
|
+
|
|
421
|
+
pos = 0
|
|
422
|
+
while True:
|
|
423
|
+
next_new_line = file_mmap.find(b"\n", pos)
|
|
424
|
+
if next_new_line == -1:
|
|
425
|
+
break
|
|
426
|
+
line_len = next_new_line - pos
|
|
427
|
+
if line_len or not self._skip_empty_lines:
|
|
428
|
+
orths.append((file_index, pos, line_len))
|
|
429
|
+
total_bytes_read += line_len + 1
|
|
430
|
+
pos = next_new_line + 1
|
|
431
|
+
_maybe_report_status()
|
|
432
|
+
|
|
433
|
+
except _ReachedDebugLimitLineCount as exc:
|
|
434
|
+
assert exc.estimated_total_num_seqs is not None # currently only for _iter_txt implemented
|
|
435
|
+
debug_limit_est_total += exc.estimated_total_num_seqs
|
|
436
|
+
debug_limit_hit = True
|
|
437
|
+
else: # iteration completed without hitting debug limit
|
|
438
|
+
debug_limit_est_total += len(orths) - prev_orth_len
|
|
417
439
|
|
|
418
440
|
lens_per_corpus_file.append(len(orths) - prev_orth_len)
|
|
419
441
|
prev_orth_len = len(orths)
|
|
@@ -447,6 +469,18 @@ class LmDataset(CachedDataset2):
|
|
|
447
469
|
file=log.v4,
|
|
448
470
|
)
|
|
449
471
|
|
|
472
|
+
if debug_limit_hit:
|
|
473
|
+
est_frac_loaded = len(self._orths_offsets_and_lens) / debug_limit_est_total
|
|
474
|
+
new_partition_epoch = max(int(self.partition_epoch * est_frac_loaded), 1)
|
|
475
|
+
print(
|
|
476
|
+
f"LmDataset: debug limit of {debug_limit_line_count} lines (per file) hit,"
|
|
477
|
+
f" estimated total num seqs {debug_limit_est_total},"
|
|
478
|
+
f" loaded {len(self._orths_offsets_and_lens)}, {est_frac_loaded:.2%},"
|
|
479
|
+
f" adjusting partition_epoch from {self.partition_epoch} to {new_partition_epoch}",
|
|
480
|
+
file=log.v4,
|
|
481
|
+
)
|
|
482
|
+
self.partition_epoch = new_partition_epoch
|
|
483
|
+
|
|
450
484
|
# It's only estimated because we might filter some out or so.
|
|
451
485
|
self._estimated_num_seqs = len(self._orths_offsets_and_lens) // self.partition_epoch
|
|
452
486
|
|
|
@@ -784,19 +818,34 @@ def _iter_txt(
|
|
|
784
818
|
:param decode:
|
|
785
819
|
"""
|
|
786
820
|
f = open(filename, "rb")
|
|
821
|
+
f_ = f
|
|
787
822
|
if filename.endswith(".gz"):
|
|
788
823
|
f = gzip.GzipFile(fileobj=f)
|
|
789
824
|
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
825
|
+
count = 0
|
|
826
|
+
try:
|
|
827
|
+
for line in f:
|
|
828
|
+
if decode:
|
|
829
|
+
try:
|
|
830
|
+
line = line.decode("utf8")
|
|
831
|
+
except UnicodeDecodeError:
|
|
832
|
+
line = line.decode("latin_1") # or iso8859_15?
|
|
833
|
+
line = line.strip()
|
|
834
|
+
if skip_empty_lines and not line:
|
|
835
|
+
continue
|
|
836
|
+
count += 1
|
|
837
|
+
callback(line)
|
|
838
|
+
|
|
839
|
+
except _ReachedDebugLimitLineCount as exc:
|
|
840
|
+
print(f"Reached debug limit line count for {filename}, stopping early", file=log.v4)
|
|
841
|
+
pos = f_.tell()
|
|
842
|
+
f_.seek(0, os.SEEK_END)
|
|
843
|
+
size = f_.tell()
|
|
844
|
+
print(f" stopped at byte {human_bytes_size(pos)} / {human_bytes_size(size)}", file=log.v4)
|
|
845
|
+
estimated_num_seqs = int(count * (size / pos))
|
|
846
|
+
print(f" estimated total num seqs: {estimated_num_seqs}", file=log.v4)
|
|
847
|
+
exc.estimated_total_num_seqs = estimated_num_seqs
|
|
848
|
+
raise
|
|
800
849
|
|
|
801
850
|
|
|
802
851
|
def iter_corpus(
|
|
@@ -2517,6 +2566,25 @@ def get_post_processor_function(opts):
|
|
|
2517
2566
|
return chained_post_processors
|
|
2518
2567
|
|
|
2519
2568
|
|
|
2569
|
+
def _get_debug_limit_line_count() -> Optional[int]:
|
|
2570
|
+
"""
|
|
2571
|
+
:return: if set, limit to this many lines for debugging
|
|
2572
|
+
"""
|
|
2573
|
+
from returnn.config import get_global_config
|
|
2574
|
+
|
|
2575
|
+
config = get_global_config(raise_exception=False)
|
|
2576
|
+
if not config:
|
|
2577
|
+
return None
|
|
2578
|
+
|
|
2579
|
+
return config.int("lm_dataset_debug_limit_line_count", None)
|
|
2580
|
+
|
|
2581
|
+
|
|
2582
|
+
class _ReachedDebugLimitLineCount(Exception):
|
|
2583
|
+
"""internal exception to signal reached debug limit line count"""
|
|
2584
|
+
|
|
2585
|
+
estimated_total_num_seqs: Optional[int] = None
|
|
2586
|
+
|
|
2587
|
+
|
|
2520
2588
|
def _main():
|
|
2521
2589
|
from returnn.util import better_exchook
|
|
2522
2590
|
|
returnn/frontend/__init__.py
CHANGED
returnn/frontend/_backend.py
CHANGED
|
@@ -42,6 +42,11 @@ class Backend(Generic[T]):
|
|
|
42
42
|
"""
|
|
43
43
|
raise NotImplementedError
|
|
44
44
|
|
|
45
|
+
@staticmethod
|
|
46
|
+
def assert_(condition: Tensor, message: str):
|
|
47
|
+
"""assert"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
45
50
|
@staticmethod
|
|
46
51
|
def get_tensor_dependencies(x: Tensor) -> Sequence[Tensor]:
|
|
47
52
|
"""
|
|
@@ -624,12 +629,48 @@ class Backend(Generic[T]):
|
|
|
624
629
|
targets_spatial_dim: Dim,
|
|
625
630
|
blank_index: int,
|
|
626
631
|
max_approx: bool = False,
|
|
632
|
+
use_native_op: Optional[bool] = None,
|
|
633
|
+
label_loop: bool = True,
|
|
627
634
|
) -> Tensor:
|
|
628
635
|
"""
|
|
629
636
|
Calculates the CTC loss.
|
|
630
637
|
"""
|
|
631
638
|
raise NotImplementedError
|
|
632
639
|
|
|
640
|
+
@staticmethod
|
|
641
|
+
def ctc_best_path(
|
|
642
|
+
*,
|
|
643
|
+
logits: Tensor,
|
|
644
|
+
logits_normalized: bool = False,
|
|
645
|
+
targets: Tensor,
|
|
646
|
+
input_spatial_dim: Dim,
|
|
647
|
+
targets_spatial_dim: Dim,
|
|
648
|
+
blank_index: int,
|
|
649
|
+
label_loop: bool = True,
|
|
650
|
+
) -> Tensor:
|
|
651
|
+
"""
|
|
652
|
+
Calculates the CTC best path.
|
|
653
|
+
"""
|
|
654
|
+
raise NotImplementedError
|
|
655
|
+
|
|
656
|
+
@staticmethod
|
|
657
|
+
def have_edit_distance() -> bool:
|
|
658
|
+
"""
|
|
659
|
+
:return: whether we have an edit_distance implementation
|
|
660
|
+
"""
|
|
661
|
+
return False
|
|
662
|
+
|
|
663
|
+
@staticmethod
|
|
664
|
+
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
|
|
665
|
+
"""
|
|
666
|
+
:param a: [B,Ta]
|
|
667
|
+
:param a_spatial_dim: Ta
|
|
668
|
+
:param b: [B,Tb]
|
|
669
|
+
:param b_spatial_dim: Tb
|
|
670
|
+
:return: [B]
|
|
671
|
+
"""
|
|
672
|
+
raise NotImplementedError
|
|
673
|
+
|
|
633
674
|
@staticmethod
|
|
634
675
|
def have_sequence_mask_raw() -> bool:
|
|
635
676
|
"""
|
|
@@ -67,6 +67,24 @@ def _code_hash_md5(filename: str) -> str:
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
_is_set_up = False
|
|
70
|
+
_enabled = True
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def set_enabled(enabled: bool):
|
|
74
|
+
"""
|
|
75
|
+
Enable or disable the native code setup.
|
|
76
|
+
|
|
77
|
+
:param enabled:
|
|
78
|
+
"""
|
|
79
|
+
global _enabled
|
|
80
|
+
_enabled = enabled
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def is_set_up() -> bool:
|
|
84
|
+
"""
|
|
85
|
+
:return: whether the native code is set up
|
|
86
|
+
"""
|
|
87
|
+
return _is_set_up
|
|
70
88
|
|
|
71
89
|
|
|
72
90
|
def setup():
|
|
@@ -76,6 +94,8 @@ def setup():
|
|
|
76
94
|
global _is_set_up
|
|
77
95
|
if _is_set_up:
|
|
78
96
|
return
|
|
97
|
+
if not _enabled:
|
|
98
|
+
return
|
|
79
99
|
_is_set_up = True # only try once
|
|
80
100
|
|
|
81
101
|
from returnn.tensor import Tensor, Dim
|
|
@@ -177,6 +197,8 @@ def setup_torch():
|
|
|
177
197
|
global _is_set_up_torch
|
|
178
198
|
if _is_set_up_torch:
|
|
179
199
|
return
|
|
200
|
+
if not _enabled:
|
|
201
|
+
return
|
|
180
202
|
_is_set_up_torch = True # only try once
|
|
181
203
|
|
|
182
204
|
import torch
|
|
@@ -26,6 +26,13 @@ class NumpyBackend(Backend[numpy.ndarray]):
|
|
|
26
26
|
"""executing eagerly"""
|
|
27
27
|
return True
|
|
28
28
|
|
|
29
|
+
@staticmethod
|
|
30
|
+
def assert_(condition: Tensor, message: str):
|
|
31
|
+
"""assert"""
|
|
32
|
+
assert condition.dims == (), "condition for assert must be a scalar"
|
|
33
|
+
if not condition.raw_tensor.item():
|
|
34
|
+
raise AssertionError(message)
|
|
35
|
+
|
|
29
36
|
@staticmethod
|
|
30
37
|
def get_dtype_name_raw(raw_tensor: numpy.ndarray) -> str:
|
|
31
38
|
"""
|
returnn/frontend/_utils.py
CHANGED
|
@@ -110,7 +110,7 @@ def bin_op_out_template(
|
|
|
110
110
|
all_dims.extend([dim_ for dim_ in a.dims if dim_ == dim])
|
|
111
111
|
else:
|
|
112
112
|
all_dims.extend([dim_ for dim_ in b.dims if dim_ == dim])
|
|
113
|
-
if all(set(x.dims) != set(all_dims) for x in (a, b)):
|
|
113
|
+
if all([set(x.dims) != set(all_dims) for x in (a, b)]):
|
|
114
114
|
if allow_broadcast_all_sources is False:
|
|
115
115
|
raise ValueError(f"compare: sources {a!r} {b!r} not allowed with allow_broadcast_all_sources=False")
|
|
116
116
|
elif allow_broadcast_all_sources is None:
|
returnn/frontend/array_.py
CHANGED
|
@@ -195,7 +195,7 @@ def merge_dims(
|
|
|
195
195
|
if out_dim is None:
|
|
196
196
|
from returnn.util.basic import prod
|
|
197
197
|
|
|
198
|
-
if any(d.need_masking() for d in dims[1:]):
|
|
198
|
+
if any([d.need_masking() for d in dims[1:]]):
|
|
199
199
|
# The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
|
|
200
200
|
# This would then potentially discard some of the data in the tensor in subsequent operations,
|
|
201
201
|
# when masking is applied.
|
|
@@ -910,7 +910,7 @@ def scatter(
|
|
|
910
910
|
else:
|
|
911
911
|
raise ValueError(f"scatter: invalid mode {mode!r}")
|
|
912
912
|
indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
|
|
913
|
-
if any(dim.need_masking() for dim in indices_dim):
|
|
913
|
+
if any([dim.need_masking() for dim in indices_dim]):
|
|
914
914
|
if use_mask is None:
|
|
915
915
|
use_mask = rf.use_mask_default(
|
|
916
916
|
default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
|
|
@@ -1367,12 +1367,13 @@ def repeat(
|
|
|
1367
1367
|
repeats = repeats.copy_masked(0, dims=[in_spatial_dim])
|
|
1368
1368
|
idxs = rf.cumsum(repeats, spatial_dim=in_spatial_dim) # [batch...,in_spatial_dim] -> idx in out_spatial_dim + 1
|
|
1369
1369
|
new_size = rf.gather(idxs, indices=in_spatial_dim.get_dim_value_tensor() - 1, axis=in_spatial_dim) # [batch...]
|
|
1370
|
+
dim_dev = rf.get_default_dim_size_device()
|
|
1370
1371
|
if out_spatial_dim is None:
|
|
1371
|
-
out_spatial_dim = Dim(new_size, name="repeat")
|
|
1372
|
+
out_spatial_dim = Dim(rf.copy_to_device(new_size, dim_dev), name="repeat")
|
|
1372
1373
|
elif out_spatial_dim.dyn_size_ext is None:
|
|
1373
|
-
out_spatial_dim.dyn_size_ext = new_size
|
|
1374
|
+
out_spatial_dim.dyn_size_ext = rf.copy_to_device(new_size, dim_dev)
|
|
1374
1375
|
elif out_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext.raw_tensor is None:
|
|
1375
|
-
out_spatial_dim.dyn_size_ext.raw_tensor = new_size.raw_tensor
|
|
1376
|
+
out_spatial_dim.dyn_size_ext.raw_tensor = rf.copy_to_device(new_size, dim_dev).raw_tensor
|
|
1376
1377
|
out_spatial_dim_ext = out_spatial_dim + 1
|
|
1377
1378
|
rel_idx_counts = rf.scatter(
|
|
1378
1379
|
rf.expand_dims(rf.ones((), device=values.device, dtype="int32"), dims=idxs.dims),
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Assertion utility functions for validating conditions in Python code.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Union
|
|
7
|
+
import returnn.frontend as rf
|
|
8
|
+
from returnn.tensor import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["assert_"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def assert_(condition: Union[Tensor, bool], message: str):
|
|
15
|
+
"""
|
|
16
|
+
Asserts that a given condition is True.
|
|
17
|
+
If the condition is False, raises an AssertionError with the provided message.
|
|
18
|
+
This runs async on GPU.
|
|
19
|
+
|
|
20
|
+
:param condition:
|
|
21
|
+
:param message:
|
|
22
|
+
:return: nothing
|
|
23
|
+
"""
|
|
24
|
+
if isinstance(condition, bool):
|
|
25
|
+
if not condition:
|
|
26
|
+
raise AssertionError(message)
|
|
27
|
+
|
|
28
|
+
elif isinstance(condition, Tensor):
|
|
29
|
+
if condition.dims:
|
|
30
|
+
condition = rf.reduce_all(condition, axis=condition.dims) # reduce to scalar
|
|
31
|
+
# noinspection PyProtectedMember
|
|
32
|
+
condition._raw_backend.assert_(condition, message=message)
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
raise TypeError(f"Condition must be a boolean or a Tensor, got {type(condition)}")
|
returnn/frontend/device.py
CHANGED
|
@@ -8,7 +8,13 @@ from contextlib import contextmanager
|
|
|
8
8
|
from returnn.tensor import Tensor
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"copy_to_device",
|
|
13
|
+
"get_default_device",
|
|
14
|
+
"set_default_device",
|
|
15
|
+
"set_default_device_ctx",
|
|
16
|
+
"get_default_dim_size_device",
|
|
17
|
+
]
|
|
12
18
|
|
|
13
19
|
|
|
14
20
|
_default_device: Optional[str] = None
|
|
@@ -61,3 +67,10 @@ def set_default_device_ctx(device: Optional[str]):
|
|
|
61
67
|
yield
|
|
62
68
|
finally:
|
|
63
69
|
_default_device = old_device
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_default_dim_size_device() -> Optional[str]:
|
|
73
|
+
"""
|
|
74
|
+
:return: default device, where to put new tensors for dim sizes (Dim.dyn_size_ext)
|
|
75
|
+
"""
|
|
76
|
+
return "cpu"
|
|
@@ -167,6 +167,25 @@ class ConformerConvSubsample(ISeqDownsamplingEncoder):
|
|
|
167
167
|
out, _ = rf.merge_dims(x, dims=[self._final_second_spatial_dim, in_dim])
|
|
168
168
|
return out, in_spatial_dims[0]
|
|
169
169
|
|
|
170
|
+
def get_out_spatial_dim(self, in_spatial_dim: Dim) -> Dim:
|
|
171
|
+
"""Get output spatial dimension given input spatial dimension."""
|
|
172
|
+
out_spatial_dim = in_spatial_dim
|
|
173
|
+
for i, conv_layer in enumerate(self.conv_layers):
|
|
174
|
+
(out_spatial_dim,) = rf.make_conv_out_spatial_dims(
|
|
175
|
+
[out_spatial_dim],
|
|
176
|
+
filter_size=conv_layer.filter_size[0],
|
|
177
|
+
strides=conv_layer.strides[0],
|
|
178
|
+
padding=conv_layer.padding,
|
|
179
|
+
)
|
|
180
|
+
if self.pool_sizes and i < len(self.pool_sizes):
|
|
181
|
+
(out_spatial_dim,) = rf.make_conv_out_spatial_dims(
|
|
182
|
+
[out_spatial_dim],
|
|
183
|
+
filter_size=self.pool_sizes[i][0],
|
|
184
|
+
strides=self.pool_sizes[i][0],
|
|
185
|
+
padding="same",
|
|
186
|
+
)
|
|
187
|
+
return out_spatial_dim
|
|
188
|
+
|
|
170
189
|
|
|
171
190
|
class ConformerEncoderLayer(rf.Module):
|
|
172
191
|
"""
|