minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,279 @@
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+ import tiktoken
10
+
11
+ LANGUAGES = {
12
+ "en": "english",
13
+ "zh": "chinese",
14
+ "de": "german",
15
+ "es": "spanish",
16
+ "ru": "russian",
17
+ "ko": "korean",
18
+ "fr": "french",
19
+ "ja": "japanese",
20
+ "pt": "portuguese",
21
+ "tr": "turkish",
22
+ "pl": "polish",
23
+ "ca": "catalan",
24
+ "nl": "dutch",
25
+ "ar": "arabic",
26
+ "sv": "swedish",
27
+ "it": "italian",
28
+ "id": "indonesian",
29
+ "hi": "hindi",
30
+ "fi": "finnish",
31
+ "vi": "vietnamese",
32
+ "he": "hebrew",
33
+ "uk": "ukrainian",
34
+ "el": "greek",
35
+ "ms": "malay",
36
+ "cs": "czech",
37
+ "ro": "romanian",
38
+ "da": "danish",
39
+ "hu": "hungarian",
40
+ "ta": "tamil",
41
+ "no": "norwegian",
42
+ "th": "thai",
43
+ "ur": "urdu",
44
+ "hr": "croatian",
45
+ "bg": "bulgarian",
46
+ "lt": "lithuanian",
47
+ "la": "latin",
48
+ "mi": "maori",
49
+ "ml": "malayalam",
50
+ "cy": "welsh",
51
+ "sk": "slovak",
52
+ "te": "telugu",
53
+ "fa": "persian",
54
+ "lv": "latvian",
55
+ "bn": "bengali",
56
+ "sr": "serbian",
57
+ "az": "azerbaijani",
58
+ "sl": "slovenian",
59
+ "kn": "kannada",
60
+ "et": "estonian",
61
+ "mk": "macedonian",
62
+ "br": "breton",
63
+ "eu": "basque",
64
+ "is": "icelandic",
65
+ "hy": "armenian",
66
+ "ne": "nepali",
67
+ "mn": "mongolian",
68
+ "bs": "bosnian",
69
+ "kk": "kazakh",
70
+ "sq": "albanian",
71
+ "sw": "swahili",
72
+ "gl": "galician",
73
+ "mr": "marathi",
74
+ "pa": "punjabi",
75
+ "si": "sinhala",
76
+ "km": "khmer",
77
+ "sn": "shona",
78
+ "yo": "yoruba",
79
+ "so": "somali",
80
+ "af": "afrikaans",
81
+ "oc": "occitan",
82
+ "ka": "georgian",
83
+ "be": "belarusian",
84
+ "tg": "tajik",
85
+ "sd": "sindhi",
86
+ "gu": "gujarati",
87
+ "am": "amharic",
88
+ "yi": "yiddish",
89
+ "lo": "lao",
90
+ "uz": "uzbek",
91
+ "fo": "faroese",
92
+ "ht": "haitian creole",
93
+ "ps": "pashto",
94
+ "tk": "turkmen",
95
+ "nn": "nynorsk",
96
+ "mt": "maltese",
97
+ "sa": "sanskrit",
98
+ "lb": "luxembourgish",
99
+ "my": "myanmar",
100
+ "bo": "tibetan",
101
+ "tl": "tagalog",
102
+ "mg": "malagasy",
103
+ "as": "assamese",
104
+ "tt": "tatar",
105
+ "haw": "hawaiian",
106
+ "ln": "lingala",
107
+ "ha": "hausa",
108
+ "ba": "bashkir",
109
+ "jw": "javanese",
110
+ "su": "sundanese",
111
+ "yue": "cantonese",
112
+ "minnan": "minnan",
113
+ "wuyu": "wuyu",
114
+ "dialect": "dialect",
115
+ "zh/en": "zh/en",
116
+ "en/zh": "en/zh",
117
+ }
118
+
119
+ # language code lookup by name, with a few language aliases
120
+ TO_LANGUAGE_CODE = {
121
+ **{language: code for code, language in LANGUAGES.items()},
122
+ "burmese": "my",
123
+ "valencian": "ca",
124
+ "flemish": "nl",
125
+ "haitian": "ht",
126
+ "letzeburgesch": "lb",
127
+ "pushto": "ps",
128
+ "panjabi": "pa",
129
+ "moldavian": "ro",
130
+ "moldovan": "ro",
131
+ "sinhalese": "si",
132
+ "castilian": "es",
133
+ "mandarin": "zh",
134
+ }
135
+
136
+ AUDIO_EVENT = {
137
+ "ASR": "ASR",
138
+ "AED": "AED",
139
+ "SER": "SER",
140
+ "Speech": "Speech",
141
+ "/Speech": "/Speech",
142
+ "BGM": "BGM",
143
+ "/BGM": "/BGM",
144
+ "Laughter": "Laughter",
145
+ "/Laughter": "/Laughter",
146
+ "Applause": "Applause",
147
+ "/Applause": "/Applause",
148
+ }
149
+
150
+ EMOTION = {
151
+ "HAPPY": "HAPPY",
152
+ "SAD": "SAD",
153
+ "ANGRY": "ANGRY",
154
+ "NEUTRAL": "NEUTRAL",
155
+ }
156
+
157
+ TTS_Vocal_Token = {
158
+ "TTS/B": "TTS/B",
159
+ "TTS/O": "TTS/O",
160
+ "TTS/Q": "TTS/Q",
161
+ "TTS/A": "TTS/A",
162
+ "TTS/CO": "TTS/CO",
163
+ "TTS/CL": "TTS/CL",
164
+ "TTS/H": "TTS/H",
165
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166
+ }
167
+
168
+
169
+ @lru_cache(maxsize=None)
170
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
171
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172
+ ranks = {
173
+ base64.b64decode(token): int(rank)
174
+ for token, rank in (line.split() for line in open(vocab_path) if line)
175
+ }
176
+ n_vocab = len(ranks)
177
+ special_tokens = {}
178
+
179
+ specials = [
180
+ "<|endoftext|>",
181
+ "<|startoftranscript|>",
182
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
183
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
184
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
185
+ "<|translate|>",
186
+ "<|transcribe|>",
187
+ "<|startoflm|>",
188
+ "<|startofprev|>",
189
+ "<|nospeech|>",
190
+ "<|notimestamps|>",
191
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
192
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
193
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
194
+ ]
195
+
196
+ for token in specials:
197
+ special_tokens[token] = n_vocab
198
+ n_vocab += 1
199
+
200
+ return tiktoken.Encoding(
201
+ name=os.path.basename(vocab_path),
202
+ explicit_n_vocab=n_vocab,
203
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
204
+ mergeable_ranks=ranks,
205
+ special_tokens=special_tokens,
206
+ )
207
+
208
+
209
+ @lru_cache(maxsize=None)
210
+ def get_tokenizer(
211
+ multilingual: bool,
212
+ *,
213
+ num_languages: int = 99,
214
+ language: Optional[str] = None,
215
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
216
+ ) -> Tokenizer:
217
+ if language is not None:
218
+ language = language.lower()
219
+ if language not in LANGUAGES:
220
+ if language in TO_LANGUAGE_CODE:
221
+ language = TO_LANGUAGE_CODE[language]
222
+ else:
223
+ raise ValueError(f"Unsupported language: {language}")
224
+
225
+ if multilingual:
226
+ encoding_name = "multilingual_zh_ja_yue_char_del"
227
+ language = language or "en"
228
+ task = task or "transcribe"
229
+ else:
230
+ encoding_name = "gpt2"
231
+ language = None
232
+ task = None
233
+
234
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
235
+
236
+ return Tokenizer(
237
+ encoding=encoding, num_languages=num_languages, language=language, task=task
238
+ )
239
+
240
+
241
+ class QwenTokenizer():
242
+ def __init__(self, token_path, skip_special_tokens=True):
243
+ super().__init__()
244
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
245
+ special_tokens = {
246
+ 'eos_token': '<|endoftext|>',
247
+ 'pad_token': '<|endoftext|>',
248
+ 'additional_special_tokens': [
249
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
250
+ '[breath]', '<strong>', '</strong>', '[noise]',
251
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
252
+ '[quick_breath]',
253
+ "<laughter>", "</laughter>",
254
+ "[hissing]", "[sigh]", "[vocalized-noise]",
255
+ "[lipsmack]", "[mn]"
256
+ ]
257
+ }
258
+ self.special_tokens = special_tokens
259
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
260
+ self.tokenizer.add_special_tokens(special_tokens)
261
+ self.skip_special_tokens = skip_special_tokens
262
+
263
+ def encode(self, text, **kwargs):
264
+ tokens = self.tokenizer([text], return_tensors="pt")
265
+ tokens = tokens["input_ids"][0].cpu().tolist()
266
+ return tokens
267
+
268
+ def decode(self, tokens):
269
+ tokens = torch.tensor(tokens, dtype=torch.int64)
270
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
271
+ return text
272
+
273
+
274
+ @lru_cache(maxsize=None)
275
+ def get_qwen_tokenizer(
276
+ token_path: str,
277
+ skip_special_tokens: bool
278
+ ) -> QwenTokenizer:
279
+ return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
File without changes
@@ -0,0 +1,84 @@
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
@@ -0,0 +1,330 @@
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache