returnn 1.20250819.10249__py3-none-any.whl → 1.20250820.123936__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/torch/updater.py +10 -8
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.123936.dist-info}/METADATA +1 -1
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.123936.dist-info}/RECORD +8 -8
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.123936.dist-info}/LICENSE +0 -0
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.123936.dist-info}/WHEEL +0 -0
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.123936.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250820.123936'
|
|
2
|
+
long_version = '1.20250820.123936+git.9c74169'
|
returnn/torch/updater.py
CHANGED
|
@@ -5,11 +5,10 @@ and model param update logic in general.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable,
|
|
8
|
+
from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Set, Dict, List, Tuple
|
|
9
9
|
import os
|
|
10
10
|
import gc
|
|
11
11
|
import torch
|
|
12
|
-
import typing
|
|
13
12
|
|
|
14
13
|
import returnn
|
|
15
14
|
from returnn.log import log
|
|
@@ -130,8 +129,8 @@ class Updater:
|
|
|
130
129
|
else:
|
|
131
130
|
raise NotImplementedError("not implemented for not callable dynamic_learning_rate")
|
|
132
131
|
|
|
133
|
-
self._optimizer_opts = None
|
|
134
|
-
self.optimizer
|
|
132
|
+
self._optimizer_opts: Optional[Dict[str, Any]] = None
|
|
133
|
+
self.optimizer: Optional[torch.optim.Optimizer] = None
|
|
135
134
|
|
|
136
135
|
self._grad_clip = self.config.float("gradient_clip", 0.0)
|
|
137
136
|
self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
|
|
@@ -481,7 +480,7 @@ class Updater:
|
|
|
481
480
|
|
|
482
481
|
def _get_optimizer_param_groups(
|
|
483
482
|
self, optim_class: Type[torch.optim.Optimizer], optimizer_opts: Dict[str, Any]
|
|
484
|
-
) -> Union[
|
|
483
|
+
) -> Union[Iterable[Dict[str, Any]], Iterable[torch.nn.Parameter]]:
|
|
485
484
|
"""
|
|
486
485
|
The weight_decay parameter from AdamW affects the weights of layers such as LayerNorm and Embedding.
|
|
487
486
|
This function creates a blacklist of network modules and splits the optimizer groups in two:
|
|
@@ -514,10 +513,13 @@ class Updater:
|
|
|
514
513
|
if custom_param_groups is not None:
|
|
515
514
|
assert callable(custom_param_groups), f"invalid param_groups_custom {custom_param_groups!r}"
|
|
516
515
|
rf_model = wrapped_pt_module_to_rf_module(self.network)
|
|
517
|
-
|
|
516
|
+
custom_param_groups_ = custom_param_groups(
|
|
518
517
|
model=self.network, rf_model=rf_model, optimizer_class=optim_class, optimizer_opts=optimizer_opts
|
|
519
518
|
)
|
|
520
|
-
|
|
519
|
+
assert isinstance(custom_param_groups_, Iterable) and all(
|
|
520
|
+
isinstance(group, dict) for group in custom_param_groups_
|
|
521
|
+
), f"invalid param_groups_custom {custom_param_groups!r} result {custom_param_groups_!r} type"
|
|
522
|
+
return custom_param_groups_
|
|
521
523
|
|
|
522
524
|
network_params = self.network.parameters()
|
|
523
525
|
|
|
@@ -545,7 +547,7 @@ class Updater:
|
|
|
545
547
|
# Parameters without weight decay: biases + LayerNorm/Embedding layers.
|
|
546
548
|
wd_params = set()
|
|
547
549
|
no_wd_params = set()
|
|
548
|
-
blacklist_wd_modules = optimizer_opts.pop("weight_decay_modules_blacklist", None)
|
|
550
|
+
blacklist_wd_modules: Any = optimizer_opts.pop("weight_decay_modules_blacklist", None)
|
|
549
551
|
if blacklist_wd_modules is None:
|
|
550
552
|
blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
551
553
|
else:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=04mKkkm6MNJQzWwBZq5enBZVM1vpvfr-0W705kv7vsk,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=UEqWo5ULLq_sbbrAc2dmq5fjH29qlfAXGYB_hw8vpZQ,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
|
|
210
210
|
returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
|
|
211
|
-
returnn/torch/updater.py,sha256=
|
|
211
|
+
returnn/torch/updater.py,sha256=9Ju_LuEdXO7LMxt9rs9_6ReePG5y1h36N3coN696rVI,30285
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
|
|
214
214
|
returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250820.123936.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250820.123936.dist-info/METADATA,sha256=04mKkkm6MNJQzWwBZq5enBZVM1vpvfr-0W705kv7vsk,5215
|
|
258
|
+
returnn-1.20250820.123936.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250820.123936.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250820.123936.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|