returnn 1.20250817.33823__py3-none-any.whl → 1.20250820.123936__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250817.33823
3
+ Version: 1.20250820.123936
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250817.033823'
2
- long_version = '1.20250817.033823+git.4f41ff3'
1
+ version = '1.20250820.123936'
2
+ long_version = '1.20250820.123936+git.9c74169'
@@ -993,7 +993,10 @@ class TorchBackend(Backend[torch.Tensor]):
993
993
  if clip_to_valid:
994
994
  if axis.dyn_size_ext is not None:
995
995
  indices = rf.clip_by_value(
996
- indices, 0, axis.get_dyn_size_ext_for_device(indices.device) - 1, allow_broadcast_all_sources=True
996
+ indices,
997
+ 0,
998
+ rf.cast(axis.get_dyn_size_ext_for_device(indices.device), indices.dtype) - 1,
999
+ allow_broadcast_all_sources=True,
997
1000
  )
998
1001
  else:
999
1002
  indices = indices.copy()
returnn/torch/updater.py CHANGED
@@ -5,11 +5,10 @@ and model param update logic in general.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Iterator, Set, Dict, List, Tuple
8
+ from typing import Optional, Union, Any, Type, Callable, Sequence, Iterable, Set, Dict, List, Tuple
9
9
  import os
10
10
  import gc
11
11
  import torch
12
- import typing
13
12
 
14
13
  import returnn
15
14
  from returnn.log import log
@@ -130,8 +129,8 @@ class Updater:
130
129
  else:
131
130
  raise NotImplementedError("not implemented for not callable dynamic_learning_rate")
132
131
 
133
- self._optimizer_opts = None
134
- self.optimizer = None # type: typing.Optional[torch.optim.Optimizer]
132
+ self._optimizer_opts: Optional[Dict[str, Any]] = None
133
+ self.optimizer: Optional[torch.optim.Optimizer] = None
135
134
 
136
135
  self._grad_clip = self.config.float("gradient_clip", 0.0)
137
136
  self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
@@ -481,7 +480,7 @@ class Updater:
481
480
 
482
481
  def _get_optimizer_param_groups(
483
482
  self, optim_class: Type[torch.optim.Optimizer], optimizer_opts: Dict[str, Any]
484
- ) -> Union[List[Dict[str, Any]], Iterator[torch.nn.Parameter]]:
483
+ ) -> Union[Iterable[Dict[str, Any]], Iterable[torch.nn.Parameter]]:
485
484
  """
486
485
  The weight_decay parameter from AdamW affects the weights of layers such as LayerNorm and Embedding.
487
486
  This function creates a blacklist of network modules and splits the optimizer groups in two:
@@ -514,10 +513,13 @@ class Updater:
514
513
  if custom_param_groups is not None:
515
514
  assert callable(custom_param_groups), f"invalid param_groups_custom {custom_param_groups!r}"
516
515
  rf_model = wrapped_pt_module_to_rf_module(self.network)
517
- custom_param_groups = custom_param_groups(
516
+ custom_param_groups_ = custom_param_groups(
518
517
  model=self.network, rf_model=rf_model, optimizer_class=optim_class, optimizer_opts=optimizer_opts
519
518
  )
520
- return custom_param_groups
519
+ assert isinstance(custom_param_groups_, Iterable) and all(
520
+ isinstance(group, dict) for group in custom_param_groups_
521
+ ), f"invalid param_groups_custom {custom_param_groups!r} result {custom_param_groups_!r} type"
522
+ return custom_param_groups_
521
523
 
522
524
  network_params = self.network.parameters()
523
525
 
@@ -545,7 +547,7 @@ class Updater:
545
547
  # Parameters without weight decay: biases + LayerNorm/Embedding layers.
546
548
  wd_params = set()
547
549
  no_wd_params = set()
548
- blacklist_wd_modules = optimizer_opts.pop("weight_decay_modules_blacklist", None)
550
+ blacklist_wd_modules: Any = optimizer_opts.pop("weight_decay_modules_blacklist", None)
549
551
  if blacklist_wd_modules is None:
550
552
  blacklist_wd_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
551
553
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250817.33823
3
+ Version: 1.20250820.123936
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=B57ObmyaxFzp_TjsXhRqtVvzxZWq3hS54vw34tZHZsQ,5214
1
+ returnn/PKG-INFO,sha256=04mKkkm6MNJQzWwBZq5enBZVM1vpvfr-0W705kv7vsk,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=BSqlencj66qyQkfuBO7fze_8cqW9p9M0Si-qLIGy1-k,77
6
+ returnn/_setup_info_generated.py,sha256=UEqWo5ULLq_sbbrAc2dmq5fjH29qlfAXGYB_hw8vpZQ,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -208,7 +208,7 @@ returnn/torch/README.md,sha256=jzJ2FpOHW02vxN69yKaV97C9LI-hmvjBglKfdZXIDdc,85
208
208
  returnn/torch/__init__.py,sha256=MHEUyNHB20Vy89uKAqZoj6FxJKF1Gq3HW-i6ra1pNcI,24
209
209
  returnn/torch/distributed.py,sha256=_lyJR71HIoCHpMi5GztGM7YwrX54Am8zSkjnDkE1Lbk,7524
210
210
  returnn/torch/engine.py,sha256=JSsQZZiVs9TxRyFEJuR3iH-YZb9sRw7TzoIAIqmplZY,78275
211
- returnn/torch/updater.py,sha256=Vyh5w6ZFVc1hQvyyoWpeienQdlBVLZ2HYfjFZRQB3cQ,30035
211
+ returnn/torch/updater.py,sha256=9Ju_LuEdXO7LMxt9rs9_6ReePG5y1h36N3coN696rVI,30285
212
212
  returnn/torch/data/__init__.py,sha256=6cLNEi8KoGI12PF6akN7mI_mtjlx-0hcQAfMYoExwik,132
213
213
  returnn/torch/data/extern_data.py,sha256=5al706ZaYtHWLp5VH2vS-rW69YXP3NHyOFRKY0WY714,7810
214
214
  returnn/torch/data/pipeline.py,sha256=HgIL0jQsPcgvh_SPC4wQ6BzclmrnpFja-UiboF_GPN4,29459
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
216
216
  returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
217
217
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
218
218
  returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
219
- returnn/torch/frontend/_backend.py,sha256=zzKN4_NJK3_I7Ehk8VlhhaXQ_jUEx8K73br8C0Q41p0,103081
219
+ returnn/torch/frontend/_backend.py,sha256=1o6v9neXLTGVu_53QmoPn_2DbbuBC-iyojL9qe5DYBQ,103166
220
220
  returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
221
221
  returnn/torch/frontend/bridge.py,sha256=c_mVBCBo29sjm8Bhxarv00szwGPgxjwoIqAHOmceGQw,7842
222
222
  returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250817.33823.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250817.33823.dist-info/METADATA,sha256=B57ObmyaxFzp_TjsXhRqtVvzxZWq3hS54vw34tZHZsQ,5214
258
- returnn-1.20250817.33823.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250817.33823.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250817.33823.dist-info/RECORD,,
256
+ returnn-1.20250820.123936.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250820.123936.dist-info/METADATA,sha256=04mKkkm6MNJQzWwBZq5enBZVM1vpvfr-0W705kv7vsk,5215
258
+ returnn-1.20250820.123936.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250820.123936.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250820.123936.dist-info/RECORD,,