returnn 1.20260109.93428__py3-none-any.whl → 1.20260113.134416__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20260109.93428
3
+ Version: 1.20260113.134416
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -17,7 +17,7 @@ This is supported as well.
17
17
  import sys
18
18
  import os
19
19
  import types
20
- import typing
20
+ from typing import Any, Dict
21
21
  import importlib
22
22
 
23
23
  old_to_new_mod_mapping = {
@@ -122,7 +122,7 @@ class _LazyLoader(types.ModuleType):
122
122
  fn = "%s/%s/__init__.py" % (_base_dir, full_mod_name.replace(".", "/"))
123
123
  assert os.path.exists(fn), "_LazyLoader: mod %r not found in %r" % (full_mod_name, _base_dir)
124
124
  self.__file__ = fn
125
- self._lazy_mod_config = dict(full_mod_name=full_mod_name, **kwargs) # type: typing.Dict[str]
125
+ self._lazy_mod_config: Dict[str, Any] = dict(full_mod_name=full_mod_name, **kwargs)
126
126
 
127
127
  def _load(self):
128
128
  full_mod_name = self.__name__
@@ -172,6 +172,30 @@ class _LazyLoader(types.ModuleType):
172
172
  return super(_LazyLoader, self).__getattribute__(item)
173
173
 
174
174
  def __getattr__(self, item):
175
+ if item == "torch":
176
+ # torch.compile Dynamo hashing can trigger this, when it uses pickle to serialize some object state,
177
+ # which iterates through sys.modules and does getattr on each module.
178
+ # In this case, it searches for torch.
179
+ # File ".../torch/_inductor/codecache.py", line 607 in dumps
180
+ # File ".../torch/_inductor/codecache.py", line 622 in get_hash
181
+ # File ".../torch/_inductor/codecache.py", line 961 in compiled_fx_graph_hash
182
+ # ...
183
+ # Unfortunately, Pickler.dump is native code, so we cannot easily check whether that is the parent frame.
184
+ # The C stacktrace looks like:
185
+ # ...
186
+ # 7 Python 0x0000000102e7d504 call_attribute + 80
187
+ # 8 Python 0x0000000102e7d400 _Py_slot_tp_getattr_hook + 576
188
+ # 9 Python 0x0000000102e507a0 PyObject_GetOptionalAttr + 248
189
+ # 10 _pickle.cpython-313-darwin.so 0x0000000102d24fb4 get_deep_attribute + 104
190
+ # 11 _pickle.cpython-313-darwin.so 0x0000000102d250b8 _checkmodule + 88
191
+ # 12 _pickle.cpython-313-darwin.so 0x0000000102d22588 save_global + 3024
192
+ # 13 _pickle.cpython-313-darwin.so 0x0000000102d1eddc save + 3424
193
+ # ...
194
+ # Right now, we just check for `item == "torch"` as a heuristic,
195
+ # which should never exist for any of the old-style wrapped modules here.
196
+ # We could maybe also check sys._getframe(1).f_code or so and add some other heuristics...
197
+ raise AttributeError(f"module {self.__name__} has no attribute {item} (lazy loading skipped)")
198
+
175
199
  module = self._load()
176
200
  return getattr(module, item)
177
201
 
@@ -1,2 +1,2 @@
1
- version = '1.20260109.093428'
2
- long_version = '1.20260109.093428+git.68426d7'
1
+ version = '1.20260113.134416'
2
+ long_version = '1.20260113.134416+git.8c8a566'
returnn/datasets/lm.py CHANGED
@@ -86,6 +86,7 @@ class LmDataset(CachedDataset2):
86
86
  delayed_seq_data_start_symbol="[START]",
87
87
  dtype: Optional[str] = None,
88
88
  tag_prefix: Optional[str] = None,
89
+ _debug_limit_line_count: Optional[int] = None,
89
90
  **kwargs,
90
91
  ):
91
92
  """
@@ -138,6 +139,8 @@ class LmDataset(CachedDataset2):
138
139
  delayed_seq_data_start_symbol + original_sequence[:-1].
139
140
  :param str delayed_seq_data_start_symbol: used for add_delayed_seq_data.
140
141
  :param dtype: explicit dtype. if not given, automatically determined based on the number of labels.
142
+ :param tag_prefix: prefix for sequence tags. by default "line-".
143
+ :param _debug_limit_line_count:
141
144
  """
142
145
  super(LmDataset, self).__init__(**kwargs)
143
146
 
@@ -316,6 +319,10 @@ class LmDataset(CachedDataset2):
316
319
  self.num_skipped = 0
317
320
  self.num_unknown = 0
318
321
 
322
+ if _debug_limit_line_count is None:
323
+ _debug_limit_line_count = _get_debug_limit_line_count()
324
+ self._debug_limit_line_count = _debug_limit_line_count
325
+
319
326
  def _lazy_init(self):
320
327
  if self._orths_offsets_and_lens is not None:
321
328
  return
@@ -340,6 +347,9 @@ class LmDataset(CachedDataset2):
340
347
  lens_per_corpus_file = []
341
348
  start_time = time.time()
342
349
  last_print_time = start_time
350
+ debug_limit_line_count = self._debug_limit_line_count
351
+ debug_limit_est_total = 0
352
+ debug_limit_hit = False
343
353
 
344
354
  def _init_tmp_file():
345
355
  nonlocal tmp_file, tmp_file_orth_files_index
@@ -368,13 +378,16 @@ class LmDataset(CachedDataset2):
368
378
 
369
379
  if time.time() - last_print_time > 10:
370
380
  print(
371
- f" ... loaded {len(self._orths_offsets_and_lens)} sequences,"
381
+ f" ... loaded {len(orths)} sequences,"
372
382
  f" {human_bytes_size(total_bytes_read)},"
373
383
  f" after {hms(time.time() - start_time)}",
374
384
  file=log.v4,
375
385
  )
376
386
  last_print_time = time.time()
377
387
 
388
+ if debug_limit_line_count is not None and len(orths) - prev_orth_len >= debug_limit_line_count:
389
+ raise _ReachedDebugLimitLineCount()
390
+
378
391
  # If a list of files is provided, concatenate all.
379
392
  if isinstance(corpus_file, str):
380
393
  corpus_file = [corpus_file]
@@ -383,37 +396,46 @@ class LmDataset(CachedDataset2):
383
396
  for file_name in corpus_file:
384
397
  if self._use_cache_manager:
385
398
  file_name = cf(file_name)
386
- if _is_bliss(file_name):
387
- _init_tmp_file()
388
- _iter_bliss(filename=file_name, callback=_tmp_file_add_line, decode=False)
389
- elif file_name.endswith(".gz"):
390
- _init_tmp_file()
391
- _iter_txt(
392
- filename=file_name,
393
- callback=_tmp_file_add_line,
394
- skip_empty_lines=self._skip_empty_lines,
395
- decode=False,
396
- )
397
- else: # Raw txt file
398
- # Directly mmap the file.
399
- # We just need to scan once through it to find line offsets.
400
- file = open(file_name, "rb")
401
- file_mmap = mmap.mmap(file.fileno(), 0, flags=mmap.MAP_PRIVATE)
402
- file_index = len(self._orth_files)
403
- self._orth_files.append(file)
404
- self._orth_mmaps.append(file_mmap)
405
-
406
- pos = 0
407
- while True:
408
- next_new_line = file_mmap.find(b"\n", pos)
409
- if next_new_line == -1:
410
- break
411
- line_len = next_new_line - pos
412
- if line_len or not self._skip_empty_lines:
413
- orths.append((file_index, pos, line_len))
414
- total_bytes_read += line_len + 1
415
- pos = next_new_line + 1
416
- _maybe_report_status()
399
+
400
+ try:
401
+ if _is_bliss(file_name):
402
+ _init_tmp_file()
403
+ _iter_bliss(filename=file_name, callback=_tmp_file_add_line, decode=False)
404
+ elif file_name.endswith(".gz"):
405
+ _init_tmp_file()
406
+ _iter_txt(
407
+ filename=file_name,
408
+ callback=_tmp_file_add_line,
409
+ skip_empty_lines=self._skip_empty_lines,
410
+ decode=False,
411
+ )
412
+ else: # Raw txt file
413
+ # Directly mmap the file.
414
+ # We just need to scan once through it to find line offsets.
415
+ file = open(file_name, "rb")
416
+ file_mmap = mmap.mmap(file.fileno(), 0, flags=mmap.MAP_PRIVATE)
417
+ file_index = len(self._orth_files)
418
+ self._orth_files.append(file)
419
+ self._orth_mmaps.append(file_mmap)
420
+
421
+ pos = 0
422
+ while True:
423
+ next_new_line = file_mmap.find(b"\n", pos)
424
+ if next_new_line == -1:
425
+ break
426
+ line_len = next_new_line - pos
427
+ if line_len or not self._skip_empty_lines:
428
+ orths.append((file_index, pos, line_len))
429
+ total_bytes_read += line_len + 1
430
+ pos = next_new_line + 1
431
+ _maybe_report_status()
432
+
433
+ except _ReachedDebugLimitLineCount as exc:
434
+ assert exc.estimated_total_num_seqs is not None # currently only for _iter_txt implemented
435
+ debug_limit_est_total += exc.estimated_total_num_seqs
436
+ debug_limit_hit = True
437
+ else: # iteration completed without hitting debug limit
438
+ debug_limit_est_total += len(orths) - prev_orth_len
417
439
 
418
440
  lens_per_corpus_file.append(len(orths) - prev_orth_len)
419
441
  prev_orth_len = len(orths)
@@ -447,6 +469,18 @@ class LmDataset(CachedDataset2):
447
469
  file=log.v4,
448
470
  )
449
471
 
472
+ if debug_limit_hit:
473
+ est_frac_loaded = len(self._orths_offsets_and_lens) / debug_limit_est_total
474
+ new_partition_epoch = max(int(self.partition_epoch * est_frac_loaded), 1)
475
+ print(
476
+ f"LmDataset: debug limit of {debug_limit_line_count} lines (per file) hit,"
477
+ f" estimated total num seqs {debug_limit_est_total},"
478
+ f" loaded {len(self._orths_offsets_and_lens)}, {est_frac_loaded:.2%},"
479
+ f" adjusting partition_epoch from {self.partition_epoch} to {new_partition_epoch}",
480
+ file=log.v4,
481
+ )
482
+ self.partition_epoch = new_partition_epoch
483
+
450
484
  # It's only estimated because we might filter some out or so.
451
485
  self._estimated_num_seqs = len(self._orths_offsets_and_lens) // self.partition_epoch
452
486
 
@@ -784,19 +818,34 @@ def _iter_txt(
784
818
  :param decode:
785
819
  """
786
820
  f = open(filename, "rb")
821
+ f_ = f
787
822
  if filename.endswith(".gz"):
788
823
  f = gzip.GzipFile(fileobj=f)
789
824
 
790
- for line in f:
791
- if decode:
792
- try:
793
- line = line.decode("utf8")
794
- except UnicodeDecodeError:
795
- line = line.decode("latin_1") # or iso8859_15?
796
- line = line.strip()
797
- if skip_empty_lines and not line:
798
- continue
799
- callback(line)
825
+ count = 0
826
+ try:
827
+ for line in f:
828
+ if decode:
829
+ try:
830
+ line = line.decode("utf8")
831
+ except UnicodeDecodeError:
832
+ line = line.decode("latin_1") # or iso8859_15?
833
+ line = line.strip()
834
+ if skip_empty_lines and not line:
835
+ continue
836
+ count += 1
837
+ callback(line)
838
+
839
+ except _ReachedDebugLimitLineCount as exc:
840
+ print(f"Reached debug limit line count for {filename}, stopping early", file=log.v4)
841
+ pos = f_.tell()
842
+ f_.seek(0, os.SEEK_END)
843
+ size = f_.tell()
844
+ print(f" stopped at byte {human_bytes_size(pos)} / {human_bytes_size(size)}", file=log.v4)
845
+ estimated_num_seqs = int(count * (size / pos))
846
+ print(f" estimated total num seqs: {estimated_num_seqs}", file=log.v4)
847
+ exc.estimated_total_num_seqs = estimated_num_seqs
848
+ raise
800
849
 
801
850
 
802
851
  def iter_corpus(
@@ -2517,6 +2566,25 @@ def get_post_processor_function(opts):
2517
2566
  return chained_post_processors
2518
2567
 
2519
2568
 
2569
+ def _get_debug_limit_line_count() -> Optional[int]:
2570
+ """
2571
+ :return: if set, limit to this many lines for debugging
2572
+ """
2573
+ from returnn.config import get_global_config
2574
+
2575
+ config = get_global_config(raise_exception=False)
2576
+ if not config:
2577
+ return None
2578
+
2579
+ return config.int("lm_dataset_debug_limit_line_count", None)
2580
+
2581
+
2582
+ class _ReachedDebugLimitLineCount(Exception):
2583
+ """internal exception to signal reached debug limit line count"""
2584
+
2585
+ estimated_total_num_seqs: Optional[int] = None
2586
+
2587
+
2520
2588
  def _main():
2521
2589
  from returnn.util import better_exchook
2522
2590
 
@@ -67,6 +67,24 @@ def _code_hash_md5(filename: str) -> str:
67
67
 
68
68
 
69
69
  _is_set_up = False
70
+ _enabled = True
71
+
72
+
73
+ def set_enabled(enabled: bool):
74
+ """
75
+ Enable or disable the native code setup.
76
+
77
+ :param enabled:
78
+ """
79
+ global _enabled
80
+ _enabled = enabled
81
+
82
+
83
+ def is_set_up() -> bool:
84
+ """
85
+ :return: whether the native code is set up
86
+ """
87
+ return _is_set_up
70
88
 
71
89
 
72
90
  def setup():
@@ -76,6 +94,8 @@ def setup():
76
94
  global _is_set_up
77
95
  if _is_set_up:
78
96
  return
97
+ if not _enabled:
98
+ return
79
99
  _is_set_up = True # only try once
80
100
 
81
101
  from returnn.tensor import Tensor, Dim
@@ -177,6 +197,8 @@ def setup_torch():
177
197
  global _is_set_up_torch
178
198
  if _is_set_up_torch:
179
199
  return
200
+ if not _enabled:
201
+ return
180
202
  _is_set_up_torch = True # only try once
181
203
 
182
204
  import torch
@@ -110,7 +110,7 @@ def bin_op_out_template(
110
110
  all_dims.extend([dim_ for dim_ in a.dims if dim_ == dim])
111
111
  else:
112
112
  all_dims.extend([dim_ for dim_ in b.dims if dim_ == dim])
113
- if all(set(x.dims) != set(all_dims) for x in (a, b)):
113
+ if all([set(x.dims) != set(all_dims) for x in (a, b)]):
114
114
  if allow_broadcast_all_sources is False:
115
115
  raise ValueError(f"compare: sources {a!r} {b!r} not allowed with allow_broadcast_all_sources=False")
116
116
  elif allow_broadcast_all_sources is None:
@@ -195,7 +195,7 @@ def merge_dims(
195
195
  if out_dim is None:
196
196
  from returnn.util.basic import prod
197
197
 
198
- if any(d.need_masking() for d in dims[1:]):
198
+ if any([d.need_masking() for d in dims[1:]]):
199
199
  # The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
200
200
  # This would then potentially discard some of the data in the tensor in subsequent operations,
201
201
  # when masking is applied.
@@ -910,7 +910,7 @@ def scatter(
910
910
  else:
911
911
  raise ValueError(f"scatter: invalid mode {mode!r}")
912
912
  indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
913
- if any(dim.need_masking() for dim in indices_dim):
913
+ if any([dim.need_masking() for dim in indices_dim]):
914
914
  if use_mask is None:
915
915
  use_mask = rf.use_mask_default(
916
916
  default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
@@ -858,7 +858,7 @@ class _DimMixin:
858
858
  self._make_extra()
859
859
  dim_order_default = self.dyn_size_ext.dims + (self,)
860
860
  if dim_order is not None:
861
- dim_order = tuple(d for d in dim_order if d in dim_order_default) # filter
861
+ dim_order = tuple([d for d in dim_order if d in dim_order_default]) # filter
862
862
  else:
863
863
  dim_order = dim_order_default
864
864
  cache_key = (device, dim_order)
@@ -2484,16 +2484,16 @@ _BinOpStrs = {
2484
2484
 
2485
2485
  def _math_get_dim_via_bin_op(dims: Sequence[Union[Dim, int]], op_kind: str) -> Dim:
2486
2486
  dims = [d if isinstance(d, _d.Dim) else _make_constant_static_dim(d) for d in dims]
2487
- if all(d.dimension is not None for d in dims):
2487
+ if all([d.dimension is not None for d in dims]):
2488
2488
  op = _BinOps[op_kind]
2489
2489
  dim_value = dims[0].dimension
2490
2490
  for d in dims[1:]:
2491
2491
  dim_value = op(dim_value, d.dimension)
2492
2492
  else:
2493
2493
  dim_value = None
2494
- if all(d.is_constant_static_dim() for d in dims):
2494
+ if all([d.is_constant_static_dim() for d in dims]):
2495
2495
  return _make_constant_static_dim(dim_value, kind=_get_merged_dim_kind(dims))
2496
- desc = _BinOpStrs[op_kind].join(_get_description(d) for d in dims)
2496
+ desc = _BinOpStrs[op_kind].join([_get_description(d) for d in dims])
2497
2497
  if op_kind.startswith("ceildiv"):
2498
2498
  desc = f"⌈{desc}⌉"
2499
2499
  return _d.Dim(
@@ -2676,16 +2676,16 @@ def _get_description(dim, brackets=True):
2676
2676
 
2677
2677
 
2678
2678
  def _get_merged_dim_kind(dim_tags: Sequence[Dim]) -> Entity:
2679
- if any(tag.is_batch_dim() for tag in dim_tags):
2679
+ if any([tag.is_batch_dim() for tag in dim_tags]):
2680
2680
  return DimTypes.Batch
2681
- elif any(tag.is_feature_dim() for tag in dim_tags):
2681
+ elif any([tag.is_feature_dim() for tag in dim_tags]):
2682
2682
  return DimTypes.Feature
2683
2683
  else:
2684
2684
  return DimTypes.Spatial
2685
2685
 
2686
2686
 
2687
2687
  def _representative_tag(terms: Sequence[Dim]) -> Optional[Dim]:
2688
- if any(not term_.auto_generated for term_ in terms):
2688
+ if any([not term_.auto_generated for term_ in terms]):
2689
2689
  # Always prefer non-auto-generated.
2690
2690
  terms = [term_ for term_ in terms if not term_.auto_generated]
2691
2691
  # First find any dynamic.
@@ -32,8 +32,8 @@ class _TensorExtra:
32
32
  tensor: Tensor,
33
33
  time_dim_axis=NotSpecified,
34
34
  available_for_inference=True,
35
- batch=None,
36
- beam=None,
35
+ batch: Optional[BatchInfo] = None,
36
+ beam: Optional[SearchBeam] = None,
37
37
  control_flow_ctx=None,
38
38
  ):
39
39
  """
@@ -41,8 +41,8 @@ class _TensorExtra:
41
41
  :param int|None|NotSpecified time_dim_axis: where we have the time dim axis, after we added the batch-dim.
42
42
  this is often 1. however, can be None if there is no time-dim.
43
43
  :param bool available_for_inference: e.g. the extern data "classes" is usually not available for inference
44
- :param BatchInfo|None batch:
45
- :param SearchBeam|None beam: the batch-dim could be extended by a beam-size,
44
+ :param batch:
45
+ :param beam: the batch-dim could be extended by a beam-size,
46
46
  such that it represents the merged dims [batch, beam_size].
47
47
  :param ControlFlowContext|None control_flow_ctx:
48
48
  """
@@ -668,11 +668,11 @@ class _TensorMixin(_TensorMixinBase):
668
668
  if not perm:
669
669
  return self.copy()
670
670
  if allow_int and isinstance(perm[0], int):
671
- assert all(isinstance(a, int) for a in perm), f"{self}: invalid perm {perm!r} types"
671
+ assert all([isinstance(a, int) for a in perm]), f"{self}: invalid perm {perm!r} types"
672
672
  assert set(perm) == set(range(len(perm))), f"{self}: invalid perm {perm!r}"
673
673
  return self._copy_compatible_to_dims_with_perm([self._dims[i] for i in perm], perm)
674
674
  else:
675
- assert all(isinstance(a, Dim) for a in perm), f"{self}: invalid perm {perm!r} types"
675
+ assert all([isinstance(a, Dim) for a in perm]), f"{self}: invalid perm {perm!r} types"
676
676
  return self.copy_compatible_to_dims(perm)
677
677
 
678
678
  def copy_move_axis(self, old_axis, new_axis) -> _t.Tensor:
@@ -1155,7 +1155,7 @@ class _TensorMixin(_TensorMixinBase):
1155
1155
  )
1156
1156
 
1157
1157
  assert v.batch_ndim == data.batch_ndim
1158
- assert all(mapped_axes[ax] == ax for ax in range(v.batch_ndim))
1158
+ assert all([mapped_axes[ax] == ax for ax in range(v.batch_ndim)])
1159
1159
 
1160
1160
  if self.version == 1:
1161
1161
  # Ensure time_dim_axis and feature_dim_axis is same as in data
@@ -1702,7 +1702,7 @@ class _TensorMixin(_TensorMixinBase):
1702
1702
  """
1703
1703
  :return: shape with added batch-dim. e.g. (batch,time,feat) = (None,None,128)
1704
1704
  """
1705
- return tuple(tag.dimension for tag in self.dim_tags)
1705
+ return tuple([tag.dimension for tag in self.dim_tags])
1706
1706
 
1707
1707
  # noinspection PyShadowingNames
1708
1708
  def get_batch_shape(self, batch_dim):
@@ -3214,7 +3214,7 @@ class _TensorMixin(_TensorMixinBase):
3214
3214
  if len(sources) == 1:
3215
3215
  return sources[0].copy_template()
3216
3216
  max_ndim = max([s.batch_ndim for s in sources])
3217
- if any(src.batch for src in sources):
3217
+ if any([src.batch for src in sources]):
3218
3218
  from returnn.tf.util.data import BatchInfo
3219
3219
 
3220
3220
  common_batch = BatchInfo.get_common_batch_info([src.batch for src in sources if src.batch])
@@ -3254,7 +3254,7 @@ class _TensorMixin(_TensorMixinBase):
3254
3254
  else:
3255
3255
  axis = common.get_default_new_axis_for_dim_tag(dim_tag)
3256
3256
  common = common.copy_add_dim_by_tag(dim_tag, unbroadcast=True, axis=axis)
3257
- if all(s.batch_ndim < common.batch_ndim for s in sources):
3257
+ if all([s.batch_ndim < common.batch_ndim for s in sources]):
3258
3258
  from returnn.util.basic import validate_broadcast_all_sources
3259
3259
 
3260
3260
  validate_broadcast_all_sources(
returnn/torch/engine.py CHANGED
@@ -3,9 +3,11 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
+
6
7
  from typing import Optional, Any, Union, Callable, Dict, Set
7
8
  from contextlib import nullcontext, ExitStack, contextmanager
8
9
 
10
+ import sys
9
11
  import gc
10
12
  import os
11
13
  import time
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
20
22
  from torch.utils.data import DataLoader
21
23
  from torch import autocast
22
24
  from torch.cuda import amp
25
+ from torch.profiler import record_function
23
26
  import numpy as np
24
27
 
25
28
  import returnn
@@ -404,10 +407,14 @@ class Engine(EngineBase):
404
407
  total_data_size_packed = NumbersDict()
405
408
  total_data_size_padded = NumbersDict()
406
409
 
410
+ prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
411
+ if prof:
412
+ prof.__enter__()
413
+
407
414
  report_prefix = f"ep {self.epoch} train"
408
415
  try:
409
416
  while True:
410
- with torch.no_grad():
417
+ with torch.no_grad(), record_function("data_loading"):
411
418
  extern_data_raw = next(data_iter, None)
412
419
 
413
420
  step_begin_time = time.monotonic()
@@ -485,7 +492,8 @@ class Engine(EngineBase):
485
492
  with (
486
493
  self._ddp_pt_model.no_sync()
487
494
  if (self._ddp_pt_model is not None and not perform_update_step)
488
- else nullcontext()
495
+ else nullcontext(),
496
+ record_function("backward"),
489
497
  ):
490
498
  if self._grad_scaler is not None:
491
499
  self._grad_scaler.scale(total_loss.raw_tensor).backward()
@@ -500,7 +508,8 @@ class Engine(EngineBase):
500
508
 
501
509
  # only update the weights when every gradient accumulation loop ends
502
510
  if perform_update_step:
503
- self._updater.step(grad_scaler=self._grad_scaler)
511
+ with record_function("optimizer_step"):
512
+ self._updater.step(grad_scaler=self._grad_scaler)
504
513
  zero_grad_next_step = perform_update_step
505
514
 
506
515
  if self._torch_distributed_ctx:
@@ -582,10 +591,19 @@ class Engine(EngineBase):
582
591
  self._updater.set_current_train_step(
583
592
  global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
584
593
  )
594
+
595
+ if prof:
596
+ prof.step()
597
+
585
598
  except Exception as exc:
599
+ if prof:
600
+ prof.__exit__(type(exc), exc, exc.__traceback__)
586
601
  help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
587
602
  raise
588
603
 
604
+ if prof:
605
+ prof.__exit__(None, None, None)
606
+
589
607
  elapsed = time.monotonic() - epoch_start_time
590
608
  elapsed_computation_percentage = elapsed_computation_time / elapsed
591
609
  total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
@@ -885,6 +903,7 @@ class Engine(EngineBase):
885
903
  if self._default_float_dtype:
886
904
  stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
887
905
  stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
906
+ stack.enter_context(record_function("model_step"))
888
907
  yield
889
908
 
890
909
  def _run_step(
@@ -1734,3 +1753,101 @@ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str,
1734
1753
  return safetensors_load(filename, device=device)
1735
1754
 
1736
1755
  return torch.load(filename, map_location=device)
1756
+
1757
+
1758
+ class _TorchProfiler:
1759
+ def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
1760
+ self.profiler = profiler
1761
+ self.max_step = max_step
1762
+ self.entered = False
1763
+
1764
+ def __enter__(self):
1765
+ self.profiler.__enter__()
1766
+ self.entered = True
1767
+
1768
+ def __exit__(self, exc_type, exc_val, exc_tb):
1769
+ if not self.entered:
1770
+ return
1771
+ self.entered = False
1772
+ self.profiler.__exit__(exc_type, exc_val, exc_tb)
1773
+
1774
+ if exc_type is None:
1775
+ print(
1776
+ "Torch profiling finished, exporting Chrome trace to torch_profile.json,"
1777
+ " memory timeline to torch_memory_profile.html...",
1778
+ file=log.v2,
1779
+ )
1780
+ self.profiler.export_chrome_trace("torch_profile.json")
1781
+ self.profiler.export_memory_timeline("torch_memory_profile.html")
1782
+
1783
+ print("Exiting program after Torch profiling.", file=log.v2)
1784
+ sys.exit(0)
1785
+
1786
+ def step(self):
1787
+ """step"""
1788
+ self.profiler.step()
1789
+ if self.max_step is not None and self.profiler.step_num > self.max_step:
1790
+ print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
1791
+ self.profiler.stop()
1792
+ self.__exit__(None, None, None)
1793
+
1794
+
1795
+ def _opt_torch_profiler_from_opts(
1796
+ opts: Union[None, int, bool, str, Dict[str, Any]],
1797
+ ) -> Optional[_TorchProfiler]:
1798
+ if isinstance(opts, str):
1799
+ from returnn.util.basic import to_bool
1800
+
1801
+ opts = to_bool(opts)
1802
+
1803
+ if opts is None:
1804
+ return None
1805
+ elif isinstance(opts, (bool, int)):
1806
+ if not opts:
1807
+ return None
1808
+ opts = {}
1809
+ elif isinstance(opts, dict):
1810
+ opts = opts.copy()
1811
+ else:
1812
+ raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
1813
+
1814
+ from torch.profiler import profile, ProfilerActivity, schedule
1815
+
1816
+ print("Using Torch profiler...", file=log.v2)
1817
+
1818
+ prof_max_step = None
1819
+
1820
+ if "activities" not in opts:
1821
+ activities = [ProfilerActivity.CPU]
1822
+ if torch.cuda.is_available():
1823
+ activities += [ProfilerActivity.CUDA]
1824
+ elif torch.xpu.is_available():
1825
+ activities += [ProfilerActivity.XPU]
1826
+ opts["activities"] = activities
1827
+
1828
+ opts.setdefault("profile_memory", True)
1829
+ opts.setdefault("record_shapes", True)
1830
+ opts.setdefault("with_stack", True)
1831
+ opts.setdefault("with_flops", True)
1832
+ # Note: active*repeat are the steps we actually profile.
1833
+ opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
1834
+
1835
+ if isinstance(opts["schedule"], dict):
1836
+ schedule_opts: Dict[str, Any] = opts["schedule"]
1837
+ schedule_opts = schedule_opts.copy()
1838
+ schedule_opts.setdefault("repeat", 0)
1839
+ schedule_opts.setdefault("skip_first", 0)
1840
+ schedule_opts.setdefault("skip_first_wait", 0)
1841
+ opts["schedule"] = schedule(**schedule_opts)
1842
+
1843
+ if schedule_opts["repeat"] > 0:
1844
+ prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
1845
+ "repeat"
1846
+ ]
1847
+ prof_max_step += schedule_opts["skip_first"]
1848
+ if schedule_opts["skip_first_wait"] != 0:
1849
+ prof_max_step -= schedule_opts["wait"]
1850
+ print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
1851
+
1852
+ prof = profile(**opts)
1853
+ return _TorchProfiler(prof, prof_max_step)
@@ -275,7 +275,7 @@ class TorchBackend(Backend[torch.Tensor]):
275
275
  :return: tensor
276
276
  """
277
277
  assert len(dims) >= 2
278
- first_axis = min(source.dims.index(d) for d in dims)
278
+ first_axis = min([source.dims.index(d) for d in dims])
279
279
  pre_dims = source.dims[:first_axis]
280
280
  post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
281
281
  source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
@@ -884,7 +884,7 @@ class TorchBackend(Backend[torch.Tensor]):
884
884
  :param perm: e.g. [0, 2, 1]
885
885
  :return: permuted (transposed) raw tensor; wraps torch.permute
886
886
  """
887
- if all(p == i for i, p in enumerate(perm)):
887
+ if all([p == i for i, p in enumerate(perm)]):
888
888
  return raw_tensor
889
889
  return torch.permute(raw_tensor, tuple(perm))
890
890
 
@@ -1788,7 +1788,7 @@ class TorchBackend(Backend[torch.Tensor]):
1788
1788
  remaining_dims = [d for d in tensor.dims if d not in mask.dims]
1789
1789
  tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
1790
1790
  in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
1791
- if any(in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)):
1791
+ if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
1792
1792
  # unbroadcast
1793
1793
  in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
1794
1794
  if mask.raw_tensor.device.type == "meta":
@@ -0,0 +1,106 @@
1
+ """
2
+ Helpers to improve torch.compile on RF code.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import Any, Iterable, List, Tuple
7
+
8
+ import os
9
+ from returnn.tensor import Tensor, Dim
10
+
11
+ # noinspection PyProtectedMember
12
+ from returnn.frontend import _native
13
+
14
+ _is_set_up = False
15
+
16
+
17
+ def setup():
18
+ """
19
+ Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
20
+ """
21
+
22
+ global _is_set_up
23
+ if _is_set_up:
24
+ return
25
+ _is_set_up = True # only try once
26
+
27
+ assert not _native.is_set_up(), "Call this setup() as early as possible."
28
+ _native.set_enabled(False)
29
+
30
+ # We have lots of dynamic shapes.
31
+ os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
32
+
33
+ # noinspection PyProtectedMember
34
+ from torch.utils._pytree import register_pytree_node
35
+
36
+ register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
37
+ register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
38
+
39
+ Dim.get_dim_value = _dim_get_dim_value
40
+
41
+
42
+ def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
43
+ """
44
+ Flatten the tensor for PyTree.
45
+ """
46
+ return [t.raw_tensor, t.dims, t.sparse_dim], [
47
+ t.name,
48
+ t.dtype,
49
+ t.version,
50
+ t.feature_dim_axis_or_unspecified,
51
+ t.time_dim_axis_or_unspecified,
52
+ ]
53
+
54
+
55
+ def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
56
+ """
57
+ Unflatten the tensor from PyTree.
58
+ """
59
+ raw_tensor, dims, sparse_dim = values
60
+ name, dtype, version, feature_dim_axis, time_dim_axis = metadata
61
+ return Tensor(
62
+ name=name,
63
+ dims=dims,
64
+ dtype=dtype,
65
+ sparse_dim=sparse_dim,
66
+ feature_dim_axis=feature_dim_axis,
67
+ time_dim_axis=time_dim_axis,
68
+ raw_tensor=raw_tensor,
69
+ version=version,
70
+ )
71
+
72
+
73
+ def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
74
+ """
75
+ Flatten the dim for PyTree.
76
+ """
77
+ return [d.dyn_size_ext], [d.name, d.dimension, d.size]
78
+
79
+
80
+ def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
81
+ """
82
+ Unflatten the dim from PyTree.
83
+ """
84
+ (dyn_size_ext,) = values
85
+ name, dimension, size = metadata
86
+ # TODO this creates a new instance... this is maybe wrong?
87
+ return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
88
+
89
+
90
+ def _dim_get_dim_value(self: Dim) -> int:
91
+ """
92
+ Infers the dim this axis should have if unbroadcasted.
93
+ If `self.src_data` has a placeholder, will use the shape from there.
94
+ Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
95
+
96
+ :return: max(size or dyn_size)
97
+ """
98
+ res = self.get_dim_value_tensor()
99
+ if isinstance(res, Tensor):
100
+ assert res.dims == ()
101
+ assert res.raw_tensor is not None
102
+ # Specifically PyTorch would then treat it as a SymInt in torch.compile,
103
+ # which is important to have for some torch functions (e.g. torch.tile and others).
104
+ return int(res.raw_tensor)
105
+ assert isinstance(res, int)
106
+ return res
returnn/util/basic.py CHANGED
@@ -3816,6 +3816,8 @@ def should_write_to_disk(config):
3816
3816
  return False
3817
3817
  if config.is_true("dry_run"):
3818
3818
  return False
3819
+ if config.is_true("torch_profile"):
3820
+ return False
3819
3821
  return True
3820
3822
 
3821
3823
 
returnn/util/debug.py CHANGED
@@ -309,6 +309,7 @@ def _get_native_signal_handler_lib_filename() -> str:
309
309
  old_signal_handler[SIGILL] = signal(SIGILL, signal_handler);
310
310
  old_signal_handler[SIGABRT] = signal(SIGABRT, signal_handler);
311
311
  old_signal_handler[SIGFPE] = signal(SIGFPE, signal_handler);
312
+ old_signal_handler[SIGUSR1] = signal(SIGUSR1, signal_handler);
312
313
  }
313
314
  """
314
315
  ),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20260109.93428
3
+ Version: 1.20260113.134416
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=8G2OFR-V5IlE98f0vmLneA27jg9-B7eN973G7vJpj0I,5215
1
+ returnn/PKG-INFO,sha256=jhNOEgbBWBglgqkHqni28aMhOK1nHC1dJlBiKkaWfX0,5216
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
- returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
4
+ returnn/__old_mod_loader__.py,sha256=-XAtilhq87CqmWmK2awbfGLoPAwjLGVu8t4QAxCw0fQ,9436
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=RHjC4xFQRTza5prANYrhwttWMqAEZoLiKwzMpCmll80,77
6
+ returnn/_setup_info_generated.py,sha256=OVoyfxrF7cQ0OBpiMfzvCidyV9ia6hJFPW-TrKd9BYE,77
7
7
  returnn/config.py,sha256=JK8EjDsUdyY2c90s0KY1rLD1kesVfz6vRT0gxy_AQ5I,29142
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -21,7 +21,7 @@ returnn/datasets/distrib_files.py,sha256=48edqdf7YpnPJ-TOis3Mz5U9A2DSxfiYT1HCMSt
21
21
  returnn/datasets/generating.py,sha256=o9-JZ2s5QKssux6GcSaM3oivf_PE6nhSOeytRyGB7pQ,99574
22
22
  returnn/datasets/hdf.py,sha256=v5sjBenURR9Z-g7AQ9tsL84yDSye5RtbLpym3M6HSDE,67833
23
23
  returnn/datasets/huggingface.py,sha256=ls9WMR6gUcMgGksl80g0An1az5Xjya_V3ojbbbsZqrU,20047
24
- returnn/datasets/lm.py,sha256=CXl_g-Z28RWlBTzx35uC4r_GCwOP05LIsUp0iSi6JG4,100652
24
+ returnn/datasets/lm.py,sha256=riDa7rkwOuPX53_0y9wgQ_s2A9453BX0gWGV0HX29_M,103614
25
25
  returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
26
26
  returnn/datasets/meta.py,sha256=hTtfwINIxP2S4JQ5IQXzvTh2MixwxzeF06pPTW36yl0,101456
27
27
  returnn/datasets/multi_proc.py,sha256=BClXq0fActi1XQa4vcMhHmhYF0Q-fnnDzlIlbBM6_DM,22614
@@ -80,8 +80,8 @@ returnn/frontend/_backend.py,sha256=MVZn2HSkF3tsqchYvy2QM9pA4ILdKq07kj-_AAHGUy0,
80
80
  returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,8550
81
81
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
82
82
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
83
- returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
84
- returnn/frontend/array_.py,sha256=bZwTgNkMsGiSP6TVgI7bxY6zZMjcs9TVsHlajYrHUoA,56791
83
+ returnn/frontend/_utils.py,sha256=LTwYQJBT9XjRdC2kVvHy29eUN5qARNSLGMJk90a8PjI,12076
84
+ returnn/frontend/array_.py,sha256=2VQYtlB6OiKdpkU9H_w_jIUrb7mlxizz7KKOHjnYaeo,56795
85
85
  returnn/frontend/attention.py,sha256=bFD9Ei6GxSi-BC1OfueDyTIE-51a3dKKZOWdSIbz7l8,46633
86
86
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
87
87
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -120,7 +120,7 @@ returnn/frontend/state.py,sha256=EePdrx6PtWL4mJ2XZmGlh5dl4nq6G9wZpqP4hdDEzfY,293
120
120
  returnn/frontend/stepwise_scheduler.py,sha256=fMOTR7npGCDXrXDmSQ4VwmudoHEbY3Yr-QGyjFdQJSc,927
121
121
  returnn/frontend/tensor_array.py,sha256=Ej7CHtvpY0yBROlAk5vFe3CTXh-iAuqu9qcXS3Qxt2I,4328
122
122
  returnn/frontend/types.py,sha256=r-QsxPQyFSr9WwCRzqTn_X5jQLbjthrtjHavY8XIDmk,1099
123
- returnn/frontend/_native/__init__.py,sha256=fVjazAujt0rdICXZL-GgW1sjFeL1HB4NPuy2m5rmMsc,6480
123
+ returnn/frontend/_native/__init__.py,sha256=VVK0x6Z7OZa3Sb4QDSz9sRrBhX8FfYdvrwhAg4W9-cc,6839
124
124
  returnn/frontend/_native/backend.cpp,sha256=MeHczHypwj_ncntOxRqanK8SqGyV9Eq1X0cpMWb_WII,4768
125
125
  returnn/frontend/_native/backend.hpp,sha256=Wq80dcEzXfRNxGOXFnIgHllkiv1rDi3KpHK-xxJsSDI,791
126
126
  returnn/frontend/_native/module.cpp,sha256=9BCUoDTZDJ6hlXp4pUus1BlN7-oxcRy6tK9ctyCkwk0,15709
@@ -155,8 +155,8 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
155
155
  returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
156
156
  returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
157
157
  returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
158
- returnn/tensor/_dim_extra.py,sha256=tHE3N6hUKqbzedJ8RNhn9aJHxvhTQuI9JckCLsPbKKI,116776
159
- returnn/tensor/_tensor_extra.py,sha256=1UPNisRAbljkvfMcrEXaPAF-2Dz7AdgC3jAKVVAnAO8,165084
158
+ returnn/tensor/_dim_extra.py,sha256=8HLTvgEnThCp7GdtB714Tvs4ad939jZmhpS3qab03sU,116790
159
+ returnn/tensor/_tensor_extra.py,sha256=ClwZBfaOavDtapXYpYRhDTGE85bzvRqox5mF_OnEHds,165112
160
160
  returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
161
161
  returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
162
162
  returnn/tensor/control_flow_ctx.py,sha256=L9e32AfYDUDgsEDHL07thSFyYFqwhyVSqzE_bM03Y4M,5252
@@ -208,7 +208,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
208
208
  returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
209
209
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
210
210
  returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
211
- returnn/torch/engine.py,sha256=XaJhVpF181sf8M1iXAs3u0zr37VVUG3SW81-DIZgg3g,81280
211
+ returnn/torch/engine.py,sha256=JnoGrAakIUIsSXVEIVIXqTOVcDYJASVoRNZQrOPNrdA,85368
212
212
  returnn/torch/updater.py,sha256=nNd1mBPQyvIB096BEFi0KKmRI-U3jnRETzb743p2B9c,32064
213
213
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
214
214
  returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
@@ -217,9 +217,10 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
217
217
  returnn/torch/data/returnn_dataset_wrapper.py,sha256=fMahf05G0SPYm6HxSQpVm8JhsIHons-i1Ce4aQv4IjM,8332
218
218
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
219
219
  returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
220
- returnn/torch/frontend/_backend.py,sha256=8EBRGN0jY5rl9Z5-wd4kvoDesssWcVDVXNl25-bG8cA,108882
220
+ returnn/torch/frontend/_backend.py,sha256=wsmalFnT_p2NDADL8N-6AHHCyv2yBe8nKM-0tKAh1cs,108888
221
221
  returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
222
222
  returnn/torch/frontend/bridge.py,sha256=RBtAIlYWn_AC-GaHWperrOncPjMLWAOrU30pWk2789A,9775
223
+ returnn/torch/frontend/compile_helper.py,sha256=ax8ax5mjC8PDHtwTQzHYWUNRoKjZMuYHF6me9VdxiSY,2969
223
224
  returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
224
225
  returnn/torch/optim/README.md,sha256=0iH5FiKb7iDrVK5n8V6yCh4ciCFG2YSbyh7lPneT5ik,360
225
226
  returnn/torch/optim/__init__.py,sha256=yxdbnOkXAHzZ_t6cHi6zn5x_DQNlLZJ-KxZByHTIg1U,29
@@ -234,11 +235,11 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
234
235
  returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
235
236
  returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
236
237
  returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
237
- returnn/util/basic.py,sha256=rFeg3XwjNcNDbBgjkhisStbjTFA8CEfIrdwHjfdkJKw,143212
238
+ returnn/util/basic.py,sha256=Pa2cAdvOJMKK7gR3heAVTol-zYVbThr9b9slVVAaH3M,143273
238
239
  returnn/util/better_exchook.py,sha256=hOKazwv2q2-d0XMfxkJXMbLZyNTtraV3jPHplFcrMsg,71014
239
240
  returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
240
241
  returnn/util/collect_outputs_dict.py,sha256=CjpsftoMgmvyE4wNKTO6F-QQ_44QHXcOZIXMUMQVZ-8,2637
241
- returnn/util/debug.py,sha256=0ED4etMKG9lVqU0HPKEiCK-HoS8hBgnQza444QCE6ec,28576
242
+ returnn/util/debug.py,sha256=Ndq5nz-tMEG9ZNwZTbgOkQYB9JSvAwF8r0o53Gf2EbM,28653
242
243
  returnn/util/debug_helpers.py,sha256=0EINLK4uLtoSt5_kHs1M2NIFpMd0S7i4c4rx90U4fJk,2914
243
244
  returnn/util/file_cache.py,sha256=8xE4zMQi38g7ZIGwNohd13_CgjzpIs18ILxFCKttzxE,29439
244
245
  returnn/util/fsa.py,sha256=k2lJ8tyf_g44Xk1EPVLwDwpP4spoMTqIigDVOWocQHY,59177
@@ -255,8 +256,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
255
256
  returnn/util/task_system.py,sha256=7Dz7Nvi_1-o5pDv9OZYdAnlJw6OSvgbYUmQ72P0Fgkw,26002
256
257
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
257
258
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
258
- returnn-1.20260109.93428.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
259
- returnn-1.20260109.93428.dist-info/METADATA,sha256=8G2OFR-V5IlE98f0vmLneA27jg9-B7eN973G7vJpj0I,5215
260
- returnn-1.20260109.93428.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
261
- returnn-1.20260109.93428.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
262
- returnn-1.20260109.93428.dist-info/RECORD,,
259
+ returnn-1.20260113.134416.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
260
+ returnn-1.20260113.134416.dist-info/METADATA,sha256=jhNOEgbBWBglgqkHqni28aMhOK1nHC1dJlBiKkaWfX0,5216
261
+ returnn-1.20260113.134416.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
262
+ returnn-1.20260113.134416.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
263
+ returnn-1.20260113.134416.dist-info/RECORD,,