returnn 1.20250903.215851__py3-none-any.whl → 1.20250904.121337__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250903.215851
3
+ Version: 1.20250904.121337
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250903.215851'
2
- long_version = '1.20250903.215851+git.7651133'
1
+ version = '1.20250904.121337'
2
+ long_version = '1.20250904.121337+git.3e0ed7d'
@@ -1312,6 +1312,7 @@ def top_p_mask(
1312
1312
  axis: Dim,
1313
1313
  p: Union[float, Tensor],
1314
1314
  one_more: bool = True,
1315
+ min_tokens_to_keep: int = 1,
1315
1316
  ) -> Tensor:
1316
1317
  """
1317
1318
  Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
@@ -1321,6 +1322,8 @@ def top_p_mask(
1321
1322
  :param p: the probability mass to keep
1322
1323
  :param one_more: if True (default), keep also the first token above the threshold.
1323
1324
  (It's enabled by default to follow the behavior of the original implementation.)
1325
+ :param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
1326
+ With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
1324
1327
  :return: mask {probs_dims..., axis} of the top-p tokens.
1325
1328
  ``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
1326
1329
  """
@@ -1334,5 +1337,7 @@ def top_p_mask(
1334
1337
  if one_more:
1335
1338
  # keep also the first token above the threshold
1336
1339
  mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
1340
+ if min_tokens_to_keep > (1 if one_more else 0):
1341
+ mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
1337
1342
  mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
1338
1343
  return mask
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250903.215851
3
+ Version: 1.20250904.121337
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=S-XwueZUBumfIMPZ4Qb-kGHRIs68AQJmjdrJW7duTXw,5215
1
+ returnn/PKG-INFO,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=qKXi9Lxgy7ppKbJEo1P_VbQHVIGgicUKX9qjJTVqo3o,77
6
+ returnn/_setup_info_generated.py,sha256=T0irKkmiSBL4AXt10Ru0CoqiCOYguD_lDSMZW2GPivw,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,85
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
83
- returnn/frontend/array_.py,sha256=j6rayxqV4ki5vohH-ZC7N3J8_CouNCRRRP_pE89O-rE,53921
83
+ returnn/frontend/array_.py,sha256=ci9NnYqwDxryOoiHCNg8DbOb9yWJScVSm7FnCqgywPY,54257
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250903.215851.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250903.215851.dist-info/METADATA,sha256=S-XwueZUBumfIMPZ4Qb-kGHRIs68AQJmjdrJW7duTXw,5215
258
- returnn-1.20250903.215851.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250903.215851.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250903.215851.dist-info/RECORD,,
256
+ returnn-1.20250904.121337.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250904.121337.dist-info/METADATA,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
258
+ returnn-1.20250904.121337.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250904.121337.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250904.121337.dist-info/RECORD,,