returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +130 -42
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/__init__.py +1 -0
  9. returnn/frontend/_backend.py +41 -0
  10. returnn/frontend/_native/__init__.py +22 -0
  11. returnn/frontend/_numpy_backend.py +7 -0
  12. returnn/frontend/_utils.py +1 -1
  13. returnn/frontend/array_.py +48 -2
  14. returnn/frontend/assert_.py +35 -0
  15. returnn/frontend/attention.py +54 -20
  16. returnn/frontend/conv.py +273 -54
  17. returnn/frontend/device.py +14 -1
  18. returnn/frontend/encoder/conformer.py +20 -0
  19. returnn/frontend/encoder/transformer.py +2 -0
  20. returnn/frontend/loss.py +222 -3
  21. returnn/frontend/math_.py +54 -14
  22. returnn/native_op.cpp +182 -172
  23. returnn/native_op.py +36 -31
  24. returnn/sprint/cache.py +12 -13
  25. returnn/tensor/_dim_extra.py +7 -7
  26. returnn/tensor/_tensor_extra.py +10 -10
  27. returnn/tensor/utils.py +8 -5
  28. returnn/tf/frontend_layers/_backend.py +7 -3
  29. returnn/tf/layers/basic.py +27 -40
  30. returnn/tf/native_op.py +27 -63
  31. returnn/tf/network.py +1 -1
  32. returnn/tf/util/basic.py +22 -197
  33. returnn/torch/engine.py +157 -6
  34. returnn/torch/frontend/_backend.py +280 -29
  35. returnn/torch/frontend/bridge.py +61 -0
  36. returnn/torch/frontend/compile_helper.py +106 -0
  37. returnn/torch/util/array_.py +30 -0
  38. returnn/torch/util/assert_.py +122 -0
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/torch/util/native_op.py +885 -0
  41. returnn/torch/util/native_op_code_compiler.py +308 -0
  42. returnn/util/basic.py +6 -7
  43. returnn/util/better_exchook.py +4 -0
  44. returnn/util/cuda_env.py +332 -0
  45. returnn/util/debug.py +12 -2
  46. returnn/util/file_cache.py +15 -1
  47. returnn/util/fsa.py +17 -13
  48. returnn/util/native_code_compiler.py +104 -47
  49. returnn/util/task_system.py +1 -1
  50. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
  51. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
  52. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  53. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  54. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,885 @@
1
+ """
2
+ Native ops for Torch, similar to :mod:`returnn.tf.native_op`.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import Optional, Any, Tuple, Dict
7
+ import os
8
+ import sys
9
+ from textwrap import dedent
10
+ from threading import RLock
11
+
12
+ import torch
13
+
14
+ from returnn import native_op
15
+ from .native_op_code_compiler import OpCodeCompiler
16
+
17
+
18
+ _base_dir = os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
19
+ _base_dir = os.path.realpath(_base_dir) # Make canonical path-name.
20
+
21
+
22
+ class OpDescription(native_op.NativeOpBaseMixin):
23
+ """
24
+ Meta-info about an op, used by :class:`OpMaker`.
25
+ """
26
+
27
+ @classmethod
28
+ def from_gen_base(cls, gen_base):
29
+ """
30
+ :param returnn.native_op.NativeOpGenBase|type[returnn.native_op.NativeOpGenBase] gen_base:
31
+ :rtype: OpDescription
32
+ """
33
+ name = gen_base.__name__
34
+ assert gen_base.in_info is not None
35
+ assert gen_base.out_info is not None
36
+ assert gen_base.c_fw_code is not None
37
+ return OpDescription(
38
+ in_info=gen_base.in_info,
39
+ out_info=gen_base.out_info,
40
+ c_fw_code=gen_base.c_fw_code,
41
+ c_bw_code=gen_base.c_bw_code,
42
+ c_extra_support_code=gen_base.c_extra_support_code,
43
+ cpu_support=gen_base.cpu_support,
44
+ grad_input_map=gen_base.grad_input_map,
45
+ name=name,
46
+ )
47
+
48
+ @property
49
+ def is_grad_defined(self) -> bool:
50
+ """
51
+ :return: whether the gradient is defined
52
+ """
53
+ return bool(self.c_bw_code)
54
+
55
+ def grad(self) -> Optional[OpDescription]:
56
+ """
57
+ :rtype: OpDescription|None
58
+ """
59
+ if not self.is_grad_defined:
60
+ return None
61
+ kwargs = self.kwargs_for_grad_op()
62
+ return OpDescription(**kwargs)
63
+
64
+
65
+ class OpMaker:
66
+ """
67
+ https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html
68
+ https://docs.pytorch.org/cppdocs/
69
+ """
70
+
71
+ with_cuda: Optional[bool] = None
72
+ global_lock = RLock()
73
+ mod_cache = {} # cache_key -> mod
74
+ op_cache = {} # cache_key -> op
75
+ log_stream = sys.stdout
76
+
77
+ def __init__(
78
+ self,
79
+ description: OpDescription,
80
+ *,
81
+ compiler_opts: Optional[Dict[str, str]] = None,
82
+ with_cuda: Optional[bool] = None,
83
+ ):
84
+ """
85
+ :param description:
86
+ :param compiler_opts: passed on to OpCodeCompiler as kwargs
87
+ :param with_cuda: override auto-detection of CUDA availability
88
+ """
89
+ if with_cuda is not None:
90
+ self.with_cuda = with_cuda
91
+ else:
92
+ self._cls_init_with_cuda()
93
+ self.description = description
94
+ self.name = description.name
95
+ self.compiler_opts = compiler_opts or {}
96
+
97
+ @classmethod
98
+ def _cls_init_with_cuda(cls):
99
+ if cls.with_cuda is None:
100
+ cls.with_cuda = torch.cuda.is_available()
101
+
102
+ @property
103
+ def op_name(self) -> str:
104
+ """op name"""
105
+ return self.name
106
+
107
+ @property
108
+ def cache_key(self) -> str:
109
+ """cache key"""
110
+ return self.name
111
+
112
+ @property
113
+ def support_native_op_cpp_filename(self) -> str:
114
+ """
115
+ :return: filename of NativeOp.cpp
116
+ """
117
+ support_native_op_cpp_filename = "%s/native_op.cpp" % _base_dir
118
+ assert os.path.exists(support_native_op_cpp_filename)
119
+ return support_native_op_cpp_filename
120
+
121
+ def _make_code(self):
122
+ # https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html
123
+ # https://docs.pytorch.org/cppdocs/
124
+
125
+ # We also include NativeOp.cpp.
126
+
127
+ # noinspection PyProtectedMember
128
+ in_info, out_info, _ = native_op.NativeOpBaseMixin._resolve_want_inplace_dummy(
129
+ in_info=self.description.in_info, out_info=self.description.out_info
130
+ )
131
+
132
+ # noinspection PyShadowingNames
133
+ def map_name(v: Dict[str, Any], is_out: bool = False) -> str:
134
+ """name"""
135
+ name = v["name"].lower()
136
+ if is_out:
137
+ name = "_out_%s" % name
138
+ else:
139
+ name = "_in_%s" % name
140
+ return name
141
+
142
+ # noinspection PyShadowingNames,PyUnusedLocal
143
+ def map_type(v: Dict[str, Any]) -> str:
144
+ """dtype"""
145
+ t = v.get("dtype", "float32")
146
+ if t == "float32":
147
+ return "at::kFloat"
148
+ elif t == "int32":
149
+ return "at::kInt"
150
+ else:
151
+ raise NotImplementedError("unsupported dtype %r" % t)
152
+
153
+ def make_compute_code(*, cuda: bool = False) -> str:
154
+ """compute code"""
155
+
156
+ code_device_specific = ""
157
+ if cuda:
158
+ tensor_vars = [v for v in in_info if _schema_type_str(v) == "Tensor"]
159
+ if tensor_vars:
160
+ code_device_specific += dedent(f"""\
161
+ at::Device _device = {map_name(tensor_vars[0])}.device();
162
+ at::OptionalDeviceGuard _device_guard(_device);
163
+ """)
164
+ else:
165
+ code_device_specific += "at::Device _device = at::kCUDA;\n"
166
+
167
+ code_forward_io = ""
168
+ out_is_ref = {} # output vars which are inplace, out_name -> in_idx
169
+ # want_inplace: output-index which this input should operate on
170
+ # Unlike the Theano variant, we always do it inplace,
171
+ # so the user has to make a copy if this is not the intention.
172
+ for in_idx, v in enumerate(in_info):
173
+ out_idx = v.get("want_inplace", -1)
174
+ if out_idx >= 0:
175
+ code_forward_io += dedent(f"""\
176
+ torch::Tensor {map_name(out_info[out_idx], is_out=True)} = {map_name(v)}; // inplace
177
+ """)
178
+ out_name = out_info[out_idx]["name"]
179
+ assert out_name not in out_is_ref
180
+ out_is_ref[out_name] = in_idx
181
+
182
+ code_set_io = ""
183
+ for in_idx, v in enumerate(in_info):
184
+ if _schema_type_str(v) != "Tensor":
185
+ code_set_io += dedent(f"""\
186
+ torch::Tensor {map_name(v)}_tensor = torch::tensor({map_name(v)}, torch::dtype({map_type(v)}));
187
+ """)
188
+ continue # scalar input
189
+ ndim = len(v["shape"])
190
+ code_set_io += dedent(f"""\
191
+ if({map_name(v)}.scalar_type() != {map_type(v)}) {{
192
+ {map_name(v)} = {map_name(v)}.to(torch::dtype({map_type(v)}));
193
+ }}
194
+ """)
195
+ code_set_io += dedent(f"""\
196
+ TORCH_CHECK(
197
+ {map_name(v)}.dim() == {ndim},
198
+ "{v["name"]} shape ndim is not {ndim}, got shape ", {map_name(v)}.sizes());
199
+ """)
200
+ for axis, d in enumerate(v["shape"]):
201
+ if isinstance(d, int):
202
+ code_set_io += dedent(f"""\
203
+ TORCH_CHECK(
204
+ {map_name(v)}.size({axis}) == {d},
205
+ "{v["name"]} shape[{axis}] != {d}, got shape ", {map_name(v)}.sizes());
206
+ """)
207
+
208
+ for out_idx, v in enumerate(out_info):
209
+ out_name = out_info[out_idx]["name"]
210
+ if out_name in out_is_ref: # is ref on input
211
+ pass
212
+ else: # no ref
213
+ cshape = "{%s}" % ", ".join(
214
+ [
215
+ str(dim) if isinstance(dim, int) else f"{map_name(in_info[dim[0]])}.size({dim[1]})"
216
+ for dim in v["shape"]
217
+ ]
218
+ )
219
+ code_set_io += dedent(f"""\
220
+ torch::Tensor {map_name(v, is_out=True)}
221
+ = torch::zeros({cshape}, torch::dtype({map_type(v)}){".device(_device)" if cuda else ""});
222
+ """)
223
+
224
+ code_set_contiguous = ""
225
+ for v in in_info:
226
+ if v.get("need_contiguous", False) and _schema_type_str(v) == "Tensor":
227
+ code_set_contiguous += dedent(f"""\
228
+ if(!{map_name(v)}.is_contiguous()) {{
229
+ {map_name(v)} = {map_name(v)}.contiguous();
230
+ }}
231
+ """)
232
+
233
+ # The user code uses inputs and outputs arrays.
234
+ _code_wrap_io_input_vars_list = [
235
+ f"&{map_name(v)}" + ("" if _schema_type_str(v) == "Tensor" else "_tensor") for v in in_info
236
+ ]
237
+ code_wrap_io = dedent(f"""\
238
+ static const int n_inputs = {len(in_info)}, n_outputs = {len(out_info)};
239
+ torch::Tensor* inputs[n_inputs] = {{
240
+ {", ".join(_code_wrap_io_input_vars_list)} }};
241
+ torch::Tensor* _outputs_ptr[n_outputs] = {{
242
+ {", ".join(f"&{map_name(v, is_out=True)}" for v in out_info)} }};
243
+ torch::Tensor** outputs[n_outputs] = {{
244
+ {", ".join(f"&_outputs_ptr[{i}]" for i in range(len(out_info)))} }};
245
+ """)
246
+
247
+ code_user = self.description.c_fw_code % {"fail": "assert(false);"}
248
+
249
+ code_return = "return std::make_tuple(%s);\n" % ", ".join([map_name(v, is_out=True) for v in out_info])
250
+
251
+ code_compute = "\n\n".join(
252
+ [
253
+ code_device_specific,
254
+ code_forward_io,
255
+ code_set_io,
256
+ code_set_contiguous,
257
+ code_wrap_io,
258
+ code_user,
259
+ code_return,
260
+ ]
261
+ )
262
+
263
+ return code_compute
264
+
265
+ code_header = ""
266
+ code_header += dedent("""\
267
+ #include <torch/extension.h>
268
+ #include <torch/types.h>
269
+ #include <c10/core/CPUAllocator.h>
270
+ """)
271
+ if self.with_cuda:
272
+ code_header += dedent("""\
273
+ #include <cuda.h>
274
+ #include <cuda_runtime.h>
275
+ #include <math_constants.h>
276
+ #include <ATen/cuda/CUDAContext.h>
277
+ #include <c10/cuda/CUDACachingAllocator.h>
278
+ """)
279
+
280
+ def _schema_type_str(v: Dict[str, Any], *, c: bool = False) -> str:
281
+ if v.get("host_memory", False):
282
+ assert v["ndim"] == 0 # not supported otherwise...
283
+ dtype = v.get("dtype", "float32")
284
+ if dtype == "float32":
285
+ if c:
286
+ return "float32_t"
287
+ return "float"
288
+ elif dtype == "int32":
289
+ if c:
290
+ return "int64_t" # int8_t, int64_t and bool are supported as an integral argument type
291
+ return "int"
292
+ else:
293
+ raise NotImplementedError("unsupported dtype %r" % dtype)
294
+ if c:
295
+ return "torch::Tensor"
296
+ return "Tensor"
297
+
298
+ # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func
299
+ func_schema_str = "(%s)" % ", ".join(
300
+ f"{_schema_type_str(v)} {v['name']}" for v in in_info
301
+ ) + " -> (%s)" % ", ".join(_schema_type_str(v) for v in out_info)
302
+
303
+ code_header += dedent(
304
+ f"""\
305
+ #define _ns // so _ns::something will use the root namespace
306
+ #define TORCH 1
307
+ #define CUDA 0
308
+ #include "{self.support_native_op_cpp_filename}"
309
+
310
+ TORCH_LIBRARY({self.op_name}, m) {{
311
+ m.def("{self.op_name}{func_schema_str}");
312
+ }}
313
+
314
+ """
315
+ )
316
+
317
+ if self.description.cpu_support:
318
+ # noinspection PyProtectedMember
319
+ code_cpu_op = self.description._reduce_c_extra_support_code(self.description.c_extra_support_code)
320
+ code_cpu_op += dedent(
321
+ f"""\
322
+
323
+ std::tuple<{", ".join(["torch::Tensor"] * len(out_info))}>
324
+ {self.op_name}_cpu(
325
+ {", ".join(f"{_schema_type_str(v, c=True)} {map_name(v)}" for v in in_info)}
326
+ ) {{
327
+ """
328
+ )
329
+ code_cpu_op += make_compute_code()
330
+ code_cpu_op += dedent(
331
+ f"""\
332
+ }}
333
+
334
+ TORCH_LIBRARY_IMPL({self.op_name}, CPU, m) {{
335
+ m.impl("{self.op_name}", &{self.op_name}_cpu);
336
+ }}
337
+ """
338
+ )
339
+ else:
340
+ code_cpu_op = ""
341
+
342
+ if self.with_cuda:
343
+ # noinspection PyProtectedMember
344
+ code_cuda_op = dedent(f"""\
345
+ namespace _cuda_impl {{
346
+
347
+ #ifdef _ns
348
+ #undef _ns
349
+ #define _ns _ns
350
+ #endif
351
+ namespace _ns = ::_cuda_impl;
352
+ #undef Ndarray_memcpy
353
+ #undef Ndarray_memset
354
+ #undef Ndarray_sgemm
355
+ #undef Ndarray_sgemv
356
+ #undef Ndarray_sgemm_batched
357
+ #undef DEF_KERNEL
358
+ #undef start_dev_kernel
359
+ #undef assert_cmp
360
+ #undef threadIdx
361
+ #undef blockIdx
362
+ #undef blockDim
363
+ #undef gridDim
364
+ #undef DEF_SHARED
365
+ #undef DEV_FUNC
366
+ #undef HANDLE_LAST_ERROR
367
+ #undef HOST_FUNC
368
+ #undef INF_F
369
+ #undef NAN_F
370
+ #undef elem_atomic_add
371
+ #undef elem_atomic_cas
372
+ #undef elem_atomic_min
373
+ #undef float_as_int
374
+ #undef int_as_float
375
+ #undef start_dev_kernel2
376
+ #undef CHECK_WITH_MSG
377
+
378
+ #undef CUDA
379
+ #define CUDA 1
380
+
381
+ #include "{self.support_native_op_cpp_filename}"
382
+
383
+ #undef CUDA // name collision in Torch code below
384
+ """)
385
+ # noinspection PyProtectedMember
386
+ code_cuda_op += self.description._reduce_c_extra_support_code(self.description.c_extra_support_code)
387
+ code_cuda_op += dedent(f"""\
388
+
389
+ std::tuple<{", ".join(["torch::Tensor"] * len(out_info))}>
390
+ {self.op_name}_cuda(
391
+ {", ".join(f"{_schema_type_str(v, c=True)} {map_name(v)}" for v in in_info)}
392
+ ) {{
393
+ """)
394
+ code_cuda_op += make_compute_code(cuda=True)
395
+ code_cuda_op += dedent(f"""\
396
+ }}
397
+
398
+ TORCH_LIBRARY_IMPL({self.op_name}, CUDA, m) {{
399
+ m.impl("{self.op_name}", &{self.op_name}_cuda);
400
+ }}
401
+ }} // namespace _cuda_impl
402
+ """)
403
+ else:
404
+ code_cuda_op = ""
405
+
406
+ return code_header + code_cpu_op + code_cuda_op
407
+
408
+ def _make_mod(self):
409
+ if self.cache_key in self.mod_cache:
410
+ return self.mod_cache[self.cache_key]
411
+
412
+ comp = OpCodeCompiler(
413
+ base_name=self.name,
414
+ code_version=self.description.code_version,
415
+ code=self._make_code(),
416
+ include_deps=[self.support_native_op_cpp_filename],
417
+ use_cuda_if_available=self.with_cuda,
418
+ log_stream=self.log_stream,
419
+ **dict(self.compiler_opts),
420
+ )
421
+ mod = comp.load_module()
422
+ mod._op_compiler = comp
423
+ self.mod_cache[self.cache_key] = mod
424
+ return mod
425
+
426
+ def make_op(self):
427
+ """
428
+ :return: op
429
+ """
430
+ with self.global_lock:
431
+ if self.cache_key in self.op_cache:
432
+ return self.op_cache[self.cache_key]
433
+ mod = self._make_mod()
434
+ op = getattr(mod, self.op_name)
435
+ op._op_maker = self
436
+ op._op_module = mod
437
+ self.op_cache[self.cache_key] = op
438
+
439
+ if self.description.is_grad_defined:
440
+ pass # not implemented yet...
441
+
442
+ return op
443
+
444
+
445
+ def make_op(cls, **kwargs):
446
+ """
447
+ :param type[returnn.native_op.NativeOpGenBase] cls:
448
+ :param kwargs: passed to OpMaker
449
+ :return: op
450
+ :rtype: (torch.Tensor) -> tuple[torch.Tensor]
451
+ """
452
+ maker = OpMaker(OpDescription.from_gen_base(cls), **kwargs)
453
+ return maker.make_op()
454
+
455
+
456
+ def ctc_loss(
457
+ *,
458
+ logits: torch.Tensor,
459
+ logits_seq_lens: torch.Tensor,
460
+ targets: torch.Tensor,
461
+ targets_seq_lens: torch.Tensor,
462
+ label_loop: bool = True,
463
+ logits_time_major: bool = False,
464
+ logits_normalize: bool = True,
465
+ blank_index: int = -1,
466
+ max_approx: bool = False,
467
+ ) -> torch.Tensor:
468
+ """
469
+ Similar to :func:`tf.nn.ctc_loss`.
470
+ We use our :func:`fast_baum_welch`.
471
+ Also see :class:`FastBaumWelchLoss`.
472
+
473
+ :param logits: (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
474
+ :param logits_seq_lens: shape (batch,) of int32|int64
475
+ :param logits_time_major:
476
+ :param targets: batch-major, [batch,time]
477
+ :param targets_seq_lens: (batch,)
478
+ :param label_loop: (ctc_merge_repeated in tf.nn.ctc_loss)
479
+ :param logits_normalize: apply log_softmax on logits (default).
480
+ if False, you might also set grad_wrt_softmax_in=False
481
+ :param blank_index: vocab index of the blank symbol
482
+ :param max_approx: use max approximation (Viterbi) instead of full sum
483
+ :return: loss, shape (batch,)
484
+ """
485
+ from .array_ import sequence_mask_time_major
486
+
487
+ assert logits.ndim == 3
488
+ dim = logits.shape[-1]
489
+ if not logits_time_major:
490
+ logits = torch.transpose(logits, 0, 1) # (time,batch,dim)
491
+
492
+ if blank_index < 0:
493
+ blank_index += dim
494
+ assert 0 <= blank_index < dim
495
+ edges, weights, start_end_states = get_ctc_fsa_fast_bw(
496
+ targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
497
+ )
498
+
499
+ seq_mask = sequence_mask_time_major(logits_seq_lens) # (time,batch), bool
500
+
501
+ if max_approx:
502
+ log_probs = torch.log_softmax(logits, dim=-1) if logits_normalize else logits # (time,batch,dim)
503
+ alignment, _ = fast_viterbi(
504
+ am_scores=log_probs,
505
+ am_seq_len=logits_seq_lens,
506
+ edges=edges,
507
+ weights=weights,
508
+ start_end_states=start_end_states,
509
+ )
510
+ # alignment is (time,batch)
511
+ log_probs_ = torch.gather(log_probs, 2, alignment.unsqueeze(-1)) # (time,batch,1)
512
+ log_probs_ = log_probs_.squeeze(-1) # (time,batch)
513
+ log_probs_ = torch.where(seq_mask, log_probs_, 0.0)
514
+ loss = -torch.sum(log_probs_, dim=0) # (batch,)
515
+ return loss
516
+
517
+ loss = _FastBaumWelchScoresAutogradFunc.apply(logits, logits_normalize, seq_mask, edges, weights, start_end_states)
518
+ return loss
519
+
520
+
521
+ # noinspection PyMethodOverriding,PyAbstractClass,PyMissingOrEmptyDocstring
522
+ class _FastBaumWelchScoresAutogradFunc(torch.autograd.Function):
523
+ @staticmethod
524
+ def forward(
525
+ ctx,
526
+ logits: torch.Tensor,
527
+ logits_normalize: bool,
528
+ seq_mask: torch.Tensor,
529
+ edges: torch.Tensor,
530
+ weights: torch.Tensor,
531
+ start_end_states: torch.Tensor,
532
+ state_buffer: Optional[torch.Tensor] = None,
533
+ ) -> torch.Tensor:
534
+ if logits_normalize:
535
+ log_sm = torch.log_softmax(logits, dim=-1) # (time,batch,dim)
536
+ else:
537
+ log_sm = logits
538
+ fwdbwd, obs_scores = fast_baum_welch(
539
+ am_scores=-log_sm,
540
+ seq_mask=seq_mask,
541
+ edges=edges,
542
+ weights=weights,
543
+ start_end_states=start_end_states,
544
+ state_buffer=state_buffer,
545
+ )
546
+ loss = obs_scores[0] # (batch,)
547
+ ctx.grad_wrt_softmax_in = logits_normalize
548
+ if logits_normalize:
549
+ ctx.save_for_backward(log_sm, seq_mask, fwdbwd)
550
+ else:
551
+ ctx.save_for_backward(seq_mask, fwdbwd)
552
+ return loss
553
+
554
+ @staticmethod
555
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
556
+ if ctx.grad_wrt_softmax_in:
557
+ log_sm, seq_mask, fwdbwd = ctx.saved_tensors
558
+ else:
559
+ log_sm = None
560
+ seq_mask, fwdbwd = ctx.saved_tensors
561
+ bw = torch.exp(-fwdbwd) # (time,batch,dim)
562
+ if ctx.grad_wrt_softmax_in:
563
+ grad_x = torch.exp(log_sm) - bw # (time,batch,dim)
564
+ else:
565
+ grad_x = -bw # (time,batch,dim)
566
+ grad_x = torch.where(seq_mask[:, None, :], grad_x, 0.0)
567
+ grad_x *= grad_output[None, :, None]
568
+ return grad_x, None, None, None, None, None, None, None
569
+
570
+
571
+ def get_ctc_fsa_fast_bw(
572
+ *, targets: torch.Tensor, seq_lens: torch.Tensor, blank_idx: int, label_loop: bool = True
573
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
574
+ """
575
+ See :class:`NativeOp.GetCtcFsaFastBwOp`.
576
+ Generates a FSA with CTC topology. The output format is compatible to :func:`fast_baum_welch`.
577
+
578
+ :param targets: shape (batch,time), int32
579
+ :param seq_lens: shape (batch), int32
580
+ :param blank_idx: vocab index of the blank symbol
581
+ :param label_loop: True -> normal CTC; False -> RNA-like
582
+ :return: edges, weights, start_end_states;
583
+ edges is (4,num_edges), int32, edges of the graph (from,to,emission_idx,sequence_idx).
584
+ weights is (num_edges,), float32. all zero.
585
+ start_end_states is (2,batch), int32, (start,end) state idx in FSA.
586
+ """
587
+ assert targets.ndim == 2
588
+ targets = targets.to(torch.int32)
589
+ n_batch, n_time = targets.shape
590
+
591
+ from .assert_ import assert_
592
+
593
+ # The check on the seq lens is important
594
+ # because invalid seq lens might not directly lead to an error here
595
+ # but it might just return an invalid FSA.
596
+ # An invalid FSA can however later cause a crash in the FastBaumWelchOp.
597
+ assert_(seq_lens.max() == n_time, "get_ctc_fsa_fast_bw seq_lens invalid")
598
+
599
+ n_edges = n_batch * (5 * (n_time - 1) + 10) # see op documentation
600
+ weights = torch.zeros((n_edges,), device=targets.device)
601
+ maker = OpMaker(OpDescription.from_gen_base(native_op.GetCtcFsaFastBwOp))
602
+ op = maker.make_op()
603
+ edges, start_end_states = op(targets, seq_lens, blank_idx, weights, label_loop)
604
+
605
+ return edges, weights, start_end_states
606
+
607
+
608
+ def fast_baum_welch(
609
+ *,
610
+ am_scores: torch.Tensor,
611
+ seq_mask: torch.Tensor,
612
+ edges: torch.Tensor,
613
+ weights: torch.Tensor,
614
+ start_end_states: torch.Tensor,
615
+ state_buffer: Optional[torch.Tensor] = None,
616
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
617
+ """
618
+ :param am_scores: (time, batch, dim), in -log space
619
+ :param seq_mask: (time, batch) -> 0 or 1 (index mask, via seq lens)
620
+ :param edges: (4,num_edges), edges of the graph (from,to,emission_idx,sequence_idx)
621
+ :param weights: (num_edges,), weights of the edges
622
+ :param start_end_states: (2, batch), (start,end) state idx in automaton.
623
+ there is only one single automaton.
624
+ :param state_buffer: (2, num_states)
625
+ :return: (fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space
626
+ """
627
+ from .assert_ import assert_
628
+
629
+ # edges, weights, start_end_states, state_buffer = SprintAlignmentAutomataOp(self.sprint_opts)(self.network.tags)
630
+ op = make_fast_baum_welch_op()
631
+ float_idx = seq_mask.float()
632
+ if state_buffer is None:
633
+ last_state_idx = start_end_states[1].max() # see get_automata_for_batch
634
+ assert_(last_state_idx >= 0, "fast_baum_welch last_state_idx must be >= 0")
635
+ state_buffer = torch.zeros((2, last_state_idx + 1))
636
+ fwdbwd, obs_scores = op(am_scores, edges, weights, start_end_states, float_idx, state_buffer) # noqa
637
+ return fwdbwd, obs_scores
638
+
639
+
640
+ def make_fast_baum_welch_op(**kwargs):
641
+ """
642
+ :return: op
643
+ :rtype: (torch.Tensor) -> tuple[torch.Tensor]
644
+ """
645
+ maker = OpMaker(OpDescription.from_gen_base(native_op.FastBaumWelchOp), **kwargs)
646
+ return maker.make_op()
647
+
648
+
649
+ def fast_viterbi(
650
+ *,
651
+ am_scores: torch.Tensor,
652
+ am_seq_len: torch.Tensor,
653
+ edges: torch.Tensor,
654
+ weights: torch.Tensor,
655
+ start_end_states: torch.Tensor,
656
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
657
+ """
658
+ :param am_scores: (time, batch, dim), in +log space, already normalized / just used as-is
659
+ :param am_seq_len: (batch,), int32
660
+ :param edges: (4,num_edges), edges of the graph (from,to,emission_idx,sequence_idx)
661
+ :param weights: (num_edges,), weights of the edges
662
+ :param start_end_states: (2, batch), (start,end) state idx in automaton.
663
+ there is only one single automaton.
664
+ :return: (alignment, scores), alignment is (time, batch), scores is (batch,), in +log space.
665
+ note: scores are not differentiable here.
666
+ do gather+sum on the am_scores by the alignment to get it differentiable.
667
+ """
668
+ last_state_idx = start_end_states[1].max()
669
+ n_states = last_state_idx + 1
670
+ maker = OpMaker(OpDescription.from_gen_base(native_op.FastViterbiOp))
671
+ op = maker.make_op()
672
+ alignment, scores = op(am_scores, am_seq_len, edges, weights, start_end_states, n_states)
673
+ return alignment, scores
674
+
675
+
676
+ def ctc_best_path(
677
+ *,
678
+ logits: torch.Tensor,
679
+ logits_seq_lens: torch.Tensor,
680
+ targets: torch.Tensor,
681
+ targets_seq_lens: torch.Tensor,
682
+ label_loop: bool = True,
683
+ logits_time_major: bool = False,
684
+ logits_normalize: bool = True,
685
+ blank_index: int = -1,
686
+ ) -> torch.Tensor:
687
+ """
688
+ :param logits: (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
689
+ :param logits_seq_lens: shape (batch,) of int32|int64
690
+ :param logits_time_major:
691
+ :param targets: batch-major, [batch,time]
692
+ :param targets_seq_lens: (batch,)
693
+ :param label_loop: (ctc_merge_repeated in tf.nn.ctc_loss)
694
+ :param logits_normalize: apply log_softmax on logits (default).
695
+ if False, you might also set grad_wrt_softmax_in=False
696
+ :param blank_index: vocab index of the blank symbol
697
+ :return: alignment, (time, batch). note, to get the scores, do gather+sum on the am_scores by the alignment.
698
+ """
699
+ assert logits.ndim == 3
700
+ dim = logits.shape[-1]
701
+ if not logits_time_major:
702
+ logits = torch.transpose(logits, 0, 1) # (time,batch,dim)
703
+ if logits_normalize:
704
+ log_sm = torch.log_softmax(logits, dim=-1) # (time,batch,dim)
705
+ else:
706
+ log_sm = logits
707
+
708
+ if blank_index < 0:
709
+ blank_index += dim
710
+ assert 0 <= blank_index < dim
711
+ edges, weights, start_end_states = get_ctc_fsa_fast_bw(
712
+ targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
713
+ )
714
+
715
+ alignment, _ = fast_viterbi(
716
+ am_scores=log_sm, am_seq_len=logits_seq_lens, edges=edges, weights=weights, start_end_states=start_end_states
717
+ )
718
+ return alignment
719
+
720
+
721
+ def edit_distance(a, a_len, b, b_len):
722
+ """
723
+ Wraps :class:`NativeOp.EditDistanceOp`.
724
+
725
+ :param torch.Tensor a: (batch,time1), int32
726
+ :param torch.Tensor a_len: (batch,), int32
727
+ :param torch.Tensor b: (batch,time2), int32
728
+ :param torch.Tensor b_len: (batch,), int32
729
+ :return: (batch,) tensor, int32, un-normalized edit distance
730
+ :rtype: torch.Tensor
731
+ """
732
+ maker = OpMaker(OpDescription.from_gen_base(native_op.EditDistanceOp))
733
+ op = maker.make_op()
734
+ return op(a, a_len, b, b_len)
735
+
736
+
737
+ def optimal_completion_edit_distance(a, a_len, b, b_len):
738
+ """
739
+ Wraps :class:`NativeOp.OptimalCompletionEditDistanceOp`.
740
+
741
+ :param torch.Tensor a: (batch,time1), int32. prefix
742
+ :param torch.Tensor a_len: (batch,), int32
743
+ :param torch.Tensor b: (batch,time2), int32
744
+ :param torch.Tensor b_len: (batch,), int32
745
+ :return: (batch,) tensor, int32, un-normalized edit distance
746
+ :rtype: torch.Tensor
747
+ """
748
+ maker = OpMaker(OpDescription.from_gen_base(native_op.OptimalCompletionEditDistanceOp))
749
+ op = maker.make_op()
750
+ return op(a, a_len, b, b_len)
751
+
752
+
753
+ def optimal_completion_edit_distance_per_successor(a, a_len, b, b_len, successors):
754
+ """
755
+ Wraps :class:`NativeOp.OptimalCompletionEditDistancePerSuccessorOp`.
756
+
757
+ :param torch.Tensor a: (batch,time1), int32. prefix
758
+ :param torch.Tensor a_len: (batch,), int32
759
+ :param torch.Tensor b: (batch,time2), int32
760
+ :param torch.Tensor b_len: (batch,), int32
761
+ :param torch.Tensor|int successors: (n_labels,), int32. scalar means tf.range(successors)
762
+ :return: (batch,n_labels) tensor, int32, un-normalized edit distance
763
+ :rtype: torch.Tensor
764
+ """
765
+ if isinstance(successors, int):
766
+ n_labels = successors
767
+ successors = torch.arange(0, n_labels, 1)
768
+ assert isinstance(successors, torch.Tensor)
769
+ maker = OpMaker(OpDescription.from_gen_base(native_op.OptimalCompletionEditDistancePerSuccessorOp))
770
+ op = maker.make_op()
771
+ return op(a, a_len, b, b_len, successors)
772
+
773
+
774
+ def next_edit_distance_row(last_row, a, a_n, a_ended, b, b_len):
775
+ """
776
+ Wraps :class:`NativeOp.NextEditDistanceRowOp`.
777
+
778
+ :param torch.Tensor last_row: 2d (batch,b_time + 1), int32. last edit distances
779
+ :param torch.Tensor a: symbols. 1d (batch,), int32. current.
780
+ :param torch.Tensor|int a_n: scalar or 1d (batch,), int32. current position
781
+ :param torch.Tensor a_ended: 1d (batch,), int32 (casted from bool, because int32 easier to handle)
782
+ :param torch.Tensor b: symbols. 2d (batch,b_time), int32
783
+ :param torch.Tensor b_len: 1d (batch,), int32
784
+ :return: 2d (batch,b_time + 1), int32, next (unnormalized) edit distance row
785
+ :rtype: torch.Tensor
786
+ """
787
+ a_ended = a_ended.int()
788
+ if isinstance(a_n, int):
789
+ a_n = torch.tensor(a_n, device=a.device)
790
+ if a_n.ndim == 0:
791
+ a_n = a_n[None].tile(a_ended.shape)
792
+ maker = OpMaker(OpDescription.from_gen_base(native_op.NextEditDistanceRowOp))
793
+ op = maker.make_op()
794
+ return op(last_row, a, a_n, a_ended, b, b_len)
795
+
796
+
797
+ def edit_distance_via_next_edit_distance_row(a, a_len, b, b_len, optimal_completion=False, full_row_output=False):
798
+ """
799
+ This is mostly for demonstration and debugging.
800
+ Should be equivalent to :func:`edit_distance` or :func:`optimal_completion_edit_distance`
801
+ (which should be much faster).
802
+
803
+ :param torch.Tensor a: (batch,time1), int32
804
+ :param torch.Tensor a_len: (batch,), int32
805
+ :param torch.Tensor b: (batch,time2), int32
806
+ :param torch.Tensor b_len: (batch,), int32
807
+ :param bool optimal_completion: calc optimal completion edit distance instead
808
+ :param bool full_row_output: outputs the full final row
809
+ :return: (batch,) or (batch,time2+1) tensor, int32, un-normalized edit distance
810
+ :rtype: torch.Tensor
811
+ """
812
+ batch_size = a.size(0)
813
+ time1 = a.size(1)
814
+ time2 = b.size(1)
815
+
816
+ row = torch.arange(time2 + 1, device=a.device)[None, :].tile((batch_size, 1)) # (B,time2+1)
817
+ for i in range(time1):
818
+ a_ended = i >= a_len # (B,)
819
+ a_cur = a[:, i] # (B,)
820
+ row = next_edit_distance_row(a=a_cur, a_n=i, a_ended=a_ended, b=b, b_len=b_len, last_row=row)
821
+
822
+ if full_row_output:
823
+ assert not optimal_completion # assert the default, this would not have an effect
824
+ return row
825
+ elif not optimal_completion:
826
+ return row[:, -1]
827
+ else:
828
+ return torch.min(row, dim=1).values
829
+
830
+
831
+ def next_edit_distance_reduce(last_row, a, a_n, a_ended, b, b_len, optimal_completion=False, a_blank_idx=None):
832
+ """
833
+ Wraps :class:`NativeOp.NextEditDistanceReduceOp`.
834
+
835
+ :param torch.Tensor last_row: 2d (batch,b_time + 1), int32. last edit distances
836
+ :param torch.Tensor a: symbols. 2d (batch|1,n_labels), int32. current.
837
+ :param torch.Tensor a_n: scalar or 1d (batch,), int32. current position
838
+ :param torch.Tensor a_ended: 1d (batch,), int32 (casted from bool, because int32 easier to handle)
839
+ :param torch.Tensor b: symbols. 2d (batch,b_time), int32
840
+ :param torch.Tensor b_len: 1d (batch,), int32
841
+ :param torch.Tensor|int|None a_blank_idx: scalar, int32
842
+ :param bool|torch.Tensor optimal_completion:
843
+ :return: 2d (batch,n_labels), int32, next (unnormalized) (optimal completion) edit distance
844
+ :rtype: torch.Tensor
845
+ """
846
+ a_ended = a_ended.int()
847
+ if a_n.ndim == 0:
848
+ a_n = a_n[None].tile(a_ended.shape)
849
+ if a_blank_idx is None:
850
+ a_blank_idx = -1
851
+ maker = OpMaker(OpDescription.from_gen_base(native_op.NextEditDistanceReduceOp))
852
+ op = maker.make_op()
853
+ return op(last_row, a, a_n, a_ended, b, b_len, optimal_completion, a_blank_idx)
854
+
855
+
856
+ def optimal_completion_edit_distance_per_successor_via_next_edit_distance(a, a_len, b, b_len, successors):
857
+ """
858
+ Uses :func:`next_edit_distance_reduce` and :func:`edit_distance_via_next_edit_distance_row`.
859
+ Mostly for demonstration/testing.
860
+ In practice, you would do something similar, but in your own loop.
861
+ Similar to :func:`optimal_completion_edit_distance_per_successor`,
862
+ but the handling of ended sequences (from ``a``) is different.
863
+
864
+ :param torch.Tensor a: (batch,time1), int32. prefix
865
+ :param torch.Tensor a_len: (batch,), int32
866
+ :param torch.Tensor b: (batch,time2), int32
867
+ :param torch.Tensor b_len: (batch,), int32
868
+ :param torch.Tensor|int successors: (n_labels,), int32. scalar means tf.range(successors)
869
+ :return: (batch,n_labels) tensor, int32, un-normalized edit distance
870
+ :rtype: torch.Tensor
871
+ """
872
+ if isinstance(successors, int):
873
+ n_labels = successors
874
+ successors = torch.arange(0, n_labels, 1)[None] # (1,n_labels)
875
+ assert isinstance(successors, torch.Tensor)
876
+ last_row = edit_distance_via_next_edit_distance_row(a, a_len, b, b_len, full_row_output=True)
877
+ return next_edit_distance_reduce(
878
+ last_row,
879
+ a=successors,
880
+ a_n=torch.tensor(a.size(1)),
881
+ a_ended=a_len != a.size(1),
882
+ b=b,
883
+ b_len=b_len,
884
+ optimal_completion=True,
885
+ )