returnn 1.20250422.113157__py3-none-any.whl → 1.20250423.105638__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250422.113157
3
+ Version: 1.20250423.105638
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250422.113157'
2
- long_version = '1.20250422.113157+git.731b3f5'
1
+ version = '1.20250423.105638'
2
+ long_version = '1.20250423.105638+git.767de47'
returnn/torch/engine.py CHANGED
@@ -82,15 +82,17 @@ class Engine(EngineBase):
82
82
  ) # type: Union[int,float,Dict[str,int],NumbersDict]
83
83
  self._orig_model = None # type: Optional[Union[rf.Module, torch.nn.Module]]
84
84
  self._pt_model = None # type: Optional[torch.nn.Module]
85
- self._train_step_func = None # type: Optional[Callable]
86
- self._forward_step_func = self.config.typed_value("forward_step") # type: Optional[Callable]
87
- self._forward_step_expected_outputs = None # type: Optional[TensorDict]
85
+ self._epoch_start_func: Optional[Callable] = self.config.typed_value("epoch_start")
86
+ self._epoch_end_func: Optional[Callable] = self.config.typed_value("epoch_end")
87
+ self._train_step_func: Optional[Callable] = None
88
+ self._forward_step_func: Optional[Callable] = self.config.typed_value("forward_step")
89
+ self._forward_step_expected_outputs: Optional[TensorDict] = None
88
90
  if self.config.typed_value("model_outputs") is not None:
89
91
  self._forward_step_expected_outputs = TensorDict()
90
92
  self._forward_step_expected_outputs.update(self.config.typed_value("model_outputs"), auto_convert=True)
91
93
  self._save_model_epoch_interval = 1
92
94
  self._ignore_param_set: Set[str] = set() # for the updater and for saving the model checkpoint
93
- self._updater = None # type: Optional[Updater]
95
+ self._updater: Optional[Updater] = None
94
96
 
95
97
  self._use_autocast = False
96
98
  self._autocast_dtype = None # type: Optional[str]
@@ -319,6 +321,26 @@ class Engine(EngineBase):
319
321
  ]
320
322
  print(f"Memory usage ({self._device}):", " ".join(stats), file=log.v1)
321
323
 
324
+ def _on_epoch_start(self, *, dataset_name: str):
325
+ if self._epoch_start_func:
326
+ self._epoch_start_func(
327
+ epoch=self.epoch,
328
+ step=self.global_train_step,
329
+ model=self._orig_model,
330
+ dataset_name=dataset_name,
331
+ **util.get_fwd_compat_kwargs(),
332
+ )
333
+
334
+ def _on_epoch_end(self, *, dataset_name: str):
335
+ if self._epoch_end_func:
336
+ self._epoch_end_func(
337
+ epoch=self.epoch,
338
+ step=self.global_train_step,
339
+ model=self._orig_model,
340
+ dataset_name=dataset_name,
341
+ **util.get_fwd_compat_kwargs(),
342
+ )
343
+
322
344
  def train_epoch(self):
323
345
  """
324
346
  train one (sub)epoch
@@ -346,6 +368,8 @@ class Engine(EngineBase):
346
368
  self._maybe_reset_dev_memory_caches()
347
369
  self._reset_dev_memory_stats()
348
370
 
371
+ self._on_epoch_start(dataset_name="train")
372
+
349
373
  if self.config.bool("debug_shell_before_train_loop", False):
350
374
  print("debug_shell_before_train_loop", file=log.v1)
351
375
  debug_shell(user_ns=locals(), user_global_ns=globals(), exit_afterwards=False)
@@ -564,6 +588,8 @@ class Engine(EngineBase):
564
588
 
565
589
  self._maybe_report_dev_memory_stats()
566
590
 
591
+ self._on_epoch_end(dataset_name="train")
592
+
567
593
  if self.epoch % self._save_model_epoch_interval == 0 or self.epoch == self._final_epoch:
568
594
  if self.model_filename:
569
595
  self._save_model()
@@ -612,6 +638,8 @@ class Engine(EngineBase):
612
638
 
613
639
  print(f"Evaluating dataset {dataset_name!r}", file=log.v3)
614
640
 
641
+ self._on_epoch_start(dataset_name=dataset_name)
642
+
615
643
  accumulated_losses_dict = NumbersDict()
616
644
  accumulated_inv_norm_factors_dict = NumbersDict()
617
645
  step_idx = 0
@@ -685,6 +713,8 @@ class Engine(EngineBase):
685
713
  _has_data = torch.tensor([False], device="cpu", dtype=torch.int8)
686
714
  torch.distributed.broadcast(_has_data, src=0)
687
715
 
716
+ self._on_epoch_end(dataset_name=dataset_name)
717
+
688
718
  if not self._torch_distributed_ctx or self._torch_distributed_ctx.rank() == 0:
689
719
  print(
690
720
  f"Epoch {self.epoch} evaluation:",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250422.113157
3
+ Version: 1.20250423.105638
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=6zi37tT7m8yIZkqfivfYcN-4Pfq9RPDvVa-XXxDychk,5215
1
+ returnn/PKG-INFO,sha256=7FZF3sif_hKT_9vHgl4snASut9Lc4uPAOGJfqaOxr_A,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=A7xaVfPcI0O9Cc0mYkkSvcQmYEgzf52QpAfjMkx9SjI,77
6
+ returnn/_setup_info_generated.py,sha256=_2mdcTTEJ__VlmLagOoes0sVhrZN1SeK64wJQZ4crrY,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
207
207
  returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=skFyutdVztxgTEk3HHJ8S83qRWbNpkNT8Tj16Ic0_hE,6981
210
- returnn/torch/engine.py,sha256=jJMVR8ltJMfF70-m0HxR4xNS94wNfO4CSAF0jFlB7_g,76771
210
+ returnn/torch/engine.py,sha256=yfqP9jzOH1OjiETqoBh20YOeEaX_kyr_kwPUkhSFxiI,77833
211
211
  returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=OSoy3x1KiyiJCr7DfF5uPFAu09We2N2WbA0yo-pYXxM,7601
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250422.113157.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250422.113157.dist-info/METADATA,sha256=6zi37tT7m8yIZkqfivfYcN-4Pfq9RPDvVa-XXxDychk,5215
258
- returnn-1.20250422.113157.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250422.113157.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250422.113157.dist-info/RECORD,,
256
+ returnn-1.20250423.105638.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250423.105638.dist-info/METADATA,sha256=7FZF3sif_hKT_9vHgl4snASut9Lc4uPAOGJfqaOxr_A,5215
258
+ returnn-1.20250423.105638.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250423.105638.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250423.105638.dist-info/RECORD,,