returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. returnn/PKG-INFO +1 -1
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +110 -42
  5. returnn/frontend/__init__.py +1 -0
  6. returnn/frontend/_backend.py +41 -0
  7. returnn/frontend/_native/__init__.py +22 -0
  8. returnn/frontend/_numpy_backend.py +7 -0
  9. returnn/frontend/_utils.py +1 -1
  10. returnn/frontend/array_.py +6 -5
  11. returnn/frontend/assert_.py +35 -0
  12. returnn/frontend/device.py +14 -1
  13. returnn/frontend/encoder/conformer.py +19 -0
  14. returnn/frontend/loss.py +183 -3
  15. returnn/frontend/math_.py +54 -14
  16. returnn/native_op.cpp +104 -174
  17. returnn/native_op.py +36 -31
  18. returnn/tensor/_dim_extra.py +7 -7
  19. returnn/tensor/_tensor_extra.py +10 -10
  20. returnn/tensor/utils.py +1 -1
  21. returnn/tf/frontend_layers/_backend.py +3 -1
  22. returnn/tf/layers/basic.py +13 -2
  23. returnn/tf/native_op.py +16 -5
  24. returnn/tf/util/basic.py +7 -201
  25. returnn/torch/engine.py +120 -3
  26. returnn/torch/frontend/_backend.py +166 -22
  27. returnn/torch/frontend/bridge.py +61 -0
  28. returnn/torch/frontend/compile_helper.py +106 -0
  29. returnn/torch/util/array_.py +30 -0
  30. returnn/torch/util/assert_.py +122 -0
  31. returnn/torch/util/native_op.py +885 -0
  32. returnn/torch/util/native_op_code_compiler.py +308 -0
  33. returnn/util/basic.py +3 -1
  34. returnn/util/cuda_env.py +332 -0
  35. returnn/util/debug.py +1 -0
  36. returnn/util/fsa.py +17 -13
  37. returnn/util/native_code_compiler.py +104 -47
  38. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
  39. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
  40. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  41. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  42. {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
returnn/frontend/loss.py CHANGED
@@ -8,7 +8,15 @@ from returnn.tensor import Tensor, Dim
8
8
  import returnn.frontend as rf
9
9
 
10
10
 
11
- __all__ = ["cross_entropy", "ctc_loss", "ctc_greedy_decode", "edit_distance"]
11
+ __all__ = [
12
+ "cross_entropy",
13
+ "ctc_loss",
14
+ "ctc_best_path",
15
+ "ctc_greedy_decode",
16
+ "ctc_durations_from_path",
17
+ "ctc_no_label_loop_blank_durations_from_path",
18
+ "edit_distance",
19
+ ]
12
20
 
13
21
 
14
22
  def cross_entropy(
@@ -64,6 +72,8 @@ def ctc_loss(
64
72
  targets_spatial_dim: Dim,
65
73
  blank_index: int,
66
74
  max_approx: bool = False,
75
+ use_native_op: Optional[bool] = None,
76
+ label_loop: bool = True,
67
77
  ) -> Tensor:
68
78
  """
69
79
  Calculates the CTC loss.
@@ -80,6 +90,8 @@ def ctc_loss(
80
90
  :param targets_spatial_dim: spatial dim of targets
81
91
  :param blank_index: vocab index of the blank symbol
82
92
  :param max_approx: if True, use max instead of sum over alignments (max approx, Viterbi)
93
+ :param use_native_op: whether to use our native op
94
+ :param label_loop:
83
95
  :return: loss shape [B...]
84
96
  """
85
97
  # noinspection PyProtectedMember
@@ -91,6 +103,42 @@ def ctc_loss(
91
103
  targets_spatial_dim=targets_spatial_dim,
92
104
  blank_index=blank_index,
93
105
  max_approx=max_approx,
106
+ use_native_op=use_native_op,
107
+ label_loop=label_loop,
108
+ )
109
+
110
+
111
+ def ctc_best_path(
112
+ *,
113
+ logits: Tensor,
114
+ logits_normalized: bool = False,
115
+ targets: Tensor,
116
+ input_spatial_dim: Dim,
117
+ targets_spatial_dim: Dim,
118
+ blank_index: int,
119
+ label_loop: bool = True,
120
+ ) -> Tensor:
121
+ """
122
+ Calculates the CTC best path.
123
+
124
+ :param logits: (before softmax). shape [B...,input_spatial,C]
125
+ :param logits_normalized: whether the logits are already normalized (e.g. via log-softmax)
126
+ :param targets: sparse. shape [B...,targets_spatial] -> C
127
+ :param input_spatial_dim: spatial dim of input logits
128
+ :param targets_spatial_dim: spatial dim of targets
129
+ :param blank_index: vocab index of the blank symbol
130
+ :param label_loop: whether label loops are allowed (standard for CTC). False is like RNA topology.
131
+ :return: best path, shape [B...,targets_spatial] -> C
132
+ """
133
+ # noinspection PyProtectedMember
134
+ return logits._raw_backend.ctc_best_path(
135
+ logits=logits,
136
+ logits_normalized=logits_normalized,
137
+ targets=targets,
138
+ input_spatial_dim=input_spatial_dim,
139
+ targets_spatial_dim=targets_spatial_dim,
140
+ blank_index=blank_index,
141
+ label_loop=label_loop,
94
142
  )
95
143
 
96
144
 
@@ -132,6 +180,133 @@ def ctc_greedy_decode(
132
180
  return labels, out_spatial_dim
133
181
 
134
182
 
183
+ def ctc_durations_from_path(
184
+ *,
185
+ path: Tensor,
186
+ path_spatial_dim: Dim,
187
+ blank_index: int,
188
+ targets_spatial_dim: Optional[Dim] = None,
189
+ out_spatial_dim: Optional[Dim] = None,
190
+ ) -> Tuple[Tensor, Dim]:
191
+ """
192
+ Given a CTC path (alignment), compute the durations of each label + blanks.
193
+ Specifically, assuming that we have N labels in the target sequence,
194
+ there are N labels and N+1 blank durations,
195
+ (one before the first label, one after the last label, and one between each pair of labels),
196
+ resulting in a total of 2N+1 durations.
197
+ The returned durations tensor will have shape [B,...,T'] where T' = 2 * N + 1,
198
+ corresponding to durations for state sequence [blank_0, label_1, blank_1, label_2, ..., label_N, blank_N].
199
+
200
+ :param path: CTC path (alignment), shape [B...,path_spatial_dim] -> label indices (including blanks)
201
+ :param path_spatial_dim: spatial dim of path
202
+ :param blank_index: index of the blank label
203
+ :param targets_spatial_dim: if given, asserts that the computed number of labels matches this size
204
+ :param out_spatial_dim: if given, asserts that the output spatial dim size matches 2 * target_spatial_dim + 1
205
+ :return: (durations, out_spatial_dim).
206
+ durations shape [B...,out_spatial_dim] where out_spatial_dim = 2 * N + 1,
207
+ where N is the number of labels in the target sequence.
208
+ """
209
+ # example path: [_ _ a a b _ _ c c c _]
210
+ path_shifted = rf.shift_right(path, axis=path_spatial_dim, pad_value=blank_index)
211
+ # path_shifted: [_ _ _ a a b _ _ c c c]
212
+ new_label_mask = rf.logical_and(path != blank_index, path != path_shifted)
213
+ new_label_mask = new_label_mask.copy_masked(False, dims=[path_spatial_dim])
214
+ num_labels = rf.reduce_sum(rf.cast(new_label_mask, "int32"), axis=path_spatial_dim)
215
+ if targets_spatial_dim is not None:
216
+ rf.assert_(
217
+ targets_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels,
218
+ "target_spatial_dim size does not match number of labels in path",
219
+ )
220
+ else:
221
+ targets_spatial_dim = Dim(
222
+ rf.copy_to_device(num_labels, rf.get_default_dim_size_device()), name="target_spatial"
223
+ )
224
+ # new_label_mask: [0 0 1 0 1 0 0 1 0 0 0]
225
+ blank_idx = rf.cumsum(rf.cast(new_label_mask, "int32"), spatial_dim=path_spatial_dim)
226
+ # label_idx = blank_idx - 1
227
+ # label_idx: [-1 -1 0 0 1 1 1 2 2 2 2]
228
+ # blank_idx: [0 0 1 1 2 2 2 3 3 3 3]
229
+ blank_idx_x2 = blank_idx * 2
230
+ # blank_idx_x2: [0 0 2 2 4 4 4 6 6 6 6]
231
+ state_idx = blank_idx_x2 + rf.where(path == blank_index, 0, -1)
232
+ # state_idx: [0 0 1 1 3 4 4 5 5 5 6]
233
+ if out_spatial_dim is not None:
234
+ rf.assert_(
235
+ out_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels * 2 + 1,
236
+ "out_spatial_dim size does not match 2 * target_spatial_dim + 1",
237
+ )
238
+ else:
239
+ out_spatial_dim = targets_spatial_dim * 2 + 1
240
+ out = rf.scatter(rf.ones_like(state_idx), indices=state_idx, indices_dim=path_spatial_dim, out_dim=out_spatial_dim)
241
+ # out state seq: [ _ a _ b _ c _ ]
242
+ # out: [ 2 2 0 1 2 3 1 ]
243
+ return out, out_spatial_dim
244
+
245
+
246
+ def ctc_no_label_loop_blank_durations_from_path(
247
+ *,
248
+ path: Tensor,
249
+ path_spatial_dim: Dim,
250
+ blank_index: int,
251
+ targets_spatial_dim: Optional[Dim] = None,
252
+ out_spatial_dim: Optional[Dim] = None,
253
+ ) -> Tuple[Tensor, Dim]:
254
+ """
255
+ Given a CTC-without-label-loop (``label_loop=False`` in :func:`ctc_best_path`) (RNA) path (alignment),
256
+ compute the durations of all the blanks.
257
+ Specifically, assuming that we have N labels in the target sequence,
258
+ there are N+1 blank durations
259
+ (one before the first label, one after the last label, and one between each pair of labels).
260
+
261
+ :param path: CTC path (alignment), shape [B...,path_spatial_dim] -> label indices (including blanks)
262
+ :param path_spatial_dim: spatial dim of path
263
+ :param blank_index: index of the blank label
264
+ :param targets_spatial_dim: if given, asserts that the computed number of labels matches this size
265
+ :param out_spatial_dim: if given, asserts that the output spatial dim size matches target_spatial_dim + 1
266
+ :return: (durations, out_spatial_dim),
267
+ durations is for the blank labels,
268
+ durations shape [B...,out_spatial_dim] where out_spatial_dim = N + 1,
269
+ where N is the number of labels in the target sequence.
270
+ """
271
+ # example path: [_ _ _ a b _ _ c _]
272
+ new_label_mask = path != blank_index
273
+ new_label_mask = new_label_mask.copy_masked(False, dims=[path_spatial_dim])
274
+ num_labels = rf.reduce_sum(rf.cast(new_label_mask, "int32"), axis=path_spatial_dim)
275
+ if targets_spatial_dim is not None:
276
+ rf.assert_(
277
+ targets_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels,
278
+ "target_spatial_dim size does not match number of labels in path",
279
+ )
280
+ else:
281
+ targets_spatial_dim = Dim(
282
+ rf.copy_to_device(num_labels, rf.get_default_dim_size_device()), name="target_spatial"
283
+ )
284
+ # new_label_mask: [0 0 0 1 1 0 0 1 0]
285
+ blank_idx = rf.cumsum(rf.cast(new_label_mask, "int32"), spatial_dim=path_spatial_dim)
286
+ # blank_idx: [0 0 0 1 2 2 2 3 3]
287
+ blank_idx = rf.where(
288
+ (path == blank_index) & rf.sequence_mask(path_spatial_dim, device=path.device),
289
+ blank_idx,
290
+ rf.reduce_max(num_labels, axis=num_labels.dims) + 1,
291
+ )
292
+ # blank_idx: [0 0 0 4 4 2 2 4 3]
293
+ if out_spatial_dim is not None:
294
+ rf.assert_(
295
+ out_spatial_dim.get_size_tensor(device=num_labels.device) == num_labels + 1,
296
+ "out_spatial_dim size does not match 2 * target_spatial_dim + 1",
297
+ )
298
+ else:
299
+ out_spatial_dim = targets_spatial_dim + 1
300
+ out_spatial_dim_ext = out_spatial_dim + 1 # for the extra label index used above
301
+ out = rf.scatter(
302
+ rf.ones_like(blank_idx), indices=blank_idx, indices_dim=path_spatial_dim, out_dim=out_spatial_dim_ext
303
+ )
304
+ out, _ = rf.slice(out, axis=out_spatial_dim_ext, size=out_spatial_dim)
305
+ # out state seq: [ _ a _ b _ c _ ]
306
+ # out: [ 3 0 2 1 ]
307
+ return out, out_spatial_dim
308
+
309
+
135
310
  def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim, *, dtype: str = "int32") -> Tensor:
136
311
  """
137
312
  :param a: [B,Ta]
@@ -141,13 +316,18 @@ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim,
141
316
  :param dtype:
142
317
  :return: [B]
143
318
  """
144
- import numpy # just for iinfo on dtype to get max value
319
+ # noinspection PyProtectedMember
320
+ backend = a._raw_backend
321
+ if backend.have_edit_distance():
322
+ return backend.edit_distance(a, a_spatial_dim, b, b_spatial_dim)
323
+
324
+ from numpy import iinfo
145
325
 
146
326
  # The axis permutation is just an efficiency optimization.
147
327
  a = a.copy_transpose([a_spatial_dim] + a.remaining_dims(a_spatial_dim))
148
328
  b = b.copy_transpose([b_spatial_dim] + b.remaining_dims(b_spatial_dim))
149
329
  dev = a.device
150
- max_dist_err = numpy.iinfo(dtype).max
330
+ max_dist_err = iinfo(dtype).max
151
331
  n_a_max_len = a_spatial_dim.get_dim_value()
152
332
  n_b_max_len = b_spatial_dim.get_dim_value()
153
333
  if int(n_a_max_len) < int(n_b_max_len):
returnn/frontend/math_.py CHANGED
@@ -3,7 +3,6 @@ Math ops
3
3
  """
4
4
 
5
5
  from __future__ import annotations
6
- import typing
7
6
  from typing import Optional, Sequence, Union, Tuple, overload
8
7
  import numpy
9
8
  from returnn.tensor import Tensor, Dim
@@ -77,7 +76,7 @@ __all__ = [
77
76
  ]
78
77
 
79
78
 
80
- @typing.overload
79
+ @overload
81
80
  def compare(
82
81
  a: Tensor,
83
82
  kind: str,
@@ -86,7 +85,19 @@ def compare(
86
85
  allow_broadcast_all_sources: Optional[bool] = None,
87
86
  dim_order: Optional[Sequence[Dim]] = None,
88
87
  ) -> Tensor:
89
- """compare with two tensors"""
88
+ """compare"""
89
+
90
+
91
+ @overload
92
+ def compare(
93
+ a: Union[Tensor, _RawTensorTypes],
94
+ kind: str,
95
+ b: Union[Tensor, _RawTensorTypes],
96
+ *,
97
+ allow_broadcast_all_sources: Optional[bool] = None,
98
+ dim_order: Optional[Sequence[Dim]] = None,
99
+ ) -> Tensor:
100
+ """compare"""
90
101
 
91
102
 
92
103
  _CompareMap = {
@@ -138,7 +149,7 @@ def compare_bc(
138
149
  return compare(a, kind, b, allow_broadcast_all_sources=True, dim_order=dim_order)
139
150
 
140
151
 
141
- @typing.overload
152
+ @overload
142
153
  def combine(
143
154
  a: Tensor,
144
155
  kind: str,
@@ -147,7 +158,19 @@ def combine(
147
158
  allow_broadcast_all_sources: Optional[bool] = None,
148
159
  dim_order: Optional[Sequence[Dim]] = None,
149
160
  ) -> Tensor:
150
- """combine with two tensors"""
161
+ """combine"""
162
+
163
+
164
+ @overload
165
+ def combine(
166
+ a: Union[Tensor, _RawTensorTypes],
167
+ kind: str,
168
+ b: Union[Tensor, _RawTensorTypes],
169
+ *,
170
+ allow_broadcast_all_sources: Optional[bool] = None,
171
+ dim_order: Optional[Sequence[Dim]] = None,
172
+ ) -> Union[Tensor, _RawTensorTypes]:
173
+ """combine"""
151
174
 
152
175
 
153
176
  _CombineMap = {
@@ -332,7 +355,12 @@ def logical_not(a: Tensor) -> Tensor:
332
355
 
333
356
  @overload
334
357
  def opt_logical_or(a: bool, b: bool) -> bool:
335
- """logical or"""
358
+ """opt logical or"""
359
+
360
+
361
+ @overload
362
+ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
363
+ """opt logical or"""
336
364
 
337
365
 
338
366
  def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
@@ -350,7 +378,12 @@ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tens
350
378
 
351
379
  @overload
352
380
  def opt_logical_and(a: bool, b: bool) -> bool:
353
- """logical and"""
381
+ """opt logical and"""
382
+
383
+
384
+ @overload
385
+ def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
386
+ """opt logical and"""
354
387
 
355
388
 
356
389
  def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
@@ -416,16 +449,23 @@ def minimum(a: Tensor, b: Union[Tensor, _RawTensorTypes], *other_tensors) -> Ten
416
449
 
417
450
  def clip_by_value(
418
451
  x: Tensor,
419
- clip_value_min: Union[Tensor, _RawTensorTypes],
420
- clip_value_max: Union[Tensor, _RawTensorTypes],
452
+ clip_value_min: Union[None, Tensor, _RawTensorTypes] = None,
453
+ clip_value_max: Union[None, Tensor, _RawTensorTypes] = None,
421
454
  *,
422
455
  allow_broadcast_all_sources: bool = False,
423
456
  ) -> Tensor:
424
457
  """clip by value"""
425
- # noinspection PyProtectedMember
426
- return x._raw_backend.clip_by_value(
427
- x, clip_value_min, clip_value_max, allow_broadcast_all_sources=allow_broadcast_all_sources
428
- )
458
+ if clip_value_min is not None and clip_value_max is not None:
459
+ # noinspection PyProtectedMember
460
+ return x._raw_backend.clip_by_value(
461
+ x, clip_value_min, clip_value_max, allow_broadcast_all_sources=allow_broadcast_all_sources
462
+ )
463
+ elif clip_value_min is not None and clip_value_max is None:
464
+ return maximum(x, clip_value_min)
465
+ elif clip_value_min is None and clip_value_max is not None:
466
+ return minimum(x, clip_value_max)
467
+ else:
468
+ return x
429
469
 
430
470
 
431
471
  def identity(x: Tensor) -> Tensor:
@@ -541,7 +581,7 @@ def floor(a: Tensor) -> Tensor:
541
581
 
542
582
  # noinspection PyShadowingBuiltins
543
583
  def round(a: Tensor) -> Tensor:
544
- """round"""
584
+ """round. the result dtype is same as input dtype, still float"""
545
585
  # noinspection PyProtectedMember
546
586
  return a._raw_backend.activation(a, "round")
547
587