returnn 1.20250820.123936__py3-none-any.whl → 1.20250821.93927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250820.123936
3
+ Version: 1.20250821.93927
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250820.123936'
2
- long_version = '1.20250820.123936+git.9c74169'
1
+ version = '1.20250821.093927'
2
+ long_version = '1.20250821.093927+git.ec56958'
returnn/torch/updater.py CHANGED
@@ -95,6 +95,8 @@ class Updater:
95
95
  Wraps a torch.optim.Optimizer, and extends it by some further functionality.
96
96
  """
97
97
 
98
+ _OptimizerParamGroupsExtraOpts = ("learning_rate_multiplier",)
99
+
98
100
  def __init__(self, *, config, network, device, initial_learning_rate=1.0):
99
101
  """
100
102
  :param returnn.config.Config config: config defining the training conditions.
@@ -131,6 +133,7 @@ class Updater:
131
133
 
132
134
  self._optimizer_opts: Optional[Dict[str, Any]] = None
133
135
  self.optimizer: Optional[torch.optim.Optimizer] = None
136
+ self._optimizer_param_groups_extra_opts: Optional[List[Dict[str, Any]]] = None
134
137
 
135
138
  self._grad_clip = self.config.float("gradient_clip", 0.0)
136
139
  self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
@@ -189,8 +192,15 @@ class Updater:
189
192
  )
190
193
  self._effective_learning_rate = float(lr)
191
194
  if self.optimizer:
192
- for param_group in self.optimizer.param_groups:
193
- param_group["lr"] = self._effective_learning_rate
195
+ if self._optimizer_param_groups_extra_opts:
196
+ assert len(self.optimizer.param_groups) == len(self._optimizer_param_groups_extra_opts)
197
+ lr_multiplies = [
198
+ opts.get("learning_rate_multiplier", 1.0) for opts in self._optimizer_param_groups_extra_opts
199
+ ]
200
+ else:
201
+ lr_multiplies = [1.0] * len(self.optimizer.param_groups)
202
+ for i, param_group in enumerate(self.optimizer.param_groups):
203
+ param_group["lr"] = self._effective_learning_rate * lr_multiplies[i]
194
204
 
195
205
  def set_current_train_step(self, *, global_train_step: int, epoch: int, epoch_continuous: Optional[float] = None):
196
206
  """
@@ -273,7 +283,7 @@ class Updater:
273
283
  if optimizer_opts is None:
274
284
  raise ValueError("config field 'optimizer' needs to be set explicitely for the Torch backend")
275
285
  self._optimizer_opts = optimizer_opts
276
- self.optimizer = self._create_optimizer(optimizer_opts)
286
+ self.optimizer, self._optimizer_param_groups_extra_opts = self._create_optimizer(optimizer_opts)
277
287
 
278
288
  def load_optimizer(self, filename):
279
289
  """
@@ -421,21 +431,20 @@ class Updater:
421
431
  """
422
432
  return self.optimizer
423
433
 
424
- def _create_optimizer(self, optimizer_opts):
434
+ def _create_optimizer(self, optimizer_opts) -> Tuple[torch.optim.Optimizer, Optional[List[Dict[str, Any]]]]:
425
435
  """
426
436
  Returns a valid optimizer considering the dictionary given by the user in the config.
427
437
 
428
438
  :param dict[str]|str optimizer_opts: Optimizer configuration specified by the user.
429
439
  If it's a dict, it must contain "class" with the optimizer name or callable.
430
440
  If it's a str, it must be the optimizer name.
431
- :return: A valid optimizer.
432
- :rtype: torch.optim.Optimizer
441
+ :return: tuple (optimizer, optional optimizer_param_groups_extra_opts).
433
442
  """
434
443
  lr = self.learning_rate
435
444
 
436
445
  # If the parameter is already a valid optimizer, return it without further processing
437
446
  if isinstance(optimizer_opts, torch.optim.Optimizer):
438
- return optimizer_opts
447
+ return optimizer_opts, None
439
448
  elif callable(optimizer_opts):
440
449
  optimizer_opts: Dict[str, Any] = {"class": optimizer_opts}
441
450
  else:
@@ -461,12 +470,23 @@ class Updater:
461
470
  lr = lr * opt_kwargs.pop("learning_rate_multiplier", 1.0)
462
471
  opt_kwargs["lr"] = lr
463
472
 
464
- params_or_param_groups = self._get_optimizer_param_groups(optim_class, opt_kwargs)
465
- optimizer = optim_class(params_or_param_groups, **opt_kwargs)
473
+ param_groups = self._get_optimizer_param_groups(optim_class, opt_kwargs)
474
+ param_groups = list(param_groups)
475
+ assert len(param_groups) > 0, "got an empty parameter list?"
476
+ if not isinstance(param_groups[0], dict):
477
+ param_groups = [{"params": param_groups}]
478
+ optimizer_param_groups_extra_opts: Optional[List[Dict[str, Any]]] = None
479
+ if any(any(key in group for key in self._OptimizerParamGroupsExtraOpts) for group in param_groups):
480
+ param_groups = [dict(group) for group in param_groups] # copy to make sure we can modify it
481
+ optimizer_param_groups_extra_opts = [
482
+ {key: group.pop(key) for key in self._OptimizerParamGroupsExtraOpts if key in group}
483
+ for group in param_groups
484
+ ]
485
+ optimizer = optim_class(param_groups, **opt_kwargs)
466
486
  print("Optimizer: %s" % optimizer, file=log.v1)
467
487
  assert isinstance(optimizer, torch.optim.Optimizer)
468
488
 
469
- return optimizer
489
+ return optimizer, optimizer_param_groups_extra_opts
470
490
 
471
491
  def _create_default_optimizer(self):
472
492
  """
@@ -514,7 +534,11 @@ class Updater:
514
534
  assert callable(custom_param_groups), f"invalid param_groups_custom {custom_param_groups!r}"
515
535
  rf_model = wrapped_pt_module_to_rf_module(self.network)
516
536
  custom_param_groups_ = custom_param_groups(
517
- model=self.network, rf_model=rf_model, optimizer_class=optim_class, optimizer_opts=optimizer_opts
537
+ model=self.network,
538
+ rf_model=rf_model,
539
+ optimizer_class=optim_class,
540
+ optimizer_opts=optimizer_opts,
541
+ **get_fwd_compat_kwargs(),
518
542
  )
519
543
  assert isinstance(custom_param_groups_, Iterable) and all(
520
544
  isinstance(group, dict) for group in custom_param_groups_
@@ -547,11 +571,9 @@ class Updater:
547
571
  # Parameters without weight decay: biases + LayerNorm/Embedding layers.
548
572
  wd_params = set()
549
573
  no_wd_params = set()
550
- blacklist_wd_modules: Any = optimizer_opts.pop("weight_decay_modules_blacklist", None)
551
- if blacklist_wd_modules is None:
552
- blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
553
- else:
554
- blacklist_wd_modules = _wrap_user_blacklist_wd_modules(blacklist_wd_modules)
574
+ blacklist_wd_modules = wrap_user_blacklist_wd_modules(
575
+ optimizer_opts.pop("weight_decay_modules_blacklist", None)
576
+ )
555
577
  custom_include_check = optimizer_opts.pop("weight_decay_custom_include_check", None)
556
578
  if custom_include_check:
557
579
  assert callable(custom_include_check), f"invalid weight_decay_custom_include_check {custom_include_check!r}"
@@ -598,9 +620,16 @@ class Updater:
598
620
  return optim_groups
599
621
 
600
622
 
601
- def _wrap_user_blacklist_wd_modules(
602
- mods: Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]],
623
+ def wrap_user_blacklist_wd_modules(
624
+ mods: Optional[Sequence[Union[str, Type[rf.Module], Type[torch.nn.Module]]]],
603
625
  ) -> Tuple[type, ...]:
626
+ """
627
+ Wraps the user-provided blacklist_weight_decay_modules into a tuple of types.
628
+ This supports both pure PyTorch modules (e.g. "torch.nn.LayerNorm")
629
+ and RF modules (e.g. "rf.LayerNorm"), which can be specified as strings or types.
630
+ """
631
+ if mods is None:
632
+ return torch.nn.LayerNorm, torch.nn.Embedding
604
633
  assert isinstance(mods, (list, tuple)), f"invalid blacklist_weight_decay_modules {mods!r}"
605
634
  res = []
606
635
  for mod in mods:
returnn/util/basic.py CHANGED
@@ -2459,8 +2459,12 @@ class DictRefKeys(Generic[K, V]):
2459
2459
  Like `dict`, but hash and equality of the keys
2460
2460
  """
2461
2461
 
2462
- def __init__(self):
2462
+ def __init__(self, items: Union[None, Iterable[Tuple[K, V]], Dict[K, V]] = None, /, **kwargs):
2463
2463
  self._d = {} # type: Dict[RefIdEq[K], V]
2464
+ if items is not None:
2465
+ self.update(items)
2466
+ if kwargs:
2467
+ self.update(kwargs)
2464
2468
 
2465
2469
  def __repr__(self):
2466
2470
  return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()])
@@ -2489,6 +2493,15 @@ class DictRefKeys(Generic[K, V]):
2489
2493
  def __contains__(self, item: K):
2490
2494
  return RefIdEq(item) in self._d
2491
2495
 
2496
+ def update(self, other: Union[Dict[K, V], Iterable[Tuple[K, V]]], /):
2497
+ """
2498
+ :param other: dict or iterable of (key, value) tuples
2499
+ """
2500
+ if isinstance(other, dict):
2501
+ other = other.items()
2502
+ for k, v in other:
2503
+ self[k] = v
2504
+
2492
2505
 
2493
2506
  def make_dll_name(basename):
2494
2507
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250820.123936
3
+ Version: 1.20250821.93927
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=04mKkkm6MNJQzWwBZq5enBZVM1vpvfr-0W705kv7vsk,5215
1
+ returnn/PKG-INFO,sha256=oANuix-AgPHTYt9t0MfUfM0PyjXdMuEBJrN2b99tMUI,5214
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=UEqWo5ULLq_sbbrAc2dmq5fjH29qlfAXGYB_hw8vpZQ,77
6
+ returnn/_setup_info_generated.py,sha256=o3ap30O-BqlF3l-gxibSvi0MgnNA3LJEG6o9FVoM0no,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
210
210
  returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
211
- returnn/torch/updater.py,sha256=9Ju_LuEdXO7LMxt9rs9_6ReePG5y1h36N3coN696rVI,30285
211
+ returnn/torch/updater.py,sha256=7lMoA01Yzp18MY5jjIFncsajTjOD713pK38nU6r-jiE,31999
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
214
214
  returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
@@ -233,7 +233,7 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
233
233
  returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
234
234
  returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
235
235
  returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
236
- returnn/util/basic.py,sha256=9Ig-7XLtvXk3yfycmBEhdJG-WVNDtoND3DmDyXOl018,142627
236
+ returnn/util/basic.py,sha256=UjHujX9pSu_dOgTxozWD0ujj5eSpyj_zD5vFU6bfyms,143096
237
237
  returnn/util/better_exchook.py,sha256=39yvRecluDgYhViwSkaQ8crJ_cBWI63KeEGuK4RKe5w,70843
238
238
  returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
239
239
  returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250820.123936.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250820.123936.dist-info/METADATA,sha256=04mKkkm6MNJQzWwBZq5enBZVM1vpvfr-0W705kv7vsk,5215
258
- returnn-1.20250820.123936.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250820.123936.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250820.123936.dist-info/RECORD,,
256
+ returnn-1.20250821.93927.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250821.93927.dist-info/METADATA,sha256=oANuix-AgPHTYt9t0MfUfM0PyjXdMuEBJrN2b99tMUI,5214
258
+ returnn-1.20250821.93927.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250821.93927.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250821.93927.dist-info/RECORD,,