returnn 1.20260105.192646__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +1 -1
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +110 -42
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +6 -5
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +19 -0
- returnn/frontend/loss.py +183 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +104 -174
- returnn/native_op.py +36 -31
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +1 -1
- returnn/tf/frontend_layers/_backend.py +3 -1
- returnn/tf/layers/basic.py +13 -2
- returnn/tf/native_op.py +16 -5
- returnn/tf/util/basic.py +7 -201
- returnn/torch/engine.py +120 -3
- returnn/torch/frontend/_backend.py +166 -22
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +3 -1
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +1 -0
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +42 -36
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20260105.192646.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,885 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Native ops for Torch, similar to :mod:`returnn.tf.native_op`.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Optional, Any, Tuple, Dict
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
from textwrap import dedent
|
|
10
|
+
from threading import RLock
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from returnn import native_op
|
|
15
|
+
from .native_op_code_compiler import OpCodeCompiler
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
|
|
19
|
+
_base_dir = os.path.realpath(_base_dir) # Make canonical path-name.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OpDescription(native_op.NativeOpBaseMixin):
|
|
23
|
+
"""
|
|
24
|
+
Meta-info about an op, used by :class:`OpMaker`.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_gen_base(cls, gen_base):
|
|
29
|
+
"""
|
|
30
|
+
:param returnn.native_op.NativeOpGenBase|type[returnn.native_op.NativeOpGenBase] gen_base:
|
|
31
|
+
:rtype: OpDescription
|
|
32
|
+
"""
|
|
33
|
+
name = gen_base.__name__
|
|
34
|
+
assert gen_base.in_info is not None
|
|
35
|
+
assert gen_base.out_info is not None
|
|
36
|
+
assert gen_base.c_fw_code is not None
|
|
37
|
+
return OpDescription(
|
|
38
|
+
in_info=gen_base.in_info,
|
|
39
|
+
out_info=gen_base.out_info,
|
|
40
|
+
c_fw_code=gen_base.c_fw_code,
|
|
41
|
+
c_bw_code=gen_base.c_bw_code,
|
|
42
|
+
c_extra_support_code=gen_base.c_extra_support_code,
|
|
43
|
+
cpu_support=gen_base.cpu_support,
|
|
44
|
+
grad_input_map=gen_base.grad_input_map,
|
|
45
|
+
name=name,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def is_grad_defined(self) -> bool:
|
|
50
|
+
"""
|
|
51
|
+
:return: whether the gradient is defined
|
|
52
|
+
"""
|
|
53
|
+
return bool(self.c_bw_code)
|
|
54
|
+
|
|
55
|
+
def grad(self) -> Optional[OpDescription]:
|
|
56
|
+
"""
|
|
57
|
+
:rtype: OpDescription|None
|
|
58
|
+
"""
|
|
59
|
+
if not self.is_grad_defined:
|
|
60
|
+
return None
|
|
61
|
+
kwargs = self.kwargs_for_grad_op()
|
|
62
|
+
return OpDescription(**kwargs)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class OpMaker:
|
|
66
|
+
"""
|
|
67
|
+
https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html
|
|
68
|
+
https://docs.pytorch.org/cppdocs/
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
with_cuda: Optional[bool] = None
|
|
72
|
+
global_lock = RLock()
|
|
73
|
+
mod_cache = {} # cache_key -> mod
|
|
74
|
+
op_cache = {} # cache_key -> op
|
|
75
|
+
log_stream = sys.stdout
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
description: OpDescription,
|
|
80
|
+
*,
|
|
81
|
+
compiler_opts: Optional[Dict[str, str]] = None,
|
|
82
|
+
with_cuda: Optional[bool] = None,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
:param description:
|
|
86
|
+
:param compiler_opts: passed on to OpCodeCompiler as kwargs
|
|
87
|
+
:param with_cuda: override auto-detection of CUDA availability
|
|
88
|
+
"""
|
|
89
|
+
if with_cuda is not None:
|
|
90
|
+
self.with_cuda = with_cuda
|
|
91
|
+
else:
|
|
92
|
+
self._cls_init_with_cuda()
|
|
93
|
+
self.description = description
|
|
94
|
+
self.name = description.name
|
|
95
|
+
self.compiler_opts = compiler_opts or {}
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _cls_init_with_cuda(cls):
|
|
99
|
+
if cls.with_cuda is None:
|
|
100
|
+
cls.with_cuda = torch.cuda.is_available()
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def op_name(self) -> str:
|
|
104
|
+
"""op name"""
|
|
105
|
+
return self.name
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def cache_key(self) -> str:
|
|
109
|
+
"""cache key"""
|
|
110
|
+
return self.name
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def support_native_op_cpp_filename(self) -> str:
|
|
114
|
+
"""
|
|
115
|
+
:return: filename of NativeOp.cpp
|
|
116
|
+
"""
|
|
117
|
+
support_native_op_cpp_filename = "%s/native_op.cpp" % _base_dir
|
|
118
|
+
assert os.path.exists(support_native_op_cpp_filename)
|
|
119
|
+
return support_native_op_cpp_filename
|
|
120
|
+
|
|
121
|
+
def _make_code(self):
|
|
122
|
+
# https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html
|
|
123
|
+
# https://docs.pytorch.org/cppdocs/
|
|
124
|
+
|
|
125
|
+
# We also include NativeOp.cpp.
|
|
126
|
+
|
|
127
|
+
# noinspection PyProtectedMember
|
|
128
|
+
in_info, out_info, _ = native_op.NativeOpBaseMixin._resolve_want_inplace_dummy(
|
|
129
|
+
in_info=self.description.in_info, out_info=self.description.out_info
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# noinspection PyShadowingNames
|
|
133
|
+
def map_name(v: Dict[str, Any], is_out: bool = False) -> str:
|
|
134
|
+
"""name"""
|
|
135
|
+
name = v["name"].lower()
|
|
136
|
+
if is_out:
|
|
137
|
+
name = "_out_%s" % name
|
|
138
|
+
else:
|
|
139
|
+
name = "_in_%s" % name
|
|
140
|
+
return name
|
|
141
|
+
|
|
142
|
+
# noinspection PyShadowingNames,PyUnusedLocal
|
|
143
|
+
def map_type(v: Dict[str, Any]) -> str:
|
|
144
|
+
"""dtype"""
|
|
145
|
+
t = v.get("dtype", "float32")
|
|
146
|
+
if t == "float32":
|
|
147
|
+
return "at::kFloat"
|
|
148
|
+
elif t == "int32":
|
|
149
|
+
return "at::kInt"
|
|
150
|
+
else:
|
|
151
|
+
raise NotImplementedError("unsupported dtype %r" % t)
|
|
152
|
+
|
|
153
|
+
def make_compute_code(*, cuda: bool = False) -> str:
|
|
154
|
+
"""compute code"""
|
|
155
|
+
|
|
156
|
+
code_device_specific = ""
|
|
157
|
+
if cuda:
|
|
158
|
+
tensor_vars = [v for v in in_info if _schema_type_str(v) == "Tensor"]
|
|
159
|
+
if tensor_vars:
|
|
160
|
+
code_device_specific += dedent(f"""\
|
|
161
|
+
at::Device _device = {map_name(tensor_vars[0])}.device();
|
|
162
|
+
at::OptionalDeviceGuard _device_guard(_device);
|
|
163
|
+
""")
|
|
164
|
+
else:
|
|
165
|
+
code_device_specific += "at::Device _device = at::kCUDA;\n"
|
|
166
|
+
|
|
167
|
+
code_forward_io = ""
|
|
168
|
+
out_is_ref = {} # output vars which are inplace, out_name -> in_idx
|
|
169
|
+
# want_inplace: output-index which this input should operate on
|
|
170
|
+
# Unlike the Theano variant, we always do it inplace,
|
|
171
|
+
# so the user has to make a copy if this is not the intention.
|
|
172
|
+
for in_idx, v in enumerate(in_info):
|
|
173
|
+
out_idx = v.get("want_inplace", -1)
|
|
174
|
+
if out_idx >= 0:
|
|
175
|
+
code_forward_io += dedent(f"""\
|
|
176
|
+
torch::Tensor {map_name(out_info[out_idx], is_out=True)} = {map_name(v)}; // inplace
|
|
177
|
+
""")
|
|
178
|
+
out_name = out_info[out_idx]["name"]
|
|
179
|
+
assert out_name not in out_is_ref
|
|
180
|
+
out_is_ref[out_name] = in_idx
|
|
181
|
+
|
|
182
|
+
code_set_io = ""
|
|
183
|
+
for in_idx, v in enumerate(in_info):
|
|
184
|
+
if _schema_type_str(v) != "Tensor":
|
|
185
|
+
code_set_io += dedent(f"""\
|
|
186
|
+
torch::Tensor {map_name(v)}_tensor = torch::tensor({map_name(v)}, torch::dtype({map_type(v)}));
|
|
187
|
+
""")
|
|
188
|
+
continue # scalar input
|
|
189
|
+
ndim = len(v["shape"])
|
|
190
|
+
code_set_io += dedent(f"""\
|
|
191
|
+
if({map_name(v)}.scalar_type() != {map_type(v)}) {{
|
|
192
|
+
{map_name(v)} = {map_name(v)}.to(torch::dtype({map_type(v)}));
|
|
193
|
+
}}
|
|
194
|
+
""")
|
|
195
|
+
code_set_io += dedent(f"""\
|
|
196
|
+
TORCH_CHECK(
|
|
197
|
+
{map_name(v)}.dim() == {ndim},
|
|
198
|
+
"{v["name"]} shape ndim is not {ndim}, got shape ", {map_name(v)}.sizes());
|
|
199
|
+
""")
|
|
200
|
+
for axis, d in enumerate(v["shape"]):
|
|
201
|
+
if isinstance(d, int):
|
|
202
|
+
code_set_io += dedent(f"""\
|
|
203
|
+
TORCH_CHECK(
|
|
204
|
+
{map_name(v)}.size({axis}) == {d},
|
|
205
|
+
"{v["name"]} shape[{axis}] != {d}, got shape ", {map_name(v)}.sizes());
|
|
206
|
+
""")
|
|
207
|
+
|
|
208
|
+
for out_idx, v in enumerate(out_info):
|
|
209
|
+
out_name = out_info[out_idx]["name"]
|
|
210
|
+
if out_name in out_is_ref: # is ref on input
|
|
211
|
+
pass
|
|
212
|
+
else: # no ref
|
|
213
|
+
cshape = "{%s}" % ", ".join(
|
|
214
|
+
[
|
|
215
|
+
str(dim) if isinstance(dim, int) else f"{map_name(in_info[dim[0]])}.size({dim[1]})"
|
|
216
|
+
for dim in v["shape"]
|
|
217
|
+
]
|
|
218
|
+
)
|
|
219
|
+
code_set_io += dedent(f"""\
|
|
220
|
+
torch::Tensor {map_name(v, is_out=True)}
|
|
221
|
+
= torch::zeros({cshape}, torch::dtype({map_type(v)}){".device(_device)" if cuda else ""});
|
|
222
|
+
""")
|
|
223
|
+
|
|
224
|
+
code_set_contiguous = ""
|
|
225
|
+
for v in in_info:
|
|
226
|
+
if v.get("need_contiguous", False) and _schema_type_str(v) == "Tensor":
|
|
227
|
+
code_set_contiguous += dedent(f"""\
|
|
228
|
+
if(!{map_name(v)}.is_contiguous()) {{
|
|
229
|
+
{map_name(v)} = {map_name(v)}.contiguous();
|
|
230
|
+
}}
|
|
231
|
+
""")
|
|
232
|
+
|
|
233
|
+
# The user code uses inputs and outputs arrays.
|
|
234
|
+
_code_wrap_io_input_vars_list = [
|
|
235
|
+
f"&{map_name(v)}" + ("" if _schema_type_str(v) == "Tensor" else "_tensor") for v in in_info
|
|
236
|
+
]
|
|
237
|
+
code_wrap_io = dedent(f"""\
|
|
238
|
+
static const int n_inputs = {len(in_info)}, n_outputs = {len(out_info)};
|
|
239
|
+
torch::Tensor* inputs[n_inputs] = {{
|
|
240
|
+
{", ".join(_code_wrap_io_input_vars_list)} }};
|
|
241
|
+
torch::Tensor* _outputs_ptr[n_outputs] = {{
|
|
242
|
+
{", ".join(f"&{map_name(v, is_out=True)}" for v in out_info)} }};
|
|
243
|
+
torch::Tensor** outputs[n_outputs] = {{
|
|
244
|
+
{", ".join(f"&_outputs_ptr[{i}]" for i in range(len(out_info)))} }};
|
|
245
|
+
""")
|
|
246
|
+
|
|
247
|
+
code_user = self.description.c_fw_code % {"fail": "assert(false);"}
|
|
248
|
+
|
|
249
|
+
code_return = "return std::make_tuple(%s);\n" % ", ".join([map_name(v, is_out=True) for v in out_info])
|
|
250
|
+
|
|
251
|
+
code_compute = "\n\n".join(
|
|
252
|
+
[
|
|
253
|
+
code_device_specific,
|
|
254
|
+
code_forward_io,
|
|
255
|
+
code_set_io,
|
|
256
|
+
code_set_contiguous,
|
|
257
|
+
code_wrap_io,
|
|
258
|
+
code_user,
|
|
259
|
+
code_return,
|
|
260
|
+
]
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return code_compute
|
|
264
|
+
|
|
265
|
+
code_header = ""
|
|
266
|
+
code_header += dedent("""\
|
|
267
|
+
#include <torch/extension.h>
|
|
268
|
+
#include <torch/types.h>
|
|
269
|
+
#include <c10/core/CPUAllocator.h>
|
|
270
|
+
""")
|
|
271
|
+
if self.with_cuda:
|
|
272
|
+
code_header += dedent("""\
|
|
273
|
+
#include <cuda.h>
|
|
274
|
+
#include <cuda_runtime.h>
|
|
275
|
+
#include <math_constants.h>
|
|
276
|
+
#include <ATen/cuda/CUDAContext.h>
|
|
277
|
+
#include <c10/cuda/CUDACachingAllocator.h>
|
|
278
|
+
""")
|
|
279
|
+
|
|
280
|
+
def _schema_type_str(v: Dict[str, Any], *, c: bool = False) -> str:
|
|
281
|
+
if v.get("host_memory", False):
|
|
282
|
+
assert v["ndim"] == 0 # not supported otherwise...
|
|
283
|
+
dtype = v.get("dtype", "float32")
|
|
284
|
+
if dtype == "float32":
|
|
285
|
+
if c:
|
|
286
|
+
return "float32_t"
|
|
287
|
+
return "float"
|
|
288
|
+
elif dtype == "int32":
|
|
289
|
+
if c:
|
|
290
|
+
return "int64_t" # int8_t, int64_t and bool are supported as an integral argument type
|
|
291
|
+
return "int"
|
|
292
|
+
else:
|
|
293
|
+
raise NotImplementedError("unsupported dtype %r" % dtype)
|
|
294
|
+
if c:
|
|
295
|
+
return "torch::Tensor"
|
|
296
|
+
return "Tensor"
|
|
297
|
+
|
|
298
|
+
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func
|
|
299
|
+
func_schema_str = "(%s)" % ", ".join(
|
|
300
|
+
f"{_schema_type_str(v)} {v['name']}" for v in in_info
|
|
301
|
+
) + " -> (%s)" % ", ".join(_schema_type_str(v) for v in out_info)
|
|
302
|
+
|
|
303
|
+
code_header += dedent(
|
|
304
|
+
f"""\
|
|
305
|
+
#define _ns // so _ns::something will use the root namespace
|
|
306
|
+
#define TORCH 1
|
|
307
|
+
#define CUDA 0
|
|
308
|
+
#include "{self.support_native_op_cpp_filename}"
|
|
309
|
+
|
|
310
|
+
TORCH_LIBRARY({self.op_name}, m) {{
|
|
311
|
+
m.def("{self.op_name}{func_schema_str}");
|
|
312
|
+
}}
|
|
313
|
+
|
|
314
|
+
"""
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
if self.description.cpu_support:
|
|
318
|
+
# noinspection PyProtectedMember
|
|
319
|
+
code_cpu_op = self.description._reduce_c_extra_support_code(self.description.c_extra_support_code)
|
|
320
|
+
code_cpu_op += dedent(
|
|
321
|
+
f"""\
|
|
322
|
+
|
|
323
|
+
std::tuple<{", ".join(["torch::Tensor"] * len(out_info))}>
|
|
324
|
+
{self.op_name}_cpu(
|
|
325
|
+
{", ".join(f"{_schema_type_str(v, c=True)} {map_name(v)}" for v in in_info)}
|
|
326
|
+
) {{
|
|
327
|
+
"""
|
|
328
|
+
)
|
|
329
|
+
code_cpu_op += make_compute_code()
|
|
330
|
+
code_cpu_op += dedent(
|
|
331
|
+
f"""\
|
|
332
|
+
}}
|
|
333
|
+
|
|
334
|
+
TORCH_LIBRARY_IMPL({self.op_name}, CPU, m) {{
|
|
335
|
+
m.impl("{self.op_name}", &{self.op_name}_cpu);
|
|
336
|
+
}}
|
|
337
|
+
"""
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
code_cpu_op = ""
|
|
341
|
+
|
|
342
|
+
if self.with_cuda:
|
|
343
|
+
# noinspection PyProtectedMember
|
|
344
|
+
code_cuda_op = dedent(f"""\
|
|
345
|
+
namespace _cuda_impl {{
|
|
346
|
+
|
|
347
|
+
#ifdef _ns
|
|
348
|
+
#undef _ns
|
|
349
|
+
#define _ns _ns
|
|
350
|
+
#endif
|
|
351
|
+
namespace _ns = ::_cuda_impl;
|
|
352
|
+
#undef Ndarray_memcpy
|
|
353
|
+
#undef Ndarray_memset
|
|
354
|
+
#undef Ndarray_sgemm
|
|
355
|
+
#undef Ndarray_sgemv
|
|
356
|
+
#undef Ndarray_sgemm_batched
|
|
357
|
+
#undef DEF_KERNEL
|
|
358
|
+
#undef start_dev_kernel
|
|
359
|
+
#undef assert_cmp
|
|
360
|
+
#undef threadIdx
|
|
361
|
+
#undef blockIdx
|
|
362
|
+
#undef blockDim
|
|
363
|
+
#undef gridDim
|
|
364
|
+
#undef DEF_SHARED
|
|
365
|
+
#undef DEV_FUNC
|
|
366
|
+
#undef HANDLE_LAST_ERROR
|
|
367
|
+
#undef HOST_FUNC
|
|
368
|
+
#undef INF_F
|
|
369
|
+
#undef NAN_F
|
|
370
|
+
#undef elem_atomic_add
|
|
371
|
+
#undef elem_atomic_cas
|
|
372
|
+
#undef elem_atomic_min
|
|
373
|
+
#undef float_as_int
|
|
374
|
+
#undef int_as_float
|
|
375
|
+
#undef start_dev_kernel2
|
|
376
|
+
#undef CHECK_WITH_MSG
|
|
377
|
+
|
|
378
|
+
#undef CUDA
|
|
379
|
+
#define CUDA 1
|
|
380
|
+
|
|
381
|
+
#include "{self.support_native_op_cpp_filename}"
|
|
382
|
+
|
|
383
|
+
#undef CUDA // name collision in Torch code below
|
|
384
|
+
""")
|
|
385
|
+
# noinspection PyProtectedMember
|
|
386
|
+
code_cuda_op += self.description._reduce_c_extra_support_code(self.description.c_extra_support_code)
|
|
387
|
+
code_cuda_op += dedent(f"""\
|
|
388
|
+
|
|
389
|
+
std::tuple<{", ".join(["torch::Tensor"] * len(out_info))}>
|
|
390
|
+
{self.op_name}_cuda(
|
|
391
|
+
{", ".join(f"{_schema_type_str(v, c=True)} {map_name(v)}" for v in in_info)}
|
|
392
|
+
) {{
|
|
393
|
+
""")
|
|
394
|
+
code_cuda_op += make_compute_code(cuda=True)
|
|
395
|
+
code_cuda_op += dedent(f"""\
|
|
396
|
+
}}
|
|
397
|
+
|
|
398
|
+
TORCH_LIBRARY_IMPL({self.op_name}, CUDA, m) {{
|
|
399
|
+
m.impl("{self.op_name}", &{self.op_name}_cuda);
|
|
400
|
+
}}
|
|
401
|
+
}} // namespace _cuda_impl
|
|
402
|
+
""")
|
|
403
|
+
else:
|
|
404
|
+
code_cuda_op = ""
|
|
405
|
+
|
|
406
|
+
return code_header + code_cpu_op + code_cuda_op
|
|
407
|
+
|
|
408
|
+
def _make_mod(self):
|
|
409
|
+
if self.cache_key in self.mod_cache:
|
|
410
|
+
return self.mod_cache[self.cache_key]
|
|
411
|
+
|
|
412
|
+
comp = OpCodeCompiler(
|
|
413
|
+
base_name=self.name,
|
|
414
|
+
code_version=self.description.code_version,
|
|
415
|
+
code=self._make_code(),
|
|
416
|
+
include_deps=[self.support_native_op_cpp_filename],
|
|
417
|
+
use_cuda_if_available=self.with_cuda,
|
|
418
|
+
log_stream=self.log_stream,
|
|
419
|
+
**dict(self.compiler_opts),
|
|
420
|
+
)
|
|
421
|
+
mod = comp.load_module()
|
|
422
|
+
mod._op_compiler = comp
|
|
423
|
+
self.mod_cache[self.cache_key] = mod
|
|
424
|
+
return mod
|
|
425
|
+
|
|
426
|
+
def make_op(self):
|
|
427
|
+
"""
|
|
428
|
+
:return: op
|
|
429
|
+
"""
|
|
430
|
+
with self.global_lock:
|
|
431
|
+
if self.cache_key in self.op_cache:
|
|
432
|
+
return self.op_cache[self.cache_key]
|
|
433
|
+
mod = self._make_mod()
|
|
434
|
+
op = getattr(mod, self.op_name)
|
|
435
|
+
op._op_maker = self
|
|
436
|
+
op._op_module = mod
|
|
437
|
+
self.op_cache[self.cache_key] = op
|
|
438
|
+
|
|
439
|
+
if self.description.is_grad_defined:
|
|
440
|
+
pass # not implemented yet...
|
|
441
|
+
|
|
442
|
+
return op
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def make_op(cls, **kwargs):
|
|
446
|
+
"""
|
|
447
|
+
:param type[returnn.native_op.NativeOpGenBase] cls:
|
|
448
|
+
:param kwargs: passed to OpMaker
|
|
449
|
+
:return: op
|
|
450
|
+
:rtype: (torch.Tensor) -> tuple[torch.Tensor]
|
|
451
|
+
"""
|
|
452
|
+
maker = OpMaker(OpDescription.from_gen_base(cls), **kwargs)
|
|
453
|
+
return maker.make_op()
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def ctc_loss(
|
|
457
|
+
*,
|
|
458
|
+
logits: torch.Tensor,
|
|
459
|
+
logits_seq_lens: torch.Tensor,
|
|
460
|
+
targets: torch.Tensor,
|
|
461
|
+
targets_seq_lens: torch.Tensor,
|
|
462
|
+
label_loop: bool = True,
|
|
463
|
+
logits_time_major: bool = False,
|
|
464
|
+
logits_normalize: bool = True,
|
|
465
|
+
blank_index: int = -1,
|
|
466
|
+
max_approx: bool = False,
|
|
467
|
+
) -> torch.Tensor:
|
|
468
|
+
"""
|
|
469
|
+
Similar to :func:`tf.nn.ctc_loss`.
|
|
470
|
+
We use our :func:`fast_baum_welch`.
|
|
471
|
+
Also see :class:`FastBaumWelchLoss`.
|
|
472
|
+
|
|
473
|
+
:param logits: (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
|
|
474
|
+
:param logits_seq_lens: shape (batch,) of int32|int64
|
|
475
|
+
:param logits_time_major:
|
|
476
|
+
:param targets: batch-major, [batch,time]
|
|
477
|
+
:param targets_seq_lens: (batch,)
|
|
478
|
+
:param label_loop: (ctc_merge_repeated in tf.nn.ctc_loss)
|
|
479
|
+
:param logits_normalize: apply log_softmax on logits (default).
|
|
480
|
+
if False, you might also set grad_wrt_softmax_in=False
|
|
481
|
+
:param blank_index: vocab index of the blank symbol
|
|
482
|
+
:param max_approx: use max approximation (Viterbi) instead of full sum
|
|
483
|
+
:return: loss, shape (batch,)
|
|
484
|
+
"""
|
|
485
|
+
from .array_ import sequence_mask_time_major
|
|
486
|
+
|
|
487
|
+
assert logits.ndim == 3
|
|
488
|
+
dim = logits.shape[-1]
|
|
489
|
+
if not logits_time_major:
|
|
490
|
+
logits = torch.transpose(logits, 0, 1) # (time,batch,dim)
|
|
491
|
+
|
|
492
|
+
if blank_index < 0:
|
|
493
|
+
blank_index += dim
|
|
494
|
+
assert 0 <= blank_index < dim
|
|
495
|
+
edges, weights, start_end_states = get_ctc_fsa_fast_bw(
|
|
496
|
+
targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
seq_mask = sequence_mask_time_major(logits_seq_lens) # (time,batch), bool
|
|
500
|
+
|
|
501
|
+
if max_approx:
|
|
502
|
+
log_probs = torch.log_softmax(logits, dim=-1) if logits_normalize else logits # (time,batch,dim)
|
|
503
|
+
alignment, _ = fast_viterbi(
|
|
504
|
+
am_scores=log_probs,
|
|
505
|
+
am_seq_len=logits_seq_lens,
|
|
506
|
+
edges=edges,
|
|
507
|
+
weights=weights,
|
|
508
|
+
start_end_states=start_end_states,
|
|
509
|
+
)
|
|
510
|
+
# alignment is (time,batch)
|
|
511
|
+
log_probs_ = torch.gather(log_probs, 2, alignment.unsqueeze(-1)) # (time,batch,1)
|
|
512
|
+
log_probs_ = log_probs_.squeeze(-1) # (time,batch)
|
|
513
|
+
log_probs_ = torch.where(seq_mask, log_probs_, 0.0)
|
|
514
|
+
loss = -torch.sum(log_probs_, dim=0) # (batch,)
|
|
515
|
+
return loss
|
|
516
|
+
|
|
517
|
+
loss = _FastBaumWelchScoresAutogradFunc.apply(logits, logits_normalize, seq_mask, edges, weights, start_end_states)
|
|
518
|
+
return loss
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# noinspection PyMethodOverriding,PyAbstractClass,PyMissingOrEmptyDocstring
|
|
522
|
+
class _FastBaumWelchScoresAutogradFunc(torch.autograd.Function):
|
|
523
|
+
@staticmethod
|
|
524
|
+
def forward(
|
|
525
|
+
ctx,
|
|
526
|
+
logits: torch.Tensor,
|
|
527
|
+
logits_normalize: bool,
|
|
528
|
+
seq_mask: torch.Tensor,
|
|
529
|
+
edges: torch.Tensor,
|
|
530
|
+
weights: torch.Tensor,
|
|
531
|
+
start_end_states: torch.Tensor,
|
|
532
|
+
state_buffer: Optional[torch.Tensor] = None,
|
|
533
|
+
) -> torch.Tensor:
|
|
534
|
+
if logits_normalize:
|
|
535
|
+
log_sm = torch.log_softmax(logits, dim=-1) # (time,batch,dim)
|
|
536
|
+
else:
|
|
537
|
+
log_sm = logits
|
|
538
|
+
fwdbwd, obs_scores = fast_baum_welch(
|
|
539
|
+
am_scores=-log_sm,
|
|
540
|
+
seq_mask=seq_mask,
|
|
541
|
+
edges=edges,
|
|
542
|
+
weights=weights,
|
|
543
|
+
start_end_states=start_end_states,
|
|
544
|
+
state_buffer=state_buffer,
|
|
545
|
+
)
|
|
546
|
+
loss = obs_scores[0] # (batch,)
|
|
547
|
+
ctx.grad_wrt_softmax_in = logits_normalize
|
|
548
|
+
if logits_normalize:
|
|
549
|
+
ctx.save_for_backward(log_sm, seq_mask, fwdbwd)
|
|
550
|
+
else:
|
|
551
|
+
ctx.save_for_backward(seq_mask, fwdbwd)
|
|
552
|
+
return loss
|
|
553
|
+
|
|
554
|
+
@staticmethod
|
|
555
|
+
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
556
|
+
if ctx.grad_wrt_softmax_in:
|
|
557
|
+
log_sm, seq_mask, fwdbwd = ctx.saved_tensors
|
|
558
|
+
else:
|
|
559
|
+
log_sm = None
|
|
560
|
+
seq_mask, fwdbwd = ctx.saved_tensors
|
|
561
|
+
bw = torch.exp(-fwdbwd) # (time,batch,dim)
|
|
562
|
+
if ctx.grad_wrt_softmax_in:
|
|
563
|
+
grad_x = torch.exp(log_sm) - bw # (time,batch,dim)
|
|
564
|
+
else:
|
|
565
|
+
grad_x = -bw # (time,batch,dim)
|
|
566
|
+
grad_x = torch.where(seq_mask[:, None, :], grad_x, 0.0)
|
|
567
|
+
grad_x *= grad_output[None, :, None]
|
|
568
|
+
return grad_x, None, None, None, None, None, None, None
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def get_ctc_fsa_fast_bw(
|
|
572
|
+
*, targets: torch.Tensor, seq_lens: torch.Tensor, blank_idx: int, label_loop: bool = True
|
|
573
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
574
|
+
"""
|
|
575
|
+
See :class:`NativeOp.GetCtcFsaFastBwOp`.
|
|
576
|
+
Generates a FSA with CTC topology. The output format is compatible to :func:`fast_baum_welch`.
|
|
577
|
+
|
|
578
|
+
:param targets: shape (batch,time), int32
|
|
579
|
+
:param seq_lens: shape (batch), int32
|
|
580
|
+
:param blank_idx: vocab index of the blank symbol
|
|
581
|
+
:param label_loop: True -> normal CTC; False -> RNA-like
|
|
582
|
+
:return: edges, weights, start_end_states;
|
|
583
|
+
edges is (4,num_edges), int32, edges of the graph (from,to,emission_idx,sequence_idx).
|
|
584
|
+
weights is (num_edges,), float32. all zero.
|
|
585
|
+
start_end_states is (2,batch), int32, (start,end) state idx in FSA.
|
|
586
|
+
"""
|
|
587
|
+
assert targets.ndim == 2
|
|
588
|
+
targets = targets.to(torch.int32)
|
|
589
|
+
n_batch, n_time = targets.shape
|
|
590
|
+
|
|
591
|
+
from .assert_ import assert_
|
|
592
|
+
|
|
593
|
+
# The check on the seq lens is important
|
|
594
|
+
# because invalid seq lens might not directly lead to an error here
|
|
595
|
+
# but it might just return an invalid FSA.
|
|
596
|
+
# An invalid FSA can however later cause a crash in the FastBaumWelchOp.
|
|
597
|
+
assert_(seq_lens.max() == n_time, "get_ctc_fsa_fast_bw seq_lens invalid")
|
|
598
|
+
|
|
599
|
+
n_edges = n_batch * (5 * (n_time - 1) + 10) # see op documentation
|
|
600
|
+
weights = torch.zeros((n_edges,), device=targets.device)
|
|
601
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.GetCtcFsaFastBwOp))
|
|
602
|
+
op = maker.make_op()
|
|
603
|
+
edges, start_end_states = op(targets, seq_lens, blank_idx, weights, label_loop)
|
|
604
|
+
|
|
605
|
+
return edges, weights, start_end_states
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def fast_baum_welch(
|
|
609
|
+
*,
|
|
610
|
+
am_scores: torch.Tensor,
|
|
611
|
+
seq_mask: torch.Tensor,
|
|
612
|
+
edges: torch.Tensor,
|
|
613
|
+
weights: torch.Tensor,
|
|
614
|
+
start_end_states: torch.Tensor,
|
|
615
|
+
state_buffer: Optional[torch.Tensor] = None,
|
|
616
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
617
|
+
"""
|
|
618
|
+
:param am_scores: (time, batch, dim), in -log space
|
|
619
|
+
:param seq_mask: (time, batch) -> 0 or 1 (index mask, via seq lens)
|
|
620
|
+
:param edges: (4,num_edges), edges of the graph (from,to,emission_idx,sequence_idx)
|
|
621
|
+
:param weights: (num_edges,), weights of the edges
|
|
622
|
+
:param start_end_states: (2, batch), (start,end) state idx in automaton.
|
|
623
|
+
there is only one single automaton.
|
|
624
|
+
:param state_buffer: (2, num_states)
|
|
625
|
+
:return: (fwdbwd, obs_scores), fwdbwd is (time, batch, dim), obs_scores is (time, batch), in -log space
|
|
626
|
+
"""
|
|
627
|
+
from .assert_ import assert_
|
|
628
|
+
|
|
629
|
+
# edges, weights, start_end_states, state_buffer = SprintAlignmentAutomataOp(self.sprint_opts)(self.network.tags)
|
|
630
|
+
op = make_fast_baum_welch_op()
|
|
631
|
+
float_idx = seq_mask.float()
|
|
632
|
+
if state_buffer is None:
|
|
633
|
+
last_state_idx = start_end_states[1].max() # see get_automata_for_batch
|
|
634
|
+
assert_(last_state_idx >= 0, "fast_baum_welch last_state_idx must be >= 0")
|
|
635
|
+
state_buffer = torch.zeros((2, last_state_idx + 1))
|
|
636
|
+
fwdbwd, obs_scores = op(am_scores, edges, weights, start_end_states, float_idx, state_buffer) # noqa
|
|
637
|
+
return fwdbwd, obs_scores
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def make_fast_baum_welch_op(**kwargs):
|
|
641
|
+
"""
|
|
642
|
+
:return: op
|
|
643
|
+
:rtype: (torch.Tensor) -> tuple[torch.Tensor]
|
|
644
|
+
"""
|
|
645
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.FastBaumWelchOp), **kwargs)
|
|
646
|
+
return maker.make_op()
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def fast_viterbi(
|
|
650
|
+
*,
|
|
651
|
+
am_scores: torch.Tensor,
|
|
652
|
+
am_seq_len: torch.Tensor,
|
|
653
|
+
edges: torch.Tensor,
|
|
654
|
+
weights: torch.Tensor,
|
|
655
|
+
start_end_states: torch.Tensor,
|
|
656
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
657
|
+
"""
|
|
658
|
+
:param am_scores: (time, batch, dim), in +log space, already normalized / just used as-is
|
|
659
|
+
:param am_seq_len: (batch,), int32
|
|
660
|
+
:param edges: (4,num_edges), edges of the graph (from,to,emission_idx,sequence_idx)
|
|
661
|
+
:param weights: (num_edges,), weights of the edges
|
|
662
|
+
:param start_end_states: (2, batch), (start,end) state idx in automaton.
|
|
663
|
+
there is only one single automaton.
|
|
664
|
+
:return: (alignment, scores), alignment is (time, batch), scores is (batch,), in +log space.
|
|
665
|
+
note: scores are not differentiable here.
|
|
666
|
+
do gather+sum on the am_scores by the alignment to get it differentiable.
|
|
667
|
+
"""
|
|
668
|
+
last_state_idx = start_end_states[1].max()
|
|
669
|
+
n_states = last_state_idx + 1
|
|
670
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.FastViterbiOp))
|
|
671
|
+
op = maker.make_op()
|
|
672
|
+
alignment, scores = op(am_scores, am_seq_len, edges, weights, start_end_states, n_states)
|
|
673
|
+
return alignment, scores
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def ctc_best_path(
|
|
677
|
+
*,
|
|
678
|
+
logits: torch.Tensor,
|
|
679
|
+
logits_seq_lens: torch.Tensor,
|
|
680
|
+
targets: torch.Tensor,
|
|
681
|
+
targets_seq_lens: torch.Tensor,
|
|
682
|
+
label_loop: bool = True,
|
|
683
|
+
logits_time_major: bool = False,
|
|
684
|
+
logits_normalize: bool = True,
|
|
685
|
+
blank_index: int = -1,
|
|
686
|
+
) -> torch.Tensor:
|
|
687
|
+
"""
|
|
688
|
+
:param logits: (time,batch,dim) or (batch,time,dim). unnormalized (before softmax)
|
|
689
|
+
:param logits_seq_lens: shape (batch,) of int32|int64
|
|
690
|
+
:param logits_time_major:
|
|
691
|
+
:param targets: batch-major, [batch,time]
|
|
692
|
+
:param targets_seq_lens: (batch,)
|
|
693
|
+
:param label_loop: (ctc_merge_repeated in tf.nn.ctc_loss)
|
|
694
|
+
:param logits_normalize: apply log_softmax on logits (default).
|
|
695
|
+
if False, you might also set grad_wrt_softmax_in=False
|
|
696
|
+
:param blank_index: vocab index of the blank symbol
|
|
697
|
+
:return: alignment, (time, batch). note, to get the scores, do gather+sum on the am_scores by the alignment.
|
|
698
|
+
"""
|
|
699
|
+
assert logits.ndim == 3
|
|
700
|
+
dim = logits.shape[-1]
|
|
701
|
+
if not logits_time_major:
|
|
702
|
+
logits = torch.transpose(logits, 0, 1) # (time,batch,dim)
|
|
703
|
+
if logits_normalize:
|
|
704
|
+
log_sm = torch.log_softmax(logits, dim=-1) # (time,batch,dim)
|
|
705
|
+
else:
|
|
706
|
+
log_sm = logits
|
|
707
|
+
|
|
708
|
+
if blank_index < 0:
|
|
709
|
+
blank_index += dim
|
|
710
|
+
assert 0 <= blank_index < dim
|
|
711
|
+
edges, weights, start_end_states = get_ctc_fsa_fast_bw(
|
|
712
|
+
targets=targets, seq_lens=targets_seq_lens, blank_idx=blank_index, label_loop=label_loop
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
alignment, _ = fast_viterbi(
|
|
716
|
+
am_scores=log_sm, am_seq_len=logits_seq_lens, edges=edges, weights=weights, start_end_states=start_end_states
|
|
717
|
+
)
|
|
718
|
+
return alignment
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def edit_distance(a, a_len, b, b_len):
|
|
722
|
+
"""
|
|
723
|
+
Wraps :class:`NativeOp.EditDistanceOp`.
|
|
724
|
+
|
|
725
|
+
:param torch.Tensor a: (batch,time1), int32
|
|
726
|
+
:param torch.Tensor a_len: (batch,), int32
|
|
727
|
+
:param torch.Tensor b: (batch,time2), int32
|
|
728
|
+
:param torch.Tensor b_len: (batch,), int32
|
|
729
|
+
:return: (batch,) tensor, int32, un-normalized edit distance
|
|
730
|
+
:rtype: torch.Tensor
|
|
731
|
+
"""
|
|
732
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.EditDistanceOp))
|
|
733
|
+
op = maker.make_op()
|
|
734
|
+
return op(a, a_len, b, b_len)
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def optimal_completion_edit_distance(a, a_len, b, b_len):
|
|
738
|
+
"""
|
|
739
|
+
Wraps :class:`NativeOp.OptimalCompletionEditDistanceOp`.
|
|
740
|
+
|
|
741
|
+
:param torch.Tensor a: (batch,time1), int32. prefix
|
|
742
|
+
:param torch.Tensor a_len: (batch,), int32
|
|
743
|
+
:param torch.Tensor b: (batch,time2), int32
|
|
744
|
+
:param torch.Tensor b_len: (batch,), int32
|
|
745
|
+
:return: (batch,) tensor, int32, un-normalized edit distance
|
|
746
|
+
:rtype: torch.Tensor
|
|
747
|
+
"""
|
|
748
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.OptimalCompletionEditDistanceOp))
|
|
749
|
+
op = maker.make_op()
|
|
750
|
+
return op(a, a_len, b, b_len)
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
def optimal_completion_edit_distance_per_successor(a, a_len, b, b_len, successors):
|
|
754
|
+
"""
|
|
755
|
+
Wraps :class:`NativeOp.OptimalCompletionEditDistancePerSuccessorOp`.
|
|
756
|
+
|
|
757
|
+
:param torch.Tensor a: (batch,time1), int32. prefix
|
|
758
|
+
:param torch.Tensor a_len: (batch,), int32
|
|
759
|
+
:param torch.Tensor b: (batch,time2), int32
|
|
760
|
+
:param torch.Tensor b_len: (batch,), int32
|
|
761
|
+
:param torch.Tensor|int successors: (n_labels,), int32. scalar means tf.range(successors)
|
|
762
|
+
:return: (batch,n_labels) tensor, int32, un-normalized edit distance
|
|
763
|
+
:rtype: torch.Tensor
|
|
764
|
+
"""
|
|
765
|
+
if isinstance(successors, int):
|
|
766
|
+
n_labels = successors
|
|
767
|
+
successors = torch.arange(0, n_labels, 1)
|
|
768
|
+
assert isinstance(successors, torch.Tensor)
|
|
769
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.OptimalCompletionEditDistancePerSuccessorOp))
|
|
770
|
+
op = maker.make_op()
|
|
771
|
+
return op(a, a_len, b, b_len, successors)
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def next_edit_distance_row(last_row, a, a_n, a_ended, b, b_len):
|
|
775
|
+
"""
|
|
776
|
+
Wraps :class:`NativeOp.NextEditDistanceRowOp`.
|
|
777
|
+
|
|
778
|
+
:param torch.Tensor last_row: 2d (batch,b_time + 1), int32. last edit distances
|
|
779
|
+
:param torch.Tensor a: symbols. 1d (batch,), int32. current.
|
|
780
|
+
:param torch.Tensor|int a_n: scalar or 1d (batch,), int32. current position
|
|
781
|
+
:param torch.Tensor a_ended: 1d (batch,), int32 (casted from bool, because int32 easier to handle)
|
|
782
|
+
:param torch.Tensor b: symbols. 2d (batch,b_time), int32
|
|
783
|
+
:param torch.Tensor b_len: 1d (batch,), int32
|
|
784
|
+
:return: 2d (batch,b_time + 1), int32, next (unnormalized) edit distance row
|
|
785
|
+
:rtype: torch.Tensor
|
|
786
|
+
"""
|
|
787
|
+
a_ended = a_ended.int()
|
|
788
|
+
if isinstance(a_n, int):
|
|
789
|
+
a_n = torch.tensor(a_n, device=a.device)
|
|
790
|
+
if a_n.ndim == 0:
|
|
791
|
+
a_n = a_n[None].tile(a_ended.shape)
|
|
792
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.NextEditDistanceRowOp))
|
|
793
|
+
op = maker.make_op()
|
|
794
|
+
return op(last_row, a, a_n, a_ended, b, b_len)
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def edit_distance_via_next_edit_distance_row(a, a_len, b, b_len, optimal_completion=False, full_row_output=False):
|
|
798
|
+
"""
|
|
799
|
+
This is mostly for demonstration and debugging.
|
|
800
|
+
Should be equivalent to :func:`edit_distance` or :func:`optimal_completion_edit_distance`
|
|
801
|
+
(which should be much faster).
|
|
802
|
+
|
|
803
|
+
:param torch.Tensor a: (batch,time1), int32
|
|
804
|
+
:param torch.Tensor a_len: (batch,), int32
|
|
805
|
+
:param torch.Tensor b: (batch,time2), int32
|
|
806
|
+
:param torch.Tensor b_len: (batch,), int32
|
|
807
|
+
:param bool optimal_completion: calc optimal completion edit distance instead
|
|
808
|
+
:param bool full_row_output: outputs the full final row
|
|
809
|
+
:return: (batch,) or (batch,time2+1) tensor, int32, un-normalized edit distance
|
|
810
|
+
:rtype: torch.Tensor
|
|
811
|
+
"""
|
|
812
|
+
batch_size = a.size(0)
|
|
813
|
+
time1 = a.size(1)
|
|
814
|
+
time2 = b.size(1)
|
|
815
|
+
|
|
816
|
+
row = torch.arange(time2 + 1, device=a.device)[None, :].tile((batch_size, 1)) # (B,time2+1)
|
|
817
|
+
for i in range(time1):
|
|
818
|
+
a_ended = i >= a_len # (B,)
|
|
819
|
+
a_cur = a[:, i] # (B,)
|
|
820
|
+
row = next_edit_distance_row(a=a_cur, a_n=i, a_ended=a_ended, b=b, b_len=b_len, last_row=row)
|
|
821
|
+
|
|
822
|
+
if full_row_output:
|
|
823
|
+
assert not optimal_completion # assert the default, this would not have an effect
|
|
824
|
+
return row
|
|
825
|
+
elif not optimal_completion:
|
|
826
|
+
return row[:, -1]
|
|
827
|
+
else:
|
|
828
|
+
return torch.min(row, dim=1).values
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def next_edit_distance_reduce(last_row, a, a_n, a_ended, b, b_len, optimal_completion=False, a_blank_idx=None):
|
|
832
|
+
"""
|
|
833
|
+
Wraps :class:`NativeOp.NextEditDistanceReduceOp`.
|
|
834
|
+
|
|
835
|
+
:param torch.Tensor last_row: 2d (batch,b_time + 1), int32. last edit distances
|
|
836
|
+
:param torch.Tensor a: symbols. 2d (batch|1,n_labels), int32. current.
|
|
837
|
+
:param torch.Tensor a_n: scalar or 1d (batch,), int32. current position
|
|
838
|
+
:param torch.Tensor a_ended: 1d (batch,), int32 (casted from bool, because int32 easier to handle)
|
|
839
|
+
:param torch.Tensor b: symbols. 2d (batch,b_time), int32
|
|
840
|
+
:param torch.Tensor b_len: 1d (batch,), int32
|
|
841
|
+
:param torch.Tensor|int|None a_blank_idx: scalar, int32
|
|
842
|
+
:param bool|torch.Tensor optimal_completion:
|
|
843
|
+
:return: 2d (batch,n_labels), int32, next (unnormalized) (optimal completion) edit distance
|
|
844
|
+
:rtype: torch.Tensor
|
|
845
|
+
"""
|
|
846
|
+
a_ended = a_ended.int()
|
|
847
|
+
if a_n.ndim == 0:
|
|
848
|
+
a_n = a_n[None].tile(a_ended.shape)
|
|
849
|
+
if a_blank_idx is None:
|
|
850
|
+
a_blank_idx = -1
|
|
851
|
+
maker = OpMaker(OpDescription.from_gen_base(native_op.NextEditDistanceReduceOp))
|
|
852
|
+
op = maker.make_op()
|
|
853
|
+
return op(last_row, a, a_n, a_ended, b, b_len, optimal_completion, a_blank_idx)
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
def optimal_completion_edit_distance_per_successor_via_next_edit_distance(a, a_len, b, b_len, successors):
|
|
857
|
+
"""
|
|
858
|
+
Uses :func:`next_edit_distance_reduce` and :func:`edit_distance_via_next_edit_distance_row`.
|
|
859
|
+
Mostly for demonstration/testing.
|
|
860
|
+
In practice, you would do something similar, but in your own loop.
|
|
861
|
+
Similar to :func:`optimal_completion_edit_distance_per_successor`,
|
|
862
|
+
but the handling of ended sequences (from ``a``) is different.
|
|
863
|
+
|
|
864
|
+
:param torch.Tensor a: (batch,time1), int32. prefix
|
|
865
|
+
:param torch.Tensor a_len: (batch,), int32
|
|
866
|
+
:param torch.Tensor b: (batch,time2), int32
|
|
867
|
+
:param torch.Tensor b_len: (batch,), int32
|
|
868
|
+
:param torch.Tensor|int successors: (n_labels,), int32. scalar means tf.range(successors)
|
|
869
|
+
:return: (batch,n_labels) tensor, int32, un-normalized edit distance
|
|
870
|
+
:rtype: torch.Tensor
|
|
871
|
+
"""
|
|
872
|
+
if isinstance(successors, int):
|
|
873
|
+
n_labels = successors
|
|
874
|
+
successors = torch.arange(0, n_labels, 1)[None] # (1,n_labels)
|
|
875
|
+
assert isinstance(successors, torch.Tensor)
|
|
876
|
+
last_row = edit_distance_via_next_edit_distance_row(a, a_len, b, b_len, full_row_output=True)
|
|
877
|
+
return next_edit_distance_reduce(
|
|
878
|
+
last_row,
|
|
879
|
+
a=successors,
|
|
880
|
+
a_n=torch.tensor(a.size(1)),
|
|
881
|
+
a_ended=a_len != a.size(1),
|
|
882
|
+
b=b,
|
|
883
|
+
b_len=b_len,
|
|
884
|
+
optimal_completion=True,
|
|
885
|
+
)
|