returnn 1.20250819.10249__py3-none-any.whl → 1.20250820.171158__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/torch/updater.py +15 -9
- returnn/util/basic.py +14 -1
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.171158.dist-info}/METADATA +1 -1
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.171158.dist-info}/RECORD +9 -9
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.171158.dist-info}/LICENSE +0 -0
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.171158.dist-info}/WHEEL +0 -0
- {returnn-1.20250819.10249.dist-info → returnn-1.20250820.171158.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250820.171158'
|
|
2
|
+
long_version = '1.20250820.171158+git.d60d270'
|
returnn/torch/updater.py
CHANGED
|
@@ -5,11 +5,10 @@ and model param update logic in general.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable,
|
|
8
|
+
from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Set, Dict, List, Tuple
|
|
9
9
|
import os
|
|
10
10
|
import gc
|
|
11
11
|
import torch
|
|
12
|
-
import typing
|
|
13
12
|
|
|
14
13
|
import returnn
|
|
15
14
|
from returnn.log import log
|
|
@@ -130,8 +129,8 @@ class Updater:
|
|
|
130
129
|
else:
|
|
131
130
|
raise NotImplementedError("not implemented for not callable dynamic_learning_rate")
|
|
132
131
|
|
|
133
|
-
self._optimizer_opts = None
|
|
134
|
-
self.optimizer
|
|
132
|
+
self._optimizer_opts: Optional[Dict[str, Any]] = None
|
|
133
|
+
self.optimizer: Optional[torch.optim.Optimizer] = None
|
|
135
134
|
|
|
136
135
|
self._grad_clip = self.config.float("gradient_clip", 0.0)
|
|
137
136
|
self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
|
|
@@ -481,7 +480,7 @@ class Updater:
|
|
|
481
480
|
|
|
482
481
|
def _get_optimizer_param_groups(
|
|
483
482
|
self, optim_class: Type[torch.optim.Optimizer], optimizer_opts: Dict[str, Any]
|
|
484
|
-
) -> Union[
|
|
483
|
+
) -> Union[Iterable[Dict[str, Any]], Iterable[torch.nn.Parameter]]:
|
|
485
484
|
"""
|
|
486
485
|
The weight_decay parameter from AdamW affects the weights of layers such as LayerNorm and Embedding.
|
|
487
486
|
This function creates a blacklist of network modules and splits the optimizer groups in two:
|
|
@@ -514,10 +513,17 @@ class Updater:
|
|
|
514
513
|
if custom_param_groups is not None:
|
|
515
514
|
assert callable(custom_param_groups), f"invalid param_groups_custom {custom_param_groups!r}"
|
|
516
515
|
rf_model = wrapped_pt_module_to_rf_module(self.network)
|
|
517
|
-
|
|
518
|
-
model=self.network,
|
|
516
|
+
custom_param_groups_ = custom_param_groups(
|
|
517
|
+
model=self.network,
|
|
518
|
+
rf_model=rf_model,
|
|
519
|
+
optimizer_class=optim_class,
|
|
520
|
+
optimizer_opts=optimizer_opts,
|
|
521
|
+
**get_fwd_compat_kwargs(),
|
|
519
522
|
)
|
|
520
|
-
|
|
523
|
+
assert isinstance(custom_param_groups_, Iterable) and all(
|
|
524
|
+
isinstance(group, dict) for group in custom_param_groups_
|
|
525
|
+
), f"invalid param_groups_custom {custom_param_groups!r} result {custom_param_groups_!r} type"
|
|
526
|
+
return custom_param_groups_
|
|
521
527
|
|
|
522
528
|
network_params = self.network.parameters()
|
|
523
529
|
|
|
@@ -545,7 +551,7 @@ class Updater:
|
|
|
545
551
|
# Parameters without weight decay: biases + LayerNorm/Embedding layers.
|
|
546
552
|
wd_params = set()
|
|
547
553
|
no_wd_params = set()
|
|
548
|
-
blacklist_wd_modules = optimizer_opts.pop("weight_decay_modules_blacklist", None)
|
|
554
|
+
blacklist_wd_modules: Any = optimizer_opts.pop("weight_decay_modules_blacklist", None)
|
|
549
555
|
if blacklist_wd_modules is None:
|
|
550
556
|
blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
551
557
|
else:
|
returnn/util/basic.py
CHANGED
|
@@ -2459,8 +2459,12 @@ class DictRefKeys(Generic[K, V]):
|
|
|
2459
2459
|
Like `dict`, but hash and equality of the keys
|
|
2460
2460
|
"""
|
|
2461
2461
|
|
|
2462
|
-
def __init__(self):
|
|
2462
|
+
def __init__(self, items: Union[None, Iterable[Tuple[K, V]], Dict[K, V]] = None, /, **kwargs):
|
|
2463
2463
|
self._d = {} # type: Dict[RefIdEq[K], V]
|
|
2464
|
+
if items is not None:
|
|
2465
|
+
self.update(items)
|
|
2466
|
+
if kwargs:
|
|
2467
|
+
self.update(kwargs)
|
|
2464
2468
|
|
|
2465
2469
|
def __repr__(self):
|
|
2466
2470
|
return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()])
|
|
@@ -2489,6 +2493,15 @@ class DictRefKeys(Generic[K, V]):
|
|
|
2489
2493
|
def __contains__(self, item: K):
|
|
2490
2494
|
return RefIdEq(item) in self._d
|
|
2491
2495
|
|
|
2496
|
+
def update(self, other: Union[Dict[K, V], Iterable[Tuple[K, V]]], /):
|
|
2497
|
+
"""
|
|
2498
|
+
:param other: dict or iterable of (key, value) tuples
|
|
2499
|
+
"""
|
|
2500
|
+
if isinstance(other, dict):
|
|
2501
|
+
other = other.items()
|
|
2502
|
+
for k, v in other:
|
|
2503
|
+
self[k] = v
|
|
2504
|
+
|
|
2492
2505
|
|
|
2493
2506
|
def make_dll_name(basename):
|
|
2494
2507
|
"""
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=Cbw-LFRDg3cxVzUdgT7yaNpnAUDTjOw5nRYgq8jjH8A,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=OQkFtzn37F7h2FgPXE84vkON338V4uRqbMen7VsZWR8,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
|
|
210
210
|
returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
|
|
211
|
-
returnn/torch/updater.py,sha256
|
|
211
|
+
returnn/torch/updater.py,sha256=-v_uY-8jDhreXPmjJYR4cgrlW_7ZI4kt2X2xIZdX_DE,30377
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
|
|
214
214
|
returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
|
|
@@ -233,7 +233,7 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
|
|
|
233
233
|
returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
|
|
234
234
|
returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
|
|
235
235
|
returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
|
|
236
|
-
returnn/util/basic.py,sha256=
|
|
236
|
+
returnn/util/basic.py,sha256=UjHujX9pSu_dOgTxozWD0ujj5eSpyj_zD5vFU6bfyms,143096
|
|
237
237
|
returnn/util/better_exchook.py,sha256=39yvRecluDgYhViwSkaQ8crJ_cBWI63KeEGuK4RKe5w,70843
|
|
238
238
|
returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
|
|
239
239
|
returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250820.171158.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250820.171158.dist-info/METADATA,sha256=Cbw-LFRDg3cxVzUdgT7yaNpnAUDTjOw5nRYgq8jjH8A,5215
|
|
258
|
+
returnn-1.20250820.171158.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250820.171158.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250820.171158.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|