xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +23 -1
- xinference/core/model.py +1 -6
- xinference/core/utils.py +10 -6
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/llm_family.json +47 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +49 -0
- xinference/model/llm/mlx/core.py +68 -13
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/utils.py +1 -0
- xinference/model/llm/vllm/core.py +11 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ from typing import Optional
|
|
|
6
6
|
|
|
7
7
|
import hydra
|
|
8
8
|
import lightning as L
|
|
9
|
-
|
|
9
|
+
import pyrootutils
|
|
10
10
|
import torch
|
|
11
11
|
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
|
12
12
|
from lightning.pytorch.loggers import Logger
|
|
@@ -18,7 +18,7 @@ os.environ.pop("SLURM_JOB_NAME", None)
|
|
|
18
18
|
os.environ.pop("SLURM_NTASKS_PER_NODE", None)
|
|
19
19
|
|
|
20
20
|
# register eval resolver and root
|
|
21
|
-
|
|
21
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
22
22
|
|
|
23
23
|
# Allow TF32 on Ampere GPUs
|
|
24
24
|
torch.set_float32_matmul_precision("high")
|
|
@@ -69,10 +69,6 @@ def parse_args():
|
|
|
69
69
|
parser.add_argument(
|
|
70
70
|
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
|
71
71
|
)
|
|
72
|
-
parser.add_argument(
|
|
73
|
-
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
|
|
74
|
-
)
|
|
75
|
-
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
|
76
72
|
parser.add_argument(
|
|
77
73
|
"--latency",
|
|
78
74
|
type=str,
|
|
@@ -83,7 +79,7 @@ def parse_args():
|
|
|
83
79
|
parser.add_argument(
|
|
84
80
|
"--max_new_tokens",
|
|
85
81
|
type=int,
|
|
86
|
-
default=
|
|
82
|
+
default=1024,
|
|
87
83
|
help="Maximum new tokens to generate. \n0 means no limit.",
|
|
88
84
|
)
|
|
89
85
|
parser.add_argument(
|
|
@@ -112,11 +108,9 @@ def parse_args():
|
|
|
112
108
|
parser.add_argument(
|
|
113
109
|
"--use_memory_cache",
|
|
114
110
|
type=str,
|
|
115
|
-
default="
|
|
116
|
-
choices=["on
|
|
117
|
-
help="Cache encoded references codes in memory.\n"
|
|
118
|
-
"If `on-demand`, the server will use cached encodings\n "
|
|
119
|
-
"instead of encoding reference audio again.",
|
|
111
|
+
default="off",
|
|
112
|
+
choices=["on", "off"],
|
|
113
|
+
help="Cache encoded references codes in memory.\n",
|
|
120
114
|
)
|
|
121
115
|
parser.add_argument(
|
|
122
116
|
"--seed",
|
|
@@ -154,14 +148,14 @@ if __name__ == "__main__":
|
|
|
154
148
|
data = {
|
|
155
149
|
"text": args.text,
|
|
156
150
|
"references": [
|
|
157
|
-
ServeReferenceAudio(
|
|
151
|
+
ServeReferenceAudio(
|
|
152
|
+
audio=ref_audio if ref_audio is not None else b"", text=ref_text
|
|
153
|
+
)
|
|
158
154
|
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
|
159
155
|
],
|
|
160
156
|
"reference_id": idstr,
|
|
161
157
|
"normalize": args.normalize,
|
|
162
158
|
"format": args.format,
|
|
163
|
-
"mp3_bitrate": args.mp3_bitrate,
|
|
164
|
-
"opus_bitrate": args.opus_bitrate,
|
|
165
159
|
"max_new_tokens": args.max_new_tokens,
|
|
166
160
|
"chunk_length": args.chunk_length,
|
|
167
161
|
"top_p": args.top_p,
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
|
|
3
|
+
import pyrootutils
|
|
4
|
+
import uvicorn
|
|
5
|
+
from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
9
|
+
|
|
10
|
+
from tools.server.api_utils import MsgPackRequest, parse_args
|
|
11
|
+
from tools.server.exception_handler import ExceptionHandler
|
|
12
|
+
from tools.server.model_manager import ModelManager
|
|
13
|
+
from tools.server.views import (
|
|
14
|
+
ASRView,
|
|
15
|
+
ChatView,
|
|
16
|
+
HealthView,
|
|
17
|
+
TTSView,
|
|
18
|
+
VQGANDecodeView,
|
|
19
|
+
VQGANEncodeView,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class API(ExceptionHandler):
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self.args = parse_args()
|
|
26
|
+
self.routes = [
|
|
27
|
+
("/v1/health", HealthView),
|
|
28
|
+
("/v1/vqgan/encode", VQGANEncodeView),
|
|
29
|
+
("/v1/vqgan/decode", VQGANDecodeView),
|
|
30
|
+
("/v1/asr", ASRView),
|
|
31
|
+
("/v1/tts", TTSView),
|
|
32
|
+
("/v1/chat", ChatView),
|
|
33
|
+
]
|
|
34
|
+
self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
|
|
35
|
+
|
|
36
|
+
self.openapi = OpenAPI(
|
|
37
|
+
{
|
|
38
|
+
"title": "Fish Speech API",
|
|
39
|
+
"version": "1.5.0",
|
|
40
|
+
},
|
|
41
|
+
).routes
|
|
42
|
+
|
|
43
|
+
# Initialize the app
|
|
44
|
+
self.app = Kui(
|
|
45
|
+
routes=self.routes + self.openapi[1:], # Remove the default route
|
|
46
|
+
exception_handlers={
|
|
47
|
+
HTTPException: self.http_exception_handler,
|
|
48
|
+
Exception: self.other_exception_handler,
|
|
49
|
+
},
|
|
50
|
+
factory_class=FactoryClass(http=MsgPackRequest),
|
|
51
|
+
cors_config={},
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Add the state variables
|
|
55
|
+
self.app.state.lock = Lock()
|
|
56
|
+
self.app.state.device = self.args.device
|
|
57
|
+
self.app.state.max_text_length = self.args.max_text_length
|
|
58
|
+
|
|
59
|
+
# Associate the app with the model manager
|
|
60
|
+
self.app.on_startup(self.initialize_app)
|
|
61
|
+
|
|
62
|
+
async def initialize_app(self, app: Kui):
|
|
63
|
+
# Make the ModelManager available to the views
|
|
64
|
+
app.state.model_manager = ModelManager(
|
|
65
|
+
mode=self.args.mode,
|
|
66
|
+
device=self.args.device,
|
|
67
|
+
half=self.args.half,
|
|
68
|
+
compile=self.args.compile,
|
|
69
|
+
asr_enabled=self.args.load_asr_model,
|
|
70
|
+
llama_checkpoint_path=self.args.llama_checkpoint_path,
|
|
71
|
+
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
|
|
72
|
+
decoder_config_name=self.args.decoder_config_name,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
logger.info(f"Startup done, listening server at http://{self.args.listen}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Each worker process created by Uvicorn has its own memory space,
|
|
79
|
+
# meaning that models and variables are not shared between processes.
|
|
80
|
+
# Therefore, any variables (like `llama_queue` or `decoder_model`)
|
|
81
|
+
# will not be shared across workers.
|
|
82
|
+
|
|
83
|
+
# Multi-threading for deep learning can cause issues, such as inconsistent
|
|
84
|
+
# outputs if multiple threads access the same buffers simultaneously.
|
|
85
|
+
# Instead, it's better to use multiprocessing or independent models per thread.
|
|
86
|
+
|
|
87
|
+
if __name__ == "__main__":
|
|
88
|
+
|
|
89
|
+
api = API()
|
|
90
|
+
host, port = api.args.listen.split(":")
|
|
91
|
+
|
|
92
|
+
uvicorn.run(
|
|
93
|
+
api.app,
|
|
94
|
+
host=host,
|
|
95
|
+
port=int(port),
|
|
96
|
+
workers=api.args.workers,
|
|
97
|
+
log_level="info",
|
|
98
|
+
)
|
|
@@ -22,14 +22,14 @@ def check_and_download_files(repo_id, file_list, local_dir):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
# 1st
|
|
25
|
-
repo_id_1 = "fishaudio/fish-speech-1.
|
|
26
|
-
local_dir_1 = "./checkpoints/fish-speech-1.
|
|
25
|
+
repo_id_1 = "fishaudio/fish-speech-1.5"
|
|
26
|
+
local_dir_1 = "./checkpoints/fish-speech-1.5"
|
|
27
27
|
files_1 = [
|
|
28
|
+
"gitattributes",
|
|
28
29
|
"model.pth",
|
|
29
30
|
"README.md",
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"tokenizer.json",
|
|
31
|
+
"special_tokens.json",
|
|
32
|
+
"tokenizer.tiktoken",
|
|
33
33
|
"config.json",
|
|
34
34
|
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
35
35
|
]
|
|
@@ -14,8 +14,8 @@ import ormsgpack
|
|
|
14
14
|
import soundfile as sf
|
|
15
15
|
|
|
16
16
|
from .schema import (
|
|
17
|
+
ServeChatRequest,
|
|
17
18
|
ServeMessage,
|
|
18
|
-
ServeRequest,
|
|
19
19
|
ServeTextPart,
|
|
20
20
|
ServeVQGANDecodeRequest,
|
|
21
21
|
ServeVQGANEncodeRequest,
|
|
@@ -163,7 +163,7 @@ class FishE2EAgent:
|
|
|
163
163
|
else:
|
|
164
164
|
user_codes = None
|
|
165
165
|
|
|
166
|
-
request =
|
|
166
|
+
request = ServeChatRequest(
|
|
167
167
|
messages=prev_messages
|
|
168
168
|
+ (
|
|
169
169
|
[
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import queue
|
|
3
|
+
from typing import Generator
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
10
|
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
11
|
+
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
12
|
+
from tools.inference_engine.reference_loader import ReferenceLoader
|
|
13
|
+
from tools.inference_engine.utils import InferenceResult, wav_chunk_header
|
|
14
|
+
from tools.inference_engine.vq_manager import VQManager
|
|
15
|
+
from tools.llama.generate import (
|
|
16
|
+
GenerateRequest,
|
|
17
|
+
GenerateResponse,
|
|
18
|
+
WrappedGenerateResponse,
|
|
19
|
+
)
|
|
20
|
+
from tools.schema import ServeTTSRequest
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TTSInferenceEngine(ReferenceLoader, VQManager):
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
llama_queue: queue.Queue,
|
|
28
|
+
decoder_model: FireflyArchitecture,
|
|
29
|
+
precision: torch.dtype,
|
|
30
|
+
compile: bool,
|
|
31
|
+
) -> None:
|
|
32
|
+
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
self.llama_queue = llama_queue
|
|
36
|
+
self.decoder_model = decoder_model
|
|
37
|
+
self.precision = precision
|
|
38
|
+
self.compile = compile
|
|
39
|
+
|
|
40
|
+
@torch.inference_mode()
|
|
41
|
+
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
|
|
42
|
+
"""
|
|
43
|
+
Main inference function:
|
|
44
|
+
- Loads the reference audio and text.
|
|
45
|
+
- Calls the LLAMA model for inference.
|
|
46
|
+
- Decodes the VQ tokens to audio.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
ref_id: str | None = req.reference_id
|
|
50
|
+
prompt_tokens, prompt_texts = [], []
|
|
51
|
+
# Load the reference audio and text based on id or hash
|
|
52
|
+
if ref_id is not None:
|
|
53
|
+
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
|
|
54
|
+
|
|
55
|
+
elif req.references:
|
|
56
|
+
prompt_tokens, prompt_texts = self.load_by_hash(
|
|
57
|
+
req.references, req.use_memory_cache
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Set the random seed if provided
|
|
61
|
+
if req.seed is not None:
|
|
62
|
+
set_seed(req.seed)
|
|
63
|
+
logger.warning(f"set seed: {req.seed}")
|
|
64
|
+
|
|
65
|
+
# Get the symbolic tokens from the LLAMA model
|
|
66
|
+
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
|
|
67
|
+
|
|
68
|
+
# Get the sample rate from the decoder model
|
|
69
|
+
sample_rate = self.decoder_model.spec_transform.sample_rate
|
|
70
|
+
|
|
71
|
+
# If streaming, send the header
|
|
72
|
+
# if req.streaming:
|
|
73
|
+
# yield InferenceResult(
|
|
74
|
+
# code="header",
|
|
75
|
+
# audio=(sample_rate, wav_chunk_header(sample_rate=sample_rate)),
|
|
76
|
+
# error=None,
|
|
77
|
+
# )
|
|
78
|
+
|
|
79
|
+
segments = []
|
|
80
|
+
|
|
81
|
+
while True:
|
|
82
|
+
# Get the response from the LLAMA model
|
|
83
|
+
wrapped_result: WrappedGenerateResponse = response_queue.get()
|
|
84
|
+
if wrapped_result.status == "error":
|
|
85
|
+
yield InferenceResult(
|
|
86
|
+
code="error",
|
|
87
|
+
audio=None,
|
|
88
|
+
error=(
|
|
89
|
+
wrapped_result.response
|
|
90
|
+
if isinstance(wrapped_result.response, Exception)
|
|
91
|
+
else Exception("Unknown error")
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
break
|
|
95
|
+
|
|
96
|
+
# Check the response type
|
|
97
|
+
if not isinstance(wrapped_result.response, GenerateResponse):
|
|
98
|
+
raise TypeError(
|
|
99
|
+
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
result: GenerateResponse = wrapped_result.response
|
|
103
|
+
if result.action != "next":
|
|
104
|
+
segment = self.get_audio_segment(result)
|
|
105
|
+
|
|
106
|
+
if req.streaming: # Used only by the API server
|
|
107
|
+
yield InferenceResult(
|
|
108
|
+
code="segment",
|
|
109
|
+
audio=(sample_rate, segment),
|
|
110
|
+
error=None,
|
|
111
|
+
)
|
|
112
|
+
segments.append(segment)
|
|
113
|
+
else:
|
|
114
|
+
break
|
|
115
|
+
|
|
116
|
+
# Clean up the memory
|
|
117
|
+
if torch.cuda.is_available():
|
|
118
|
+
torch.cuda.empty_cache()
|
|
119
|
+
gc.collect()
|
|
120
|
+
|
|
121
|
+
# Edge case: no audio generated
|
|
122
|
+
if len(segments) == 0:
|
|
123
|
+
yield InferenceResult(
|
|
124
|
+
code="error",
|
|
125
|
+
audio=None,
|
|
126
|
+
error=RuntimeError("No audio generated, please check the input text."),
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
# Streaming or not, return the final audio
|
|
130
|
+
audio = np.concatenate(segments, axis=0)
|
|
131
|
+
yield InferenceResult(
|
|
132
|
+
code="final",
|
|
133
|
+
audio=(sample_rate, audio),
|
|
134
|
+
error=None,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
def send_Llama_request(
|
|
140
|
+
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
|
|
141
|
+
) -> queue.Queue:
|
|
142
|
+
"""
|
|
143
|
+
Send a request to the LLAMA model to generate the symbolic tokens.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
# Prepare the request
|
|
147
|
+
request = dict(
|
|
148
|
+
device=self.decoder_model.device,
|
|
149
|
+
max_new_tokens=req.max_new_tokens,
|
|
150
|
+
text=(
|
|
151
|
+
req.text
|
|
152
|
+
if not req.normalize
|
|
153
|
+
else ChnNormedText(raw_text=req.text).normalize()
|
|
154
|
+
),
|
|
155
|
+
top_p=req.top_p,
|
|
156
|
+
repetition_penalty=req.repetition_penalty,
|
|
157
|
+
temperature=req.temperature,
|
|
158
|
+
compile=self.compile,
|
|
159
|
+
iterative_prompt=req.chunk_length > 0,
|
|
160
|
+
chunk_length=req.chunk_length,
|
|
161
|
+
max_length=4096,
|
|
162
|
+
prompt_tokens=prompt_tokens,
|
|
163
|
+
prompt_text=prompt_texts,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Create a queue to get the response
|
|
167
|
+
response_queue = queue.Queue()
|
|
168
|
+
|
|
169
|
+
# Send the request to the LLAMA model
|
|
170
|
+
self.llama_queue.put(
|
|
171
|
+
GenerateRequest(
|
|
172
|
+
request=request,
|
|
173
|
+
response_queue=response_queue,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return response_queue
|
|
178
|
+
|
|
179
|
+
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
|
|
180
|
+
"""
|
|
181
|
+
Decode the VQ tokens to audio.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
# Don't use autocast on MPS devices
|
|
185
|
+
with autocast_exclude_mps(
|
|
186
|
+
device_type=self.decoder_model.device.type, dtype=self.precision
|
|
187
|
+
):
|
|
188
|
+
# Decode the symbolic tokens to audio
|
|
189
|
+
segment = self.decode_vq_tokens(codes=result.codes)
|
|
190
|
+
|
|
191
|
+
# Convert the audio to numpy
|
|
192
|
+
return segment.float().cpu().numpy()
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import io
|
|
2
|
+
from hashlib import sha256
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torchaudio
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
11
|
+
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
12
|
+
from tools.schema import ServeReferenceAudio
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ReferenceLoader:
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Component of the TTSInferenceEngine class.
|
|
20
|
+
Loads and manages the cache for the reference audio and text.
|
|
21
|
+
"""
|
|
22
|
+
self.ref_by_id: dict = {}
|
|
23
|
+
self.ref_by_hash: dict = {}
|
|
24
|
+
|
|
25
|
+
# Make Pylance happy (attribut/method not defined...)
|
|
26
|
+
self.decoder_model: FireflyArchitecture
|
|
27
|
+
self.encode_reference: Callable
|
|
28
|
+
|
|
29
|
+
# Define the torchaudio backend
|
|
30
|
+
backends = torchaudio.list_audio_backends()
|
|
31
|
+
if "ffmpeg" in backends:
|
|
32
|
+
self.backend = "ffmpeg"
|
|
33
|
+
else:
|
|
34
|
+
self.backend = "soundfile"
|
|
35
|
+
|
|
36
|
+
def load_by_id(
|
|
37
|
+
self,
|
|
38
|
+
id: str,
|
|
39
|
+
use_cache: Literal["on", "off"],
|
|
40
|
+
) -> Tuple:
|
|
41
|
+
|
|
42
|
+
# Load the references audio and text by id
|
|
43
|
+
ref_folder = Path("references") / id
|
|
44
|
+
ref_folder.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
ref_audios = list_files(
|
|
46
|
+
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if use_cache == "off" or id not in self.ref_by_id:
|
|
50
|
+
# If the references are not already loaded, encode them
|
|
51
|
+
prompt_tokens = [
|
|
52
|
+
self.encode_reference(
|
|
53
|
+
# decoder_model=self.decoder_model,
|
|
54
|
+
reference_audio=audio_to_bytes(str(ref_audio)),
|
|
55
|
+
enable_reference_audio=True,
|
|
56
|
+
)
|
|
57
|
+
for ref_audio in ref_audios
|
|
58
|
+
]
|
|
59
|
+
prompt_texts = [
|
|
60
|
+
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
61
|
+
for ref_audio in ref_audios
|
|
62
|
+
]
|
|
63
|
+
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
|
|
64
|
+
|
|
65
|
+
else:
|
|
66
|
+
# Reuse already encoded references
|
|
67
|
+
logger.info("Use same references")
|
|
68
|
+
prompt_tokens, prompt_texts = self.ref_by_id[id]
|
|
69
|
+
|
|
70
|
+
return prompt_tokens, prompt_texts
|
|
71
|
+
|
|
72
|
+
def load_by_hash(
|
|
73
|
+
self,
|
|
74
|
+
references: list[ServeReferenceAudio],
|
|
75
|
+
use_cache: Literal["on", "off"],
|
|
76
|
+
) -> Tuple:
|
|
77
|
+
|
|
78
|
+
# Load the references audio and text by hash
|
|
79
|
+
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
|
|
80
|
+
|
|
81
|
+
cache_used = False
|
|
82
|
+
prompt_tokens, prompt_texts = [], []
|
|
83
|
+
for i, ref in enumerate(references):
|
|
84
|
+
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
|
|
85
|
+
# If the references are not already loaded, encode them
|
|
86
|
+
prompt_tokens.append(
|
|
87
|
+
self.encode_reference(
|
|
88
|
+
reference_audio=ref.audio,
|
|
89
|
+
enable_reference_audio=True,
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
prompt_texts.append(ref.text)
|
|
93
|
+
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
# Reuse already encoded references
|
|
97
|
+
prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
|
|
98
|
+
cache_used = True
|
|
99
|
+
|
|
100
|
+
if cache_used:
|
|
101
|
+
logger.info("Use same references")
|
|
102
|
+
|
|
103
|
+
return prompt_tokens, prompt_texts
|
|
104
|
+
|
|
105
|
+
def load_audio(self, reference_audio, sr):
|
|
106
|
+
"""
|
|
107
|
+
Load the audio data from a file or bytes.
|
|
108
|
+
"""
|
|
109
|
+
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
110
|
+
audio_data = reference_audio
|
|
111
|
+
reference_audio = io.BytesIO(audio_data)
|
|
112
|
+
|
|
113
|
+
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
|
|
114
|
+
|
|
115
|
+
if waveform.shape[0] > 1:
|
|
116
|
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
117
|
+
|
|
118
|
+
if original_sr != sr:
|
|
119
|
+
resampler = torchaudio.transforms.Resample(
|
|
120
|
+
orig_freq=original_sr, new_freq=sr
|
|
121
|
+
)
|
|
122
|
+
waveform = resampler(waveform)
|
|
123
|
+
|
|
124
|
+
audio = waveform.squeeze().numpy()
|
|
125
|
+
return audio
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import wave
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class InferenceResult:
|
|
13
|
+
code: Literal["header", "segment", "error", "final"]
|
|
14
|
+
audio: Optional[Tuple[int, np.ndarray | bytes]]
|
|
15
|
+
error: Optional[Exception]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def normalize_text(user_input: str, use_normalization: bool) -> str:
|
|
19
|
+
"""Normalize user input text if needed."""
|
|
20
|
+
if use_normalization:
|
|
21
|
+
return ChnNormedText(raw_text=user_input).normalize()
|
|
22
|
+
else:
|
|
23
|
+
return user_input
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def wav_chunk_header(
|
|
27
|
+
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
|
|
28
|
+
) -> bytes:
|
|
29
|
+
buffer = io.BytesIO()
|
|
30
|
+
|
|
31
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
32
|
+
wav_file.setnchannels(channels)
|
|
33
|
+
wav_file.setsampwidth(bit_depth // 8)
|
|
34
|
+
wav_file.setframerate(sample_rate)
|
|
35
|
+
|
|
36
|
+
wav_header_bytes = buffer.getvalue()
|
|
37
|
+
buffer.close()
|
|
38
|
+
|
|
39
|
+
return wav_header_bytes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VQManager:
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# Make Pylance happy (attribut/method not defined...)
|
|
13
|
+
self.decoder_model: FireflyArchitecture
|
|
14
|
+
self.load_audio: Callable
|
|
15
|
+
|
|
16
|
+
def decode_vq_tokens(self, codes):
|
|
17
|
+
feature_lengths = torch.tensor(
|
|
18
|
+
[codes.shape[1]], device=self.decoder_model.device
|
|
19
|
+
)
|
|
20
|
+
logger.info(f"VQ features: {codes.shape}")
|
|
21
|
+
|
|
22
|
+
if isinstance(self.decoder_model, FireflyArchitecture):
|
|
23
|
+
return self.decoder_model.decode(
|
|
24
|
+
indices=codes[None],
|
|
25
|
+
feature_lengths=feature_lengths,
|
|
26
|
+
)[0].squeeze()
|
|
27
|
+
|
|
28
|
+
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
|
29
|
+
|
|
30
|
+
def encode_reference(self, reference_audio, enable_reference_audio):
|
|
31
|
+
if enable_reference_audio and reference_audio is not None:
|
|
32
|
+
# Load audios, and prepare basic info here
|
|
33
|
+
reference_audio_content = self.load_audio(
|
|
34
|
+
reference_audio, self.decoder_model.spec_transform.sample_rate
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
audios = torch.from_numpy(reference_audio_content).to(
|
|
38
|
+
self.decoder_model.device
|
|
39
|
+
)[None, None, :]
|
|
40
|
+
audio_lengths = torch.tensor(
|
|
41
|
+
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
|
|
42
|
+
)
|
|
43
|
+
logger.info(
|
|
44
|
+
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# VQ Encoder
|
|
48
|
+
if isinstance(self.decoder_model, FireflyArchitecture):
|
|
49
|
+
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
|
|
50
|
+
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
|
|
53
|
+
else:
|
|
54
|
+
prompt_tokens = None
|
|
55
|
+
logger.info("No reference audio provided")
|
|
56
|
+
|
|
57
|
+
return prompt_tokens
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import pyrootutils
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn.functional as F
|
|
4
4
|
from matplotlib import pyplot as plt
|
|
5
5
|
from transformers import AutoTokenizer
|
|
6
6
|
|
|
7
7
|
# register eval resolver and root
|
|
8
|
-
|
|
8
|
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
9
9
|
|
|
10
10
|
from torch.utils.data import DataLoader
|
|
11
11
|
|