returnn 1.20250902.133328__py3-none-any.whl → 1.20250904.121337__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/array_.py +5 -0
- returnn/tensor/_dim_extra.py +9 -4
- {returnn-1.20250902.133328.dist-info → returnn-1.20250904.121337.dist-info}/METADATA +1 -1
- {returnn-1.20250902.133328.dist-info → returnn-1.20250904.121337.dist-info}/RECORD +9 -9
- {returnn-1.20250902.133328.dist-info → returnn-1.20250904.121337.dist-info}/LICENSE +0 -0
- {returnn-1.20250902.133328.dist-info → returnn-1.20250904.121337.dist-info}/WHEEL +0 -0
- {returnn-1.20250902.133328.dist-info → returnn-1.20250904.121337.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250904.121337'
|
|
2
|
+
long_version = '1.20250904.121337+git.3e0ed7d'
|
returnn/frontend/array_.py
CHANGED
|
@@ -1312,6 +1312,7 @@ def top_p_mask(
|
|
|
1312
1312
|
axis: Dim,
|
|
1313
1313
|
p: Union[float, Tensor],
|
|
1314
1314
|
one_more: bool = True,
|
|
1315
|
+
min_tokens_to_keep: int = 1,
|
|
1315
1316
|
) -> Tensor:
|
|
1316
1317
|
"""
|
|
1317
1318
|
Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
|
|
@@ -1321,6 +1322,8 @@ def top_p_mask(
|
|
|
1321
1322
|
:param p: the probability mass to keep
|
|
1322
1323
|
:param one_more: if True (default), keep also the first token above the threshold.
|
|
1323
1324
|
(It's enabled by default to follow the behavior of the original implementation.)
|
|
1325
|
+
:param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
|
|
1326
|
+
With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
|
|
1324
1327
|
:return: mask {probs_dims..., axis} of the top-p tokens.
|
|
1325
1328
|
``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
|
|
1326
1329
|
"""
|
|
@@ -1334,5 +1337,7 @@ def top_p_mask(
|
|
|
1334
1337
|
if one_more:
|
|
1335
1338
|
# keep also the first token above the threshold
|
|
1336
1339
|
mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
|
|
1340
|
+
if min_tokens_to_keep > (1 if one_more else 0):
|
|
1341
|
+
mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
|
|
1337
1342
|
mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
|
|
1338
1343
|
return mask
|
returnn/tensor/_dim_extra.py
CHANGED
|
@@ -2184,7 +2184,7 @@ class _DimMixin:
|
|
|
2184
2184
|
other = other.dimension # makes matching easier
|
|
2185
2185
|
if isinstance(other, int) and other == 1:
|
|
2186
2186
|
return self
|
|
2187
|
-
if self.is_constant_static_dim() and isinstance(other, _d.Dim):
|
|
2187
|
+
if self.is_constant_static_dim() and isinstance(other, _d.Dim) and not other.is_constant_static_dim():
|
|
2188
2188
|
return self.dimension * other # use rmul
|
|
2189
2189
|
cache_key = ("mul", other)
|
|
2190
2190
|
cache = self.get_same_base()._make_extra().cache_dim_math
|
|
@@ -2571,14 +2571,19 @@ class _MathFindMatchingAdditive:
|
|
|
2571
2571
|
|
|
2572
2572
|
|
|
2573
2573
|
def _math_find_matching_mult(start: Dim, other: Union[int, Dim], *, right: bool) -> Optional[Dim]:
|
|
2574
|
-
if
|
|
2574
|
+
# we assume, if other is Dim, then it is not constant static dim
|
|
2575
|
+
if isinstance(other, int) and start.is_constant_static_dim():
|
|
2575
2576
|
return _math_get_dim_via_bin_op([start, other] if right else [other, start], "mul")
|
|
2576
2577
|
c_op = start.derived_from_op
|
|
2577
2578
|
if c_op and c_op.kind == "mul" and len(c_op.inputs) == 2:
|
|
2578
2579
|
if right:
|
|
2579
2580
|
return c_op.inputs[0] * (c_op.inputs[1] * other)
|
|
2580
|
-
|
|
2581
|
-
|
|
2581
|
+
# Don't do right=False -> (other * c_op.inputs[0]) * c_op.inputs[1],
|
|
2582
|
+
# because this can lead to infinite recursions,
|
|
2583
|
+
# and also we don't have a proper normalized form for multiplication.
|
|
2584
|
+
# However, if both left-most factors are constant static dims, then we can merge it.
|
|
2585
|
+
elif isinstance(other, int) and c_op.inputs[0].is_constant_static_dim():
|
|
2586
|
+
return (other * c_op.inputs[0].dimension) * c_op.inputs[1]
|
|
2582
2587
|
return None
|
|
2583
2588
|
|
|
2584
2589
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256
|
|
1
|
+
returnn/PKG-INFO,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=T0irKkmiSBL4AXt10Ru0CoqiCOYguD_lDSMZW2GPivw,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,85
|
|
|
80
80
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=ci9NnYqwDxryOoiHCNg8DbOb9yWJScVSm7FnCqgywPY,54257
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -154,7 +154,7 @@ returnn/sprint/extern_interface.py,sha256=l-v1X-Yg0UpTFe7Y3c4FwWOqpSNuv9Oy5EzqlK
|
|
|
154
154
|
returnn/sprint/interface.py,sha256=1j5SB0V8hSW8A5song9ciZtcBnZoKKfNipk9ezOIMuA,36491
|
|
155
155
|
returnn/tensor/README.md,sha256=X6BqcRLrPLPnwF9yR69uqIFrMnNluj9pBkOPHwNgzuo,501
|
|
156
156
|
returnn/tensor/__init__.py,sha256=on6j5PEOQpck50UcsR4nJzJSDmoVy34z1Oq4efv6Ax0,154
|
|
157
|
-
returnn/tensor/_dim_extra.py,sha256=
|
|
157
|
+
returnn/tensor/_dim_extra.py,sha256=D1lDB-zjF1tPhBQFApbui2AlyARdTx0hIFKRhTtk4T4,116033
|
|
158
158
|
returnn/tensor/_tensor_extra.py,sha256=gbSl6HMtn8WFYloanew_RaNNwx3eCpnKv3UfCkntJiQ,164923
|
|
159
159
|
returnn/tensor/_tensor_mixin_base.py,sha256=H5z86I0NejxrSgMH1c5oXQzBqS6L9HpvP4y7oegBaSc,643
|
|
160
160
|
returnn/tensor/_tensor_op_overloads.py,sha256=HklwuTBjy7mH_665VKaCUdu-oC3aa7Uz1ZQiCz4jeZc,5448
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250904.121337.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250904.121337.dist-info/METADATA,sha256=E0KAiZhUNf-yyIY-MbzXczY25ftmtsffuGszhUWrvUU,5215
|
|
258
|
+
returnn-1.20250904.121337.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250904.121337.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250904.121337.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|