xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import Optional
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import AutoTokenizer
|
|
7
|
+
from whisper.tokenizer import Tokenizer
|
|
8
|
+
|
|
9
|
+
import tiktoken
|
|
10
|
+
|
|
11
|
+
LANGUAGES = {
|
|
12
|
+
"en": "english",
|
|
13
|
+
"zh": "chinese",
|
|
14
|
+
"de": "german",
|
|
15
|
+
"es": "spanish",
|
|
16
|
+
"ru": "russian",
|
|
17
|
+
"ko": "korean",
|
|
18
|
+
"fr": "french",
|
|
19
|
+
"ja": "japanese",
|
|
20
|
+
"pt": "portuguese",
|
|
21
|
+
"tr": "turkish",
|
|
22
|
+
"pl": "polish",
|
|
23
|
+
"ca": "catalan",
|
|
24
|
+
"nl": "dutch",
|
|
25
|
+
"ar": "arabic",
|
|
26
|
+
"sv": "swedish",
|
|
27
|
+
"it": "italian",
|
|
28
|
+
"id": "indonesian",
|
|
29
|
+
"hi": "hindi",
|
|
30
|
+
"fi": "finnish",
|
|
31
|
+
"vi": "vietnamese",
|
|
32
|
+
"he": "hebrew",
|
|
33
|
+
"uk": "ukrainian",
|
|
34
|
+
"el": "greek",
|
|
35
|
+
"ms": "malay",
|
|
36
|
+
"cs": "czech",
|
|
37
|
+
"ro": "romanian",
|
|
38
|
+
"da": "danish",
|
|
39
|
+
"hu": "hungarian",
|
|
40
|
+
"ta": "tamil",
|
|
41
|
+
"no": "norwegian",
|
|
42
|
+
"th": "thai",
|
|
43
|
+
"ur": "urdu",
|
|
44
|
+
"hr": "croatian",
|
|
45
|
+
"bg": "bulgarian",
|
|
46
|
+
"lt": "lithuanian",
|
|
47
|
+
"la": "latin",
|
|
48
|
+
"mi": "maori",
|
|
49
|
+
"ml": "malayalam",
|
|
50
|
+
"cy": "welsh",
|
|
51
|
+
"sk": "slovak",
|
|
52
|
+
"te": "telugu",
|
|
53
|
+
"fa": "persian",
|
|
54
|
+
"lv": "latvian",
|
|
55
|
+
"bn": "bengali",
|
|
56
|
+
"sr": "serbian",
|
|
57
|
+
"az": "azerbaijani",
|
|
58
|
+
"sl": "slovenian",
|
|
59
|
+
"kn": "kannada",
|
|
60
|
+
"et": "estonian",
|
|
61
|
+
"mk": "macedonian",
|
|
62
|
+
"br": "breton",
|
|
63
|
+
"eu": "basque",
|
|
64
|
+
"is": "icelandic",
|
|
65
|
+
"hy": "armenian",
|
|
66
|
+
"ne": "nepali",
|
|
67
|
+
"mn": "mongolian",
|
|
68
|
+
"bs": "bosnian",
|
|
69
|
+
"kk": "kazakh",
|
|
70
|
+
"sq": "albanian",
|
|
71
|
+
"sw": "swahili",
|
|
72
|
+
"gl": "galician",
|
|
73
|
+
"mr": "marathi",
|
|
74
|
+
"pa": "punjabi",
|
|
75
|
+
"si": "sinhala",
|
|
76
|
+
"km": "khmer",
|
|
77
|
+
"sn": "shona",
|
|
78
|
+
"yo": "yoruba",
|
|
79
|
+
"so": "somali",
|
|
80
|
+
"af": "afrikaans",
|
|
81
|
+
"oc": "occitan",
|
|
82
|
+
"ka": "georgian",
|
|
83
|
+
"be": "belarusian",
|
|
84
|
+
"tg": "tajik",
|
|
85
|
+
"sd": "sindhi",
|
|
86
|
+
"gu": "gujarati",
|
|
87
|
+
"am": "amharic",
|
|
88
|
+
"yi": "yiddish",
|
|
89
|
+
"lo": "lao",
|
|
90
|
+
"uz": "uzbek",
|
|
91
|
+
"fo": "faroese",
|
|
92
|
+
"ht": "haitian creole",
|
|
93
|
+
"ps": "pashto",
|
|
94
|
+
"tk": "turkmen",
|
|
95
|
+
"nn": "nynorsk",
|
|
96
|
+
"mt": "maltese",
|
|
97
|
+
"sa": "sanskrit",
|
|
98
|
+
"lb": "luxembourgish",
|
|
99
|
+
"my": "myanmar",
|
|
100
|
+
"bo": "tibetan",
|
|
101
|
+
"tl": "tagalog",
|
|
102
|
+
"mg": "malagasy",
|
|
103
|
+
"as": "assamese",
|
|
104
|
+
"tt": "tatar",
|
|
105
|
+
"haw": "hawaiian",
|
|
106
|
+
"ln": "lingala",
|
|
107
|
+
"ha": "hausa",
|
|
108
|
+
"ba": "bashkir",
|
|
109
|
+
"jw": "javanese",
|
|
110
|
+
"su": "sundanese",
|
|
111
|
+
"yue": "cantonese",
|
|
112
|
+
"minnan": "minnan",
|
|
113
|
+
"wuyu": "wuyu",
|
|
114
|
+
"dialect": "dialect",
|
|
115
|
+
"zh/en": "zh/en",
|
|
116
|
+
"en/zh": "en/zh",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
# language code lookup by name, with a few language aliases
|
|
120
|
+
TO_LANGUAGE_CODE = {
|
|
121
|
+
**{language: code for code, language in LANGUAGES.items()},
|
|
122
|
+
"burmese": "my",
|
|
123
|
+
"valencian": "ca",
|
|
124
|
+
"flemish": "nl",
|
|
125
|
+
"haitian": "ht",
|
|
126
|
+
"letzeburgesch": "lb",
|
|
127
|
+
"pushto": "ps",
|
|
128
|
+
"panjabi": "pa",
|
|
129
|
+
"moldavian": "ro",
|
|
130
|
+
"moldovan": "ro",
|
|
131
|
+
"sinhalese": "si",
|
|
132
|
+
"castilian": "es",
|
|
133
|
+
"mandarin": "zh",
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
AUDIO_EVENT = {
|
|
137
|
+
"ASR": "ASR",
|
|
138
|
+
"AED": "AED",
|
|
139
|
+
"SER": "SER",
|
|
140
|
+
"Speech": "Speech",
|
|
141
|
+
"/Speech": "/Speech",
|
|
142
|
+
"BGM": "BGM",
|
|
143
|
+
"/BGM": "/BGM",
|
|
144
|
+
"Laughter": "Laughter",
|
|
145
|
+
"/Laughter": "/Laughter",
|
|
146
|
+
"Applause": "Applause",
|
|
147
|
+
"/Applause": "/Applause",
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
EMOTION = {
|
|
151
|
+
"HAPPY": "HAPPY",
|
|
152
|
+
"SAD": "SAD",
|
|
153
|
+
"ANGRY": "ANGRY",
|
|
154
|
+
"NEUTRAL": "NEUTRAL",
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
TTS_Vocal_Token = {
|
|
158
|
+
"TTS/B": "TTS/B",
|
|
159
|
+
"TTS/O": "TTS/O",
|
|
160
|
+
"TTS/Q": "TTS/Q",
|
|
161
|
+
"TTS/A": "TTS/A",
|
|
162
|
+
"TTS/CO": "TTS/CO",
|
|
163
|
+
"TTS/CL": "TTS/CL",
|
|
164
|
+
"TTS/H": "TTS/H",
|
|
165
|
+
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@lru_cache(maxsize=None)
|
|
170
|
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
|
171
|
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
|
172
|
+
ranks = {
|
|
173
|
+
base64.b64decode(token): int(rank)
|
|
174
|
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
175
|
+
}
|
|
176
|
+
n_vocab = len(ranks)
|
|
177
|
+
special_tokens = {}
|
|
178
|
+
|
|
179
|
+
specials = [
|
|
180
|
+
"<|endoftext|>",
|
|
181
|
+
"<|startoftranscript|>",
|
|
182
|
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
|
183
|
+
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
|
184
|
+
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
|
185
|
+
"<|translate|>",
|
|
186
|
+
"<|transcribe|>",
|
|
187
|
+
"<|startoflm|>",
|
|
188
|
+
"<|startofprev|>",
|
|
189
|
+
"<|nospeech|>",
|
|
190
|
+
"<|notimestamps|>",
|
|
191
|
+
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
|
192
|
+
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
|
193
|
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
for token in specials:
|
|
197
|
+
special_tokens[token] = n_vocab
|
|
198
|
+
n_vocab += 1
|
|
199
|
+
|
|
200
|
+
return tiktoken.Encoding(
|
|
201
|
+
name=os.path.basename(vocab_path),
|
|
202
|
+
explicit_n_vocab=n_vocab,
|
|
203
|
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
204
|
+
mergeable_ranks=ranks,
|
|
205
|
+
special_tokens=special_tokens,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@lru_cache(maxsize=None)
|
|
210
|
+
def get_tokenizer(
|
|
211
|
+
multilingual: bool,
|
|
212
|
+
*,
|
|
213
|
+
num_languages: int = 99,
|
|
214
|
+
language: Optional[str] = None,
|
|
215
|
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
216
|
+
) -> Tokenizer:
|
|
217
|
+
if language is not None:
|
|
218
|
+
language = language.lower()
|
|
219
|
+
if language not in LANGUAGES:
|
|
220
|
+
if language in TO_LANGUAGE_CODE:
|
|
221
|
+
language = TO_LANGUAGE_CODE[language]
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(f"Unsupported language: {language}")
|
|
224
|
+
|
|
225
|
+
if multilingual:
|
|
226
|
+
encoding_name = "multilingual_zh_ja_yue_char_del"
|
|
227
|
+
language = language or "en"
|
|
228
|
+
task = task or "transcribe"
|
|
229
|
+
else:
|
|
230
|
+
encoding_name = "gpt2"
|
|
231
|
+
language = None
|
|
232
|
+
task = None
|
|
233
|
+
|
|
234
|
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
|
235
|
+
|
|
236
|
+
return Tokenizer(
|
|
237
|
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class QwenTokenizer():
|
|
242
|
+
def __init__(self, token_path, skip_special_tokens=True):
|
|
243
|
+
super().__init__()
|
|
244
|
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
|
245
|
+
special_tokens = {
|
|
246
|
+
'eos_token': '<|endoftext|>',
|
|
247
|
+
'pad_token': '<|endoftext|>',
|
|
248
|
+
'additional_special_tokens': [
|
|
249
|
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
|
250
|
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
|
251
|
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
|
252
|
+
'[quick_breath]',
|
|
253
|
+
"<laughter>", "</laughter>",
|
|
254
|
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
|
255
|
+
"[lipsmack]", "[mn]"
|
|
256
|
+
]
|
|
257
|
+
}
|
|
258
|
+
self.special_tokens = special_tokens
|
|
259
|
+
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
|
260
|
+
self.tokenizer.add_special_tokens(special_tokens)
|
|
261
|
+
self.skip_special_tokens = skip_special_tokens
|
|
262
|
+
|
|
263
|
+
def encode(self, text, **kwargs):
|
|
264
|
+
tokens = self.tokenizer([text], return_tensors="pt")
|
|
265
|
+
tokens = tokens["input_ids"][0].cpu().tolist()
|
|
266
|
+
return tokens
|
|
267
|
+
|
|
268
|
+
def decode(self, tokens):
|
|
269
|
+
tokens = torch.tensor(tokens, dtype=torch.int64)
|
|
270
|
+
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
|
271
|
+
return text
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@lru_cache(maxsize=None)
|
|
275
|
+
def get_qwen_tokenizer(
|
|
276
|
+
token_path: str,
|
|
277
|
+
skip_special_tokens: bool
|
|
278
|
+
) -> QwenTokenizer:
|
|
279
|
+
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
|
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
|
212
212
|
|
|
213
213
|
"""
|
|
214
214
|
|
|
215
|
-
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
|
|
215
|
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
|
216
216
|
"""Construct an PositionalEncoding object."""
|
|
217
217
|
super(EspnetRelPositionalEncoding, self).__init__()
|
|
218
218
|
self.d_model = d_model
|
|
@@ -289,6 +289,6 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
|
289
289
|
"""
|
|
290
290
|
pos_emb = self.pe[
|
|
291
291
|
:,
|
|
292
|
-
self.pe.size(1) // 2 - size + 1
|
|
292
|
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
|
293
293
|
]
|
|
294
294
|
return pos_emb
|
|
@@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module):
|
|
|
49
49
|
super().__init__()
|
|
50
50
|
self.self_attn = self_attn
|
|
51
51
|
self.feed_forward = feed_forward
|
|
52
|
-
self.norm1 = nn.LayerNorm(size, eps=1e-
|
|
53
|
-
self.norm2 = nn.LayerNorm(size, eps=1e-
|
|
52
|
+
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
|
53
|
+
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
|
54
54
|
self.dropout = nn.Dropout(dropout_rate)
|
|
55
55
|
self.size = size
|
|
56
56
|
self.normalize_before = normalize_before
|
|
@@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module):
|
|
|
142
142
|
self.feed_forward = feed_forward
|
|
143
143
|
self.feed_forward_macaron = feed_forward_macaron
|
|
144
144
|
self.conv_module = conv_module
|
|
145
|
-
self.norm_ff = nn.LayerNorm(size, eps=1e-
|
|
146
|
-
self.norm_mha = nn.LayerNorm(size, eps=1e-
|
|
145
|
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
|
146
|
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
|
147
147
|
if feed_forward_macaron is not None:
|
|
148
|
-
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-
|
|
148
|
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
|
149
149
|
self.ff_scale = 0.5
|
|
150
150
|
else:
|
|
151
151
|
self.ff_scale = 1.0
|
|
152
152
|
if self.conv_module is not None:
|
|
153
|
-
self.norm_conv = nn.LayerNorm(size, eps=1e-
|
|
153
|
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
|
154
154
|
self.norm_final = nn.LayerNorm(
|
|
155
|
-
size, eps=1e-
|
|
155
|
+
size, eps=1e-12) # for the final output of the block
|
|
156
156
|
self.dropout = nn.Dropout(dropout_rate)
|
|
157
157
|
self.size = size
|
|
158
158
|
self.normalize_before = normalize_before
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
|
2
|
+
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
|
3
|
+
# 2024 Alibaba Inc (Xiang Lyu)
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
|
17
|
+
"""Encoder definition."""
|
|
18
|
+
from typing import Tuple
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
from torch.nn import functional as F
|
|
23
|
+
|
|
24
|
+
from cosyvoice.transformer.convolution import ConvolutionModule
|
|
25
|
+
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
|
26
|
+
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
|
27
|
+
from cosyvoice.utils.class_utils import (
|
|
28
|
+
COSYVOICE_EMB_CLASSES,
|
|
29
|
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
|
30
|
+
COSYVOICE_ATTENTION_CLASSES,
|
|
31
|
+
COSYVOICE_ACTIVATION_CLASSES,
|
|
32
|
+
)
|
|
33
|
+
from cosyvoice.utils.mask import make_pad_mask
|
|
34
|
+
from cosyvoice.utils.mask import add_optional_chunk_mask
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Upsample1D(nn.Module):
|
|
38
|
+
"""A 1D upsampling layer with an optional convolution.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
channels (`int`):
|
|
42
|
+
number of channels in the inputs and outputs.
|
|
43
|
+
use_conv (`bool`, default `False`):
|
|
44
|
+
option to use a convolution.
|
|
45
|
+
use_conv_transpose (`bool`, default `False`):
|
|
46
|
+
option to use a convolution transpose.
|
|
47
|
+
out_channels (`int`, optional):
|
|
48
|
+
number of output channels. Defaults to `channels`.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.channels = channels
|
|
54
|
+
self.out_channels = out_channels
|
|
55
|
+
self.stride = stride
|
|
56
|
+
# In this mode, first repeat interpolate, than conv with stride=1
|
|
57
|
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
|
58
|
+
|
|
59
|
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
|
60
|
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
|
61
|
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
|
62
|
+
outputs = self.conv(outputs)
|
|
63
|
+
return outputs, input_lengths * self.stride
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PreLookaheadLayer(nn.Module):
|
|
67
|
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.channels = channels
|
|
70
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
71
|
+
self.conv1 = nn.Conv1d(
|
|
72
|
+
channels, channels,
|
|
73
|
+
kernel_size=pre_lookahead_len + 1,
|
|
74
|
+
stride=1, padding=0,
|
|
75
|
+
)
|
|
76
|
+
self.conv2 = nn.Conv1d(
|
|
77
|
+
channels, channels,
|
|
78
|
+
kernel_size=3, stride=1, padding=0,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
82
|
+
"""
|
|
83
|
+
inputs: (batch_size, seq_len, channels)
|
|
84
|
+
"""
|
|
85
|
+
outputs = inputs.transpose(1, 2).contiguous()
|
|
86
|
+
# look ahead
|
|
87
|
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
|
88
|
+
outputs = F.leaky_relu(self.conv1(outputs))
|
|
89
|
+
# outputs
|
|
90
|
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
|
|
91
|
+
outputs = self.conv2(outputs)
|
|
92
|
+
outputs = outputs.transpose(1, 2).contiguous()
|
|
93
|
+
|
|
94
|
+
# residual connection
|
|
95
|
+
outputs = outputs + inputs
|
|
96
|
+
return outputs
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class UpsampleConformerEncoder(torch.nn.Module):
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
input_size: int,
|
|
104
|
+
output_size: int = 256,
|
|
105
|
+
attention_heads: int = 4,
|
|
106
|
+
linear_units: int = 2048,
|
|
107
|
+
num_blocks: int = 6,
|
|
108
|
+
dropout_rate: float = 0.1,
|
|
109
|
+
positional_dropout_rate: float = 0.1,
|
|
110
|
+
attention_dropout_rate: float = 0.0,
|
|
111
|
+
input_layer: str = "conv2d",
|
|
112
|
+
pos_enc_layer_type: str = "rel_pos",
|
|
113
|
+
normalize_before: bool = True,
|
|
114
|
+
static_chunk_size: int = 0,
|
|
115
|
+
use_dynamic_chunk: bool = False,
|
|
116
|
+
global_cmvn: torch.nn.Module = None,
|
|
117
|
+
use_dynamic_left_chunk: bool = False,
|
|
118
|
+
positionwise_conv_kernel_size: int = 1,
|
|
119
|
+
macaron_style: bool = True,
|
|
120
|
+
selfattention_layer_type: str = "rel_selfattn",
|
|
121
|
+
activation_type: str = "swish",
|
|
122
|
+
use_cnn_module: bool = True,
|
|
123
|
+
cnn_module_kernel: int = 15,
|
|
124
|
+
causal: bool = False,
|
|
125
|
+
cnn_module_norm: str = "batch_norm",
|
|
126
|
+
key_bias: bool = True,
|
|
127
|
+
gradient_checkpointing: bool = False,
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
Args:
|
|
131
|
+
input_size (int): input dim
|
|
132
|
+
output_size (int): dimension of attention
|
|
133
|
+
attention_heads (int): the number of heads of multi head attention
|
|
134
|
+
linear_units (int): the hidden units number of position-wise feed
|
|
135
|
+
forward
|
|
136
|
+
num_blocks (int): the number of decoder blocks
|
|
137
|
+
dropout_rate (float): dropout rate
|
|
138
|
+
attention_dropout_rate (float): dropout rate in attention
|
|
139
|
+
positional_dropout_rate (float): dropout rate after adding
|
|
140
|
+
positional encoding
|
|
141
|
+
input_layer (str): input layer type.
|
|
142
|
+
optional [linear, conv2d, conv2d6, conv2d8]
|
|
143
|
+
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
|
144
|
+
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
|
145
|
+
normalize_before (bool):
|
|
146
|
+
True: use layer_norm before each sub-block of a layer.
|
|
147
|
+
False: use layer_norm after each sub-block of a layer.
|
|
148
|
+
static_chunk_size (int): chunk size for static chunk training and
|
|
149
|
+
decoding
|
|
150
|
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
|
151
|
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
|
152
|
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
|
153
|
+
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
|
154
|
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
|
155
|
+
dynamic chunk training
|
|
156
|
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
|
157
|
+
gradient_checkpointing: rerunning a forward-pass segment for each
|
|
158
|
+
checkpointed segment during backward.
|
|
159
|
+
"""
|
|
160
|
+
super().__init__()
|
|
161
|
+
self._output_size = output_size
|
|
162
|
+
|
|
163
|
+
self.global_cmvn = global_cmvn
|
|
164
|
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
|
165
|
+
input_size,
|
|
166
|
+
output_size,
|
|
167
|
+
dropout_rate,
|
|
168
|
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
|
169
|
+
positional_dropout_rate),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
self.normalize_before = normalize_before
|
|
173
|
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
|
174
|
+
self.static_chunk_size = static_chunk_size
|
|
175
|
+
self.use_dynamic_chunk = use_dynamic_chunk
|
|
176
|
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
|
177
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
178
|
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
|
179
|
+
# self-attention module definition
|
|
180
|
+
encoder_selfattn_layer_args = (
|
|
181
|
+
attention_heads,
|
|
182
|
+
output_size,
|
|
183
|
+
attention_dropout_rate,
|
|
184
|
+
key_bias,
|
|
185
|
+
)
|
|
186
|
+
# feed-forward module definition
|
|
187
|
+
positionwise_layer_args = (
|
|
188
|
+
output_size,
|
|
189
|
+
linear_units,
|
|
190
|
+
dropout_rate,
|
|
191
|
+
activation,
|
|
192
|
+
)
|
|
193
|
+
# convolution module definition
|
|
194
|
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
|
195
|
+
cnn_module_norm, causal)
|
|
196
|
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
|
197
|
+
self.encoders = torch.nn.ModuleList([
|
|
198
|
+
ConformerEncoderLayer(
|
|
199
|
+
output_size,
|
|
200
|
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
|
201
|
+
*encoder_selfattn_layer_args),
|
|
202
|
+
PositionwiseFeedForward(*positionwise_layer_args),
|
|
203
|
+
PositionwiseFeedForward(
|
|
204
|
+
*positionwise_layer_args) if macaron_style else None,
|
|
205
|
+
ConvolutionModule(
|
|
206
|
+
*convolution_layer_args) if use_cnn_module else None,
|
|
207
|
+
dropout_rate,
|
|
208
|
+
normalize_before,
|
|
209
|
+
) for _ in range(num_blocks)
|
|
210
|
+
])
|
|
211
|
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
|
212
|
+
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
|
213
|
+
input_size,
|
|
214
|
+
output_size,
|
|
215
|
+
dropout_rate,
|
|
216
|
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
|
217
|
+
positional_dropout_rate),
|
|
218
|
+
)
|
|
219
|
+
self.up_encoders = torch.nn.ModuleList([
|
|
220
|
+
ConformerEncoderLayer(
|
|
221
|
+
output_size,
|
|
222
|
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
|
223
|
+
*encoder_selfattn_layer_args),
|
|
224
|
+
PositionwiseFeedForward(*positionwise_layer_args),
|
|
225
|
+
PositionwiseFeedForward(
|
|
226
|
+
*positionwise_layer_args) if macaron_style else None,
|
|
227
|
+
ConvolutionModule(
|
|
228
|
+
*convolution_layer_args) if use_cnn_module else None,
|
|
229
|
+
dropout_rate,
|
|
230
|
+
normalize_before,
|
|
231
|
+
) for _ in range(4)
|
|
232
|
+
])
|
|
233
|
+
|
|
234
|
+
def output_size(self) -> int:
|
|
235
|
+
return self._output_size
|
|
236
|
+
|
|
237
|
+
def forward(
|
|
238
|
+
self,
|
|
239
|
+
xs: torch.Tensor,
|
|
240
|
+
xs_lens: torch.Tensor,
|
|
241
|
+
decoding_chunk_size: int = 0,
|
|
242
|
+
num_decoding_left_chunks: int = -1,
|
|
243
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
244
|
+
"""Embed positions in tensor.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
xs: padded input tensor (B, T, D)
|
|
248
|
+
xs_lens: input length (B)
|
|
249
|
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
250
|
+
0: default for training, use random dynamic chunk.
|
|
251
|
+
<0: for decoding, use full chunk.
|
|
252
|
+
>0: for decoding, use fixed chunk size as set.
|
|
253
|
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
254
|
+
the chunk size is decoding_chunk_size.
|
|
255
|
+
>=0: use num_decoding_left_chunks
|
|
256
|
+
<0: use all left chunks
|
|
257
|
+
Returns:
|
|
258
|
+
encoder output tensor xs, and subsampled masks
|
|
259
|
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
|
260
|
+
masks: torch.Tensor batch padding mask after subsample
|
|
261
|
+
(B, 1, T' ~= T/subsample_rate)
|
|
262
|
+
NOTE(xcsong):
|
|
263
|
+
We pass the `__call__` method of the modules instead of `forward` to the
|
|
264
|
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
|
265
|
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
266
|
+
"""
|
|
267
|
+
T = xs.size(1)
|
|
268
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
269
|
+
if self.global_cmvn is not None:
|
|
270
|
+
xs = self.global_cmvn(xs)
|
|
271
|
+
xs, pos_emb, masks = self.embed(xs, masks)
|
|
272
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
273
|
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
274
|
+
self.use_dynamic_chunk,
|
|
275
|
+
self.use_dynamic_left_chunk,
|
|
276
|
+
decoding_chunk_size,
|
|
277
|
+
self.static_chunk_size,
|
|
278
|
+
num_decoding_left_chunks)
|
|
279
|
+
# lookahead + conformer encoder
|
|
280
|
+
xs = self.pre_lookahead_layer(xs)
|
|
281
|
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
282
|
+
|
|
283
|
+
# upsample + conformer encoder
|
|
284
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
285
|
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
|
286
|
+
xs = xs.transpose(1, 2).contiguous()
|
|
287
|
+
T = xs.size(1)
|
|
288
|
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
289
|
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
|
290
|
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
291
|
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
292
|
+
self.use_dynamic_chunk,
|
|
293
|
+
self.use_dynamic_left_chunk,
|
|
294
|
+
decoding_chunk_size,
|
|
295
|
+
self.static_chunk_size * self.up_layer.stride,
|
|
296
|
+
num_decoding_left_chunks)
|
|
297
|
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
298
|
+
|
|
299
|
+
if self.normalize_before:
|
|
300
|
+
xs = self.after_norm(xs)
|
|
301
|
+
# Here we assume the mask is not changed in encoder layers, so just
|
|
302
|
+
# return the masks before encoder layers, and the masks will be used
|
|
303
|
+
# for cross attention with decoder later
|
|
304
|
+
return xs, masks
|
|
305
|
+
|
|
306
|
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
|
307
|
+
pos_emb: torch.Tensor,
|
|
308
|
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
|
309
|
+
for layer in self.encoders:
|
|
310
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
311
|
+
return xs
|
|
312
|
+
|
|
313
|
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
|
314
|
+
pos_emb: torch.Tensor,
|
|
315
|
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
|
316
|
+
for layer in self.up_encoders:
|
|
317
|
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
318
|
+
return xs
|