xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +48 -41
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +2 -0
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +10 -1
- xinference/model/llm/vllm/core.py +1 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-08-
|
|
11
|
+
"date": "2024-08-23T18:14:53+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.14.
|
|
14
|
+
"full-revisionid": "b5002242e04634bca7e75cac9df0cdc6c0bf407a",
|
|
15
|
+
"version": "0.14.3"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -340,7 +340,7 @@ class GradioInterface:
|
|
|
340
340
|
state = gr.State([])
|
|
341
341
|
with gr.Row():
|
|
342
342
|
chatbot = gr.Chatbot(
|
|
343
|
-
elem_id="chatbot", label=self.model_name, height=
|
|
343
|
+
elem_id="chatbot", label=self.model_name, height=700, scale=7
|
|
344
344
|
)
|
|
345
345
|
with gr.Column(scale=3):
|
|
346
346
|
imagebox = gr.Image(type="filepath")
|
|
@@ -163,6 +163,7 @@ class ImageInterface:
|
|
|
163
163
|
size_width: int,
|
|
164
164
|
size_height: int,
|
|
165
165
|
num_inference_steps: int,
|
|
166
|
+
padding_image_to_multiple: int,
|
|
166
167
|
) -> PIL.Image.Image:
|
|
167
168
|
from ..client import RESTfulClient
|
|
168
169
|
|
|
@@ -178,6 +179,7 @@ class ImageInterface:
|
|
|
178
179
|
num_inference_steps = (
|
|
179
180
|
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
180
181
|
)
|
|
182
|
+
padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
|
|
181
183
|
|
|
182
184
|
bio = io.BytesIO()
|
|
183
185
|
image.save(bio, format="png")
|
|
@@ -190,6 +192,7 @@ class ImageInterface:
|
|
|
190
192
|
size=size,
|
|
191
193
|
response_format="b64_json",
|
|
192
194
|
num_inference_steps=num_inference_steps,
|
|
195
|
+
padding_image_to_multiple=padding_image_to_multiple,
|
|
193
196
|
)
|
|
194
197
|
|
|
195
198
|
images = []
|
|
@@ -222,9 +225,14 @@ class ImageInterface:
|
|
|
222
225
|
n = gr.Number(label="Number of image", value=1)
|
|
223
226
|
size_width = gr.Number(label="Width", value=-1)
|
|
224
227
|
size_height = gr.Number(label="Height", value=-1)
|
|
228
|
+
|
|
229
|
+
with gr.Row():
|
|
225
230
|
num_inference_steps = gr.Number(
|
|
226
231
|
label="Inference Step Number", value=-1
|
|
227
232
|
)
|
|
233
|
+
padding_image_to_multiple = gr.Number(
|
|
234
|
+
label="Padding image to multiple", value=-1
|
|
235
|
+
)
|
|
228
236
|
|
|
229
237
|
with gr.Row():
|
|
230
238
|
with gr.Column(scale=1):
|
|
@@ -242,6 +250,7 @@ class ImageInterface:
|
|
|
242
250
|
size_width,
|
|
243
251
|
size_height,
|
|
244
252
|
num_inference_steps,
|
|
253
|
+
padding_image_to_multiple,
|
|
245
254
|
],
|
|
246
255
|
outputs=output_gallery,
|
|
247
256
|
)
|
xinference/core/model.py
CHANGED
|
@@ -177,6 +177,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
177
177
|
request_limits: Optional[int] = None,
|
|
178
178
|
):
|
|
179
179
|
super().__init__()
|
|
180
|
+
from ..model.llm.lmdeploy.core import LMDeployModel
|
|
180
181
|
from ..model.llm.sglang.core import SGLANGModel
|
|
181
182
|
from ..model.llm.transformers.core import PytorchModel
|
|
182
183
|
from ..model.llm.vllm.core import VLLMModel
|
|
@@ -192,7 +193,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
192
193
|
self._current_generator = lambda: None
|
|
193
194
|
self._lock = (
|
|
194
195
|
None
|
|
195
|
-
if isinstance(
|
|
196
|
+
if isinstance(
|
|
197
|
+
self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel)
|
|
198
|
+
)
|
|
196
199
|
else asyncio.locks.Lock()
|
|
197
200
|
)
|
|
198
201
|
self._worker_ref = None
|
xinference/core/worker.py
CHANGED
|
@@ -39,9 +39,11 @@ from ..core.status_guard import LaunchStatus
|
|
|
39
39
|
from ..device_utils import get_available_device_env_name, gpu_count
|
|
40
40
|
from ..model.core import ModelDescription, create_model_instance
|
|
41
41
|
from ..types import PeftModelConfig
|
|
42
|
+
from .cache_tracker import CacheTrackerActor
|
|
42
43
|
from .event import Event, EventCollectorActor, EventType
|
|
43
44
|
from .metrics import launch_metrics_export_server, record_metrics
|
|
44
45
|
from .resource import gather_node_info
|
|
46
|
+
from .status_guard import StatusGuardActor
|
|
45
47
|
from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
|
|
46
48
|
|
|
47
49
|
logger = getLogger(__name__)
|
|
@@ -71,6 +73,15 @@ class WorkerActor(xo.StatelessActor):
|
|
|
71
73
|
self._supervisor_ref: Optional[xo.ActorRefType] = None
|
|
72
74
|
self._main_pool = main_pool
|
|
73
75
|
self._main_pool.recover_sub_pool = self.recover_sub_pool
|
|
76
|
+
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
77
|
+
"StatusGuardActor"
|
|
78
|
+
] = None
|
|
79
|
+
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
80
|
+
EventCollectorActor
|
|
81
|
+
] = None
|
|
82
|
+
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
83
|
+
CacheTrackerActor
|
|
84
|
+
] = None
|
|
74
85
|
|
|
75
86
|
# internal states.
|
|
76
87
|
# temporary placeholder during model launch process:
|
|
@@ -308,56 +319,50 @@ class WorkerActor(xo.StatelessActor):
|
|
|
308
319
|
Params:
|
|
309
320
|
add_worker: By default will call supervisor.add_worker after first connect
|
|
310
321
|
"""
|
|
311
|
-
from .status_guard import StatusGuardActor
|
|
312
322
|
from .supervisor import SupervisorActor
|
|
313
323
|
|
|
314
324
|
if self._supervisor_ref is not None:
|
|
315
325
|
return self._supervisor_ref
|
|
316
|
-
|
|
326
|
+
supervisor_ref = await xo.actor_ref( # type: ignore
|
|
317
327
|
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
318
328
|
)
|
|
329
|
+
# Prevent concurrent operations leads to double initialization, check again.
|
|
330
|
+
if self._supervisor_ref is not None:
|
|
331
|
+
return self._supervisor_ref
|
|
332
|
+
self._supervisor_ref = supervisor_ref
|
|
319
333
|
if add_worker and len(self._model_uid_to_model) == 0:
|
|
320
334
|
# Newly started (or restarted), has no model, notify supervisor
|
|
321
335
|
await self._supervisor_ref.add_worker(self.address)
|
|
322
336
|
logger.info("Connected to supervisor as a fresh worker")
|
|
323
337
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
)
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
model_version_infos.update(get_llm_model_descriptions())
|
|
353
|
-
model_version_infos.update(get_embedding_model_descriptions())
|
|
354
|
-
model_version_infos.update(get_rerank_model_descriptions())
|
|
355
|
-
model_version_infos.update(get_image_model_descriptions())
|
|
356
|
-
model_version_infos.update(get_audio_model_descriptions())
|
|
357
|
-
model_version_infos.update(get_flexible_model_descriptions())
|
|
358
|
-
await self._cache_tracker_ref.record_model_version(
|
|
359
|
-
model_version_infos, self.address
|
|
360
|
-
)
|
|
338
|
+
self._status_guard_ref = await xo.actor_ref(
|
|
339
|
+
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
340
|
+
)
|
|
341
|
+
self._event_collector_ref = await xo.actor_ref(
|
|
342
|
+
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
343
|
+
)
|
|
344
|
+
self._cache_tracker_ref = await xo.actor_ref(
|
|
345
|
+
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
346
|
+
)
|
|
347
|
+
# cache_tracker is on supervisor
|
|
348
|
+
from ..model.audio import get_audio_model_descriptions
|
|
349
|
+
from ..model.embedding import get_embedding_model_descriptions
|
|
350
|
+
from ..model.flexible import get_flexible_model_descriptions
|
|
351
|
+
from ..model.image import get_image_model_descriptions
|
|
352
|
+
from ..model.llm import get_llm_model_descriptions
|
|
353
|
+
from ..model.rerank import get_rerank_model_descriptions
|
|
354
|
+
|
|
355
|
+
# record model version
|
|
356
|
+
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
|
|
357
|
+
model_version_infos.update(get_llm_model_descriptions())
|
|
358
|
+
model_version_infos.update(get_embedding_model_descriptions())
|
|
359
|
+
model_version_infos.update(get_rerank_model_descriptions())
|
|
360
|
+
model_version_infos.update(get_image_model_descriptions())
|
|
361
|
+
model_version_infos.update(get_audio_model_descriptions())
|
|
362
|
+
model_version_infos.update(get_flexible_model_descriptions())
|
|
363
|
+
await self._cache_tracker_ref.record_model_version(
|
|
364
|
+
model_version_infos, self.address
|
|
365
|
+
)
|
|
361
366
|
return self._supervisor_ref
|
|
362
367
|
|
|
363
368
|
@staticmethod
|
|
@@ -734,7 +739,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
734
739
|
elif model_type == "image":
|
|
735
740
|
return ["text_to_image"]
|
|
736
741
|
elif model_type == "audio":
|
|
737
|
-
return [
|
|
742
|
+
return [model._model_spec.ability]
|
|
738
743
|
elif model_type == "video":
|
|
739
744
|
return ["text_to_video"]
|
|
740
745
|
elif model_type == "flexible":
|
|
@@ -793,6 +798,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
793
798
|
logger.exception(e)
|
|
794
799
|
raise
|
|
795
800
|
try:
|
|
801
|
+
_ = await self.get_supervisor_ref()
|
|
796
802
|
if self._event_collector_ref is not None:
|
|
797
803
|
await self._event_collector_ref.report_event(
|
|
798
804
|
origin_uid,
|
|
@@ -914,6 +920,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
914
920
|
raise ValueError(f"{model_uid} is launching")
|
|
915
921
|
origin_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
916
922
|
try:
|
|
923
|
+
_ = await self.get_supervisor_ref()
|
|
917
924
|
if self._event_collector_ref is not None:
|
|
918
925
|
await self._event_collector_ref.report_event(
|
|
919
926
|
origin_uid,
|
|
@@ -1081,7 +1088,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
1081
1088
|
paths.update([os.path.realpath(path) for path in paths])
|
|
1082
1089
|
|
|
1083
1090
|
# get tensorizer path
|
|
1084
|
-
from ..model.llm.
|
|
1091
|
+
from ..model.llm.transformers.tensorizer_utils import get_tensorizer_dir
|
|
1085
1092
|
|
|
1086
1093
|
tensorizer_path = get_tensorizer_dir(path)
|
|
1087
1094
|
if os.path.isdir(tensorizer_path):
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import base64
|
|
14
15
|
import logging
|
|
15
16
|
from io import BytesIO
|
|
16
17
|
from typing import TYPE_CHECKING, Optional
|
|
@@ -61,16 +62,31 @@ class ChatTTSModel:
|
|
|
61
62
|
import torchaudio
|
|
62
63
|
import xxhash
|
|
63
64
|
|
|
64
|
-
|
|
65
|
+
rnd_spk_emb = None
|
|
65
66
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
67
|
+
if len(voice) > 400:
|
|
68
|
+
try:
|
|
69
|
+
assert self._model is not None
|
|
70
|
+
b = base64.b64decode(voice)
|
|
71
|
+
bio = BytesIO(b)
|
|
72
|
+
tensor = torch.load(bio, map_location="cpu")
|
|
73
|
+
rnd_spk_emb = self._model._encode_spk_emb(tensor)
|
|
74
|
+
logger.info("Speech by input speaker")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.info("Fallback to random speaker due to %s", e)
|
|
71
77
|
|
|
72
|
-
|
|
73
|
-
|
|
78
|
+
if rnd_spk_emb is None:
|
|
79
|
+
seed = xxhash.xxh32_intdigest(voice)
|
|
80
|
+
|
|
81
|
+
torch.manual_seed(seed)
|
|
82
|
+
np.random.seed(seed)
|
|
83
|
+
torch.cuda.manual_seed(seed)
|
|
84
|
+
torch.backends.cudnn.deterministic = True
|
|
85
|
+
torch.backends.cudnn.benchmark = False
|
|
86
|
+
|
|
87
|
+
assert self._model is not None
|
|
88
|
+
rnd_spk_emb = self._model.sample_random_speaker()
|
|
89
|
+
logger.info("Speech by voice %s", voice)
|
|
74
90
|
|
|
75
91
|
default = 5
|
|
76
92
|
infer_speed = int(default * speed)
|
|
@@ -100,7 +116,6 @@ class ChatTTSModel:
|
|
|
100
116
|
if new_last_pos != last_pos:
|
|
101
117
|
out.seek(last_pos)
|
|
102
118
|
encoded_bytes = out.read()
|
|
103
|
-
print(len(encoded_bytes))
|
|
104
119
|
yield encoded_bytes
|
|
105
120
|
last_pos = new_last_pos
|
|
106
121
|
|
xinference/model/audio/core.py
CHANGED
|
@@ -21,6 +21,7 @@ from ..core import CacheableModelSpec, ModelDescription
|
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
|
+
from .fish_speech import FishSpeechModel
|
|
24
25
|
from .funasr import FunASRModel
|
|
25
26
|
from .whisper import WhisperModel
|
|
26
27
|
|
|
@@ -46,6 +47,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
46
47
|
model_id: str
|
|
47
48
|
model_revision: str
|
|
48
49
|
multilingual: bool
|
|
50
|
+
ability: str
|
|
49
51
|
default_model_config: Optional[Dict[str, Any]]
|
|
50
52
|
default_transcription_config: Optional[Dict[str, Any]]
|
|
51
53
|
|
|
@@ -156,13 +158,15 @@ def create_audio_model_instance(
|
|
|
156
158
|
model_path: Optional[str] = None,
|
|
157
159
|
**kwargs,
|
|
158
160
|
) -> Tuple[
|
|
159
|
-
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
|
|
161
|
+
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
|
|
160
162
|
AudioModelDescription,
|
|
161
163
|
]:
|
|
162
164
|
model_spec = match_audio(model_name, download_hub)
|
|
163
165
|
if model_path is None:
|
|
164
166
|
model_path = cache(model_spec)
|
|
165
|
-
model: Union[
|
|
167
|
+
model: Union[
|
|
168
|
+
WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
|
|
169
|
+
]
|
|
166
170
|
if model_spec.model_family == "whisper":
|
|
167
171
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
168
172
|
elif model_spec.model_family == "funasr":
|
|
@@ -171,6 +175,8 @@ def create_audio_model_instance(
|
|
|
171
175
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
172
176
|
elif model_spec.model_family == "CosyVoice":
|
|
173
177
|
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
|
|
178
|
+
elif model_spec.model_family == "FishAudio":
|
|
179
|
+
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
|
|
174
180
|
else:
|
|
175
181
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
176
182
|
model_description = AudioModelDescription(
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import gc
|
|
15
|
+
import logging
|
|
16
|
+
import os.path
|
|
17
|
+
import queue
|
|
18
|
+
import sys
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from typing import TYPE_CHECKING, Optional
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
from ...device_utils import get_available_device, is_device_available
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from .core import AudioModelFamilyV1
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
34
|
+
import wave
|
|
35
|
+
|
|
36
|
+
buffer = BytesIO()
|
|
37
|
+
|
|
38
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
39
|
+
wav_file.setnchannels(channels)
|
|
40
|
+
wav_file.setsampwidth(bit_depth // 8)
|
|
41
|
+
wav_file.setframerate(sample_rate)
|
|
42
|
+
|
|
43
|
+
wav_header_bytes = buffer.getvalue()
|
|
44
|
+
buffer.close()
|
|
45
|
+
return wav_header_bytes
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FishSpeechModel:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model_uid: str,
|
|
52
|
+
model_path: str,
|
|
53
|
+
model_spec: "AudioModelFamilyV1",
|
|
54
|
+
device: Optional[str] = None,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
self._model_uid = model_uid
|
|
58
|
+
self._model_path = model_path
|
|
59
|
+
self._model_spec = model_spec
|
|
60
|
+
self._device = device
|
|
61
|
+
self._llama_queue = None
|
|
62
|
+
self._model = None
|
|
63
|
+
self._kwargs = kwargs
|
|
64
|
+
|
|
65
|
+
def load(self):
|
|
66
|
+
# There are too many imports from fish_speech.
|
|
67
|
+
sys.path.insert(
|
|
68
|
+
0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
from tools.llama.generate import launch_thread_safe_queue
|
|
72
|
+
from tools.vqgan.inference import load_model as load_decoder_model
|
|
73
|
+
|
|
74
|
+
if self._device is None:
|
|
75
|
+
self._device = get_available_device()
|
|
76
|
+
else:
|
|
77
|
+
if not is_device_available(self._device):
|
|
78
|
+
raise ValueError(f"Device {self._device} is not available!")
|
|
79
|
+
|
|
80
|
+
logger.info("Loading Llama model...")
|
|
81
|
+
self._llama_queue = launch_thread_safe_queue(
|
|
82
|
+
checkpoint_path=self._model_path,
|
|
83
|
+
device=self._device,
|
|
84
|
+
precision=torch.bfloat16,
|
|
85
|
+
compile=False,
|
|
86
|
+
)
|
|
87
|
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
88
|
+
|
|
89
|
+
checkpoint_path = os.path.join(
|
|
90
|
+
self._model_path,
|
|
91
|
+
"firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
92
|
+
)
|
|
93
|
+
self._model = load_decoder_model(
|
|
94
|
+
config_name="firefly_gan_vq",
|
|
95
|
+
checkpoint_path=checkpoint_path,
|
|
96
|
+
device=self._device,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@torch.inference_mode()
|
|
100
|
+
def _inference(
|
|
101
|
+
self,
|
|
102
|
+
text,
|
|
103
|
+
enable_reference_audio,
|
|
104
|
+
reference_audio,
|
|
105
|
+
reference_text,
|
|
106
|
+
max_new_tokens,
|
|
107
|
+
chunk_length,
|
|
108
|
+
top_p,
|
|
109
|
+
repetition_penalty,
|
|
110
|
+
temperature,
|
|
111
|
+
streaming=False,
|
|
112
|
+
):
|
|
113
|
+
from fish_speech.utils import autocast_exclude_mps
|
|
114
|
+
from tools.api import decode_vq_tokens, encode_reference
|
|
115
|
+
from tools.llama.generate import (
|
|
116
|
+
GenerateRequest,
|
|
117
|
+
GenerateResponse,
|
|
118
|
+
WrappedGenerateResponse,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Parse reference audio aka prompt
|
|
122
|
+
prompt_tokens = encode_reference(
|
|
123
|
+
decoder_model=self._model,
|
|
124
|
+
reference_audio=reference_audio,
|
|
125
|
+
enable_reference_audio=enable_reference_audio,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# LLAMA Inference
|
|
129
|
+
request = dict(
|
|
130
|
+
device=self._model.device,
|
|
131
|
+
max_new_tokens=max_new_tokens,
|
|
132
|
+
text=text,
|
|
133
|
+
top_p=top_p,
|
|
134
|
+
repetition_penalty=repetition_penalty,
|
|
135
|
+
temperature=temperature,
|
|
136
|
+
compile=False,
|
|
137
|
+
iterative_prompt=chunk_length > 0,
|
|
138
|
+
chunk_length=chunk_length,
|
|
139
|
+
max_length=2048,
|
|
140
|
+
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
141
|
+
prompt_text=reference_text if enable_reference_audio else None,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
response_queue = queue.Queue()
|
|
145
|
+
self._llama_queue.put(
|
|
146
|
+
GenerateRequest(
|
|
147
|
+
request=request,
|
|
148
|
+
response_queue=response_queue,
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if streaming:
|
|
153
|
+
yield wav_chunk_header(), None, None
|
|
154
|
+
|
|
155
|
+
segments = []
|
|
156
|
+
|
|
157
|
+
while True:
|
|
158
|
+
result: WrappedGenerateResponse = response_queue.get()
|
|
159
|
+
if result.status == "error":
|
|
160
|
+
raise Exception(str(result.response))
|
|
161
|
+
|
|
162
|
+
result: GenerateResponse = result.response
|
|
163
|
+
if result.action == "next":
|
|
164
|
+
break
|
|
165
|
+
|
|
166
|
+
with autocast_exclude_mps(
|
|
167
|
+
device_type=self._model.device.type, dtype=torch.bfloat16
|
|
168
|
+
):
|
|
169
|
+
fake_audios = decode_vq_tokens(
|
|
170
|
+
decoder_model=self._model,
|
|
171
|
+
codes=result.codes,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
fake_audios = fake_audios.float().cpu().numpy()
|
|
175
|
+
segments.append(fake_audios)
|
|
176
|
+
|
|
177
|
+
if streaming:
|
|
178
|
+
yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
|
|
179
|
+
|
|
180
|
+
if len(segments) == 0:
|
|
181
|
+
raise Exception("No audio generated, please check the input text.")
|
|
182
|
+
|
|
183
|
+
# No matter streaming or not, we need to return the final audio
|
|
184
|
+
audio = np.concatenate(segments, axis=0)
|
|
185
|
+
yield None, (self._model.spec_transform.sample_rate, audio), None
|
|
186
|
+
|
|
187
|
+
if torch.cuda.is_available():
|
|
188
|
+
torch.cuda.empty_cache()
|
|
189
|
+
gc.collect()
|
|
190
|
+
|
|
191
|
+
def speech(
|
|
192
|
+
self,
|
|
193
|
+
input: str,
|
|
194
|
+
voice: str,
|
|
195
|
+
response_format: str = "mp3",
|
|
196
|
+
speed: float = 1.0,
|
|
197
|
+
stream: bool = False,
|
|
198
|
+
**kwargs,
|
|
199
|
+
):
|
|
200
|
+
logger.warning("Fish speech does not support setting voice: %s.", voice)
|
|
201
|
+
if speed != 1.0:
|
|
202
|
+
logger.warning("Fish speech does not support setting speed: %s.", speed)
|
|
203
|
+
if stream is True:
|
|
204
|
+
logger.warning("stream mode is not implemented.")
|
|
205
|
+
import torchaudio
|
|
206
|
+
|
|
207
|
+
result = list(
|
|
208
|
+
self._inference(
|
|
209
|
+
text=input,
|
|
210
|
+
enable_reference_audio=False,
|
|
211
|
+
reference_audio=None,
|
|
212
|
+
reference_text="",
|
|
213
|
+
max_new_tokens=0,
|
|
214
|
+
chunk_length=100,
|
|
215
|
+
top_p=0.7,
|
|
216
|
+
repetition_penalty=1.2,
|
|
217
|
+
temperature=0.7,
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
sample_rate, audio = result[0][1]
|
|
221
|
+
audio = np.array([audio])
|
|
222
|
+
|
|
223
|
+
# Save the generated audio
|
|
224
|
+
with BytesIO() as out:
|
|
225
|
+
torchaudio.save(
|
|
226
|
+
out, torch.from_numpy(audio), sample_rate, format=response_format
|
|
227
|
+
)
|
|
228
|
+
return out.getvalue()
|
|
@@ -146,5 +146,13 @@
|
|
|
146
146
|
"model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
|
|
147
147
|
"ability": "text-to-audio",
|
|
148
148
|
"multilingual": true
|
|
149
|
+
},
|
|
150
|
+
{
|
|
151
|
+
"model_name": "FishSpeech-1.2-SFT",
|
|
152
|
+
"model_family": "FishAudio",
|
|
153
|
+
"model_id": "fishaudio/fish-speech-1.2-sft",
|
|
154
|
+
"model_revision": "180288e21ec5c50cfc564023a22f789e4b88a0e0",
|
|
155
|
+
"ability": "text-to-audio",
|
|
156
|
+
"multilingual": true
|
|
149
157
|
}
|
|
150
158
|
]
|
|
@@ -154,10 +154,32 @@ class EmbeddingModel:
|
|
|
154
154
|
"gte" in self._model_spec.model_name.lower()
|
|
155
155
|
and "qwen2" in self._model_spec.model_name.lower()
|
|
156
156
|
):
|
|
157
|
+
import torch
|
|
158
|
+
|
|
159
|
+
torch_dtype_str = self._kwargs.get("torch_dtype")
|
|
160
|
+
if torch_dtype_str is not None:
|
|
161
|
+
try:
|
|
162
|
+
torch_dtype = getattr(torch, torch_dtype_str)
|
|
163
|
+
if torch_dtype not in [
|
|
164
|
+
torch.float16,
|
|
165
|
+
torch.float32,
|
|
166
|
+
torch.bfloat16,
|
|
167
|
+
]:
|
|
168
|
+
logger.warning(
|
|
169
|
+
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
|
|
170
|
+
)
|
|
171
|
+
torch_dtype = torch.float32
|
|
172
|
+
except AttributeError:
|
|
173
|
+
logger.warning(
|
|
174
|
+
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
|
|
175
|
+
)
|
|
176
|
+
torch_dtype = torch.float32
|
|
177
|
+
else:
|
|
178
|
+
torch_dtype = "auto"
|
|
157
179
|
self._model = XSentenceTransformer(
|
|
158
180
|
self._model_path,
|
|
159
181
|
device=self._device,
|
|
160
|
-
model_kwargs={"device_map": "auto"},
|
|
182
|
+
model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
|
|
161
183
|
)
|
|
162
184
|
else:
|
|
163
185
|
self._model = SentenceTransformer(self._model_path, device=self._device)
|