returnn 1.20250903.215851__py3-none-any.whl → 1.20250904.142552__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/frontend/array_.py +5 -0
- returnn/tf/layers/basic.py +2 -0
- returnn/torch/frontend/_backend.py +2 -0
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.142552.dist-info}/METADATA +1 -1
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.142552.dist-info}/RECORD +10 -10
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.142552.dist-info}/LICENSE +0 -0
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.142552.dist-info}/WHEEL +0 -0
- {returnn-1.20250903.215851.dist-info → returnn-1.20250904.142552.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20250904.142552'
|
|
2
|
+
long_version = '1.20250904.142552+git.6024737'
|
returnn/frontend/array_.py
CHANGED
|
@@ -1312,6 +1312,7 @@ def top_p_mask(
|
|
|
1312
1312
|
axis: Dim,
|
|
1313
1313
|
p: Union[float, Tensor],
|
|
1314
1314
|
one_more: bool = True,
|
|
1315
|
+
min_tokens_to_keep: int = 1,
|
|
1315
1316
|
) -> Tensor:
|
|
1316
1317
|
"""
|
|
1317
1318
|
Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
|
|
@@ -1321,6 +1322,8 @@ def top_p_mask(
|
|
|
1321
1322
|
:param p: the probability mass to keep
|
|
1322
1323
|
:param one_more: if True (default), keep also the first token above the threshold.
|
|
1323
1324
|
(It's enabled by default to follow the behavior of the original implementation.)
|
|
1325
|
+
:param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
|
|
1326
|
+
With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
|
|
1324
1327
|
:return: mask {probs_dims..., axis} of the top-p tokens.
|
|
1325
1328
|
``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
|
|
1326
1329
|
"""
|
|
@@ -1334,5 +1337,7 @@ def top_p_mask(
|
|
|
1334
1337
|
if one_more:
|
|
1335
1338
|
# keep also the first token above the threshold
|
|
1336
1339
|
mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
|
|
1340
|
+
if min_tokens_to_keep > (1 if one_more else 0):
|
|
1341
|
+
mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
|
|
1337
1342
|
mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
|
|
1338
1343
|
return mask
|
returnn/tf/layers/basic.py
CHANGED
|
@@ -10488,6 +10488,7 @@ class TopKLayer(LayerBase):
|
|
|
10488
10488
|
self._sub_layers = {}
|
|
10489
10489
|
for key, (v, a) in sub_outputs.items():
|
|
10490
10490
|
sub_out_data = self.output.copy_template(name="%s/%s" % (self.name, key))
|
|
10491
|
+
sub_out_data.feature_dim = None
|
|
10491
10492
|
sub_out_data.dtype = "int32"
|
|
10492
10493
|
sub_out_data.sparse_dim = a
|
|
10493
10494
|
sub_out_data.placeholder = v
|
|
@@ -10527,6 +10528,7 @@ class TopKLayer(LayerBase):
|
|
|
10527
10528
|
axis = [in_data.get_dim_tag_from_description(a) for a in axis]
|
|
10528
10529
|
out_dims = [dim for dim in in_data.dim_tags if dim not in axis] + [k_dim]
|
|
10529
10530
|
out_data = in_data.copy_template(name=name).copy_template_new_dim_tags(out_dims)
|
|
10531
|
+
out_data.feature_dim = None
|
|
10530
10532
|
if for_indices is not None:
|
|
10531
10533
|
assert 0 <= for_indices < len(axis)
|
|
10532
10534
|
out_data.dtype = "int32"
|
|
@@ -1572,6 +1572,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1572
1572
|
indices_out_raw = indices_raw % a.dimension
|
|
1573
1573
|
indices_raw = indices_raw // a.dimension
|
|
1574
1574
|
indices = values.copy_template(name=f"top_k_indices_{a.name or i}")
|
|
1575
|
+
indices.feature_dim = None
|
|
1575
1576
|
indices.dtype = TorchBackend.get_dtype_name_raw(indices_out_raw)
|
|
1576
1577
|
indices.sparse_dim = a
|
|
1577
1578
|
indices.raw_tensor = indices_out_raw
|
|
@@ -1588,6 +1589,7 @@ class TorchBackend(Backend[torch.Tensor]):
|
|
|
1588
1589
|
values = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_values")
|
|
1589
1590
|
values.raw_tensor = values_raw
|
|
1590
1591
|
indices = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_indices")
|
|
1592
|
+
indices.feature_dim = None
|
|
1591
1593
|
indices.dtype = TorchBackend.get_dtype_name_raw(indices_raw)
|
|
1592
1594
|
indices.sparse_dim = axis
|
|
1593
1595
|
indices.raw_tensor = indices_raw
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=KpNzqgPrE8JNlqZK4ogzbI8bQolj51p_nX4vs8qfw04,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=iGayBdALfCwuaadY-5wJH64F9dsvpvsr_oHZ7kD0uq8,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,85
|
|
|
80
80
|
returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
|
|
81
81
|
returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
|
|
82
82
|
returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
|
|
83
|
-
returnn/frontend/array_.py,sha256=
|
|
83
|
+
returnn/frontend/array_.py,sha256=ci9NnYqwDxryOoiHCNg8DbOb9yWJScVSm7FnCqgywPY,54257
|
|
84
84
|
returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
|
|
85
85
|
returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
|
|
86
86
|
returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
|
|
@@ -193,7 +193,7 @@ returnn/tf/frontend_low_level/__init__.py,sha256=34469k3KzMUIGowxReOZnbf6WdTjxY7
|
|
|
193
193
|
returnn/tf/frontend_low_level/_backend.py,sha256=JwwRRIGnElqBC4bTImdB7w3U1u_SJESeZHYLmq86wog,24479
|
|
194
194
|
returnn/tf/layers/__init__.py,sha256=Ngu-X84nWFgz7ndDu88DqoZ-5lUMMTQWH4g7N8pSoCg,72
|
|
195
195
|
returnn/tf/layers/base.py,sha256=sUxEfh6WxaHWHG7O3cfxB6gG6YpEHkFKUJVayKvTBSI,152968
|
|
196
|
-
returnn/tf/layers/basic.py,sha256=
|
|
196
|
+
returnn/tf/layers/basic.py,sha256=PMYNoMq8qH41QhWhJPg5Uc409GZHkcnecouorg9sqJY,615466
|
|
197
197
|
returnn/tf/layers/rec.py,sha256=3f6M_5aAMPvx7aAHdPV3VSFRHf7tjpp8lrXSzmk1I5c,548435
|
|
198
198
|
returnn/tf/layers/segmental_model.py,sha256=wUyDZGr-eTVIIQWcsHLML0wtOxuWn_NFKOIrUKQcvoI,21515
|
|
199
199
|
returnn/tf/layers/signal_processing.py,sha256=vRlkN7k7otk9_Qdv0qr_l6V0VT5Q6dO2MxwZWb2HH2M,52693
|
|
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
|
|
|
216
216
|
returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
|
|
217
217
|
returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
|
|
218
218
|
returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
|
|
219
|
-
returnn/torch/frontend/_backend.py,sha256=
|
|
219
|
+
returnn/torch/frontend/_backend.py,sha256=XeiXlfGK8hy4wmMbVhQCTY7o4FFZ6TZb5cO2FgKl2zw,103244
|
|
220
220
|
returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
|
|
221
221
|
returnn/torch/frontend/bridge.py,sha256=c_mVBCBo29sjm8Bhxarv00szwGPgxjwoIqAHOmceGQw,7842
|
|
222
222
|
returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
|
|
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
253
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
254
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
255
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
256
|
+
returnn-1.20250904.142552.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
257
|
+
returnn-1.20250904.142552.dist-info/METADATA,sha256=KpNzqgPrE8JNlqZK4ogzbI8bQolj51p_nX4vs8qfw04,5215
|
|
258
|
+
returnn-1.20250904.142552.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
259
|
+
returnn-1.20250904.142552.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
260
|
+
returnn-1.20250904.142552.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|