returnn 1.20250422.113157__py3-none-any.whl → 1.20250423.105638__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/torch/engine.py +34 -4
- {returnn-1.20250422.113157.dist-info → returnn-1.20250423.105638.dist-info}/METADATA +1 -1
- {returnn-1.20250422.113157.dist-info → returnn-1.20250423.105638.dist-info}/RECORD +8 -8
- {returnn-1.20250422.113157.dist-info → returnn-1.20250423.105638.dist-info}/LICENSE +0 -0
- {returnn-1.20250422.113157.dist-info → returnn-1.20250423.105638.dist-info}/WHEEL +0 -0
- {returnn-1.20250422.113157.dist-info → returnn-1.20250423.105638.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250423.105638'
|
|
2
|
+
long_version = '1.20250423.105638+git.767de47'
|
returnn/torch/engine.py
CHANGED
|
@@ -82,15 +82,17 @@ class Engine(EngineBase):
|
|
|
82
82
|
) # type: Union[int,float,Dict[str,int],NumbersDict]
|
|
83
83
|
self._orig_model = None # type: Optional[Union[rf.Module, torch.nn.Module]]
|
|
84
84
|
self._pt_model = None # type: Optional[torch.nn.Module]
|
|
85
|
-
self.
|
|
86
|
-
self.
|
|
87
|
-
self.
|
|
85
|
+
self._epoch_start_func: Optional[Callable] = self.config.typed_value("epoch_start")
|
|
86
|
+
self._epoch_end_func: Optional[Callable] = self.config.typed_value("epoch_end")
|
|
87
|
+
self._train_step_func: Optional[Callable] = None
|
|
88
|
+
self._forward_step_func: Optional[Callable] = self.config.typed_value("forward_step")
|
|
89
|
+
self._forward_step_expected_outputs: Optional[TensorDict] = None
|
|
88
90
|
if self.config.typed_value("model_outputs") is not None:
|
|
89
91
|
self._forward_step_expected_outputs = TensorDict()
|
|
90
92
|
self._forward_step_expected_outputs.update(self.config.typed_value("model_outputs"), auto_convert=True)
|
|
91
93
|
self._save_model_epoch_interval = 1
|
|
92
94
|
self._ignore_param_set: Set[str] = set() # for the updater and for saving the model checkpoint
|
|
93
|
-
self._updater
|
|
95
|
+
self._updater: Optional[Updater] = None
|
|
94
96
|
|
|
95
97
|
self._use_autocast = False
|
|
96
98
|
self._autocast_dtype = None # type: Optional[str]
|
|
@@ -319,6 +321,26 @@ class Engine(EngineBase):
|
|
|
319
321
|
]
|
|
320
322
|
print(f"Memory usage ({self._device}):", " ".join(stats), file=log.v1)
|
|
321
323
|
|
|
324
|
+
def _on_epoch_start(self, *, dataset_name: str):
|
|
325
|
+
if self._epoch_start_func:
|
|
326
|
+
self._epoch_start_func(
|
|
327
|
+
epoch=self.epoch,
|
|
328
|
+
step=self.global_train_step,
|
|
329
|
+
model=self._orig_model,
|
|
330
|
+
dataset_name=dataset_name,
|
|
331
|
+
**util.get_fwd_compat_kwargs(),
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def _on_epoch_end(self, *, dataset_name: str):
|
|
335
|
+
if self._epoch_end_func:
|
|
336
|
+
self._epoch_end_func(
|
|
337
|
+
epoch=self.epoch,
|
|
338
|
+
step=self.global_train_step,
|
|
339
|
+
model=self._orig_model,
|
|
340
|
+
dataset_name=dataset_name,
|
|
341
|
+
**util.get_fwd_compat_kwargs(),
|
|
342
|
+
)
|
|
343
|
+
|
|
322
344
|
def train_epoch(self):
|
|
323
345
|
"""
|
|
324
346
|
train one (sub)epoch
|
|
@@ -346,6 +368,8 @@ class Engine(EngineBase):
|
|
|
346
368
|
self._maybe_reset_dev_memory_caches()
|
|
347
369
|
self._reset_dev_memory_stats()
|
|
348
370
|
|
|
371
|
+
self._on_epoch_start(dataset_name="train")
|
|
372
|
+
|
|
349
373
|
if self.config.bool("debug_shell_before_train_loop", False):
|
|
350
374
|
print("debug_shell_before_train_loop", file=log.v1)
|
|
351
375
|
debug_shell(user_ns=locals(), user_global_ns=globals(), exit_afterwards=False)
|
|
@@ -564,6 +588,8 @@ class Engine(EngineBase):
|
|
|
564
588
|
|
|
565
589
|
self._maybe_report_dev_memory_stats()
|
|
566
590
|
|
|
591
|
+
self._on_epoch_end(dataset_name="train")
|
|
592
|
+
|
|
567
593
|
if self.epoch % self._save_model_epoch_interval == 0 or self.epoch == self._final_epoch:
|
|
568
594
|
if self.model_filename:
|
|
569
595
|
self._save_model()
|
|
@@ -612,6 +638,8 @@ class Engine(EngineBase):
|
|
|
612
638
|
|
|
613
639
|
print(f"Evaluating dataset {dataset_name!r}", file=log.v3)
|
|
614
640
|
|
|
641
|
+
self._on_epoch_start(dataset_name=dataset_name)
|
|
642
|
+
|
|
615
643
|
accumulated_losses_dict = NumbersDict()
|
|
616
644
|
accumulated_inv_norm_factors_dict = NumbersDict()
|
|
617
645
|
step_idx = 0
|
|
@@ -685,6 +713,8 @@ class Engine(EngineBase):
|
|
|
685
713
|
_has_data = torch.tensor([False], device="cpu", dtype=torch.int8)
|
|
686
714
|
torch.distributed.broadcast(_has_data, src=0)
|
|
687
715
|
|
|
716
|
+
self._on_epoch_end(dataset_name=dataset_name)
|
|
717
|
+
|
|
688
718
|
if not self._torch_distributed_ctx or self._torch_distributed_ctx.rank() == 0:
|
|
689
719
|
print(
|
|
690
720
|
f"Epoch {self.epoch} evaluation:",
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=7FZF3sif_hKT_9vHgl4snASut9Lc4uPAOGJfqaOxr_A,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=_2mdcTTEJ__VlmLagOoes0sVhrZN1SeK64wJQZ4crrY,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -207,7 +207,7 @@ returnn/tf/util/open_fst.py,sha256=sZRDw4TbxvhGqpGdUJWy1ebvlZm4_RPhygpRw9uLAOQ,1
|
|
|
207
207
|
returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=skFyutdVztxgTEk3HHJ8S83qRWbNpkNT8Tj16Ic0_hE,6981
|
|
210
|
-
returnn/torch/engine.py,sha256=
|
|
210
|
+
returnn/torch/engine.py,sha256=yfqP9jzOH1OjiETqoBh20YOeEaX_kyr_kwPUkhSFxiI,77833
|
|
211
211
|
returnn/torch/updater.py,sha256=GqtBvZpElPVMm0lq84JPl4NVLFFETZAzAbR0rTomSao,28249
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=OSoy3x1KiyiJCr7DfF5uPFAu09We2N2WbA0yo-pYXxM,7601
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250423.105638.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250423.105638.dist-info/METADATA,sha256=7FZF3sif_hKT_9vHgl4snASut9Lc4uPAOGJfqaOxr_A,5215
|
|
258
|
+
returnn-1.20250423.105638.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250423.105638.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250423.105638.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|