xinference 0.16.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -11
- xinference/client/restful/restful_client.py +8 -2
- xinference/conftest.py +0 -8
- xinference/constants.py +2 -0
- xinference/core/model.py +44 -5
- xinference/core/supervisor.py +13 -7
- xinference/core/utils.py +76 -12
- xinference/core/worker.py +5 -4
- xinference/deploy/cmdline.py +5 -0
- xinference/deploy/utils.py +7 -4
- xinference/model/audio/model_spec.json +2 -2
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/core.py +1 -3
- xinference/model/llm/llm_family.json +263 -4
- xinference/model/llm/llm_family_modelscope.json +302 -0
- xinference/model/llm/mlx/core.py +45 -2
- xinference/model/llm/vllm/core.py +2 -1
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/METADATA +26 -3
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/RECORD +49 -56
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,2 +1,256 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
|
|
6
|
+
|
|
7
|
+
IM_START_TOKEN = "<|im_start|>"
|
|
8
|
+
IM_END_TOKEN = "<|im_end|>"
|
|
1
9
|
SEMANTIC_TOKEN = "<|semantic|>"
|
|
10
|
+
MEL_TOKEN = "<|mel|>"
|
|
11
|
+
PHONEME_START_TOKEN = "<|phoneme_start|>"
|
|
12
|
+
PHONEME_END_TOKEN = "<|phoneme_end|>"
|
|
13
|
+
ALL_SPECIAL_TOKENS = [
|
|
14
|
+
IM_START_TOKEN,
|
|
15
|
+
IM_END_TOKEN,
|
|
16
|
+
SEMANTIC_TOKEN,
|
|
17
|
+
MEL_TOKEN,
|
|
18
|
+
PHONEME_START_TOKEN,
|
|
19
|
+
PHONEME_END_TOKEN,
|
|
20
|
+
]
|
|
21
|
+
|
|
2
22
|
CODEBOOK_PAD_TOKEN_ID = 0
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FishTokenizerConfig(PretrainedConfig):
|
|
26
|
+
share_codebook_embeddings: bool = True
|
|
27
|
+
codebook_size: int = 1024
|
|
28
|
+
num_codebooks: int = 8
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FishTokenizerFast(PreTrainedTokenizerFast):
|
|
32
|
+
def __init__(self, *args, **kwargs):
|
|
33
|
+
super().__init__(*args, **kwargs)
|
|
34
|
+
self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
|
|
35
|
+
self.codebook_size = kwargs.pop("codebook_size", 1024)
|
|
36
|
+
self.num_codebooks = kwargs.pop("num_codebooks", 8)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(kw_only=True)
|
|
43
|
+
class BasePart:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(kw_only=True)
|
|
48
|
+
class VQPart(BasePart):
|
|
49
|
+
codes: torch.Tensor
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(kw_only=True)
|
|
53
|
+
class TextPart(BasePart):
|
|
54
|
+
text: str
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(kw_only=True)
|
|
58
|
+
class MelPart(BasePart):
|
|
59
|
+
mels: torch.Tensor
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass(kw_only=True)
|
|
63
|
+
class EncodedMessage:
|
|
64
|
+
tokens: torch.Tensor
|
|
65
|
+
labels: torch.Tensor
|
|
66
|
+
vq_parts: list[torch.Tensor]
|
|
67
|
+
mel_parts: list[torch.Tensor]
|
|
68
|
+
vq_require_losses: torch.Tensor | None = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(kw_only=True)
|
|
72
|
+
class Message:
|
|
73
|
+
role: Literal["system", "user", "assistant"]
|
|
74
|
+
parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
|
|
75
|
+
add_im_start: bool = True
|
|
76
|
+
add_im_end: bool = True
|
|
77
|
+
cal_loss: bool = False
|
|
78
|
+
|
|
79
|
+
# By default, ignore the loss of the auto-generated im_start token
|
|
80
|
+
ignore_im_start_loss: bool = True
|
|
81
|
+
|
|
82
|
+
def encode(
|
|
83
|
+
self: "Message",
|
|
84
|
+
tokenizer: AutoTokenizer,
|
|
85
|
+
) -> EncodedMessage:
|
|
86
|
+
all_tokens = []
|
|
87
|
+
all_labels = []
|
|
88
|
+
|
|
89
|
+
# Multi-modal tokens
|
|
90
|
+
vq_parts = []
|
|
91
|
+
mel_parts = []
|
|
92
|
+
|
|
93
|
+
semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
|
|
94
|
+
[SEMANTIC_TOKEN, MEL_TOKEN]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
parts = self.parts.copy()
|
|
98
|
+
if self.add_im_start:
|
|
99
|
+
parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
|
|
100
|
+
|
|
101
|
+
if self.add_im_end:
|
|
102
|
+
parts.append(TextPart(text="<|im_end|>"))
|
|
103
|
+
|
|
104
|
+
for part in parts:
|
|
105
|
+
if isinstance(part, TextPart):
|
|
106
|
+
tokens = tokenizer.encode(
|
|
107
|
+
part.text,
|
|
108
|
+
add_special_tokens=False,
|
|
109
|
+
truncation=False,
|
|
110
|
+
return_tensors="pt",
|
|
111
|
+
).int()[0]
|
|
112
|
+
elif isinstance(part, VQPart):
|
|
113
|
+
tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
|
|
114
|
+
codes = part.codes.clone() + 1
|
|
115
|
+
|
|
116
|
+
if getattr(tokenizer, "share_codebook_embeddings", True) is False:
|
|
117
|
+
for i in range(len(codes)):
|
|
118
|
+
codes[i] += tokenizer.codebook_size * i
|
|
119
|
+
|
|
120
|
+
vq_parts.append(codes)
|
|
121
|
+
elif isinstance(part, MelPart):
|
|
122
|
+
tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
|
|
123
|
+
mel_parts.append(part.mels)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unsupported part type: {type(part)}")
|
|
126
|
+
|
|
127
|
+
all_tokens.append(tokens)
|
|
128
|
+
if self.cal_loss:
|
|
129
|
+
all_labels.append(tokens.clone())
|
|
130
|
+
else:
|
|
131
|
+
all_labels.append(torch.full_like(tokens, -100))
|
|
132
|
+
|
|
133
|
+
tokens = torch.cat(all_tokens, dim=0)
|
|
134
|
+
labels = torch.cat(all_labels, dim=0)
|
|
135
|
+
assert tokens.shape == labels.shape
|
|
136
|
+
|
|
137
|
+
if self.ignore_im_start_loss and self.add_im_start:
|
|
138
|
+
labels[: len(all_tokens[0])] = -100
|
|
139
|
+
|
|
140
|
+
return EncodedMessage(
|
|
141
|
+
tokens=tokens,
|
|
142
|
+
labels=labels,
|
|
143
|
+
vq_parts=vq_parts,
|
|
144
|
+
mel_parts=mel_parts,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class Conversation:
|
|
150
|
+
messages: list[Message]
|
|
151
|
+
|
|
152
|
+
def encode(
|
|
153
|
+
self: "Conversation",
|
|
154
|
+
tokenizer: AutoTokenizer,
|
|
155
|
+
add_shift: bool = True,
|
|
156
|
+
) -> EncodedMessage:
|
|
157
|
+
# Build the input_ids and labels
|
|
158
|
+
tokens = []
|
|
159
|
+
labels = []
|
|
160
|
+
vq_parts = []
|
|
161
|
+
mel_parts = []
|
|
162
|
+
vq_require_losses = []
|
|
163
|
+
|
|
164
|
+
for message in self.messages:
|
|
165
|
+
encoded = message.encode(
|
|
166
|
+
tokenizer,
|
|
167
|
+
)
|
|
168
|
+
tokens.append(encoded.tokens)
|
|
169
|
+
labels.append(encoded.labels)
|
|
170
|
+
vq_parts.extend(encoded.vq_parts)
|
|
171
|
+
mel_parts.extend(encoded.mel_parts)
|
|
172
|
+
vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
|
|
173
|
+
|
|
174
|
+
tokens = torch.cat(tokens, dim=0)
|
|
175
|
+
labels = torch.cat(labels, dim=0)
|
|
176
|
+
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
|
|
177
|
+
|
|
178
|
+
if add_shift:
|
|
179
|
+
tokens = tokens[:-1]
|
|
180
|
+
labels = labels[1:]
|
|
181
|
+
|
|
182
|
+
assert tokens.dtype in [
|
|
183
|
+
torch.int,
|
|
184
|
+
torch.long,
|
|
185
|
+
], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
|
|
186
|
+
|
|
187
|
+
return EncodedMessage(
|
|
188
|
+
tokens=tokens,
|
|
189
|
+
labels=labels,
|
|
190
|
+
vq_parts=vq_parts,
|
|
191
|
+
mel_parts=mel_parts,
|
|
192
|
+
vq_require_losses=vq_require_losses,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def encode_for_inference(
|
|
196
|
+
self: "Conversation",
|
|
197
|
+
tokenizer: AutoTokenizer,
|
|
198
|
+
num_codebooks: int,
|
|
199
|
+
) -> EncodedMessage:
|
|
200
|
+
encoded = self.encode(tokenizer, add_shift=False)
|
|
201
|
+
tokens = encoded.tokens
|
|
202
|
+
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
|
|
203
|
+
values[0] = tokens
|
|
204
|
+
|
|
205
|
+
if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
|
|
206
|
+
return values
|
|
207
|
+
|
|
208
|
+
semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
|
|
209
|
+
[SEMANTIC_TOKEN, MEL_TOKEN]
|
|
210
|
+
)
|
|
211
|
+
vq_parts = encoded.vq_parts
|
|
212
|
+
vq_parts = torch.cat(vq_parts, dim=1)
|
|
213
|
+
values[1:, tokens == semantic_id] = vq_parts
|
|
214
|
+
return values
|
|
215
|
+
|
|
216
|
+
def visualize(self: "Conversation", tokenizer: AutoTokenizer):
|
|
217
|
+
encoded = self.encode(tokenizer, add_shift=False)
|
|
218
|
+
|
|
219
|
+
print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
|
|
220
|
+
print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
|
|
221
|
+
|
|
222
|
+
for tok, lab in zip(encoded.tokens, encoded.labels):
|
|
223
|
+
val = tokenizer.decode(tok, skip_special_tokens=False)
|
|
224
|
+
if val == "\n":
|
|
225
|
+
val = "\\n\n"
|
|
226
|
+
|
|
227
|
+
if lab == -100:
|
|
228
|
+
print_in_green(val)
|
|
229
|
+
else:
|
|
230
|
+
print_in_blue(val)
|
|
231
|
+
|
|
232
|
+
print()
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
if __name__ == "__main__":
|
|
236
|
+
message0 = Message(
|
|
237
|
+
role="user",
|
|
238
|
+
parts=[
|
|
239
|
+
TextPart(text="Hello, how are you?"),
|
|
240
|
+
VQPart(codes=torch.zeros((4, 10))),
|
|
241
|
+
],
|
|
242
|
+
cal_loss=False,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
message1 = Message(
|
|
246
|
+
role="assistant",
|
|
247
|
+
parts=[TextPart(text="I'm fine, thank you.")],
|
|
248
|
+
cal_loss=True,
|
|
249
|
+
)
|
|
250
|
+
conversation = Conversation([message0, message1])
|
|
251
|
+
tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
|
|
252
|
+
conversation.visualize(tokenizer)
|
|
253
|
+
|
|
254
|
+
encoded = conversation.encode(tokenizer)
|
|
255
|
+
print(encoded)
|
|
256
|
+
print(tokenizer.batch_decode(encoded.tokens))
|
|
@@ -118,5 +118,6 @@
|
|
|
118
118
|
"new": "new",
|
|
119
119
|
"Realtime Transform Text": "Realtime Transform Text",
|
|
120
120
|
"Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
|
|
121
|
-
"Text Normalization": "Text Normalization"
|
|
121
|
+
"Text Normalization": "Text Normalization",
|
|
122
|
+
"Select Example Audio": "Select Example Audio"
|
|
122
123
|
}
|
|
@@ -118,5 +118,6 @@
|
|
|
118
118
|
"new": "nuevo",
|
|
119
119
|
"Realtime Transform Text": "Transformación de Texto en Tiempo Real",
|
|
120
120
|
"Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
|
|
121
|
-
"Text Normalization": "Normalización de Texto"
|
|
121
|
+
"Text Normalization": "Normalización de Texto",
|
|
122
|
+
"Select Example Audio": "Selecionar áudio de exemplo"
|
|
122
123
|
}
|
|
@@ -118,6 +118,6 @@
|
|
|
118
118
|
"new": "新規",
|
|
119
119
|
"Realtime Transform Text": "リアルタイム変換テキスト",
|
|
120
120
|
"Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
|
|
121
|
-
"Text Normalization": "テキスト正規化"
|
|
122
|
-
|
|
121
|
+
"Text Normalization": "テキスト正規化",
|
|
122
|
+
"Select Example Audio": "サンプル音声を選択"
|
|
123
123
|
}
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
{
|
|
2
|
+
"16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
|
|
3
|
+
"5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
|
|
4
|
+
"A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
|
|
5
|
+
"Accumulate Gradient Batches": "그라디언트 배치 누적",
|
|
6
|
+
"Add to Processing Area": "처리 영역에 추가",
|
|
7
|
+
"Added path successfully!": "경로가 성공적으로 추가되었습니다!",
|
|
8
|
+
"Advanced Config": "고급 설정",
|
|
9
|
+
"Base LLAMA Model": "기본 LLAMA 모델",
|
|
10
|
+
"Batch Inference": "배치 추론",
|
|
11
|
+
"Batch Size": "배치 크기",
|
|
12
|
+
"Changing with the Model Path": "모델 경로에 따라 변경 중",
|
|
13
|
+
"Chinese": "중국어",
|
|
14
|
+
"Compile Model": "모델 컴파일",
|
|
15
|
+
"Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
|
|
16
|
+
"Copy": "복사",
|
|
17
|
+
"Data Preprocessing": "데이터 전처리",
|
|
18
|
+
"Data Preprocessing Path": "데이터 전처리 경로",
|
|
19
|
+
"Data Source": "데이터 소스",
|
|
20
|
+
"Decoder Model Config": "디코더 모델 설정",
|
|
21
|
+
"Decoder Model Path": "디코더 모델 경로",
|
|
22
|
+
"Disabled": "비활성화 됨",
|
|
23
|
+
"Enable Reference Audio": "참고 음성 활성화",
|
|
24
|
+
"English": "영어",
|
|
25
|
+
"Error Message": "오류 메시지",
|
|
26
|
+
"File Preprocessing": "파일 전처리",
|
|
27
|
+
"Generate": "생성",
|
|
28
|
+
"Generated Audio": "생성된 오디오",
|
|
29
|
+
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
|
|
30
|
+
"Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
|
|
31
|
+
"Inference Configuration": "추론 설정",
|
|
32
|
+
"Inference Server Configuration": "추론 서버 설정",
|
|
33
|
+
"Inference Server Error": "추론 서버 오류",
|
|
34
|
+
"Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
|
|
35
|
+
"Initial Learning Rate": "초기 학습률",
|
|
36
|
+
"Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
|
|
37
|
+
"Input Text": "입력 텍스트",
|
|
38
|
+
"Invalid path: {}": "유효하지 않은 경로: {}",
|
|
39
|
+
"It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
|
|
40
|
+
"Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
|
|
41
|
+
"Japanese": "일본어",
|
|
42
|
+
"LLAMA Configuration": "LLAMA 설정",
|
|
43
|
+
"LLAMA Model Config": "LLAMA 모델 설정",
|
|
44
|
+
"LLAMA Model Path": "LLAMA 모델 경로",
|
|
45
|
+
"Labeling Device": "라벨링 장치",
|
|
46
|
+
"LoRA Model to be merged": "병합할 LoRA 모델",
|
|
47
|
+
"Maximum Audio Duration": "최대 오디오 길이",
|
|
48
|
+
"Maximum Length per Sample": "샘플당 최대 길이",
|
|
49
|
+
"Maximum Training Steps": "최대 학습 단계",
|
|
50
|
+
"Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
|
|
51
|
+
"Merge": "병합",
|
|
52
|
+
"Merge LoRA": "LoRA 병합",
|
|
53
|
+
"Merge successfully": "성공적으로 병합 되었습니다.",
|
|
54
|
+
"Minimum Audio Duration": "최소 오디오 길이",
|
|
55
|
+
"Model Output Path": "모델 출력 경로",
|
|
56
|
+
"Model Size": "모델 크기",
|
|
57
|
+
"Move": "이동",
|
|
58
|
+
"Move files successfully": "파일이 성공적으로 이동되었습니다.",
|
|
59
|
+
"No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
|
|
60
|
+
"No selected options": "옵션이 선택되지 않았습니다.",
|
|
61
|
+
"Number of Workers": "작업자 수",
|
|
62
|
+
"Open Inference Server": "추론 서버 열기",
|
|
63
|
+
"Open Labeler WebUI": "라벨러 WebUI 열기",
|
|
64
|
+
"Open Tensorboard": "Tensorboard 열기",
|
|
65
|
+
"Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
|
|
66
|
+
"Optional Label Language": "선택적 라벨 언어",
|
|
67
|
+
"Optional online ver": "온라인 버전 선택",
|
|
68
|
+
"Output Path": "출력 경로",
|
|
69
|
+
"Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
|
|
70
|
+
"Precision": "정밀도",
|
|
71
|
+
"Probability of applying Speaker Condition": "화자 조건 적용 확률",
|
|
72
|
+
"Put your text here.": "여기에 텍스트를 입력하세요.",
|
|
73
|
+
"Reference Audio": "참고 오디오",
|
|
74
|
+
"Reference Text": "참고 텍스트",
|
|
75
|
+
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
|
|
76
|
+
"Remove Selected Data": "선택한 데이터 제거",
|
|
77
|
+
"Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
|
|
78
|
+
"Repetition Penalty": "반복 패널티",
|
|
79
|
+
"Save model every n steps": "n 단계마다 모델 저장",
|
|
80
|
+
"Select LLAMA ckpt": "LLAMA ckpt 선택",
|
|
81
|
+
"Select VITS ckpt": "VITS ckpt 선택",
|
|
82
|
+
"Select VQGAN ckpt": "VQGAN ckpt 선택",
|
|
83
|
+
"Select source file processing method": "소스 파일 처리 방법 선택",
|
|
84
|
+
"Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
|
|
85
|
+
"Selected: {}": "선택됨: {}",
|
|
86
|
+
"Speaker": "화자",
|
|
87
|
+
"Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
|
|
88
|
+
"Start Training": "학습 시작",
|
|
89
|
+
"Streaming Audio": "스트리밍 오디오",
|
|
90
|
+
"Streaming Generate": "스트리밍 생성",
|
|
91
|
+
"Tensorboard Host": "Tensorboard 호스트",
|
|
92
|
+
"Tensorboard Log Path": "Tensorboard 로그 경로",
|
|
93
|
+
"Tensorboard Port": "Tensorboard 포트",
|
|
94
|
+
"Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
|
|
95
|
+
"Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
|
|
96
|
+
"Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
|
|
97
|
+
"The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
|
|
98
|
+
"Training Configuration": "학습 설정",
|
|
99
|
+
"Training Error": "학습 오류",
|
|
100
|
+
"Training stopped": "학습이 중지되었습니다.",
|
|
101
|
+
"Type name of the speaker": "화자의 이름을 입력하세요.",
|
|
102
|
+
"Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
|
|
103
|
+
"Use LoRA": "LoRA 사용",
|
|
104
|
+
"Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
|
|
105
|
+
"Use filelist": "파일 목록 사용",
|
|
106
|
+
"Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
|
|
107
|
+
"VITS Configuration": "VITS 설정",
|
|
108
|
+
"VQGAN Configuration": "VQGAN 설정",
|
|
109
|
+
"Validation Batch Size": "검증 배치 크기",
|
|
110
|
+
"View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
|
|
111
|
+
"We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
|
|
112
|
+
"WebUI Host": "WebUI 호스트",
|
|
113
|
+
"WebUI Port": "WebUI 포트",
|
|
114
|
+
"Whisper Model": "Whisper 모델",
|
|
115
|
+
"You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
|
|
116
|
+
"bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
|
|
117
|
+
"latest": "최신",
|
|
118
|
+
"new": "새로운",
|
|
119
|
+
"Realtime Transform Text": "실시간 텍스트 변환",
|
|
120
|
+
"Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
|
|
121
|
+
"Text Normalization": "텍스트 정규화",
|
|
122
|
+
"Select Example Audio": "예시 오디오 선택"
|
|
123
|
+
}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import dataclasses
|
|
1
2
|
import json
|
|
2
3
|
import math
|
|
3
4
|
from collections import OrderedDict
|
|
@@ -57,6 +58,10 @@ class BaseModelArgs:
|
|
|
57
58
|
# Initialize the model
|
|
58
59
|
initializer_range: float = 0.02
|
|
59
60
|
|
|
61
|
+
# Dummy vars
|
|
62
|
+
is_reward_model: bool = False
|
|
63
|
+
share_codebook_embeddings: bool = True
|
|
64
|
+
|
|
60
65
|
def __post_init__(self):
|
|
61
66
|
if self.n_local_heads == -1:
|
|
62
67
|
self.n_local_heads = self.n_head
|
|
@@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs):
|
|
|
100
105
|
class DualARModelArgs(BaseModelArgs):
|
|
101
106
|
model_type: str = "dual_ar"
|
|
102
107
|
n_fast_layer: int = 4
|
|
108
|
+
fast_dim: int | None = None
|
|
109
|
+
fast_n_head: int | None = None
|
|
110
|
+
fast_n_local_heads: int | None = None
|
|
111
|
+
fast_head_dim: int | None = None
|
|
112
|
+
fast_intermediate_size: int | None = None
|
|
113
|
+
fast_attention_qkv_bias: bool | None = None
|
|
114
|
+
|
|
115
|
+
def __post_init__(self):
|
|
116
|
+
super().__post_init__()
|
|
117
|
+
|
|
118
|
+
self.fast_dim = self.fast_dim or self.dim
|
|
119
|
+
self.fast_n_head = self.fast_n_head or self.n_head
|
|
120
|
+
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
|
|
121
|
+
self.fast_head_dim = self.fast_head_dim or self.head_dim
|
|
122
|
+
self.fast_intermediate_size = (
|
|
123
|
+
self.fast_intermediate_size or self.intermediate_size
|
|
124
|
+
)
|
|
125
|
+
self.fast_attention_qkv_bias = (
|
|
126
|
+
self.fast_attention_qkv_bias
|
|
127
|
+
if self.fast_attention_qkv_bias is not None
|
|
128
|
+
else self.attention_qkv_bias
|
|
129
|
+
)
|
|
103
130
|
|
|
104
131
|
|
|
105
132
|
class KVCache(nn.Module):
|
|
@@ -369,7 +396,10 @@ class BaseTransformer(nn.Module):
|
|
|
369
396
|
model = simple_quantizer.convert_for_runtime()
|
|
370
397
|
|
|
371
398
|
weights = torch.load(
|
|
372
|
-
Path(path) / "model.pth",
|
|
399
|
+
Path(path) / "model.pth",
|
|
400
|
+
map_location="cpu",
|
|
401
|
+
mmap=True,
|
|
402
|
+
weights_only=True,
|
|
373
403
|
)
|
|
374
404
|
|
|
375
405
|
if "state_dict" in weights:
|
|
@@ -471,20 +501,46 @@ class DualARTransformer(BaseTransformer):
|
|
|
471
501
|
def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
|
|
472
502
|
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
473
503
|
|
|
504
|
+
# Project to fast dim if needed
|
|
505
|
+
if config.fast_dim is not None and config.fast_dim != config.dim:
|
|
506
|
+
self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
|
|
507
|
+
else:
|
|
508
|
+
self.fast_project_in = nn.Identity()
|
|
509
|
+
|
|
474
510
|
# Fast transformer
|
|
475
|
-
self.fast_embeddings = nn.Embedding(config.codebook_size, config.
|
|
511
|
+
self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
|
|
476
512
|
|
|
477
513
|
# The equivalent bs is so large that sdpa doesn't work
|
|
514
|
+
override_config = dataclasses.replace(
|
|
515
|
+
config,
|
|
516
|
+
dim=config.fast_dim,
|
|
517
|
+
n_head=config.fast_n_head,
|
|
518
|
+
n_local_heads=config.fast_n_local_heads,
|
|
519
|
+
head_dim=config.fast_head_dim,
|
|
520
|
+
intermediate_size=config.fast_intermediate_size,
|
|
521
|
+
attention_qkv_bias=config.fast_attention_qkv_bias,
|
|
522
|
+
)
|
|
523
|
+
|
|
478
524
|
self.fast_layers = nn.ModuleList(
|
|
479
|
-
TransformerBlock(
|
|
525
|
+
TransformerBlock(override_config, use_sdpa=False)
|
|
526
|
+
for _ in range(config.n_fast_layer)
|
|
480
527
|
)
|
|
481
|
-
self.fast_norm = RMSNorm(config.
|
|
528
|
+
self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
|
|
482
529
|
self.fast_output = nn.Linear(
|
|
483
|
-
config.
|
|
530
|
+
config.fast_dim,
|
|
484
531
|
config.codebook_size,
|
|
485
532
|
bias=False,
|
|
486
533
|
)
|
|
487
534
|
|
|
535
|
+
self.register_buffer(
|
|
536
|
+
"fast_freqs_cis",
|
|
537
|
+
precompute_freqs_cis(
|
|
538
|
+
config.num_codebooks,
|
|
539
|
+
config.fast_dim // config.fast_n_head,
|
|
540
|
+
config.rope_base,
|
|
541
|
+
),
|
|
542
|
+
persistent=False,
|
|
543
|
+
)
|
|
488
544
|
self.apply(self._init_weights)
|
|
489
545
|
|
|
490
546
|
def setup_caches(
|
|
@@ -492,7 +548,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
492
548
|
):
|
|
493
549
|
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
|
494
550
|
|
|
495
|
-
head_dim = self.config.
|
|
551
|
+
head_dim = self.config.fast_dim // self.config.fast_n_head
|
|
496
552
|
|
|
497
553
|
# Fast transformer
|
|
498
554
|
# The max seq len here is the number of codebooks
|
|
@@ -500,7 +556,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
500
556
|
b.attention.kv_cache = KVCache(
|
|
501
557
|
max_batch_size,
|
|
502
558
|
self.config.num_codebooks,
|
|
503
|
-
self.config.
|
|
559
|
+
self.config.fast_n_local_heads,
|
|
504
560
|
head_dim,
|
|
505
561
|
dtype=dtype,
|
|
506
562
|
)
|
|
@@ -513,13 +569,13 @@ class DualARTransformer(BaseTransformer):
|
|
|
513
569
|
parent_result = super().forward(inp, key_padding_mask)
|
|
514
570
|
token_logits = parent_result.logits
|
|
515
571
|
x = parent_result.hidden_states
|
|
572
|
+
x = self.fast_project_in(x)
|
|
516
573
|
|
|
517
574
|
# Fast transformer
|
|
518
575
|
fast_seq_len = self.config.num_codebooks
|
|
519
576
|
fast_mask = self.causal_mask[
|
|
520
577
|
None, None, :fast_seq_len, :fast_seq_len
|
|
521
578
|
] # (B, N, Q, K)
|
|
522
|
-
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
523
579
|
|
|
524
580
|
# Drop the last token and rotate left
|
|
525
581
|
codebooks = inp[:, 1:-1, 1:]
|
|
@@ -542,9 +598,11 @@ class DualARTransformer(BaseTransformer):
|
|
|
542
598
|
|
|
543
599
|
for layer in self.fast_layers:
|
|
544
600
|
if self.config.use_gradient_checkpointing and self.training:
|
|
545
|
-
x = checkpoint(
|
|
601
|
+
x = checkpoint(
|
|
602
|
+
layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
|
|
603
|
+
)
|
|
546
604
|
else:
|
|
547
|
-
x = layer(x, fast_freqs_cis, fast_mask)
|
|
605
|
+
x = layer(x, self.fast_freqs_cis, fast_mask)
|
|
548
606
|
|
|
549
607
|
# unflatten the batch and num_codebooks
|
|
550
608
|
fast_out = self.fast_norm(x)
|
|
@@ -584,7 +642,7 @@ class DualARTransformer(BaseTransformer):
|
|
|
584
642
|
fast_mask = self.causal_mask[
|
|
585
643
|
None, None, input_pos, : self.config.num_codebooks
|
|
586
644
|
] # (B, N, Q, K)
|
|
587
|
-
fast_freqs_cis = self.
|
|
645
|
+
fast_freqs_cis = self.fast_freqs_cis[input_pos]
|
|
588
646
|
|
|
589
647
|
for layer in self.fast_layers:
|
|
590
648
|
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
|
@@ -595,6 +653,13 @@ class DualARTransformer(BaseTransformer):
|
|
|
595
653
|
|
|
596
654
|
return codebook_logits
|
|
597
655
|
|
|
656
|
+
def forward_generate(
|
|
657
|
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
658
|
+
) -> TransformerForwardResult:
|
|
659
|
+
x = super().forward_generate(x, input_pos)
|
|
660
|
+
x.hidden_states = self.fast_project_in(x.hidden_states)
|
|
661
|
+
return x
|
|
662
|
+
|
|
598
663
|
|
|
599
664
|
class TransformerBlock(nn.Module):
|
|
600
665
|
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
|
@@ -102,8 +102,8 @@ class FishConvNet(nn.Module):
|
|
|
102
102
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
103
103
|
return self
|
|
104
104
|
|
|
105
|
-
def
|
|
106
|
-
self.conv = remove_parametrizations(self.conv)
|
|
105
|
+
def remove_parametrizations(self, name="weight"):
|
|
106
|
+
self.conv = remove_parametrizations(self.conv, name)
|
|
107
107
|
return self
|
|
108
108
|
|
|
109
109
|
|
|
@@ -128,8 +128,8 @@ class FishTransConvNet(nn.Module):
|
|
|
128
128
|
self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
129
129
|
return self
|
|
130
130
|
|
|
131
|
-
def
|
|
132
|
-
self.conv = remove_parametrizations(self.conv)
|
|
131
|
+
def remove_parametrizations(self, name="weight"):
|
|
132
|
+
self.conv = remove_parametrizations(self.conv, name)
|
|
133
133
|
return self
|
|
134
134
|
|
|
135
135
|
|
|
@@ -178,9 +178,9 @@ class ResBlock1(torch.nn.Module):
|
|
|
178
178
|
|
|
179
179
|
def remove_parametrizations(self):
|
|
180
180
|
for conv in self.convs1:
|
|
181
|
-
remove_parametrizations(
|
|
181
|
+
conv.remove_parametrizations()
|
|
182
182
|
for conv in self.convs2:
|
|
183
|
-
remove_parametrizations(
|
|
183
|
+
conv.remove_parametrizations()
|
|
184
184
|
|
|
185
185
|
|
|
186
186
|
class ParallelBlock(nn.Module):
|
|
@@ -288,11 +288,11 @@ class HiFiGANGenerator(nn.Module):
|
|
|
288
288
|
|
|
289
289
|
def remove_parametrizations(self):
|
|
290
290
|
for up in self.ups:
|
|
291
|
-
remove_parametrizations(
|
|
291
|
+
up.remove_parametrizations()
|
|
292
292
|
for block in self.resblocks:
|
|
293
293
|
block.remove_parametrizations()
|
|
294
|
-
|
|
295
|
-
|
|
294
|
+
self.conv_pre.remove_parametrizations()
|
|
295
|
+
self.conv_post.remove_parametrizations()
|
|
296
296
|
|
|
297
297
|
|
|
298
298
|
# DropPath copied from timm library
|