returnn 1.20250304.101951__py3-none-any.whl → 1.20250304.113330__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/torch/engine.py +29 -62
- {returnn-1.20250304.101951.dist-info → returnn-1.20250304.113330.dist-info}/METADATA +1 -1
- {returnn-1.20250304.101951.dist-info → returnn-1.20250304.113330.dist-info}/RECORD +8 -8
- {returnn-1.20250304.101951.dist-info → returnn-1.20250304.113330.dist-info}/LICENSE +0 -0
- {returnn-1.20250304.101951.dist-info → returnn-1.20250304.113330.dist-info}/WHEEL +0 -0
- {returnn-1.20250304.101951.dist-info → returnn-1.20250304.113330.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.20250304.
|
|
2
|
-
long_version = '1.20250304.
|
|
1
|
+
version = '1.20250304.113330'
|
|
2
|
+
long_version = '1.20250304.113330+git.acf09da'
|
returnn/torch/engine.py
CHANGED
|
@@ -3,7 +3,7 @@ Main engine for PyTorch
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
6
|
+
from typing import Optional, Any, Union, Callable, Dict, Set
|
|
7
7
|
from contextlib import nullcontext, ExitStack, contextmanager
|
|
8
8
|
|
|
9
9
|
import gc
|
|
@@ -365,8 +365,6 @@ class Engine(EngineBase):
|
|
|
365
365
|
zero_grad_next_step = True
|
|
366
366
|
cur_count_grad_accum = 0
|
|
367
367
|
extern_data = None
|
|
368
|
-
num_seqs = None
|
|
369
|
-
last_seq_idx = 0
|
|
370
368
|
|
|
371
369
|
total_data_size_packed = NumbersDict()
|
|
372
370
|
total_data_size_padded = NumbersDict()
|
|
@@ -400,20 +398,8 @@ class Engine(EngineBase):
|
|
|
400
398
|
)
|
|
401
399
|
|
|
402
400
|
complete_frac = float(extern_data_raw["complete_frac"])
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
extern_data_raw=extern_data_raw,
|
|
406
|
-
step_idx=step_idx,
|
|
407
|
-
prev_num_seqs=num_seqs,
|
|
408
|
-
prev_last_seq_idx=last_seq_idx,
|
|
409
|
-
)
|
|
410
|
-
epoch_continuous = (
|
|
411
|
-
self.epoch - 1 + complete_frac
|
|
412
|
-
if complete_frac >= 0.0
|
|
413
|
-
else (self.epoch - 1 + (last_seq_idx + 1) / num_seqs)
|
|
414
|
-
if num_seqs is not None
|
|
415
|
-
else None
|
|
416
|
-
)
|
|
401
|
+
epoch_continuous = self.epoch - 1 + complete_frac if complete_frac >= 0.0 else None
|
|
402
|
+
num_seqs = int(extern_data_raw["num_seqs"])
|
|
417
403
|
|
|
418
404
|
# clear the gradients when every gradient accumulation loop starts
|
|
419
405
|
if zero_grad_next_step:
|
|
@@ -490,7 +476,7 @@ class Engine(EngineBase):
|
|
|
490
476
|
eval_info=dict(eval_info),
|
|
491
477
|
step_duration=step_duration,
|
|
492
478
|
start_elapsed=step_end_time - epoch_start_time,
|
|
493
|
-
|
|
479
|
+
complete_frac=complete_frac,
|
|
494
480
|
num_seqs=num_seqs,
|
|
495
481
|
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
|
|
496
482
|
log_memory_usage_device=self._device if self._log_memory_usage else None,
|
|
@@ -629,13 +615,18 @@ class Engine(EngineBase):
|
|
|
629
615
|
accumulated_losses_dict = NumbersDict()
|
|
630
616
|
accumulated_inv_norm_factors_dict = NumbersDict()
|
|
631
617
|
step_idx = 0
|
|
618
|
+
eval_start_time = time.monotonic()
|
|
632
619
|
|
|
620
|
+
report_prefix = f"ep {self.epoch} {dataset_name} eval"
|
|
633
621
|
with torch.no_grad():
|
|
634
622
|
for extern_data_raw in data_loader:
|
|
635
623
|
if self._torch_distributed_ctx and step_idx % 100 == 0:
|
|
636
624
|
_has_data = torch.tensor([True], device="cpu", dtype=torch.int8)
|
|
637
625
|
torch.distributed.broadcast(_has_data, src=0)
|
|
638
626
|
|
|
627
|
+
complete_frac = float(extern_data_raw["complete_frac"])
|
|
628
|
+
num_seqs = int(extern_data_raw["num_seqs"])
|
|
629
|
+
|
|
639
630
|
extern_data = extern_data_util.raw_dict_to_extern_data(
|
|
640
631
|
extern_data_raw,
|
|
641
632
|
extern_data_template=self.extern_data,
|
|
@@ -644,6 +635,8 @@ class Engine(EngineBase):
|
|
|
644
635
|
)
|
|
645
636
|
|
|
646
637
|
self._run_step(extern_data, train_func=True)
|
|
638
|
+
step_end_time = time.monotonic()
|
|
639
|
+
|
|
647
640
|
train_ctx = rf.get_run_ctx()
|
|
648
641
|
|
|
649
642
|
losses_dict = NumbersDict(
|
|
@@ -664,9 +657,12 @@ class Engine(EngineBase):
|
|
|
664
657
|
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
|
|
665
658
|
eval_info = self._maybe_extend_losses_info(losses_dict / inv_norm_factors_dict)
|
|
666
659
|
_print_process(
|
|
667
|
-
|
|
660
|
+
report_prefix,
|
|
668
661
|
step=step_idx,
|
|
669
662
|
eval_info=dict(eval_info),
|
|
663
|
+
complete_frac=complete_frac,
|
|
664
|
+
num_seqs=num_seqs,
|
|
665
|
+
start_elapsed=step_end_time - eval_start_time,
|
|
670
666
|
log_memory_usage_device=self._device if self._log_memory_usage else None,
|
|
671
667
|
)
|
|
672
668
|
step_idx += 1
|
|
@@ -1290,8 +1286,6 @@ class Engine(EngineBase):
|
|
|
1290
1286
|
new_dim.dyn_size_ext = _get_tensor_wo_batch_numpy(dim.dyn_size_ext)
|
|
1291
1287
|
return new_dim
|
|
1292
1288
|
|
|
1293
|
-
num_seqs = None
|
|
1294
|
-
last_seq_idx = 0
|
|
1295
1289
|
report_prefix = f"ep {self.epoch} {dataset.name} forward"
|
|
1296
1290
|
with torch.no_grad():
|
|
1297
1291
|
callback.init(model=self._orig_model)
|
|
@@ -1300,13 +1294,8 @@ class Engine(EngineBase):
|
|
|
1300
1294
|
for extern_data_raw in data_loader:
|
|
1301
1295
|
step_begin_time = time.monotonic()
|
|
1302
1296
|
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
extern_data_raw=extern_data_raw,
|
|
1306
|
-
step_idx=step_idx,
|
|
1307
|
-
prev_num_seqs=num_seqs,
|
|
1308
|
-
prev_last_seq_idx=last_seq_idx,
|
|
1309
|
-
)
|
|
1297
|
+
complete_frac = float(extern_data_raw["complete_frac"])
|
|
1298
|
+
num_seqs = int(extern_data_raw["num_seqs"])
|
|
1310
1299
|
|
|
1311
1300
|
if self._forward_step_expected_outputs:
|
|
1312
1301
|
# Also resets any dyn dims, which might have been set in the prev step.
|
|
@@ -1354,7 +1343,7 @@ class Engine(EngineBase):
|
|
|
1354
1343
|
eval_info=None,
|
|
1355
1344
|
step_duration=step_duration,
|
|
1356
1345
|
start_elapsed=step_end_time - epoch_start_time,
|
|
1357
|
-
|
|
1346
|
+
complete_frac=complete_frac,
|
|
1358
1347
|
num_seqs=num_seqs,
|
|
1359
1348
|
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
|
|
1360
1349
|
log_memory_usage_device=self._device if self._log_memory_usage else None,
|
|
@@ -1442,7 +1431,7 @@ def _print_process(
|
|
|
1442
1431
|
batch_size_info: Optional[Dict[str, Any]] = None,
|
|
1443
1432
|
step_duration: Optional[float] = None,
|
|
1444
1433
|
start_elapsed: Optional[float] = None,
|
|
1445
|
-
|
|
1434
|
+
complete_frac: Optional[float] = None,
|
|
1446
1435
|
num_seqs: Optional[int] = None,
|
|
1447
1436
|
log_memory_usage_device: Optional[str] = None,
|
|
1448
1437
|
):
|
|
@@ -1455,11 +1444,14 @@ def _print_process(
|
|
|
1455
1444
|
:param batch_size_info:
|
|
1456
1445
|
:param step_duration: time elapsed for this step (secs)
|
|
1457
1446
|
:param start_elapsed: time elapsed since epoch start (secs)
|
|
1458
|
-
:param
|
|
1447
|
+
:param complete_frac: how much of the current epoch is already consumed
|
|
1448
|
+
:param num_seqs: total number of seqs this epoch
|
|
1459
1449
|
:param log_memory_usage_device: if given, will log memory usage (peak allocated memory)
|
|
1460
1450
|
:return: nothing, will be printed to log
|
|
1461
1451
|
"""
|
|
1462
1452
|
if log.verbose[5]: # report every minibatch
|
|
1453
|
+
if step == 0 and num_seqs is not None and num_seqs >= 0:
|
|
1454
|
+
print(f"{report_prefix} num_seqs: {num_seqs}", file=log.v5)
|
|
1463
1455
|
info = [report_prefix, "step %i" % step]
|
|
1464
1456
|
if eval_info: # Such as score.
|
|
1465
1457
|
info += ["%s %s" % (k, _format_score_value(v)) for k, v in eval_info.items()]
|
|
@@ -1475,17 +1467,16 @@ def _print_process(
|
|
|
1475
1467
|
info += ["%.3f sec/step" % step_duration]
|
|
1476
1468
|
if start_elapsed is not None:
|
|
1477
1469
|
info += ["elapsed %s" % hms(start_elapsed)]
|
|
1478
|
-
if
|
|
1479
|
-
assert
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
total_time_estimated = start_elapsed / complete
|
|
1470
|
+
if complete_frac is not None:
|
|
1471
|
+
assert 1 >= complete_frac > 0, f"{step} step, {complete_frac} complete_frac"
|
|
1472
|
+
assert start_elapsed is not None
|
|
1473
|
+
total_time_estimated = start_elapsed / complete_frac
|
|
1483
1474
|
remaining_estimated = total_time_estimated - start_elapsed
|
|
1484
1475
|
info += [
|
|
1485
1476
|
"exp. remaining %s" % hms(remaining_estimated),
|
|
1486
|
-
"complete %.02f%%" % (
|
|
1477
|
+
"complete %.02f%%" % (complete_frac * 100),
|
|
1487
1478
|
]
|
|
1488
|
-
if start_elapsed is not None and
|
|
1479
|
+
if start_elapsed is not None and complete_frac is None:
|
|
1489
1480
|
info += ["(unk epoch len)"]
|
|
1490
1481
|
print(", ".join(filter(None, info)), file=log.v5)
|
|
1491
1482
|
|
|
@@ -1634,27 +1625,3 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
|
|
|
1634
1625
|
p=p,
|
|
1635
1626
|
).item()
|
|
1636
1627
|
)
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
def _get_num_seqs_last_seq_idx(
|
|
1640
|
-
*,
|
|
1641
|
-
report_prefix: str,
|
|
1642
|
-
extern_data_raw: Dict[str, Any],
|
|
1643
|
-
step_idx: int,
|
|
1644
|
-
prev_num_seqs: Optional[int],
|
|
1645
|
-
prev_last_seq_idx: int,
|
|
1646
|
-
) -> Tuple[Optional[int], int]:
|
|
1647
|
-
num_seqs = prev_num_seqs
|
|
1648
|
-
num_seqs_ = int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
|
|
1649
|
-
# Note: The batches might have been shuffled,
|
|
1650
|
-
# thus we cannot really assert that the seq_idx is always increasing.
|
|
1651
|
-
last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), prev_last_seq_idx)
|
|
1652
|
-
if step_idx == 0:
|
|
1653
|
-
if num_seqs_ >= 0:
|
|
1654
|
-
print(f"{report_prefix} num_seqs: {num_seqs_}", file=log.v5)
|
|
1655
|
-
num_seqs = num_seqs_
|
|
1656
|
-
elif num_seqs_ >= 0:
|
|
1657
|
-
assert num_seqs_ == num_seqs
|
|
1658
|
-
if num_seqs is not None:
|
|
1659
|
-
assert last_seq_idx < num_seqs
|
|
1660
|
-
return num_seqs, last_seq_idx
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=94BElbYUGmjpsoY8BzvfW39RUTXw9Fy3UwlPoEjrkU8,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
|
|
|
207
207
|
returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=skFyutdVztxgTEk3HHJ8S83qRWbNpkNT8Tj16Ic0_hE,6981
|
|
210
|
-
returnn/torch/engine.py,sha256=
|
|
210
|
+
returnn/torch/engine.py,sha256=2FLLb2m4sWFwYOQGREDSxQCheCKd_osnFJCdLa_4TzE,76400
|
|
211
211
|
returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.20250304.
|
|
257
|
-
returnn-1.20250304.
|
|
258
|
-
returnn-1.20250304.
|
|
259
|
-
returnn-1.20250304.
|
|
260
|
-
returnn-1.20250304.
|
|
256
|
+
returnn-1.20250304.113330.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250304.113330.dist-info/METADATA,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
|
|
258
|
+
returnn-1.20250304.113330.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
259
|
+
returnn-1.20250304.113330.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250304.113330.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|