returnn 1.20250429.161207__py3-none-any.whl → 1.20250430.145858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250429.161207
3
+ Version: 1.20250430.145858
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250429.161207'
2
- long_version = '1.20250429.161207+git.d4d26bf'
1
+ version = '1.20250430.145858'
2
+ long_version = '1.20250430.145858+git.6447e67'
@@ -39,6 +39,7 @@ __all__ = [
39
39
  "pad_packed",
40
40
  "gather",
41
41
  "scatter",
42
+ "scatter_mean",
42
43
  "scatter_argmax",
43
44
  "scatter_logsumexp",
44
45
  "scatter_logmeanexp",
@@ -807,8 +808,8 @@ def scatter(
807
808
  :param source: [batch_dims..., indices_dim(s)..., feature_dims...]
808
809
  :param indices: [batch_dims..., indices_dim(s)...] -> out_dim
809
810
  :param indices_dim:
810
- :param mode: "sum", "max", "min", "logsumexp", "logmeanexp", "argmax".
811
- (Note: If you ever need mean, argmin, etc, please open an issue/PR.)
811
+ :param mode: "sum", "max", "min", "mean", "logsumexp", "logmeanexp", "argmax".
812
+ (Note: If you ever need another mode, please open an issue/PR.)
812
813
  :param fill_value:
813
814
  :param out_dim: The indices target dim.
814
815
  If not given, will be automatically determined as the sparse_dim from indices.
@@ -817,6 +818,8 @@ def scatter(
817
818
  :param use_mask:
818
819
  :return: [batch_dims..., out_dim(s)..., feature_dims...]
819
820
  """
821
+ if mode == "mean":
822
+ return scatter_mean(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
820
823
  if mode == "logsumexp":
821
824
  return scatter_logsumexp(
822
825
  source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim
@@ -863,6 +866,35 @@ def scatter(
863
866
  return out
864
867
 
865
868
 
869
+ def scatter_mean(
870
+ source: Tensor,
871
+ *,
872
+ indices: Tensor,
873
+ indices_dim: Union[Dim, Sequence[Dim]],
874
+ fill_value: Optional[Union[int, float]] = None,
875
+ out_dim: Optional[Union[Dim, Sequence[Dim]]] = None,
876
+ ) -> Tensor:
877
+ """
878
+ Scatters into new zero-tensor.
879
+ If entries in indices are duplicated, the corresponding values in source will be mean'ed together.
880
+ This is like :func:`scatter` with ``mode="mean"``.
881
+
882
+ :param source: [batch_dims..., indices_dim(s)..., feature_dims...]
883
+ :param indices: [batch_dims..., indices_dim(s)...] -> out_dim
884
+ :param indices_dim:
885
+ :param fill_value:
886
+ :param out_dim: The indices target dim.
887
+ If not given, will be automatically determined as the sparse_dim from indices.
888
+ If multiple out dims, use indices into the merged out dims,
889
+ and then we use :func:`rf.split_dims` afterwards.
890
+ :return: [batch_dims..., out_dim(s)..., feature_dims...]
891
+ """
892
+ ones = rf.ones(dims=indices.dims, dtype=source.dtype, device=source.device)
893
+ counts = rf.scatter(ones, indices=indices, indices_dim=indices_dim, fill_value=1, out_dim=out_dim)
894
+ y = scatter(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
895
+ return y / counts
896
+
897
+
866
898
  def scatter_argmax(
867
899
  source: Tensor,
868
900
  *,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250429.161207
3
+ Version: 1.20250430.145858
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=HyJ82YwGT_Vw0szcqg-a4TqmDp-OQ2YsRXQ8rM5gTg8,5215
1
+ returnn/PKG-INFO,sha256=h2zLUzvzAFUbBjf8XCgka4s04uduHZmGrkLit9JyDow,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=DR3lkkJjhQ4JxpXrXSh3Yr0jKAWL6KbWNKJUE9DMJPs,77
6
+ returnn/_setup_info_generated.py,sha256=mczEirj8AMV_Zo8t71vBq_GPVXqMycX3DI91nhfpNmo,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,84
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=4A3MSRM0i86J77550uR_AjcBEPu6nymLUZ9Xd1V3Fkc,12073
83
- returnn/frontend/array_.py,sha256=eYwH-NVAoHpVrFdJv08lCqh3jvfoZV_ZBEoWHjsBz0o,50090
83
+ returnn/frontend/array_.py,sha256=o_NSq87pB5I2XvFUjk40Dobqx6tTfEY1wzgmaelujgM,51511
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250429.161207.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250429.161207.dist-info/METADATA,sha256=HyJ82YwGT_Vw0szcqg-a4TqmDp-OQ2YsRXQ8rM5gTg8,5215
258
- returnn-1.20250429.161207.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250429.161207.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250429.161207.dist-info/RECORD,,
256
+ returnn-1.20250430.145858.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250430.145858.dist-info/METADATA,sha256=h2zLUzvzAFUbBjf8XCgka4s04uduHZmGrkLit9JyDow,5215
258
+ returnn-1.20250430.145858.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250430.145858.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250430.145858.dist-info/RECORD,,