returnn 1.20251027.232712__py3-none-any.whl → 1.20260119.15400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/__old_mod_loader__.py +26 -2
  3. returnn/_setup_info_generated.py +2 -2
  4. returnn/datasets/lm.py +130 -42
  5. returnn/datasets/meta.py +93 -43
  6. returnn/datasets/postprocessing.py +597 -108
  7. returnn/datasets/util/vocabulary.py +90 -0
  8. returnn/frontend/__init__.py +1 -0
  9. returnn/frontend/_backend.py +41 -0
  10. returnn/frontend/_native/__init__.py +22 -0
  11. returnn/frontend/_numpy_backend.py +7 -0
  12. returnn/frontend/_utils.py +1 -1
  13. returnn/frontend/array_.py +48 -2
  14. returnn/frontend/assert_.py +35 -0
  15. returnn/frontend/attention.py +54 -20
  16. returnn/frontend/conv.py +273 -54
  17. returnn/frontend/device.py +14 -1
  18. returnn/frontend/encoder/conformer.py +20 -0
  19. returnn/frontend/encoder/transformer.py +2 -0
  20. returnn/frontend/loss.py +222 -3
  21. returnn/frontend/math_.py +54 -14
  22. returnn/native_op.cpp +182 -172
  23. returnn/native_op.py +36 -31
  24. returnn/sprint/cache.py +12 -13
  25. returnn/tensor/_dim_extra.py +7 -7
  26. returnn/tensor/_tensor_extra.py +10 -10
  27. returnn/tensor/utils.py +8 -5
  28. returnn/tf/frontend_layers/_backend.py +7 -3
  29. returnn/tf/layers/basic.py +27 -40
  30. returnn/tf/native_op.py +27 -63
  31. returnn/tf/network.py +1 -1
  32. returnn/tf/util/basic.py +22 -197
  33. returnn/torch/engine.py +157 -6
  34. returnn/torch/frontend/_backend.py +280 -29
  35. returnn/torch/frontend/bridge.py +61 -0
  36. returnn/torch/frontend/compile_helper.py +106 -0
  37. returnn/torch/util/array_.py +30 -0
  38. returnn/torch/util/assert_.py +122 -0
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/torch/util/native_op.py +885 -0
  41. returnn/torch/util/native_op_code_compiler.py +308 -0
  42. returnn/util/basic.py +6 -7
  43. returnn/util/better_exchook.py +4 -0
  44. returnn/util/cuda_env.py +332 -0
  45. returnn/util/debug.py +12 -2
  46. returnn/util/file_cache.py +15 -1
  47. returnn/util/fsa.py +17 -13
  48. returnn/util/native_code_compiler.py +104 -47
  49. returnn/util/task_system.py +1 -1
  50. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/METADATA +2 -2
  51. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/RECORD +54 -48
  52. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/WHEEL +1 -1
  53. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/LICENSE +0 -0
  54. {returnn-1.20251027.232712.dist-info → returnn-1.20260119.15400.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ __all__ = [
11
11
  "SentencePieces",
12
12
  "CharacterTargets",
13
13
  "Utf8ByteTargets",
14
+ "HuggingFaceTokenizer",
14
15
  ]
15
16
 
16
17
  from typing import Optional, Union, Type, Callable, List, Dict
@@ -691,3 +692,92 @@ class Utf8ByteTargets(Vocabulary):
691
692
  assert ((seq >= 0) & (seq < 256)).all(), f"invalid byte value, must be within 0-255: {seq}"
692
693
  seq = seq.astype(numpy.uint8)
693
694
  return bytearray(seq).decode(encoding="utf8")
695
+
696
+
697
+ class HuggingFaceTokenizer(Vocabulary):
698
+ """
699
+ Uses the `AutoTokenizer` class from the `transformers` package.
700
+ """
701
+
702
+ def __init__(self, *, huggingface_repo_dir: str):
703
+ """
704
+ :param str huggingface_repo_dir: the directory containing the `tokenizer_config.json` file.
705
+ """
706
+ import transformers # noqa
707
+
708
+ # Make sure it is a string. (Could be e.g. Sis Path.)
709
+ huggingface_repo_dir = str(huggingface_repo_dir)
710
+ self._opts = {"huggingface_repo_dir": huggingface_repo_dir}
711
+ self._cache_key = huggingface_repo_dir
712
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(huggingface_repo_dir, trust_remote_code=True)
713
+ super().__init__(
714
+ vocab_file=None,
715
+ seq_postfix=None,
716
+ unknown_label=self.tokenizer.unk_token_id,
717
+ eos_label=self.tokenizer.eos_token_id,
718
+ bos_label=self.tokenizer.bos_token_id,
719
+ pad_label=self.tokenizer.pad_token_id,
720
+ )
721
+
722
+ def __repr__(self):
723
+ return "%s(%r)" % (self.__class__.__name__, self._opts)
724
+
725
+ def _parse_vocab(self):
726
+ self.num_labels = len(self.tokenizer)
727
+ # Do not load labels/vocab here. This is not really needed.
728
+
729
+ @property
730
+ def labels(self) -> List[str]:
731
+ """list of labels"""
732
+ if self._cache_key and self._cache_key in self._cache:
733
+ self._vocab, self._labels = self._cache[self._cache_key]
734
+ assert self.num_labels == len(self._vocab) == len(self._labels)
735
+ else:
736
+ self._labels = [self.tokenizer._convert_id_to_token(i) for i in range(self.num_labels)] # noqa
737
+ self._vocab = {label: i for (i, label) in enumerate(self._labels)}
738
+ if self._cache_key:
739
+ self._cache[self._cache_key] = (self._vocab, self._labels)
740
+ return self._labels
741
+
742
+ def is_id_valid(self, idx: int) -> bool:
743
+ """
744
+ :param idx:
745
+ """
746
+ return 0 <= idx < len(self.tokenizer)
747
+
748
+ def id_to_label(self, idx: int, default: Union[str, Type[KeyError], None] = KeyError) -> Optional[str]:
749
+ """
750
+ :param idx:
751
+ :param default:
752
+ """
753
+ if default is not KeyError and not self.is_id_valid(idx):
754
+ return default
755
+ return self.tokenizer.convert_ids_to_tokens(idx)
756
+
757
+ def label_to_id(self, label: str, default: Union[int, Type[KeyError], None] = KeyError) -> Optional[int]:
758
+ """
759
+ :param label:
760
+ :param default:
761
+ """
762
+ res = self.tokenizer.convert_token_to_id(label)
763
+ if res == self.unknown_label_id or res < 0 or res is None:
764
+ # It could be that the label really is the unknown-label, or it could be that the label is unknown.
765
+ if label == self.id_to_label(self.unknown_label_id):
766
+ return self.unknown_label_id
767
+ if default is KeyError:
768
+ raise KeyError("label %r not found" % label)
769
+ return default
770
+ return res
771
+
772
+ def get_seq(self, sentence: str) -> List[int]:
773
+ """
774
+ :param sentence: assumed to be seq of vocab entries separated by whitespace
775
+ """
776
+ return self.tokenizer(sentence)["input_ids"]
777
+
778
+ def get_seq_labels(self, seq):
779
+ """
780
+ :param list[int]|numpy.ndarray seq: 1D sequence
781
+ :rtype: str
782
+ """
783
+ return self.tokenizer.decode(seq, skip_special_tokens=True)
@@ -19,6 +19,7 @@ from .state import *
19
19
 
20
20
  # Now the rest, in alphabetical order.
21
21
  from .array_ import *
22
+ from .assert_ import *
22
23
  from .attention import *
23
24
  from .backend import *
24
25
  from .build_from_dict import *
@@ -42,6 +42,11 @@ class Backend(Generic[T]):
42
42
  """
43
43
  raise NotImplementedError
44
44
 
45
+ @staticmethod
46
+ def assert_(condition: Tensor, message: str):
47
+ """assert"""
48
+ raise NotImplementedError
49
+
45
50
  @staticmethod
46
51
  def get_tensor_dependencies(x: Tensor) -> Sequence[Tensor]:
47
52
  """
@@ -624,12 +629,48 @@ class Backend(Generic[T]):
624
629
  targets_spatial_dim: Dim,
625
630
  blank_index: int,
626
631
  max_approx: bool = False,
632
+ use_native_op: Optional[bool] = None,
633
+ label_loop: bool = True,
627
634
  ) -> Tensor:
628
635
  """
629
636
  Calculates the CTC loss.
630
637
  """
631
638
  raise NotImplementedError
632
639
 
640
+ @staticmethod
641
+ def ctc_best_path(
642
+ *,
643
+ logits: Tensor,
644
+ logits_normalized: bool = False,
645
+ targets: Tensor,
646
+ input_spatial_dim: Dim,
647
+ targets_spatial_dim: Dim,
648
+ blank_index: int,
649
+ label_loop: bool = True,
650
+ ) -> Tensor:
651
+ """
652
+ Calculates the CTC best path.
653
+ """
654
+ raise NotImplementedError
655
+
656
+ @staticmethod
657
+ def have_edit_distance() -> bool:
658
+ """
659
+ :return: whether we have an edit_distance implementation
660
+ """
661
+ return False
662
+
663
+ @staticmethod
664
+ def edit_distance(a: Tensor, a_spatial_dim: Dim, b: Tensor, b_spatial_dim: Dim) -> Tensor:
665
+ """
666
+ :param a: [B,Ta]
667
+ :param a_spatial_dim: Ta
668
+ :param b: [B,Tb]
669
+ :param b_spatial_dim: Tb
670
+ :return: [B]
671
+ """
672
+ raise NotImplementedError
673
+
633
674
  @staticmethod
634
675
  def have_sequence_mask_raw() -> bool:
635
676
  """
@@ -67,6 +67,24 @@ def _code_hash_md5(filename: str) -> str:
67
67
 
68
68
 
69
69
  _is_set_up = False
70
+ _enabled = True
71
+
72
+
73
+ def set_enabled(enabled: bool):
74
+ """
75
+ Enable or disable the native code setup.
76
+
77
+ :param enabled:
78
+ """
79
+ global _enabled
80
+ _enabled = enabled
81
+
82
+
83
+ def is_set_up() -> bool:
84
+ """
85
+ :return: whether the native code is set up
86
+ """
87
+ return _is_set_up
70
88
 
71
89
 
72
90
  def setup():
@@ -76,6 +94,8 @@ def setup():
76
94
  global _is_set_up
77
95
  if _is_set_up:
78
96
  return
97
+ if not _enabled:
98
+ return
79
99
  _is_set_up = True # only try once
80
100
 
81
101
  from returnn.tensor import Tensor, Dim
@@ -177,6 +197,8 @@ def setup_torch():
177
197
  global _is_set_up_torch
178
198
  if _is_set_up_torch:
179
199
  return
200
+ if not _enabled:
201
+ return
180
202
  _is_set_up_torch = True # only try once
181
203
 
182
204
  import torch
@@ -26,6 +26,13 @@ class NumpyBackend(Backend[numpy.ndarray]):
26
26
  """executing eagerly"""
27
27
  return True
28
28
 
29
+ @staticmethod
30
+ def assert_(condition: Tensor, message: str):
31
+ """assert"""
32
+ assert condition.dims == (), "condition for assert must be a scalar"
33
+ if not condition.raw_tensor.item():
34
+ raise AssertionError(message)
35
+
29
36
  @staticmethod
30
37
  def get_dtype_name_raw(raw_tensor: numpy.ndarray) -> str:
31
38
  """
@@ -110,7 +110,7 @@ def bin_op_out_template(
110
110
  all_dims.extend([dim_ for dim_ in a.dims if dim_ == dim])
111
111
  else:
112
112
  all_dims.extend([dim_ for dim_ in b.dims if dim_ == dim])
113
- if all(set(x.dims) != set(all_dims) for x in (a, b)):
113
+ if all([set(x.dims) != set(all_dims) for x in (a, b)]):
114
114
  if allow_broadcast_all_sources is False:
115
115
  raise ValueError(f"compare: sources {a!r} {b!r} not allowed with allow_broadcast_all_sources=False")
116
116
  elif allow_broadcast_all_sources is None:
@@ -54,6 +54,7 @@ __all__ = [
54
54
  "one_hot",
55
55
  "top_k_mask",
56
56
  "top_p_mask",
57
+ "repeat",
57
58
  ]
58
59
 
59
60
 
@@ -84,6 +85,10 @@ def convert_to_tensor(
84
85
  :return: tensor
85
86
  """
86
87
  if isinstance(value, Tensor): # fast path
88
+ if device and value.device != device:
89
+ value = rf.copy_to_device(value, device)
90
+ if dtype and value.dtype != dtype:
91
+ value = rf.cast(value, dtype=dtype)
87
92
  return value
88
93
  if isinstance(value, (tuple, list)):
89
94
  value = numpy.array(value, dtype=dtype)
@@ -190,7 +195,7 @@ def merge_dims(
190
195
  if out_dim is None:
191
196
  from returnn.util.basic import prod
192
197
 
193
- if any(d.need_masking() for d in dims[1:]):
198
+ if any([d.need_masking() for d in dims[1:]]):
194
199
  # The dynamic sizes as calculated via dim math would not correctly describe how the tensor looks like.
195
200
  # This would then potentially discard some of the data in the tensor in subsequent operations,
196
201
  # when masking is applied.
@@ -905,7 +910,7 @@ def scatter(
905
910
  else:
906
911
  raise ValueError(f"scatter: invalid mode {mode!r}")
907
912
  indices_dim = indices_dim if isinstance(indices_dim, (list, tuple)) else [indices_dim]
908
- if any(dim.need_masking() for dim in indices_dim):
913
+ if any([dim.need_masking() for dim in indices_dim]):
909
914
  if use_mask is None:
910
915
  use_mask = rf.use_mask_default(
911
916
  default=True, default_false_for_behavior_version_up_to=22, func_name="scatter"
@@ -1341,3 +1346,44 @@ def top_p_mask(
1341
1346
  mask = mask | (rf.range_over_dim(sorted_dim, device=mask.device) < min_tokens_to_keep)
1342
1347
  mask = rf.scatter(mask, indices=sorted_indices, indices_dim=sorted_dim)
1343
1348
  return mask
1349
+
1350
+
1351
+ def repeat(
1352
+ values: Tensor, *, in_spatial_dim: Dim, repeats: Tensor, out_spatial_dim: Optional[Dim] = None
1353
+ ) -> Tuple[Tensor, Dim]:
1354
+ """
1355
+ Repeats certain elements in a tensor along a given spatial dimension.
1356
+ 0 repeats means to remove that element.
1357
+
1358
+ This can be used to implement duration-based expansion, e.g. in text-to-speech.
1359
+
1360
+ :param values: [common..., values..., in_spatial_dim]
1361
+ :param in_spatial_dim:
1362
+ :param repeats: [common..., repeats..., in_spatial_dim] -> int32 durations / number of repetitions for each element
1363
+ :param out_spatial_dim:
1364
+ :return: expanded_values: [common..., values..., repeats..., out_spatial_dim], out_spatial_dim
1365
+ """
1366
+ # Similar to masked_select
1367
+ repeats = repeats.copy_masked(0, dims=[in_spatial_dim])
1368
+ idxs = rf.cumsum(repeats, spatial_dim=in_spatial_dim) # [batch...,in_spatial_dim] -> idx in out_spatial_dim + 1
1369
+ new_size = rf.gather(idxs, indices=in_spatial_dim.get_dim_value_tensor() - 1, axis=in_spatial_dim) # [batch...]
1370
+ dim_dev = rf.get_default_dim_size_device()
1371
+ if out_spatial_dim is None:
1372
+ out_spatial_dim = Dim(rf.copy_to_device(new_size, dim_dev), name="repeat")
1373
+ elif out_spatial_dim.dyn_size_ext is None:
1374
+ out_spatial_dim.dyn_size_ext = rf.copy_to_device(new_size, dim_dev)
1375
+ elif out_spatial_dim.dyn_size_ext is not None and out_spatial_dim.dyn_size_ext.raw_tensor is None:
1376
+ out_spatial_dim.dyn_size_ext.raw_tensor = rf.copy_to_device(new_size, dim_dev).raw_tensor
1377
+ out_spatial_dim_ext = out_spatial_dim + 1
1378
+ rel_idx_counts = rf.scatter(
1379
+ rf.expand_dims(rf.ones((), device=values.device, dtype="int32"), dims=idxs.dims),
1380
+ indices=idxs,
1381
+ indices_dim=in_spatial_dim,
1382
+ out_dim=out_spatial_dim_ext,
1383
+ )
1384
+ # rel_idx_counts: [batch...,out_spatial_dim+1] -> count of how many times each index was selected
1385
+ idxs_ = rf.cumsum(rel_idx_counts, spatial_dim=out_spatial_dim_ext)
1386
+ # idxs_: [batch...,out_spatial_dim+1] -> idx in in_spatial_dim
1387
+ idxs_, _ = rf.slice(idxs_, axis=out_spatial_dim_ext, size=out_spatial_dim) # remove last element
1388
+ # idxs_: [batch...,out_spatial_dim] -> idx in in_spatial_dim (potentially with invalid indices in padded area)
1389
+ return rf.gather(values, indices=idxs_, axis=in_spatial_dim, clip_to_valid=True), out_spatial_dim
@@ -0,0 +1,35 @@
1
+ """
2
+ Assertion utility functions for validating conditions in Python code.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from typing import Union
7
+ import returnn.frontend as rf
8
+ from returnn.tensor import Tensor
9
+
10
+
11
+ __all__ = ["assert_"]
12
+
13
+
14
+ def assert_(condition: Union[Tensor, bool], message: str):
15
+ """
16
+ Asserts that a given condition is True.
17
+ If the condition is False, raises an AssertionError with the provided message.
18
+ This runs async on GPU.
19
+
20
+ :param condition:
21
+ :param message:
22
+ :return: nothing
23
+ """
24
+ if isinstance(condition, bool):
25
+ if not condition:
26
+ raise AssertionError(message)
27
+
28
+ elif isinstance(condition, Tensor):
29
+ if condition.dims:
30
+ condition = rf.reduce_all(condition, axis=condition.dims) # reduce to scalar
31
+ # noinspection PyProtectedMember
32
+ condition._raw_backend.assert_(condition, message=message)
33
+
34
+ else:
35
+ raise TypeError(f"Condition must be a boolean or a Tensor, got {type(condition)}")
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "LearnedRelativePositionalEncoding",
25
25
  "relative_positional_encoding",
26
26
  "sinusoidal_positional_encoding",
27
+ "sinusoidal_encoding",
27
28
  ]
28
29
 
29
30
 
@@ -454,7 +455,7 @@ class RelPosSelfAttention(SelfAttentionBase):
454
455
  pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=axis)
455
456
  else:
456
457
  pos_emb, pos_emb_spatial_dim = relative_positional_encoding(
457
- query_spatial_dim=axis, key_value_spatial_dim=axis, feat_dim=self.pos_emb_feat_dim
458
+ query_spatial_dim=axis, key_value_spatial_dim=axis, feat_dim=self.pos_emb_feat_dim, device=source.device
458
459
  )
459
460
  if self.pos_emb_dropout:
460
461
  pos_emb = rf.dropout(pos_emb, self.pos_emb_dropout)
@@ -483,6 +484,7 @@ class RelPosSelfAttention(SelfAttentionBase):
483
484
  matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
484
485
 
485
486
  scores = matrix_ac + matrix_bd # (batch, head, time1, time2)
487
+ del matrix_ac, matrix_bd
486
488
  scores *= self.key_dim_per_head.dimension**-0.5
487
489
  att_weights = rf.softmax(scores, axis=hist_dim)
488
490
  att_weights = rf.dropout(att_weights, self.att_dropout, axis=self.att_dropout_broadcast and hist_dim)
@@ -609,7 +611,10 @@ class RelPosCausalSelfAttention(CausalSelfAttention):
609
611
  pos_emb, pos_emb_spatial_dim = self.learned_pos_emb(query_spatial_dim=axis, key_value_spatial_dim=hist_dim)
610
612
  else:
611
613
  pos_emb, pos_emb_spatial_dim = relative_positional_encoding(
612
- query_spatial_dim=axis, key_value_spatial_dim=hist_dim, feat_dim=self.pos_emb_feat_dim
614
+ query_spatial_dim=axis,
615
+ key_value_spatial_dim=hist_dim,
616
+ feat_dim=self.pos_emb_feat_dim,
617
+ device=source.device,
613
618
  )
614
619
  # pos_emb_spatial_dim is 2*time1-1 if axis!=single_step_dim, else time1
615
620
  if self.pos_emb_dropout:
@@ -724,6 +729,7 @@ class CrossAttention(rf.Module):
724
729
  """
725
730
  Transformer encoder output. This is intended as an initial API suggestion.
726
731
  """
732
+ assert axis in encoder.dims
727
733
  k, v = self.forward_kv(encoder)
728
734
  return rf.State(k=k, v=v, kv_axis=axis)
729
735
 
@@ -811,7 +817,9 @@ class LearnedRelativePositionalEncoding(rf.Module):
811
817
  :return: tensor of shape [spatial_dim * 2 - 1, feat_dim], and the out spatial dim (spatial_dim * 2 - 1).
812
818
  In the center is the rel pos i-j=0. All to the right are for i-j>0, all to the left for i-j<0.
813
819
  """
814
- indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset)
820
+ indices, out_spatial_dim = _make_indices(
821
+ query_spatial_dim, key_value_spatial_dim, query_offset, device=self.pos_emb.device
822
+ )
815
823
  indices = rf.clip_by_value(indices, -self.clipping, 0 if self.causal else self.clipping)
816
824
  # Shift values to be >= 0. Each integer still uniquely identifies a relative position difference.
817
825
  indices = indices + self.clipping
@@ -851,8 +859,9 @@ def _make_indices(
851
859
  query_spatial_dim: Dim,
852
860
  key_value_spatial_dim: Dim,
853
861
  query_offset: Optional[Union[int, Tensor]] = None,
862
+ device: Optional[str] = None,
854
863
  ) -> Tuple[Tensor, Dim]:
855
- kv_pos_vec = rf.range_over_dim(key_value_spatial_dim) # [kv_len]
864
+ kv_pos_vec = rf.range_over_dim(key_value_spatial_dim, device=device) # [kv_len]
856
865
 
857
866
  # See also RelativePositionalEncodingLayer
858
867
  if query_spatial_dim == single_step_dim:
@@ -865,7 +874,7 @@ def _make_indices(
865
874
  query_offset = key_value_spatial_dim.get_size_tensor() - 1
866
875
  else:
867
876
  query_spatial_dim_m1 = query_spatial_dim - 1
868
- q_pos_vec = rf.range_over_dim(query_spatial_dim_m1) # [q_len-1]
877
+ q_pos_vec = rf.range_over_dim(query_spatial_dim_m1, device=device) # [q_len-1]
869
878
 
870
879
  # The masking in the output is quite custom (left+right masking), so our seq lens don't make sense,
871
880
  # and might even cause to fail some tests (that e.g. max(q_seq_len+k_seq_len-1) == shape).
@@ -902,6 +911,7 @@ def relative_positional_encoding(
902
911
  feat_dim: Dim,
903
912
  query_offset: int = 0,
904
913
  dtype: Optional[str] = None,
914
+ device: Optional[str] = None,
905
915
  ) -> Tuple[Tensor, Dim]:
906
916
  """
907
917
  Implements relative positional encoding, Transformer-XL style (https://arxiv.org/abs/1901.02860),
@@ -924,7 +934,9 @@ def relative_positional_encoding(
924
934
  """
925
935
  if not dtype:
926
936
  dtype = rf.get_default_float_dtype()
927
- cache_key = (query_spatial_dim, key_value_spatial_dim, feat_dim, query_offset, dtype)
937
+ if not device:
938
+ device = rf.get_default_device()
939
+ cache_key = (query_spatial_dim, key_value_spatial_dim, feat_dim, query_offset, dtype, device)
928
940
  cache_entry = _relative_positional_encoding_cache.get(cache_key)
929
941
  if cache_entry is not None:
930
942
  return cache_entry
@@ -932,7 +944,7 @@ def relative_positional_encoding(
932
944
 
933
945
  with rf.control_flow_ctx(None):
934
946
  # See also RelativePositionalEncodingLayer, LearnedRelativePositionalEncoding
935
- indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset)
947
+ indices, out_spatial_dim = _make_indices(query_spatial_dim, key_value_spatial_dim, query_offset, device=device)
936
948
 
937
949
  feat2_dim = feat_dim.div_left(2)
938
950
  div_term = rf.exp(rf.range_over_dim(feat2_dim, dtype=dtype) * -(2.0 * math.log(1e4) / feat_dim.dimension))
@@ -986,7 +998,6 @@ def sinusoidal_positional_encoding(
986
998
  cache_entry = _sinusoidal_positional_encoding_cache.get(cache_key)
987
999
  if cache_entry is not None:
988
1000
  return cache_entry
989
- import math
990
1001
 
991
1002
  with rf.control_flow_ctx(None):
992
1003
  # See also RelativePositionalEncodingLayer, LearnedRelativePositionalEncoding
@@ -997,26 +1008,49 @@ def sinusoidal_positional_encoding(
997
1008
  indices = rf.range_over_dim(spatial_dim, device=device) # [len]
998
1009
  if offset is not None:
999
1010
  indices = indices + offset
1000
- indices = rf.copy_to_device(indices, device)
1001
-
1002
- feat2_dim = feat_dim.div_left(2)
1003
- div_term = rf.exp(
1004
- rf.range_over_dim(feat2_dim, dtype=dtype, device=device) * -(math.log(base) / (feat2_dim.dimension - 1))
1005
- )
1006
- arg_sin = rf.combine_bc(rf.cast(indices, dtype), "*", div_term)
1007
- arg_cos = arg_sin + math.pi / 2.0
1008
- arg, feat_dim_ = rf.concat((arg_sin, feat2_dim), (arg_cos, feat2_dim))
1009
- arg, feat_dim_ = rf.replace_dim(arg, in_dim=feat_dim_, out_dim=feat_dim)
1010
- emb = rf.sin(arg)
1011
+ emb = sinusoidal_encoding(indices, base=base, feat_dim=feat_dim, dtype=dtype)
1011
1012
  emb.verify_out_shape(
1012
1013
  {feat_dim} | indices.dims_set | ({spatial_dim} if spatial_dim != single_step_dim else set()),
1013
1014
  allow_missing_implicit_dims=True,
1014
1015
  )
1015
- emb.feature_dim = feat_dim
1016
1016
  _sinusoidal_positional_encoding_cache.set(cache_key, emb)
1017
1017
  return emb
1018
1018
 
1019
1019
 
1020
+ def sinusoidal_encoding(
1021
+ indices: Tensor,
1022
+ *,
1023
+ feat_dim: Dim,
1024
+ base: Union[int, float] = 1e4,
1025
+ dtype: Optional[str] = None,
1026
+ ) -> Tensor:
1027
+ """
1028
+
1029
+ :param indices: [...], to be encoded
1030
+ :param feat_dim:
1031
+ :param base: base for the angles
1032
+ :param dtype: data type
1033
+ :return: tensor of shape [..., feat_dim]
1034
+ """
1035
+ import math
1036
+
1037
+ if not dtype:
1038
+ dtype = rf.get_default_float_dtype()
1039
+
1040
+ device = indices.device
1041
+ feat2_dim = feat_dim.div_left(2)
1042
+ div_term = rf.exp(
1043
+ rf.range_over_dim(feat2_dim, dtype=dtype, device=device) * -(math.log(base) / (feat2_dim.dimension - 1))
1044
+ )
1045
+ arg_sin = rf.combine_bc(rf.cast(indices, dtype), "*", div_term)
1046
+ arg_cos = arg_sin + math.pi / 2.0
1047
+ arg, feat_dim_ = rf.concat((arg_sin, feat2_dim), (arg_cos, feat2_dim))
1048
+ arg, feat_dim_ = rf.replace_dim(arg, in_dim=feat_dim_, out_dim=feat_dim)
1049
+ emb = rf.sin(arg)
1050
+ emb.feature_dim = feat_dim
1051
+ return emb
1052
+
1053
+
1020
1054
  _att_dropout_broadcast_shown_warning = False
1021
1055
 
1022
1056