returnn 1.20260109.93428__py3-none-any.whl → 1.20260113.134416__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +110 -42
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +2 -2
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/torch/engine.py +120 -3
- returnn/torch/frontend/_backend.py +3 -3
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/util/basic.py +2 -0
- returnn/util/debug.py +1 -0
- {returnn-1.20260109.93428.dist-info → returnn-1.20260113.134416.dist-info}/METADATA +1 -1
- {returnn-1.20260109.93428.dist-info → returnn-1.20260113.134416.dist-info}/RECORD +19 -18
- {returnn-1.20260109.93428.dist-info → returnn-1.20260113.134416.dist-info}/LICENSE +0 -0
- {returnn-1.20260109.93428.dist-info → returnn-1.20260113.134416.dist-info}/WHEEL +0 -0
- {returnn-1.20260109.93428.dist-info → returnn-1.20260113.134416.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/__old_mod_loader__.py
CHANGED
|
@@ -17,7 +17,7 @@ This is supported as well.
|
|
|
17
17
|
import sys
|
|
18
18
|
import os
|
|
19
19
|
import types
|
|
20
|
-
import
|
|
20
|
+
from typing import Any, Dict
|
|
21
21
|
import importlib
|
|
22
22
|
|
|
23
23
|
old_to_new_mod_mapping = {
|
|
@@ -122,7 +122,7 @@ class _LazyLoader(types.ModuleType):
|
|
|
122
122
|
fn = "%s/%s/__init__.py" % (_base_dir, full_mod_name.replace(".", "/"))
|
|
123
123
|
assert os.path.exists(fn), "_LazyLoader: mod %r not found in %r" % (full_mod_name, _base_dir)
|
|
124
124
|
self.__file__ = fn
|
|
125
|
-
self._lazy_mod_config = dict(full_mod_name=full_mod_name, **kwargs)
|
|
125
|
+
self._lazy_mod_config: Dict[str, Any] = dict(full_mod_name=full_mod_name, **kwargs)
|
|
126
126
|
|
|
127
127
|
def _load(self):
|
|
128
128
|
full_mod_name = self.__name__
|
|
@@ -172,6 +172,30 @@ class _LazyLoader(types.ModuleType):
|
|
|
172
172
|
return super(_LazyLoader, self).__getattribute__(item)
|
|
173
173
|
|
|
174
174
|
def __getattr__(self, item):
|
|
175
|
+
if item == "torch":
|
|
176
|
+
# torch.compile Dynamo hashing can trigger this, when it uses pickle to serialize some object state,
|
|
177
|
+
# which iterates through sys.modules and does getattr on each module.
|
|
178
|
+
# In this case, it searches for torch.
|
|
179
|
+
# File ".../torch/_inductor/codecache.py", line 607 in dumps
|
|
180
|
+
# File ".../torch/_inductor/codecache.py", line 622 in get_hash
|
|
181
|
+
# File ".../torch/_inductor/codecache.py", line 961 in compiled_fx_graph_hash
|
|
182
|
+
# ...
|
|
183
|
+
# Unfortunately, Pickler.dump is native code, so we cannot easily check whether that is the parent frame.
|
|
184
|
+
# The C stacktrace looks like:
|
|
185
|
+
# ...
|
|
186
|
+
# 7 Python 0x0000000102e7d504 call_attribute + 80
|
|
187
|
+
# 8 Python 0x0000000102e7d400 _Py_slot_tp_getattr_hook + 576
|
|
188
|
+
# 9 Python 0x0000000102e507a0 PyObject_GetOptionalAttr + 248
|
|
189
|
+
# 10 _pickle.cpython-313-darwin.so 0x0000000102d24fb4 get_deep_attribute + 104
|
|
190
|
+
# 11 _pickle.cpython-313-darwin.so 0x0000000102d250b8 _checkmodule + 88
|
|
191
|
+
# 12 _pickle.cpython-313-darwin.so 0x0000000102d22588 save_global + 3024
|
|
192
|
+
# 13 _pickle.cpython-313-darwin.so 0x0000000102d1eddc save + 3424
|
|
193
|
+
# ...
|
|
194
|
+
# Right now, we just check for `item == "torch"` as a heuristic,
|
|
195
|
+
# which should never exist for any of the old-style wrapped modules here.
|
|
196
|
+
# We could maybe also check sys._getframe(1).f_code or so and add some other heuristics...
|
|
197
|
+
raise AttributeError(f"module {self.__name__} has no attribute {item} (lazy loading skipped)")
|
|
198
|
+
|
|
175
199
|
module = self._load()
|
|
176
200
|
return getattr(module, item)
|
|
177
201
|
|
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20260113.134416'
|
|
2
|
+
long_version = '1.20260113.134416+git.8c8a566'
|
returnn/datasets/lm.py
CHANGED
|
@@ -86,6 +86,7 @@ class LmDataset(CachedDataset2):
|
|
|
86
86
|
delayed_seq_data_start_symbol="[START]",
|
|
87
87
|
dtype: Optional[str] = None,
|
|
88
88
|
tag_prefix: Optional[str] = None,
|
|
89
|
+
_debug_limit_line_count: Optional[int] = None,
|
|
89
90
|
**kwargs,
|
|
90
91
|
):
|
|
91
92
|
"""
|
|
@@ -138,6 +139,8 @@ class LmDataset(CachedDataset2):
|
|
|
138
139
|
delayed_seq_data_start_symbol + original_sequence[:-1].
|
|
139
140
|
:param str delayed_seq_data_start_symbol: used for add_delayed_seq_data.
|
|
140
141
|
:param dtype: explicit dtype. if not given, automatically determined based on the number of labels.
|
|
142
|
+
:param tag_prefix: prefix for sequence tags. by default "line-".
|
|
143
|
+
:param _debug_limit_line_count:
|
|
141
144
|
"""
|
|
142
145
|
super(LmDataset, self).__init__(**kwargs)
|
|
143
146
|
|
|
@@ -316,6 +319,10 @@ class LmDataset(CachedDataset2):
|
|
|
316
319
|
self.num_skipped = 0
|
|
317
320
|
self.num_unknown = 0
|
|
318
321
|
|
|
322
|
+
if _debug_limit_line_count is None:
|
|
323
|
+
_debug_limit_line_count = _get_debug_limit_line_count()
|
|
324
|
+
self._debug_limit_line_count = _debug_limit_line_count
|
|
325
|
+
|
|
319
326
|
def _lazy_init(self):
|
|
320
327
|
if self._orths_offsets_and_lens is not None:
|
|
321
328
|
return
|
|
@@ -340,6 +347,9 @@ class LmDataset(CachedDataset2):
|
|
|
340
347
|
lens_per_corpus_file = []
|
|
341
348
|
start_time = time.time()
|
|
342
349
|
last_print_time = start_time
|
|
350
|
+
debug_limit_line_count = self._debug_limit_line_count
|
|
351
|
+
debug_limit_est_total = 0
|
|
352
|
+
debug_limit_hit = False
|
|
343
353
|
|
|
344
354
|
def _init_tmp_file():
|
|
345
355
|
nonlocal tmp_file, tmp_file_orth_files_index
|
|
@@ -368,13 +378,16 @@ class LmDataset(CachedDataset2):
|
|
|
368
378
|
|
|
369
379
|
if time.time() - last_print_time > 10:
|
|
370
380
|
print(
|
|
371
|
-
f" ... loaded {len(
|
|
381
|
+
f" ... loaded {len(orths)} sequences,"
|
|
372
382
|
f" {human_bytes_size(total_bytes_read)},"
|
|
373
383
|
f" after {hms(time.time() - start_time)}",
|
|
374
384
|
file=log.v4,
|
|
375
385
|
)
|
|
376
386
|
last_print_time = time.time()
|
|
377
387
|
|
|
388
|
+
if debug_limit_line_count is not None and len(orths) - prev_orth_len >= debug_limit_line_count:
|
|
389
|
+
raise _ReachedDebugLimitLineCount()
|
|
390
|
+
|
|
378
391
|
# If a list of files is provided, concatenate all.
|
|
379
392
|
if isinstance(corpus_file, str):
|
|
380
393
|
corpus_file = [corpus_file]
|
|
@@ -383,37 +396,46 @@ class LmDataset(CachedDataset2):
|
|
|
383
396
|
for file_name in corpus_file:
|
|
384
397
|
if self._use_cache_manager:
|
|
385
398
|
file_name = cf(file_name)
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
#
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
if _is_bliss(file_name):
|
|
402
|
+
_init_tmp_file()
|
|
403
|
+
_iter_bliss(filename=file_name, callback=_tmp_file_add_line, decode=False)
|
|
404
|
+
elif file_name.endswith(".gz"):
|
|
405
|
+
_init_tmp_file()
|
|
406
|
+
_iter_txt(
|
|
407
|
+
filename=file_name,
|
|
408
|
+
callback=_tmp_file_add_line,
|
|
409
|
+
skip_empty_lines=self._skip_empty_lines,
|
|
410
|
+
decode=False,
|
|
411
|
+
)
|
|
412
|
+
else: # Raw txt file
|
|
413
|
+
# Directly mmap the file.
|
|
414
|
+
# We just need to scan once through it to find line offsets.
|
|
415
|
+
file = open(file_name, "rb")
|
|
416
|
+
file_mmap = mmap.mmap(file.fileno(), 0, flags=mmap.MAP_PRIVATE)
|
|
417
|
+
file_index = len(self._orth_files)
|
|
418
|
+
self._orth_files.append(file)
|
|
419
|
+
self._orth_mmaps.append(file_mmap)
|
|
420
|
+
|
|
421
|
+
pos = 0
|
|
422
|
+
while True:
|
|
423
|
+
next_new_line = file_mmap.find(b"\n", pos)
|
|
424
|
+
if next_new_line == -1:
|
|
425
|
+
break
|
|
426
|
+
line_len = next_new_line - pos
|
|
427
|
+
if line_len or not self._skip_empty_lines:
|
|
428
|
+
orths.append((file_index, pos, line_len))
|
|
429
|
+
total_bytes_read += line_len + 1
|
|
430
|
+
pos = next_new_line + 1
|
|
431
|
+
_maybe_report_status()
|
|
432
|
+
|
|
433
|
+
except _ReachedDebugLimitLineCount as exc:
|
|
434
|
+
assert exc.estimated_total_num_seqs is not None # currently only for _iter_txt implemented
|
|
435
|
+
debug_limit_est_total += exc.estimated_total_num_seqs
|
|
436
|
+
debug_limit_hit = True
|
|
437
|
+
else: # iteration completed without hitting debug limit
|
|
438
|
+
debug_limit_est_total += len(orths) - prev_orth_len
|
|
417
439
|
|
|
418
440
|
lens_per_corpus_file.append(len(orths) - prev_orth_len)
|
|
419
441
|
prev_orth_len = len(orths)
|
|
@@ -447,6 +469,18 @@ class LmDataset(CachedDataset2):
|
|
|
447
469
|
file=log.v4,
|
|
448
470
|
)
|
|
449
471
|
|
|
472
|
+
if debug_limit_hit:
|
|
473
|
+
est_frac_loaded = len(self._orths_offsets_and_lens) / debug_limit_est_total
|
|
474
|
+
new_partition_epoch = max(int(self.partition_epoch * est_frac_loaded), 1)
|
|
475
|
+
print(
|
|
476
|
+
f"LmDataset: debug limit of {debug_limit_line_count} lines (per file) hit,"
|
|
477
|
+
f" estimated total num seqs {debug_limit_est_total},"
|
|
478
|
+
f" loaded {len(self._orths_offsets_and_lens)}, {est_frac_loaded:.2%},"
|
|
479
|
+
f" adjusting partition_epoch from {self.partition_epoch} to {new_partition_epoch}",
|
|
480
|
+
file=log.v4,
|
|
481
|
+
)
|
|
482
|
+
self.partition_epoch = new_partition_epoch
|
|
483
|
+
|
|
450
484
|
# It's only estimated because we might filter some out or so.
|
|
451
485
|
self._estimated_num_seqs = len(self._orths_offsets_and_lens) // self.partition_epoch
|
|
452
486
|
|
|
@@ -784,19 +818,34 @@ def _iter_txt(
|
|
|
784
818
|
:param decode:
|
|
785
819
|
"""
|
|
786
820
|
f = open(filename, "rb")
|
|
821
|
+
f_ = f
|
|
787
822
|
if filename.endswith(".gz"):
|
|
788
823
|
f = gzip.GzipFile(fileobj=f)
|
|
789
824
|
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
825
|
+
count = 0
|
|
826
|
+
try:
|
|
827
|
+
for line in f:
|
|
828
|
+
if decode:
|
|
829
|
+
try:
|
|
830
|
+
line = line.decode("utf8")
|
|
831
|
+
except UnicodeDecodeError:
|
|
832
|
+
line = line.decode("latin_1") # or iso8859_15?
|
|
833
|
+
line = line.strip()
|
|
834
|
+
if skip_empty_lines and not line:
|
|
835
|
+
continue
|
|
836
|
+
count += 1
|
|
837
|
+
callback(line)
|
|
838
|
+
|
|
839
|
+
except _ReachedDebugLimitLineCount as exc:
|
|
840
|
+
print(f"Reached debug limit line count for {filename}, stopping early", file=log.v4)
|
|
841
|
+
pos = f_.tell()
|
|
842
|
+
f_.seek(0, os.SEEK_END)
|
|
843
|
+
size = f_.tell()
|
|
844
|
+
print(f" stopped at byte {human_bytes_size(pos)} / {human_bytes_size(size)}", file=log.v4)
|
|
845
|
+
estimated_num_seqs = int(count * (size / pos))
|
|
846
|
+
print(f" estimated total num seqs: {estimated_num_seqs}", file=log.v4)
|
|
847
|
+
exc.estimated_total_num_seqs = estimated_num_seqs
|
|
848
|
+
raise
|
|
800
849
|
|
|
801
850
|
|
|
802
851
|
def iter_corpus(
|
|
@@ -2517,6 +2566,25 @@ def get_post_processor_function(opts):
|
|
|
2517
2566
|
return chained_post_processors
|
|
2518
2567
|
|
|
2519
2568
|
|
|
2569
|
+
def _get_debug_limit_line_count() -> Optional[int]:
|
|
2570
|
+
"""
|
|
2571
|
+
:return: if set, limit to this many lines for debugging
|
|
2572
|
+
"""
|
|
2573
|
+
from returnn.config import get_global_config
|
|
2574
|
+
|
|
2575
|
+
config = get_global_config(raise_exception=False)
|
|
2576
|
+
if not config:
|
|
2577
|
+
return None
|
|
2578
|
+
|
|
2579
|
+
return config.int("lm_dataset_debug_limit_line_count", None)
|
|
2580
|
+
|
|
2581
|
+
|
|
2582
|
+
class _ReachedDebugLimitLineCount(Exception):
|
|
2583
|
+
"""internal exception to signal reached debug limit line count"""
|
|
2584
|
+
|
|
2585
|
+
estimated_total_num_seqs: Optional[int] = None
|
|
2586
|
+
|
|
2587
|
+
|
|
2520
2588
|
def _main():
|
|
2521
2589
|
from returnn.util import better_exchook
|
|
2522
2590
|
|
|
@@ -67,6 +67,24 @@ def _code_hash_md5(filename: str) -> str:
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
_is_set_up = False
|
|
70
|
+
_enabled = True
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def set_enabled(enabled: bool):
|
|
74
|
+
"""
|
|
75
|
+
Enable or disable the native code setup.
|
|
76
|
+
|
|
77
|
+
:param enabled:
|
|
78
|
+
"""
|
|
79
|
+
global _enabled
|
|
80
|
+
_enabled = enabled
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def is_set_up() -> bool:
|
|
84
|
+
"""
|
|
85
|
+
:return: whether the native code is set up
|
|
86
|
+
"""
|
|
87
|
+
return _is_set_up
|
|
70
88
|
|
|
71
89
|
|
|
72
90
|
def setup():
|
|
@@ -76,6 +94,8 @@ def setup():
|
|
|
76
94
|
global _is_set_up
|
|
77
95
|
if _is_set_up:
|
|
78
96
|
return
|
|
97
|
+
if not _enabled:
|
|
98
|
+
return
|
|
79
99
|
_is_set_up = True # only try once
|
|
80
100
|
|
|
81
101
|
from returnn.tensor import Tensor, Dim
|
|
@@ -177,6 +197,8 @@ def setup_torch():
|
|
|
177
197
|
global _is_set_up_torch
|
|
178
198
|
if _is_set_up_torch:
|
|
179
199
|
return
|
|
200
|
+
if not _enabled:
|
|
201
|
+
return
|
|
180
202
|
_is_set_up_torch = True # only try once
|
|
181
203
|
|
|
182
204
|
import torch
|
returnn/frontend/_utils.py
CHANGED
|
@@ -110,7 +110,7 @@ def bin_op_out_template(
|
|
|
110
110
|
all_dims.extend([dim_ for dim_ in a.dims if dim_ == dim])
|
|
111
111
|
else:
|
|
112
112
|
all_dims.extend([dim_ for dim_ in b.dims if dim_ == dim])
|
|
113
|
-
if all(set(x.dims) != set(all_dims) for x in (a, b)):
|
|
113
|
+
if all([set(x.dims) != set(all_dims) for x in (a, b)]):
|
|
114
114
|
if allow_broadcast_all_sources is False:
|
|
115
115
|
raise ValueError(f"compare: sources {a!r} {b!r} not allowed with allow_broadcast_all_sources=False")
|
|
116
116
|
elif allow_broadcast_all_sources is None:
|
returnn/frontend/array_.py
CHANGED
|
@@ -195,7 +195,7 @@ def merge_dims(
|
|
|
195
195
|
if out_dim is None:
|
|
196
196
|
from returnn.util.basic import prod
|
|
197
197
|
|
|
198
|
-
if any(d.need_masking() for d in dims[1:]):
|
|
198
|
+
if any([d.need_masking() for d in dims[1:]]):
|
|
199
199
|
# The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
|
|
200
200
|
# This would then potentially discard some of the data in the tensor in subsequent operations,
|
|
201
201
|
# when masking is applied.
|
|
@@ -910,7 +910,7 @@ def scatter(
|
|
|
910
910
|
else:
|
|
911
911
|
raise ValueError(f"scatter: invalid mode {mode!r}")
|
|
912
912
|
indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
|
|
913
|
-
if any(dim.need_masking() for dim in indices_dim):
|
|
913
|
+
if any([dim.need_masking() for dim in indices_dim]):
|
|
914
914
|
if use_mask is None:
|
|
915
915
|
use_mask = rf.use_mask_default(
|
|
916
916
|
default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
|
returnn/tensor/_dim_extra.py
CHANGED
|
@@ -858,7 +858,7 @@ class _DimMixin:
|
|
|
858
858
|
self._make_extra()
|
|
859
859
|
dim_order_default = self.dyn_size_ext.dims + (self,)
|
|
860
860
|
if dim_order is not None:
|
|
861
|
-
dim_order = tuple(d for d in dim_order if d in dim_order_default) # filter
|
|
861
|
+
dim_order = tuple([d for d in dim_order if d in dim_order_default]) # filter
|
|
862
862
|
else:
|
|
863
863
|
dim_order = dim_order_default
|
|
864
864
|
cache_key = (device, dim_order)
|
|
@@ -2484,16 +2484,16 @@ _BinOpStrs = {
|
|
|
2484
2484
|
|
|
2485
2485
|
def _math_get_dim_via_bin_op(dims: Sequence[Union[Dim, int]], op_kind: str) -> Dim:
|
|
2486
2486
|
dims = [d if isinstance(d, _d.Dim) else _make_constant_static_dim(d) for d in dims]
|
|
2487
|
-
if all(d.dimension is not None for d in dims):
|
|
2487
|
+
if all([d.dimension is not None for d in dims]):
|
|
2488
2488
|
op = _BinOps[op_kind]
|
|
2489
2489
|
dim_value = dims[0].dimension
|
|
2490
2490
|
for d in dims[1:]:
|
|
2491
2491
|
dim_value = op(dim_value, d.dimension)
|
|
2492
2492
|
else:
|
|
2493
2493
|
dim_value = None
|
|
2494
|
-
if all(d.is_constant_static_dim() for d in dims):
|
|
2494
|
+
if all([d.is_constant_static_dim() for d in dims]):
|
|
2495
2495
|
return _make_constant_static_dim(dim_value, kind=_get_merged_dim_kind(dims))
|
|
2496
|
-
desc = _BinOpStrs[op_kind].join(_get_description(d) for d in dims)
|
|
2496
|
+
desc = _BinOpStrs[op_kind].join([_get_description(d) for d in dims])
|
|
2497
2497
|
if op_kind.startswith("ceildiv"):
|
|
2498
2498
|
desc = f"⌈{desc}⌉"
|
|
2499
2499
|
return _d.Dim(
|
|
@@ -2676,16 +2676,16 @@ def _get_description(dim, brackets=True):
|
|
|
2676
2676
|
|
|
2677
2677
|
|
|
2678
2678
|
def _get_merged_dim_kind(dim_tags: Sequence[Dim]) -> Entity:
|
|
2679
|
-
if any(tag.is_batch_dim() for tag in dim_tags):
|
|
2679
|
+
if any([tag.is_batch_dim() for tag in dim_tags]):
|
|
2680
2680
|
return DimTypes.Batch
|
|
2681
|
-
elif any(tag.is_feature_dim() for tag in dim_tags):
|
|
2681
|
+
elif any([tag.is_feature_dim() for tag in dim_tags]):
|
|
2682
2682
|
return DimTypes.Feature
|
|
2683
2683
|
else:
|
|
2684
2684
|
return DimTypes.Spatial
|
|
2685
2685
|
|
|
2686
2686
|
|
|
2687
2687
|
def _representative_tag(terms: Sequence[Dim]) -> Optional[Dim]:
|
|
2688
|
-
if any(not term_.auto_generated for term_ in terms):
|
|
2688
|
+
if any([not term_.auto_generated for term_ in terms]):
|
|
2689
2689
|
# Always prefer non-auto-generated.
|
|
2690
2690
|
terms = [term_ for term_ in terms if not term_.auto_generated]
|
|
2691
2691
|
# First find any dynamic.
|
returnn/tensor/_tensor_extra.py
CHANGED
|
@@ -32,8 +32,8 @@ class _TensorExtra:
|
|
|
32
32
|
tensor: Tensor,
|
|
33
33
|
time_dim_axis=NotSpecified,
|
|
34
34
|
available_for_inference=True,
|
|
35
|
-
batch=None,
|
|
36
|
-
beam=None,
|
|
35
|
+
batch: Optional[BatchInfo] = None,
|
|
36
|
+
beam: Optional[SearchBeam] = None,
|
|
37
37
|
control_flow_ctx=None,
|
|
38
38
|
):
|
|
39
39
|
"""
|
|
@@ -41,8 +41,8 @@ class _TensorExtra:
|
|
|
41
41
|
:param int|None|NotSpecified time_dim_axis: where we have the time dim axis, after we added the batch-dim.
|
|
42
42
|
this is often 1. however, can be None if there is no time-dim.
|
|
43
43
|
:param bool available_for_inference: e.g. the extern data "classes" is usually not available for inference
|
|
44
|
-
:param
|
|
45
|
-
:param
|
|
44
|
+
:param batch:
|
|
45
|
+
:param beam: the batch-dim could be extended by a beam-size,
|
|
46
46
|
such that it represents the merged dims [batch, beam_size].
|
|
47
47
|
:param ControlFlowContext|None control_flow_ctx:
|
|
48
48
|
"""
|
|
@@ -668,11 +668,11 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
668
668
|
if not perm:
|
|
669
669
|
return self.copy()
|
|
670
670
|
if allow_int and isinstance(perm[0], int):
|
|
671
|
-
assert all(isinstance(a, int) for a in perm), f"{self}: invalid perm {perm!r} types"
|
|
671
|
+
assert all([isinstance(a, int) for a in perm]), f"{self}: invalid perm {perm!r} types"
|
|
672
672
|
assert set(perm) == set(range(len(perm))), f"{self}: invalid perm {perm!r}"
|
|
673
673
|
return self._copy_compatible_to_dims_with_perm([self._dims[i] for i in perm], perm)
|
|
674
674
|
else:
|
|
675
|
-
assert all(isinstance(a, Dim) for a in perm), f"{self}: invalid perm {perm!r} types"
|
|
675
|
+
assert all([isinstance(a, Dim) for a in perm]), f"{self}: invalid perm {perm!r} types"
|
|
676
676
|
return self.copy_compatible_to_dims(perm)
|
|
677
677
|
|
|
678
678
|
def copy_move_axis(self, old_axis, new_axis) -> _t.Tensor:
|
|
@@ -1155,7 +1155,7 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
1155
1155
|
)
|
|
1156
1156
|
|
|
1157
1157
|
assert v.batch_ndim == data.batch_ndim
|
|
1158
|
-
assert all(mapped_axes[ax] == ax for ax in range(v.batch_ndim))
|
|
1158
|
+
assert all([mapped_axes[ax] == ax for ax in range(v.batch_ndim)])
|
|
1159
1159
|
|
|
1160
1160
|
if self.version == 1:
|
|
1161
1161
|
# Ensure time_dim_axis and feature_dim_axis is same as in data
|
|
@@ -1702,7 +1702,7 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
1702
1702
|
"""
|
|
1703
1703
|
:return: shape with added batch-dim. e.g. (batch,time,feat) = (None,None,128)
|
|
1704
1704
|
"""
|
|
1705
|
-
return tuple(tag.dimension for tag in self.dim_tags)
|
|
1705
|
+
return tuple([tag.dimension for tag in self.dim_tags])
|
|
1706
1706
|
|
|
1707
1707
|
# noinspection PyShadowingNames
|
|
1708
1708
|
def get_batch_shape(self, batch_dim):
|
|
@@ -3214,7 +3214,7 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
3214
3214
|
if len(sources) == 1:
|
|
3215
3215
|
return sources[0].copy_template()
|
|
3216
3216
|
max_ndim = max([s.batch_ndim for s in sources])
|
|
3217
|
-
if any(src.batch for src in sources):
|
|
3217
|
+
if any([src.batch for src in sources]):
|
|
3218
3218
|
from returnn.tf.util.data import BatchInfo
|
|
3219
3219
|
|
|
3220
3220
|
common_batch = BatchInfo.get_common_batch_info([src.batch for src in sources if src.batch])
|
|
@@ -3254,7 +3254,7 @@ class _TensorMixin(_TensorMixinBase):
|
|
|
3254
3254
|
else:
|
|
3255
3255
|
axis = common.get_default_new_axis_for_dim_tag(dim_tag)
|
|
3256
3256
|
common = common.copy_add_dim_by_tag(dim_tag, unbroadcast=True, axis=axis)
|
|
3257
|
-
if all(s.batch_ndim < common.batch_ndim for s in sources):
|
|
3257
|
+
if all([s.batch_ndim < common.batch_ndim for s in sources]):
|
|
3258
3258
|
from returnn.util.basic import validate_broadcast_all_sources
|
|
3259
3259
|
|
|
3260
3260
|
validate_broadcast_all_sources(
|
returnn/torch/engine.py
CHANGED
|
@@ -3,9 +3,11 @@ Main engine for PyTorch
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
+
|
|
6
7
|
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
7
8
|
from contextlib import nullcontext, ExitStack, contextmanager
|
|
8
9
|
|
|
10
|
+
import sys
|
|
9
11
|
import gc
|
|
10
12
|
import os
|
|
11
13
|
import time
|
|
@@ -20,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel
|
|
|
20
22
|
from torch.utils.data import DataLoader
|
|
21
23
|
from torch import autocast
|
|
22
24
|
from torch.cuda import amp
|
|
25
|
+
from torch.profiler import record_function
|
|
23
26
|
import numpy as np
|
|
24
27
|
|
|
25
28
|
import returnn
|
|
@@ -404,10 +407,14 @@ class Engine(EngineBase):
|
|
|
404
407
|
total_data_size_packed = NumbersDict()
|
|
405
408
|
total_data_size_padded = NumbersDict()
|
|
406
409
|
|
|
410
|
+
prof = _opt_torch_profiler_from_opts(self.config.opt_typed_value("torch_profile"))
|
|
411
|
+
if prof:
|
|
412
|
+
prof.__enter__()
|
|
413
|
+
|
|
407
414
|
report_prefix = f"ep {self.epoch} train"
|
|
408
415
|
try:
|
|
409
416
|
while True:
|
|
410
|
-
with torch.no_grad():
|
|
417
|
+
with torch.no_grad(), record_function("data_loading"):
|
|
411
418
|
extern_data_raw = next(data_iter, None)
|
|
412
419
|
|
|
413
420
|
step_begin_time = time.monotonic()
|
|
@@ -485,7 +492,8 @@ class Engine(EngineBase):
|
|
|
485
492
|
with (
|
|
486
493
|
self._ddp_pt_model.no_sync()
|
|
487
494
|
if (self._ddp_pt_model is not None and not perform_update_step)
|
|
488
|
-
else nullcontext()
|
|
495
|
+
else nullcontext(),
|
|
496
|
+
record_function("backward"),
|
|
489
497
|
):
|
|
490
498
|
if self._grad_scaler is not None:
|
|
491
499
|
self._grad_scaler.scale(total_loss.raw_tensor).backward()
|
|
@@ -500,7 +508,8 @@ class Engine(EngineBase):
|
|
|
500
508
|
|
|
501
509
|
# only update the weights when every gradient accumulation loop ends
|
|
502
510
|
if perform_update_step:
|
|
503
|
-
|
|
511
|
+
with record_function("optimizer_step"):
|
|
512
|
+
self._updater.step(grad_scaler=self._grad_scaler)
|
|
504
513
|
zero_grad_next_step = perform_update_step
|
|
505
514
|
|
|
506
515
|
if self._torch_distributed_ctx:
|
|
@@ -582,10 +591,19 @@ class Engine(EngineBase):
|
|
|
582
591
|
self._updater.set_current_train_step(
|
|
583
592
|
global_train_step=self.global_train_step, epoch=self.epoch, epoch_continuous=epoch_continuous
|
|
584
593
|
)
|
|
594
|
+
|
|
595
|
+
if prof:
|
|
596
|
+
prof.step()
|
|
597
|
+
|
|
585
598
|
except Exception as exc:
|
|
599
|
+
if prof:
|
|
600
|
+
prof.__exit__(type(exc), exc, exc.__traceback__)
|
|
586
601
|
help_on_torch_exception(exc, step_idx=step_idx, model=self._orig_model, extern_data=extern_data)
|
|
587
602
|
raise
|
|
588
603
|
|
|
604
|
+
if prof:
|
|
605
|
+
prof.__exit__(None, None, None)
|
|
606
|
+
|
|
589
607
|
elapsed = time.monotonic() - epoch_start_time
|
|
590
608
|
elapsed_computation_percentage = elapsed_computation_time / elapsed
|
|
591
609
|
total_padding_ratio = NumbersDict.constant_like(1.0, total_data_size_packed) - (
|
|
@@ -885,6 +903,7 @@ class Engine(EngineBase):
|
|
|
885
903
|
if self._default_float_dtype:
|
|
886
904
|
stack.enter_context(rf.set_default_float_dtype_ctx(str(self._default_float_dtype).split(".")[-1]))
|
|
887
905
|
stack.enter_context(_set_torch_default_dtype_ctx_mgr(self._default_float_dtype))
|
|
906
|
+
stack.enter_context(record_function("model_step"))
|
|
888
907
|
yield
|
|
889
908
|
|
|
890
909
|
def _run_step(
|
|
@@ -1734,3 +1753,101 @@ def _torch_load(filename: Union[str, os.PathLike], *, device: str) -> Dict[str,
|
|
|
1734
1753
|
return safetensors_load(filename, device=device)
|
|
1735
1754
|
|
|
1736
1755
|
return torch.load(filename, map_location=device)
|
|
1756
|
+
|
|
1757
|
+
|
|
1758
|
+
class _TorchProfiler:
|
|
1759
|
+
def __init__(self, profiler: torch.profiler.profile, max_step: Optional[int]):
|
|
1760
|
+
self.profiler = profiler
|
|
1761
|
+
self.max_step = max_step
|
|
1762
|
+
self.entered = False
|
|
1763
|
+
|
|
1764
|
+
def __enter__(self):
|
|
1765
|
+
self.profiler.__enter__()
|
|
1766
|
+
self.entered = True
|
|
1767
|
+
|
|
1768
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1769
|
+
if not self.entered:
|
|
1770
|
+
return
|
|
1771
|
+
self.entered = False
|
|
1772
|
+
self.profiler.__exit__(exc_type, exc_val, exc_tb)
|
|
1773
|
+
|
|
1774
|
+
if exc_type is None:
|
|
1775
|
+
print(
|
|
1776
|
+
"Torch profiling finished, exporting Chrome trace to torch_profile.json,"
|
|
1777
|
+
" memory timeline to torch_memory_profile.html...",
|
|
1778
|
+
file=log.v2,
|
|
1779
|
+
)
|
|
1780
|
+
self.profiler.export_chrome_trace("torch_profile.json")
|
|
1781
|
+
self.profiler.export_memory_timeline("torch_memory_profile.html")
|
|
1782
|
+
|
|
1783
|
+
print("Exiting program after Torch profiling.", file=log.v2)
|
|
1784
|
+
sys.exit(0)
|
|
1785
|
+
|
|
1786
|
+
def step(self):
|
|
1787
|
+
"""step"""
|
|
1788
|
+
self.profiler.step()
|
|
1789
|
+
if self.max_step is not None and self.profiler.step_num > self.max_step:
|
|
1790
|
+
print(f"Reached max profiling step {self.max_step}, stopping Torch profiler.", file=log.v2)
|
|
1791
|
+
self.profiler.stop()
|
|
1792
|
+
self.__exit__(None, None, None)
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
def _opt_torch_profiler_from_opts(
|
|
1796
|
+
opts: Union[None, int, bool, str, Dict[str, Any]],
|
|
1797
|
+
) -> Optional[_TorchProfiler]:
|
|
1798
|
+
if isinstance(opts, str):
|
|
1799
|
+
from returnn.util.basic import to_bool
|
|
1800
|
+
|
|
1801
|
+
opts = to_bool(opts)
|
|
1802
|
+
|
|
1803
|
+
if opts is None:
|
|
1804
|
+
return None
|
|
1805
|
+
elif isinstance(opts, (bool, int)):
|
|
1806
|
+
if not opts:
|
|
1807
|
+
return None
|
|
1808
|
+
opts = {}
|
|
1809
|
+
elif isinstance(opts, dict):
|
|
1810
|
+
opts = opts.copy()
|
|
1811
|
+
else:
|
|
1812
|
+
raise TypeError(f"Invalid type for torch_profile {opts!r}: {type(opts)}")
|
|
1813
|
+
|
|
1814
|
+
from torch.profiler import profile, ProfilerActivity, schedule
|
|
1815
|
+
|
|
1816
|
+
print("Using Torch profiler...", file=log.v2)
|
|
1817
|
+
|
|
1818
|
+
prof_max_step = None
|
|
1819
|
+
|
|
1820
|
+
if "activities" not in opts:
|
|
1821
|
+
activities = [ProfilerActivity.CPU]
|
|
1822
|
+
if torch.cuda.is_available():
|
|
1823
|
+
activities += [ProfilerActivity.CUDA]
|
|
1824
|
+
elif torch.xpu.is_available():
|
|
1825
|
+
activities += [ProfilerActivity.XPU]
|
|
1826
|
+
opts["activities"] = activities
|
|
1827
|
+
|
|
1828
|
+
opts.setdefault("profile_memory", True)
|
|
1829
|
+
opts.setdefault("record_shapes", True)
|
|
1830
|
+
opts.setdefault("with_stack", True)
|
|
1831
|
+
opts.setdefault("with_flops", True)
|
|
1832
|
+
# Note: active*repeat are the steps we actually profile.
|
|
1833
|
+
opts.setdefault("schedule", dict(skip_first=10, wait=5, warmup=3, active=3, repeat=1))
|
|
1834
|
+
|
|
1835
|
+
if isinstance(opts["schedule"], dict):
|
|
1836
|
+
schedule_opts: Dict[str, Any] = opts["schedule"]
|
|
1837
|
+
schedule_opts = schedule_opts.copy()
|
|
1838
|
+
schedule_opts.setdefault("repeat", 0)
|
|
1839
|
+
schedule_opts.setdefault("skip_first", 0)
|
|
1840
|
+
schedule_opts.setdefault("skip_first_wait", 0)
|
|
1841
|
+
opts["schedule"] = schedule(**schedule_opts)
|
|
1842
|
+
|
|
1843
|
+
if schedule_opts["repeat"] > 0:
|
|
1844
|
+
prof_max_step = (schedule_opts["wait"] + schedule_opts["warmup"] + schedule_opts["active"]) * schedule_opts[
|
|
1845
|
+
"repeat"
|
|
1846
|
+
]
|
|
1847
|
+
prof_max_step += schedule_opts["skip_first"]
|
|
1848
|
+
if schedule_opts["skip_first_wait"] != 0:
|
|
1849
|
+
prof_max_step -= schedule_opts["wait"]
|
|
1850
|
+
print(f"Profiling will stop automatically after {prof_max_step} steps.", file=log.v3)
|
|
1851
|
+
|
|
1852
|
+
prof = profile(**opts)
|
|
1853
|
+
return _TorchProfiler(prof, prof_max_step)
|
|
@@ -275,7 +275,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
275
275
|
:return: tensor
|
|
276
276
|
"""
|
|
277
277
|
assert len(dims) >= 2
|
|
278
|
-
first_axis = min(source.dims.index(d) for d in dims)
|
|
278
|
+
first_axis = min([source.dims.index(d) for d in dims])
|
|
279
279
|
pre_dims = source.dims[:first_axis]
|
|
280
280
|
post_dims = [d for d in source.dims if d not in dims and d not in pre_dims]
|
|
281
281
|
source = source.copy_transpose(tuple(pre_dims) + tuple(dims) + tuple(post_dims), allow_int=False)
|
|
@@ -884,7 +884,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
884
884
|
:param perm: e.g. [0, 2, 1]
|
|
885
885
|
:return: permuted (transposed) raw tensor; wraps torch.permute
|
|
886
886
|
"""
|
|
887
|
-
if all(p == i for i, p in enumerate(perm)):
|
|
887
|
+
if all([p == i for i, p in enumerate(perm)]):
|
|
888
888
|
return raw_tensor
|
|
889
889
|
return torch.permute(raw_tensor, tuple(perm))
|
|
890
890
|
|
|
@@ -1788,7 +1788,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1788
1788
|
remaining_dims = [d for d in tensor.dims if d not in mask.dims]
|
|
1789
1789
|
tensor_templ_dims = tuple(dims) + tuple(remaining_dims)
|
|
1790
1790
|
in_raw = tensor.copy_compatible_to_dims_raw(tensor_templ_dims)
|
|
1791
|
-
if any(in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)):
|
|
1791
|
+
if any([in_raw.shape[i] == 1 < d.get_dim_value() for i, d in enumerate(dims)]):
|
|
1792
1792
|
# unbroadcast
|
|
1793
1793
|
in_raw = in_raw.expand([d.get_dim_value() for d in tensor_templ_dims])
|
|
1794
1794
|
if mask.raw_tensor.device.type == "meta":
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helpers to improve torch.compile on RF code.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Any, Iterable, List, Tuple
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from returnn.tensor import Tensor, Dim
|
|
10
|
+
|
|
11
|
+
# noinspection PyProtectedMember
|
|
12
|
+
from returnn.frontend import _native
|
|
13
|
+
|
|
14
|
+
_is_set_up = False
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup():
|
|
18
|
+
"""
|
|
19
|
+
Set up the torch.compile helpers for RF code, also including :class:`Tensor` and :class:`Dim`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
global _is_set_up
|
|
23
|
+
if _is_set_up:
|
|
24
|
+
return
|
|
25
|
+
_is_set_up = True # only try once
|
|
26
|
+
|
|
27
|
+
assert not _native.is_set_up(), "Call this setup() as early as possible."
|
|
28
|
+
_native.set_enabled(False)
|
|
29
|
+
|
|
30
|
+
# We have lots of dynamic shapes.
|
|
31
|
+
os.environ["TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS"] = "1"
|
|
32
|
+
|
|
33
|
+
# noinspection PyProtectedMember
|
|
34
|
+
from torch.utils._pytree import register_pytree_node
|
|
35
|
+
|
|
36
|
+
register_pytree_node(Tensor, _tensor_flatten, _tensor_unflatten)
|
|
37
|
+
register_pytree_node(Dim, _dim_flatten, _dim_unflatten)
|
|
38
|
+
|
|
39
|
+
Dim.get_dim_value = _dim_get_dim_value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _tensor_flatten(t: Tensor) -> Tuple[List[Any], Any]:
|
|
43
|
+
"""
|
|
44
|
+
Flatten the tensor for PyTree.
|
|
45
|
+
"""
|
|
46
|
+
return [t.raw_tensor, t.dims, t.sparse_dim], [
|
|
47
|
+
t.name,
|
|
48
|
+
t.dtype,
|
|
49
|
+
t.version,
|
|
50
|
+
t.feature_dim_axis_or_unspecified,
|
|
51
|
+
t.time_dim_axis_or_unspecified,
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _tensor_unflatten(values: Iterable[Any], metadata: Any) -> Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Unflatten the tensor from PyTree.
|
|
58
|
+
"""
|
|
59
|
+
raw_tensor, dims, sparse_dim = values
|
|
60
|
+
name, dtype, version, feature_dim_axis, time_dim_axis = metadata
|
|
61
|
+
return Tensor(
|
|
62
|
+
name=name,
|
|
63
|
+
dims=dims,
|
|
64
|
+
dtype=dtype,
|
|
65
|
+
sparse_dim=sparse_dim,
|
|
66
|
+
feature_dim_axis=feature_dim_axis,
|
|
67
|
+
time_dim_axis=time_dim_axis,
|
|
68
|
+
raw_tensor=raw_tensor,
|
|
69
|
+
version=version,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _dim_flatten(d: Dim) -> Tuple[List[Any], Any]:
|
|
74
|
+
"""
|
|
75
|
+
Flatten the dim for PyTree.
|
|
76
|
+
"""
|
|
77
|
+
return [d.dyn_size_ext], [d.name, d.dimension, d.size]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _dim_unflatten(values: Iterable[Any], metadata: Any) -> Dim:
|
|
81
|
+
"""
|
|
82
|
+
Unflatten the dim from PyTree.
|
|
83
|
+
"""
|
|
84
|
+
(dyn_size_ext,) = values
|
|
85
|
+
name, dimension, size = metadata
|
|
86
|
+
# TODO this creates a new instance... this is maybe wrong?
|
|
87
|
+
return Dim(name=name, dimension=dimension, size=size, dyn_size_ext=dyn_size_ext)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _dim_get_dim_value(self: Dim) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Infers the dim this axis should have if unbroadcasted.
|
|
93
|
+
If `self.src_data` has a placeholder, will use the shape from there.
|
|
94
|
+
Otherwise, uses `self.dimension` (if static) or `self.dyn_size` (if dynamic).
|
|
95
|
+
|
|
96
|
+
:return: max(size or dyn_size)
|
|
97
|
+
"""
|
|
98
|
+
res = self.get_dim_value_tensor()
|
|
99
|
+
if isinstance(res, Tensor):
|
|
100
|
+
assert res.dims == ()
|
|
101
|
+
assert res.raw_tensor is not None
|
|
102
|
+
# Specifically PyTorch would then treat it as a SymInt in torch.compile,
|
|
103
|
+
# which is important to have for some torch functions (e.g. torch.tile and others).
|
|
104
|
+
return int(res.raw_tensor)
|
|
105
|
+
assert isinstance(res, int)
|
|
106
|
+
return res
|
returnn/util/basic.py
CHANGED
returnn/util/debug.py
CHANGED
|
@@ -309,6 +309,7 @@ def _get_native_signal_handler_lib_filename() -> str:
|
|
|
309
309
|
old_signal_handler[SIGILL] = signal(SIGILL, signal_handler);
|
|
310
310
|
old_signal_handler[SIGABRT] = signal(SIGABRT, signal_handler);
|
|
311
311
|
old_signal_handler[SIGFPE] = signal(SIGFPE, signal_handler);
|
|
312
|
+
old_signal_handler[SIGUSR1] = signal(SIGUSR1, signal_handler);
|
|
312
313
|
}
|
|
313
314
|
"""
|
|
314
315
|
),
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=jhNOEgbBWBglgqkHqni28aMhOK1nHC1dJlBiKkaWfX0,5216
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
|
-
returnn/__old_mod_loader__.py,sha256
|
|
4
|
+
returnn/__old_mod_loader__.py,sha256=-XAtilhq87CqmWmK2awbfGLoPAwjLGVu8t4QAxCw0fQ,9436
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=OVoyfxrF7cQ0OBpiMfzvCidyV9ia6hJFPW-TrKd9BYE,77
|
|
7
7
|
returnn/config.py,sha256=JK8EjDsUdyY2c90s0KY1rLD1kesVfz6vRT0gxy_AQ5I,29142
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -21,7 +21,7 @@ returnn/datasets/distrib_files.py,sha256=48edqdf7YpnPJ-TOis3Mz5U9A2DSxfiYT1HCMSt
|
|
|
21
21
|
returnn/datasets/generating.py,sha256=o9-JZ2s5QKssux6GcSaM3oivf_PE6nhSOeytRyGB7pQ,99574
|
|
22
22
|
returnn/datasets/hdf.py,sha256=v5sjBenURR9Z-g7AQ9tsL84yDSye5RtbLpym3M6HSDE,67833
|
|
23
23
|
returnn/datasets/huggingface.py,sha256=ls9WMR6gUcMgGksl80g0An1az5Xjya_V3ojbbbsZqrU,20047
|
|
24
|
-
returnn/datasets/lm.py,sha256=
|
|
24
|
+
returnn/datasets/lm.py,sha256=riDa7rkwOuPX53_0y9wgQ_s2A9453BX0gWGV0HX29_M,103614
|
|
25
25
|
returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
|
|
26
26
|
returnn/datasets/meta.py,sha256=hTtfwINIxP2S4JQ5IQXzvTh2MixwxzeF06pPTW36yl0,101456
|
|
27
27
|
returnn/datasets/multi_proc.py,sha256=BClXq0fActi1XQa4vcMhHmhYF0Q-fnnDzlIlbBM6_DM,22614
|
|
@@ -80,8 +80,8 @@ returnn/frontend/_backend.py,sha256=MVZn2HSkF3tsqchYvy2QM9pA4ILdKq07kj-_AAHGUy0,
|
|
|
80
80
|
returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,8550
|
|
81
81
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
82
82
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
83
|
-
returnn/frontend/_utils.py,sha256=
|
|
84
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/_utils.py,sha256=LTwYQJBT9XjRdC2kVvHy29eUN5qARNSLGMJk90a8PjI,12076
|
|
84
|
+
returnn/frontend/array_.py,sha256=2VQYtlB6OiKdpkU9H_w_jIUrb7mlxizz7KKOHjnYaeo,56795
|
|
85
85
|
returnn/frontend/attention.py,sha256=bFD9Ei6GxSi-BC1OfueDyTIE-51a3dKKZOWdSIbz7l8,46633
|
|
86
86
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
87
87
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -120,7 +120,7 @@ returnn/frontend/state.py,sha256=EePdrx6PtWL4mJ2XZmGlh5dl4nq6G9wZpqP4hdDEzfY,293
|
|
|
120
120
|
returnn/frontend/stepwise_scheduler.py,sha256=fMOTR7npGCDXrXDmSQ4VwmudoHEbY3Yr-QGyjFdQJSc,927
|
|
121
121
|
returnn/frontend/tensor_array.py,sha256=Ej7CHtvpY0yBROlAk5vFe3CTXh-iAuqu9qcXS3Qxt2I,4328
|
|
122
122
|
returnn/frontend/types.py,sha256=r-QsxPQyFSr9WwCRzqTn_X5jQLbjthrtjHavY8XIDmk,1099
|
|
123
|
-
returnn/frontend/_native/__init__.py,sha256=
|
|
123
|
+
returnn/frontend/_native/__init__.py,sha256=VVK0x6Z7OZa3Sb4QDSz9sRrBhX8FfYdvrwhAg4W9-cc,6839
|
|
124
124
|
returnn/frontend/_native/backend.cpp,sha256=MeHczHypwj_ncntOxRqanK8SqGyV9Eq1X0cpMWb_WII,4768
|
|
125
125
|
returnn/frontend/_native/backend.hpp,sha256=Wq80dcEzXfRNxGOXFnIgHllkiv1rDi3KpHK-xxJsSDI,791
|
|
126
126
|
returnn/frontend/_native/module.cpp,sha256=9BCUoDTZDJ6hlXp4pUus1BlN7-oxcRy6tK9ctyCkwk0,15709
|
|
@@ -155,8 +155,8 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
|
|
|
155
155
|
returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
|
|
156
156
|
returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
|
|
157
157
|
returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
|
|
158
|
-
returnn/tensor/_dim_extra.py,sha256=
|
|
159
|
-
returnn/tensor/_tensor_extra.py,sha256=
|
|
158
|
+
returnn/tensor/_dim_extra.py,sha256=8HLTvgEnThCp7GdtB714Tvs4ad939jZmhpS3qab03sU,116790
|
|
159
|
+
returnn/tensor/_tensor_extra.py,sha256=ClwZBfaOavDtapXYpYRhDTGE85bzvRqox5mF_OnEHds,165112
|
|
160
160
|
returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
|
|
161
161
|
returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
|
|
162
162
|
returnn/tensor/control_flow_ctx.py,sha256=L9e32AfYDUDgsEDHL07thSFyYFqwhyVSqzE_bM03Y4M,5252
|
|
@@ -208,7 +208,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
|
|
|
208
208
|
returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
209
209
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
210
210
|
returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
|
|
211
|
-
returnn/torch/engine.py,sha256=
|
|
211
|
+
returnn/torch/engine.py,sha256=JnoGrAakIUIsSXVEIVIXqTOVcDYJASVoRNZQrOPNrdA,85368
|
|
212
212
|
returnn/torch/updater.py,sha256=nNd1mBPQyvIB096BEFi0KKmRI-U3jnRETzb743p2B9c,32064
|
|
213
213
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
214
214
|
returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
|
|
@@ -217,9 +217,10 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
|
|
|
217
217
|
returnn/torch/data/returnn_dataset_wrapper.py,sha256=fMahf05G0SPYm6HxSQpVm8JhsIHons-i1Ce4aQv4IjM,8332
|
|
218
218
|
returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
|
|
219
219
|
returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
|
|
220
|
-
returnn/torch/frontend/_backend.py,sha256=
|
|
220
|
+
returnn/torch/frontend/_backend.py,sha256=wsmalFnT_p2NDADL8N-6AHHCyv2yBe8nKM-0tKAh1cs,108888
|
|
221
221
|
returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
|
|
222
222
|
returnn/torch/frontend/bridge.py,sha256=RBtAIlYWn_AC-GaHWperrOncPjMLWAOrU30pWk2789A,9775
|
|
223
|
+
returnn/torch/frontend/compile_helper.py,sha256=ax8ax5mjC8PDHtwTQzHYWUNRoKjZMuYHF6me9VdxiSY,2969
|
|
223
224
|
returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
|
|
224
225
|
returnn/torch/optim/README.md,sha256=0iH5FiKb7iDrVK5n8V6yCh4ciCFG2YSbyh7lPneT5ik,360
|
|
225
226
|
returnn/torch/optim/__init__.py,sha256=yxdbnOkXAHzZ_t6cHi6zn5x_DQNlLZJ-KxZByHTIg1U,29
|
|
@@ -234,11 +235,11 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
|
|
|
234
235
|
returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
|
|
235
236
|
returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
|
|
236
237
|
returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
|
|
237
|
-
returnn/util/basic.py,sha256=
|
|
238
|
+
returnn/util/basic.py,sha256=Pa2cAdvOJMKK7gR3heAVTol-zYVbThr9b9slVVAaH3M,143273
|
|
238
239
|
returnn/util/better_exchook.py,sha256=hOKazwv2q2-d0XMfxkJXMbLZyNTtraV3jPHplFcrMsg,71014
|
|
239
240
|
returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
|
|
240
241
|
returnn/util/collect_outputs_dict.py,sha256=CjpsftoMgmvyE4wNKTO6F-QQ_44QHXcOZIXMUMQVZ-8,2637
|
|
241
|
-
returnn/util/debug.py,sha256=
|
|
242
|
+
returnn/util/debug.py,sha256=Ndq5nz-tMEG9ZNwZTbgOkQYB9JSvAwF8r0o53Gf2EbM,28653
|
|
242
243
|
returnn/util/debug_helpers.py,sha256=0EINLK4uLtoSt5_kHs1M2NIFpMd0S7i4c4rx90U4fJk,2914
|
|
243
244
|
returnn/util/file_cache.py,sha256=8xE4zMQi38g7ZIGwNohd13_CgjzpIs18ILxFCKttzxE,29439
|
|
244
245
|
returnn/util/fsa.py,sha256=k2lJ8tyf_g44Xk1EPVLwDwpP4spoMTqIigDVOWocQHY,59177
|
|
@@ -255,8 +256,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
255
256
|
returnn/util/task_system.py,sha256=7Dz7Nvi_1-o5pDv9OZYdAnlJw6OSvgbYUmQ72P0Fgkw,26002
|
|
256
257
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
257
258
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
261
|
-
returnn-1.
|
|
262
|
-
returnn-1.
|
|
259
|
+
returnn-1.20260113.134416.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
260
|
+
returnn-1.20260113.134416.dist-info/METADATA,sha256=jhNOEgbBWBglgqkHqni28aMhOK1nHC1dJlBiKkaWfX0,5216
|
|
261
|
+
returnn-1.20260113.134416.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
262
|
+
returnn-1.20260113.134416.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
263
|
+
returnn-1.20260113.134416.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|