returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- returnn/PKG-INFO +2 -2
- returnn/_setup_info_generated.py +2 -2
- returnn/config.py +1 -1
- returnn/datasets/basic.py +29 -13
- returnn/datasets/distrib_files.py +61 -3
- returnn/datasets/generating.py +12 -21
- returnn/datasets/huggingface.py +434 -0
- returnn/datasets/lm.py +20 -0
- returnn/datasets/meta.py +179 -60
- returnn/datasets/multi_proc.py +1 -1
- returnn/datasets/postprocessing.py +597 -108
- returnn/datasets/text_dict.py +1 -1
- returnn/datasets/util/vocabulary.py +90 -0
- returnn/frontend/_backend.py +7 -0
- returnn/frontend/array_.py +54 -1
- returnn/frontend/attention.py +54 -20
- returnn/frontend/conv.py +273 -54
- returnn/frontend/decoder/transformer.py +36 -17
- returnn/frontend/encoder/conformer.py +1 -0
- returnn/frontend/encoder/transformer.py +2 -0
- returnn/frontend/loss.py +40 -1
- returnn/frontend/module.py +8 -1
- returnn/frontend/nested.py +9 -0
- returnn/native_op.cpp +80 -0
- returnn/sprint/cache.py +12 -13
- returnn/tensor/_dim_extra.py +51 -29
- returnn/tensor/_tensor_extra.py +6 -1
- returnn/tensor/utils.py +7 -4
- returnn/tf/frontend_layers/_backend.py +11 -2
- returnn/tf/frontend_low_level/_backend.py +15 -0
- returnn/tf/layers/basic.py +16 -38
- returnn/tf/native_op.py +11 -58
- returnn/tf/network.py +1 -1
- returnn/tf/util/basic.py +19 -0
- returnn/torch/data/returnn_dataset_wrapper.py +9 -3
- returnn/torch/engine.py +67 -2
- returnn/torch/frontend/_backend.py +119 -7
- returnn/torch/util/diagnose_gpu.py +65 -31
- returnn/torch/util/exception_helper.py +7 -1
- returnn/util/basic.py +6 -7
- returnn/util/better_exchook.py +4 -0
- returnn/util/collect_outputs_dict.py +79 -0
- returnn/util/debug.py +11 -2
- returnn/util/file_cache.py +42 -4
- returnn/util/task_system.py +1 -1
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
- {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
returnn/datasets/text_dict.py
CHANGED
|
@@ -100,7 +100,7 @@ class TextDictDataset(CachedDataset2):
|
|
|
100
100
|
print(f"{self}: Warning: literal_py_to_pickle.literal_eval failed:", file=log.v3)
|
|
101
101
|
print(f" {type(exc).__name__}: {exc}", file=log.v3)
|
|
102
102
|
print(" Fallback to eval...", file=log.v3)
|
|
103
|
-
data: Dict[str, Any] = eval(txt)
|
|
103
|
+
data: Dict[str, Any] = eval(txt, {"nan": float("nan"), "inf": float("inf")})
|
|
104
104
|
assert data is not None
|
|
105
105
|
assert isinstance(data, dict)
|
|
106
106
|
assert len(data) > 0
|
|
@@ -11,6 +11,7 @@ __all__ = [
|
|
|
11
11
|
"SentencePieces",
|
|
12
12
|
"CharacterTargets",
|
|
13
13
|
"Utf8ByteTargets",
|
|
14
|
+
"HuggingFaceTokenizer",
|
|
14
15
|
]
|
|
15
16
|
|
|
16
17
|
from typing import Optional, Union, Type, Callable, List, Dict
|
|
@@ -691,3 +692,92 @@ class Utf8ByteTargets(Vocabulary):
|
|
|
691
692
|
assert ((seq >= 0) & (seq < 256)).all(), f"invalid byte value, must be within 0-255: {seq}"
|
|
692
693
|
seq = seq.astype(numpy.uint8)
|
|
693
694
|
return bytearray(seq).decode(encoding="utf8")
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
class HuggingFaceTokenizer(Vocabulary):
|
|
698
|
+
"""
|
|
699
|
+
Uses the `AutoTokenizer` class from the `transformers` package.
|
|
700
|
+
"""
|
|
701
|
+
|
|
702
|
+
def __init__(self, *, huggingface_repo_dir: str):
|
|
703
|
+
"""
|
|
704
|
+
:param str huggingface_repo_dir: the directory containing the `tokenizer_config.json` file.
|
|
705
|
+
"""
|
|
706
|
+
import transformers # noqa
|
|
707
|
+
|
|
708
|
+
# Make sure it is a string. (Could be e.g. Sis Path.)
|
|
709
|
+
huggingface_repo_dir = str(huggingface_repo_dir)
|
|
710
|
+
self._opts = {"huggingface_repo_dir": huggingface_repo_dir}
|
|
711
|
+
self._cache_key = huggingface_repo_dir
|
|
712
|
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(huggingface_repo_dir, trust_remote_code=True)
|
|
713
|
+
super().__init__(
|
|
714
|
+
vocab_file=None,
|
|
715
|
+
seq_postfix=None,
|
|
716
|
+
unknown_label=self.tokenizer.unk_token_id,
|
|
717
|
+
eos_label=self.tokenizer.eos_token_id,
|
|
718
|
+
bos_label=self.tokenizer.bos_token_id,
|
|
719
|
+
pad_label=self.tokenizer.pad_token_id,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
def __repr__(self):
|
|
723
|
+
return "%s(%r)" % (self.__class__.__name__, self._opts)
|
|
724
|
+
|
|
725
|
+
def _parse_vocab(self):
|
|
726
|
+
self.num_labels = len(self.tokenizer)
|
|
727
|
+
# Do not load labels/vocab here. This is not really needed.
|
|
728
|
+
|
|
729
|
+
@property
|
|
730
|
+
def labels(self) -> List[str]:
|
|
731
|
+
"""list of labels"""
|
|
732
|
+
if self._cache_key and self._cache_key in self._cache:
|
|
733
|
+
self._vocab, self._labels = self._cache[self._cache_key]
|
|
734
|
+
assert self.num_labels == len(self._vocab) == len(self._labels)
|
|
735
|
+
else:
|
|
736
|
+
self._labels = [self.tokenizer._convert_id_to_token(i) for i in range(self.num_labels)] # noqa
|
|
737
|
+
self._vocab = {label: i for (i, label) in enumerate(self._labels)}
|
|
738
|
+
if self._cache_key:
|
|
739
|
+
self._cache[self._cache_key] = (self._vocab, self._labels)
|
|
740
|
+
return self._labels
|
|
741
|
+
|
|
742
|
+
def is_id_valid(self, idx: int) -> bool:
|
|
743
|
+
"""
|
|
744
|
+
:param idx:
|
|
745
|
+
"""
|
|
746
|
+
return 0 <= idx < len(self.tokenizer)
|
|
747
|
+
|
|
748
|
+
def id_to_label(self, idx: int, default: Union[str, Type[KeyError], None] = KeyError) -> Optional[str]:
|
|
749
|
+
"""
|
|
750
|
+
:param idx:
|
|
751
|
+
:param default:
|
|
752
|
+
"""
|
|
753
|
+
if default is not KeyError and not self.is_id_valid(idx):
|
|
754
|
+
return default
|
|
755
|
+
return self.tokenizer.convert_ids_to_tokens(idx)
|
|
756
|
+
|
|
757
|
+
def label_to_id(self, label: str, default: Union[int, Type[KeyError], None] = KeyError) -> Optional[int]:
|
|
758
|
+
"""
|
|
759
|
+
:param label:
|
|
760
|
+
:param default:
|
|
761
|
+
"""
|
|
762
|
+
res = self.tokenizer.convert_token_to_id(label)
|
|
763
|
+
if res == self.unknown_label_id or res < 0 or res is None:
|
|
764
|
+
# It could be that the label really is the unknown-label, or it could be that the label is unknown.
|
|
765
|
+
if label == self.id_to_label(self.unknown_label_id):
|
|
766
|
+
return self.unknown_label_id
|
|
767
|
+
if default is KeyError:
|
|
768
|
+
raise KeyError("label %r not found" % label)
|
|
769
|
+
return default
|
|
770
|
+
return res
|
|
771
|
+
|
|
772
|
+
def get_seq(self, sentence: str) -> List[int]:
|
|
773
|
+
"""
|
|
774
|
+
:param sentence: assumed to be seq of vocab entries separated by whitespace
|
|
775
|
+
"""
|
|
776
|
+
return self.tokenizer(sentence)["input_ids"]
|
|
777
|
+
|
|
778
|
+
def get_seq_labels(self, seq):
|
|
779
|
+
"""
|
|
780
|
+
:param list[int]|numpy.ndarray seq: 1D sequence
|
|
781
|
+
:rtype: str
|
|
782
|
+
"""
|
|
783
|
+
return self.tokenizer.decode(seq, skip_special_tokens=True)
|
returnn/frontend/_backend.py
CHANGED
|
@@ -66,6 +66,13 @@ class Backend(Generic[T]):
|
|
|
66
66
|
"""
|
|
67
67
|
raise NotImplementedError
|
|
68
68
|
|
|
69
|
+
@staticmethod
|
|
70
|
+
def should_pickle_tensor(raw_tensor: T) -> bool:
|
|
71
|
+
"""
|
|
72
|
+
:return: whether the tensor should be included in a pickle or set to `None`.
|
|
73
|
+
"""
|
|
74
|
+
return True
|
|
75
|
+
|
|
69
76
|
@staticmethod
|
|
70
77
|
def cond(pred: Tensor, true_fn: Callable, false_fn: Callable):
|
|
71
78
|
"""
|
returnn/frontend/array_.py
CHANGED
|
@@ -54,6 +54,7 @@ __all__ = [
|
|
|
54
54
|
"one_hot",
|
|
55
55
|
"top_k_mask",
|
|
56
56
|
"top_p_mask",
|
|
57
|
+
"repeat",
|
|
57
58
|
]
|
|
58
59
|
|
|
59
60
|
|
|
@@ -84,6 +85,10 @@ def convert_to_tensor(
|
|
|
84
85
|
:return: tensor
|
|
85
86
|
"""
|
|
86
87
|
if isinstance(value, Tensor): # fast path
|
|
88
|
+
if device and value.device != device:
|
|
89
|
+
value = rf.copy_to_device(value, device)
|
|
90
|
+
if dtype and value.dtype != dtype:
|
|
91
|
+
value = rf.cast(value, dtype=dtype)
|
|
87
92
|
return value
|
|
88
93
|
if isinstance(value, (tuple, list)):
|
|
89
94
|
value = numpy.array(value, dtype=dtype)
|
|
@@ -1195,7 +1200,10 @@ def reverse_sequence(tensor: Tensor, *, axis: Dim, handle_dynamic_dims: bool = T
|
|
|
1195
1200
|
if not handle_dynamic_dims or not axis.need_masking():
|
|
1196
1201
|
# noinspection PyProtectedMember
|
|
1197
1202
|
return tensor._raw_backend.flip_no_mask(tensor, axis=axis)
|
|
1198
|
-
indices =
|
|
1203
|
+
indices = (
|
|
1204
|
+
rf.combine_bc(axis.get_size_tensor(device=tensor.device), "-", rf.range_over_dim(axis, device=tensor.device))
|
|
1205
|
+
- 1
|
|
1206
|
+
)
|
|
1199
1207
|
return rf.gather(tensor, indices=indices, axis=axis, clip_to_valid=True)
|
|
1200
1208
|
|
|
1201
1209
|
|
|
@@ -1309,6 +1317,7 @@ def top_p_mask(
|
|
|
1309
1317
|
axis: Dim,
|
|
1310
1318
|
p: Union[float, Tensor],
|
|
1311
1319
|
one_more: bool = True,
|
|
1320
|
+
min_tokens_to_keep: int = 1,
|
|
1312
1321
|
) -> Tensor:
|
|
1313
1322
|
"""
|
|
1314
1323
|
Top-p filtering, e.g. as used in Nucleus sampling (https://arxiv.org/abs/1904.09751).
|
|
@@ -1318,6 +1327,8 @@ def top_p_mask(
|
|
|
1318
1327
|
:param p: the probability mass to keep
|
|
1319
1328
|
:param one_more: if True (default), keep also the first token above the threshold.
|
|
1320
1329
|
(It's enabled by default to follow the behavior of the original implementation.)
|
|
1330
|
+
:param min_tokens_to_keep: ensure to keep at least these many tokens (default 1)
|
|
1331
|
+
With one_more=True, min_tokens_to_keep=1 is anyway guaranteed.
|
|
1321
1332
|
:return: mask {probs_dims..., axis} of the top-p tokens.
|
|
1322
1333
|
``sum(probs[mask]) <= p``, or slightly more if ``one_more`` is True.
|
|
1323
1334
|
"""
|
|
@@ -1331,5 +1342,47 @@ def top_p_mask(
|
|
|
1331
1342
|
if one_more:
|
|
1332
1343
|
# keep also the first token above the threshold
|
|
1333
1344
|
mask = rf.shift_right(mask, axis=sorted_dim, pad_value=True)
|
|
1345
|
+
if min_tokens_to_keep > (1 if one_more else 0):
|
|
1346
|
+
mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
|
|
1334
1347
|
mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
|
|
1335
1348
|
return mask
|
|
1349
|
+
|
|
1350
|
+
|
|
1351
|
+
def repeat(
|
|
1352
|
+
values: Tensor, *, in_spatial_dim: Dim, repeats: Tensor, out_spatial_dim: Optional[Dim] = None
|
|
1353
|
+
) -> Tuple[Tensor, Dim]:
|
|
1354
|
+
"""
|
|
1355
|
+
Repeats certain elements in a tensor along a given spatial dimension.
|
|
1356
|
+
0 repeats means to remove that element.
|
|
1357
|
+
|
|
1358
|
+
This can be used to implement duration-based expansion, e.g. in text-to-speech.
|
|
1359
|
+
|
|
1360
|
+
:param values: [common..., values..., in_spatial_dim]
|
|
1361
|
+
:param in_spatial_dim:
|
|
1362
|
+
:param repeats: [common..., repeats..., in_spatial_dim] -> int32 durations / number of repetitions for each element
|
|
1363
|
+
:param out_spatial_dim:
|
|
1364
|
+
:return: expanded_values: [common..., values..., repeats..., out_spatial_dim], out_spatial_dim
|
|
1365
|
+
"""
|
|
1366
|
+
# Similar to masked_select
|
|
1367
|
+
repeats = repeats.copy_masked(0, dims=[in_spatial_dim])
|
|
1368
|
+
idxs = rf.cumsum(repeats, spatial_dim=in_spatial_dim) # [batch...,in_spatial_dim] -> idx in out_spatial_dim + 1
|
|
1369
|
+
new_size = rf.gather(idxs, indices=in_spatial_dim.get_dim_value_tensor() - 1, axis=in_spatial_dim) # [batch...]
|
|
1370
|
+
if out_spatial_dim is None:
|
|
1371
|
+
out_spatial_dim = Dim(new_size, name="repeat")
|
|
1372
|
+
elif out_spatial_dim.dyn_size_ext is None:
|
|
1373
|
+
out_spatial_dim.dyn_size_ext = new_size
|
|
1374
|
+
elif out_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext.raw_tensor is None:
|
|
1375
|
+
out_spatial_dim.dyn_size_ext.raw_tensor = new_size.raw_tensor
|
|
1376
|
+
out_spatial_dim_ext = out_spatial_dim + 1
|
|
1377
|
+
rel_idx_counts = rf.scatter(
|
|
1378
|
+
rf.expand_dims(rf.ones((), device=values.device, dtype="int32"), dims=idxs.dims),
|
|
1379
|
+
indices=idxs,
|
|
1380
|
+
indices_dim=in_spatial_dim,
|
|
1381
|
+
out_dim=out_spatial_dim_ext,
|
|
1382
|
+
)
|
|
1383
|
+
# rel_idx_counts: [batch...,out_spatial_dim+1] -> count of how many times each index was selected
|
|
1384
|
+
idxs_ = rf.cumsum(rel_idx_counts, spatial_dim=out_spatial_dim_ext)
|
|
1385
|
+
# idxs_: [batch...,out_spatial_dim+1] -> idx in in_spatial_dim
|
|
1386
|
+
idxs_, _ = rf.slice(idxs_, axis=out_spatial_dim_ext, size=out_spatial_dim) # remove last element
|
|
1387
|
+
# idxs_: [batch...,out_spatial_dim] -> idx in in_spatial_dim (potentially with invalid indices in padded area)
|
|
1388
|
+
return rf.gather(values, indices=idxs_, axis=in_spatial_dim, clip_to_valid=True), out_spatial_dim
|
returnn/frontend/attention.py
CHANGED
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
|
24
24
|
"LearnedRelativePositionalEncoding",
|
|
25
25
|
"relative_positional_encoding",
|
|
26
26
|
"sinusoidal_positional_encoding",
|
|
27
|
+
"sinusoidal_encoding",
|
|
27
28
|
]
|
|
28
29
|
|
|
29
30
|
|
|
@@ -454,7 +455,7 @@ class RelPosSelfAttention(SelfAttentionBase):
|
|
|
454
455
|
pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=axis)
|
|
455
456
|
else:
|
|
456
457
|
pos_emb, pos_emb_spatial_dim = relative_positional_encoding(
|
|
457
|
-
query_spatial_dim=axis, key_value_spatial_dim=axis, feat_dim=self.pos_emb_feat_dim
|
|
458
|
+
query_spatial_dim=axis, key_value_spatial_dim=axis, feat_dim=self.pos_emb_feat_dim, device=source.device
|
|
458
459
|
)
|
|
459
460
|
if self.pos_emb_dropout:
|
|
460
461
|
pos_emb = rf.dropout(pos_emb, self.pos_emb_dropout)
|
|
@@ -483,6 +484,7 @@ class RelPosSelfAttention(SelfAttentionBase):
|
|
|
483
484
|
matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
|
|
484
485
|
|
|
485
486
|
scores = matrix_ac + matrix_bd # (batch, head, time1, time2)
|
|
487
|
+
del matrix_ac, matrix_bd
|
|
486
488
|
scores *= self.key_dim_per_head.dimension**-0.5
|
|
487
489
|
att_weights = rf.softmax(scores, axis=hist_dim)
|
|
488
490
|
att_weights = rf.dropout(att_weights, self.att_dropout, axis=self.att_dropout_broadcast and hist_dim)
|
|
@@ -609,7 +611,10 @@ class RelPosCausalSelfAttention(CausalSelfAttention):
|
|
|
609
611
|
pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=hist_dim)
|
|
610
612
|
else:
|
|
611
613
|
pos_emb, pos_emb_spatial_dim = relative_positional_encoding(
|
|
612
|
-
query_spatial_dim=axis,
|
|
614
|
+
query_spatial_dim=axis,
|
|
615
|
+
key_value_spatial_dim=hist_dim,
|
|
616
|
+
feat_dim=self.pos_emb_feat_dim,
|
|
617
|
+
device=source.device,
|
|
613
618
|
)
|
|
614
619
|
# pos_emb_spatial_dim is 2*time1-1 if axis!=single_step_dim, else time1
|
|
615
620
|
if self.pos_emb_dropout:
|
|
@@ -724,6 +729,7 @@ class CrossAttention(rf.Module):
|
|
|
724
729
|
"""
|
|
725
730
|
Transformer encoder output. This is intended as an initial API suggestion.
|
|
726
731
|
"""
|
|
732
|
+
assert axis in encoder.dims
|
|
727
733
|
k, v = self.forward_kv(encoder)
|
|
728
734
|
return rf.State(k=k, v=v, kv_axis=axis)
|
|
729
735
|
|
|
@@ -811,7 +817,9 @@ class LearnedRelativePositionalEncoding(rf.Module):
|
|
|
811
817
|
:return: tensor of shape [spatial_dim * 2 - 1, feat_dim], and the out spatial dim (spatial_dim * 2 - 1).
|
|
812
818
|
In the center is the rel pos i-j=0. All to the right are for i-j>0, all to the left for i-j<0.
|
|
813
819
|
"""
|
|
814
|
-
indices, out_spatial_dim = _make_indices(
|
|
820
|
+
indices, out_spatial_dim = _make_indices(
|
|
821
|
+
query_spatial_dim, key_value_spatial_dim, query_offset, device=self.pos_emb.device
|
|
822
|
+
)
|
|
815
823
|
indices = rf.clip_by_value(indices, -self.clipping, 0 if self.causal else self.clipping)
|
|
816
824
|
# Shift values to be >= 0. Each integer still uniquely identifies a relative position difference.
|
|
817
825
|
indices = indices + self.clipping
|
|
@@ -851,8 +859,9 @@ def _make_indices(
|
|
|
851
859
|
query_spatial_dim: Dim,
|
|
852
860
|
key_value_spatial_dim: Dim,
|
|
853
861
|
query_offset: Optional[Union[int, Tensor]] = None,
|
|
862
|
+
device: Optional[str] = None,
|
|
854
863
|
) -> Tuple[Tensor, Dim]:
|
|
855
|
-
kv_pos_vec = rf.range_over_dim(key_value_spatial_dim) # [kv_len]
|
|
864
|
+
kv_pos_vec = rf.range_over_dim(key_value_spatial_dim, device=device) # [kv_len]
|
|
856
865
|
|
|
857
866
|
# See also RelativePositionalEncodingLayer
|
|
858
867
|
if query_spatial_dim == single_step_dim:
|
|
@@ -865,7 +874,7 @@ def _make_indices(
|
|
|
865
874
|
query_offset = key_value_spatial_dim.get_size_tensor() - 1
|
|
866
875
|
else:
|
|
867
876
|
query_spatial_dim_m1 = query_spatial_dim - 1
|
|
868
|
-
q_pos_vec = rf.range_over_dim(query_spatial_dim_m1) # [q_len-1]
|
|
877
|
+
q_pos_vec = rf.range_over_dim(query_spatial_dim_m1, device=device) # [q_len-1]
|
|
869
878
|
|
|
870
879
|
# The masking in the output is quite custom (left+right masking), so our seq lens don't make sense,
|
|
871
880
|
# and might even cause to fail some tests (that e.g. max(q_seq_len+k_seq_len-1) == shape).
|
|
@@ -902,6 +911,7 @@ def relative_positional_encoding(
|
|
|
902
911
|
feat_dim: Dim,
|
|
903
912
|
query_offset: int = 0,
|
|
904
913
|
dtype: Optional[str] = None,
|
|
914
|
+
device: Optional[str] = None,
|
|
905
915
|
) -> Tuple[Tensor, Dim]:
|
|
906
916
|
"""
|
|
907
917
|
Implements relative positional encoding, Transformer-XL style (https://arxiv.org/abs/1901.02860),
|
|
@@ -924,7 +934,9 @@ def relative_positional_encoding(
|
|
|
924
934
|
"""
|
|
925
935
|
if not dtype:
|
|
926
936
|
dtype = rf.get_default_float_dtype()
|
|
927
|
-
|
|
937
|
+
if not device:
|
|
938
|
+
device = rf.get_default_device()
|
|
939
|
+
cache_key = (query_spatial_dim, key_value_spatial_dim, feat_dim, query_offset, dtype, device)
|
|
928
940
|
cache_entry = _relative_positional_encoding_cache.get(cache_key)
|
|
929
941
|
if cache_entry is not None:
|
|
930
942
|
return cache_entry
|
|
@@ -932,7 +944,7 @@ def relative_positional_encoding(
|
|
|
932
944
|
|
|
933
945
|
with rf.control_flow_ctx(None):
|
|
934
946
|
# See also RelativePositionalEncodingLayer, LearnedRelativePositionalEncoding
|
|
935
|
-
indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset)
|
|
947
|
+
indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset, device=device)
|
|
936
948
|
|
|
937
949
|
feat2_dim = feat_dim.div_left(2)
|
|
938
950
|
div_term = rf.exp(rf.range_over_dim(feat2_dim, dtype=dtype) * -(2.0 * math.log(1e4) / feat_dim.dimension))
|
|
@@ -986,7 +998,6 @@ def sinusoidal_positional_encoding(
|
|
|
986
998
|
cache_entry = _sinusoidal_positional_encoding_cache.get(cache_key)
|
|
987
999
|
if cache_entry is not None:
|
|
988
1000
|
return cache_entry
|
|
989
|
-
import math
|
|
990
1001
|
|
|
991
1002
|
with rf.control_flow_ctx(None):
|
|
992
1003
|
# See also RelativePositionalEncodingLayer, LearnedRelativePositionalEncoding
|
|
@@ -997,26 +1008,49 @@ def sinusoidal_positional_encoding(
|
|
|
997
1008
|
indices = rf.range_over_dim(spatial_dim, device=device) # [len]
|
|
998
1009
|
if offset is not None:
|
|
999
1010
|
indices = indices + offset
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
feat2_dim = feat_dim.div_left(2)
|
|
1003
|
-
div_term = rf.exp(
|
|
1004
|
-
rf.range_over_dim(feat2_dim, dtype=dtype, device=device) * -(math.log(base) / (feat2_dim.dimension - 1))
|
|
1005
|
-
)
|
|
1006
|
-
arg_sin = rf.combine_bc(rf.cast(indices, dtype), "*", div_term)
|
|
1007
|
-
arg_cos = arg_sin + math.pi / 2.0
|
|
1008
|
-
arg, feat_dim_ = rf.concat((arg_sin, feat2_dim), (arg_cos, feat2_dim))
|
|
1009
|
-
arg, feat_dim_ = rf.replace_dim(arg, in_dim=feat_dim_, out_dim=feat_dim)
|
|
1010
|
-
emb = rf.sin(arg)
|
|
1011
|
+
emb = sinusoidal_encoding(indices, base=base, feat_dim=feat_dim, dtype=dtype)
|
|
1011
1012
|
emb.verify_out_shape(
|
|
1012
1013
|
{feat_dim} | indices.dims_set | ({spatial_dim} if spatial_dim != single_step_dim else set()),
|
|
1013
1014
|
allow_missing_implicit_dims=True,
|
|
1014
1015
|
)
|
|
1015
|
-
emb.feature_dim = feat_dim
|
|
1016
1016
|
_sinusoidal_positional_encoding_cache.set(cache_key, emb)
|
|
1017
1017
|
return emb
|
|
1018
1018
|
|
|
1019
1019
|
|
|
1020
|
+
def sinusoidal_encoding(
|
|
1021
|
+
indices: Tensor,
|
|
1022
|
+
*,
|
|
1023
|
+
feat_dim: Dim,
|
|
1024
|
+
base: Union[int, float] = 1e4,
|
|
1025
|
+
dtype: Optional[str] = None,
|
|
1026
|
+
) -> Tensor:
|
|
1027
|
+
"""
|
|
1028
|
+
|
|
1029
|
+
:param indices: [...], to be encoded
|
|
1030
|
+
:param feat_dim:
|
|
1031
|
+
:param base: base for the angles
|
|
1032
|
+
:param dtype: data type
|
|
1033
|
+
:return: tensor of shape [..., feat_dim]
|
|
1034
|
+
"""
|
|
1035
|
+
import math
|
|
1036
|
+
|
|
1037
|
+
if not dtype:
|
|
1038
|
+
dtype = rf.get_default_float_dtype()
|
|
1039
|
+
|
|
1040
|
+
device = indices.device
|
|
1041
|
+
feat2_dim = feat_dim.div_left(2)
|
|
1042
|
+
div_term = rf.exp(
|
|
1043
|
+
rf.range_over_dim(feat2_dim, dtype=dtype, device=device) * -(math.log(base) / (feat2_dim.dimension - 1))
|
|
1044
|
+
)
|
|
1045
|
+
arg_sin = rf.combine_bc(rf.cast(indices, dtype), "*", div_term)
|
|
1046
|
+
arg_cos = arg_sin + math.pi / 2.0
|
|
1047
|
+
arg, feat_dim_ = rf.concat((arg_sin, feat2_dim), (arg_cos, feat2_dim))
|
|
1048
|
+
arg, feat_dim_ = rf.replace_dim(arg, in_dim=feat_dim_, out_dim=feat_dim)
|
|
1049
|
+
emb = rf.sin(arg)
|
|
1050
|
+
emb.feature_dim = feat_dim
|
|
1051
|
+
return emb
|
|
1052
|
+
|
|
1053
|
+
|
|
1020
1054
|
_att_dropout_broadcast_shown_warning = False
|
|
1021
1055
|
|
|
1022
1056
|
|