returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/__old_mod_loader__.py +26 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/lm.py +130 -42
- returnn/datasets/meta.py +93 -43
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/__init__.py +1 -0
- returnn/frontend/_backend.py +41 -0
- returnn/frontend/_native/__init__.py +22 -0
- returnn/frontend/_numpy_backend.py +7 -0
- returnn/frontend/_utils.py +1 -1
- returnn/frontend/array_.py +48 -2
- returnn/frontend/assert_.py +35 -0
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/device.py +14 -1
- returnn/frontend/encoder/conformer.py +20 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +222 -3
- returnn/frontend/math_.py +54 -14
- returnn/native_op.cpp +182 -172
- returnn/native_op.py +36 -31
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +7 -7
- returnn/tensor/_tensor_extra.py +10 -10
- returnn/tensor/utils.py +8 -5
- returnn/tf/frontend_layers/_backend.py +7 -3
- returnn/tf/layers/basic.py +27 -40
- returnn/tf/native_op.py +27 -63
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +22 -197
- returnn/torch/engine.py +157 -6
- returnn/torch/frontend/_backend.py +280 -29
- returnn/torch/frontend/bridge.py +61 -0
- returnn/torch/frontend/compile_helper.py +106 -0
- returnn/torch/util/array_.py +30 -0
- returnn/torch/util/assert_.py +122 -0
- returnn/torch/util/exception_helper.py +7 -1
- returnn/torch/util/native_op.py +885 -0
- returnn/torch/util/native_op_code_compiler.py +308 -0
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/cuda_env.py +332 -0
- returnn/util/debug.py +12 -2
- returnn/util/file_cache.py +15 -1
- returnn/util/fsa.py +17 -13
- returnn/util/native_code_compiler.py +104 -47
- returnn/util/task_system.py +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
- {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ __all__ = [
|
|
|
11
11
|
"SentencePieces",
|
|
12
12
|
"CharacterTargets",
|
|
13
13
|
"Utf8ByteTargets",
|
|
14
|
+
"HuggingFaceTokenizer",
|
|
14
15
|
]
|
|
15
16
|
|
|
16
17
|
from typing import Optional, Union, Type, Callable, List, Dict
|
|
@@ -691,3 +692,92 @@ class Utf8ByteTargets(Vocabulary):
|
|
|
691
692
|
assert ((seq >= 0) & (seq < 256)).all(), f"invalid byte value, must be within 0-255: {seq}"
|
|
692
693
|
seq = seq.astype(numpy.uint8)
|
|
693
694
|
return bytearray(seq).decode(encoding="utf8")
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
class HuggingFaceTokenizer(Vocabulary):
|
|
698
|
+
"""
|
|
699
|
+
Uses the `AutoTokenizer` class from the `transformers` package.
|
|
700
|
+
"""
|
|
701
|
+
|
|
702
|
+
def __init__(self, *, huggingface_repo_dir: str):
|
|
703
|
+
"""
|
|
704
|
+
:param str huggingface_repo_dir: the directory containing the `tokenizer_config.json` file.
|
|
705
|
+
"""
|
|
706
|
+
import transformers # noqa
|
|
707
|
+
|
|
708
|
+
# Make sure it is a string. (Could be e.g. Sis Path.)
|
|
709
|
+
huggingface_repo_dir = str(huggingface_repo_dir)
|
|
710
|
+
self._opts = {"huggingface_repo_dir": huggingface_repo_dir}
|
|
711
|
+
self._cache_key = huggingface_repo_dir
|
|
712
|
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(huggingface_repo_dir, trust_remote_code=True)
|
|
713
|
+
super().__init__(
|
|
714
|
+
vocab_file=None,
|
|
715
|
+
seq_postfix=None,
|
|
716
|
+
unknown_label=self.tokenizer.unk_token_id,
|
|
717
|
+
eos_label=self.tokenizer.eos_token_id,
|
|
718
|
+
bos_label=self.tokenizer.bos_token_id,
|
|
719
|
+
pad_label=self.tokenizer.pad_token_id,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
def __repr__(self):
|
|
723
|
+
return "%s(%r)" % (self.__class__.__name__, self._opts)
|
|
724
|
+
|
|
725
|
+
def _parse_vocab(self):
|
|
726
|
+
self.num_labels = len(self.tokenizer)
|
|
727
|
+
# Do not load labels/vocab here. This is not really needed.
|
|
728
|
+
|
|
729
|
+
@property
|
|
730
|
+
def labels(self) -> List[str]:
|
|
731
|
+
"""list of labels"""
|
|
732
|
+
if self._cache_key and self._cache_key in self._cache:
|
|
733
|
+
self._vocab, self._labels = self._cache[self._cache_key]
|
|
734
|
+
assert self.num_labels == len(self._vocab) == len(self._labels)
|
|
735
|
+
else:
|
|
736
|
+
self._labels = [self.tokenizer._convert_id_to_token(i) for i in range(self.num_labels)] # noqa
|
|
737
|
+
self._vocab = {label: i for (i, label) in enumerate(self._labels)}
|
|
738
|
+
if self._cache_key:
|
|
739
|
+
self._cache[self._cache_key] = (self._vocab, self._labels)
|
|
740
|
+
return self._labels
|
|
741
|
+
|
|
742
|
+
def is_id_valid(self, idx: int) -> bool:
|
|
743
|
+
"""
|
|
744
|
+
:param idx:
|
|
745
|
+
"""
|
|
746
|
+
return 0 <= idx < len(self.tokenizer)
|
|
747
|
+
|
|
748
|
+
def id_to_label(self, idx: int, default: Union[str, Type[KeyError], None] = KeyError) -> Optional[str]:
|
|
749
|
+
"""
|
|
750
|
+
:param idx:
|
|
751
|
+
:param default:
|
|
752
|
+
"""
|
|
753
|
+
if default is not KeyError and not self.is_id_valid(idx):
|
|
754
|
+
return default
|
|
755
|
+
return self.tokenizer.convert_ids_to_tokens(idx)
|
|
756
|
+
|
|
757
|
+
def label_to_id(self, label: str, default: Union[int, Type[KeyError], None] = KeyError) -> Optional[int]:
|
|
758
|
+
"""
|
|
759
|
+
:param label:
|
|
760
|
+
:param default:
|
|
761
|
+
"""
|
|
762
|
+
res = self.tokenizer.convert_token_to_id(label)
|
|
763
|
+
if res == self.unknown_label_id or res < 0 or res is None:
|
|
764
|
+
# It could be that the label really is the unknown-label, or it could be that the label is unknown.
|
|
765
|
+
if label == self.id_to_label(self.unknown_label_id):
|
|
766
|
+
return self.unknown_label_id
|
|
767
|
+
if default is KeyError:
|
|
768
|
+
raise KeyError("label %r not found" % label)
|
|
769
|
+
return default
|
|
770
|
+
return res
|
|
771
|
+
|
|
772
|
+
def get_seq(self, sentence: str) -> List[int]:
|
|
773
|
+
"""
|
|
774
|
+
:param sentence: assumed to be seq of vocab entries separated by whitespace
|
|
775
|
+
"""
|
|
776
|
+
return self.tokenizer(sentence)["input_ids"]
|
|
777
|
+
|
|
778
|
+
def get_seq_labels(self, seq):
|
|
779
|
+
"""
|
|
780
|
+
:param list[int]|numpy.ndarray seq: 1D sequence
|
|
781
|
+
:rtype: str
|
|
782
|
+
"""
|
|
783
|
+
return self.tokenizer.decode(seq, skip_special_tokens=True)
|
returnn/frontend/__init__.py
CHANGED
returnn/frontend/_backend.py
CHANGED
|
@@ -42,6 +42,11 @@ class Backend(Generic[T]):
|
|
|
42
42
|
"""
|
|
43
43
|
raise NotImplementedError
|
|
44
44
|
|
|
45
|
+
@staticmethod
|
|
46
|
+
def assert_(condition: Tensor, message: str):
|
|
47
|
+
"""assert"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
45
50
|
@staticmethod
|
|
46
51
|
def get_tensor_dependencies(x: Tensor) -> Sequence[Tensor]:
|
|
47
52
|
"""
|
|
@@ -624,12 +629,48 @@ class Backend(Generic[T]):
|
|
|
624
629
|
targets_spatial_dim: Dim,
|
|
625
630
|
blank_index: int,
|
|
626
631
|
max_approx: bool = False,
|
|
632
|
+
use_native_op: Optional[bool] = None,
|
|
633
|
+
label_loop: bool = True,
|
|
627
634
|
) -> Tensor:
|
|
628
635
|
"""
|
|
629
636
|
Calculates the CTC loss.
|
|
630
637
|
"""
|
|
631
638
|
raise NotImplementedError
|
|
632
639
|
|
|
640
|
+
@staticmethod
|
|
641
|
+
def ctc_best_path(
|
|
642
|
+
*,
|
|
643
|
+
logits: Tensor,
|
|
644
|
+
logits_normalized: bool = False,
|
|
645
|
+
targets: Tensor,
|
|
646
|
+
input_spatial_dim: Dim,
|
|
647
|
+
targets_spatial_dim: Dim,
|
|
648
|
+
blank_index: int,
|
|
649
|
+
label_loop: bool = True,
|
|
650
|
+
) -> Tensor:
|
|
651
|
+
"""
|
|
652
|
+
Calculates the CTC best path.
|
|
653
|
+
"""
|
|
654
|
+
raise NotImplementedError
|
|
655
|
+
|
|
656
|
+
@staticmethod
|
|
657
|
+
def have_edit_distance() -> bool:
|
|
658
|
+
"""
|
|
659
|
+
:return: whether we have an edit_distance implementation
|
|
660
|
+
"""
|
|
661
|
+
return False
|
|
662
|
+
|
|
663
|
+
@staticmethod
|
|
664
|
+
def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
|
|
665
|
+
"""
|
|
666
|
+
:param a: [B,Ta]
|
|
667
|
+
:param a_spatial_dim: Ta
|
|
668
|
+
:param b: [B,Tb]
|
|
669
|
+
:param b_spatial_dim: Tb
|
|
670
|
+
:return: [B]
|
|
671
|
+
"""
|
|
672
|
+
raise NotImplementedError
|
|
673
|
+
|
|
633
674
|
@staticmethod
|
|
634
675
|
def have_sequence_mask_raw() -> bool:
|
|
635
676
|
"""
|
|
@@ -67,6 +67,24 @@ def _code_hash_md5(filename: str) -> str:
|
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
_is_set_up = False
|
|
70
|
+
_enabled = True
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def set_enabled(enabled: bool):
|
|
74
|
+
"""
|
|
75
|
+
Enable or disable the native code setup.
|
|
76
|
+
|
|
77
|
+
:param enabled:
|
|
78
|
+
"""
|
|
79
|
+
global _enabled
|
|
80
|
+
_enabled = enabled
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def is_set_up() -> bool:
|
|
84
|
+
"""
|
|
85
|
+
:return: whether the native code is set up
|
|
86
|
+
"""
|
|
87
|
+
return _is_set_up
|
|
70
88
|
|
|
71
89
|
|
|
72
90
|
def setup():
|
|
@@ -76,6 +94,8 @@ def setup():
|
|
|
76
94
|
global _is_set_up
|
|
77
95
|
if _is_set_up:
|
|
78
96
|
return
|
|
97
|
+
if not _enabled:
|
|
98
|
+
return
|
|
79
99
|
_is_set_up = True # only try once
|
|
80
100
|
|
|
81
101
|
from returnn.tensor import Tensor, Dim
|
|
@@ -177,6 +197,8 @@ def setup_torch():
|
|
|
177
197
|
global _is_set_up_torch
|
|
178
198
|
if _is_set_up_torch:
|
|
179
199
|
return
|
|
200
|
+
if not _enabled:
|
|
201
|
+
return
|
|
180
202
|
_is_set_up_torch = True # only try once
|
|
181
203
|
|
|
182
204
|
import torch
|
|
@@ -26,6 +26,13 @@ class NumpyBackend(Backend[numpy.ndarray]):
|
|
|
26
26
|
"""executing eagerly"""
|
|
27
27
|
return True
|
|
28
28
|
|
|
29
|
+
@staticmethod
|
|
30
|
+
def assert_(condition: Tensor, message: str):
|
|
31
|
+
"""assert"""
|
|
32
|
+
assert condition.dims == (), "condition for assert must be a scalar"
|
|
33
|
+
if not condition.raw_tensor.item():
|
|
34
|
+
raise AssertionError(message)
|
|
35
|
+
|
|
29
36
|
@staticmethod
|
|
30
37
|
def get_dtype_name_raw(raw_tensor: numpy.ndarray) -> str:
|
|
31
38
|
"""
|
returnn/frontend/_utils.py
CHANGED
|
@@ -110,7 +110,7 @@ def bin_op_out_template(
|
|
|
110
110
|
all_dims.extend([dim_ for dim_ in a.dims if dim_ == dim])
|
|
111
111
|
else:
|
|
112
112
|
all_dims.extend([dim_ for dim_ in b.dims if dim_ == dim])
|
|
113
|
-
if all(set(x.dims) != set(all_dims) for x in (a, b)):
|
|
113
|
+
if all([set(x.dims) != set(all_dims) for x in (a, b)]):
|
|
114
114
|
if allow_broadcast_all_sources is False:
|
|
115
115
|
raise ValueError(f"compare: sources {a!r} {b!r} not allowed with allow_broadcast_all_sources=False")
|
|
116
116
|
elif allow_broadcast_all_sources is None:
|
returnn/frontend/array_.py
CHANGED
|
@@ -54,6 +54,7 @@ __all__ = [
|
|
|
54
54
|
"one_hot",
|
|
55
55
|
"top_k_mask",
|
|
56
56
|
"top_p_mask",
|
|
57
|
+
"repeat",
|
|
57
58
|
]
|
|
58
59
|
|
|
59
60
|
|
|
@@ -84,6 +85,10 @@ def convert_to_tensor(
|
|
|
84
85
|
:return: tensor
|
|
85
86
|
"""
|
|
86
87
|
if isinstance(value, Tensor): # fast path
|
|
88
|
+
if device and value.device != device:
|
|
89
|
+
value = rf.copy_to_device(value, device)
|
|
90
|
+
if dtype and value.dtype != dtype:
|
|
91
|
+
value = rf.cast(value, dtype=dtype)
|
|
87
92
|
return value
|
|
88
93
|
if isinstance(value, (tuple, list)):
|
|
89
94
|
value = numpy.array(value, dtype=dtype)
|
|
@@ -190,7 +195,7 @@ def merge_dims(
|
|
|
190
195
|
if out_dim is None:
|
|
191
196
|
from returnn.util.basic import prod
|
|
192
197
|
|
|
193
|
-
if any(d.need_masking() for d in dims[1:]):
|
|
198
|
+
if any([d.need_masking() for d in dims[1:]]):
|
|
194
199
|
# The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
|
|
195
200
|
# This would then potentially discard some of the data in the tensor in subsequent operations,
|
|
196
201
|
# when masking is applied.
|
|
@@ -905,7 +910,7 @@ def scatter(
|
|
|
905
910
|
else:
|
|
906
911
|
raise ValueError(f"scatter: invalid mode {mode!r}")
|
|
907
912
|
indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
|
|
908
|
-
if any(dim.need_masking() for dim in indices_dim):
|
|
913
|
+
if any([dim.need_masking() for dim in indices_dim]):
|
|
909
914
|
if use_mask is None:
|
|
910
915
|
use_mask = rf.use_mask_default(
|
|
911
916
|
default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
|
|
@@ -1341,3 +1346,44 @@ def top_p_mask(
|
|
|
1341
1346
|
mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
|
|
1342
1347
|
mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
|
|
1343
1348
|
return mask
|
|
1349
|
+
|
|
1350
|
+
|
|
1351
|
+
def repeat(
|
|
1352
|
+
values: Tensor, *, in_spatial_dim: Dim, repeats: Tensor, out_spatial_dim: Optional[Dim] = None
|
|
1353
|
+
) -> Tuple[Tensor, Dim]:
|
|
1354
|
+
"""
|
|
1355
|
+
Repeats certain elements in a tensor along a given spatial dimension.
|
|
1356
|
+
0 repeats means to remove that element.
|
|
1357
|
+
|
|
1358
|
+
This can be used to implement duration-based expansion, e.g. in text-to-speech.
|
|
1359
|
+
|
|
1360
|
+
:param values: [common..., values..., in_spatial_dim]
|
|
1361
|
+
:param in_spatial_dim:
|
|
1362
|
+
:param repeats: [common..., repeats..., in_spatial_dim] -> int32 durations / number of repetitions for each element
|
|
1363
|
+
:param out_spatial_dim:
|
|
1364
|
+
:return: expanded_values: [common..., values..., repeats..., out_spatial_dim], out_spatial_dim
|
|
1365
|
+
"""
|
|
1366
|
+
# Similar to masked_select
|
|
1367
|
+
repeats = repeats.copy_masked(0, dims=[in_spatial_dim])
|
|
1368
|
+
idxs = rf.cumsum(repeats, spatial_dim=in_spatial_dim) # [batch...,in_spatial_dim] -> idx in out_spatial_dim + 1
|
|
1369
|
+
new_size = rf.gather(idxs, indices=in_spatial_dim.get_dim_value_tensor() - 1, axis=in_spatial_dim) # [batch...]
|
|
1370
|
+
dim_dev = rf.get_default_dim_size_device()
|
|
1371
|
+
if out_spatial_dim is None:
|
|
1372
|
+
out_spatial_dim = Dim(rf.copy_to_device(new_size, dim_dev), name="repeat")
|
|
1373
|
+
elif out_spatial_dim.dyn_size_ext is None:
|
|
1374
|
+
out_spatial_dim.dyn_size_ext = rf.copy_to_device(new_size, dim_dev)
|
|
1375
|
+
elif out_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext.raw_tensor is None:
|
|
1376
|
+
out_spatial_dim.dyn_size_ext.raw_tensor = rf.copy_to_device(new_size, dim_dev).raw_tensor
|
|
1377
|
+
out_spatial_dim_ext = out_spatial_dim + 1
|
|
1378
|
+
rel_idx_counts = rf.scatter(
|
|
1379
|
+
rf.expand_dims(rf.ones((), device=values.device, dtype="int32"), dims=idxs.dims),
|
|
1380
|
+
indices=idxs,
|
|
1381
|
+
indices_dim=in_spatial_dim,
|
|
1382
|
+
out_dim=out_spatial_dim_ext,
|
|
1383
|
+
)
|
|
1384
|
+
# rel_idx_counts: [batch...,out_spatial_dim+1] -> count of how many times each index was selected
|
|
1385
|
+
idxs_ = rf.cumsum(rel_idx_counts, spatial_dim=out_spatial_dim_ext)
|
|
1386
|
+
# idxs_: [batch...,out_spatial_dim+1] -> idx in in_spatial_dim
|
|
1387
|
+
idxs_, _ = rf.slice(idxs_, axis=out_spatial_dim_ext, size=out_spatial_dim) # remove last element
|
|
1388
|
+
# idxs_: [batch...,out_spatial_dim] -> idx in in_spatial_dim (potentially with invalid indices in padded area)
|
|
1389
|
+
return rf.gather(values, indices=idxs_, axis=in_spatial_dim, clip_to_valid=True), out_spatial_dim
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Assertion utility functions for validating conditions in Python code.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from typing import Union
|
|
7
|
+
import returnn.frontend as rf
|
|
8
|
+
from returnn.tensor import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["assert_"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def assert_(condition: Union[Tensor, bool], message: str):
|
|
15
|
+
"""
|
|
16
|
+
Asserts that a given condition is True.
|
|
17
|
+
If the condition is False, raises an AssertionError with the provided message.
|
|
18
|
+
This runs async on GPU.
|
|
19
|
+
|
|
20
|
+
:param condition:
|
|
21
|
+
:param message:
|
|
22
|
+
:return: nothing
|
|
23
|
+
"""
|
|
24
|
+
if isinstance(condition, bool):
|
|
25
|
+
if not condition:
|
|
26
|
+
raise AssertionError(message)
|
|
27
|
+
|
|
28
|
+
elif isinstance(condition, Tensor):
|
|
29
|
+
if condition.dims:
|
|
30
|
+
condition = rf.reduce_all(condition, axis=condition.dims) # reduce to scalar
|
|
31
|
+
# noinspection PyProtectedMember
|
|
32
|
+
condition._raw_backend.assert_(condition, message=message)
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
raise TypeError(f"Condition must be a boolean or a Tensor, got {type(condition)}")
|
returnn/frontend/attention.py
CHANGED
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
|
24
24
|
"LearnedRelativePositionalEncoding",
|
|
25
25
|
"relative_positional_encoding",
|
|
26
26
|
"sinusoidal_positional_encoding",
|
|
27
|
+
"sinusoidal_encoding",
|
|
27
28
|
]
|
|
28
29
|
|
|
29
30
|
|
|
@@ -454,7 +455,7 @@ class RelPosSelfAttention(SelfAttentionBase):
|
|
|
454
455
|
pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=axis)
|
|
455
456
|
else:
|
|
456
457
|
pos_emb, pos_emb_spatial_dim = relative_positional_encoding(
|
|
457
|
-
query_spatial_dim=axis, key_value_spatial_dim=axis, feat_dim=self.pos_emb_feat_dim
|
|
458
|
+
query_spatial_dim=axis, key_value_spatial_dim=axis, feat_dim=self.pos_emb_feat_dim, device=source.device
|
|
458
459
|
)
|
|
459
460
|
if self.pos_emb_dropout:
|
|
460
461
|
pos_emb = rf.dropout(pos_emb, self.pos_emb_dropout)
|
|
@@ -483,6 +484,7 @@ class RelPosSelfAttention(SelfAttentionBase):
|
|
|
483
484
|
matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
|
|
484
485
|
|
|
485
486
|
scores = matrix_ac + matrix_bd # (batch, head, time1, time2)
|
|
487
|
+
del matrix_ac, matrix_bd
|
|
486
488
|
scores *= self.key_dim_per_head.dimension**-0.5
|
|
487
489
|
att_weights = rf.softmax(scores, axis=hist_dim)
|
|
488
490
|
att_weights = rf.dropout(att_weights, self.att_dropout, axis=self.att_dropout_broadcast and hist_dim)
|
|
@@ -609,7 +611,10 @@ class RelPosCausalSelfAttention(CausalSelfAttention):
|
|
|
609
611
|
pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=hist_dim)
|
|
610
612
|
else:
|
|
611
613
|
pos_emb, pos_emb_spatial_dim = relative_positional_encoding(
|
|
612
|
-
query_spatial_dim=axis,
|
|
614
|
+
query_spatial_dim=axis,
|
|
615
|
+
key_value_spatial_dim=hist_dim,
|
|
616
|
+
feat_dim=self.pos_emb_feat_dim,
|
|
617
|
+
device=source.device,
|
|
613
618
|
)
|
|
614
619
|
# pos_emb_spatial_dim is 2*time1-1 if axis!=single_step_dim, else time1
|
|
615
620
|
if self.pos_emb_dropout:
|
|
@@ -724,6 +729,7 @@ class CrossAttention(rf.Module):
|
|
|
724
729
|
"""
|
|
725
730
|
Transformer encoder output. This is intended as an initial API suggestion.
|
|
726
731
|
"""
|
|
732
|
+
assert axis in encoder.dims
|
|
727
733
|
k, v = self.forward_kv(encoder)
|
|
728
734
|
return rf.State(k=k, v=v, kv_axis=axis)
|
|
729
735
|
|
|
@@ -811,7 +817,9 @@ class LearnedRelativePositionalEncoding(rf.Module):
|
|
|
811
817
|
:return: tensor of shape [spatial_dim * 2 - 1, feat_dim], and the out spatial dim (spatial_dim * 2 - 1).
|
|
812
818
|
In the center is the rel pos i-j=0. All to the right are for i-j>0, all to the left for i-j<0.
|
|
813
819
|
"""
|
|
814
|
-
indices, out_spatial_dim = _make_indices(
|
|
820
|
+
indices, out_spatial_dim = _make_indices(
|
|
821
|
+
query_spatial_dim, key_value_spatial_dim, query_offset, device=self.pos_emb.device
|
|
822
|
+
)
|
|
815
823
|
indices = rf.clip_by_value(indices, -self.clipping, 0 if self.causal else self.clipping)
|
|
816
824
|
# Shift values to be >= 0. Each integer still uniquely identifies a relative position difference.
|
|
817
825
|
indices = indices + self.clipping
|
|
@@ -851,8 +859,9 @@ def _make_indices(
|
|
|
851
859
|
query_spatial_dim: Dim,
|
|
852
860
|
key_value_spatial_dim: Dim,
|
|
853
861
|
query_offset: Optional[Union[int, Tensor]] = None,
|
|
862
|
+
device: Optional[str] = None,
|
|
854
863
|
) -> Tuple[Tensor, Dim]:
|
|
855
|
-
kv_pos_vec = rf.range_over_dim(key_value_spatial_dim) # [kv_len]
|
|
864
|
+
kv_pos_vec = rf.range_over_dim(key_value_spatial_dim, device=device) # [kv_len]
|
|
856
865
|
|
|
857
866
|
# See also RelativePositionalEncodingLayer
|
|
858
867
|
if query_spatial_dim == single_step_dim:
|
|
@@ -865,7 +874,7 @@ def _make_indices(
|
|
|
865
874
|
query_offset = key_value_spatial_dim.get_size_tensor() - 1
|
|
866
875
|
else:
|
|
867
876
|
query_spatial_dim_m1 = query_spatial_dim - 1
|
|
868
|
-
q_pos_vec = rf.range_over_dim(query_spatial_dim_m1) # [q_len-1]
|
|
877
|
+
q_pos_vec = rf.range_over_dim(query_spatial_dim_m1, device=device) # [q_len-1]
|
|
869
878
|
|
|
870
879
|
# The masking in the output is quite custom (left+right masking), so our seq lens don't make sense,
|
|
871
880
|
# and might even cause to fail some tests (that e.g. max(q_seq_len+k_seq_len-1) == shape).
|
|
@@ -902,6 +911,7 @@ def relative_positional_encoding(
|
|
|
902
911
|
feat_dim: Dim,
|
|
903
912
|
query_offset: int = 0,
|
|
904
913
|
dtype: Optional[str] = None,
|
|
914
|
+
device: Optional[str] = None,
|
|
905
915
|
) -> Tuple[Tensor, Dim]:
|
|
906
916
|
"""
|
|
907
917
|
Implements relative positional encoding, Transformer-XL style (https://arxiv.org/abs/1901.02860),
|
|
@@ -924,7 +934,9 @@ def relative_positional_encoding(
|
|
|
924
934
|
"""
|
|
925
935
|
if not dtype:
|
|
926
936
|
dtype = rf.get_default_float_dtype()
|
|
927
|
-
|
|
937
|
+
if not device:
|
|
938
|
+
device = rf.get_default_device()
|
|
939
|
+
cache_key = (query_spatial_dim, key_value_spatial_dim, feat_dim, query_offset, dtype, device)
|
|
928
940
|
cache_entry = _relative_positional_encoding_cache.get(cache_key)
|
|
929
941
|
if cache_entry is not None:
|
|
930
942
|
return cache_entry
|
|
@@ -932,7 +944,7 @@ def relative_positional_encoding(
|
|
|
932
944
|
|
|
933
945
|
with rf.control_flow_ctx(None):
|
|
934
946
|
# See also RelativePositionalEncodingLayer, LearnedRelativePositionalEncoding
|
|
935
|
-
indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset)
|
|
947
|
+
indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset, device=device)
|
|
936
948
|
|
|
937
949
|
feat2_dim = feat_dim.div_left(2)
|
|
938
950
|
div_term = rf.exp(rf.range_over_dim(feat2_dim, dtype=dtype) * -(2.0 * math.log(1e4) / feat_dim.dimension))
|
|
@@ -986,7 +998,6 @@ def sinusoidal_positional_encoding(
|
|
|
986
998
|
cache_entry = _sinusoidal_positional_encoding_cache.get(cache_key)
|
|
987
999
|
if cache_entry is not None:
|
|
988
1000
|
return cache_entry
|
|
989
|
-
import math
|
|
990
1001
|
|
|
991
1002
|
with rf.control_flow_ctx(None):
|
|
992
1003
|
# See also RelativePositionalEncodingLayer, LearnedRelativePositionalEncoding
|
|
@@ -997,26 +1008,49 @@ def sinusoidal_positional_encoding(
|
|
|
997
1008
|
indices = rf.range_over_dim(spatial_dim, device=device) # [len]
|
|
998
1009
|
if offset is not None:
|
|
999
1010
|
indices = indices + offset
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
feat2_dim = feat_dim.div_left(2)
|
|
1003
|
-
div_term = rf.exp(
|
|
1004
|
-
rf.range_over_dim(feat2_dim, dtype=dtype, device=device) * -(math.log(base) / (feat2_dim.dimension - 1))
|
|
1005
|
-
)
|
|
1006
|
-
arg_sin = rf.combine_bc(rf.cast(indices, dtype), "*", div_term)
|
|
1007
|
-
arg_cos = arg_sin + math.pi / 2.0
|
|
1008
|
-
arg, feat_dim_ = rf.concat((arg_sin, feat2_dim), (arg_cos, feat2_dim))
|
|
1009
|
-
arg, feat_dim_ = rf.replace_dim(arg, in_dim=feat_dim_, out_dim=feat_dim)
|
|
1010
|
-
emb = rf.sin(arg)
|
|
1011
|
+
emb = sinusoidal_encoding(indices, base=base, feat_dim=feat_dim, dtype=dtype)
|
|
1011
1012
|
emb.verify_out_shape(
|
|
1012
1013
|
{feat_dim} | indices.dims_set | ({spatial_dim} if spatial_dim != single_step_dim else set()),
|
|
1013
1014
|
allow_missing_implicit_dims=True,
|
|
1014
1015
|
)
|
|
1015
|
-
emb.feature_dim = feat_dim
|
|
1016
1016
|
_sinusoidal_positional_encoding_cache.set(cache_key, emb)
|
|
1017
1017
|
return emb
|
|
1018
1018
|
|
|
1019
1019
|
|
|
1020
|
+
def sinusoidal_encoding(
|
|
1021
|
+
indices: Tensor,
|
|
1022
|
+
*,
|
|
1023
|
+
feat_dim: Dim,
|
|
1024
|
+
base: Union[int, float] = 1e4,
|
|
1025
|
+
dtype: Optional[str] = None,
|
|
1026
|
+
) -> Tensor:
|
|
1027
|
+
"""
|
|
1028
|
+
|
|
1029
|
+
:param indices: [...], to be encoded
|
|
1030
|
+
:param feat_dim:
|
|
1031
|
+
:param base: base for the angles
|
|
1032
|
+
:param dtype: data type
|
|
1033
|
+
:return: tensor of shape [..., feat_dim]
|
|
1034
|
+
"""
|
|
1035
|
+
import math
|
|
1036
|
+
|
|
1037
|
+
if not dtype:
|
|
1038
|
+
dtype = rf.get_default_float_dtype()
|
|
1039
|
+
|
|
1040
|
+
device = indices.device
|
|
1041
|
+
feat2_dim = feat_dim.div_left(2)
|
|
1042
|
+
div_term = rf.exp(
|
|
1043
|
+
rf.range_over_dim(feat2_dim, dtype=dtype, device=device) * -(math.log(base) / (feat2_dim.dimension - 1))
|
|
1044
|
+
)
|
|
1045
|
+
arg_sin = rf.combine_bc(rf.cast(indices, dtype), "*", div_term)
|
|
1046
|
+
arg_cos = arg_sin + math.pi / 2.0
|
|
1047
|
+
arg, feat_dim_ = rf.concat((arg_sin, feat2_dim), (arg_cos, feat2_dim))
|
|
1048
|
+
arg, feat_dim_ = rf.replace_dim(arg, in_dim=feat_dim_, out_dim=feat_dim)
|
|
1049
|
+
emb = rf.sin(arg)
|
|
1050
|
+
emb.feature_dim = feat_dim
|
|
1051
|
+
return emb
|
|
1052
|
+
|
|
1053
|
+
|
|
1020
1054
|
_att_dropout_broadcast_shown_warning = False
|
|
1021
1055
|
|
|
1022
1056
|
|