xinference 0.16.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (54) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/constants.py +1 -0
  5. xinference/core/model.py +10 -3
  6. xinference/core/supervisor.py +8 -2
  7. xinference/core/utils.py +67 -2
  8. xinference/model/audio/model_spec.json +1 -1
  9. xinference/model/image/stable_diffusion/core.py +5 -2
  10. xinference/model/llm/llm_family.json +176 -4
  11. xinference/model/llm/llm_family_modelscope.json +211 -0
  12. xinference/model/llm/mlx/core.py +45 -2
  13. xinference/model/rerank/core.py +11 -4
  14. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  15. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  16. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  17. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  18. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  19. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  20. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  21. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  22. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  23. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  24. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  25. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  26. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  27. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  28. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  29. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  30. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  31. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  32. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  33. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  34. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  35. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  36. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  37. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  38. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/METADATA +23 -1
  39. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/RECORD +43 -50
  40. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  42. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  46. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  50. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  52. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  53. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  54. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,2 +1,256 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
6
+
7
+ IM_START_TOKEN = "<|im_start|>"
8
+ IM_END_TOKEN = "<|im_end|>"
1
9
  SEMANTIC_TOKEN = "<|semantic|>"
10
+ MEL_TOKEN = "<|mel|>"
11
+ PHONEME_START_TOKEN = "<|phoneme_start|>"
12
+ PHONEME_END_TOKEN = "<|phoneme_end|>"
13
+ ALL_SPECIAL_TOKENS = [
14
+ IM_START_TOKEN,
15
+ IM_END_TOKEN,
16
+ SEMANTIC_TOKEN,
17
+ MEL_TOKEN,
18
+ PHONEME_START_TOKEN,
19
+ PHONEME_END_TOKEN,
20
+ ]
21
+
2
22
  CODEBOOK_PAD_TOKEN_ID = 0
23
+
24
+
25
+ class FishTokenizerConfig(PretrainedConfig):
26
+ share_codebook_embeddings: bool = True
27
+ codebook_size: int = 1024
28
+ num_codebooks: int = 8
29
+
30
+
31
+ class FishTokenizerFast(PreTrainedTokenizerFast):
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
35
+ self.codebook_size = kwargs.pop("codebook_size", 1024)
36
+ self.num_codebooks = kwargs.pop("num_codebooks", 8)
37
+
38
+
39
+ AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
40
+
41
+
42
+ @dataclass(kw_only=True)
43
+ class BasePart:
44
+ pass
45
+
46
+
47
+ @dataclass(kw_only=True)
48
+ class VQPart(BasePart):
49
+ codes: torch.Tensor
50
+
51
+
52
+ @dataclass(kw_only=True)
53
+ class TextPart(BasePart):
54
+ text: str
55
+
56
+
57
+ @dataclass(kw_only=True)
58
+ class MelPart(BasePart):
59
+ mels: torch.Tensor
60
+
61
+
62
+ @dataclass(kw_only=True)
63
+ class EncodedMessage:
64
+ tokens: torch.Tensor
65
+ labels: torch.Tensor
66
+ vq_parts: list[torch.Tensor]
67
+ mel_parts: list[torch.Tensor]
68
+ vq_require_losses: torch.Tensor | None = None
69
+
70
+
71
+ @dataclass(kw_only=True)
72
+ class Message:
73
+ role: Literal["system", "user", "assistant"]
74
+ parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
75
+ add_im_start: bool = True
76
+ add_im_end: bool = True
77
+ cal_loss: bool = False
78
+
79
+ # By default, ignore the loss of the auto-generated im_start token
80
+ ignore_im_start_loss: bool = True
81
+
82
+ def encode(
83
+ self: "Message",
84
+ tokenizer: AutoTokenizer,
85
+ ) -> EncodedMessage:
86
+ all_tokens = []
87
+ all_labels = []
88
+
89
+ # Multi-modal tokens
90
+ vq_parts = []
91
+ mel_parts = []
92
+
93
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
94
+ [SEMANTIC_TOKEN, MEL_TOKEN]
95
+ )
96
+
97
+ parts = self.parts.copy()
98
+ if self.add_im_start:
99
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
100
+
101
+ if self.add_im_end:
102
+ parts.append(TextPart(text="<|im_end|>"))
103
+
104
+ for part in parts:
105
+ if isinstance(part, TextPart):
106
+ tokens = tokenizer.encode(
107
+ part.text,
108
+ add_special_tokens=False,
109
+ truncation=False,
110
+ return_tensors="pt",
111
+ ).int()[0]
112
+ elif isinstance(part, VQPart):
113
+ tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
114
+ codes = part.codes.clone() + 1
115
+
116
+ if getattr(tokenizer, "share_codebook_embeddings", True) is False:
117
+ for i in range(len(codes)):
118
+ codes[i] += tokenizer.codebook_size * i
119
+
120
+ vq_parts.append(codes)
121
+ elif isinstance(part, MelPart):
122
+ tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
123
+ mel_parts.append(part.mels)
124
+ else:
125
+ raise ValueError(f"Unsupported part type: {type(part)}")
126
+
127
+ all_tokens.append(tokens)
128
+ if self.cal_loss:
129
+ all_labels.append(tokens.clone())
130
+ else:
131
+ all_labels.append(torch.full_like(tokens, -100))
132
+
133
+ tokens = torch.cat(all_tokens, dim=0)
134
+ labels = torch.cat(all_labels, dim=0)
135
+ assert tokens.shape == labels.shape
136
+
137
+ if self.ignore_im_start_loss and self.add_im_start:
138
+ labels[: len(all_tokens[0])] = -100
139
+
140
+ return EncodedMessage(
141
+ tokens=tokens,
142
+ labels=labels,
143
+ vq_parts=vq_parts,
144
+ mel_parts=mel_parts,
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class Conversation:
150
+ messages: list[Message]
151
+
152
+ def encode(
153
+ self: "Conversation",
154
+ tokenizer: AutoTokenizer,
155
+ add_shift: bool = True,
156
+ ) -> EncodedMessage:
157
+ # Build the input_ids and labels
158
+ tokens = []
159
+ labels = []
160
+ vq_parts = []
161
+ mel_parts = []
162
+ vq_require_losses = []
163
+
164
+ for message in self.messages:
165
+ encoded = message.encode(
166
+ tokenizer,
167
+ )
168
+ tokens.append(encoded.tokens)
169
+ labels.append(encoded.labels)
170
+ vq_parts.extend(encoded.vq_parts)
171
+ mel_parts.extend(encoded.mel_parts)
172
+ vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
173
+
174
+ tokens = torch.cat(tokens, dim=0)
175
+ labels = torch.cat(labels, dim=0)
176
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
177
+
178
+ if add_shift:
179
+ tokens = tokens[:-1]
180
+ labels = labels[1:]
181
+
182
+ assert tokens.dtype in [
183
+ torch.int,
184
+ torch.long,
185
+ ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
186
+
187
+ return EncodedMessage(
188
+ tokens=tokens,
189
+ labels=labels,
190
+ vq_parts=vq_parts,
191
+ mel_parts=mel_parts,
192
+ vq_require_losses=vq_require_losses,
193
+ )
194
+
195
+ def encode_for_inference(
196
+ self: "Conversation",
197
+ tokenizer: AutoTokenizer,
198
+ num_codebooks: int,
199
+ ) -> EncodedMessage:
200
+ encoded = self.encode(tokenizer, add_shift=False)
201
+ tokens = encoded.tokens
202
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
203
+ values[0] = tokens
204
+
205
+ if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
206
+ return values
207
+
208
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
209
+ [SEMANTIC_TOKEN, MEL_TOKEN]
210
+ )
211
+ vq_parts = encoded.vq_parts
212
+ vq_parts = torch.cat(vq_parts, dim=1)
213
+ values[1:, tokens == semantic_id] = vq_parts
214
+ return values
215
+
216
+ def visualize(self: "Conversation", tokenizer: AutoTokenizer):
217
+ encoded = self.encode(tokenizer, add_shift=False)
218
+
219
+ print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
220
+ print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
221
+
222
+ for tok, lab in zip(encoded.tokens, encoded.labels):
223
+ val = tokenizer.decode(tok, skip_special_tokens=False)
224
+ if val == "\n":
225
+ val = "\\n\n"
226
+
227
+ if lab == -100:
228
+ print_in_green(val)
229
+ else:
230
+ print_in_blue(val)
231
+
232
+ print()
233
+
234
+
235
+ if __name__ == "__main__":
236
+ message0 = Message(
237
+ role="user",
238
+ parts=[
239
+ TextPart(text="Hello, how are you?"),
240
+ VQPart(codes=torch.zeros((4, 10))),
241
+ ],
242
+ cal_loss=False,
243
+ )
244
+
245
+ message1 = Message(
246
+ role="assistant",
247
+ parts=[TextPart(text="I'm fine, thank you.")],
248
+ cal_loss=True,
249
+ )
250
+ conversation = Conversation([message0, message1])
251
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
252
+ conversation.visualize(tokenizer)
253
+
254
+ encoded = conversation.encode(tokenizer)
255
+ print(encoded)
256
+ print(tokenizer.batch_decode(encoded.tokens))
@@ -118,5 +118,6 @@
118
118
  "new": "new",
119
119
  "Realtime Transform Text": "Realtime Transform Text",
120
120
  "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
- "Text Normalization": "Text Normalization"
121
+ "Text Normalization": "Text Normalization",
122
+ "Select Example Audio": "Select Example Audio"
122
123
  }
@@ -118,5 +118,6 @@
118
118
  "new": "nuevo",
119
119
  "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
120
  "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
- "Text Normalization": "Normalización de Texto"
121
+ "Text Normalization": "Normalización de Texto",
122
+ "Select Example Audio": "Selecionar áudio de exemplo"
122
123
  }
@@ -118,6 +118,6 @@
118
118
  "new": "新規",
119
119
  "Realtime Transform Text": "リアルタイム変換テキスト",
120
120
  "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
- "Text Normalization": "テキスト正規化"
122
-
121
+ "Text Normalization": "テキスト正規化",
122
+ "Select Example Audio": "サンプル音声を選択"
123
123
  }
@@ -0,0 +1,123 @@
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
5
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
6
+ "Add to Processing Area": "처리 영역에 추가",
7
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
8
+ "Advanced Config": "고급 설정",
9
+ "Base LLAMA Model": "기본 LLAMA 모델",
10
+ "Batch Inference": "배치 추론",
11
+ "Batch Size": "배치 크기",
12
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
13
+ "Chinese": "중국어",
14
+ "Compile Model": "모델 컴파일",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
16
+ "Copy": "복사",
17
+ "Data Preprocessing": "데이터 전처리",
18
+ "Data Preprocessing Path": "데이터 전처리 경로",
19
+ "Data Source": "데이터 소스",
20
+ "Decoder Model Config": "디코더 모델 설정",
21
+ "Decoder Model Path": "디코더 모델 경로",
22
+ "Disabled": "비활성화 됨",
23
+ "Enable Reference Audio": "참고 음성 활성화",
24
+ "English": "영어",
25
+ "Error Message": "오류 메시지",
26
+ "File Preprocessing": "파일 전처리",
27
+ "Generate": "생성",
28
+ "Generated Audio": "생성된 오디오",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
30
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
31
+ "Inference Configuration": "추론 설정",
32
+ "Inference Server Configuration": "추론 서버 설정",
33
+ "Inference Server Error": "추론 서버 오류",
34
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
35
+ "Initial Learning Rate": "초기 학습률",
36
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
37
+ "Input Text": "입력 텍스트",
38
+ "Invalid path: {}": "유효하지 않은 경로: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
40
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
41
+ "Japanese": "일본어",
42
+ "LLAMA Configuration": "LLAMA 설정",
43
+ "LLAMA Model Config": "LLAMA 모델 설정",
44
+ "LLAMA Model Path": "LLAMA 모델 경로",
45
+ "Labeling Device": "라벨링 장치",
46
+ "LoRA Model to be merged": "병합할 LoRA 모델",
47
+ "Maximum Audio Duration": "최대 오디오 길이",
48
+ "Maximum Length per Sample": "샘플당 최대 길이",
49
+ "Maximum Training Steps": "최대 학습 단계",
50
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
51
+ "Merge": "병합",
52
+ "Merge LoRA": "LoRA 병합",
53
+ "Merge successfully": "성공적으로 병합 되었습니다.",
54
+ "Minimum Audio Duration": "최소 오디오 길이",
55
+ "Model Output Path": "모델 출력 경로",
56
+ "Model Size": "모델 크기",
57
+ "Move": "이동",
58
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
59
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
60
+ "No selected options": "옵션이 선택되지 않았습니다.",
61
+ "Number of Workers": "작업자 수",
62
+ "Open Inference Server": "추론 서버 열기",
63
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
64
+ "Open Tensorboard": "Tensorboard 열기",
65
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
66
+ "Optional Label Language": "선택적 라벨 언어",
67
+ "Optional online ver": "온라인 버전 선택",
68
+ "Output Path": "출력 경로",
69
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
70
+ "Precision": "정밀도",
71
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
72
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
73
+ "Reference Audio": "참고 오디오",
74
+ "Reference Text": "참고 텍스트",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
76
+ "Remove Selected Data": "선택한 데이터 제거",
77
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
78
+ "Repetition Penalty": "반복 패널티",
79
+ "Save model every n steps": "n 단계마다 모델 저장",
80
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
81
+ "Select VITS ckpt": "VITS ckpt 선택",
82
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
83
+ "Select source file processing method": "소스 파일 처리 방법 선택",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
85
+ "Selected: {}": "선택됨: {}",
86
+ "Speaker": "화자",
87
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
88
+ "Start Training": "학습 시작",
89
+ "Streaming Audio": "스트리밍 오디오",
90
+ "Streaming Generate": "스트리밍 생성",
91
+ "Tensorboard Host": "Tensorboard 호스트",
92
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
93
+ "Tensorboard Port": "Tensorboard 포트",
94
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
96
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
98
+ "Training Configuration": "학습 설정",
99
+ "Training Error": "학습 오류",
100
+ "Training stopped": "학습이 중지되었습니다.",
101
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
102
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
103
+ "Use LoRA": "LoRA 사용",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
105
+ "Use filelist": "파일 목록 사용",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
107
+ "VITS Configuration": "VITS 설정",
108
+ "VQGAN Configuration": "VQGAN 설정",
109
+ "Validation Batch Size": "검증 배치 크기",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
112
+ "WebUI Host": "WebUI 호스트",
113
+ "WebUI Port": "WebUI 포트",
114
+ "Whisper Model": "Whisper 모델",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
117
+ "latest": "최신",
118
+ "new": "새로운",
119
+ "Realtime Transform Text": "실시간 텍스트 변환",
120
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
121
+ "Text Normalization": "텍스트 정규화",
122
+ "Select Example Audio": "예시 오디오 선택"
123
+ }
@@ -118,5 +118,6 @@
118
118
  "new": "创建新的检查点",
119
119
  "Realtime Transform Text": "实时规范化文本",
120
120
  "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
- "Text Normalization": "文本规范化"
121
+ "Text Normalization": "文本规范化",
122
+ "Select Example Audio": "选择参考音频"
122
123
  }
@@ -1,3 +1,4 @@
1
+ import dataclasses
1
2
  import json
2
3
  import math
3
4
  from collections import OrderedDict
@@ -57,6 +58,10 @@ class BaseModelArgs:
57
58
  # Initialize the model
58
59
  initializer_range: float = 0.02
59
60
 
61
+ # Dummy vars
62
+ is_reward_model: bool = False
63
+ share_codebook_embeddings: bool = True
64
+
60
65
  def __post_init__(self):
61
66
  if self.n_local_heads == -1:
62
67
  self.n_local_heads = self.n_head
@@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs):
100
105
  class DualARModelArgs(BaseModelArgs):
101
106
  model_type: str = "dual_ar"
102
107
  n_fast_layer: int = 4
108
+ fast_dim: int | None = None
109
+ fast_n_head: int | None = None
110
+ fast_n_local_heads: int | None = None
111
+ fast_head_dim: int | None = None
112
+ fast_intermediate_size: int | None = None
113
+ fast_attention_qkv_bias: bool | None = None
114
+
115
+ def __post_init__(self):
116
+ super().__post_init__()
117
+
118
+ self.fast_dim = self.fast_dim or self.dim
119
+ self.fast_n_head = self.fast_n_head or self.n_head
120
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
121
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
122
+ self.fast_intermediate_size = (
123
+ self.fast_intermediate_size or self.intermediate_size
124
+ )
125
+ self.fast_attention_qkv_bias = (
126
+ self.fast_attention_qkv_bias
127
+ if self.fast_attention_qkv_bias is not None
128
+ else self.attention_qkv_bias
129
+ )
103
130
 
104
131
 
105
132
  class KVCache(nn.Module):
@@ -369,7 +396,10 @@ class BaseTransformer(nn.Module):
369
396
  model = simple_quantizer.convert_for_runtime()
370
397
 
371
398
  weights = torch.load(
372
- Path(path) / "model.pth", map_location="cpu", mmap=True
399
+ Path(path) / "model.pth",
400
+ map_location="cpu",
401
+ mmap=True,
402
+ weights_only=True,
373
403
  )
374
404
 
375
405
  if "state_dict" in weights:
@@ -471,20 +501,46 @@ class DualARTransformer(BaseTransformer):
471
501
  def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
472
502
  super().__init__(config, init_weights=False, tokenizer=tokenizer)
473
503
 
504
+ # Project to fast dim if needed
505
+ if config.fast_dim is not None and config.fast_dim != config.dim:
506
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
507
+ else:
508
+ self.fast_project_in = nn.Identity()
509
+
474
510
  # Fast transformer
475
- self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
511
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
476
512
 
477
513
  # The equivalent bs is so large that sdpa doesn't work
514
+ override_config = dataclasses.replace(
515
+ config,
516
+ dim=config.fast_dim,
517
+ n_head=config.fast_n_head,
518
+ n_local_heads=config.fast_n_local_heads,
519
+ head_dim=config.fast_head_dim,
520
+ intermediate_size=config.fast_intermediate_size,
521
+ attention_qkv_bias=config.fast_attention_qkv_bias,
522
+ )
523
+
478
524
  self.fast_layers = nn.ModuleList(
479
- TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
525
+ TransformerBlock(override_config, use_sdpa=False)
526
+ for _ in range(config.n_fast_layer)
480
527
  )
481
- self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
528
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
482
529
  self.fast_output = nn.Linear(
483
- config.dim,
530
+ config.fast_dim,
484
531
  config.codebook_size,
485
532
  bias=False,
486
533
  )
487
534
 
535
+ self.register_buffer(
536
+ "fast_freqs_cis",
537
+ precompute_freqs_cis(
538
+ config.num_codebooks,
539
+ config.fast_dim // config.fast_n_head,
540
+ config.rope_base,
541
+ ),
542
+ persistent=False,
543
+ )
488
544
  self.apply(self._init_weights)
489
545
 
490
546
  def setup_caches(
@@ -492,7 +548,7 @@ class DualARTransformer(BaseTransformer):
492
548
  ):
493
549
  super().setup_caches(max_batch_size, max_seq_len, dtype)
494
550
 
495
- head_dim = self.config.dim // self.config.n_head
551
+ head_dim = self.config.fast_dim // self.config.fast_n_head
496
552
 
497
553
  # Fast transformer
498
554
  # The max seq len here is the number of codebooks
@@ -500,7 +556,7 @@ class DualARTransformer(BaseTransformer):
500
556
  b.attention.kv_cache = KVCache(
501
557
  max_batch_size,
502
558
  self.config.num_codebooks,
503
- self.config.n_local_heads,
559
+ self.config.fast_n_local_heads,
504
560
  head_dim,
505
561
  dtype=dtype,
506
562
  )
@@ -513,13 +569,13 @@ class DualARTransformer(BaseTransformer):
513
569
  parent_result = super().forward(inp, key_padding_mask)
514
570
  token_logits = parent_result.logits
515
571
  x = parent_result.hidden_states
572
+ x = self.fast_project_in(x)
516
573
 
517
574
  # Fast transformer
518
575
  fast_seq_len = self.config.num_codebooks
519
576
  fast_mask = self.causal_mask[
520
577
  None, None, :fast_seq_len, :fast_seq_len
521
578
  ] # (B, N, Q, K)
522
- fast_freqs_cis = self.freqs_cis[:fast_seq_len]
523
579
 
524
580
  # Drop the last token and rotate left
525
581
  codebooks = inp[:, 1:-1, 1:]
@@ -542,9 +598,11 @@ class DualARTransformer(BaseTransformer):
542
598
 
543
599
  for layer in self.fast_layers:
544
600
  if self.config.use_gradient_checkpointing and self.training:
545
- x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
601
+ x = checkpoint(
602
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
603
+ )
546
604
  else:
547
- x = layer(x, fast_freqs_cis, fast_mask)
605
+ x = layer(x, self.fast_freqs_cis, fast_mask)
548
606
 
549
607
  # unflatten the batch and num_codebooks
550
608
  fast_out = self.fast_norm(x)
@@ -584,7 +642,7 @@ class DualARTransformer(BaseTransformer):
584
642
  fast_mask = self.causal_mask[
585
643
  None, None, input_pos, : self.config.num_codebooks
586
644
  ] # (B, N, Q, K)
587
- fast_freqs_cis = self.freqs_cis[input_pos]
645
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
588
646
 
589
647
  for layer in self.fast_layers:
590
648
  x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
@@ -595,6 +653,13 @@ class DualARTransformer(BaseTransformer):
595
653
 
596
654
  return codebook_logits
597
655
 
656
+ def forward_generate(
657
+ self, x: Tensor, input_pos: Optional[Tensor] = None
658
+ ) -> TransformerForwardResult:
659
+ x = super().forward_generate(x, input_pos)
660
+ x.hidden_states = self.fast_project_in(x.hidden_states)
661
+ return x
662
+
598
663
 
599
664
  class TransformerBlock(nn.Module):
600
665
  def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
@@ -102,8 +102,8 @@ class FishConvNet(nn.Module):
102
102
  self.conv = weight_norm(self.conv, name=name, dim=dim)
103
103
  return self
104
104
 
105
- def remove_weight_norm(self):
106
- self.conv = remove_parametrizations(self.conv)
105
+ def remove_parametrizations(self, name="weight"):
106
+ self.conv = remove_parametrizations(self.conv, name)
107
107
  return self
108
108
 
109
109
 
@@ -128,8 +128,8 @@ class FishTransConvNet(nn.Module):
128
128
  self.conv = weight_norm(self.conv, name=name, dim=dim)
129
129
  return self
130
130
 
131
- def remove_weight_norm(self):
132
- self.conv = remove_parametrizations(self.conv)
131
+ def remove_parametrizations(self, name="weight"):
132
+ self.conv = remove_parametrizations(self.conv, name)
133
133
  return self
134
134
 
135
135
 
@@ -178,9 +178,9 @@ class ResBlock1(torch.nn.Module):
178
178
 
179
179
  def remove_parametrizations(self):
180
180
  for conv in self.convs1:
181
- remove_parametrizations(conv, tensor_name="weight")
181
+ conv.remove_parametrizations()
182
182
  for conv in self.convs2:
183
- remove_parametrizations(conv, tensor_name="weight")
183
+ conv.remove_parametrizations()
184
184
 
185
185
 
186
186
  class ParallelBlock(nn.Module):
@@ -288,11 +288,11 @@ class HiFiGANGenerator(nn.Module):
288
288
 
289
289
  def remove_parametrizations(self):
290
290
  for up in self.ups:
291
- remove_parametrizations(up, tensor_name="weight")
291
+ up.remove_parametrizations()
292
292
  for block in self.resblocks:
293
293
  block.remove_parametrizations()
294
- remove_parametrizations(self.conv_pre, tensor_name="weight")
295
- remove_parametrizations(self.conv_post, tensor_name="weight")
294
+ self.conv_pre.remove_parametrizations()
295
+ self.conv_post.remove_parametrizations()
296
296
 
297
297
 
298
298
  # DropPath copied from timm library
@@ -99,7 +99,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
99
99
  if diff > 0:
100
100
  result.z = F.pad(result.z, (left, right))
101
101
  elif diff < 0:
102
- result.z = result.z[..., left:-right]
102
+ result.z = result.z[..., -left:right]
103
103
 
104
104
  return result
105
105