returnn 1.20250903.215851__py3-none-any.whl → 1.20250904.121337__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/array_.py +5 -0
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.121337.dist-info}/METADATA +1 -1
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.121337.dist-info}/RECORD +8 -8
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.121337.dist-info}/LICENSE +0 -0
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.121337.dist-info}/WHEEL +0 -0
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.121337.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250904.121337'
|
|
2
|
+
long_version = '1.20250904.121337+git.3e0ed7d'
|
returnn/frontend/array_.py
CHANGED
|
@@ -1312,6 +1312,7 @@ def top_p_mask(
|
|
|
1312
1312
|
axis: Dim,
|
|
1313
1313
|
p: Union[float, Tensor],
|
|
1314
1314
|
one_more: bool = True,
|
|
1315
|
+
min_tokens_to_keep: int = 1,
|
|
1315
1316
|
) -> Tensor:
|
|
1316
1317
|
"""
|
|
1317
1318
|
Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
|
|
@@ -1321,6 +1322,8 @@ def top_p_mask(
|
|
|
1321
1322
|
:param p: the probability mass to keep
|
|
1322
1323
|
:param one_more: if True (default), keep also the first token above the threshold.
|
|
1323
1324
|
(It's enabled by default to follow the behavior of the original implementation.)
|
|
1325
|
+
:param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
|
|
1326
|
+
With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
|
|
1324
1327
|
:return: mask {probs_dims..., axis} of the top-p tokens.
|
|
1325
1328
|
``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
|
|
1326
1329
|
"""
|
|
@@ -1334,5 +1337,7 @@ def top_p_mask(
|
|
|
1334
1337
|
if one_more:
|
|
1335
1338
|
# keep also the first token above the threshold
|
|
1336
1339
|
mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
|
|
1340
|
+
if min_tokens_to_keep > (1 if one_more else 0):
|
|
1341
|
+
mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
|
|
1337
1342
|
mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
|
|
1338
1343
|
return mask
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=T0irKkmiSBL4AXt10Ru0CoqiCOYguD_lDSMZW2GPivw,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,85
|
|
|
80
80
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=ci9NnYqwDxryOoiHCNg8DbOb9yWJScVSm7FnCqgywPY,54257
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250904.121337.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250904.121337.dist-info/METADATA,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
|
|
258
|
+
returnn-1.20250904.121337.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250904.121337.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250904.121337.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|