returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +110 -42
  5. returnn/frontend/__init__.py +1 -0
  6. returnn/frontend/_backend.py +41 -0
  7. returnn/frontend/_native/__init__.py +22 -0
  8. returnn/frontend/_numpy_backend.py +7 -0
  9. returnn/frontend/_utils.py +1 -1
  10. returnn/frontend/array_.py +6 -5
  11. returnn/frontend/assert_.py +35 -0
  12. returnn/frontend/device.py +14 -1
  13. returnn/frontend/encoder/conformer.py +19 -0
  14. returnn/frontend/loss.py +183 -3
  15. returnn/frontend/math_.py +54 -14
  16. returnn/native_op.cpp +104 -174
  17. returnn/native_op.py +36 -31
  18. returnn/tensor/_dim_extra.py +7 -7
  19. returnn/tensor/_tensor_extra.py +10 -10
  20. returnn/tensor/utils.py +1 -1
  21. returnn/tf/frontend_layers/_backend.py +3 -1
  22. returnn/tf/layers/basic.py +13 -2
  23. returnn/tf/native_op.py +16 -5
  24. returnn/tf/util/basic.py +7 -201
  25. returnn/torch/engine.py +120 -3
  26. returnn/torch/frontend/_backend.py +166 -22
  27. returnn/torch/frontend/bridge.py +61 -0
  28. returnn/torch/frontend/compile_helper.py +106 -0
  29. returnn/torch/util/array_.py +30 -0
  30. returnn/torch/util/assert_.py +122 -0
  31. returnn/torch/util/native_op.py +885 -0
  32. returnn/torch/util/native_op_code_compiler.py +308 -0
  33. returnn/util/basic.py +3 -1
  34. returnn/util/cuda_env.py +332 -0
  35. returnn/util/debug.py +1 -0
  36. returnn/util/fsa.py +17 -13
  37. returnn/util/native_code_compiler.py +104 -47
  38. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
  39. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
  40. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  41. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  42. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20260105.192646
3
+ Version: 1.20260119.15400
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -17,7 +17,7 @@ This is supported as well.
17
17
  import sys
18
18
  import os
19
19
  import types
20
- import typing
20
+ from typing import Any, Dict
21
21
  import importlib
22
22
 
23
23
  old_to_new_mod_mapping = {
@@ -122,7 +122,7 @@ class _LazyLoader(types.ModuleType):
122
122
  fn = "%s/%s/__init__.py" % (_base_dir, full_mod_name.replace(".", "/"))
123
123
  assert os.path.exists(fn), "_LazyLoader: mod %r not found in %r" % (full_mod_name, _base_dir)
124
124
  self.__file__ = fn
125
- self._lazy_mod_config = dict(full_mod_name=full_mod_name, **kwargs) # type: typing.Dict[str]
125
+ self._lazy_mod_config: Dict[str, Any] = dict(full_mod_name=full_mod_name, **kwargs)
126
126
 
127
127
  def _load(self):
128
128
  full_mod_name = self.__name__
@@ -172,6 +172,30 @@ class _LazyLoader(types.ModuleType):
172
172
  return super(_LazyLoader, self).__getattribute__(item)
173
173
 
174
174
  def __getattr__(self, item):
175
+ if item == "torch":
176
+ # torch.compile Dynamo hashing can trigger this, when it uses pickle to serialize some object state,
177
+ # which iterates through sys.modules and does getattr on each module.
178
+ # In this case, it searches for torch.
179
+ # File ".../torch/_inductor/codecache.py", line 607 in dumps
180
+ # File ".../torch/_inductor/codecache.py", line 622 in get_hash
181
+ # File ".../torch/_inductor/codecache.py", line 961 in compiled_fx_graph_hash
182
+ # ...
183
+ # Unfortunately, Pickler.dump is native code, so we cannot easily check whether that is the parent frame.
184
+ # The C stacktrace looks like:
185
+ # ...
186
+ # 7 Python 0x0000000102e7d504 call_attribute + 80
187
+ # 8 Python 0x0000000102e7d400 _Py_slot_tp_getattr_hook + 576
188
+ # 9 Python 0x0000000102e507a0 PyObject_GetOptionalAttr + 248
189
+ # 10 _pickle.cpython-313-darwin.so 0x0000000102d24fb4 get_deep_attribute + 104
190
+ # 11 _pickle.cpython-313-darwin.so 0x0000000102d250b8 _checkmodule + 88
191
+ # 12 _pickle.cpython-313-darwin.so 0x0000000102d22588 save_global + 3024
192
+ # 13 _pickle.cpython-313-darwin.so 0x0000000102d1eddc save + 3424
193
+ # ...
194
+ # Right now, we just check for `item == "torch"` as a heuristic,
195
+ # which should never exist for any of the old-style wrapped modules here.
196
+ # We could maybe also check sys._getframe(1).f_code or so and add some other heuristics...
197
+ raise AttributeError(f"module {self.__name__} has no attribute {item} (lazy loading skipped)")
198
+
175
199
  module = self._load()
176
200
  return getattr(module, item)
177
201
 
@@ -1,2 +1,2 @@
1
- version = '1.20260105.192646'
2
- long_version = '1.20260105.192646+git.1201db0'
1
+ version = '1.20260119.015400'
2
+ long_version = '1.20260119.015400+git.5c6a8c0'
returnn/datasets/lm.py CHANGED
@@ -86,6 +86,7 @@ class LmDataset(CachedDataset2):
86
86
  delayed_seq_data_start_symbol="[START]",
87
87
  dtype: Optional[str] = None,
88
88
  tag_prefix: Optional[str] = None,
89
+ _debug_limit_line_count: Optional[int] = None,
89
90
  **kwargs,
90
91
  ):
91
92
  """
@@ -138,6 +139,8 @@ class LmDataset(CachedDataset2):
138
139
  delayed_seq_data_start_symbol + original_sequence[:-1].
139
140
  :param str delayed_seq_data_start_symbol: used for add_delayed_seq_data.
140
141
  :param dtype: explicit dtype. if not given, automatically determined based on the number of labels.
142
+ :param tag_prefix: prefix for sequence tags. by default "line-".
143
+ :param _debug_limit_line_count:
141
144
  """
142
145
  super(LmDataset, self).__init__(**kwargs)
143
146
 
@@ -316,6 +319,10 @@ class LmDataset(CachedDataset2):
316
319
  self.num_skipped = 0
317
320
  self.num_unknown = 0
318
321
 
322
+ if _debug_limit_line_count is None:
323
+ _debug_limit_line_count = _get_debug_limit_line_count()
324
+ self._debug_limit_line_count = _debug_limit_line_count
325
+
319
326
  def _lazy_init(self):
320
327
  if self._orths_offsets_and_lens is not None:
321
328
  return
@@ -340,6 +347,9 @@ class LmDataset(CachedDataset2):
340
347
  lens_per_corpus_file = []
341
348
  start_time = time.time()
342
349
  last_print_time = start_time
350
+ debug_limit_line_count = self._debug_limit_line_count
351
+ debug_limit_est_total = 0
352
+ debug_limit_hit = False
343
353
 
344
354
  def _init_tmp_file():
345
355
  nonlocal tmp_file, tmp_file_orth_files_index
@@ -368,13 +378,16 @@ class LmDataset(CachedDataset2):
368
378
 
369
379
  if time.time() - last_print_time > 10:
370
380
  print(
371
- f" ... loaded {len(self._orths_offsets_and_lens)} sequences,"
381
+ f" ... loaded {len(orths)} sequences,"
372
382
  f" {human_bytes_size(total_bytes_read)},"
373
383
  f" after {hms(time.time() - start_time)}",
374
384
  file=log.v4,
375
385
  )
376
386
  last_print_time = time.time()
377
387
 
388
+ if debug_limit_line_count is not None and len(orths) - prev_orth_len >= debug_limit_line_count:
389
+ raise _ReachedDebugLimitLineCount()
390
+
378
391
  # If a list of files is provided, concatenate all.
379
392
  if isinstance(corpus_file, str):
380
393
  corpus_file = [corpus_file]
@@ -383,37 +396,46 @@ class LmDataset(CachedDataset2):
383
396
  for file_name in corpus_file:
384
397
  if self._use_cache_manager:
385
398
  file_name = cf(file_name)
386
- if _is_bliss(file_name):
387
- _init_tmp_file()
388
- _iter_bliss(filename=file_name, callback=_tmp_file_add_line, decode=False)
389
- elif file_name.endswith(".gz"):
390
- _init_tmp_file()
391
- _iter_txt(
392
- filename=file_name,
393
- callback=_tmp_file_add_line,
394
- skip_empty_lines=self._skip_empty_lines,
395
- decode=False,
396
- )
397
- else: # Raw txt file
398
- # Directly mmap the file.
399
- # We just need to scan once through it to find line offsets.
400
- file = open(file_name, "rb")
401
- file_mmap = mmap.mmap(file.fileno(), 0, flags=mmap.MAP_PRIVATE)
402
- file_index = len(self._orth_files)
403
- self._orth_files.append(file)
404
- self._orth_mmaps.append(file_mmap)
405
-
406
- pos = 0
407
- while True:
408
- next_new_line = file_mmap.find(b"\n", pos)
409
- if next_new_line == -1:
410
- break
411
- line_len = next_new_line - pos
412
- if line_len or not self._skip_empty_lines:
413
- orths.append((file_index, pos, line_len))
414
- total_bytes_read += line_len + 1
415
- pos = next_new_line + 1
416
- _maybe_report_status()
399
+
400
+ try:
401
+ if _is_bliss(file_name):
402
+ _init_tmp_file()
403
+ _iter_bliss(filename=file_name, callback=_tmp_file_add_line, decode=False)
404
+ elif file_name.endswith(".gz"):
405
+ _init_tmp_file()
406
+ _iter_txt(
407
+ filename=file_name,
408
+ callback=_tmp_file_add_line,
409
+ skip_empty_lines=self._skip_empty_lines,
410
+ decode=False,
411
+ )
412
+ else: # Raw txt file
413
+ # Directly mmap the file.
414
+ # We just need to scan once through it to find line offsets.
415
+ file = open(file_name, "rb")
416
+ file_mmap = mmap.mmap(file.fileno(), 0, flags=mmap.MAP_PRIVATE)
417
+ file_index = len(self._orth_files)
418
+ self._orth_files.append(file)
419
+ self._orth_mmaps.append(file_mmap)
420
+
421
+ pos = 0
422
+ while True:
423
+ next_new_line = file_mmap.find(b"\n", pos)
424
+ if next_new_line == -1:
425
+ break
426
+ line_len = next_new_line - pos
427
+ if line_len or not self._skip_empty_lines:
428
+ orths.append((file_index, pos, line_len))
429
+ total_bytes_read += line_len + 1
430
+ pos = next_new_line + 1
431
+ _maybe_report_status()
432
+
433
+ except _ReachedDebugLimitLineCount as exc:
434
+ assert exc.estimated_total_num_seqs is not None # currently only for _iter_txt implemented
435
+ debug_limit_est_total += exc.estimated_total_num_seqs
436
+ debug_limit_hit = True
437
+ else: # iteration completed without hitting debug limit
438
+ debug_limit_est_total += len(orths) - prev_orth_len
417
439
 
418
440
  lens_per_corpus_file.append(len(orths) - prev_orth_len)
419
441
  prev_orth_len = len(orths)
@@ -447,6 +469,18 @@ class LmDataset(CachedDataset2):
447
469
  file=log.v4,
448
470
  )
449
471
 
472
+ if debug_limit_hit:
473
+ est_frac_loaded = len(self._orths_offsets_and_lens) / debug_limit_est_total
474
+ new_partition_epoch = max(int(self.partition_epoch * est_frac_loaded), 1)
475
+ print(
476
+ f"LmDataset: debug limit of {debug_limit_line_count} lines (per file) hit,"
477
+ f" estimated total num seqs {debug_limit_est_total},"
478
+ f" loaded {len(self._orths_offsets_and_lens)}, {est_frac_loaded:.2%},"
479
+ f" adjusting partition_epoch from {self.partition_epoch} to {new_partition_epoch}",
480
+ file=log.v4,
481
+ )
482
+ self.partition_epoch = new_partition_epoch
483
+
450
484
  # It's only estimated because we might filter some out or so.
451
485
  self._estimated_num_seqs = len(self._orths_offsets_and_lens) // self.partition_epoch
452
486
 
@@ -784,19 +818,34 @@ def _iter_txt(
784
818
  :param decode:
785
819
  """
786
820
  f = open(filename, "rb")
821
+ f_ = f
787
822
  if filename.endswith(".gz"):
788
823
  f = gzip.GzipFile(fileobj=f)
789
824
 
790
- for line in f:
791
- if decode:
792
- try:
793
- line = line.decode("utf8")
794
- except UnicodeDecodeError:
795
- line = line.decode("latin_1") # or iso8859_15?
796
- line = line.strip()
797
- if skip_empty_lines and not line:
798
- continue
799
- callback(line)
825
+ count = 0
826
+ try:
827
+ for line in f:
828
+ if decode:
829
+ try:
830
+ line = line.decode("utf8")
831
+ except UnicodeDecodeError:
832
+ line = line.decode("latin_1") # or iso8859_15?
833
+ line = line.strip()
834
+ if skip_empty_lines and not line:
835
+ continue
836
+ count += 1
837
+ callback(line)
838
+
839
+ except _ReachedDebugLimitLineCount as exc:
840
+ print(f"Reached debug limit line count for {filename}, stopping early", file=log.v4)
841
+ pos = f_.tell()
842
+ f_.seek(0, os.SEEK_END)
843
+ size = f_.tell()
844
+ print(f" stopped at byte {human_bytes_size(pos)} / {human_bytes_size(size)}", file=log.v4)
845
+ estimated_num_seqs = int(count * (size / pos))
846
+ print(f" estimated total num seqs: {estimated_num_seqs}", file=log.v4)
847
+ exc.estimated_total_num_seqs = estimated_num_seqs
848
+ raise
800
849
 
801
850
 
802
851
  def iter_corpus(
@@ -2517,6 +2566,25 @@ def get_post_processor_function(opts):
2517
2566
  return chained_post_processors
2518
2567
 
2519
2568
 
2569
+ def _get_debug_limit_line_count() -> Optional[int]:
2570
+ """
2571
+ :return: if set, limit to this many lines for debugging
2572
+ """
2573
+ from returnn.config import get_global_config
2574
+
2575
+ config = get_global_config(raise_exception=False)
2576
+ if not config:
2577
+ return None
2578
+
2579
+ return config.int("lm_dataset_debug_limit_line_count", None)
2580
+
2581
+
2582
+ class _ReachedDebugLimitLineCount(Exception):
2583
+ """internal exception to signal reached debug limit line count"""
2584
+
2585
+ estimated_total_num_seqs: Optional[int] = None
2586
+
2587
+
2520
2588
  def _main():
2521
2589
  from returnn.util import better_exchook
2522
2590
 
@@ -19,6 +19,7 @@ from .state import *
19
19
 
20
20
  # Now the rest, in alphabetical order.
21
21
  from .array_ import *
22
+ from .assert_ import *
22
23
  from .attention import *
23
24
  from .backend import *
24
25
  from .build_from_dict import *
@@ -42,6 +42,11 @@ class Backend(Generic[T]):
42
42
  """
43
43
  raise NotImplementedError
44
44
 
45
+ @staticmethod
46
+ def assert_(condition: Tensor, message: str):
47
+ """assert"""
48
+ raise NotImplementedError
49
+
45
50
  @staticmethod
46
51
  def get_tensor_dependencies(x: Tensor) -> Sequence[Tensor]:
47
52
  """
@@ -624,12 +629,48 @@ class Backend(Generic[T]):
624
629
  targets_spatial_dim: Dim,
625
630
  blank_index: int,
626
631
  max_approx: bool = False,
632
+ use_native_op: Optional[bool] = None,
633
+ label_loop: bool = True,
627
634
  ) -> Tensor:
628
635
  """
629
636
  Calculates the CTC loss.
630
637
  """
631
638
  raise NotImplementedError
632
639
 
640
+ @staticmethod
641
+ def ctc_best_path(
642
+ *,
643
+ logits: Tensor,
644
+ logits_normalized: bool = False,
645
+ targets: Tensor,
646
+ input_spatial_dim: Dim,
647
+ targets_spatial_dim: Dim,
648
+ blank_index: int,
649
+ label_loop: bool = True,
650
+ ) -> Tensor:
651
+ """
652
+ Calculates the CTC best path.
653
+ """
654
+ raise NotImplementedError
655
+
656
+ @staticmethod
657
+ def have_edit_distance() -> bool:
658
+ """
659
+ :return: whether we have an edit_distance implementation
660
+ """
661
+ return False
662
+
663
+ @staticmethod
664
+ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
665
+ """
666
+ :param a: [B,Ta]
667
+ :param a_spatial_dim: Ta
668
+ :param b: [B,Tb]
669
+ :param b_spatial_dim: Tb
670
+ :return: [B]
671
+ """
672
+ raise NotImplementedError
673
+
633
674
  @staticmethod
634
675
  def have_sequence_mask_raw() -> bool:
635
676
  """
@@ -67,6 +67,24 @@ def _code_hash_md5(filename: str) -> str:
67
67
 
68
68
 
69
69
  _is_set_up = False
70
+ _enabled = True
71
+
72
+
73
+ def set_enabled(enabled: bool):
74
+ """
75
+ Enable or disable the native code setup.
76
+
77
+ :param enabled:
78
+ """
79
+ global _enabled
80
+ _enabled = enabled
81
+
82
+
83
+ def is_set_up() -> bool:
84
+ """
85
+ :return: whether the native code is set up
86
+ """
87
+ return _is_set_up
70
88
 
71
89
 
72
90
  def setup():
@@ -76,6 +94,8 @@ def setup():
76
94
  global _is_set_up
77
95
  if _is_set_up:
78
96
  return
97
+ if not _enabled:
98
+ return
79
99
  _is_set_up = True # only try once
80
100
 
81
101
  from returnn.tensor import Tensor, Dim
@@ -177,6 +197,8 @@ def setup_torch():
177
197
  global _is_set_up_torch
178
198
  if _is_set_up_torch:
179
199
  return
200
+ if not _enabled:
201
+ return
180
202
  _is_set_up_torch = True # only try once
181
203
 
182
204
  import torch
@@ -26,6 +26,13 @@ class NumpyBackend(Backend[numpy.ndarray]):
26
26
  """executing eagerly"""
27
27
  return True
28
28
 
29
+ @staticmethod
30
+ def assert_(condition: Tensor, message: str):
31
+ """assert"""
32
+ assert condition.dims == (), "condition for assert must be a scalar"
33
+ if not condition.raw_tensor.item():
34
+ raise AssertionError(message)
35
+
29
36
  @staticmethod
30
37
  def get_dtype_name_raw(raw_tensor: numpy.ndarray) -> str:
31
38
  """
@@ -110,7 +110,7 @@ def bin_op_out_template(
110
110
  all_dims.extend([dim_ for dim_ in a.dims if dim_ == dim])
111
111
  else:
112
112
  all_dims.extend([dim_ for dim_ in b.dims if dim_ == dim])
113
- if all(set(x.dims) != set(all_dims) for x in (a, b)):
113
+ if all([set(x.dims) != set(all_dims) for x in (a, b)]):
114
114
  if allow_broadcast_all_sources is False:
115
115
  raise ValueError(f"compare: sources {a!r} {b!r} not allowed with allow_broadcast_all_sources=False")
116
116
  elif allow_broadcast_all_sources is None:
@@ -195,7 +195,7 @@ def merge_dims(
195
195
  if out_dim is None:
196
196
  from returnn.util.basic import prod
197
197
 
198
- if any(d.need_masking() for d in dims[1:]):
198
+ if any([d.need_masking() for d in dims[1:]]):
199
199
  # The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
200
200
  # This would then potentially discard some of the data in the tensor in subsequent operations,
201
201
  # when masking is applied.
@@ -910,7 +910,7 @@ def scatter(
910
910
  else:
911
911
  raise ValueError(f"scatter: invalid mode {mode!r}")
912
912
  indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
913
- if any(dim.need_masking() for dim in indices_dim):
913
+ if any([dim.need_masking() for dim in indices_dim]):
914
914
  if use_mask is None:
915
915
  use_mask = rf.use_mask_default(
916
916
  default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
@@ -1367,12 +1367,13 @@ def repeat(
1367
1367
  repeats = repeats.copy_masked(0, dims=[in_spatial_dim])
1368
1368
  idxs = rf.cumsum(repeats, spatial_dim=in_spatial_dim) # [batch...,in_spatial_dim] -> idx in out_spatial_dim + 1
1369
1369
  new_size = rf.gather(idxs, indices=in_spatial_dim.get_dim_value_tensor() - 1, axis=in_spatial_dim) # [batch...]
1370
+ dim_dev = rf.get_default_dim_size_device()
1370
1371
  if out_spatial_dim is None:
1371
- out_spatial_dim = Dim(new_size, name="repeat")
1372
+ out_spatial_dim = Dim(rf.copy_to_device(new_size, dim_dev), name="repeat")
1372
1373
  elif out_spatial_dim.dyn_size_ext is None:
1373
- out_spatial_dim.dyn_size_ext = new_size
1374
+ out_spatial_dim.dyn_size_ext = rf.copy_to_device(new_size, dim_dev)
1374
1375
  elif out_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext.raw_tensor is None:
1375
- out_spatial_dim.dyn_size_ext.raw_tensor = new_size.raw_tensor
1376
+ out_spatial_dim.dyn_size_ext.raw_tensor = rf.copy_to_device(new_size, dim_dev).raw_tensor
1376
1377
  out_spatial_dim_ext = out_spatial_dim + 1
1377
1378
  rel_idx_counts = rf.scatter(
1378
1379
  rf.expand_dims(rf.ones((), device=values.device, dtype="int32"), dims=idxs.dims),
@@ -0,0 +1,35 @@
1
+ """
2
+ Assertion utility functions for validating conditions in Python code.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import Union
7
+ import returnn.frontend as rf
8
+ from returnn.tensor import Tensor
9
+
10
+
11
+ __all__ = ["assert_"]
12
+
13
+
14
+ def assert_(condition: Union[Tensor, bool], message: str):
15
+ """
16
+ Asserts that a given condition is True.
17
+ If the condition is False, raises an AssertionError with the provided message.
18
+ This runs async on GPU.
19
+
20
+ :param condition:
21
+ :param message:
22
+ :return: nothing
23
+ """
24
+ if isinstance(condition, bool):
25
+ if not condition:
26
+ raise AssertionError(message)
27
+
28
+ elif isinstance(condition, Tensor):
29
+ if condition.dims:
30
+ condition = rf.reduce_all(condition, axis=condition.dims) # reduce to scalar
31
+ # noinspection PyProtectedMember
32
+ condition._raw_backend.assert_(condition, message=message)
33
+
34
+ else:
35
+ raise TypeError(f"Condition must be a boolean or a Tensor, got {type(condition)}")
@@ -8,7 +8,13 @@ from contextlib import contextmanager
8
8
  from returnn.tensor import Tensor
9
9
 
10
10
 
11
- __all__ = ["copy_to_device", "get_default_device", "set_default_device", "set_default_device_ctx"]
11
+ __all__ = [
12
+ "copy_to_device",
13
+ "get_default_device",
14
+ "set_default_device",
15
+ "set_default_device_ctx",
16
+ "get_default_dim_size_device",
17
+ ]
12
18
 
13
19
 
14
20
  _default_device: Optional[str] = None
@@ -61,3 +67,10 @@ def set_default_device_ctx(device: Optional[str]):
61
67
  yield
62
68
  finally:
63
69
  _default_device = old_device
70
+
71
+
72
+ def get_default_dim_size_device() -> Optional[str]:
73
+ """
74
+ :return: default device, where to put new tensors for dim sizes (Dim.dyn_size_ext)
75
+ """
76
+ return "cpu"
@@ -167,6 +167,25 @@ class ConformerConvSubsample(ISeqDownsamplingEncoder):
167
167
  out, _ = rf.merge_dims(x, dims=[self._final_second_spatial_dim, in_dim])
168
168
  return out, in_spatial_dims[0]
169
169
 
170
+ def get_out_spatial_dim(self, in_spatial_dim: Dim) -> Dim:
171
+ """Get output spatial dimension given input spatial dimension."""
172
+ out_spatial_dim = in_spatial_dim
173
+ for i, conv_layer in enumerate(self.conv_layers):
174
+ (out_spatial_dim,) = rf.make_conv_out_spatial_dims(
175
+ [out_spatial_dim],
176
+ filter_size=conv_layer.filter_size[0],
177
+ strides=conv_layer.strides[0],
178
+ padding=conv_layer.padding,
179
+ )
180
+ if self.pool_sizes and i < len(self.pool_sizes):
181
+ (out_spatial_dim,) = rf.make_conv_out_spatial_dims(
182
+ [out_spatial_dim],
183
+ filter_size=self.pool_sizes[i][0],
184
+ strides=self.pool_sizes[i][0],
185
+ padding="same",
186
+ )
187
+ return out_spatial_dim
188
+
170
189
 
171
190
  class ConformerEncoderLayer(rf.Module):
172
191
  """