returnn 1.20250819.10249__py3-none-any.whl → 1.20250820.171158__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250819.10249
3
+ Version: 1.20250820.171158
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250819.010249'
2
- long_version = '1.20250819.010249+git.9c1f159'
1
+ version = '1.20250820.171158'
2
+ long_version = '1.20250820.171158+git.d60d270'
returnn/torch/updater.py CHANGED
@@ -5,11 +5,10 @@ and model param update logic in general.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Iterator, Set, Dict, List, Tuple
8
+ from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Set, Dict, List, Tuple
9
9
  import os
10
10
  import gc
11
11
  import torch
12
- import typing
13
12
 
14
13
  import returnn
15
14
  from returnn.log import log
@@ -130,8 +129,8 @@ class Updater:
130
129
  else:
131
130
  raise NotImplementedError("not implemented for not callable dynamic_learning_rate")
132
131
 
133
- self._optimizer_opts = None
134
- self.optimizer = None # type: typing.Optional[torch.optim.Optimizer]
132
+ self._optimizer_opts: Optional[Dict[str, Any]] = None
133
+ self.optimizer: Optional[torch.optim.Optimizer] = None
135
134
 
136
135
  self._grad_clip = self.config.float("gradient_clip", 0.0)
137
136
  self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
@@ -481,7 +480,7 @@ class Updater:
481
480
 
482
481
  def _get_optimizer_param_groups(
483
482
  self, optim_class: Type[torch.optim.Optimizer], optimizer_opts: Dict[str, Any]
484
- ) -> Union[List[Dict[str, Any]], Iterator[torch.nn.Parameter]]:
483
+ ) -> Union[Iterable[Dict[str, Any]], Iterable[torch.nn.Parameter]]:
485
484
  """
486
485
  The weight_decay parameter from AdamW affects the weights of layers such as LayerNorm and Embedding.
487
486
  This function creates a blacklist of network modules and splits the optimizer groups in two:
@@ -514,10 +513,17 @@ class Updater:
514
513
  if custom_param_groups is not None:
515
514
  assert callable(custom_param_groups), f"invalid param_groups_custom {custom_param_groups!r}"
516
515
  rf_model = wrapped_pt_module_to_rf_module(self.network)
517
- custom_param_groups = custom_param_groups(
518
- model=self.network, rf_model=rf_model, optimizer_class=optim_class, optimizer_opts=optimizer_opts
516
+ custom_param_groups_ = custom_param_groups(
517
+ model=self.network,
518
+ rf_model=rf_model,
519
+ optimizer_class=optim_class,
520
+ optimizer_opts=optimizer_opts,
521
+ **get_fwd_compat_kwargs(),
519
522
  )
520
- return custom_param_groups
523
+ assert isinstance(custom_param_groups_, Iterable) and all(
524
+ isinstance(group, dict) for group in custom_param_groups_
525
+ ), f"invalid param_groups_custom {custom_param_groups!r} result {custom_param_groups_!r} type"
526
+ return custom_param_groups_
521
527
 
522
528
  network_params = self.network.parameters()
523
529
 
@@ -545,7 +551,7 @@ class Updater:
545
551
  # Parameters without weight decay: biases + LayerNorm/Embedding layers.
546
552
  wd_params = set()
547
553
  no_wd_params = set()
548
- blacklist_wd_modules = optimizer_opts.pop("weight_decay_modules_blacklist", None)
554
+ blacklist_wd_modules: Any = optimizer_opts.pop("weight_decay_modules_blacklist", None)
549
555
  if blacklist_wd_modules is None:
550
556
  blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
551
557
  else:
returnn/util/basic.py CHANGED
@@ -2459,8 +2459,12 @@ class DictRefKeys(Generic[K, V]):
2459
2459
  Like `dict`, but hash and equality of the keys
2460
2460
  """
2461
2461
 
2462
- def __init__(self):
2462
+ def __init__(self, items: Union[None, Iterable[Tuple[K, V]], Dict[K, V]] = None, /, **kwargs):
2463
2463
  self._d = {} # type: Dict[RefIdEq[K], V]
2464
+ if items is not None:
2465
+ self.update(items)
2466
+ if kwargs:
2467
+ self.update(kwargs)
2464
2468
 
2465
2469
  def __repr__(self):
2466
2470
  return "DictRefKeys(%s)" % ", ".join(["%r: %r" % (k, v) for (k, v) in self.items()])
@@ -2489,6 +2493,15 @@ class DictRefKeys(Generic[K, V]):
2489
2493
  def __contains__(self, item: K):
2490
2494
  return RefIdEq(item) in self._d
2491
2495
 
2496
+ def update(self, other: Union[Dict[K, V], Iterable[Tuple[K, V]]], /):
2497
+ """
2498
+ :param other: dict or iterable of (key, value) tuples
2499
+ """
2500
+ if isinstance(other, dict):
2501
+ other = other.items()
2502
+ for k, v in other:
2503
+ self[k] = v
2504
+
2492
2505
 
2493
2506
  def make_dll_name(basename):
2494
2507
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250819.10249
3
+ Version: 1.20250820.171158
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=40ciCZzddEgWfHHnfFmRo7cpK8dukyBH8HYxTaEd5XY,5214
1
+ returnn/PKG-INFO,sha256=Cbw-LFRDg3cxVzUdgT7yaNpnAUDTjOw5nRYgq8jjH8A,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=poqBmOb1nT6ZUEe9pGqrp91VojHENemIF3zMYt80T4g,77
6
+ returnn/_setup_info_generated.py,sha256=OQkFtzn37F7h2FgPXE84vkON338V4uRqbMen7VsZWR8,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
210
210
  returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
211
- returnn/torch/updater.py,sha256=Vyh5w6ZFVc1hQvyyoWpeienQdlBVLZ2HYfjFZRQB3cQ,30035
211
+ returnn/torch/updater.py,sha256=-v_uY-8jDhreXPmjJYR4cgrlW_7ZI4kt2X2xIZdX_DE,30377
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
214
214
  returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
@@ -233,7 +233,7 @@ returnn/torch/util/gradient_checkpoint.py,sha256=iLy-FB65DC8O6LxzmMvFjnSdpIVpko8
233
233
  returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,1693
234
234
  returnn/torch/util/scaled_gradient.py,sha256=C5e79mpqtxdtw08OTSy413TSBSlOertRisc-ioiFIaU,3191
235
235
  returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
236
- returnn/util/basic.py,sha256=9Ig-7XLtvXk3yfycmBEhdJG-WVNDtoND3DmDyXOl018,142627
236
+ returnn/util/basic.py,sha256=UjHujX9pSu_dOgTxozWD0ujj5eSpyj_zD5vFU6bfyms,143096
237
237
  returnn/util/better_exchook.py,sha256=39yvRecluDgYhViwSkaQ8crJ_cBWI63KeEGuK4RKe5w,70843
238
238
  returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
239
239
  returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250819.10249.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250819.10249.dist-info/METADATA,sha256=40ciCZzddEgWfHHnfFmRo7cpK8dukyBH8HYxTaEd5XY,5214
258
- returnn-1.20250819.10249.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250819.10249.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250819.10249.dist-info/RECORD,,
256
+ returnn-1.20250820.171158.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250820.171158.dist-info/METADATA,sha256=Cbw-LFRDg3cxVzUdgT7yaNpnAUDTjOw5nRYgq8jjH8A,5215
258
+ returnn-1.20250820.171158.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250820.171158.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250820.171158.dist-info/RECORD,,