returnn 1.20250206.144022__py3-none-any.whl → 1.20250206.151011__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250206.144022
3
+ Version: 1.20250206.151011
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250206.144022'
2
- long_version = '1.20250206.144022+git.550e757'
1
+ version = '1.20250206.151011'
2
+ long_version = '1.20250206.151011+git.6fa4b38'
returnn/torch/engine.py CHANGED
@@ -3,7 +3,7 @@ Main engine for PyTorch
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- from typing import Optional, Any, Union, Callable, Dict, Set
6
+ from typing import Optional, Any, Union, Callable, Dict, Set, Tuple
7
7
  from contextlib import nullcontext, ExitStack, contextmanager
8
8
 
9
9
  import gc
@@ -371,6 +371,7 @@ class Engine(EngineBase):
371
371
  total_data_size_packed = NumbersDict()
372
372
  total_data_size_padded = NumbersDict()
373
373
 
374
+ report_prefix = f"ep {self.epoch} train"
374
375
  try:
375
376
  while True:
376
377
  with torch.no_grad():
@@ -398,21 +399,13 @@ class Engine(EngineBase):
398
399
  {k: int(util.prod(extern_data_raw[k].shape[:2])) for k in keys_w_seq_len},
399
400
  )
400
401
 
401
- num_seqs_ = (
402
- int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
402
+ num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
403
+ report_prefix=report_prefix,
404
+ extern_data_raw=extern_data_raw,
405
+ step_idx=step_idx,
406
+ prev_num_seqs=num_seqs,
407
+ prev_last_seq_idx=last_seq_idx,
403
408
  )
404
- # Note: The batches might have been shuffled,
405
- # thus we cannot really assert that the seq_idx is always increasing.
406
- last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), last_seq_idx)
407
- if step_idx == 0:
408
- if num_seqs_ >= 0:
409
- print(f"Epoch {self.epoch} num_seqs: {num_seqs_}", file=log.v5)
410
- num_seqs = num_seqs_
411
- elif num_seqs_ >= 0:
412
- assert num_seqs_ == num_seqs
413
- del num_seqs_
414
- if num_seqs is not None:
415
- assert last_seq_idx < num_seqs
416
409
  epoch_continuous = (self.epoch - 1 + (last_seq_idx + 1) / num_seqs) if num_seqs is not None else None
417
410
 
418
411
  # clear the gradients when every gradient accumulation loop starts
@@ -485,7 +478,7 @@ class Engine(EngineBase):
485
478
  accumulated_inv_norm_factors_dict += inv_norm_factors_dict
486
479
  eval_info = self._maybe_extend_losses_info(losses_dict / inv_norm_factors_dict)
487
480
  _print_process(
488
- f"ep {self.epoch} train",
481
+ report_prefix,
489
482
  step=step_idx,
490
483
  eval_info=dict(eval_info),
491
484
  step_duration=step_duration,
@@ -1276,6 +1269,8 @@ class Engine(EngineBase):
1276
1269
  new_dim.dyn_size_ext = _get_tensor_wo_batch_numpy(dim.dyn_size_ext)
1277
1270
  return new_dim
1278
1271
 
1272
+ num_seqs = None
1273
+ last_seq_idx = 0
1279
1274
  report_prefix = f"ep {self.epoch} {dataset.name} forward"
1280
1275
  with torch.no_grad():
1281
1276
  callback.init(model=self._orig_model)
@@ -1283,6 +1278,15 @@ class Engine(EngineBase):
1283
1278
  step_idx = 0
1284
1279
  for extern_data_raw in data_loader:
1285
1280
  step_begin_time = time.monotonic()
1281
+
1282
+ num_seqs, last_seq_idx = _get_num_seqs_last_seq_idx(
1283
+ report_prefix=report_prefix,
1284
+ extern_data_raw=extern_data_raw,
1285
+ step_idx=step_idx,
1286
+ prev_num_seqs=num_seqs,
1287
+ prev_last_seq_idx=last_seq_idx,
1288
+ )
1289
+
1286
1290
  if self._forward_step_expected_outputs:
1287
1291
  # Also resets any dyn dims, which might have been set in the prev step.
1288
1292
  self._forward_step_expected_outputs.reset_content()
@@ -1319,11 +1323,19 @@ class Engine(EngineBase):
1319
1323
  model_outputs_per_batch.data[k] = _get_tensor_wo_batch_numpy(v)
1320
1324
  callback.process_seq(seq_tag=seq_tag, outputs=model_outputs_per_batch)
1321
1325
 
1322
- elapsed_computation_time += time.monotonic() - step_begin_time
1326
+ step_end_time = time.monotonic()
1327
+ step_duration = step_end_time - step_begin_time
1328
+ elapsed_computation_time += step_duration
1329
+
1323
1330
  _print_process(
1324
1331
  report_prefix,
1325
1332
  step=step_idx,
1326
1333
  eval_info=None,
1334
+ step_duration=step_duration,
1335
+ start_elapsed=step_end_time - epoch_start_time,
1336
+ seq_idx=last_seq_idx,
1337
+ num_seqs=num_seqs,
1338
+ batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
1327
1339
  log_memory_usage_device=self._device if self._log_memory_usage else None,
1328
1340
  )
1329
1341
  step_idx += 1
@@ -1601,3 +1613,27 @@ def _get_total_grad_norm(model: torch.nn.Module, p: float) -> float:
1601
1613
  p=p,
1602
1614
  ).item()
1603
1615
  )
1616
+
1617
+
1618
+ def _get_num_seqs_last_seq_idx(
1619
+ *,
1620
+ report_prefix: str,
1621
+ extern_data_raw: Dict[str, Any],
1622
+ step_idx: int,
1623
+ prev_num_seqs: Optional[int],
1624
+ prev_last_seq_idx: int,
1625
+ ) -> Tuple[Optional[int], int]:
1626
+ num_seqs = prev_num_seqs
1627
+ num_seqs_ = int(extern_data_raw["num_seqs"]) if extern_data_raw.get("num_seqs", None) is not None else -1
1628
+ # Note: The batches might have been shuffled,
1629
+ # thus we cannot really assert that the seq_idx is always increasing.
1630
+ last_seq_idx = max(int(extern_data_raw["seq_idx"].max()), prev_last_seq_idx)
1631
+ if step_idx == 0:
1632
+ if num_seqs_ >= 0:
1633
+ print(f"{report_prefix} num_seqs: {num_seqs_}", file=log.v5)
1634
+ num_seqs = num_seqs_
1635
+ elif num_seqs_ >= 0:
1636
+ assert num_seqs_ == num_seqs
1637
+ if num_seqs is not None:
1638
+ assert last_seq_idx < num_seqs
1639
+ return num_seqs, last_seq_idx
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250206.144022
3
+ Version: 1.20250206.151011
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=vBdT0ayV-Q8OjPdp1xlJt0CiopUZKNWKwKuadpiKHDk,5215
1
+ returnn/PKG-INFO,sha256=BbQPkoVha1AYEcED8txzZyyyDiJt3J29FBKlYy1rTYc,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=uAkEz6DVwoN42Nh2WLNsoE4lJ0BtlRznPPlXMWKxJQo,77
6
+ returnn/_setup_info_generated.py,sha256=9T1yfQUP7ASjffpzcwvOLEGNWMdiwS4EmjqY_bI2EdY,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
207
207
  returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=i13cUVjI7GxpO0TAresrNyCM0ZBAaf-cXNr09Fmg_2k,6266
210
- returnn/torch/engine.py,sha256=Zd3ePKFSi5fkvV1FxaYn0QGgu5cag_ocKPwFmKglf3I,75095
210
+ returnn/torch/engine.py,sha256=neM-AL7XQLpZ3V1K4ziqVmij19ey1k2MpLCaFXATOpg,76301
211
211
  returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=_uT_9_gd5HIh1IoRsrebVG-nufSnb7fgC5jyU05GxJg,7580
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250206.144022.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250206.144022.dist-info/METADATA,sha256=vBdT0ayV-Q8OjPdp1xlJt0CiopUZKNWKwKuadpiKHDk,5215
258
- returnn-1.20250206.144022.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
- returnn-1.20250206.144022.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250206.144022.dist-info/RECORD,,
256
+ returnn-1.20250206.151011.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250206.151011.dist-info/METADATA,sha256=BbQPkoVha1AYEcED8txzZyyyDiJt3J29FBKlYy1rTYc,5215
258
+ returnn-1.20250206.151011.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
259
+ returnn-1.20250206.151011.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250206.151011.dist-info/RECORD,,