returnn 1.20250902.133328__py3-none-any.whl → 1.20250904.121337__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250902.133328
3
+ Version: 1.20250904.121337
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250902.133328'
2
- long_version = '1.20250902.133328+git.75be98d'
1
+ version = '1.20250904.121337'
2
+ long_version = '1.20250904.121337+git.3e0ed7d'
@@ -1312,6 +1312,7 @@ def top_p_mask(
1312
1312
  axis: Dim,
1313
1313
  p: Union[float, Tensor],
1314
1314
  one_more: bool = True,
1315
+ min_tokens_to_keep: int = 1,
1315
1316
  ) -> Tensor:
1316
1317
  """
1317
1318
  Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
@@ -1321,6 +1322,8 @@ def top_p_mask(
1321
1322
  :param p: the probability mass to keep
1322
1323
  :param one_more: if True (default), keep also the first token above the threshold.
1323
1324
  (It's enabled by default to follow the behavior of the original implementation.)
1325
+ :param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
1326
+ With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
1324
1327
  :return: mask {probs_dims..., axis} of the top-p tokens.
1325
1328
  ``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
1326
1329
  """
@@ -1334,5 +1337,7 @@ def top_p_mask(
1334
1337
  if one_more:
1335
1338
  # keep also the first token above the threshold
1336
1339
  mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
1340
+ if min_tokens_to_keep > (1 if one_more else 0):
1341
+ mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
1337
1342
  mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
1338
1343
  return mask
@@ -2184,7 +2184,7 @@ class _DimMixin:
2184
2184
  other = other.dimension # makes matching easier
2185
2185
  if isinstance(other, int) and other == 1:
2186
2186
  return self
2187
- if self.is_constant_static_dim() and isinstance(other, _d.Dim):
2187
+ if self.is_constant_static_dim() and isinstance(other, _d.Dim) and not other.is_constant_static_dim():
2188
2188
  return self.dimension * other # use rmul
2189
2189
  cache_key = ("mul", other)
2190
2190
  cache = self.get_same_base()._make_extra().cache_dim_math
@@ -2571,14 +2571,19 @@ class _MathFindMatchingAdditive:
2571
2571
 
2572
2572
 
2573
2573
  def _math_find_matching_mult(start: Dim, other: Union[int, Dim], *, right: bool) -> Optional[Dim]:
2574
- if (isinstance(other, int) or other.is_constant_static_dim()) and start.is_constant_static_dim():
2574
+ # we assume, if other is Dim, then it is not constant static dim
2575
+ if isinstance(other, int) and start.is_constant_static_dim():
2575
2576
  return _math_get_dim_via_bin_op([start, other] if right else [other, start], "mul")
2576
2577
  c_op = start.derived_from_op
2577
2578
  if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2:
2578
2579
  if right:
2579
2580
  return c_op.inputs[0] * (c_op.inputs[1] * other)
2580
- else:
2581
- return (other * c_op.inputs[0]) * c_op.inputs[1]
2581
+ # Don't do right=False -> (other * c_op.inputs[0]) * c_op.inputs[1],
2582
+ # because this can lead to infinite recursions,
2583
+ # and also we don't have a proper normalized form for multiplication.
2584
+ # However, if both left-most factors are constant static dims, then we can merge it.
2585
+ elif isinstance(other, int) and c_op.inputs[0].is_constant_static_dim():
2586
+ return (other * c_op.inputs[0].dimension) * c_op.inputs[1]
2582
2587
  return None
2583
2588
 
2584
2589
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250902.133328
3
+ Version: 1.20250904.121337
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=-ZvhSryWeKNwB4MqJwJvrQ4-_RmbzINVbo0YIJGkPus,5215
1
+ returnn/PKG-INFO,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=4NOZ-Gcm8VzS9-1m9hc0Kf87l-aNk43-f_x9EdT_c9M,77
6
+ returnn/_setup_info_generated.py,sha256=T0irKkmiSBL4AXt10Ru0CoqiCOYguD_lDSMZW2GPivw,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,85
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
83
- returnn/frontend/array_.py,sha256=j6rayxqV4ki5vohH-ZC7N3J8_CouNCRRRP_pE89O-rE,53921
83
+ returnn/frontend/array_.py,sha256=ci9NnYqwDxryOoiHCNg8DbOb9yWJScVSm7FnCqgywPY,54257
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -154,7 +154,7 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
154
154
  returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
155
155
  returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
156
156
  returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
157
- returnn/tensor/_dim_extra.py,sha256=rwtDR5WRS8wqgKj4WkPaWtaKa8UJYTrS76ZhX0W5bP4,115580
157
+ returnn/tensor/_dim_extra.py,sha256=D1lDB-zjF1tPhBQFApbui2AlyARdTx0hIFKRhTtk4T4,116033
158
158
  returnn/tensor/_tensor_extra.py,sha256=gbSl6HMtn8WFYloanew_RaNNwx3eCpnKv3UfCkntJiQ,164923
159
159
  returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
160
160
  returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250902.133328.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250902.133328.dist-info/METADATA,sha256=-ZvhSryWeKNwB4MqJwJvrQ4-_RmbzINVbo0YIJGkPus,5215
258
- returnn-1.20250902.133328.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250902.133328.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250902.133328.dist-info/RECORD,,
256
+ returnn-1.20250904.121337.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250904.121337.dist-info/METADATA,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
258
+ returnn-1.20250904.121337.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250904.121337.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250904.121337.dist-info/RECORD,,