minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import Optional
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import AutoTokenizer
|
|
7
|
+
from whisper.tokenizer import Tokenizer
|
|
8
|
+
|
|
9
|
+
import tiktoken
|
|
10
|
+
|
|
11
|
+
LANGUAGES = {
|
|
12
|
+
"en": "english",
|
|
13
|
+
"zh": "chinese",
|
|
14
|
+
"de": "german",
|
|
15
|
+
"es": "spanish",
|
|
16
|
+
"ru": "russian",
|
|
17
|
+
"ko": "korean",
|
|
18
|
+
"fr": "french",
|
|
19
|
+
"ja": "japanese",
|
|
20
|
+
"pt": "portuguese",
|
|
21
|
+
"tr": "turkish",
|
|
22
|
+
"pl": "polish",
|
|
23
|
+
"ca": "catalan",
|
|
24
|
+
"nl": "dutch",
|
|
25
|
+
"ar": "arabic",
|
|
26
|
+
"sv": "swedish",
|
|
27
|
+
"it": "italian",
|
|
28
|
+
"id": "indonesian",
|
|
29
|
+
"hi": "hindi",
|
|
30
|
+
"fi": "finnish",
|
|
31
|
+
"vi": "vietnamese",
|
|
32
|
+
"he": "hebrew",
|
|
33
|
+
"uk": "ukrainian",
|
|
34
|
+
"el": "greek",
|
|
35
|
+
"ms": "malay",
|
|
36
|
+
"cs": "czech",
|
|
37
|
+
"ro": "romanian",
|
|
38
|
+
"da": "danish",
|
|
39
|
+
"hu": "hungarian",
|
|
40
|
+
"ta": "tamil",
|
|
41
|
+
"no": "norwegian",
|
|
42
|
+
"th": "thai",
|
|
43
|
+
"ur": "urdu",
|
|
44
|
+
"hr": "croatian",
|
|
45
|
+
"bg": "bulgarian",
|
|
46
|
+
"lt": "lithuanian",
|
|
47
|
+
"la": "latin",
|
|
48
|
+
"mi": "maori",
|
|
49
|
+
"ml": "malayalam",
|
|
50
|
+
"cy": "welsh",
|
|
51
|
+
"sk": "slovak",
|
|
52
|
+
"te": "telugu",
|
|
53
|
+
"fa": "persian",
|
|
54
|
+
"lv": "latvian",
|
|
55
|
+
"bn": "bengali",
|
|
56
|
+
"sr": "serbian",
|
|
57
|
+
"az": "azerbaijani",
|
|
58
|
+
"sl": "slovenian",
|
|
59
|
+
"kn": "kannada",
|
|
60
|
+
"et": "estonian",
|
|
61
|
+
"mk": "macedonian",
|
|
62
|
+
"br": "breton",
|
|
63
|
+
"eu": "basque",
|
|
64
|
+
"is": "icelandic",
|
|
65
|
+
"hy": "armenian",
|
|
66
|
+
"ne": "nepali",
|
|
67
|
+
"mn": "mongolian",
|
|
68
|
+
"bs": "bosnian",
|
|
69
|
+
"kk": "kazakh",
|
|
70
|
+
"sq": "albanian",
|
|
71
|
+
"sw": "swahili",
|
|
72
|
+
"gl": "galician",
|
|
73
|
+
"mr": "marathi",
|
|
74
|
+
"pa": "punjabi",
|
|
75
|
+
"si": "sinhala",
|
|
76
|
+
"km": "khmer",
|
|
77
|
+
"sn": "shona",
|
|
78
|
+
"yo": "yoruba",
|
|
79
|
+
"so": "somali",
|
|
80
|
+
"af": "afrikaans",
|
|
81
|
+
"oc": "occitan",
|
|
82
|
+
"ka": "georgian",
|
|
83
|
+
"be": "belarusian",
|
|
84
|
+
"tg": "tajik",
|
|
85
|
+
"sd": "sindhi",
|
|
86
|
+
"gu": "gujarati",
|
|
87
|
+
"am": "amharic",
|
|
88
|
+
"yi": "yiddish",
|
|
89
|
+
"lo": "lao",
|
|
90
|
+
"uz": "uzbek",
|
|
91
|
+
"fo": "faroese",
|
|
92
|
+
"ht": "haitian creole",
|
|
93
|
+
"ps": "pashto",
|
|
94
|
+
"tk": "turkmen",
|
|
95
|
+
"nn": "nynorsk",
|
|
96
|
+
"mt": "maltese",
|
|
97
|
+
"sa": "sanskrit",
|
|
98
|
+
"lb": "luxembourgish",
|
|
99
|
+
"my": "myanmar",
|
|
100
|
+
"bo": "tibetan",
|
|
101
|
+
"tl": "tagalog",
|
|
102
|
+
"mg": "malagasy",
|
|
103
|
+
"as": "assamese",
|
|
104
|
+
"tt": "tatar",
|
|
105
|
+
"haw": "hawaiian",
|
|
106
|
+
"ln": "lingala",
|
|
107
|
+
"ha": "hausa",
|
|
108
|
+
"ba": "bashkir",
|
|
109
|
+
"jw": "javanese",
|
|
110
|
+
"su": "sundanese",
|
|
111
|
+
"yue": "cantonese",
|
|
112
|
+
"minnan": "minnan",
|
|
113
|
+
"wuyu": "wuyu",
|
|
114
|
+
"dialect": "dialect",
|
|
115
|
+
"zh/en": "zh/en",
|
|
116
|
+
"en/zh": "en/zh",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
# language code lookup by name, with a few language aliases
|
|
120
|
+
TO_LANGUAGE_CODE = {
|
|
121
|
+
**{language: code for code, language in LANGUAGES.items()},
|
|
122
|
+
"burmese": "my",
|
|
123
|
+
"valencian": "ca",
|
|
124
|
+
"flemish": "nl",
|
|
125
|
+
"haitian": "ht",
|
|
126
|
+
"letzeburgesch": "lb",
|
|
127
|
+
"pushto": "ps",
|
|
128
|
+
"panjabi": "pa",
|
|
129
|
+
"moldavian": "ro",
|
|
130
|
+
"moldovan": "ro",
|
|
131
|
+
"sinhalese": "si",
|
|
132
|
+
"castilian": "es",
|
|
133
|
+
"mandarin": "zh",
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
AUDIO_EVENT = {
|
|
137
|
+
"ASR": "ASR",
|
|
138
|
+
"AED": "AED",
|
|
139
|
+
"SER": "SER",
|
|
140
|
+
"Speech": "Speech",
|
|
141
|
+
"/Speech": "/Speech",
|
|
142
|
+
"BGM": "BGM",
|
|
143
|
+
"/BGM": "/BGM",
|
|
144
|
+
"Laughter": "Laughter",
|
|
145
|
+
"/Laughter": "/Laughter",
|
|
146
|
+
"Applause": "Applause",
|
|
147
|
+
"/Applause": "/Applause",
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
EMOTION = {
|
|
151
|
+
"HAPPY": "HAPPY",
|
|
152
|
+
"SAD": "SAD",
|
|
153
|
+
"ANGRY": "ANGRY",
|
|
154
|
+
"NEUTRAL": "NEUTRAL",
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
TTS_Vocal_Token = {
|
|
158
|
+
"TTS/B": "TTS/B",
|
|
159
|
+
"TTS/O": "TTS/O",
|
|
160
|
+
"TTS/Q": "TTS/Q",
|
|
161
|
+
"TTS/A": "TTS/A",
|
|
162
|
+
"TTS/CO": "TTS/CO",
|
|
163
|
+
"TTS/CL": "TTS/CL",
|
|
164
|
+
"TTS/H": "TTS/H",
|
|
165
|
+
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@lru_cache(maxsize=None)
|
|
170
|
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
|
171
|
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
|
172
|
+
ranks = {
|
|
173
|
+
base64.b64decode(token): int(rank)
|
|
174
|
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
175
|
+
}
|
|
176
|
+
n_vocab = len(ranks)
|
|
177
|
+
special_tokens = {}
|
|
178
|
+
|
|
179
|
+
specials = [
|
|
180
|
+
"<|endoftext|>",
|
|
181
|
+
"<|startoftranscript|>",
|
|
182
|
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
|
183
|
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
|
184
|
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
|
185
|
+
"<|translate|>",
|
|
186
|
+
"<|transcribe|>",
|
|
187
|
+
"<|startoflm|>",
|
|
188
|
+
"<|startofprev|>",
|
|
189
|
+
"<|nospeech|>",
|
|
190
|
+
"<|notimestamps|>",
|
|
191
|
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
|
192
|
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
|
193
|
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
for token in specials:
|
|
197
|
+
special_tokens[token] = n_vocab
|
|
198
|
+
n_vocab += 1
|
|
199
|
+
|
|
200
|
+
return tiktoken.Encoding(
|
|
201
|
+
name=os.path.basename(vocab_path),
|
|
202
|
+
explicit_n_vocab=n_vocab,
|
|
203
|
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
204
|
+
mergeable_ranks=ranks,
|
|
205
|
+
special_tokens=special_tokens,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@lru_cache(maxsize=None)
|
|
210
|
+
def get_tokenizer(
|
|
211
|
+
multilingual: bool,
|
|
212
|
+
*,
|
|
213
|
+
num_languages: int = 99,
|
|
214
|
+
language: Optional[str] = None,
|
|
215
|
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
216
|
+
) -> Tokenizer:
|
|
217
|
+
if language is not None:
|
|
218
|
+
language = language.lower()
|
|
219
|
+
if language not in LANGUAGES:
|
|
220
|
+
if language in TO_LANGUAGE_CODE:
|
|
221
|
+
language = TO_LANGUAGE_CODE[language]
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(f"Unsupported language: {language}")
|
|
224
|
+
|
|
225
|
+
if multilingual:
|
|
226
|
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
|
227
|
+
language = language or "en"
|
|
228
|
+
task = task or "transcribe"
|
|
229
|
+
else:
|
|
230
|
+
encoding_name = "gpt2"
|
|
231
|
+
language = None
|
|
232
|
+
task = None
|
|
233
|
+
|
|
234
|
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
|
235
|
+
|
|
236
|
+
return Tokenizer(
|
|
237
|
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class QwenTokenizer():
|
|
242
|
+
def __init__(self, token_path, skip_special_tokens=True):
|
|
243
|
+
super().__init__()
|
|
244
|
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
|
245
|
+
special_tokens = {
|
|
246
|
+
'eos_token': '<|endoftext|>',
|
|
247
|
+
'pad_token': '<|endoftext|>',
|
|
248
|
+
'additional_special_tokens': [
|
|
249
|
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
|
250
|
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
|
251
|
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
|
252
|
+
'[quick_breath]',
|
|
253
|
+
"<laughter>", "</laughter>",
|
|
254
|
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
|
255
|
+
"[lipsmack]", "[mn]"
|
|
256
|
+
]
|
|
257
|
+
}
|
|
258
|
+
self.special_tokens = special_tokens
|
|
259
|
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
|
260
|
+
self.tokenizer.add_special_tokens(special_tokens)
|
|
261
|
+
self.skip_special_tokens = skip_special_tokens
|
|
262
|
+
|
|
263
|
+
def encode(self, text, **kwargs):
|
|
264
|
+
tokens = self.tokenizer([text], return_tensors="pt")
|
|
265
|
+
tokens = tokens["input_ids"][0].cpu().tolist()
|
|
266
|
+
return tokens
|
|
267
|
+
|
|
268
|
+
def decode(self, tokens):
|
|
269
|
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
|
270
|
+
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
|
271
|
+
return text
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@lru_cache(maxsize=None)
|
|
275
|
+
def get_qwen_tokenizer(
|
|
276
|
+
token_path: str,
|
|
277
|
+
skip_special_tokens: bool
|
|
278
|
+
) -> QwenTokenizer:
|
|
279
|
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
|
File without changes
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
|
2
|
+
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
|
3
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
4
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Swish() activation function for Conformer."""
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import nn, sin, pow
|
|
21
|
+
from torch.nn import Parameter
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Swish(torch.nn.Module):
|
|
25
|
+
"""Construct an Swish object."""
|
|
26
|
+
|
|
27
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
28
|
+
"""Return Swish activation function."""
|
|
29
|
+
return x * torch.sigmoid(x)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
|
33
|
+
# LICENSE is in incl_licenses directory.
|
|
34
|
+
class Snake(nn.Module):
|
|
35
|
+
'''
|
|
36
|
+
Implementation of a sine-based periodic activation function
|
|
37
|
+
Shape:
|
|
38
|
+
- Input: (B, C, T)
|
|
39
|
+
- Output: (B, C, T), same shape as the input
|
|
40
|
+
Parameters:
|
|
41
|
+
- alpha - trainable parameter
|
|
42
|
+
References:
|
|
43
|
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
|
44
|
+
https://arxiv.org/abs/2006.08195
|
|
45
|
+
Examples:
|
|
46
|
+
>>> a1 = snake(256)
|
|
47
|
+
>>> x = torch.randn(256)
|
|
48
|
+
>>> x = a1(x)
|
|
49
|
+
'''
|
|
50
|
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
|
51
|
+
'''
|
|
52
|
+
Initialization.
|
|
53
|
+
INPUT:
|
|
54
|
+
- in_features: shape of the input
|
|
55
|
+
- alpha: trainable parameter
|
|
56
|
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
|
57
|
+
alpha will be trained along with the rest of your model.
|
|
58
|
+
'''
|
|
59
|
+
super(Snake, self).__init__()
|
|
60
|
+
self.in_features = in_features
|
|
61
|
+
|
|
62
|
+
# initialize alpha
|
|
63
|
+
self.alpha_logscale = alpha_logscale
|
|
64
|
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
65
|
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
|
66
|
+
else: # linear scale alphas initialized to ones
|
|
67
|
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
|
68
|
+
|
|
69
|
+
self.alpha.requires_grad = alpha_trainable
|
|
70
|
+
|
|
71
|
+
self.no_div_by_zero = 0.000000001
|
|
72
|
+
|
|
73
|
+
def forward(self, x):
|
|
74
|
+
'''
|
|
75
|
+
Forward pass of the function.
|
|
76
|
+
Applies the function to the input elementwise.
|
|
77
|
+
Snake ∶= x + 1/a * sin^2 (xa)
|
|
78
|
+
'''
|
|
79
|
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
|
80
|
+
if self.alpha_logscale:
|
|
81
|
+
alpha = torch.exp(alpha)
|
|
82
|
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
|
83
|
+
|
|
84
|
+
return x
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# Copyright (c) 2019 Shigeki Karita
|
|
2
|
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
3
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
4
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
"""Multi-Head Attention layer definition."""
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
from typing import Tuple
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MultiHeadedAttention(nn.Module):
|
|
27
|
+
"""Multi-Head Attention layer.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
n_head (int): The number of heads.
|
|
31
|
+
n_feat (int): The number of features.
|
|
32
|
+
dropout_rate (float): Dropout rate.
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
n_head: int,
|
|
38
|
+
n_feat: int,
|
|
39
|
+
dropout_rate: float,
|
|
40
|
+
key_bias: bool = True):
|
|
41
|
+
"""Construct an MultiHeadedAttention object."""
|
|
42
|
+
super().__init__()
|
|
43
|
+
assert n_feat % n_head == 0
|
|
44
|
+
# We assume d_v always equals d_k
|
|
45
|
+
self.d_k = n_feat // n_head
|
|
46
|
+
self.h = n_head
|
|
47
|
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
|
48
|
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
|
49
|
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
|
50
|
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
|
51
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
52
|
+
|
|
53
|
+
def forward_qkv(
|
|
54
|
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
|
55
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
56
|
+
"""Transform query, key and value.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
60
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
61
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: Transformed query tensor, size
|
|
65
|
+
(#batch, n_head, time1, d_k).
|
|
66
|
+
torch.Tensor: Transformed key tensor, size
|
|
67
|
+
(#batch, n_head, time2, d_k).
|
|
68
|
+
torch.Tensor: Transformed value tensor, size
|
|
69
|
+
(#batch, n_head, time2, d_k).
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
n_batch = query.size(0)
|
|
73
|
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
|
74
|
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
|
75
|
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
|
76
|
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
|
77
|
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
|
78
|
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
|
79
|
+
|
|
80
|
+
return q, k, v
|
|
81
|
+
|
|
82
|
+
def forward_attention(
|
|
83
|
+
self,
|
|
84
|
+
value: torch.Tensor,
|
|
85
|
+
scores: torch.Tensor,
|
|
86
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
|
87
|
+
) -> torch.Tensor:
|
|
88
|
+
"""Compute attention context vector.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
value (torch.Tensor): Transformed value, size
|
|
92
|
+
(#batch, n_head, time2, d_k).
|
|
93
|
+
scores (torch.Tensor): Attention score, size
|
|
94
|
+
(#batch, n_head, time1, time2).
|
|
95
|
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
|
96
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
|
100
|
+
weighted by the attention score (#batch, time1, time2).
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
n_batch = value.size(0)
|
|
104
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
|
105
|
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
|
106
|
+
# 1st chunk to ease the onnx export.]
|
|
107
|
+
# 2. pytorch training
|
|
108
|
+
if mask.size(2) > 0: # time2 > 0
|
|
109
|
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
110
|
+
# For last chunk, time2 might be larger than scores.size(-1)
|
|
111
|
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
|
112
|
+
scores = scores.masked_fill(mask, -float('inf'))
|
|
113
|
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
|
114
|
+
mask, 0.0) # (batch, head, time1, time2)
|
|
115
|
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
|
116
|
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
|
117
|
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
|
118
|
+
else:
|
|
119
|
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
120
|
+
|
|
121
|
+
p_attn = self.dropout(attn)
|
|
122
|
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
|
123
|
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
|
124
|
+
self.h * self.d_k)
|
|
125
|
+
) # (batch, time1, d_model)
|
|
126
|
+
|
|
127
|
+
return self.linear_out(x) # (batch, time1, d_model)
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
query: torch.Tensor,
|
|
132
|
+
key: torch.Tensor,
|
|
133
|
+
value: torch.Tensor,
|
|
134
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
135
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
136
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
137
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
"""Compute scaled dot product attention.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
142
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
143
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
144
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
145
|
+
(#batch, time1, time2).
|
|
146
|
+
1.When applying cross attention between decoder and encoder,
|
|
147
|
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
|
148
|
+
2.When applying self attention of encoder,
|
|
149
|
+
the mask is in (#batch, T, T) shape.
|
|
150
|
+
3.When applying self attention of decoder,
|
|
151
|
+
the mask is in (#batch, L, L) shape.
|
|
152
|
+
4.If the different position in decoder see different block
|
|
153
|
+
of the encoder, such as Mocha, the passed in mask could be
|
|
154
|
+
in (#batch, L, T) shape. But there is no such case in current
|
|
155
|
+
CosyVoice.
|
|
156
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
157
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
158
|
+
and `head * d_k == size`
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
163
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
164
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
165
|
+
and `head * d_k == size`
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
169
|
+
|
|
170
|
+
# NOTE(xcsong):
|
|
171
|
+
# when export onnx model, for 1st chunk, we feed
|
|
172
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
173
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
174
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
175
|
+
# and we will always do splitting and
|
|
176
|
+
# concatnation(this will simplify onnx export). Note that
|
|
177
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
178
|
+
# when export jit model, for 1st chunk, we always feed
|
|
179
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
180
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
181
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
182
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
183
|
+
# >>> torch.equal(b, c) # True
|
|
184
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
185
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
186
|
+
if cache.size(0) > 0:
|
|
187
|
+
key_cache, value_cache = torch.split(cache,
|
|
188
|
+
cache.size(-1) // 2,
|
|
189
|
+
dim=-1)
|
|
190
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
191
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
192
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
193
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
194
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
195
|
+
|
|
196
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
197
|
+
return self.forward_attention(v, scores, mask), new_cache
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
201
|
+
"""Multi-Head Attention layer with relative position encoding.
|
|
202
|
+
Paper: https://arxiv.org/abs/1901.02860
|
|
203
|
+
Args:
|
|
204
|
+
n_head (int): The number of heads.
|
|
205
|
+
n_feat (int): The number of features.
|
|
206
|
+
dropout_rate (float): Dropout rate.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(self,
|
|
210
|
+
n_head: int,
|
|
211
|
+
n_feat: int,
|
|
212
|
+
dropout_rate: float,
|
|
213
|
+
key_bias: bool = True):
|
|
214
|
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
|
215
|
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
|
216
|
+
# linear transformation for positional encoding
|
|
217
|
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
|
218
|
+
# these two learnable bias are used in matrix c and matrix d
|
|
219
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
220
|
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
221
|
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
|
222
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
|
223
|
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
|
224
|
+
|
|
225
|
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
|
226
|
+
"""Compute relative positional encoding.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
|
230
|
+
time1 means the length of query vector.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
torch.Tensor: Output tensor.
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
|
237
|
+
device=x.device,
|
|
238
|
+
dtype=x.dtype)
|
|
239
|
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
|
240
|
+
|
|
241
|
+
x_padded = x_padded.view(x.size()[0],
|
|
242
|
+
x.size()[1],
|
|
243
|
+
x.size(3) + 1, x.size(2))
|
|
244
|
+
x = x_padded[:, :, 1:].view_as(x)[
|
|
245
|
+
:, :, :, : x.size(-1) // 2 + 1
|
|
246
|
+
] # only keep the positions from 0 to time2
|
|
247
|
+
return x
|
|
248
|
+
|
|
249
|
+
def forward(
|
|
250
|
+
self,
|
|
251
|
+
query: torch.Tensor,
|
|
252
|
+
key: torch.Tensor,
|
|
253
|
+
value: torch.Tensor,
|
|
254
|
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
|
255
|
+
pos_emb: torch.Tensor = torch.empty(0),
|
|
256
|
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
|
257
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
258
|
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
|
259
|
+
Args:
|
|
260
|
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
261
|
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
262
|
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
263
|
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
264
|
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
|
265
|
+
pos_emb (torch.Tensor): Positional embedding tensor
|
|
266
|
+
(#batch, time2, size).
|
|
267
|
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
|
268
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
269
|
+
and `head * d_k == size`
|
|
270
|
+
Returns:
|
|
271
|
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
272
|
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
|
273
|
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
|
274
|
+
and `head * d_k == size`
|
|
275
|
+
"""
|
|
276
|
+
q, k, v = self.forward_qkv(query, key, value)
|
|
277
|
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
|
278
|
+
|
|
279
|
+
# NOTE(xcsong):
|
|
280
|
+
# when export onnx model, for 1st chunk, we feed
|
|
281
|
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
|
282
|
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
|
283
|
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
|
284
|
+
# and we will always do splitting and
|
|
285
|
+
# concatnation(this will simplify onnx export). Note that
|
|
286
|
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
|
287
|
+
# when export jit model, for 1st chunk, we always feed
|
|
288
|
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
|
289
|
+
# >>> a = torch.ones((1, 2, 0, 4))
|
|
290
|
+
# >>> b = torch.ones((1, 2, 3, 4))
|
|
291
|
+
# >>> c = torch.cat((a, b), dim=2)
|
|
292
|
+
# >>> torch.equal(b, c) # True
|
|
293
|
+
# >>> d = torch.split(a, 2, dim=-1)
|
|
294
|
+
# >>> torch.equal(d[0], d[1]) # True
|
|
295
|
+
if cache.size(0) > 0:
|
|
296
|
+
key_cache, value_cache = torch.split(cache,
|
|
297
|
+
cache.size(-1) // 2,
|
|
298
|
+
dim=-1)
|
|
299
|
+
k = torch.cat([key_cache, k], dim=2)
|
|
300
|
+
v = torch.cat([value_cache, v], dim=2)
|
|
301
|
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
|
302
|
+
# non-trivial to calculate `next_cache_start` here.
|
|
303
|
+
new_cache = torch.cat((k, v), dim=-1)
|
|
304
|
+
|
|
305
|
+
n_batch_pos = pos_emb.size(0)
|
|
306
|
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
|
307
|
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
|
308
|
+
|
|
309
|
+
# (batch, head, time1, d_k)
|
|
310
|
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
|
311
|
+
# (batch, head, time1, d_k)
|
|
312
|
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
|
313
|
+
|
|
314
|
+
# compute attention score
|
|
315
|
+
# first compute matrix a and matrix c
|
|
316
|
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
|
317
|
+
# (batch, head, time1, time2)
|
|
318
|
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
|
319
|
+
|
|
320
|
+
# compute matrix b and matrix d
|
|
321
|
+
# (batch, head, time1, time2)
|
|
322
|
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
|
323
|
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
|
324
|
+
if matrix_ac.shape != matrix_bd.shape:
|
|
325
|
+
matrix_bd = self.rel_shift(matrix_bd)
|
|
326
|
+
|
|
327
|
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
|
328
|
+
self.d_k) # (batch, head, time1, time2)
|
|
329
|
+
|
|
330
|
+
return self.forward_attention(v, scores, mask), new_cache
|