xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-08-
|
|
11
|
+
"date": "2024-08-30T18:54:16+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.14.
|
|
14
|
+
"full-revisionid": "f3d510eceffbbbc41ce065919fd2c48411017573",
|
|
15
|
+
"version": "0.14.4"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -340,7 +340,7 @@ class GradioInterface:
|
|
|
340
340
|
state = gr.State([])
|
|
341
341
|
with gr.Row():
|
|
342
342
|
chatbot = gr.Chatbot(
|
|
343
|
-
elem_id="chatbot", label=self.model_name, height=
|
|
343
|
+
elem_id="chatbot", label=self.model_name, height=700, scale=7
|
|
344
344
|
)
|
|
345
345
|
with gr.Column(scale=3):
|
|
346
346
|
imagebox = gr.Image(type="filepath")
|
|
@@ -163,6 +163,7 @@ class ImageInterface:
|
|
|
163
163
|
size_width: int,
|
|
164
164
|
size_height: int,
|
|
165
165
|
num_inference_steps: int,
|
|
166
|
+
padding_image_to_multiple: int,
|
|
166
167
|
) -> PIL.Image.Image:
|
|
167
168
|
from ..client import RESTfulClient
|
|
168
169
|
|
|
@@ -178,6 +179,7 @@ class ImageInterface:
|
|
|
178
179
|
num_inference_steps = (
|
|
179
180
|
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
180
181
|
)
|
|
182
|
+
padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
|
|
181
183
|
|
|
182
184
|
bio = io.BytesIO()
|
|
183
185
|
image.save(bio, format="png")
|
|
@@ -190,6 +192,7 @@ class ImageInterface:
|
|
|
190
192
|
size=size,
|
|
191
193
|
response_format="b64_json",
|
|
192
194
|
num_inference_steps=num_inference_steps,
|
|
195
|
+
padding_image_to_multiple=padding_image_to_multiple,
|
|
193
196
|
)
|
|
194
197
|
|
|
195
198
|
images = []
|
|
@@ -222,9 +225,14 @@ class ImageInterface:
|
|
|
222
225
|
n = gr.Number(label="Number of image", value=1)
|
|
223
226
|
size_width = gr.Number(label="Width", value=-1)
|
|
224
227
|
size_height = gr.Number(label="Height", value=-1)
|
|
228
|
+
|
|
229
|
+
with gr.Row():
|
|
225
230
|
num_inference_steps = gr.Number(
|
|
226
231
|
label="Inference Step Number", value=-1
|
|
227
232
|
)
|
|
233
|
+
padding_image_to_multiple = gr.Number(
|
|
234
|
+
label="Padding image to multiple", value=-1
|
|
235
|
+
)
|
|
228
236
|
|
|
229
237
|
with gr.Row():
|
|
230
238
|
with gr.Column(scale=1):
|
|
@@ -242,6 +250,7 @@ class ImageInterface:
|
|
|
242
250
|
size_width,
|
|
243
251
|
size_height,
|
|
244
252
|
num_inference_steps,
|
|
253
|
+
padding_image_to_multiple,
|
|
245
254
|
],
|
|
246
255
|
outputs=output_gallery,
|
|
247
256
|
)
|
xinference/core/model.py
CHANGED
|
@@ -177,6 +177,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
177
177
|
request_limits: Optional[int] = None,
|
|
178
178
|
):
|
|
179
179
|
super().__init__()
|
|
180
|
+
from ..model.llm.lmdeploy.core import LMDeployModel
|
|
180
181
|
from ..model.llm.sglang.core import SGLANGModel
|
|
181
182
|
from ..model.llm.transformers.core import PytorchModel
|
|
182
183
|
from ..model.llm.vllm.core import VLLMModel
|
|
@@ -192,7 +193,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
192
193
|
self._current_generator = lambda: None
|
|
193
194
|
self._lock = (
|
|
194
195
|
None
|
|
195
|
-
if isinstance(
|
|
196
|
+
if isinstance(
|
|
197
|
+
self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel)
|
|
198
|
+
)
|
|
196
199
|
else asyncio.locks.Lock()
|
|
197
200
|
)
|
|
198
201
|
self._worker_ref = None
|
xinference/core/worker.py
CHANGED
|
@@ -39,9 +39,11 @@ from ..core.status_guard import LaunchStatus
|
|
|
39
39
|
from ..device_utils import get_available_device_env_name, gpu_count
|
|
40
40
|
from ..model.core import ModelDescription, create_model_instance
|
|
41
41
|
from ..types import PeftModelConfig
|
|
42
|
+
from .cache_tracker import CacheTrackerActor
|
|
42
43
|
from .event import Event, EventCollectorActor, EventType
|
|
43
44
|
from .metrics import launch_metrics_export_server, record_metrics
|
|
44
45
|
from .resource import gather_node_info
|
|
46
|
+
from .status_guard import StatusGuardActor
|
|
45
47
|
from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
|
|
46
48
|
|
|
47
49
|
logger = getLogger(__name__)
|
|
@@ -71,6 +73,15 @@ class WorkerActor(xo.StatelessActor):
|
|
|
71
73
|
self._supervisor_ref: Optional[xo.ActorRefType] = None
|
|
72
74
|
self._main_pool = main_pool
|
|
73
75
|
self._main_pool.recover_sub_pool = self.recover_sub_pool
|
|
76
|
+
self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
|
|
77
|
+
None
|
|
78
|
+
)
|
|
79
|
+
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
80
|
+
EventCollectorActor
|
|
81
|
+
] = None
|
|
82
|
+
self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
|
|
83
|
+
None
|
|
84
|
+
)
|
|
74
85
|
|
|
75
86
|
# internal states.
|
|
76
87
|
# temporary placeholder during model launch process:
|
|
@@ -135,7 +146,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
135
146
|
else:
|
|
136
147
|
recover_count = self._model_uid_to_recover_count.get(model_uid)
|
|
137
148
|
try:
|
|
138
|
-
await self.terminate_model(model_uid)
|
|
149
|
+
await self.terminate_model(model_uid, is_model_die=True)
|
|
139
150
|
except Exception:
|
|
140
151
|
pass
|
|
141
152
|
if recover_count is not None:
|
|
@@ -308,56 +319,50 @@ class WorkerActor(xo.StatelessActor):
|
|
|
308
319
|
Params:
|
|
309
320
|
add_worker: By default will call supervisor.add_worker after first connect
|
|
310
321
|
"""
|
|
311
|
-
from .status_guard import StatusGuardActor
|
|
312
322
|
from .supervisor import SupervisorActor
|
|
313
323
|
|
|
314
324
|
if self._supervisor_ref is not None:
|
|
315
325
|
return self._supervisor_ref
|
|
316
|
-
|
|
326
|
+
supervisor_ref = await xo.actor_ref( # type: ignore
|
|
317
327
|
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
318
328
|
)
|
|
329
|
+
# Prevent concurrent operations leads to double initialization, check again.
|
|
330
|
+
if self._supervisor_ref is not None:
|
|
331
|
+
return self._supervisor_ref
|
|
332
|
+
self._supervisor_ref = supervisor_ref
|
|
319
333
|
if add_worker and len(self._model_uid_to_model) == 0:
|
|
320
334
|
# Newly started (or restarted), has no model, notify supervisor
|
|
321
335
|
await self._supervisor_ref.add_worker(self.address)
|
|
322
336
|
logger.info("Connected to supervisor as a fresh worker")
|
|
323
337
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
model_version_infos.update(get_llm_model_descriptions())
|
|
353
|
-
model_version_infos.update(get_embedding_model_descriptions())
|
|
354
|
-
model_version_infos.update(get_rerank_model_descriptions())
|
|
355
|
-
model_version_infos.update(get_image_model_descriptions())
|
|
356
|
-
model_version_infos.update(get_audio_model_descriptions())
|
|
357
|
-
model_version_infos.update(get_flexible_model_descriptions())
|
|
358
|
-
await self._cache_tracker_ref.record_model_version(
|
|
359
|
-
model_version_infos, self.address
|
|
360
|
-
)
|
|
338
|
+
self._status_guard_ref = await xo.actor_ref(
|
|
339
|
+
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
340
|
+
)
|
|
341
|
+
self._event_collector_ref = await xo.actor_ref(
|
|
342
|
+
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
343
|
+
)
|
|
344
|
+
self._cache_tracker_ref = await xo.actor_ref(
|
|
345
|
+
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
346
|
+
)
|
|
347
|
+
# cache_tracker is on supervisor
|
|
348
|
+
from ..model.audio import get_audio_model_descriptions
|
|
349
|
+
from ..model.embedding import get_embedding_model_descriptions
|
|
350
|
+
from ..model.flexible import get_flexible_model_descriptions
|
|
351
|
+
from ..model.image import get_image_model_descriptions
|
|
352
|
+
from ..model.llm import get_llm_model_descriptions
|
|
353
|
+
from ..model.rerank import get_rerank_model_descriptions
|
|
354
|
+
|
|
355
|
+
# record model version
|
|
356
|
+
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
|
|
357
|
+
model_version_infos.update(get_llm_model_descriptions())
|
|
358
|
+
model_version_infos.update(get_embedding_model_descriptions())
|
|
359
|
+
model_version_infos.update(get_rerank_model_descriptions())
|
|
360
|
+
model_version_infos.update(get_image_model_descriptions())
|
|
361
|
+
model_version_infos.update(get_audio_model_descriptions())
|
|
362
|
+
model_version_infos.update(get_flexible_model_descriptions())
|
|
363
|
+
await self._cache_tracker_ref.record_model_version(
|
|
364
|
+
model_version_infos, self.address
|
|
365
|
+
)
|
|
361
366
|
return self._supervisor_ref
|
|
362
367
|
|
|
363
368
|
@staticmethod
|
|
@@ -659,6 +664,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
659
664
|
|
|
660
665
|
ret.sort(key=sort_helper)
|
|
661
666
|
return ret
|
|
667
|
+
elif model_type == "video":
|
|
668
|
+
return []
|
|
662
669
|
elif model_type == "rerank":
|
|
663
670
|
from ..model.rerank.custom import get_user_defined_reranks
|
|
664
671
|
|
|
@@ -698,6 +705,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
698
705
|
for f in get_user_defined_audios():
|
|
699
706
|
if f.model_name == model_name:
|
|
700
707
|
return f
|
|
708
|
+
elif model_type == "video":
|
|
709
|
+
return None
|
|
701
710
|
elif model_type == "rerank":
|
|
702
711
|
from ..model.rerank.custom import get_user_defined_reranks
|
|
703
712
|
|
|
@@ -734,7 +743,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
734
743
|
elif model_type == "image":
|
|
735
744
|
return ["text_to_image"]
|
|
736
745
|
elif model_type == "audio":
|
|
737
|
-
return [
|
|
746
|
+
return [model._model_spec.ability]
|
|
738
747
|
elif model_type == "video":
|
|
739
748
|
return ["text_to_video"]
|
|
740
749
|
elif model_type == "flexible":
|
|
@@ -793,6 +802,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
793
802
|
logger.exception(e)
|
|
794
803
|
raise
|
|
795
804
|
try:
|
|
805
|
+
_ = await self.get_supervisor_ref()
|
|
796
806
|
if self._event_collector_ref is not None:
|
|
797
807
|
await self._event_collector_ref.report_event(
|
|
798
808
|
origin_uid,
|
|
@@ -908,12 +918,13 @@ class WorkerActor(xo.StatelessActor):
|
|
|
908
918
|
)
|
|
909
919
|
|
|
910
920
|
@log_async(logger=logger)
|
|
911
|
-
async def terminate_model(self, model_uid: str):
|
|
921
|
+
async def terminate_model(self, model_uid: str, is_model_die=False):
|
|
912
922
|
# Terminate model while its launching is not allow
|
|
913
923
|
if model_uid in self._model_uid_launching_guard:
|
|
914
924
|
raise ValueError(f"{model_uid} is launching")
|
|
915
925
|
origin_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
916
926
|
try:
|
|
927
|
+
_ = await self.get_supervisor_ref()
|
|
917
928
|
if self._event_collector_ref is not None:
|
|
918
929
|
await self._event_collector_ref.report_event(
|
|
919
930
|
origin_uid,
|
|
@@ -956,11 +967,16 @@ class WorkerActor(xo.StatelessActor):
|
|
|
956
967
|
self._model_uid_to_recover_count.pop(model_uid, None)
|
|
957
968
|
self._model_uid_to_launch_args.pop(model_uid, None)
|
|
958
969
|
|
|
970
|
+
if is_model_die:
|
|
971
|
+
status = LaunchStatus.ERROR.name
|
|
972
|
+
else:
|
|
973
|
+
status = LaunchStatus.TERMINATED.name
|
|
974
|
+
|
|
959
975
|
if self._status_guard_ref is None:
|
|
960
976
|
_ = await self.get_supervisor_ref()
|
|
961
977
|
assert self._status_guard_ref is not None
|
|
962
978
|
await self._status_guard_ref.update_instance_info(
|
|
963
|
-
origin_uid, {"status":
|
|
979
|
+
origin_uid, {"status": status}
|
|
964
980
|
)
|
|
965
981
|
|
|
966
982
|
# Provide an interface for future version of supervisor to call
|
|
@@ -1081,7 +1097,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
1081
1097
|
paths.update([os.path.realpath(path) for path in paths])
|
|
1082
1098
|
|
|
1083
1099
|
# get tensorizer path
|
|
1084
|
-
from ..model.llm.
|
|
1100
|
+
from ..model.llm.transformers.tensorizer_utils import get_tensorizer_dir
|
|
1085
1101
|
|
|
1086
1102
|
tensorizer_path = get_tensorizer_dir(path)
|
|
1087
1103
|
if os.path.isdir(tensorizer_path):
|
|
@@ -11,10 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
14
16
|
import logging
|
|
15
17
|
from io import BytesIO
|
|
16
18
|
from typing import TYPE_CHECKING, Optional
|
|
17
19
|
|
|
20
|
+
from ..utils import set_all_random_seed
|
|
21
|
+
|
|
18
22
|
if TYPE_CHECKING:
|
|
19
23
|
from .core import AudioModelFamilyV1
|
|
20
24
|
|
|
@@ -61,16 +65,29 @@ class ChatTTSModel:
|
|
|
61
65
|
import torchaudio
|
|
62
66
|
import xxhash
|
|
63
67
|
|
|
64
|
-
|
|
68
|
+
rnd_spk_emb = None
|
|
65
69
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
70
|
+
if len(voice) > 400:
|
|
71
|
+
try:
|
|
72
|
+
assert self._model is not None
|
|
73
|
+
b = base64.b64decode(voice)
|
|
74
|
+
bio = BytesIO(b)
|
|
75
|
+
tensor = torch.load(bio, map_location="cpu")
|
|
76
|
+
rnd_spk_emb = self._model._encode_spk_emb(tensor)
|
|
77
|
+
logger.info("Speech by input speaker")
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.info("Fallback to random speaker due to %s", e)
|
|
71
80
|
|
|
72
|
-
|
|
73
|
-
|
|
81
|
+
if rnd_spk_emb is None:
|
|
82
|
+
seed = xxhash.xxh32_intdigest(voice)
|
|
83
|
+
|
|
84
|
+
set_all_random_seed(seed)
|
|
85
|
+
torch.backends.cudnn.deterministic = True
|
|
86
|
+
torch.backends.cudnn.benchmark = False
|
|
87
|
+
|
|
88
|
+
assert self._model is not None
|
|
89
|
+
rnd_spk_emb = self._model.sample_random_speaker()
|
|
90
|
+
logger.info("Speech by voice %s", voice)
|
|
74
91
|
|
|
75
92
|
default = 5
|
|
76
93
|
infer_speed = int(default * speed)
|
|
@@ -100,7 +117,6 @@ class ChatTTSModel:
|
|
|
100
117
|
if new_last_pos != last_pos:
|
|
101
118
|
out.seek(last_pos)
|
|
102
119
|
encoded_bytes = out.read()
|
|
103
|
-
print(len(encoded_bytes))
|
|
104
120
|
yield encoded_bytes
|
|
105
121
|
last_pos = new_last_pos
|
|
106
122
|
|
xinference/model/audio/core.py
CHANGED
|
@@ -21,6 +21,7 @@ from ..core import CacheableModelSpec, ModelDescription
|
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
|
+
from .fish_speech import FishSpeechModel
|
|
24
25
|
from .funasr import FunASRModel
|
|
25
26
|
from .whisper import WhisperModel
|
|
26
27
|
|
|
@@ -46,6 +47,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
46
47
|
model_id: str
|
|
47
48
|
model_revision: str
|
|
48
49
|
multilingual: bool
|
|
50
|
+
ability: str
|
|
49
51
|
default_model_config: Optional[Dict[str, Any]]
|
|
50
52
|
default_transcription_config: Optional[Dict[str, Any]]
|
|
51
53
|
|
|
@@ -156,13 +158,15 @@ def create_audio_model_instance(
|
|
|
156
158
|
model_path: Optional[str] = None,
|
|
157
159
|
**kwargs,
|
|
158
160
|
) -> Tuple[
|
|
159
|
-
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
|
|
161
|
+
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
|
|
160
162
|
AudioModelDescription,
|
|
161
163
|
]:
|
|
162
164
|
model_spec = match_audio(model_name, download_hub)
|
|
163
165
|
if model_path is None:
|
|
164
166
|
model_path = cache(model_spec)
|
|
165
|
-
model: Union[
|
|
167
|
+
model: Union[
|
|
168
|
+
WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
|
|
169
|
+
]
|
|
166
170
|
if model_spec.model_family == "whisper":
|
|
167
171
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
168
172
|
elif model_spec.model_family == "funasr":
|
|
@@ -171,6 +175,8 @@ def create_audio_model_instance(
|
|
|
171
175
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
172
176
|
elif model_spec.model_family == "CosyVoice":
|
|
173
177
|
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
|
|
178
|
+
elif model_spec.model_family == "FishAudio":
|
|
179
|
+
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
|
|
174
180
|
else:
|
|
175
181
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
176
182
|
model_description = AudioModelDescription(
|
|
@@ -16,6 +16,8 @@ import logging
|
|
|
16
16
|
from io import BytesIO
|
|
17
17
|
from typing import TYPE_CHECKING, Optional
|
|
18
18
|
|
|
19
|
+
from ..utils import set_all_random_seed
|
|
20
|
+
|
|
19
21
|
if TYPE_CHECKING:
|
|
20
22
|
from .core import AudioModelFamilyV1
|
|
21
23
|
|
|
@@ -67,6 +69,7 @@ class CosyVoiceModel:
|
|
|
67
69
|
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
68
70
|
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
69
71
|
instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
|
|
72
|
+
seed: Optional[int] = kwargs.pop("seed", 0)
|
|
70
73
|
|
|
71
74
|
if "SFT" in self._model_spec.model_name:
|
|
72
75
|
# inference_sft
|
|
@@ -87,9 +90,6 @@ class CosyVoiceModel:
|
|
|
87
90
|
assert (
|
|
88
91
|
prompt_text is None
|
|
89
92
|
), "CosyVoice Instruct model does not support prompt_text"
|
|
90
|
-
assert (
|
|
91
|
-
instruct_text is not None
|
|
92
|
-
), "CosyVoice Instruct model expect a instruct_text"
|
|
93
93
|
else:
|
|
94
94
|
# inference_zero_shot
|
|
95
95
|
# inference_cross_lingual
|
|
@@ -99,6 +99,7 @@ class CosyVoiceModel:
|
|
|
99
99
|
), "CosyVoice model does not support instruct_text"
|
|
100
100
|
|
|
101
101
|
assert self._model is not None
|
|
102
|
+
set_all_random_seed(seed)
|
|
102
103
|
if prompt_speech:
|
|
103
104
|
assert not voice, "voice can't be set with prompt speech."
|
|
104
105
|
with io.BytesIO(prompt_speech) as prompt_speech_io:
|
xinference/model/audio/custom.py
CHANGED
|
@@ -88,6 +88,10 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
|
|
|
88
88
|
if not is_valid_model_name(model_spec.model_name):
|
|
89
89
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
90
90
|
|
|
91
|
+
model_uri = model_spec.model_uri
|
|
92
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
93
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
94
|
+
|
|
91
95
|
with UD_AUDIO_LOCK:
|
|
92
96
|
for model_name in (
|
|
93
97
|
list(BUILTIN_AUDIO_MODELS.keys())
|
|
@@ -102,11 +106,6 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
|
|
|
102
106
|
UD_AUDIOS.append(model_spec)
|
|
103
107
|
|
|
104
108
|
if persist:
|
|
105
|
-
# We only validate model URL when persist is True.
|
|
106
|
-
model_uri = model_spec.model_uri
|
|
107
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
108
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
109
|
-
|
|
110
109
|
persist_path = os.path.join(
|
|
111
110
|
XINFERENCE_MODEL_DIR, "audio", f"{model_spec.model_name}.json"
|
|
112
111
|
)
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import gc
|
|
15
|
+
import logging
|
|
16
|
+
import os.path
|
|
17
|
+
import queue
|
|
18
|
+
import sys
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from typing import TYPE_CHECKING, Optional
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
from ...device_utils import get_available_device, is_device_available
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from .core import AudioModelFamilyV1
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
34
|
+
import wave
|
|
35
|
+
|
|
36
|
+
buffer = BytesIO()
|
|
37
|
+
|
|
38
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
39
|
+
wav_file.setnchannels(channels)
|
|
40
|
+
wav_file.setsampwidth(bit_depth // 8)
|
|
41
|
+
wav_file.setframerate(sample_rate)
|
|
42
|
+
|
|
43
|
+
wav_header_bytes = buffer.getvalue()
|
|
44
|
+
buffer.close()
|
|
45
|
+
return wav_header_bytes
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FishSpeechModel:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model_uid: str,
|
|
52
|
+
model_path: str,
|
|
53
|
+
model_spec: "AudioModelFamilyV1",
|
|
54
|
+
device: Optional[str] = None,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
self._model_uid = model_uid
|
|
58
|
+
self._model_path = model_path
|
|
59
|
+
self._model_spec = model_spec
|
|
60
|
+
self._device = device
|
|
61
|
+
self._llama_queue = None
|
|
62
|
+
self._model = None
|
|
63
|
+
self._kwargs = kwargs
|
|
64
|
+
|
|
65
|
+
def load(self):
|
|
66
|
+
# There are too many imports from fish_speech.
|
|
67
|
+
sys.path.insert(
|
|
68
|
+
0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
from tools.llama.generate import launch_thread_safe_queue
|
|
72
|
+
from tools.vqgan.inference import load_model as load_decoder_model
|
|
73
|
+
|
|
74
|
+
if self._device is None:
|
|
75
|
+
self._device = get_available_device()
|
|
76
|
+
else:
|
|
77
|
+
if not is_device_available(self._device):
|
|
78
|
+
raise ValueError(f"Device {self._device} is not available!")
|
|
79
|
+
|
|
80
|
+
logger.info("Loading Llama model...")
|
|
81
|
+
self._llama_queue = launch_thread_safe_queue(
|
|
82
|
+
checkpoint_path=self._model_path,
|
|
83
|
+
device=self._device,
|
|
84
|
+
precision=torch.bfloat16,
|
|
85
|
+
compile=False,
|
|
86
|
+
)
|
|
87
|
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
88
|
+
|
|
89
|
+
checkpoint_path = os.path.join(
|
|
90
|
+
self._model_path,
|
|
91
|
+
"firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
92
|
+
)
|
|
93
|
+
self._model = load_decoder_model(
|
|
94
|
+
config_name="firefly_gan_vq",
|
|
95
|
+
checkpoint_path=checkpoint_path,
|
|
96
|
+
device=self._device,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@torch.inference_mode()
|
|
100
|
+
def _inference(
|
|
101
|
+
self,
|
|
102
|
+
text,
|
|
103
|
+
enable_reference_audio,
|
|
104
|
+
reference_audio,
|
|
105
|
+
reference_text,
|
|
106
|
+
max_new_tokens,
|
|
107
|
+
chunk_length,
|
|
108
|
+
top_p,
|
|
109
|
+
repetition_penalty,
|
|
110
|
+
temperature,
|
|
111
|
+
streaming=False,
|
|
112
|
+
):
|
|
113
|
+
from fish_speech.utils import autocast_exclude_mps
|
|
114
|
+
from tools.api import decode_vq_tokens, encode_reference
|
|
115
|
+
from tools.llama.generate import (
|
|
116
|
+
GenerateRequest,
|
|
117
|
+
GenerateResponse,
|
|
118
|
+
WrappedGenerateResponse,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Parse reference audio aka prompt
|
|
122
|
+
prompt_tokens = encode_reference(
|
|
123
|
+
decoder_model=self._model,
|
|
124
|
+
reference_audio=reference_audio,
|
|
125
|
+
enable_reference_audio=enable_reference_audio,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# LLAMA Inference
|
|
129
|
+
request = dict(
|
|
130
|
+
device=self._model.device,
|
|
131
|
+
max_new_tokens=max_new_tokens,
|
|
132
|
+
text=text,
|
|
133
|
+
top_p=top_p,
|
|
134
|
+
repetition_penalty=repetition_penalty,
|
|
135
|
+
temperature=temperature,
|
|
136
|
+
compile=False,
|
|
137
|
+
iterative_prompt=chunk_length > 0,
|
|
138
|
+
chunk_length=chunk_length,
|
|
139
|
+
max_length=2048,
|
|
140
|
+
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
141
|
+
prompt_text=reference_text if enable_reference_audio else None,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
response_queue = queue.Queue()
|
|
145
|
+
self._llama_queue.put(
|
|
146
|
+
GenerateRequest(
|
|
147
|
+
request=request,
|
|
148
|
+
response_queue=response_queue,
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if streaming:
|
|
153
|
+
yield wav_chunk_header(), None, None
|
|
154
|
+
|
|
155
|
+
segments = []
|
|
156
|
+
|
|
157
|
+
while True:
|
|
158
|
+
result: WrappedGenerateResponse = response_queue.get()
|
|
159
|
+
if result.status == "error":
|
|
160
|
+
raise Exception(str(result.response))
|
|
161
|
+
|
|
162
|
+
result: GenerateResponse = result.response
|
|
163
|
+
if result.action == "next":
|
|
164
|
+
break
|
|
165
|
+
|
|
166
|
+
with autocast_exclude_mps(
|
|
167
|
+
device_type=self._model.device.type, dtype=torch.bfloat16
|
|
168
|
+
):
|
|
169
|
+
fake_audios = decode_vq_tokens(
|
|
170
|
+
decoder_model=self._model,
|
|
171
|
+
codes=result.codes,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
fake_audios = fake_audios.float().cpu().numpy()
|
|
175
|
+
segments.append(fake_audios)
|
|
176
|
+
|
|
177
|
+
if streaming:
|
|
178
|
+
yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
|
|
179
|
+
|
|
180
|
+
if len(segments) == 0:
|
|
181
|
+
raise Exception("No audio generated, please check the input text.")
|
|
182
|
+
|
|
183
|
+
# No matter streaming or not, we need to return the final audio
|
|
184
|
+
audio = np.concatenate(segments, axis=0)
|
|
185
|
+
yield None, (self._model.spec_transform.sample_rate, audio), None
|
|
186
|
+
|
|
187
|
+
if torch.cuda.is_available():
|
|
188
|
+
torch.cuda.empty_cache()
|
|
189
|
+
gc.collect()
|
|
190
|
+
|
|
191
|
+
def speech(
|
|
192
|
+
self,
|
|
193
|
+
input: str,
|
|
194
|
+
voice: str,
|
|
195
|
+
response_format: str = "mp3",
|
|
196
|
+
speed: float = 1.0,
|
|
197
|
+
stream: bool = False,
|
|
198
|
+
**kwargs,
|
|
199
|
+
):
|
|
200
|
+
logger.warning("Fish speech does not support setting voice: %s.", voice)
|
|
201
|
+
if speed != 1.0:
|
|
202
|
+
logger.warning("Fish speech does not support setting speed: %s.", speed)
|
|
203
|
+
if stream is True:
|
|
204
|
+
logger.warning("stream mode is not implemented.")
|
|
205
|
+
import torchaudio
|
|
206
|
+
|
|
207
|
+
result = list(
|
|
208
|
+
self._inference(
|
|
209
|
+
text=input,
|
|
210
|
+
enable_reference_audio=False,
|
|
211
|
+
reference_audio=None,
|
|
212
|
+
reference_text="",
|
|
213
|
+
max_new_tokens=0,
|
|
214
|
+
chunk_length=100,
|
|
215
|
+
top_p=0.7,
|
|
216
|
+
repetition_penalty=1.2,
|
|
217
|
+
temperature=0.7,
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
sample_rate, audio = result[0][1]
|
|
221
|
+
audio = np.array([audio])
|
|
222
|
+
|
|
223
|
+
# Save the generated audio
|
|
224
|
+
with BytesIO() as out:
|
|
225
|
+
torchaudio.save(
|
|
226
|
+
out, torch.from_numpy(audio), sample_rate, format=response_format
|
|
227
|
+
)
|
|
228
|
+
return out.getvalue()
|