xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -9,16 +9,20 @@ import wave
|
|
|
9
9
|
from argparse import ArgumentParser
|
|
10
10
|
from http import HTTPStatus
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from typing import Annotated, Literal, Optional
|
|
12
|
+
from typing import Annotated, Any, Literal, Optional
|
|
13
13
|
|
|
14
14
|
import numpy as np
|
|
15
|
+
import ormsgpack
|
|
15
16
|
# import pyrootutils
|
|
16
17
|
import soundfile as sf
|
|
17
18
|
import torch
|
|
18
19
|
import torchaudio
|
|
20
|
+
# from baize.datastructures import ContentType
|
|
19
21
|
# from kui.asgi import (
|
|
20
22
|
# Body,
|
|
23
|
+
# FactoryClass,
|
|
21
24
|
# HTTPException,
|
|
25
|
+
# HttpRequest,
|
|
22
26
|
# HttpView,
|
|
23
27
|
# JSONResponse,
|
|
24
28
|
# Kui,
|
|
@@ -27,14 +31,16 @@ import torchaudio
|
|
|
27
31
|
# )
|
|
28
32
|
# from kui.asgi.routing import MultimethodRoutes
|
|
29
33
|
from loguru import logger
|
|
30
|
-
from pydantic import BaseModel, Field
|
|
34
|
+
from pydantic import BaseModel, Field, conint
|
|
31
35
|
|
|
32
36
|
# pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
33
37
|
|
|
34
38
|
# from fish_speech.models.vqgan.lit_module import VQGAN
|
|
35
39
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
40
|
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
36
41
|
from fish_speech.utils import autocast_exclude_mps
|
|
37
|
-
|
|
42
|
+
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
|
43
|
+
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
38
44
|
from tools.llama.generate import (
|
|
39
45
|
GenerateRequest,
|
|
40
46
|
GenerateResponse,
|
|
@@ -82,11 +88,8 @@ async def other_exception_handler(exc: "Exception"):
|
|
|
82
88
|
|
|
83
89
|
def load_audio(reference_audio, sr):
|
|
84
90
|
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
reference_audio = io.BytesIO(audio_data)
|
|
88
|
-
except base64.binascii.Error:
|
|
89
|
-
raise ValueError("Invalid path or base64 string")
|
|
91
|
+
audio_data = reference_audio
|
|
92
|
+
reference_audio = io.BytesIO(audio_data)
|
|
90
93
|
|
|
91
94
|
waveform, original_sr = torchaudio.load(
|
|
92
95
|
reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
|
|
@@ -145,7 +148,7 @@ def decode_vq_tokens(
|
|
|
145
148
|
return decoder_model.decode(
|
|
146
149
|
indices=codes[None],
|
|
147
150
|
feature_lengths=feature_lengths,
|
|
148
|
-
).squeeze()
|
|
151
|
+
)[0].squeeze()
|
|
149
152
|
|
|
150
153
|
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
|
151
154
|
|
|
@@ -153,58 +156,6 @@ def decode_vq_tokens(
|
|
|
153
156
|
# routes = MultimethodRoutes(base_class=HttpView)
|
|
154
157
|
|
|
155
158
|
|
|
156
|
-
def get_random_paths(base_path, data, speaker, emotion):
|
|
157
|
-
if base_path and data and speaker and emotion and (Path(base_path).exists()):
|
|
158
|
-
if speaker in data and emotion in data[speaker]:
|
|
159
|
-
files = data[speaker][emotion]
|
|
160
|
-
lab_files = [f for f in files if f.endswith(".lab")]
|
|
161
|
-
wav_files = [f for f in files if f.endswith(".wav")]
|
|
162
|
-
|
|
163
|
-
if lab_files and wav_files:
|
|
164
|
-
selected_lab = random.choice(lab_files)
|
|
165
|
-
selected_wav = random.choice(wav_files)
|
|
166
|
-
|
|
167
|
-
lab_path = Path(base_path) / speaker / emotion / selected_lab
|
|
168
|
-
wav_path = Path(base_path) / speaker / emotion / selected_wav
|
|
169
|
-
if lab_path.exists() and wav_path.exists():
|
|
170
|
-
return lab_path, wav_path
|
|
171
|
-
|
|
172
|
-
return None, None
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def load_json(json_file):
|
|
176
|
-
if not json_file:
|
|
177
|
-
logger.info("Not using a json file")
|
|
178
|
-
return None
|
|
179
|
-
try:
|
|
180
|
-
with open(json_file, "r", encoding="utf-8") as file:
|
|
181
|
-
data = json.load(file)
|
|
182
|
-
except FileNotFoundError:
|
|
183
|
-
logger.warning(f"ref json not found: {json_file}")
|
|
184
|
-
data = None
|
|
185
|
-
except Exception as e:
|
|
186
|
-
logger.warning(f"Loading json failed: {e}")
|
|
187
|
-
data = None
|
|
188
|
-
return data
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
class InvokeRequest(BaseModel):
|
|
192
|
-
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
193
|
-
reference_text: Optional[str] = None
|
|
194
|
-
reference_audio: Optional[str] = None
|
|
195
|
-
max_new_tokens: int = 1024
|
|
196
|
-
chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
|
|
197
|
-
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
198
|
-
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
199
|
-
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
200
|
-
emotion: Optional[str] = None
|
|
201
|
-
format: Literal["wav", "mp3", "flac"] = "wav"
|
|
202
|
-
streaming: bool = False
|
|
203
|
-
ref_json: Optional[str] = "ref_data.json"
|
|
204
|
-
ref_base: Optional[str] = "ref_data"
|
|
205
|
-
speaker: Optional[str] = None
|
|
206
|
-
|
|
207
|
-
|
|
208
159
|
def get_content_type(audio_format):
|
|
209
160
|
if audio_format == "wav":
|
|
210
161
|
return "audio/wav"
|
|
@@ -217,35 +168,52 @@ def get_content_type(audio_format):
|
|
|
217
168
|
|
|
218
169
|
|
|
219
170
|
@torch.inference_mode()
|
|
220
|
-
def inference(req:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
171
|
+
def inference(req: ServeTTSRequest):
|
|
172
|
+
|
|
173
|
+
idstr: str | None = req.reference_id
|
|
174
|
+
if idstr is not None:
|
|
175
|
+
ref_folder = Path("references") / idstr
|
|
176
|
+
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
177
|
+
ref_audios = list_files(
|
|
178
|
+
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
179
|
+
)
|
|
180
|
+
prompt_tokens = [
|
|
181
|
+
encode_reference(
|
|
182
|
+
decoder_model=decoder_model,
|
|
183
|
+
reference_audio=audio_to_bytes(str(ref_audio)),
|
|
184
|
+
enable_reference_audio=True,
|
|
185
|
+
)
|
|
186
|
+
for ref_audio in ref_audios
|
|
187
|
+
]
|
|
188
|
+
prompt_texts = [
|
|
189
|
+
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
190
|
+
for ref_audio in ref_audios
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
# Parse reference audio aka prompt
|
|
195
|
+
refs = req.references
|
|
196
|
+
if refs is None:
|
|
197
|
+
refs = []
|
|
198
|
+
prompt_tokens = [
|
|
199
|
+
encode_reference(
|
|
200
|
+
decoder_model=decoder_model,
|
|
201
|
+
reference_audio=ref.audio,
|
|
202
|
+
enable_reference_audio=True,
|
|
203
|
+
)
|
|
204
|
+
for ref in refs
|
|
205
|
+
]
|
|
206
|
+
prompt_texts = [ref.text for ref in refs]
|
|
207
|
+
|
|
244
208
|
# LLAMA Inference
|
|
245
209
|
request = dict(
|
|
246
210
|
device=decoder_model.device,
|
|
247
211
|
max_new_tokens=req.max_new_tokens,
|
|
248
|
-
text=
|
|
212
|
+
text=(
|
|
213
|
+
req.text
|
|
214
|
+
if not req.normalize
|
|
215
|
+
else ChnNormedText(raw_text=req.text).normalize()
|
|
216
|
+
),
|
|
249
217
|
top_p=req.top_p,
|
|
250
218
|
repetition_penalty=req.repetition_penalty,
|
|
251
219
|
temperature=req.temperature,
|
|
@@ -254,7 +222,7 @@ def inference(req: InvokeRequest):
|
|
|
254
222
|
chunk_length=req.chunk_length,
|
|
255
223
|
max_length=2048,
|
|
256
224
|
prompt_tokens=prompt_tokens,
|
|
257
|
-
prompt_text=
|
|
225
|
+
prompt_text=prompt_texts,
|
|
258
226
|
)
|
|
259
227
|
|
|
260
228
|
response_queue = queue.Queue()
|
|
@@ -307,40 +275,7 @@ def inference(req: InvokeRequest):
|
|
|
307
275
|
yield fake_audios
|
|
308
276
|
|
|
309
277
|
|
|
310
|
-
def
|
|
311
|
-
if not use_auto_rerank:
|
|
312
|
-
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
|
|
313
|
-
return inference(req)
|
|
314
|
-
|
|
315
|
-
zh_model, en_model = load_model()
|
|
316
|
-
max_attempts = 5
|
|
317
|
-
best_wer = float("inf")
|
|
318
|
-
best_audio = None
|
|
319
|
-
|
|
320
|
-
for attempt in range(max_attempts):
|
|
321
|
-
# 调用原始的 inference 函数
|
|
322
|
-
audio_generator = inference(req)
|
|
323
|
-
fake_audios = next(audio_generator)
|
|
324
|
-
|
|
325
|
-
asr_result = batch_asr(
|
|
326
|
-
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
|
|
327
|
-
)[0]
|
|
328
|
-
wer = calculate_wer(req.text, asr_result["text"])
|
|
329
|
-
|
|
330
|
-
if wer <= 0.1 and not asr_result["huge_gap"]:
|
|
331
|
-
return fake_audios
|
|
332
|
-
|
|
333
|
-
if wer < best_wer:
|
|
334
|
-
best_wer = wer
|
|
335
|
-
best_audio = fake_audios
|
|
336
|
-
|
|
337
|
-
if attempt == max_attempts - 1:
|
|
338
|
-
break
|
|
339
|
-
|
|
340
|
-
return best_audio
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
async def inference_async(req: InvokeRequest):
|
|
278
|
+
async def inference_async(req: ServeTTSRequest):
|
|
344
279
|
for chunk in inference(req):
|
|
345
280
|
yield chunk
|
|
346
281
|
|
|
@@ -349,9 +284,9 @@ async def buffer_to_async_generator(buffer):
|
|
|
349
284
|
yield buffer
|
|
350
285
|
|
|
351
286
|
|
|
352
|
-
# @routes.http.post("/v1/
|
|
287
|
+
# @routes.http.post("/v1/tts")
|
|
353
288
|
# async def api_invoke_model(
|
|
354
|
-
# req: Annotated[
|
|
289
|
+
# req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
|
355
290
|
# ):
|
|
356
291
|
# """
|
|
357
292
|
# Invoke model and generate audio
|
|
@@ -410,21 +345,20 @@ def parse_args():
|
|
|
410
345
|
parser.add_argument(
|
|
411
346
|
"--llama-checkpoint-path",
|
|
412
347
|
type=str,
|
|
413
|
-
default="checkpoints/fish-speech-1.
|
|
348
|
+
default="checkpoints/fish-speech-1.4",
|
|
414
349
|
)
|
|
415
350
|
parser.add_argument(
|
|
416
351
|
"--decoder-checkpoint-path",
|
|
417
352
|
type=str,
|
|
418
|
-
default="checkpoints/fish-speech-1.
|
|
353
|
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
419
354
|
)
|
|
420
355
|
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
421
356
|
parser.add_argument("--device", type=str, default="cuda")
|
|
422
357
|
parser.add_argument("--half", action="store_true")
|
|
423
358
|
parser.add_argument("--compile", action="store_true")
|
|
424
359
|
parser.add_argument("--max-text-length", type=int, default=0)
|
|
425
|
-
parser.add_argument("--listen", type=str, default="127.0.0.1:
|
|
360
|
+
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
|
426
361
|
parser.add_argument("--workers", type=int, default=1)
|
|
427
|
-
parser.add_argument("--use-auto-rerank", type=bool, default=True)
|
|
428
362
|
|
|
429
363
|
return parser.parse_args()
|
|
430
364
|
|
|
@@ -436,18 +370,30 @@ def parse_args():
|
|
|
436
370
|
# },
|
|
437
371
|
# ).routes
|
|
438
372
|
#
|
|
373
|
+
#
|
|
374
|
+
# class MsgPackRequest(HttpRequest):
|
|
375
|
+
# async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
|
|
376
|
+
# if self.content_type == "application/msgpack":
|
|
377
|
+
# return ormsgpack.unpackb(await self.body)
|
|
378
|
+
#
|
|
379
|
+
# raise HTTPException(
|
|
380
|
+
# HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
381
|
+
# headers={"Accept": "application/msgpack"},
|
|
382
|
+
# )
|
|
383
|
+
#
|
|
384
|
+
#
|
|
439
385
|
# app = Kui(
|
|
440
386
|
# routes=routes + openapi[1:], # Remove the default route
|
|
441
387
|
# exception_handlers={
|
|
442
388
|
# HTTPException: http_execption_handler,
|
|
443
389
|
# Exception: other_exception_handler,
|
|
444
390
|
# },
|
|
391
|
+
# factory_class=FactoryClass(http=MsgPackRequest),
|
|
445
392
|
# cors_config={},
|
|
446
393
|
# )
|
|
447
394
|
|
|
448
395
|
|
|
449
396
|
if __name__ == "__main__":
|
|
450
|
-
import threading
|
|
451
397
|
|
|
452
398
|
import uvicorn
|
|
453
399
|
|
|
@@ -474,18 +420,17 @@ if __name__ == "__main__":
|
|
|
474
420
|
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
475
421
|
list(
|
|
476
422
|
inference(
|
|
477
|
-
|
|
423
|
+
ServeTTSRequest(
|
|
478
424
|
text="Hello world.",
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
max_new_tokens=
|
|
425
|
+
references=[],
|
|
426
|
+
reference_id=None,
|
|
427
|
+
max_new_tokens=1024,
|
|
428
|
+
chunk_length=200,
|
|
482
429
|
top_p=0.7,
|
|
483
430
|
repetition_penalty=1.2,
|
|
484
431
|
temperature=0.7,
|
|
485
432
|
emotion=None,
|
|
486
433
|
format="wav",
|
|
487
|
-
ref_base=None,
|
|
488
|
-
ref_json=None,
|
|
489
434
|
)
|
|
490
435
|
)
|
|
491
436
|
)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Annotated, Literal, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, conint
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ServeReferenceAudio(BaseModel):
|
|
7
|
+
audio: bytes
|
|
8
|
+
text: str
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ServeTTSRequest(BaseModel):
|
|
12
|
+
text: str
|
|
13
|
+
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
|
14
|
+
# Audio format
|
|
15
|
+
format: Literal["wav", "pcm", "mp3"] = "wav"
|
|
16
|
+
mp3_bitrate: Literal[64, 128, 192] = 128
|
|
17
|
+
# References audios for in-context learning
|
|
18
|
+
references: list[ServeReferenceAudio] = []
|
|
19
|
+
# Reference id
|
|
20
|
+
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
|
21
|
+
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
|
22
|
+
reference_id: str | None = None
|
|
23
|
+
# Normalize text for en & zh, this increase stability for numbers
|
|
24
|
+
normalize: bool = True
|
|
25
|
+
mp3_bitrate: Optional[int] = 64
|
|
26
|
+
opus_bitrate: Optional[int] = -1000
|
|
27
|
+
# Balance mode will reduce latency to 300ms, but may decrease stability
|
|
28
|
+
latency: Literal["normal", "balanced"] = "normal"
|
|
29
|
+
# not usually used below
|
|
30
|
+
streaming: bool = False
|
|
31
|
+
emotion: Optional[str] = None
|
|
32
|
+
max_new_tokens: int = 1024
|
|
33
|
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
34
|
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
35
|
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
@@ -22,8 +22,8 @@ def check_and_download_files(repo_id, file_list, local_dir):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# 1st
|
|
25
|
-
repo_id_1 = "fishaudio/fish-speech-1.
|
|
26
|
-
local_dir_1 = "./checkpoints/fish-speech-1.
|
|
25
|
+
repo_id_1 = "fishaudio/fish-speech-1.4"
|
|
26
|
+
local_dir_1 = "./checkpoints/fish-speech-1.4"
|
|
27
27
|
files_1 = [
|
|
28
28
|
"model.pth",
|
|
29
29
|
"README.md",
|
|
@@ -31,7 +31,7 @@ files_1 = [
|
|
|
31
31
|
"tokenizer_config.json",
|
|
32
32
|
"tokenizer.json",
|
|
33
33
|
"config.json",
|
|
34
|
-
"firefly-gan-vq-fsq-
|
|
34
|
+
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
35
35
|
]
|
|
36
36
|
|
|
37
37
|
# 3rd
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
from typing import Union
|
|
3
4
|
|
|
@@ -23,6 +24,22 @@ VIDEO_EXTENSIONS = {
|
|
|
23
24
|
}
|
|
24
25
|
|
|
25
26
|
|
|
27
|
+
def audio_to_bytes(file_path):
|
|
28
|
+
if not file_path or not Path(file_path).exists():
|
|
29
|
+
return None
|
|
30
|
+
with open(file_path, "rb") as wav_file:
|
|
31
|
+
wav = wav_file.read()
|
|
32
|
+
return wav
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def read_ref_text(ref_text):
|
|
36
|
+
path = Path(ref_text)
|
|
37
|
+
if path.exists() and path.is_file():
|
|
38
|
+
with path.open("r", encoding="utf-8") as file:
|
|
39
|
+
return file.read()
|
|
40
|
+
return ref_text
|
|
41
|
+
|
|
42
|
+
|
|
26
43
|
def list_files(
|
|
27
44
|
path: Union[Path, str],
|
|
28
45
|
extensions: set[str] = None,
|
|
@@ -13,7 +13,7 @@ from tqdm import tqdm
|
|
|
13
13
|
|
|
14
14
|
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
|
15
15
|
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
|
16
|
-
from
|
|
16
|
+
from tools.file import load_filelist
|
|
17
17
|
|
|
18
18
|
# To avoid CPU overload
|
|
19
19
|
os.environ["MKL_NUM_THREADS"] = "1"
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
import queue
|
|
3
3
|
import threading
|
|
4
4
|
import time
|
|
5
|
+
from contextlib import nullcontext
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import Literal, Optional, Tuple, Union
|
|
@@ -93,15 +94,20 @@ def decode_one_token_ar(
|
|
|
93
94
|
**sampling_kwargs,
|
|
94
95
|
) -> torch.Tensor:
|
|
95
96
|
x = model.forward_generate(x, input_pos)
|
|
97
|
+
|
|
98
|
+
sampling_kwargs_main = sampling_kwargs.copy()
|
|
99
|
+
sampling_kwargs_main["temperature"] = 0.1
|
|
100
|
+
sampling_kwargs_main["top_p"] = 0.1
|
|
101
|
+
sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
102
|
+
|
|
96
103
|
codebooks = [
|
|
97
104
|
sample(
|
|
98
105
|
x.logits,
|
|
99
|
-
previous_tokens=
|
|
100
|
-
|
|
101
|
-
), # Disable repetition penalty for the token codebook
|
|
102
|
-
**sampling_kwargs,
|
|
106
|
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
107
|
+
**sampling_kwargs_main,
|
|
103
108
|
)[0]
|
|
104
109
|
]
|
|
110
|
+
|
|
105
111
|
x = x.hidden_states
|
|
106
112
|
|
|
107
113
|
# Cleanup the cache
|
|
@@ -136,11 +142,16 @@ def decode_one_token_naive(
|
|
|
136
142
|
) -> torch.Tensor:
|
|
137
143
|
x = model.forward_generate(x, input_pos)
|
|
138
144
|
|
|
145
|
+
sampling_kwargs_main = sampling_kwargs.copy()
|
|
146
|
+
sampling_kwargs_main["temperature"] = 0.1
|
|
147
|
+
sampling_kwargs_main["top_p"] = 0.1
|
|
148
|
+
sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
149
|
+
|
|
139
150
|
codebooks = [
|
|
140
151
|
sample(
|
|
141
|
-
x.
|
|
152
|
+
x.logits,
|
|
142
153
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
143
|
-
**
|
|
154
|
+
**sampling_kwargs_main,
|
|
144
155
|
)[0]
|
|
145
156
|
]
|
|
146
157
|
|
|
@@ -181,8 +192,12 @@ def decode_n_tokens(
|
|
|
181
192
|
else:
|
|
182
193
|
window = previous_tokens[:, i - win_size : i]
|
|
183
194
|
|
|
184
|
-
with
|
|
185
|
-
|
|
195
|
+
with (
|
|
196
|
+
torch.backends.cuda.sdp_kernel(
|
|
197
|
+
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
198
|
+
)
|
|
199
|
+
if torch.cuda.is_available()
|
|
200
|
+
else nullcontext()
|
|
186
201
|
): # Actually better for Inductor to codegen attention here
|
|
187
202
|
next_token = decode_one_token(
|
|
188
203
|
model=model,
|
|
@@ -222,25 +237,11 @@ def generate(
|
|
|
222
237
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
223
238
|
T = prompt.size(1)
|
|
224
239
|
|
|
225
|
-
if max_new_tokens:
|
|
226
|
-
if T + max_new_tokens > model.config.max_seq_len:
|
|
227
|
-
max_new_tokens = model.config.max_seq_len - T
|
|
228
|
-
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
229
|
-
|
|
230
|
-
T_new = T + max_new_tokens
|
|
231
|
-
else:
|
|
232
|
-
T_new = model.config.max_seq_len
|
|
233
|
-
max_new_tokens = T_new - T
|
|
234
|
-
|
|
235
240
|
device, dtype = prompt.device, prompt.dtype
|
|
236
|
-
with torch.device(device):
|
|
237
|
-
model.setup_caches(
|
|
238
|
-
max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
|
|
239
|
-
)
|
|
240
241
|
|
|
241
242
|
codebook_dim = 1 + model.config.num_codebooks
|
|
242
243
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
243
|
-
empty = torch.empty((codebook_dim,
|
|
244
|
+
empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
|
|
244
245
|
empty[:, :T] = prompt
|
|
245
246
|
seq = empty
|
|
246
247
|
input_pos = torch.arange(0, T, device=device)
|
|
@@ -560,6 +561,10 @@ def launch_thread_safe_queue(
|
|
|
560
561
|
model, decode_one_token = load_model(
|
|
561
562
|
checkpoint_path, device, precision, compile=compile
|
|
562
563
|
)
|
|
564
|
+
with torch.device(device):
|
|
565
|
+
model.setup_caches(
|
|
566
|
+
max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
|
|
567
|
+
)
|
|
563
568
|
init_event.set()
|
|
564
569
|
|
|
565
570
|
while True:
|
|
@@ -607,7 +612,7 @@ def launch_thread_safe_queue(
|
|
|
607
612
|
@click.option(
|
|
608
613
|
"--checkpoint-path",
|
|
609
614
|
type=click.Path(path_type=Path, exists=True),
|
|
610
|
-
default="checkpoints/fish-speech-1.
|
|
615
|
+
default="checkpoints/fish-speech-1.4",
|
|
611
616
|
)
|
|
612
617
|
@click.option("--device", type=str, default="cuda")
|
|
613
618
|
@click.option("--compile/--no-compile", default=False)
|
|
@@ -15,7 +15,7 @@ from fish_speech.models.text2semantic.lora import get_merged_state_dict
|
|
|
15
15
|
|
|
16
16
|
@click.command()
|
|
17
17
|
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
|
18
|
-
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.
|
|
18
|
+
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
|
|
19
19
|
@click.option("--lora-weight", type=str, required=True)
|
|
20
20
|
@click.option("--output", type=str, required=True)
|
|
21
21
|
def merge(lora_config, base_weight, lora_weight, output):
|
|
@@ -428,7 +428,7 @@ def generate_folder_name():
|
|
|
428
428
|
@click.option(
|
|
429
429
|
"--checkpoint-path",
|
|
430
430
|
type=click.Path(path_type=Path, exists=True),
|
|
431
|
-
default="checkpoints/fish-speech-1.
|
|
431
|
+
default="checkpoints/fish-speech-1.4",
|
|
432
432
|
)
|
|
433
433
|
@click.option(
|
|
434
434
|
"--mode", type=str, default="int8", help="type of quantization to perform"
|
|
@@ -451,7 +451,7 @@ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -
|
|
|
451
451
|
precision=precision,
|
|
452
452
|
compile=False,
|
|
453
453
|
)
|
|
454
|
-
vq_model = "firefly-gan-vq-fsq-
|
|
454
|
+
vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
|
|
455
455
|
now = timestamp if timestamp != "None" else generate_folder_name()
|
|
456
456
|
|
|
457
457
|
if mode == "int8":
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import httpx
|
|
2
|
+
import ormsgpack
|
|
3
|
+
|
|
4
|
+
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
|
5
|
+
|
|
6
|
+
# priority: ref_id > references
|
|
7
|
+
request = ServeTTSRequest(
|
|
8
|
+
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
9
|
+
# reference_id="114514",
|
|
10
|
+
references=[
|
|
11
|
+
ServeReferenceAudio(
|
|
12
|
+
audio=open("lengyue.wav", "rb").read(),
|
|
13
|
+
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
|
14
|
+
)
|
|
15
|
+
],
|
|
16
|
+
streaming=True,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
with (
|
|
20
|
+
httpx.Client() as client,
|
|
21
|
+
open("hello.wav", "wb") as f,
|
|
22
|
+
):
|
|
23
|
+
with client.stream(
|
|
24
|
+
"POST",
|
|
25
|
+
"http://127.0.0.1:8080/v1/tts",
|
|
26
|
+
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
|
27
|
+
headers={
|
|
28
|
+
"authorization": "Bearer YOUR_API_KEY",
|
|
29
|
+
"content-type": "application/msgpack",
|
|
30
|
+
},
|
|
31
|
+
timeout=None,
|
|
32
|
+
) as response:
|
|
33
|
+
for chunk in response.iter_bytes():
|
|
34
|
+
f.write(chunk)
|