xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -1,943 +0,0 @@
1
- import io
2
- import os
3
- import queue
4
- import re
5
- import time
6
- import traceback
7
- import wave
8
- from argparse import ArgumentParser
9
- from http import HTTPStatus
10
- from pathlib import Path
11
- from typing import Annotated, Any
12
-
13
- import librosa
14
- import numpy as np
15
- import ormsgpack
16
- # import pyrootutils
17
- import soundfile as sf
18
- import torch
19
- import torchaudio
20
- # from baize.datastructures import ContentType
21
- # from kui.asgi import (
22
- # Body,
23
- # FactoryClass,
24
- # HTTPException,
25
- # HttpRequest,
26
- # HttpView,
27
- # JSONResponse,
28
- # Kui,
29
- # OpenAPI,
30
- # StreamResponse,
31
- # request,
32
- # )
33
- # from kui.asgi.routing import MultimethodRoutes
34
- from loguru import logger
35
- from transformers import AutoTokenizer
36
-
37
- # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
38
- import struct
39
- from threading import Lock
40
-
41
- import httpx
42
- from cachetools import LRUCache, cached
43
- from funasr import AutoModel
44
- from silero_vad import get_speech_timestamps, load_silero_vad
45
-
46
- from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
47
- from fish_speech.models.text2semantic.llama import BaseModelArgs
48
-
49
- # from fish_speech.models.vqgan.lit_module import VQGAN
50
- from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
51
- from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
52
- from fish_speech.utils import autocast_exclude_mps, set_seed
53
- from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
54
- from tools.llama.generate import (
55
- GenerateRequest,
56
- GenerateResponse,
57
- WrappedGenerateResponse,
58
- launch_thread_safe_queue,
59
- launch_thread_safe_queue_agent,
60
- )
61
- from tools.schema import (
62
- GLOBAL_NUM_SAMPLES,
63
- ASRPackRequest,
64
- ServeASRRequest,
65
- ServeASRResponse,
66
- ServeASRSegment,
67
- ServeAudioPart,
68
- ServeForwardMessage,
69
- ServeMessage,
70
- ServeRequest,
71
- ServeResponse,
72
- ServeStreamDelta,
73
- ServeStreamResponse,
74
- ServeTextPart,
75
- ServeTimedASRResponse,
76
- ServeTTSRequest,
77
- ServeVQGANDecodeRequest,
78
- ServeVQGANDecodeResponse,
79
- ServeVQGANEncodeRequest,
80
- ServeVQGANEncodeResponse,
81
- ServeVQPart,
82
- )
83
- from tools.vqgan.inference import load_model as load_decoder_model
84
-
85
- global_lock = Lock()
86
-
87
- # Whether to disable keepalive (which is helpful if the server is in the same cluster)
88
- DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
89
- async_client = httpx.AsyncClient(
90
- timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
91
- )
92
- backends = torchaudio.list_audio_backends()
93
-
94
- if "ffmpeg" in backends:
95
- backend = "ffmpeg"
96
- else:
97
- backend = "soundfile"
98
-
99
-
100
- def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
101
- buffer = io.BytesIO()
102
-
103
- with wave.open(buffer, "wb") as wav_file:
104
- wav_file.setnchannels(channels)
105
- wav_file.setsampwidth(bit_depth // 8)
106
- wav_file.setframerate(sample_rate)
107
-
108
- wav_header_bytes = buffer.getvalue()
109
- buffer.close()
110
- return wav_header_bytes
111
-
112
-
113
- # Define utils for web server
114
- # async def http_execption_handler(exc: HTTPException):
115
- # return JSONResponse(
116
- # dict(
117
- # statusCode=exc.status_code,
118
- # message=exc.content,
119
- # error=HTTPStatus(exc.status_code).phrase,
120
- # ),
121
- # exc.status_code,
122
- # exc.headers,
123
- # )
124
-
125
-
126
- async def other_exception_handler(exc: "Exception"):
127
- traceback.print_exc()
128
-
129
- status = HTTPStatus.INTERNAL_SERVER_ERROR
130
- return JSONResponse(
131
- dict(statusCode=status, message=str(exc), error=status.phrase),
132
- status,
133
- )
134
-
135
-
136
- def load_audio(reference_audio, sr):
137
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
138
- audio_data = reference_audio
139
- reference_audio = io.BytesIO(audio_data)
140
-
141
- waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
142
-
143
- if waveform.shape[0] > 1:
144
- waveform = torch.mean(waveform, dim=0, keepdim=True)
145
-
146
- if original_sr != sr:
147
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
148
- waveform = resampler(waveform)
149
-
150
- audio = waveform.squeeze().numpy()
151
- return audio
152
-
153
-
154
- def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
155
- if enable_reference_audio and reference_audio is not None:
156
- # Load audios, and prepare basic info here
157
- reference_audio_content = load_audio(
158
- reference_audio, decoder_model.spec_transform.sample_rate
159
- )
160
-
161
- audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
162
- None, None, :
163
- ]
164
- audio_lengths = torch.tensor(
165
- [audios.shape[2]], device=decoder_model.device, dtype=torch.long
166
- )
167
- logger.info(
168
- f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
169
- )
170
-
171
- # VQ Encoder
172
- if isinstance(decoder_model, FireflyArchitecture):
173
- prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
174
-
175
- logger.info(f"Encoded prompt: {prompt_tokens.shape}")
176
- else:
177
- prompt_tokens = None
178
- logger.info("No reference audio provided")
179
-
180
- return prompt_tokens
181
-
182
-
183
- def decode_vq_tokens(
184
- *,
185
- decoder_model,
186
- codes,
187
- ):
188
- feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
189
- logger.info(f"VQ features: {codes.shape}")
190
-
191
- if isinstance(decoder_model, FireflyArchitecture):
192
- # VQGAN Inference
193
- return decoder_model.decode(
194
- indices=codes[None],
195
- feature_lengths=feature_lengths,
196
- )[0].squeeze()
197
-
198
- raise ValueError(f"Unknown model type: {type(decoder_model)}")
199
-
200
-
201
- # routes = MultimethodRoutes(base_class=HttpView)
202
-
203
-
204
- def get_content_type(audio_format):
205
- if audio_format == "wav":
206
- return "audio/wav"
207
- elif audio_format == "flac":
208
- return "audio/flac"
209
- elif audio_format == "mp3":
210
- return "audio/mpeg"
211
- else:
212
- return "application/octet-stream"
213
-
214
-
215
- @torch.no_grad()
216
- @torch.autocast(device_type="cuda", dtype=torch.half)
217
- def batch_encode(model, audios: list[bytes | torch.Tensor]):
218
- audios = [
219
- (
220
- torch.from_numpy(
221
- librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
222
- )[None]
223
- if isinstance(audio, bytes)
224
- else audio
225
- )
226
- for audio in audios
227
- ]
228
-
229
- # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
230
- # raise ValueError("Single audio length is too long (>120s)")
231
-
232
- max_length = max(audio.shape[-1] for audio in audios)
233
- print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
234
-
235
- lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
236
- max_length = lengths.max().item()
237
- padded = torch.stack(
238
- [
239
- torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
240
- for audio in audios
241
- ]
242
- ).to(model.device)
243
-
244
- features, feature_lengths = model.encode(padded, audio_lengths=lengths)
245
- features, feature_lengths = features.cpu(), feature_lengths.cpu()
246
-
247
- return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
248
-
249
-
250
- @cached(
251
- cache=LRUCache(maxsize=10000),
252
- key=lambda model, audios: (model.device, tuple(audios)),
253
- )
254
- def cached_vqgan_batch_encode(model, audios: list[bytes]):
255
- return batch_encode(model, audios)
256
-
257
-
258
- # @routes.http.post("/v1/vqgan/encode")
259
- # def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
260
- #
261
- # start_time = time.time()
262
- # tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
263
- # logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
264
- #
265
- # return ormsgpack.packb(
266
- # ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
267
- # option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
268
- # )
269
-
270
-
271
- @torch.no_grad()
272
- @torch.autocast(device_type="cuda", dtype=torch.half)
273
- def vqgan_decode(model, features):
274
- lengths = torch.tensor(
275
- [feature.shape[-1] for feature in features], device=model.device
276
- )
277
- max_length = lengths.max().item()
278
- padded = torch.stack(
279
- [
280
- torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
281
- for feature in features
282
- ]
283
- ).to(model.device)
284
-
285
- # If bs too large, we do micro batch decode
286
- audios, audio_lengths = [], []
287
- for i in range(0, padded.shape[0], 8):
288
- audio, audio_length = model.decode(
289
- padded[i : i + 8], feature_lengths=lengths[i : i + 8]
290
- )
291
- audios.append(audio)
292
- audio_lengths.append(audio_length)
293
- audios = torch.cat(audios, dim=0)
294
- audio_lengths = torch.cat(audio_lengths, dim=0)
295
- audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
296
-
297
- return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
298
-
299
-
300
- # @routes.http.post("/v1/vqgan/decode")
301
- # def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
302
- # tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
303
- # start_time = time.time()
304
- # audios = vqgan_decode(decoder_model, tokens)
305
- # logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
306
- # audios = [audio.astype(np.float16).tobytes() for audio in audios]
307
- # return ormsgpack.packb(
308
- # ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
309
- # )
310
-
311
-
312
- @torch.no_grad()
313
- def batch_asr(model, audios, sr, language="auto"):
314
- resampled_audios = []
315
- for audio in audios:
316
- audio = torchaudio.functional.resample(audio, sr, 16000)
317
- assert audio.ndim == 1
318
- resampled_audios.append(audio)
319
-
320
- with global_lock:
321
- res = model.generate(
322
- input=resampled_audios,
323
- batch_size=len(resampled_audios),
324
- language=language,
325
- use_itn=True,
326
- )
327
-
328
- results = []
329
- for r, audio in zip(res, audios):
330
- text = r["text"]
331
- text = re.sub(r"<\|.*?\|>", "", text)
332
- duration = len(audio) / sr * 1000
333
- huge_gap = False
334
-
335
- if "timestamp" in r and len(r["timestamp"]) > 2:
336
- for timestamp_a, timestamp_b in zip(
337
- r["timestamp"][:-1], r["timestamp"][1:]
338
- ):
339
- # If there is a gap of more than 5 seconds, we consider it as a huge gap
340
- if timestamp_b[0] - timestamp_a[1] > 5000:
341
- huge_gap = True
342
- break
343
-
344
- # Doesn't make sense to have a huge gap at the end
345
- if duration - r["timestamp"][-1][1] > 3000:
346
- huge_gap = True
347
-
348
- results.append(
349
- {
350
- "text": text,
351
- "duration": duration,
352
- "huge_gap": huge_gap,
353
- }
354
- )
355
-
356
- return results
357
-
358
-
359
- # @routes.http.post("/v1/asr")
360
- # def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
361
- # start_time = time.time()
362
- # audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
363
- # audios = [torch.from_numpy(audio).float() for audio in audios]
364
- #
365
- # if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
366
- # raise HTTPException(status_code=400, detail="Audio length is too long")
367
- #
368
- # transcriptions = batch_asr(
369
- # asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
370
- # )
371
- # logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
372
- #
373
- # return ormsgpack.packb(
374
- # ServeASRResponse(transcriptions=transcriptions),
375
- # option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
376
- # )
377
-
378
-
379
- from fish_speech.conversation import Conversation, Message
380
-
381
-
382
- def execute_request(
383
- input_queue: queue.Queue,
384
- tokenizer: AutoTokenizer,
385
- config: BaseModelArgs,
386
- request: ServeRequest,
387
- device: str = "cuda:0",
388
- ):
389
- semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
390
- [SEMANTIC_TOKEN, IM_END_TOKEN]
391
- )
392
- messages = []
393
- for message in request.messages:
394
- messages.append(message.to_conversation_message())
395
-
396
- assert len(messages) >= 1, "At least one message is required"
397
- # assert messages[-1].role == "user", "The last message must be from the user"
398
-
399
- if messages[-1].role == "user":
400
- messages.append(Message(role="assistant", parts=[], add_im_end=False))
401
- else:
402
- assert (
403
- messages[-1].role == "assistant"
404
- ), "The last message must be from the assistant"
405
- messages[-1].add_im_end = False
406
-
407
- conv = Conversation(messages=messages)
408
- prompt = conv.encode_for_inference(
409
- tokenizer=tokenizer, num_codebooks=config.num_codebooks
410
- ).to(device)
411
-
412
- if request.streaming:
413
- for i in range(request.num_samples):
414
- yield ServeStreamResponse(
415
- sample_id=i,
416
- delta=ServeStreamDelta(
417
- role="assistant",
418
- ),
419
- )
420
-
421
- req = {
422
- "prompt": prompt,
423
- "max_new_tokens": request.max_new_tokens,
424
- "im_end_id": im_end_id,
425
- "semantic_id": semantic_id,
426
- "temperature": request.temperature,
427
- "top_p": request.top_p,
428
- "repetition_penalty": request.repetition_penalty,
429
- "num_samples": request.num_samples,
430
- "early_stop_threshold": request.early_stop_threshold,
431
- }
432
-
433
- start = time.time()
434
- response_queue = queue.Queue()
435
- input_queue.put(GenerateRequest(req, response_queue))
436
-
437
- # Decoding
438
- decode_buffer = [[] for _ in range(request.num_samples)]
439
- parts = [[] for _ in range(request.num_samples)]
440
-
441
- def send_reset_buffer(sample_id):
442
- nonlocal decode_buffer
443
- if len(decode_buffer[sample_id]) == 0:
444
- return
445
-
446
- decoded = tokenizer.decode(decode_buffer[sample_id])
447
- part = ServeTextPart(text=decoded)
448
-
449
- if request.streaming:
450
- yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
451
- else:
452
- parts[sample_id].append(part)
453
-
454
- decode_buffer[sample_id] = []
455
-
456
- # Decode process
457
- finished = [False for _ in range(request.num_samples)]
458
- stats = {}
459
- idx = 0
460
- while True:
461
- response = response_queue.get()
462
-
463
- if response in ["stop", "error"]:
464
- break
465
-
466
- for sample_id, tokens in enumerate(response):
467
- if finished[sample_id]:
468
- continue
469
-
470
- if tokens[0] == im_end_id:
471
- finished[sample_id] = True
472
- if request.streaming:
473
- yield from send_reset_buffer(sample_id)
474
- yield ServeStreamResponse(
475
- sample_id=sample_id,
476
- finish_reason="stop",
477
- stats=stats,
478
- )
479
- continue
480
-
481
- if tokens[0] == semantic_id and request.streaming:
482
- yield from send_reset_buffer(sample_id)
483
- # Streaming vq
484
- _tokens = tokens[1:].clone() - 1
485
-
486
- if config.share_codebook_embeddings is False:
487
- for i in range(len(_tokens)):
488
- _tokens[i] -= config.codebook_size * i
489
-
490
- yield ServeStreamResponse(
491
- sample_id=sample_id,
492
- delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
493
- )
494
- continue
495
-
496
- # Not streaming vq
497
- if tokens[0] == semantic_id:
498
- yield from send_reset_buffer(sample_id)
499
- # None streaming vq
500
- if len(parts[sample_id]) == 0 or not isinstance(
501
- parts[sample_id][-1], ServeVQPart
502
- ):
503
- _tokens = tokens[1:].clone() - 1
504
-
505
- if config.share_codebook_embeddings is False:
506
- for i in range(len(_tokens)):
507
- _tokens[i] -= config.codebook_size * i
508
-
509
- parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
510
- else:
511
- for codebook_id, value in enumerate(tokens[1:, :]):
512
- val = value.item() - 1
513
- if config.share_codebook_embeddings is False:
514
- val -= config.codebook_size * codebook_id
515
-
516
- parts[sample_id][-1].codes[codebook_id].append(val)
517
- continue
518
-
519
- if tokens[0] != semantic_id:
520
- # Stream text decode is not supported now
521
- decode_buffer[sample_id].append(tokens[0, 0])
522
-
523
- if idx == 0:
524
- stats["time_to_first_token"] = (time.time() - start) * 1000
525
-
526
- idx += 1
527
-
528
- for sample_id in range(request.num_samples):
529
- yield from send_reset_buffer(sample_id)
530
-
531
- stats["total_time"] = (time.time() - start) * 1000
532
- stats["total_tokens"] = idx
533
-
534
- if request.streaming:
535
- for sample_id in range(request.num_samples):
536
- if finished[sample_id]:
537
- continue
538
- yield ServeStreamResponse(
539
- finish_reason=response, stats=stats, sample_id=sample_id
540
- )
541
- return
542
-
543
- yield ServeResponse(
544
- messages=[
545
- ServeMessage(role="assistant", parts=parts[i])
546
- for i in range(request.num_samples)
547
- ],
548
- finish_reason=response,
549
- stats=stats,
550
- )
551
-
552
-
553
- # @routes.http.post("/v1/chat")
554
- # def api_invoke_chat(
555
- # req: Annotated[ServeRequest, Body(exclusive=True)],
556
- # ):
557
- # """
558
- # Invoke model and generate audio
559
- # """
560
- #
561
- # # This makes torch compile happy
562
- # assert (
563
- # req.num_samples == GLOBAL_NUM_SAMPLES
564
- # ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
565
- #
566
- # content_type = request.headers.get("Content-Type", "application/json")
567
- # json_mode = "application/json" in content_type
568
- #
569
- # async def wrapped_generator():
570
- # generator = execute_request(llama_queue, tokenizer, config, req, args.device)
571
- #
572
- # for i in generator:
573
- # if json_mode:
574
- # body = i.model_dump_json().encode("utf-8")
575
- # yield b"data: " + body + b"\n\n"
576
- # else:
577
- # body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
578
- # yield struct.pack("I", len(body)) + body
579
- #
580
- # # Naive mode
581
- # if req.streaming is False:
582
- # result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
583
- #
584
- # if json_mode:
585
- # return JSONResponse(result.model_dump())
586
- # else:
587
- # return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
588
- #
589
- # return StreamResponse(
590
- # iterable=wrapped_generator(), content_type="text/event-stream"
591
- # )
592
-
593
-
594
- @torch.inference_mode()
595
- def inference(req: ServeTTSRequest):
596
-
597
- global prompt_tokens, prompt_texts
598
-
599
- idstr: str | None = req.reference_id
600
- if idstr is not None:
601
- ref_folder = Path("references") / idstr
602
- ref_folder.mkdir(parents=True, exist_ok=True)
603
- ref_audios = list_files(
604
- ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
605
- )
606
-
607
- if req.use_memory_cache == "never" or (
608
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
609
- ):
610
- prompt_tokens = [
611
- encode_reference(
612
- decoder_model=decoder_model,
613
- reference_audio=audio_to_bytes(str(ref_audio)),
614
- enable_reference_audio=True,
615
- )
616
- for ref_audio in ref_audios
617
- ]
618
- prompt_texts = [
619
- read_ref_text(str(ref_audio.with_suffix(".lab")))
620
- for ref_audio in ref_audios
621
- ]
622
- else:
623
- logger.info("Use same references")
624
-
625
- else:
626
- # Parse reference audio aka prompt
627
- refs = req.references
628
-
629
- if req.use_memory_cache == "never" or (
630
- req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
631
- ):
632
- prompt_tokens = [
633
- encode_reference(
634
- decoder_model=decoder_model,
635
- reference_audio=ref.audio,
636
- enable_reference_audio=True,
637
- )
638
- for ref in refs
639
- ]
640
- prompt_texts = [ref.text for ref in refs]
641
- else:
642
- logger.info("Use same references")
643
-
644
- if req.seed is not None:
645
- set_seed(req.seed)
646
- logger.warning(f"set seed: {req.seed}")
647
-
648
- # LLAMA Inference
649
- request = dict(
650
- device=decoder_model.device,
651
- max_new_tokens=req.max_new_tokens,
652
- text=(
653
- req.text
654
- if not req.normalize
655
- else ChnNormedText(raw_text=req.text).normalize()
656
- ),
657
- top_p=req.top_p,
658
- repetition_penalty=req.repetition_penalty,
659
- temperature=req.temperature,
660
- compile=args.compile,
661
- iterative_prompt=req.chunk_length > 0,
662
- chunk_length=req.chunk_length,
663
- max_length=4096,
664
- prompt_tokens=prompt_tokens,
665
- prompt_text=prompt_texts,
666
- )
667
-
668
- response_queue = queue.Queue()
669
- llama_queue.put(
670
- GenerateRequest(
671
- request=request,
672
- response_queue=response_queue,
673
- )
674
- )
675
-
676
- if req.streaming:
677
- yield wav_chunk_header()
678
-
679
- segments = []
680
- while True:
681
- result: WrappedGenerateResponse = response_queue.get()
682
- if result.status == "error":
683
- raise result.response
684
- break
685
-
686
- result: GenerateResponse = result.response
687
- if result.action == "next":
688
- break
689
-
690
- with autocast_exclude_mps(
691
- device_type=decoder_model.device.type, dtype=args.precision
692
- ):
693
- fake_audios = decode_vq_tokens(
694
- decoder_model=decoder_model,
695
- codes=result.codes,
696
- )
697
-
698
- fake_audios = fake_audios.float().cpu().numpy()
699
-
700
- if req.streaming:
701
- yield (fake_audios * 32768).astype(np.int16).tobytes()
702
- else:
703
- segments.append(fake_audios)
704
-
705
- if req.streaming:
706
- return
707
-
708
- if len(segments) == 0:
709
- raise HTTPException(
710
- HTTPStatus.INTERNAL_SERVER_ERROR,
711
- content="No audio generated, please check the input text.",
712
- )
713
-
714
- fake_audios = np.concatenate(segments, axis=0)
715
- yield fake_audios
716
-
717
-
718
- async def inference_async(req: ServeTTSRequest):
719
- for chunk in inference(req):
720
- yield chunk
721
-
722
-
723
- async def buffer_to_async_generator(buffer):
724
- yield buffer
725
-
726
-
727
- # @routes.http.post("/v1/tts")
728
- # async def api_invoke_model(
729
- # req: Annotated[ServeTTSRequest, Body(exclusive=True)],
730
- # ):
731
- # """
732
- # Invoke model and generate audio
733
- # """
734
- #
735
- # if args.max_text_length > 0 and len(req.text) > args.max_text_length:
736
- # raise HTTPException(
737
- # HTTPStatus.BAD_REQUEST,
738
- # content=f"Text is too long, max length is {args.max_text_length}",
739
- # )
740
- #
741
- # if req.streaming and req.format != "wav":
742
- # raise HTTPException(
743
- # HTTPStatus.BAD_REQUEST,
744
- # content="Streaming only supports WAV format",
745
- # )
746
- #
747
- # if req.streaming:
748
- # return StreamResponse(
749
- # iterable=inference_async(req),
750
- # headers={
751
- # "Content-Disposition": f"attachment; filename=audio.{req.format}",
752
- # },
753
- # content_type=get_content_type(req.format),
754
- # )
755
- # else:
756
- # fake_audios = next(inference(req))
757
- # buffer = io.BytesIO()
758
- # sf.write(
759
- # buffer,
760
- # fake_audios,
761
- # decoder_model.spec_transform.sample_rate,
762
- # format=req.format,
763
- # )
764
- #
765
- # return StreamResponse(
766
- # iterable=buffer_to_async_generator(buffer.getvalue()),
767
- # headers={
768
- # "Content-Disposition": f"attachment; filename=audio.{req.format}",
769
- # },
770
- # content_type=get_content_type(req.format),
771
- # )
772
- #
773
- #
774
- # @routes.http.post("/v1/health")
775
- # async def api_health():
776
- # """
777
- # Health check
778
- # """
779
- #
780
- # return JSONResponse({"status": "ok"})
781
-
782
-
783
- def parse_args():
784
- parser = ArgumentParser()
785
- parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
786
- parser.add_argument("--load-asr-model", action="store_true")
787
- parser.add_argument(
788
- "--llama-checkpoint-path",
789
- type=str,
790
- default="checkpoints/fish-speech-1.4",
791
- )
792
- parser.add_argument(
793
- "--decoder-checkpoint-path",
794
- type=str,
795
- default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
796
- )
797
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
798
- parser.add_argument("--device", type=str, default="cuda")
799
- parser.add_argument("--half", action="store_true")
800
- parser.add_argument("--compile", action="store_true")
801
- parser.add_argument("--max-text-length", type=int, default=0)
802
- parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
803
- parser.add_argument("--workers", type=int, default=1)
804
-
805
- return parser.parse_args()
806
-
807
-
808
- # Define Kui app
809
- # openapi = OpenAPI(
810
- # {
811
- # "title": "Fish Speech API",
812
- # "version": "1.4.2",
813
- # },
814
- # ).routes
815
- #
816
- #
817
- # class MsgPackRequest(HttpRequest):
818
- # async def data(
819
- # self,
820
- # ) -> Annotated[
821
- # Any, ContentType("application/msgpack"), ContentType("application/json")
822
- # ]:
823
- # if self.content_type == "application/msgpack":
824
- # return ormsgpack.unpackb(await self.body)
825
- #
826
- # elif self.content_type == "application/json":
827
- # return await self.json
828
- #
829
- # raise HTTPException(
830
- # HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
831
- # headers={"Accept": "application/msgpack, application/json"},
832
- # )
833
- #
834
- #
835
- # app = Kui(
836
- # routes=routes + openapi[1:], # Remove the default route
837
- # exception_handlers={
838
- # HTTPException: http_execption_handler,
839
- # Exception: other_exception_handler,
840
- # },
841
- # factory_class=FactoryClass(http=MsgPackRequest),
842
- # cors_config={},
843
- # )
844
-
845
-
846
- def load_asr_model(*, device="cuda", hub="ms"):
847
- return AutoModel(
848
- model="iic/SenseVoiceSmall",
849
- device=device,
850
- disable_pbar=True,
851
- hub=hub,
852
- )
853
-
854
-
855
- # Each worker process created by Uvicorn has its own memory space,
856
- # meaning that models and variables are not shared between processes.
857
- # Therefore, any global variables (like `llama_queue` or `decoder_model`)
858
- # will not be shared across workers.
859
-
860
-
861
- # Multi-threading for deep learning can cause issues, such as inconsistent
862
- # outputs if multiple threads access the same buffers simultaneously.
863
- # Instead, it's better to use multiprocessing or independent models per thread.
864
- # @app.on_startup
865
- # def initialize_app(app: Kui):
866
- #
867
- # global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
868
- #
869
- # prompt_tokens, prompt_texts = [], []
870
- #
871
- # args = parse_args() # args same as ones in other processes
872
- # args.precision = torch.half if args.half else torch.bfloat16
873
- #
874
- # if args.load_asr_model:
875
- # logger.info(f"Loading ASR model...")
876
- # asr_model = load_asr_model(device=args.device)
877
- #
878
- # logger.info("Loading Llama model...")
879
- #
880
- # if args.mode == "tts":
881
- # llama_queue = launch_thread_safe_queue(
882
- # checkpoint_path=args.llama_checkpoint_path,
883
- # device=args.device,
884
- # precision=args.precision,
885
- # compile=args.compile,
886
- # )
887
- # else:
888
- # llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
889
- # checkpoint_path=args.llama_checkpoint_path,
890
- # device=args.device,
891
- # precision=args.precision,
892
- # compile=args.compile,
893
- # )
894
- #
895
- # logger.info("Llama model loaded, loading VQ-GAN model...")
896
- #
897
- # decoder_model = load_decoder_model(
898
- # config_name=args.decoder_config_name,
899
- # checkpoint_path=args.decoder_checkpoint_path,
900
- # device=args.device,
901
- # )
902
- #
903
- # logger.info("VQ-GAN model loaded, warming up...")
904
- #
905
- # vad_model = load_silero_vad()
906
- #
907
- # logger.info("VAD model loaded, warming up...")
908
- #
909
- # if args.mode == "tts":
910
- # # Dry run to ensure models work and avoid first-time latency
911
- # list(
912
- # inference(
913
- # ServeTTSRequest(
914
- # text="Hello world.",
915
- # references=[],
916
- # reference_id=None,
917
- # max_new_tokens=0,
918
- # chunk_length=200,
919
- # top_p=0.7,
920
- # repetition_penalty=1.2,
921
- # temperature=0.7,
922
- # emotion=None,
923
- # format="wav",
924
- # )
925
- # )
926
- # )
927
- #
928
- # logger.info(f"Warming up done, starting server at http://{args.listen}")
929
-
930
-
931
- if __name__ == "__main__":
932
-
933
- import uvicorn
934
-
935
- args = parse_args()
936
- host, port = args.listen.split(":")
937
- uvicorn.run(
938
- "tools.api:app",
939
- host=host,
940
- port=int(port),
941
- workers=args.workers,
942
- log_level="info",
943
- )