xinference 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (83) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +204 -1
  3. xinference/client/restful/restful_client.py +4 -2
  4. xinference/core/image_interface.py +28 -0
  5. xinference/core/model.py +28 -0
  6. xinference/core/supervisor.py +6 -0
  7. xinference/model/audio/fish_speech.py +9 -9
  8. xinference/model/audio/model_spec.json +9 -9
  9. xinference/model/audio/whisper.py +4 -1
  10. xinference/model/image/core.py +2 -1
  11. xinference/model/image/model_spec.json +16 -4
  12. xinference/model/image/model_spec_modelscope.json +16 -4
  13. xinference/model/image/sdapi.py +136 -0
  14. xinference/model/image/stable_diffusion/core.py +148 -20
  15. xinference/model/llm/__init__.py +8 -0
  16. xinference/model/llm/llm_family.json +393 -0
  17. xinference/model/llm/llm_family.py +3 -1
  18. xinference/model/llm/llm_family_modelscope.json +408 -3
  19. xinference/model/llm/sglang/core.py +3 -0
  20. xinference/model/llm/transformers/chatglm.py +1 -1
  21. xinference/model/llm/transformers/core.py +6 -0
  22. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  23. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  24. xinference/model/llm/transformers/qwen2_vl.py +31 -5
  25. xinference/model/llm/utils.py +104 -84
  26. xinference/model/llm/vllm/core.py +8 -0
  27. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +2 -3
  28. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  37. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  38. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  39. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  40. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  42. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  43. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  44. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  45. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  46. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  47. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  48. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  49. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  50. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  51. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  52. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  53. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  54. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  55. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  56. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  57. xinference/types.py +7 -4
  58. xinference/web/ui/build/asset-manifest.json +6 -6
  59. xinference/web/ui/build/index.html +1 -1
  60. xinference/web/ui/build/static/css/{main.632e9148.css → main.5061c4c3.css} +2 -2
  61. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  62. xinference/web/ui/build/static/js/{main.9cfafbd6.js → main.754740c0.js} +3 -3
  63. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  66. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/METADATA +9 -3
  67. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/RECORD +72 -74
  68. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  69. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  72. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  73. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  74. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  75. xinference/web/ui/build/static/css/main.632e9148.css.map +0 -1
  76. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +0 -1
  79. /xinference/web/ui/build/static/js/{main.9cfafbd6.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +0 -0
  80. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  81. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  82. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  83. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
4
+
5
+ os.environ["USE_LIBUV"] = "0"
3
6
  import datetime
4
7
  import html
5
8
  import json
6
- import os
7
9
  import platform
8
10
  import shutil
9
11
  import signal
@@ -469,7 +471,7 @@ def train_process(
469
471
  "--config-name",
470
472
  "firefly_gan_vq",
471
473
  "--checkpoint-path",
472
- "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
474
+ "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
473
475
  ]
474
476
  )
475
477
 
@@ -485,7 +487,7 @@ def train_process(
485
487
  "16",
486
488
  ]
487
489
  )
488
- ckpt_path = "checkpoints/fish-speech-1.2-sft/model.pth"
490
+ ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
489
491
  lora_prefix = "lora_" if llama_use_lora else ""
490
492
  llama_name = lora_prefix + "text2semantic_" + new_project
491
493
  latest = next(
@@ -862,7 +864,7 @@ with gr.Blocks(
862
864
  minimum=1,
863
865
  maximum=32,
864
866
  step=1,
865
- value=4,
867
+ value=2,
866
868
  )
867
869
  llama_data_max_length_slider = gr.Slider(
868
870
  label=i18n("Maximum Length per Sample"),
@@ -870,7 +872,7 @@ with gr.Blocks(
870
872
  minimum=1024,
871
873
  maximum=4096,
872
874
  step=128,
873
- value=1024,
875
+ value=2048,
874
876
  )
875
877
  with gr.Row(equal_height=False):
876
878
  llama_precision_dropdown = gr.Dropdown(
@@ -925,9 +927,9 @@ with gr.Blocks(
925
927
  "Type the path or select from the dropdown"
926
928
  ),
927
929
  choices=[
928
- "checkpoints/fish-speech-1.2-sft/model.pth",
930
+ "checkpoints/fish-speech-1.4/model.pth",
929
931
  ],
930
- value="checkpoints/fish-speech-1.2-sft/model.pth",
932
+ value="checkpoints/fish-speech-1.4/model.pth",
931
933
  allow_custom_value=True,
932
934
  interactive=True,
933
935
  )
@@ -979,7 +981,7 @@ with gr.Blocks(
979
981
  "Type the path or select from the dropdown"
980
982
  ),
981
983
  choices=list_llama_models(),
982
- value="checkpoints/fish-speech-1.2-sft",
984
+ value="checkpoints/fish-speech-1.4",
983
985
  allow_custom_value=True,
984
986
  interactive=True,
985
987
  )
@@ -1042,7 +1044,7 @@ with gr.Blocks(
1042
1044
  "Type the path or select from the dropdown"
1043
1045
  ),
1044
1046
  choices=list_decoder_models(),
1045
- value="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
1047
+ value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
1046
1048
  allow_custom_value=True,
1047
1049
  )
1048
1050
  infer_decoder_config = gr.Dropdown(
@@ -1060,7 +1062,7 @@ with gr.Blocks(
1060
1062
  info=i18n(
1061
1063
  "Type the path or select from the dropdown"
1062
1064
  ),
1063
- value="checkpoints/fish-speech-1.2-sft",
1065
+ value="checkpoints/fish-speech-1.4",
1064
1066
  choices=list_llama_models(),
1065
1067
  allow_custom_value=True,
1066
1068
  )
@@ -9,16 +9,20 @@ import wave
9
9
  from argparse import ArgumentParser
10
10
  from http import HTTPStatus
11
11
  from pathlib import Path
12
- from typing import Annotated, Literal, Optional
12
+ from typing import Annotated, Any, Literal, Optional
13
13
 
14
14
  import numpy as np
15
+ import ormsgpack
15
16
  # import pyrootutils
16
17
  import soundfile as sf
17
18
  import torch
18
19
  import torchaudio
20
+ # from baize.datastructures import ContentType
19
21
  # from kui.asgi import (
20
22
  # Body,
23
+ # FactoryClass,
21
24
  # HTTPException,
25
+ # HttpRequest,
22
26
  # HttpView,
23
27
  # JSONResponse,
24
28
  # Kui,
@@ -27,14 +31,16 @@ import torchaudio
27
31
  # )
28
32
  # from kui.asgi.routing import MultimethodRoutes
29
33
  from loguru import logger
30
- from pydantic import BaseModel, Field
34
+ from pydantic import BaseModel, Field, conint
31
35
 
32
36
  # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
33
37
 
34
38
  # from fish_speech.models.vqgan.lit_module import VQGAN
35
39
  from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
40
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
36
41
  from fish_speech.utils import autocast_exclude_mps
37
- # from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
42
+ from tools.commons import ServeReferenceAudio, ServeTTSRequest
43
+ from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
38
44
  from tools.llama.generate import (
39
45
  GenerateRequest,
40
46
  GenerateResponse,
@@ -82,11 +88,8 @@ async def other_exception_handler(exc: "Exception"):
82
88
 
83
89
  def load_audio(reference_audio, sr):
84
90
  if len(reference_audio) > 255 or not Path(reference_audio).exists():
85
- try:
86
- audio_data = base64.b64decode(reference_audio)
87
- reference_audio = io.BytesIO(audio_data)
88
- except base64.binascii.Error:
89
- raise ValueError("Invalid path or base64 string")
91
+ audio_data = reference_audio
92
+ reference_audio = io.BytesIO(audio_data)
90
93
 
91
94
  waveform, original_sr = torchaudio.load(
92
95
  reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
@@ -145,7 +148,7 @@ def decode_vq_tokens(
145
148
  return decoder_model.decode(
146
149
  indices=codes[None],
147
150
  feature_lengths=feature_lengths,
148
- ).squeeze()
151
+ )[0].squeeze()
149
152
 
150
153
  raise ValueError(f"Unknown model type: {type(decoder_model)}")
151
154
 
@@ -153,58 +156,6 @@ def decode_vq_tokens(
153
156
  # routes = MultimethodRoutes(base_class=HttpView)
154
157
 
155
158
 
156
- def get_random_paths(base_path, data, speaker, emotion):
157
- if base_path and data and speaker and emotion and (Path(base_path).exists()):
158
- if speaker in data and emotion in data[speaker]:
159
- files = data[speaker][emotion]
160
- lab_files = [f for f in files if f.endswith(".lab")]
161
- wav_files = [f for f in files if f.endswith(".wav")]
162
-
163
- if lab_files and wav_files:
164
- selected_lab = random.choice(lab_files)
165
- selected_wav = random.choice(wav_files)
166
-
167
- lab_path = Path(base_path) / speaker / emotion / selected_lab
168
- wav_path = Path(base_path) / speaker / emotion / selected_wav
169
- if lab_path.exists() and wav_path.exists():
170
- return lab_path, wav_path
171
-
172
- return None, None
173
-
174
-
175
- def load_json(json_file):
176
- if not json_file:
177
- logger.info("Not using a json file")
178
- return None
179
- try:
180
- with open(json_file, "r", encoding="utf-8") as file:
181
- data = json.load(file)
182
- except FileNotFoundError:
183
- logger.warning(f"ref json not found: {json_file}")
184
- data = None
185
- except Exception as e:
186
- logger.warning(f"Loading json failed: {e}")
187
- data = None
188
- return data
189
-
190
-
191
- class InvokeRequest(BaseModel):
192
- text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
193
- reference_text: Optional[str] = None
194
- reference_audio: Optional[str] = None
195
- max_new_tokens: int = 1024
196
- chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
197
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
198
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
199
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
200
- emotion: Optional[str] = None
201
- format: Literal["wav", "mp3", "flac"] = "wav"
202
- streaming: bool = False
203
- ref_json: Optional[str] = "ref_data.json"
204
- ref_base: Optional[str] = "ref_data"
205
- speaker: Optional[str] = None
206
-
207
-
208
159
  def get_content_type(audio_format):
209
160
  if audio_format == "wav":
210
161
  return "audio/wav"
@@ -217,35 +168,52 @@ def get_content_type(audio_format):
217
168
 
218
169
 
219
170
  @torch.inference_mode()
220
- def inference(req: InvokeRequest):
221
- # Parse reference audio aka prompt
222
- prompt_tokens = None
223
-
224
- ref_data = load_json(req.ref_json)
225
- ref_base = req.ref_base
226
-
227
- lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
228
-
229
- if lab_path and wav_path:
230
- with open(lab_path, "r", encoding="utf-8") as lab_file:
231
- ref_text = lab_file.read()
232
- req.reference_audio = wav_path
233
- req.reference_text = ref_text
234
- logger.info("ref_path: " + str(wav_path))
235
- logger.info("ref_text: " + ref_text)
236
-
237
- # Parse reference audio aka prompt
238
- prompt_tokens = encode_reference(
239
- decoder_model=decoder_model,
240
- reference_audio=req.reference_audio,
241
- enable_reference_audio=req.reference_audio is not None,
242
- )
243
- logger.info(f"ref_text: {req.reference_text}")
171
+ def inference(req: ServeTTSRequest):
172
+
173
+ idstr: str | None = req.reference_id
174
+ if idstr is not None:
175
+ ref_folder = Path("references") / idstr
176
+ ref_folder.mkdir(parents=True, exist_ok=True)
177
+ ref_audios = list_files(
178
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
179
+ )
180
+ prompt_tokens = [
181
+ encode_reference(
182
+ decoder_model=decoder_model,
183
+ reference_audio=audio_to_bytes(str(ref_audio)),
184
+ enable_reference_audio=True,
185
+ )
186
+ for ref_audio in ref_audios
187
+ ]
188
+ prompt_texts = [
189
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
190
+ for ref_audio in ref_audios
191
+ ]
192
+
193
+ else:
194
+ # Parse reference audio aka prompt
195
+ refs = req.references
196
+ if refs is None:
197
+ refs = []
198
+ prompt_tokens = [
199
+ encode_reference(
200
+ decoder_model=decoder_model,
201
+ reference_audio=ref.audio,
202
+ enable_reference_audio=True,
203
+ )
204
+ for ref in refs
205
+ ]
206
+ prompt_texts = [ref.text for ref in refs]
207
+
244
208
  # LLAMA Inference
245
209
  request = dict(
246
210
  device=decoder_model.device,
247
211
  max_new_tokens=req.max_new_tokens,
248
- text=req.text,
212
+ text=(
213
+ req.text
214
+ if not req.normalize
215
+ else ChnNormedText(raw_text=req.text).normalize()
216
+ ),
249
217
  top_p=req.top_p,
250
218
  repetition_penalty=req.repetition_penalty,
251
219
  temperature=req.temperature,
@@ -254,7 +222,7 @@ def inference(req: InvokeRequest):
254
222
  chunk_length=req.chunk_length,
255
223
  max_length=2048,
256
224
  prompt_tokens=prompt_tokens,
257
- prompt_text=req.reference_text,
225
+ prompt_text=prompt_texts,
258
226
  )
259
227
 
260
228
  response_queue = queue.Queue()
@@ -307,40 +275,7 @@ def inference(req: InvokeRequest):
307
275
  yield fake_audios
308
276
 
309
277
 
310
- def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
311
- if not use_auto_rerank:
312
- # 如果不使用 auto_rerank,直接调用原始的 inference 函数
313
- return inference(req)
314
-
315
- zh_model, en_model = load_model()
316
- max_attempts = 5
317
- best_wer = float("inf")
318
- best_audio = None
319
-
320
- for attempt in range(max_attempts):
321
- # 调用原始的 inference 函数
322
- audio_generator = inference(req)
323
- fake_audios = next(audio_generator)
324
-
325
- asr_result = batch_asr(
326
- zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
327
- )[0]
328
- wer = calculate_wer(req.text, asr_result["text"])
329
-
330
- if wer <= 0.1 and not asr_result["huge_gap"]:
331
- return fake_audios
332
-
333
- if wer < best_wer:
334
- best_wer = wer
335
- best_audio = fake_audios
336
-
337
- if attempt == max_attempts - 1:
338
- break
339
-
340
- return best_audio
341
-
342
-
343
- async def inference_async(req: InvokeRequest):
278
+ async def inference_async(req: ServeTTSRequest):
344
279
  for chunk in inference(req):
345
280
  yield chunk
346
281
 
@@ -349,9 +284,9 @@ async def buffer_to_async_generator(buffer):
349
284
  yield buffer
350
285
 
351
286
 
352
- # @routes.http.post("/v1/invoke")
287
+ # @routes.http.post("/v1/tts")
353
288
  # async def api_invoke_model(
354
- # req: Annotated[InvokeRequest, Body(exclusive=True)],
289
+ # req: Annotated[ServeTTSRequest, Body(exclusive=True)],
355
290
  # ):
356
291
  # """
357
292
  # Invoke model and generate audio
@@ -410,21 +345,20 @@ def parse_args():
410
345
  parser.add_argument(
411
346
  "--llama-checkpoint-path",
412
347
  type=str,
413
- default="checkpoints/fish-speech-1.2-sft",
348
+ default="checkpoints/fish-speech-1.4",
414
349
  )
415
350
  parser.add_argument(
416
351
  "--decoder-checkpoint-path",
417
352
  type=str,
418
- default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
353
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
419
354
  )
420
355
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
421
356
  parser.add_argument("--device", type=str, default="cuda")
422
357
  parser.add_argument("--half", action="store_true")
423
358
  parser.add_argument("--compile", action="store_true")
424
359
  parser.add_argument("--max-text-length", type=int, default=0)
425
- parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
360
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
426
361
  parser.add_argument("--workers", type=int, default=1)
427
- parser.add_argument("--use-auto-rerank", type=bool, default=True)
428
362
 
429
363
  return parser.parse_args()
430
364
 
@@ -436,18 +370,30 @@ def parse_args():
436
370
  # },
437
371
  # ).routes
438
372
  #
373
+ #
374
+ # class MsgPackRequest(HttpRequest):
375
+ # async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
376
+ # if self.content_type == "application/msgpack":
377
+ # return ormsgpack.unpackb(await self.body)
378
+ #
379
+ # raise HTTPException(
380
+ # HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
381
+ # headers={"Accept": "application/msgpack"},
382
+ # )
383
+ #
384
+ #
439
385
  # app = Kui(
440
386
  # routes=routes + openapi[1:], # Remove the default route
441
387
  # exception_handlers={
442
388
  # HTTPException: http_execption_handler,
443
389
  # Exception: other_exception_handler,
444
390
  # },
391
+ # factory_class=FactoryClass(http=MsgPackRequest),
445
392
  # cors_config={},
446
393
  # )
447
394
 
448
395
 
449
396
  if __name__ == "__main__":
450
- import threading
451
397
 
452
398
  import uvicorn
453
399
 
@@ -474,18 +420,17 @@ if __name__ == "__main__":
474
420
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
475
421
  list(
476
422
  inference(
477
- InvokeRequest(
423
+ ServeTTSRequest(
478
424
  text="Hello world.",
479
- reference_text=None,
480
- reference_audio=None,
481
- max_new_tokens=0,
425
+ references=[],
426
+ reference_id=None,
427
+ max_new_tokens=1024,
428
+ chunk_length=200,
482
429
  top_p=0.7,
483
430
  repetition_penalty=1.2,
484
431
  temperature=0.7,
485
432
  emotion=None,
486
433
  format="wav",
487
- ref_base=None,
488
- ref_json=None,
489
434
  )
490
435
  )
491
436
  )
@@ -0,0 +1,35 @@
1
+ from typing import Annotated, Literal, Optional
2
+
3
+ from pydantic import BaseModel, Field, conint
4
+
5
+
6
+ class ServeReferenceAudio(BaseModel):
7
+ audio: bytes
8
+ text: str
9
+
10
+
11
+ class ServeTTSRequest(BaseModel):
12
+ text: str
13
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
14
+ # Audio format
15
+ format: Literal["wav", "pcm", "mp3"] = "wav"
16
+ mp3_bitrate: Literal[64, 128, 192] = 128
17
+ # References audios for in-context learning
18
+ references: list[ServeReferenceAudio] = []
19
+ # Reference id
20
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
21
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
22
+ reference_id: str | None = None
23
+ # Normalize text for en & zh, this increase stability for numbers
24
+ normalize: bool = True
25
+ mp3_bitrate: Optional[int] = 64
26
+ opus_bitrate: Optional[int] = -1000
27
+ # Balance mode will reduce latency to 300ms, but may decrease stability
28
+ latency: Literal["normal", "balanced"] = "normal"
29
+ # not usually used below
30
+ streaming: bool = False
31
+ emotion: Optional[str] = None
32
+ max_new_tokens: int = 1024
33
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
34
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
35
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
@@ -22,8 +22,8 @@ def check_and_download_files(repo_id, file_list, local_dir):
22
22
 
23
23
 
24
24
  # 1st
25
- repo_id_1 = "fishaudio/fish-speech-1.2-sft"
26
- local_dir_1 = "./checkpoints/fish-speech-1.2-sft"
25
+ repo_id_1 = "fishaudio/fish-speech-1.4"
26
+ local_dir_1 = "./checkpoints/fish-speech-1.4"
27
27
  files_1 = [
28
28
  "model.pth",
29
29
  "README.md",
@@ -31,7 +31,7 @@ files_1 = [
31
31
  "tokenizer_config.json",
32
32
  "tokenizer.json",
33
33
  "config.json",
34
- "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
34
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35
35
  ]
36
36
 
37
37
  # 3rd
@@ -1,3 +1,4 @@
1
+ import base64
1
2
  from pathlib import Path
2
3
  from typing import Union
3
4
 
@@ -23,6 +24,22 @@ VIDEO_EXTENSIONS = {
23
24
  }
24
25
 
25
26
 
27
+ def audio_to_bytes(file_path):
28
+ if not file_path or not Path(file_path).exists():
29
+ return None
30
+ with open(file_path, "rb") as wav_file:
31
+ wav = wav_file.read()
32
+ return wav
33
+
34
+
35
+ def read_ref_text(ref_text):
36
+ path = Path(ref_text)
37
+ if path.exists() and path.is_file():
38
+ with path.open("r", encoding="utf-8") as file:
39
+ return file.read()
40
+ return ref_text
41
+
42
+
26
43
  def list_files(
27
44
  path: Union[Path, str],
28
45
  extensions: set[str] = None,
@@ -13,7 +13,7 @@ from tqdm import tqdm
13
13
 
14
14
  from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
15
  from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
- from fish_speech.utils.file import load_filelist
16
+ from tools.file import load_filelist
17
17
 
18
18
  # To avoid CPU overload
19
19
  os.environ["MKL_NUM_THREADS"] = "1"
@@ -2,6 +2,7 @@ import os
2
2
  import queue
3
3
  import threading
4
4
  import time
5
+ from contextlib import nullcontext
5
6
  from dataclasses import dataclass
6
7
  from pathlib import Path
7
8
  from typing import Literal, Optional, Tuple, Union
@@ -93,15 +94,20 @@ def decode_one_token_ar(
93
94
  **sampling_kwargs,
94
95
  ) -> torch.Tensor:
95
96
  x = model.forward_generate(x, input_pos)
97
+
98
+ sampling_kwargs_main = sampling_kwargs.copy()
99
+ sampling_kwargs_main["temperature"] = 0.1
100
+ sampling_kwargs_main["top_p"] = 0.1
101
+ sampling_kwargs_main["repetition_penalty"] = 1.0
102
+
96
103
  codebooks = [
97
104
  sample(
98
105
  x.logits,
99
- previous_tokens=(
100
- previous_tokens[0] if previous_tokens is not None else None
101
- ), # Disable repetition penalty for the token codebook
102
- **sampling_kwargs,
106
+ previous_tokens=None, # Disable repetition penalty for the token codebook
107
+ **sampling_kwargs_main,
103
108
  )[0]
104
109
  ]
110
+
105
111
  x = x.hidden_states
106
112
 
107
113
  # Cleanup the cache
@@ -136,11 +142,16 @@ def decode_one_token_naive(
136
142
  ) -> torch.Tensor:
137
143
  x = model.forward_generate(x, input_pos)
138
144
 
145
+ sampling_kwargs_main = sampling_kwargs.copy()
146
+ sampling_kwargs_main["temperature"] = 0.1
147
+ sampling_kwargs_main["top_p"] = 0.1
148
+ sampling_kwargs_main["repetition_penalty"] = 1.0
149
+
139
150
  codebooks = [
140
151
  sample(
141
- x.token_logits,
152
+ x.logits,
142
153
  previous_tokens=None, # Disable repetition penalty for the token codebook
143
- **sampling_kwargs,
154
+ **sampling_kwargs_main,
144
155
  )[0]
145
156
  ]
146
157
 
@@ -181,8 +192,12 @@ def decode_n_tokens(
181
192
  else:
182
193
  window = previous_tokens[:, i - win_size : i]
183
194
 
184
- with torch.backends.cuda.sdp_kernel(
185
- enable_flash=False, enable_mem_efficient=False, enable_math=True
195
+ with (
196
+ torch.backends.cuda.sdp_kernel(
197
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
198
+ )
199
+ if torch.cuda.is_available()
200
+ else nullcontext()
186
201
  ): # Actually better for Inductor to codegen attention here
187
202
  next_token = decode_one_token(
188
203
  model=model,
@@ -222,25 +237,11 @@ def generate(
222
237
  # create an empty tensor of the expected final shape and fill in the current tokens
223
238
  T = prompt.size(1)
224
239
 
225
- if max_new_tokens:
226
- if T + max_new_tokens > model.config.max_seq_len:
227
- max_new_tokens = model.config.max_seq_len - T
228
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
229
-
230
- T_new = T + max_new_tokens
231
- else:
232
- T_new = model.config.max_seq_len
233
- max_new_tokens = T_new - T
234
-
235
240
  device, dtype = prompt.device, prompt.dtype
236
- with torch.device(device):
237
- model.setup_caches(
238
- max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
239
- )
240
241
 
241
242
  codebook_dim = 1 + model.config.num_codebooks
242
243
  # create an empty tensor of the expected final shape and fill in the current tokens
243
- empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
244
+ empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
244
245
  empty[:, :T] = prompt
245
246
  seq = empty
246
247
  input_pos = torch.arange(0, T, device=device)
@@ -560,6 +561,10 @@ def launch_thread_safe_queue(
560
561
  model, decode_one_token = load_model(
561
562
  checkpoint_path, device, precision, compile=compile
562
563
  )
564
+ with torch.device(device):
565
+ model.setup_caches(
566
+ max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
567
+ )
563
568
  init_event.set()
564
569
 
565
570
  while True:
@@ -607,7 +612,7 @@ def launch_thread_safe_queue(
607
612
  @click.option(
608
613
  "--checkpoint-path",
609
614
  type=click.Path(path_type=Path, exists=True),
610
- default="checkpoints/fish-speech-1.2-sft",
615
+ default="checkpoints/fish-speech-1.4",
611
616
  )
612
617
  @click.option("--device", type=str, default="cuda")
613
618
  @click.option("--compile/--no-compile", default=False)
@@ -15,7 +15,7 @@ from fish_speech.models.text2semantic.lora import get_merged_state_dict
15
15
 
16
16
  @click.command()
17
17
  @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
- @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
18
+ @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
19
19
  @click.option("--lora-weight", type=str, required=True)
20
20
  @click.option("--output", type=str, required=True)
21
21
  def merge(lora_config, base_weight, lora_weight, output):
@@ -428,7 +428,7 @@ def generate_folder_name():
428
428
  @click.option(
429
429
  "--checkpoint-path",
430
430
  type=click.Path(path_type=Path, exists=True),
431
- default="checkpoints/fish-speech-1.2-sft",
431
+ default="checkpoints/fish-speech-1.4",
432
432
  )
433
433
  @click.option(
434
434
  "--mode", type=str, default="int8", help="type of quantization to perform"
@@ -451,7 +451,7 @@ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -
451
451
  precision=precision,
452
452
  compile=False,
453
453
  )
454
- vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
454
+ vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
455
455
  now = timestamp if timestamp != "None" else generate_folder_name()
456
456
 
457
457
  if mode == "int8":