returnn 1.20250429.161207__py3-none-any.whl → 1.20250508.93313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250429.161207
3
+ Version: 1.20250508.93313
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250429.161207'
2
- long_version = '1.20250429.161207+git.d4d26bf'
1
+ version = '1.20250508.093313'
2
+ long_version = '1.20250508.093313+git.4f05ac7'
@@ -39,6 +39,7 @@ __all__ = [
39
39
  "pad_packed",
40
40
  "gather",
41
41
  "scatter",
42
+ "scatter_mean",
42
43
  "scatter_argmax",
43
44
  "scatter_logsumexp",
44
45
  "scatter_logmeanexp",
@@ -807,8 +808,8 @@ def scatter(
807
808
  :param source: [batch_dims..., indices_dim(s)..., feature_dims...]
808
809
  :param indices: [batch_dims..., indices_dim(s)...] -> out_dim
809
810
  :param indices_dim:
810
- :param mode: "sum", "max", "min", "logsumexp", "logmeanexp", "argmax".
811
- (Note: If you ever need mean, argmin, etc, please open an issue/PR.)
811
+ :param mode: "sum", "max", "min", "mean", "logsumexp", "logmeanexp", "argmax".
812
+ (Note: If you ever need another mode, please open an issue/PR.)
812
813
  :param fill_value:
813
814
  :param out_dim: The indices target dim.
814
815
  If not given, will be automatically determined as the sparse_dim from indices.
@@ -817,6 +818,8 @@ def scatter(
817
818
  :param use_mask:
818
819
  :return: [batch_dims..., out_dim(s)..., feature_dims...]
819
820
  """
821
+ if mode == "mean":
822
+ return scatter_mean(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
820
823
  if mode == "logsumexp":
821
824
  return scatter_logsumexp(
822
825
  source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim
@@ -863,6 +866,35 @@ def scatter(
863
866
  return out
864
867
 
865
868
 
869
+ def scatter_mean(
870
+ source: Tensor,
871
+ *,
872
+ indices: Tensor,
873
+ indices_dim: Union[Dim, Sequence[Dim]],
874
+ fill_value: Optional[Union[int, float]] = None,
875
+ out_dim: Optional[Union[Dim, Sequence[Dim]]] = None,
876
+ ) -> Tensor:
877
+ """
878
+ Scatters into new zero-tensor.
879
+ If entries in indices are duplicated, the corresponding values in source will be mean'ed together.
880
+ This is like :func:`scatter` with ``mode="mean"``.
881
+
882
+ :param source: [batch_dims..., indices_dim(s)..., feature_dims...]
883
+ :param indices: [batch_dims..., indices_dim(s)...] -> out_dim
884
+ :param indices_dim:
885
+ :param fill_value:
886
+ :param out_dim: The indices target dim.
887
+ If not given, will be automatically determined as the sparse_dim from indices.
888
+ If multiple out dims, use indices into the merged out dims,
889
+ and then we use :func:`rf.split_dims` afterwards.
890
+ :return: [batch_dims..., out_dim(s)..., feature_dims...]
891
+ """
892
+ ones = rf.ones(dims=indices.dims, dtype=source.dtype, device=source.device)
893
+ counts = rf.scatter(ones, indices=indices, indices_dim=indices_dim, fill_value=1, out_dim=out_dim)
894
+ y = scatter(source, indices=indices, indices_dim=indices_dim, fill_value=fill_value, out_dim=out_dim)
895
+ return y / counts
896
+
897
+
866
898
  def scatter_argmax(
867
899
  source: Tensor,
868
900
  *,
@@ -58,6 +58,7 @@ import threading
58
58
  import keyword
59
59
  import inspect
60
60
  import contextlib
61
+ from weakref import WeakKeyDictionary
61
62
 
62
63
  try:
63
64
  import typing
@@ -1564,6 +1565,9 @@ def get_func_str_from_code_object(co, frame=None):
1564
1565
  return co.co_name
1565
1566
 
1566
1567
 
1568
+ _func_from_code_object_cache = WeakKeyDictionary() # code object -> function
1569
+
1570
+
1567
1571
  def get_func_from_code_object(co, frame=None):
1568
1572
  """
1569
1573
  :param types.CodeType co:
@@ -1580,6 +1584,11 @@ def get_func_from_code_object(co, frame=None):
1580
1584
  import types
1581
1585
 
1582
1586
  assert isinstance(co, (types.CodeType, DummyFrame))
1587
+ co_is_code_object = isinstance(co, types.CodeType)
1588
+ if co_is_code_object:
1589
+ candidate = _func_from_code_object_cache.get(co)
1590
+ if candidate:
1591
+ return candidate
1583
1592
  _attr_name = "__code__" if PY3 else "func_code"
1584
1593
  if frame and frame.f_code.co_nlocals > 0:
1585
1594
  func_name = frame.f_code.co_name
@@ -1587,18 +1596,23 @@ def get_func_from_code_object(co, frame=None):
1587
1596
  if frame_self is not None:
1588
1597
  candidate = getattr(frame_self.__class__, func_name, None)
1589
1598
  if candidate and (getattr(candidate, _attr_name, None) is co or isinstance(co, DummyFrame)):
1599
+ if co_is_code_object:
1600
+ _func_from_code_object_cache[co] = candidate
1590
1601
  return candidate
1591
1602
  try:
1592
1603
  candidate = getattr(_get_loaded_module_from_filename(co.co_filename), co.co_name, None)
1593
1604
  except ImportError: # some modules have lazy loaders, but those might fail here
1594
1605
  candidate = None
1595
1606
  if candidate and (getattr(candidate, _attr_name, None) is co or isinstance(co, DummyFrame)):
1607
+ if co_is_code_object:
1608
+ _func_from_code_object_cache[co] = candidate
1596
1609
  return candidate
1597
1610
  if isinstance(co, DummyFrame):
1598
1611
  return None
1599
1612
  candidates = gc.get_referrers(co)
1600
1613
  candidates = [f for f in candidates if getattr(f, _attr_name, None) is co]
1601
1614
  if candidates:
1615
+ _func_from_code_object_cache[co] = candidates[0]
1602
1616
  return candidates[0]
1603
1617
  return None
1604
1618
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250429.161207
3
+ Version: 1.20250508.93313
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=HyJ82YwGT_Vw0szcqg-a4TqmDp-OQ2YsRXQ8rM5gTg8,5215
1
+ returnn/PKG-INFO,sha256=Rsc8t3mGIL2V2iJflkOuNthzptMNFK9cMC-KiImaBDg,5214
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=qBFbuB1yN3adgVM5pXt2-Yq9vorjRNchNPL8kDKx44M,31752
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=DR3lkkJjhQ4JxpXrXSh3Yr0jKAWL6KbWNKJUE9DMJPs,77
6
+ returnn/_setup_info_generated.py,sha256=cYPqIVcVewY03CWaIMT3pOmIFS88wUWDSj14rjsgy3Q,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,84
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=4A3MSRM0i86J77550uR_AjcBEPu6nymLUZ9Xd1V3Fkc,12073
83
- returnn/frontend/array_.py,sha256=eYwH-NVAoHpVrFdJv08lCqh3jvfoZV_ZBEoWHjsBz0o,50090
83
+ returnn/frontend/array_.py,sha256=o_NSq87pB5I2XvFUjk40Dobqx6tTfEY1wzgmaelujgM,51511
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -234,7 +234,7 @@ returnn/torch/util/module.py,sha256=MXHIrF9Isu575DDJIa81212ULKwdqu1oOLxDVZecVSk,
234
234
  returnn/torch/util/scaled_gradient.py,sha256=3585VuNypBty-pW6r3BKK047H3MqZQSdMjXeYAb4cmU,3192
235
235
  returnn/util/__init__.py,sha256=UIG1qw4idqhW71BV60ha7h9PktxvEVcBIu0lYRossK8,336
236
236
  returnn/util/basic.py,sha256=jwtaaZyOV7fUjKDRXVHDy-K5kwR1mPrkAZrzc5STOvE,142554
237
- returnn/util/better_exchook.py,sha256=TAtb_ZyM-357UnOg_HMoBZUSxzt0WPgumlvprmlCprA,63921
237
+ returnn/util/better_exchook.py,sha256=98XnUZIWpYN7NfklSGt_5hYNplADVFQnh857esKxjdI,64475
238
238
  returnn/util/bpe.py,sha256=LWFhICZsEOnMwNws0lybPNzKRX6rSr8yKCvP65vjl9Y,19656
239
239
  returnn/util/debug.py,sha256=wuRzdg9zB84WWCGyTjmRR_zYypu8gXxlc0nZ6si9OC8,28224
240
240
  returnn/util/debug_helpers.py,sha256=0EINLK4uLtoSt5_kHs1M2NIFpMd0S7i4c4rx90U4fJk,2914
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250429.161207.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250429.161207.dist-info/METADATA,sha256=HyJ82YwGT_Vw0szcqg-a4TqmDp-OQ2YsRXQ8rM5gTg8,5215
258
- returnn-1.20250429.161207.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250429.161207.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250429.161207.dist-info/RECORD,,
256
+ returnn-1.20250508.93313.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250508.93313.dist-info/METADATA,sha256=Rsc8t3mGIL2V2iJflkOuNthzptMNFK9cMC-KiImaBDg,5214
258
+ returnn-1.20250508.93313.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250508.93313.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250508.93313.dist-info/RECORD,,