xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +23 -1
  4. xinference/core/model.py +1 -6
  5. xinference/core/utils.py +10 -6
  6. xinference/model/audio/core.py +5 -0
  7. xinference/model/audio/cosyvoice.py +25 -3
  8. xinference/model/audio/f5tts.py +15 -10
  9. xinference/model/audio/f5tts_mlx.py +260 -0
  10. xinference/model/audio/fish_speech.py +35 -111
  11. xinference/model/audio/model_spec.json +19 -3
  12. xinference/model/audio/model_spec_modelscope.json +9 -0
  13. xinference/model/audio/utils.py +32 -0
  14. xinference/model/image/core.py +69 -1
  15. xinference/model/image/model_spec.json +127 -4
  16. xinference/model/image/model_spec_modelscope.json +130 -4
  17. xinference/model/image/stable_diffusion/core.py +45 -13
  18. xinference/model/llm/llm_family.json +47 -0
  19. xinference/model/llm/llm_family.py +15 -36
  20. xinference/model/llm/llm_family_modelscope.json +49 -0
  21. xinference/model/llm/mlx/core.py +68 -13
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  24. xinference/model/llm/utils.py +1 -0
  25. xinference/model/llm/vllm/core.py +11 -2
  26. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  27. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  28. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  29. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  30. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  31. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  34. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  35. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  36. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  37. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  38. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  39. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  40. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  41. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  42. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  43. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  44. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  45. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  46. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  47. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  48. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  49. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  50. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  51. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  52. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  53. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  54. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  55. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  56. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  57. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  58. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  59. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  60. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  61. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  62. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  63. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  64. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  65. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  66. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  67. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  68. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  69. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  70. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  71. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  72. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  73. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  74. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  75. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  76. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  77. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  78. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  79. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  80. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  81. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  82. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  83. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  84. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  85. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  86. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  87. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  88. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  89. xinference/thirdparty/matcha/utils/utils.py +2 -2
  90. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
  91. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
  92. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  93. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  94. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  95. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  96. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  99. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  100. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  101. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  102. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  103. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  104. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,279 @@
1
+ import base64
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Optional
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from whisper.tokenizer import Tokenizer
8
+
9
+ import tiktoken
10
+
11
+ LANGUAGES = {
12
+ "en": "english",
13
+ "zh": "chinese",
14
+ "de": "german",
15
+ "es": "spanish",
16
+ "ru": "russian",
17
+ "ko": "korean",
18
+ "fr": "french",
19
+ "ja": "japanese",
20
+ "pt": "portuguese",
21
+ "tr": "turkish",
22
+ "pl": "polish",
23
+ "ca": "catalan",
24
+ "nl": "dutch",
25
+ "ar": "arabic",
26
+ "sv": "swedish",
27
+ "it": "italian",
28
+ "id": "indonesian",
29
+ "hi": "hindi",
30
+ "fi": "finnish",
31
+ "vi": "vietnamese",
32
+ "he": "hebrew",
33
+ "uk": "ukrainian",
34
+ "el": "greek",
35
+ "ms": "malay",
36
+ "cs": "czech",
37
+ "ro": "romanian",
38
+ "da": "danish",
39
+ "hu": "hungarian",
40
+ "ta": "tamil",
41
+ "no": "norwegian",
42
+ "th": "thai",
43
+ "ur": "urdu",
44
+ "hr": "croatian",
45
+ "bg": "bulgarian",
46
+ "lt": "lithuanian",
47
+ "la": "latin",
48
+ "mi": "maori",
49
+ "ml": "malayalam",
50
+ "cy": "welsh",
51
+ "sk": "slovak",
52
+ "te": "telugu",
53
+ "fa": "persian",
54
+ "lv": "latvian",
55
+ "bn": "bengali",
56
+ "sr": "serbian",
57
+ "az": "azerbaijani",
58
+ "sl": "slovenian",
59
+ "kn": "kannada",
60
+ "et": "estonian",
61
+ "mk": "macedonian",
62
+ "br": "breton",
63
+ "eu": "basque",
64
+ "is": "icelandic",
65
+ "hy": "armenian",
66
+ "ne": "nepali",
67
+ "mn": "mongolian",
68
+ "bs": "bosnian",
69
+ "kk": "kazakh",
70
+ "sq": "albanian",
71
+ "sw": "swahili",
72
+ "gl": "galician",
73
+ "mr": "marathi",
74
+ "pa": "punjabi",
75
+ "si": "sinhala",
76
+ "km": "khmer",
77
+ "sn": "shona",
78
+ "yo": "yoruba",
79
+ "so": "somali",
80
+ "af": "afrikaans",
81
+ "oc": "occitan",
82
+ "ka": "georgian",
83
+ "be": "belarusian",
84
+ "tg": "tajik",
85
+ "sd": "sindhi",
86
+ "gu": "gujarati",
87
+ "am": "amharic",
88
+ "yi": "yiddish",
89
+ "lo": "lao",
90
+ "uz": "uzbek",
91
+ "fo": "faroese",
92
+ "ht": "haitian creole",
93
+ "ps": "pashto",
94
+ "tk": "turkmen",
95
+ "nn": "nynorsk",
96
+ "mt": "maltese",
97
+ "sa": "sanskrit",
98
+ "lb": "luxembourgish",
99
+ "my": "myanmar",
100
+ "bo": "tibetan",
101
+ "tl": "tagalog",
102
+ "mg": "malagasy",
103
+ "as": "assamese",
104
+ "tt": "tatar",
105
+ "haw": "hawaiian",
106
+ "ln": "lingala",
107
+ "ha": "hausa",
108
+ "ba": "bashkir",
109
+ "jw": "javanese",
110
+ "su": "sundanese",
111
+ "yue": "cantonese",
112
+ "minnan": "minnan",
113
+ "wuyu": "wuyu",
114
+ "dialect": "dialect",
115
+ "zh/en": "zh/en",
116
+ "en/zh": "en/zh",
117
+ }
118
+
119
+ # language code lookup by name, with a few language aliases
120
+ TO_LANGUAGE_CODE = {
121
+ **{language: code for code, language in LANGUAGES.items()},
122
+ "burmese": "my",
123
+ "valencian": "ca",
124
+ "flemish": "nl",
125
+ "haitian": "ht",
126
+ "letzeburgesch": "lb",
127
+ "pushto": "ps",
128
+ "panjabi": "pa",
129
+ "moldavian": "ro",
130
+ "moldovan": "ro",
131
+ "sinhalese": "si",
132
+ "castilian": "es",
133
+ "mandarin": "zh",
134
+ }
135
+
136
+ AUDIO_EVENT = {
137
+ "ASR": "ASR",
138
+ "AED": "AED",
139
+ "SER": "SER",
140
+ "Speech": "Speech",
141
+ "/Speech": "/Speech",
142
+ "BGM": "BGM",
143
+ "/BGM": "/BGM",
144
+ "Laughter": "Laughter",
145
+ "/Laughter": "/Laughter",
146
+ "Applause": "Applause",
147
+ "/Applause": "/Applause",
148
+ }
149
+
150
+ EMOTION = {
151
+ "HAPPY": "HAPPY",
152
+ "SAD": "SAD",
153
+ "ANGRY": "ANGRY",
154
+ "NEUTRAL": "NEUTRAL",
155
+ }
156
+
157
+ TTS_Vocal_Token = {
158
+ "TTS/B": "TTS/B",
159
+ "TTS/O": "TTS/O",
160
+ "TTS/Q": "TTS/Q",
161
+ "TTS/A": "TTS/A",
162
+ "TTS/CO": "TTS/CO",
163
+ "TTS/CL": "TTS/CL",
164
+ "TTS/H": "TTS/H",
165
+ **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166
+ }
167
+
168
+
169
+ @lru_cache(maxsize=None)
170
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
171
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172
+ ranks = {
173
+ base64.b64decode(token): int(rank)
174
+ for token, rank in (line.split() for line in open(vocab_path) if line)
175
+ }
176
+ n_vocab = len(ranks)
177
+ special_tokens = {}
178
+
179
+ specials = [
180
+ "<|endoftext|>",
181
+ "<|startoftranscript|>",
182
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
183
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
184
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
185
+ "<|translate|>",
186
+ "<|transcribe|>",
187
+ "<|startoflm|>",
188
+ "<|startofprev|>",
189
+ "<|nospeech|>",
190
+ "<|notimestamps|>",
191
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
192
+ *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
193
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
194
+ ]
195
+
196
+ for token in specials:
197
+ special_tokens[token] = n_vocab
198
+ n_vocab += 1
199
+
200
+ return tiktoken.Encoding(
201
+ name=os.path.basename(vocab_path),
202
+ explicit_n_vocab=n_vocab,
203
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
204
+ mergeable_ranks=ranks,
205
+ special_tokens=special_tokens,
206
+ )
207
+
208
+
209
+ @lru_cache(maxsize=None)
210
+ def get_tokenizer(
211
+ multilingual: bool,
212
+ *,
213
+ num_languages: int = 99,
214
+ language: Optional[str] = None,
215
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
216
+ ) -> Tokenizer:
217
+ if language is not None:
218
+ language = language.lower()
219
+ if language not in LANGUAGES:
220
+ if language in TO_LANGUAGE_CODE:
221
+ language = TO_LANGUAGE_CODE[language]
222
+ else:
223
+ raise ValueError(f"Unsupported language: {language}")
224
+
225
+ if multilingual:
226
+ encoding_name = "multilingual_zh_ja_yue_char_del"
227
+ language = language or "en"
228
+ task = task or "transcribe"
229
+ else:
230
+ encoding_name = "gpt2"
231
+ language = None
232
+ task = None
233
+
234
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
235
+
236
+ return Tokenizer(
237
+ encoding=encoding, num_languages=num_languages, language=language, task=task
238
+ )
239
+
240
+
241
+ class QwenTokenizer():
242
+ def __init__(self, token_path, skip_special_tokens=True):
243
+ super().__init__()
244
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
245
+ special_tokens = {
246
+ 'eos_token': '<|endoftext|>',
247
+ 'pad_token': '<|endoftext|>',
248
+ 'additional_special_tokens': [
249
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
250
+ '[breath]', '<strong>', '</strong>', '[noise]',
251
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
252
+ '[quick_breath]',
253
+ "<laughter>", "</laughter>",
254
+ "[hissing]", "[sigh]", "[vocalized-noise]",
255
+ "[lipsmack]", "[mn]"
256
+ ]
257
+ }
258
+ self.special_tokens = special_tokens
259
+ self.tokenizer = AutoTokenizer.from_pretrained(token_path)
260
+ self.tokenizer.add_special_tokens(special_tokens)
261
+ self.skip_special_tokens = skip_special_tokens
262
+
263
+ def encode(self, text, **kwargs):
264
+ tokens = self.tokenizer([text], return_tensors="pt")
265
+ tokens = tokens["input_ids"][0].cpu().tolist()
266
+ return tokens
267
+
268
+ def decode(self, tokens):
269
+ tokens = torch.tensor(tokens, dtype=torch.int64)
270
+ text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
271
+ return text
272
+
273
+
274
+ @lru_cache(maxsize=None)
275
+ def get_qwen_tokenizer(
276
+ token_path: str,
277
+ skip_special_tokens: bool
278
+ ) -> QwenTokenizer:
279
+ return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
212
212
 
213
213
  """
214
214
 
215
- def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
216
  """Construct an PositionalEncoding object."""
217
217
  super(EspnetRelPositionalEncoding, self).__init__()
218
218
  self.d_model = d_model
@@ -289,6 +289,6 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
289
289
  """
290
290
  pos_emb = self.pe[
291
291
  :,
292
- self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
292
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
293
  ]
294
294
  return pos_emb
@@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module):
49
49
  super().__init__()
50
50
  self.self_attn = self_attn
51
51
  self.feed_forward = feed_forward
52
- self.norm1 = nn.LayerNorm(size, eps=1e-5)
53
- self.norm2 = nn.LayerNorm(size, eps=1e-5)
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
54
  self.dropout = nn.Dropout(dropout_rate)
55
55
  self.size = size
56
56
  self.normalize_before = normalize_before
@@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module):
142
142
  self.feed_forward = feed_forward
143
143
  self.feed_forward_macaron = feed_forward_macaron
144
144
  self.conv_module = conv_module
145
- self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
146
- self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
147
  if feed_forward_macaron is not None:
148
- self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
149
  self.ff_scale = 0.5
150
150
  else:
151
151
  self.ff_scale = 1.0
152
152
  if self.conv_module is not None:
153
- self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
154
  self.norm_final = nn.LayerNorm(
155
- size, eps=1e-5) # for the final output of the block
155
+ size, eps=1e-12) # for the final output of the block
156
156
  self.dropout = nn.Dropout(dropout_rate)
157
157
  self.size = size
158
158
  self.normalize_before = normalize_before
@@ -0,0 +1,318 @@
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from cosyvoice.transformer.convolution import ConvolutionModule
25
+ from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
26
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from cosyvoice.utils.mask import make_pad_mask
34
+ from cosyvoice.utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class Upsample1D(nn.Module):
38
+ """A 1D upsampling layer with an optional convolution.
39
+
40
+ Parameters:
41
+ channels (`int`):
42
+ number of channels in the inputs and outputs.
43
+ use_conv (`bool`, default `False`):
44
+ option to use a convolution.
45
+ use_conv_transpose (`bool`, default `False`):
46
+ option to use a convolution transpose.
47
+ out_channels (`int`, optional):
48
+ number of output channels. Defaults to `channels`.
49
+ """
50
+
51
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels
55
+ self.stride = stride
56
+ # In this mode, first repeat interpolate, than conv with stride=1
57
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
+
59
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
60
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
62
+ outputs = self.conv(outputs)
63
+ return outputs, input_lengths * self.stride
64
+
65
+
66
+ class PreLookaheadLayer(nn.Module):
67
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
68
+ super().__init__()
69
+ self.channels = channels
70
+ self.pre_lookahead_len = pre_lookahead_len
71
+ self.conv1 = nn.Conv1d(
72
+ channels, channels,
73
+ kernel_size=pre_lookahead_len + 1,
74
+ stride=1, padding=0,
75
+ )
76
+ self.conv2 = nn.Conv1d(
77
+ channels, channels,
78
+ kernel_size=3, stride=1, padding=0,
79
+ )
80
+
81
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ inputs: (batch_size, seq_len, channels)
84
+ """
85
+ outputs = inputs.transpose(1, 2).contiguous()
86
+ # look ahead
87
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
88
+ outputs = F.leaky_relu(self.conv1(outputs))
89
+ # outputs
90
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
91
+ outputs = self.conv2(outputs)
92
+ outputs = outputs.transpose(1, 2).contiguous()
93
+
94
+ # residual connection
95
+ outputs = outputs + inputs
96
+ return outputs
97
+
98
+
99
+ class UpsampleConformerEncoder(torch.nn.Module):
100
+
101
+ def __init__(
102
+ self,
103
+ input_size: int,
104
+ output_size: int = 256,
105
+ attention_heads: int = 4,
106
+ linear_units: int = 2048,
107
+ num_blocks: int = 6,
108
+ dropout_rate: float = 0.1,
109
+ positional_dropout_rate: float = 0.1,
110
+ attention_dropout_rate: float = 0.0,
111
+ input_layer: str = "conv2d",
112
+ pos_enc_layer_type: str = "rel_pos",
113
+ normalize_before: bool = True,
114
+ static_chunk_size: int = 0,
115
+ use_dynamic_chunk: bool = False,
116
+ global_cmvn: torch.nn.Module = None,
117
+ use_dynamic_left_chunk: bool = False,
118
+ positionwise_conv_kernel_size: int = 1,
119
+ macaron_style: bool = True,
120
+ selfattention_layer_type: str = "rel_selfattn",
121
+ activation_type: str = "swish",
122
+ use_cnn_module: bool = True,
123
+ cnn_module_kernel: int = 15,
124
+ causal: bool = False,
125
+ cnn_module_norm: str = "batch_norm",
126
+ key_bias: bool = True,
127
+ gradient_checkpointing: bool = False,
128
+ ):
129
+ """
130
+ Args:
131
+ input_size (int): input dim
132
+ output_size (int): dimension of attention
133
+ attention_heads (int): the number of heads of multi head attention
134
+ linear_units (int): the hidden units number of position-wise feed
135
+ forward
136
+ num_blocks (int): the number of decoder blocks
137
+ dropout_rate (float): dropout rate
138
+ attention_dropout_rate (float): dropout rate in attention
139
+ positional_dropout_rate (float): dropout rate after adding
140
+ positional encoding
141
+ input_layer (str): input layer type.
142
+ optional [linear, conv2d, conv2d6, conv2d8]
143
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
144
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
145
+ normalize_before (bool):
146
+ True: use layer_norm before each sub-block of a layer.
147
+ False: use layer_norm after each sub-block of a layer.
148
+ static_chunk_size (int): chunk size for static chunk training and
149
+ decoding
150
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
151
+ training or not, You can only use fixed chunk(chunk_size > 0)
152
+ or dyanmic chunk size(use_dynamic_chunk = True)
153
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
154
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
155
+ dynamic chunk training
156
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
157
+ gradient_checkpointing: rerunning a forward-pass segment for each
158
+ checkpointed segment during backward.
159
+ """
160
+ super().__init__()
161
+ self._output_size = output_size
162
+
163
+ self.global_cmvn = global_cmvn
164
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
165
+ input_size,
166
+ output_size,
167
+ dropout_rate,
168
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
169
+ positional_dropout_rate),
170
+ )
171
+
172
+ self.normalize_before = normalize_before
173
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
174
+ self.static_chunk_size = static_chunk_size
175
+ self.use_dynamic_chunk = use_dynamic_chunk
176
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
177
+ self.gradient_checkpointing = gradient_checkpointing
178
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
179
+ # self-attention module definition
180
+ encoder_selfattn_layer_args = (
181
+ attention_heads,
182
+ output_size,
183
+ attention_dropout_rate,
184
+ key_bias,
185
+ )
186
+ # feed-forward module definition
187
+ positionwise_layer_args = (
188
+ output_size,
189
+ linear_units,
190
+ dropout_rate,
191
+ activation,
192
+ )
193
+ # convolution module definition
194
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
195
+ cnn_module_norm, causal)
196
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
197
+ self.encoders = torch.nn.ModuleList([
198
+ ConformerEncoderLayer(
199
+ output_size,
200
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
201
+ *encoder_selfattn_layer_args),
202
+ PositionwiseFeedForward(*positionwise_layer_args),
203
+ PositionwiseFeedForward(
204
+ *positionwise_layer_args) if macaron_style else None,
205
+ ConvolutionModule(
206
+ *convolution_layer_args) if use_cnn_module else None,
207
+ dropout_rate,
208
+ normalize_before,
209
+ ) for _ in range(num_blocks)
210
+ ])
211
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
212
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
213
+ input_size,
214
+ output_size,
215
+ dropout_rate,
216
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
217
+ positional_dropout_rate),
218
+ )
219
+ self.up_encoders = torch.nn.ModuleList([
220
+ ConformerEncoderLayer(
221
+ output_size,
222
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
223
+ *encoder_selfattn_layer_args),
224
+ PositionwiseFeedForward(*positionwise_layer_args),
225
+ PositionwiseFeedForward(
226
+ *positionwise_layer_args) if macaron_style else None,
227
+ ConvolutionModule(
228
+ *convolution_layer_args) if use_cnn_module else None,
229
+ dropout_rate,
230
+ normalize_before,
231
+ ) for _ in range(4)
232
+ ])
233
+
234
+ def output_size(self) -> int:
235
+ return self._output_size
236
+
237
+ def forward(
238
+ self,
239
+ xs: torch.Tensor,
240
+ xs_lens: torch.Tensor,
241
+ decoding_chunk_size: int = 0,
242
+ num_decoding_left_chunks: int = -1,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ """Embed positions in tensor.
245
+
246
+ Args:
247
+ xs: padded input tensor (B, T, D)
248
+ xs_lens: input length (B)
249
+ decoding_chunk_size: decoding chunk size for dynamic chunk
250
+ 0: default for training, use random dynamic chunk.
251
+ <0: for decoding, use full chunk.
252
+ >0: for decoding, use fixed chunk size as set.
253
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
254
+ the chunk size is decoding_chunk_size.
255
+ >=0: use num_decoding_left_chunks
256
+ <0: use all left chunks
257
+ Returns:
258
+ encoder output tensor xs, and subsampled masks
259
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
260
+ masks: torch.Tensor batch padding mask after subsample
261
+ (B, 1, T' ~= T/subsample_rate)
262
+ NOTE(xcsong):
263
+ We pass the `__call__` method of the modules instead of `forward` to the
264
+ checkpointing API because `__call__` attaches all the hooks of the module.
265
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
266
+ """
267
+ T = xs.size(1)
268
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
269
+ if self.global_cmvn is not None:
270
+ xs = self.global_cmvn(xs)
271
+ xs, pos_emb, masks = self.embed(xs, masks)
272
+ mask_pad = masks # (B, 1, T/subsample_rate)
273
+ chunk_masks = add_optional_chunk_mask(xs, masks,
274
+ self.use_dynamic_chunk,
275
+ self.use_dynamic_left_chunk,
276
+ decoding_chunk_size,
277
+ self.static_chunk_size,
278
+ num_decoding_left_chunks)
279
+ # lookahead + conformer encoder
280
+ xs = self.pre_lookahead_layer(xs)
281
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
282
+
283
+ # upsample + conformer encoder
284
+ xs = xs.transpose(1, 2).contiguous()
285
+ xs, xs_lens = self.up_layer(xs, xs_lens)
286
+ xs = xs.transpose(1, 2).contiguous()
287
+ T = xs.size(1)
288
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
289
+ xs, pos_emb, masks = self.up_embed(xs, masks)
290
+ mask_pad = masks # (B, 1, T/subsample_rate)
291
+ chunk_masks = add_optional_chunk_mask(xs, masks,
292
+ self.use_dynamic_chunk,
293
+ self.use_dynamic_left_chunk,
294
+ decoding_chunk_size,
295
+ self.static_chunk_size * self.up_layer.stride,
296
+ num_decoding_left_chunks)
297
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
298
+
299
+ if self.normalize_before:
300
+ xs = self.after_norm(xs)
301
+ # Here we assume the mask is not changed in encoder layers, so just
302
+ # return the masks before encoder layers, and the masks will be used
303
+ # for cross attention with decoder later
304
+ return xs, masks
305
+
306
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
307
+ pos_emb: torch.Tensor,
308
+ mask_pad: torch.Tensor) -> torch.Tensor:
309
+ for layer in self.encoders:
310
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
311
+ return xs
312
+
313
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
314
+ pos_emb: torch.Tensor,
315
+ mask_pad: torch.Tensor) -> torch.Tensor:
316
+ for layer in self.up_encoders:
317
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
318
+ return xs