returnn 1.20250429.161207__py3-none-any.whl → 1.20250430.145858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/array_.py +34 -2
- {returnn-1.20250429.161207.dist-info → returnn-1.20250430.145858.dist-info}/METADATA +1 -1
- {returnn-1.20250429.161207.dist-info → returnn-1.20250430.145858.dist-info}/RECORD +8 -8
- {returnn-1.20250429.161207.dist-info → returnn-1.20250430.145858.dist-info}/LICENSE +0 -0
- {returnn-1.20250429.161207.dist-info → returnn-1.20250430.145858.dist-info}/WHEEL +0 -0
- {returnn-1.20250429.161207.dist-info → returnn-1.20250430.145858.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250430.145858'
|
|
2
|
+
long_version = '1.20250430.145858+git.6447e67'
|
returnn/frontend/array_.py
CHANGED
|
@@ -39,6 +39,7 @@ __all__ = [
|
|
|
39
39
|
"pad_packed",
|
|
40
40
|
"gather",
|
|
41
41
|
"scatter",
|
|
42
|
+
"scatter_mean",
|
|
42
43
|
"scatter_argmax",
|
|
43
44
|
"scatter_logsumexp",
|
|
44
45
|
"scatter_logmeanexp",
|
|
@@ -807,8 +808,8 @@ def scatter(
|
|
|
807
808
|
:param source: [batch_dims..., indices_dim(s)..., feature_dims...]
|
|
808
809
|
:param indices: [batch_dims..., indices_dim(s)...] -> out_dim
|
|
809
810
|
:param indices_dim:
|
|
810
|
-
:param mode: "sum", "max", "min", "logsumexp", "logmeanexp", "argmax".
|
|
811
|
-
(Note: If you ever need
|
|
811
|
+
:param mode: "sum", "max", "min", "mean", "logsumexp", "logmeanexp", "argmax".
|
|
812
|
+
(Note: If you ever need another mode, please open an issue/PR.)
|
|
812
813
|
:param fill_value:
|
|
813
814
|
:param out_dim: The indices target dim.
|
|
814
815
|
If not given, will be automatically determined as the sparse_dim from indices.
|
|
@@ -817,6 +818,8 @@ def scatter(
|
|
|
817
818
|
:param use_mask:
|
|
818
819
|
:return: [batch_dims..., out_dim(s)..., feature_dims...]
|
|
819
820
|
"""
|
|
821
|
+
if mode == "mean":
|
|
822
|
+
return scatter_mean(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
|
|
820
823
|
if mode == "logsumexp":
|
|
821
824
|
return scatter_logsumexp(
|
|
822
825
|
source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim
|
|
@@ -863,6 +866,35 @@ def scatter(
|
|
|
863
866
|
return out
|
|
864
867
|
|
|
865
868
|
|
|
869
|
+
def scatter_mean(
|
|
870
|
+
source: Tensor,
|
|
871
|
+
*,
|
|
872
|
+
indices: Tensor,
|
|
873
|
+
indices_dim: Union[Dim, Sequence[Dim]],
|
|
874
|
+
fill_value: Optional[Union[int, float]] = None,
|
|
875
|
+
out_dim: Optional[Union[Dim, Sequence[Dim]]] = None,
|
|
876
|
+
) -> Tensor:
|
|
877
|
+
"""
|
|
878
|
+
Scatters into new zero-tensor.
|
|
879
|
+
If entries in indices are duplicated, the corresponding values in source will be mean'ed together.
|
|
880
|
+
This is like :func:`scatter` with ``mode="mean"``.
|
|
881
|
+
|
|
882
|
+
:param source: [batch_dims..., indices_dim(s)..., feature_dims...]
|
|
883
|
+
:param indices: [batch_dims..., indices_dim(s)...] -> out_dim
|
|
884
|
+
:param indices_dim:
|
|
885
|
+
:param fill_value:
|
|
886
|
+
:param out_dim: The indices target dim.
|
|
887
|
+
If not given, will be automatically determined as the sparse_dim from indices.
|
|
888
|
+
If multiple out dims, use indices into the merged out dims,
|
|
889
|
+
and then we use :func:`rf.split_dims` afterwards.
|
|
890
|
+
:return: [batch_dims..., out_dim(s)..., feature_dims...]
|
|
891
|
+
"""
|
|
892
|
+
ones = rf.ones(dims=indices.dims, dtype=source.dtype, device=source.device)
|
|
893
|
+
counts = rf.scatter(ones, indices=indices, indices_dim=indices_dim, fill_value=1, out_dim=out_dim)
|
|
894
|
+
y = scatter(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
|
|
895
|
+
return y / counts
|
|
896
|
+
|
|
897
|
+
|
|
866
898
|
def scatter_argmax(
|
|
867
899
|
source: Tensor,
|
|
868
900
|
*,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=h2zLUzvzAFUbBjf8XCgka4s04uduHZmGrkLit9JyDow,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=mczEirj8AMV_Zo8t71vBq_GPVXqMycX3DI91nhfpNmo,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,84
|
|
|
80
80
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=4A3MSRM0i86J77550uR_AjcBEPu6nymLUZ9Xd1V3Fkc,12073
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=o_NSq87pB5I2XvFUjk40Dobqx6tTfEY1wzgmaelujgM,51511
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250430.145858.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250430.145858.dist-info/METADATA,sha256=h2zLUzvzAFUbBjf8XCgka4s04uduHZmGrkLit9JyDow,5215
|
|
258
|
+
returnn-1.20250430.145858.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250430.145858.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250430.145858.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|