returnn 1.20250304.101951__py3-none-any.whl → 1.20250304.113330__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250304.101951
3
+ Version: 1.20250304.113330
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250304.101951'
2
- long_version = '1.20250304.101951+git.0fa434e'
1
+ version = '1.20250304.113330'
2
+ long_version = '1.20250304.113330+git.acf09da'
returnn/torch/engine.py CHANGED
@@ -3,7 +3,7 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Any, Union, Callable, Dict, Set, Tuple
6
+ from typing import Optional, Any, Union, Callable, Dict, Set
7
7
  from contextlib import nullcontext, ExitStack, contextmanager
8
8
 
9
9
  import gc
@@ -365,8 +365,6 @@ class Engine(EngineBase):
365
365
  zero_grad_next_step = True
366
366
  cur_count_grad_accum = 0
367
367
  extern_data = None
368
- num_seqs = None
369
- last_seq_idx = 0
370
368
 
371
369
  total_data_size_packed = NumbersDict()
372
370
  total_data_size_padded = NumbersDict()
@@ -400,20 +398,8 @@ class Engine(EngineBase):
400
398
  )
401
399
 
402
400
  complete_frac = float(extern_data_raw["complete_frac"])
403
- num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
404
- report_prefix=report_prefix,
405
- extern_data_raw=extern_data_raw,
406
- step_idx=step_idx,
407
- prev_num_seqs=num_seqs,
408
- prev_last_seq_idx=last_seq_idx,
409
- )
410
- epoch_continuous = (
411
- self.epoch - 1 + complete_frac
412
- if complete_frac >= 0.0
413
- else (self.epoch - 1 + (last_seq_idx + 1) / num_seqs)
414
- if num_seqs is not None
415
- else None
416
- )
401
+ epoch_continuous = self.epoch - 1 + complete_frac if complete_frac >= 0.0 else None
402
+ num_seqs = int(extern_data_raw["num_seqs"])
417
403
 
418
404
  # clear the gradients when every gradient accumulation loop starts
419
405
  if zero_grad_next_step:
@@ -490,7 +476,7 @@ class Engine(EngineBase):
490
476
  eval_info=dict(eval_info),
491
477
  step_duration=step_duration,
492
478
  start_elapsed=step_end_time - epoch_start_time,
493
- seq_idx=last_seq_idx,
479
+ complete_frac=complete_frac,
494
480
  num_seqs=num_seqs,
495
481
  batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
496
482
  log_memory_usage_device=self._device if self._log_memory_usage else None,
@@ -629,13 +615,18 @@ class Engine(EngineBase):
629
615
  accumulated_losses_dict = NumbersDict()
630
616
  accumulated_inv_norm_factors_dict = NumbersDict()
631
617
  step_idx = 0
618
+ eval_start_time = time.monotonic()
632
619
 
620
+ report_prefix = f"ep {self.epoch} {dataset_name} eval"
633
621
  with torch.no_grad():
634
622
  for extern_data_raw in data_loader:
635
623
  if self._torch_distributed_ctx and step_idx % 100 == 0:
636
624
  _has_data = torch.tensor([True], device="cpu", dtype=torch.int8)
637
625
  torch.distributed.broadcast(_has_data, src=0)
638
626
 
627
+ complete_frac = float(extern_data_raw["complete_frac"])
628
+ num_seqs = int(extern_data_raw["num_seqs"])
629
+
639
630
  extern_data = extern_data_util.raw_dict_to_extern_data(
640
631
  extern_data_raw,
641
632
  extern_data_template=self.extern_data,
@@ -644,6 +635,8 @@ class Engine(EngineBase):
644
635
  )
645
636
 
646
637
  self._run_step(extern_data, train_func=True)
638
+ step_end_time = time.monotonic()
639
+
647
640
  train_ctx = rf.get_run_ctx()
648
641
 
649
642
  losses_dict = NumbersDict(
@@ -664,9 +657,12 @@ class Engine(EngineBase):
664
657
  accumulated_inv_norm_factors_dict += inv_norm_factors_dict
665
658
  eval_info = self._maybe_extend_losses_info(losses_dict / inv_norm_factors_dict)
666
659
  _print_process(
667
- f"ep {self.epoch} {dataset_name} eval",
660
+ report_prefix,
668
661
  step=step_idx,
669
662
  eval_info=dict(eval_info),
663
+ complete_frac=complete_frac,
664
+ num_seqs=num_seqs,
665
+ start_elapsed=step_end_time - eval_start_time,
670
666
  log_memory_usage_device=self._device if self._log_memory_usage else None,
671
667
  )
672
668
  step_idx += 1
@@ -1290,8 +1286,6 @@ class Engine(EngineBase):
1290
1286
  new_dim.dyn_size_ext = _get_tensor_wo_batch_numpy(dim.dyn_size_ext)
1291
1287
  return new_dim
1292
1288
 
1293
- num_seqs = None
1294
- last_seq_idx = 0
1295
1289
  report_prefix = f"ep {self.epoch} {dataset.name} forward"
1296
1290
  with torch.no_grad():
1297
1291
  callback.init(model=self._orig_model)
@@ -1300,13 +1294,8 @@ class Engine(EngineBase):
1300
1294
  for extern_data_raw in data_loader:
1301
1295
  step_begin_time = time.monotonic()
1302
1296
 
1303
- num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
1304
- report_prefix=report_prefix,
1305
- extern_data_raw=extern_data_raw,
1306
- step_idx=step_idx,
1307
- prev_num_seqs=num_seqs,
1308
- prev_last_seq_idx=last_seq_idx,
1309
- )
1297
+ complete_frac = float(extern_data_raw["complete_frac"])
1298
+ num_seqs = int(extern_data_raw["num_seqs"])
1310
1299
 
1311
1300
  if self._forward_step_expected_outputs:
1312
1301
  # Also resets any dyn dims, which might have been set in the prev step.
@@ -1354,7 +1343,7 @@ class Engine(EngineBase):
1354
1343
  eval_info=None,
1355
1344
  step_duration=step_duration,
1356
1345
  start_elapsed=step_end_time - epoch_start_time,
1357
- seq_idx=last_seq_idx,
1346
+ complete_frac=complete_frac,
1358
1347
  num_seqs=num_seqs,
1359
1348
  batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
1360
1349
  log_memory_usage_device=self._device if self._log_memory_usage else None,
@@ -1442,7 +1431,7 @@ def _print_process(
1442
1431
  batch_size_info: Optional[Dict[str, Any]] = None,
1443
1432
  step_duration: Optional[float] = None,
1444
1433
  start_elapsed: Optional[float] = None,
1445
- seq_idx: Optional[int] = None,
1434
+ complete_frac: Optional[float] = None,
1446
1435
  num_seqs: Optional[int] = None,
1447
1436
  log_memory_usage_device: Optional[str] = None,
1448
1437
  ):
@@ -1455,11 +1444,14 @@ def _print_process(
1455
1444
  :param batch_size_info:
1456
1445
  :param step_duration: time elapsed for this step (secs)
1457
1446
  :param start_elapsed: time elapsed since epoch start (secs)
1458
- :param num_seqs: total number of sequences for this epoch
1447
+ :param complete_frac: how much of the current epoch is already consumed
1448
+ :param num_seqs: total number of seqs this epoch
1459
1449
  :param log_memory_usage_device: if given, will log memory usage (peak allocated memory)
1460
1450
  :return: nothing, will be printed to log
1461
1451
  """
1462
1452
  if log.verbose[5]: # report every minibatch
1453
+ if step == 0 and num_seqs is not None and num_seqs >= 0:
1454
+ print(f"{report_prefix} num_seqs: {num_seqs}", file=log.v5)
1463
1455
  info = [report_prefix, "step %i" % step]
1464
1456
  if eval_info: # Such as score.
1465
1457
  info += ["%s %s" % (k, _format_score_value(v)) for k, v in eval_info.items()]
@@ -1475,17 +1467,16 @@ def _print_process(
1475
1467
  info += ["%.3f sec/step" % step_duration]
1476
1468
  if start_elapsed is not None:
1477
1469
  info += ["elapsed %s" % hms(start_elapsed)]
1478
- if num_seqs is not None:
1479
- assert seq_idx is not None and start_elapsed is not None # unexpected combination...
1480
- complete = (seq_idx + 1) / num_seqs
1481
- assert 1 >= complete > 0, f"{step} step, {num_seqs} num_seqs"
1482
- total_time_estimated = start_elapsed / complete
1470
+ if complete_frac is not None:
1471
+ assert 1 >= complete_frac > 0, f"{step} step, {complete_frac} complete_frac"
1472
+ assert start_elapsed is not None
1473
+ total_time_estimated = start_elapsed / complete_frac
1483
1474
  remaining_estimated = total_time_estimated - start_elapsed
1484
1475
  info += [
1485
1476
  "exp. remaining %s" % hms(remaining_estimated),
1486
- "complete %.02f%%" % (complete * 100),
1477
+ "complete %.02f%%" % (complete_frac * 100),
1487
1478
  ]
1488
- if start_elapsed is not None and num_seqs is None:
1479
+ if start_elapsed is not None and complete_frac is None:
1489
1480
  info += ["(unk epoch len)"]
1490
1481
  print(", ".join(filter(None, info)), file=log.v5)
1491
1482
 
@@ -1634,27 +1625,3 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1634
1625
  p=p,
1635
1626
  ).item()
1636
1627
  )
1637
-
1638
-
1639
- def _get_num_seqs_last_seq_idx(
1640
- *,
1641
- report_prefix: str,
1642
- extern_data_raw: Dict[str, Any],
1643
- step_idx: int,
1644
- prev_num_seqs: Optional[int],
1645
- prev_last_seq_idx: int,
1646
- ) -> Tuple[Optional[int], int]:
1647
- num_seqs = prev_num_seqs
1648
- num_seqs_ = int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
1649
- # Note: The batches might have been shuffled,
1650
- # thus we cannot really assert that the seq_idx is always increasing.
1651
- last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), prev_last_seq_idx)
1652
- if step_idx == 0:
1653
- if num_seqs_ >= 0:
1654
- print(f"{report_prefix} num_seqs: {num_seqs_}", file=log.v5)
1655
- num_seqs = num_seqs_
1656
- elif num_seqs_ >= 0:
1657
- assert num_seqs_ == num_seqs
1658
- if num_seqs is not None:
1659
- assert last_seq_idx < num_seqs
1660
- return num_seqs, last_seq_idx
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250304.101951
3
+ Version: 1.20250304.113330
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=33hja9F4qBUxg8Y8J2At3XpXgjYhstvueyHkyMN3GdI,5215
1
+ returnn/PKG-INFO,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=Sz2yGoXcP-7QpDNHDiaxopoVFqP8i7nfUQMH1Wss9YA,77
6
+ returnn/_setup_info_generated.py,sha256=94BElbYUGmjpsoY8BzvfW39RUTXw9Fy3UwlPoEjrkU8,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
207
207
  returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=skFyutdVztxgTEk3HHJ8S83qRWbNpkNT8Tj16Ic0_hE,6981
210
- returnn/torch/engine.py,sha256=sU9A96icaj65uaEkX4i4aUK3IrB2S19_Fb9_sueB_JE,77426
210
+ returnn/torch/engine.py,sha256=2FLLb2m4sWFwYOQGREDSxQCheCKd_osnFJCdLa_4TzE,76400
211
211
  returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250304.101951.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250304.101951.dist-info/METADATA,sha256=33hja9F4qBUxg8Y8J2At3XpXgjYhstvueyHkyMN3GdI,5215
258
- returnn-1.20250304.101951.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
- returnn-1.20250304.101951.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250304.101951.dist-info/RECORD,,
256
+ returnn-1.20250304.113330.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250304.113330.dist-info/METADATA,sha256=BmSxZKkRxyL20E4Zsud1muiQ-rth9Ob9PMR-43IrAMw,5215
258
+ returnn-1.20250304.113330.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
+ returnn-1.20250304.113330.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250304.113330.dist-info/RECORD,,