returnn 1.20250903.215851__py3-none-any.whl → 1.20250904.142552__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250903.215851
3
+ Version: 1.20250904.142552
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250903.215851'
2
- long_version = '1.20250903.215851+git.7651133'
1
+ version = '1.20250904.142552'
2
+ long_version = '1.20250904.142552+git.6024737'
@@ -1312,6 +1312,7 @@ def top_p_mask(
1312
1312
  axis: Dim,
1313
1313
  p: Union[float, Tensor],
1314
1314
  one_more: bool = True,
1315
+ min_tokens_to_keep: int = 1,
1315
1316
  ) -> Tensor:
1316
1317
  """
1317
1318
  Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
@@ -1321,6 +1322,8 @@ def top_p_mask(
1321
1322
  :param p: the probability mass to keep
1322
1323
  :param one_more: if True (default), keep also the first token above the threshold.
1323
1324
  (It's enabled by default to follow the behavior of the original implementation.)
1325
+ :param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
1326
+ With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
1324
1327
  :return: mask {probs_dims..., axis} of the top-p tokens.
1325
1328
  ``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
1326
1329
  """
@@ -1334,5 +1337,7 @@ def top_p_mask(
1334
1337
  if one_more:
1335
1338
  # keep also the first token above the threshold
1336
1339
  mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
1340
+ if min_tokens_to_keep > (1 if one_more else 0):
1341
+ mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
1337
1342
  mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
1338
1343
  return mask
@@ -10488,6 +10488,7 @@ class TopKLayer(LayerBase):
10488
10488
  self._sub_layers = {}
10489
10489
  for key, (v, a) in sub_outputs.items():
10490
10490
  sub_out_data = self.output.copy_template(name="%s/%s" % (self.name, key))
10491
+ sub_out_data.feature_dim = None
10491
10492
  sub_out_data.dtype = "int32"
10492
10493
  sub_out_data.sparse_dim = a
10493
10494
  sub_out_data.placeholder = v
@@ -10527,6 +10528,7 @@ class TopKLayer(LayerBase):
10527
10528
  axis = [in_data.get_dim_tag_from_description(a) for a in axis]
10528
10529
  out_dims = [dim for dim in in_data.dim_tags if dim not in axis] + [k_dim]
10529
10530
  out_data = in_data.copy_template(name=name).copy_template_new_dim_tags(out_dims)
10531
+ out_data.feature_dim = None
10530
10532
  if for_indices is not None:
10531
10533
  assert 0 <= for_indices < len(axis)
10532
10534
  out_data.dtype = "int32"
@@ -1572,6 +1572,7 @@ class TorchBackend(Backend[torch.Tensor]):
1572
1572
  indices_out_raw = indices_raw % a.dimension
1573
1573
  indices_raw = indices_raw // a.dimension
1574
1574
  indices = values.copy_template(name=f"top_k_indices_{a.name or i}")
1575
+ indices.feature_dim = None
1575
1576
  indices.dtype = TorchBackend.get_dtype_name_raw(indices_out_raw)
1576
1577
  indices.sparse_dim = a
1577
1578
  indices.raw_tensor = indices_out_raw
@@ -1588,6 +1589,7 @@ class TorchBackend(Backend[torch.Tensor]):
1588
1589
  values = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_values")
1589
1590
  values.raw_tensor = values_raw
1590
1591
  indices = source.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=k_dim, name="top_k_indices")
1592
+ indices.feature_dim = None
1591
1593
  indices.dtype = TorchBackend.get_dtype_name_raw(indices_raw)
1592
1594
  indices.sparse_dim = axis
1593
1595
  indices.raw_tensor = indices_raw
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250903.215851
3
+ Version: 1.20250904.142552
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=S-XwueZUBumfIMPZ4Qb-kGHRIs68AQJmjdrJW7duTXw,5215
1
+ returnn/PKG-INFO,sha256=KpNzqgPrE8JNlqZK4ogzbI8bQolj51p_nX4vs8qfw04,5215
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=qKXi9Lxgy7ppKbJEo1P_VbQHVIGgicUKX9qjJTVqo3o,77
6
+ returnn/_setup_info_generated.py,sha256=iGayBdALfCwuaadY-5wJH64F9dsvpvsr_oHZ7kD0uq8,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -80,7 +80,7 @@ returnn/frontend/_cache.py,sha256=Uao2xzfvVaKABk1fkxcpXzxKIGJaI9FwwlTvvoNUstk,85
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
82
82
  returnn/frontend/_utils.py,sha256=uVQldGHyYKIyhSEmumJ04ix5eP5tjZw4CEC0w6-zhyQ,12074
83
- returnn/frontend/array_.py,sha256=j6rayxqV4ki5vohH-ZC7N3J8_CouNCRRRP_pE89O-rE,53921
83
+ returnn/frontend/array_.py,sha256=ci9NnYqwDxryOoiHCNg8DbOb9yWJScVSm7FnCqgywPY,54257
84
84
  returnn/frontend/attention.py,sha256=GKt-Xqnz8sIyXVrE0i4VCS7J2Wu7dmoH_BA0Cu8CrXQ,45769
85
85
  returnn/frontend/backend.py,sha256=iQ9w4xl8Ea7bgpb0VUaCKq50rV5Bl2E5J8Rhd-oqD_c,883
86
86
  returnn/frontend/build_from_dict.py,sha256=rfWa2rjjhIR_kIQED_nMrygrQBunS6unegzWTLVbC98,3017
@@ -193,7 +193,7 @@ returnn/tf/frontend_low_level/__init__.py,sha256=34469k3KzMUIGowxReOZnbf6WdTjxY7
193
193
  returnn/tf/frontend_low_level/_backend.py,sha256=JwwRRIGnElqBC4bTImdB7w3U1u_SJESeZHYLmq86wog,24479
194
194
  returnn/tf/layers/__init__.py,sha256=Ngu-X84nWFgz7ndDu88DqoZ-5lUMMTQWH4g7N8pSoCg,72
195
195
  returnn/tf/layers/base.py,sha256=sUxEfh6WxaHWHG7O3cfxB6gG6YpEHkFKUJVayKvTBSI,152968
196
- returnn/tf/layers/basic.py,sha256=zHDPLP97jSvYYZcMPqQVOVxFk6I1BfXd71XVfs0VIkQ,615386
196
+ returnn/tf/layers/basic.py,sha256=PMYNoMq8qH41QhWhJPg5Uc409GZHkcnecouorg9sqJY,615466
197
197
  returnn/tf/layers/rec.py,sha256=3f6M_5aAMPvx7aAHdPV3VSFRHf7tjpp8lrXSzmk1I5c,548435
198
198
  returnn/tf/layers/segmental_model.py,sha256=wUyDZGr-eTVIIQWcsHLML0wtOxuWn_NFKOIrUKQcvoI,21515
199
199
  returnn/tf/layers/signal_processing.py,sha256=vRlkN7k7otk9_Qdv0qr_l6V0VT5Q6dO2MxwZWb2HH2M,52693
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
216
216
  returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
217
217
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
218
218
  returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
219
- returnn/torch/frontend/_backend.py,sha256=1o6v9neXLTGVu_53QmoPn_2DbbuBC-iyojL9qe5DYBQ,103166
219
+ returnn/torch/frontend/_backend.py,sha256=XeiXlfGK8hy4wmMbVhQCTY7o4FFZ6TZb5cO2FgKl2zw,103244
220
220
  returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
221
221
  returnn/torch/frontend/bridge.py,sha256=c_mVBCBo29sjm8Bhxarv00szwGPgxjwoIqAHOmceGQw,7842
222
222
  returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250903.215851.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250903.215851.dist-info/METADATA,sha256=S-XwueZUBumfIMPZ4Qb-kGHRIs68AQJmjdrJW7duTXw,5215
258
- returnn-1.20250903.215851.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250903.215851.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250903.215851.dist-info/RECORD,,
256
+ returnn-1.20250904.142552.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250904.142552.dist-info/METADATA,sha256=KpNzqgPrE8JNlqZK4ogzbI8bQolj51p_nX4vs8qfw04,5215
258
+ returnn-1.20250904.142552.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250904.142552.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250904.142552.dist-info/RECORD,,