returnn 1.20250820.171158__py3-none-any.whl → 1.20250826.115240__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/torch/updater.py +42 -17
- returnn/util/file_cache.py +4 -0
- {returnn-1.20250820.171158.dist-info → returnn-1.20250826.115240.dist-info}/METADATA +1 -1
- {returnn-1.20250820.171158.dist-info → returnn-1.20250826.115240.dist-info}/RECORD +9 -9
- {returnn-1.20250820.171158.dist-info → returnn-1.20250826.115240.dist-info}/LICENSE +0 -0
- {returnn-1.20250820.171158.dist-info → returnn-1.20250826.115240.dist-info}/WHEEL +0 -0
- {returnn-1.20250820.171158.dist-info → returnn-1.20250826.115240.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250826.115240'
|
|
2
|
+
long_version = '1.20250826.115240+git.2baa52f'
|
returnn/torch/updater.py
CHANGED
|
@@ -95,6 +95,8 @@ class Updater:
|
|
|
95
95
|
Wraps a torch.optim.Optimizer, and extends it by some further functionality.
|
|
96
96
|
"""
|
|
97
97
|
|
|
98
|
+
_OptimizerParamGroupsExtraOpts = ("learning_rate_multiplier",)
|
|
99
|
+
|
|
98
100
|
def __init__(self, *, config, network, device, initial_learning_rate=1.0):
|
|
99
101
|
"""
|
|
100
102
|
:param returnn.config.Config config: config defining the training conditions.
|
|
@@ -131,6 +133,7 @@ class Updater:
|
|
|
131
133
|
|
|
132
134
|
self._optimizer_opts: Optional[Dict[str, Any]] = None
|
|
133
135
|
self.optimizer: Optional[torch.optim.Optimizer] = None
|
|
136
|
+
self._optimizer_param_groups_extra_opts: Optional[List[Dict[str, Any]]] = None
|
|
134
137
|
|
|
135
138
|
self._grad_clip = self.config.float("gradient_clip", 0.0)
|
|
136
139
|
self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
|
|
@@ -189,8 +192,15 @@ class Updater:
|
|
|
189
192
|
)
|
|
190
193
|
self._effective_learning_rate = float(lr)
|
|
191
194
|
if self.optimizer:
|
|
192
|
-
|
|
193
|
-
|
|
195
|
+
if self._optimizer_param_groups_extra_opts:
|
|
196
|
+
assert len(self.optimizer.param_groups) == len(self._optimizer_param_groups_extra_opts)
|
|
197
|
+
lr_multiplies = [
|
|
198
|
+
opts.get("learning_rate_multiplier", 1.0) for opts in self._optimizer_param_groups_extra_opts
|
|
199
|
+
]
|
|
200
|
+
else:
|
|
201
|
+
lr_multiplies = [1.0] * len(self.optimizer.param_groups)
|
|
202
|
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
203
|
+
param_group["lr"] = self._effective_learning_rate * lr_multiplies[i]
|
|
194
204
|
|
|
195
205
|
def set_current_train_step(self, *, global_train_step: int, epoch: int, epoch_continuous: Optional[float] = None):
|
|
196
206
|
"""
|
|
@@ -273,7 +283,7 @@ class Updater:
|
|
|
273
283
|
if optimizer_opts is None:
|
|
274
284
|
raise ValueError("config field 'optimizer' needs to be set explicitely for the Torch backend")
|
|
275
285
|
self._optimizer_opts = optimizer_opts
|
|
276
|
-
self.optimizer = self._create_optimizer(optimizer_opts)
|
|
286
|
+
self.optimizer, self._optimizer_param_groups_extra_opts = self._create_optimizer(optimizer_opts)
|
|
277
287
|
|
|
278
288
|
def load_optimizer(self, filename):
|
|
279
289
|
"""
|
|
@@ -421,21 +431,20 @@ class Updater:
|
|
|
421
431
|
"""
|
|
422
432
|
return self.optimizer
|
|
423
433
|
|
|
424
|
-
def _create_optimizer(self, optimizer_opts):
|
|
434
|
+
def _create_optimizer(self, optimizer_opts) -> Tuple[torch.optim.Optimizer, Optional[List[Dict[str, Any]]]]:
|
|
425
435
|
"""
|
|
426
436
|
Returns a valid optimizer considering the dictionary given by the user in the config.
|
|
427
437
|
|
|
428
438
|
:param dict[str]|str optimizer_opts: Optimizer configuration specified by the user.
|
|
429
439
|
If it's a dict, it must contain "class" with the optimizer name or callable.
|
|
430
440
|
If it's a str, it must be the optimizer name.
|
|
431
|
-
:return:
|
|
432
|
-
:rtype: torch.optim.Optimizer
|
|
441
|
+
:return: tuple (optimizer, optional optimizer_param_groups_extra_opts).
|
|
433
442
|
"""
|
|
434
443
|
lr = self.learning_rate
|
|
435
444
|
|
|
436
445
|
# If the parameter is already a valid optimizer, return it without further processing
|
|
437
446
|
if isinstance(optimizer_opts, torch.optim.Optimizer):
|
|
438
|
-
return optimizer_opts
|
|
447
|
+
return optimizer_opts, None
|
|
439
448
|
elif callable(optimizer_opts):
|
|
440
449
|
optimizer_opts: Dict[str, Any] = {"class": optimizer_opts}
|
|
441
450
|
else:
|
|
@@ -461,12 +470,23 @@ class Updater:
|
|
|
461
470
|
lr = lr * opt_kwargs.pop("learning_rate_multiplier", 1.0)
|
|
462
471
|
opt_kwargs["lr"] = lr
|
|
463
472
|
|
|
464
|
-
|
|
465
|
-
|
|
473
|
+
param_groups = self._get_optimizer_param_groups(optim_class, opt_kwargs)
|
|
474
|
+
param_groups = list(param_groups)
|
|
475
|
+
assert len(param_groups) > 0, "got an empty parameter list?"
|
|
476
|
+
if not isinstance(param_groups[0], dict):
|
|
477
|
+
param_groups = [{"params": param_groups}]
|
|
478
|
+
optimizer_param_groups_extra_opts: Optional[List[Dict[str, Any]]] = None
|
|
479
|
+
if any(any(key in group for key in self._OptimizerParamGroupsExtraOpts) for group in param_groups):
|
|
480
|
+
param_groups = [dict(group) for group in param_groups] # copy to make sure we can modify it
|
|
481
|
+
optimizer_param_groups_extra_opts = [
|
|
482
|
+
{key: group.pop(key) for key in self._OptimizerParamGroupsExtraOpts if key in group}
|
|
483
|
+
for group in param_groups
|
|
484
|
+
]
|
|
485
|
+
optimizer = optim_class(param_groups, **opt_kwargs)
|
|
466
486
|
print("Optimizer: %s" % optimizer, file=log.v1)
|
|
467
487
|
assert isinstance(optimizer, torch.optim.Optimizer)
|
|
468
488
|
|
|
469
|
-
return optimizer
|
|
489
|
+
return optimizer, optimizer_param_groups_extra_opts
|
|
470
490
|
|
|
471
491
|
def _create_default_optimizer(self):
|
|
472
492
|
"""
|
|
@@ -551,11 +571,9 @@ class Updater:
|
|
|
551
571
|
# Parameters without weight decay: biases + LayerNorm/Embedding layers.
|
|
552
572
|
wd_params = set()
|
|
553
573
|
no_wd_params = set()
|
|
554
|
-
blacklist_wd_modules
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
else:
|
|
558
|
-
blacklist_wd_modules = _wrap_user_blacklist_wd_modules(blacklist_wd_modules)
|
|
574
|
+
blacklist_wd_modules = wrap_user_blacklist_wd_modules(
|
|
575
|
+
optimizer_opts.pop("weight_decay_modules_blacklist", None)
|
|
576
|
+
)
|
|
559
577
|
custom_include_check = optimizer_opts.pop("weight_decay_custom_include_check", None)
|
|
560
578
|
if custom_include_check:
|
|
561
579
|
assert callable(custom_include_check), f"invalid weight_decay_custom_include_check {custom_include_check!r}"
|
|
@@ -602,9 +620,16 @@ class Updater:
|
|
|
602
620
|
return optim_groups
|
|
603
621
|
|
|
604
622
|
|
|
605
|
-
def
|
|
606
|
-
mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]],
|
|
623
|
+
def wrap_user_blacklist_wd_modules(
|
|
624
|
+
mods: Optional[Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]]],
|
|
607
625
|
) -> Tuple[type, ...]:
|
|
626
|
+
"""
|
|
627
|
+
Wraps the user-provided blacklist_weight_decay_modules into a tuple of types.
|
|
628
|
+
This supports both pure PyTorch modules (e.g. "torch.nn.LayerNorm")
|
|
629
|
+
and RF modules (e.g. "rf.LayerNorm"), which can be specified as strings or types.
|
|
630
|
+
"""
|
|
631
|
+
if mods is None:
|
|
632
|
+
return torch.nn.LayerNorm, torch.nn.Embedding
|
|
608
633
|
assert isinstance(mods, (list, tuple)), f"invalid blacklist_weight_decay_modules {mods!r}"
|
|
609
634
|
res = []
|
|
610
635
|
for mod in mods:
|
returnn/util/file_cache.py
CHANGED
|
@@ -390,6 +390,10 @@ class FileCache:
|
|
|
390
390
|
# - https://github.com/rwth-i6/returnn/pull/1709
|
|
391
391
|
os.utime(dst_filename, None)
|
|
392
392
|
os.utime(info_file_name, None)
|
|
393
|
+
# Ensure we proactively make space for other users
|
|
394
|
+
# even in case we have all files ready on disk.
|
|
395
|
+
# See for discussion: https://github.com/rwth-i6/returnn/pull/1752.
|
|
396
|
+
self.cleanup(need_at_least_free_space_size=0)
|
|
393
397
|
return
|
|
394
398
|
|
|
395
399
|
print(f"FileCache: Copy file {src_filename} to cache")
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=FML_TBBHEcTCM4--oW_2wEb8Yk4ybXtAH6CQ7tSums0,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=BCJ_PJr520UrgVLYRMgDdj07KiDWctntLtC9aS8cTH0,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
|
|
|
208
208
|
returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
|
|
209
209
|
returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
|
|
210
210
|
returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
|
|
211
|
-
returnn/torch/updater.py,sha256
|
|
211
|
+
returnn/torch/updater.py,sha256=7lMoA01Yzp18MY5jjIFncsajTjOD713pK38nU6r-jiE,31999
|
|
212
212
|
returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
|
|
213
213
|
returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
|
|
214
214
|
returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
|
|
@@ -238,7 +238,7 @@ returnn/util/better_exchook.py,sha256=39yvRecluDgYhViwSkaQ8crJ_cBWI63KeEGuK4RKe5
|
|
|
238
238
|
returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
|
|
239
239
|
returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
|
|
240
240
|
returnn/util/debug_helpers.py,sha256=0EINLK4uLtoSt5_kHs1M2NIFpMd0S7i4c4rx90U4fJk,2914
|
|
241
|
-
returnn/util/file_cache.py,sha256=
|
|
241
|
+
returnn/util/file_cache.py,sha256=ERGz6TEWqetGk4odj1x6cMfecfQ5G5G4e5psSrbx03Y,27852
|
|
242
242
|
returnn/util/fsa.py,sha256=k2lJ8tyf_g44Xk1EPVLwDwpP4spoMTqIigDVOWocQHY,59177
|
|
243
243
|
returnn/util/literal_py_to_pickle.py,sha256=3dnjWPeeiDT2xp4bRDgIf9yddx7b1AG7mOKEn_jiSl8,2173
|
|
244
244
|
returnn/util/lru_cache.py,sha256=7Q5H3a8b07E8e1iB7PA9jCpRnxMJZOFS2KO07cy0gqk,11446
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250826.115240.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250826.115240.dist-info/METADATA,sha256=FML_TBBHEcTCM4--oW_2wEb8Yk4ybXtAH6CQ7tSums0,5215
|
|
258
|
+
returnn-1.20250826.115240.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250826.115240.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250826.115240.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|