returnn 1.20250425.85727__py3-none-any.whl → 1.20250430.145858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250425.85727
3
+ Version: 1.20250430.145858
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250425.085727'
2
- long_version = '1.20250425.085727+git.547e726'
1
+ version = '1.20250430.145858'
2
+ long_version = '1.20250430.145858+git.6447e67'
returnn/datasets/hdf.py CHANGED
@@ -1073,6 +1073,8 @@ class SimpleHDFWriter:
1073
1073
  which can be read later by :class:`HDFDataset`.
1074
1074
 
1075
1075
  Note that we dump to a temp file first, and only at :func:`close` we move it over to the real destination.
1076
+
1077
+ Can be used as a context manager, i.e. with the `with` statement.
1076
1078
  """
1077
1079
 
1078
1080
  def __init__(
@@ -1413,6 +1415,12 @@ class SimpleHDFWriter:
1413
1415
  os.remove(self.tmp_filename)
1414
1416
  self.tmp_filename = None
1415
1417
 
1418
+ def __enter__(self):
1419
+ return self
1420
+
1421
+ def __exit__(self, exc_type, exc_val, exc_tb):
1422
+ self.close()
1423
+
1416
1424
 
1417
1425
  class HDFDatasetWriter:
1418
1426
  """
@@ -39,6 +39,7 @@ __all__ = [
39
39
  "pad_packed",
40
40
  "gather",
41
41
  "scatter",
42
+ "scatter_mean",
42
43
  "scatter_argmax",
43
44
  "scatter_logsumexp",
44
45
  "scatter_logmeanexp",
@@ -807,8 +808,8 @@ def scatter(
807
808
  :param source: [batch_dims..., indices_dim(s)..., feature_dims...]
808
809
  :param indices: [batch_dims..., indices_dim(s)...] -> out_dim
809
810
  :param indices_dim:
810
- :param mode: "sum", "max", "min", "logsumexp", "logmeanexp", "argmax".
811
- (Note: If you ever need mean, argmin, etc, please open an issue/PR.)
811
+ :param mode: "sum", "max", "min", "mean", "logsumexp", "logmeanexp", "argmax".
812
+ (Note: If you ever need another mode, please open an issue/PR.)
812
813
  :param fill_value:
813
814
  :param out_dim: The indices target dim.
814
815
  If not given, will be automatically determined as the sparse_dim from indices.
@@ -817,6 +818,8 @@ def scatter(
817
818
  :param use_mask:
818
819
  :return: [batch_dims..., out_dim(s)..., feature_dims...]
819
820
  """
821
+ if mode == "mean":
822
+ return scatter_mean(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
820
823
  if mode == "logsumexp":
821
824
  return scatter_logsumexp(
822
825
  source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim
@@ -863,6 +866,35 @@ def scatter(
863
866
  return out
864
867
 
865
868
 
869
+ def scatter_mean(
870
+ source: Tensor,
871
+ *,
872
+ indices: Tensor,
873
+ indices_dim: Union[Dim, Sequence[Dim]],
874
+ fill_value: Optional[Union[int, float]] = None,
875
+ out_dim: Optional[Union[Dim, Sequence[Dim]]] = None,
876
+ ) -> Tensor:
877
+ """
878
+ Scatters into new zero-tensor.
879
+ If entries in indices are duplicated, the corresponding values in source will be mean'ed together.
880
+ This is like :func:`scatter` with ``mode="mean"``.
881
+
882
+ :param source: [batch_dims..., indices_dim(s)..., feature_dims...]
883
+ :param indices: [batch_dims..., indices_dim(s)...] -> out_dim
884
+ :param indices_dim:
885
+ :param fill_value:
886
+ :param out_dim: The indices target dim.
887
+ If not given, will be automatically determined as the sparse_dim from indices.
888
+ If multiple out dims, use indices into the merged out dims,
889
+ and then we use :func:`rf.split_dims` afterwards.
890
+ :return: [batch_dims..., out_dim(s)..., feature_dims...]
891
+ """
892
+ ones = rf.ones(dims=indices.dims, dtype=source.dtype, device=source.device)
893
+ counts = rf.scatter(ones, indices=indices, indices_dim=indices_dim, fill_value=1, out_dim=out_dim)
894
+ y = scatter(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
895
+ return y / counts
896
+
897
+
866
898
  def scatter_argmax(
867
899
  source: Tensor,
868
900
  *,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250425.85727
3
+ Version: 1.20250430.145858
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=d21Lj9SyyBGpYI76EmdiZA-0o461Y9Obq0xdEUtuym0,5214
1
+ returnn/PKG-INFO,sha256=h2zLUzvzAFUbBjf8XCgka4s04uduHZmGrkLit9JyDow,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=mn_I1EXYlj19uyaatMxZYx7k2X5fwxbUzLRPoDt_juI,77
6
+ returnn/_setup_info_generated.py,sha256=mczEirj8AMV_Zo8t71vBq_GPVXqMycX3DI91nhfpNmo,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -19,7 +19,7 @@ returnn/datasets/cached.py,sha256=DIRdWrxBmsZG8O_9eVxBO5mcdo4f5KU-Xb-4wVz59Io,25
19
19
  returnn/datasets/cached2.py,sha256=_6pza3IG68JexaExhj1ld3fP6pE7T-G804driJ9Z_qo,12141
20
20
  returnn/datasets/distrib_files.py,sha256=wMOP0GX4vwaSwKtcHPEcj_zFKS__xVNNCKze5JkZ930,29881
21
21
  returnn/datasets/generating.py,sha256=E_6KpnSu8ChqG3pb4VTChWDsBTonIwFFAj53SI9NSow,99846
22
- returnn/datasets/hdf.py,sha256=yqzr-nzqlt02QZoW2uFowKT19gd5e-9mJpHCKSQxW8o,67643
22
+ returnn/datasets/hdf.py,sha256=fPlzmZtyblyzurRkqQUWKAWDqwzU6NPdJEqF2OuIEpU,67833
23
23
  returnn/datasets/lm.py,sha256=5hSdBgmgTP0IzO2p-JjiWtny0Zb0M20goXtjlw4JVR4,99206
24
24
  returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
25
25
  returnn/datasets/meta.py,sha256=EySwPQUqIAzvocAoSpMxszHbymXjJeCSGhDn0T1BO-0,95355
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,84
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=4A3MSRM0i86J77550uR_AjcBEPu6nymLUZ9Xd1V3Fkc,12073
83
- returnn/frontend/array_.py,sha256=eYwH-NVAoHpVrFdJv08lCqh3jvfoZV_ZBEoWHjsBz0o,50090
83
+ returnn/frontend/array_.py,sha256=o_NSq87pB5I2XvFUjk40Dobqx6tTfEY1wzgmaelujgM,51511
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250425.85727.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250425.85727.dist-info/METADATA,sha256=d21Lj9SyyBGpYI76EmdiZA-0o461Y9Obq0xdEUtuym0,5214
258
- returnn-1.20250425.85727.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250425.85727.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250425.85727.dist-info/RECORD,,
256
+ returnn-1.20250430.145858.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250430.145858.dist-info/METADATA,sha256=h2zLUzvzAFUbBjf8XCgka4s04uduHZmGrkLit9JyDow,5215
258
+ returnn-1.20250430.145858.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250430.145858.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250430.145858.dist-info/RECORD,,