returnn 1.20250814.205214__py3-none-any.whl → 1.20250819.10249__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of returnn might be problematic. Click here for more details.

returnn/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250814.205214
3
+ Version: 1.20250819.10249
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,2 +1,2 @@
1
- version = '1.20250814.205214'
2
- long_version = '1.20250814.205214+git.26a4c87'
1
+ version = '1.20250819.010249'
2
+ long_version = '1.20250819.010249+git.9c1f159'
@@ -1153,6 +1153,13 @@ class Backend(Generic[T]):
1153
1153
  """
1154
1154
  raise NotImplementedError
1155
1155
 
1156
+ @staticmethod
1157
+ def random_choice_with_replacement(dims: Sequence[Dim], *, probs: Tensor, axis: Dim) -> Tensor:
1158
+ """
1159
+ random choice with replacement. See `rf.random_choice_with_replacement` for details.
1160
+ """
1161
+ raise NotImplementedError
1162
+
1156
1163
  @staticmethod
1157
1164
  def masked_select(
1158
1165
  tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Optional[Dim] = None
returnn/frontend/rand.py CHANGED
@@ -65,6 +65,7 @@ __all__ = [
65
65
  "random_normal",
66
66
  "random_truncated_normal",
67
67
  "random_choice_without_replacement",
68
+ "random_choice_with_replacement",
68
69
  ]
69
70
 
70
71
 
@@ -379,3 +380,16 @@ def random_choice_without_replacement(
379
380
  scores = log_probs + scores_random_sample
380
381
  _, indices, _ = rf.top_k(scores, k_dim=num_samples_dim, axis=axis)
381
382
  return indices
383
+
384
+
385
+ def random_choice_with_replacement(dims: Sequence[Dim], *, probs: Tensor, axis: Dim) -> Tensor:
386
+ """
387
+ Randomly sample with replacement.
388
+
389
+ :param dims: {common_dims..., new_dims...}. Defines how many samples to draw. Defines the output shape.
390
+ :param probs: {common_dims..., axis}
391
+ :param axis: feature axis, where to sample from
392
+ :return: random indices shape dims -> axis.
393
+ """
394
+ # noinspection PyProtectedMember
395
+ return probs._raw_backend.random_choice_with_replacement(dims=dims, probs=probs, axis=axis)
@@ -993,7 +993,10 @@ class TorchBackend(Backend[torch.Tensor]):
993
993
  if clip_to_valid:
994
994
  if axis.dyn_size_ext is not None:
995
995
  indices = rf.clip_by_value(
996
- indices, 0, axis.get_dyn_size_ext_for_device(indices.device) - 1, allow_broadcast_all_sources=True
996
+ indices,
997
+ 0,
998
+ rf.cast(axis.get_dyn_size_ext_for_device(indices.device), indices.dtype) - 1,
999
+ allow_broadcast_all_sources=True,
997
1000
  )
998
1001
  else:
999
1002
  indices = indices.copy()
@@ -1718,6 +1721,28 @@ class TorchBackend(Backend[torch.Tensor]):
1718
1721
  )
1719
1722
  return out
1720
1723
 
1724
+ @staticmethod
1725
+ def random_choice_with_replacement(dims: Sequence[Dim], *, probs: Tensor, axis: Dim) -> Tensor:
1726
+ """random choice with replacement"""
1727
+ assert all(d == axis or d in dims for d in probs.dims), (
1728
+ f"random_choice_with_replacement: dims {dims} not compatible with probs {probs} and axis {axis}"
1729
+ )
1730
+ common_dims = [d for d in dims if d in probs.dims]
1731
+ assert axis not in common_dims
1732
+ probs = probs.copy_transpose(common_dims + [axis])
1733
+ non_common_dims = [d for d in dims if d not in common_dims]
1734
+ num_samples = prod([d.get_dim_value() for d in non_common_dims])
1735
+ if len(common_dims) >= 2:
1736
+ probs, flat_common_dim = rf.merge_dims(probs, dims=common_dims)
1737
+ out_raw = torch.multinomial(probs.raw_tensor, num_samples=num_samples, replacement=True)
1738
+ out_raw = out_raw.reshape(
1739
+ [d.get_dim_value() for d in common_dims] + [d.get_dim_value() for d in non_common_dims]
1740
+ )
1741
+ out = rf.convert_to_tensor(out_raw, dims=common_dims + non_common_dims, sparse_dim=axis)
1742
+ out = out.copy_transpose(dims)
1743
+ out.name = "random_choice_with_replacement"
1744
+ return out
1745
+
1721
1746
  @staticmethod
1722
1747
  def masked_select(
1723
1748
  tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Optional[Dim] = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: returnn
3
- Version: 1.20250814.205214
3
+ Version: 1.20250819.10249
4
4
  Summary: The RWTH extensible training framework for universal recurrent neural networks
5
5
  Home-page: https://github.com/rwth-i6/returnn/
6
6
  Author: Albert Zeyer
@@ -1,9 +1,9 @@
1
- returnn/PKG-INFO,sha256=dDWKCPCaWRWqMS1-os5IZOo9uFYqHjXcDFKcI19Hqzo,5215
1
+ returnn/PKG-INFO,sha256=40ciCZzddEgWfHHnfFmRo7cpK8dukyBH8HYxTaEd5XY,5214
2
2
  returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
3
3
  returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
4
4
  returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
5
5
  returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
6
- returnn/_setup_info_generated.py,sha256=Z6UV10sIpgDmM1cDSqz87OBlBD3G32Ss3yF09lnHfnQ,77
6
+ returnn/_setup_info_generated.py,sha256=poqBmOb1nT6ZUEe9pGqrp91VojHENemIF3zMYt80T4g,77
7
7
  returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
8
8
  returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
9
9
  returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
@@ -75,7 +75,7 @@ returnn/extern/graph_editor/subgraph.py,sha256=q9o0zVBLDrTIidaXg5WG5daDW0mLbwv2J
75
75
  returnn/extern/graph_editor/transform.py,sha256=qMGSenpbAnGqdG6QP6iWjlm6_ccySYJaZKOoAj1dbOM,29348
76
76
  returnn/extern/graph_editor/util.py,sha256=HfRbyQPmQ6_n5-O-096n0KeJtllQXFtaurpeJS_URZ0,18706
77
77
  returnn/frontend/__init__.py,sha256=2aS7nbxXniIrBp2DODl0xN0f3IJ_dX4Bi9ZlR7W5_DE,1472
78
- returnn/frontend/_backend.py,sha256=pAnVAbZhIGKD-10tp0Mx7AO1GZNghYu7AVAPhiimN-k,50471
78
+ returnn/frontend/_backend.py,sha256=39l5MC1DaT0MPklMM8HXAW9nqisIIZQ9g2QSHOOtPQE,50741
79
79
  returnn/frontend/_cache.py,sha256=JAhi7L-raQ3A-NC3JUYDtdRTwT3BGJJGGZxrZ8MfEWQ,8403
80
80
  returnn/frontend/_numpy_backend.py,sha256=fZjks7p3dgxVZ6tSDazTTgBxNjJqXjfqgw_7mA7rDEE,9066
81
81
  returnn/frontend/_random_journal.py,sha256=_ktP_mjgx8vtQQGX_DofdhewJj0aPiczefTWeemPkmo,5457
@@ -110,7 +110,7 @@ returnn/frontend/parameter.py,sha256=zvrkhSYC1c_O9kVwgHvOtOnWNurl5J28lkS0i1LQpWU
110
110
  returnn/frontend/parametrizations.py,sha256=ptNgBw5IiPXVpB3QGse7AGAhdXp8X1rCqYUl2Mae8aI,2876
111
111
  returnn/frontend/parametrize.py,sha256=VhgTEP7ehON950Q4bkCy8rvg9641moEKAXn0XzomK6E,7216
112
112
  returnn/frontend/piecewise_linear.py,sha256=TdL6wzop8P1dcIZwkEbJFvSUZSI1cbhS3XKzlWQkEVI,1964
113
- returnn/frontend/rand.py,sha256=Levgf5VtOOBKDSgz0869Jf3VW4BWxYZuRXsa_fOxNI4,12969
113
+ returnn/frontend/rand.py,sha256=2x7AHSYH_tZkzTk_q3t3GA_yYRNeKsVbJjw2InqSGDk,13542
114
114
  returnn/frontend/rec.py,sha256=6YSsSG7fdtfvvg24vmexSg8R2aVCcKHBdGLh-Mgn9Co,8037
115
115
  returnn/frontend/reduce.py,sha256=gRSvBJZNHa757IqBxGw4hu5eiO3pjie_ptEwUXHLSCs,10340
116
116
  returnn/frontend/run_ctx.py,sha256=yyOMUCKTOe19C4z2Nfly4YCLBmQ9ihip6nGrkW-Y6qg,23789
@@ -216,7 +216,7 @@ returnn/torch/data/queued_data_iter.py,sha256=PoOsGHdHVZjTmcyfq_ZOw--P6hyfTdmAWI
216
216
  returnn/torch/data/returnn_dataset_wrapper.py,sha256=2CaDapzrlqahANuq-nyVAtv5ENHuM8A7okORwYJDisg,8006
217
217
  returnn/torch/data/tensor_utils.py,sha256=-Teqi--LLbt6q_5mDRdoHZHmPgSdC83W706ukif_YiU,1284
218
218
  returnn/torch/frontend/__init__.py,sha256=AA48HZnC17ASuKA0EWy8loZ-Bib_yUtqF4T1wYvjst4,62
219
- returnn/torch/frontend/_backend.py,sha256=a9qcpUJrSDtH7KR6ZIpB4sijm6ztRlZ4myAe2P0dtaE,101875
219
+ returnn/torch/frontend/_backend.py,sha256=1o6v9neXLTGVu_53QmoPn_2DbbuBC-iyojL9qe5DYBQ,103166
220
220
  returnn/torch/frontend/_rand.py,sha256=1JgIkV2XmpgJD86zXZ-NCAe-QuoP2swr6NaS1oz3Qa8,1830
221
221
  returnn/torch/frontend/bridge.py,sha256=c_mVBCBo29sjm8Bhxarv00szwGPgxjwoIqAHOmceGQw,7842
222
222
  returnn/torch/frontend/raw_ops.py,sha256=lF0h-KtYYsdaaqQADylVZp9qzPskOOXA4MfmYDyx5IU,296
@@ -253,8 +253,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
253
253
  returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
254
254
  returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
255
255
  returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
256
- returnn-1.20250814.205214.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
- returnn-1.20250814.205214.dist-info/METADATA,sha256=dDWKCPCaWRWqMS1-os5IZOo9uFYqHjXcDFKcI19Hqzo,5215
258
- returnn-1.20250814.205214.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
- returnn-1.20250814.205214.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
- returnn-1.20250814.205214.dist-info/RECORD,,
256
+ returnn-1.20250819.10249.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
257
+ returnn-1.20250819.10249.dist-info/METADATA,sha256=40ciCZzddEgWfHHnfFmRo7cpK8dukyBH8HYxTaEd5XY,5214
258
+ returnn-1.20250819.10249.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
259
+ returnn-1.20250819.10249.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
260
+ returnn-1.20250819.10249.dist-info/RECORD,,