xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +26 -4
- xinference/client/restful/restful_client.py +16 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +8 -3
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/model_spec.json +7 -0
- xinference/model/image/stable_diffusion/core.py +6 -1
- xinference/model/llm/llm_family.json +802 -82
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +295 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +8 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
xinference/__init__.py
CHANGED
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-07-
|
|
11
|
+
"date": "2024-07-26T18:42:50+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.13.
|
|
14
|
+
"full-revisionid": "aa51ff22dbfb5644554436270deaf57a7ebaf066",
|
|
15
|
+
"version": "0.13.3"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -130,6 +130,7 @@ class SpeechRequest(BaseModel):
|
|
|
130
130
|
response_format: Optional[str] = "mp3"
|
|
131
131
|
speed: Optional[float] = 1.0
|
|
132
132
|
stream: Optional[bool] = False
|
|
133
|
+
kwargs: Optional[str] = None
|
|
133
134
|
|
|
134
135
|
|
|
135
136
|
class RegisterModelRequest(BaseModel):
|
|
@@ -1309,8 +1310,18 @@ class RESTfulAPI:
|
|
|
1309
1310
|
await self._report_error_event(model_uid, str(e))
|
|
1310
1311
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1311
1312
|
|
|
1312
|
-
async def create_speech(
|
|
1313
|
-
|
|
1313
|
+
async def create_speech(
|
|
1314
|
+
self,
|
|
1315
|
+
request: Request,
|
|
1316
|
+
prompt_speech: Optional[UploadFile] = File(
|
|
1317
|
+
None, media_type="application/octet-stream"
|
|
1318
|
+
),
|
|
1319
|
+
) -> Response:
|
|
1320
|
+
if prompt_speech:
|
|
1321
|
+
f = await request.form()
|
|
1322
|
+
else:
|
|
1323
|
+
f = await request.json()
|
|
1324
|
+
body = SpeechRequest.parse_obj(f)
|
|
1314
1325
|
model_uid = body.model
|
|
1315
1326
|
try:
|
|
1316
1327
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -1324,12 +1335,19 @@ class RESTfulAPI:
|
|
|
1324
1335
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1325
1336
|
|
|
1326
1337
|
try:
|
|
1338
|
+
if body.kwargs is not None:
|
|
1339
|
+
parsed_kwargs = json.loads(body.kwargs)
|
|
1340
|
+
else:
|
|
1341
|
+
parsed_kwargs = {}
|
|
1342
|
+
if prompt_speech is not None:
|
|
1343
|
+
parsed_kwargs["prompt_speech"] = await prompt_speech.read()
|
|
1327
1344
|
out = await model.speech(
|
|
1328
1345
|
input=body.input,
|
|
1329
1346
|
voice=body.voice,
|
|
1330
1347
|
response_format=body.response_format,
|
|
1331
1348
|
speed=body.speed,
|
|
1332
1349
|
stream=body.stream,
|
|
1350
|
+
**parsed_kwargs,
|
|
1333
1351
|
)
|
|
1334
1352
|
if body.stream:
|
|
1335
1353
|
return EventSourceResponse(
|
|
@@ -1626,10 +1644,14 @@ class RESTfulAPI:
|
|
|
1626
1644
|
if body.tools and body.stream:
|
|
1627
1645
|
is_vllm = await model.is_vllm_backend()
|
|
1628
1646
|
|
|
1629
|
-
if not
|
|
1647
|
+
if not (
|
|
1648
|
+
(is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
|
|
1649
|
+
or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
|
|
1650
|
+
):
|
|
1630
1651
|
raise HTTPException(
|
|
1631
1652
|
status_code=400,
|
|
1632
|
-
detail="Streaming support for tool calls is available only when using
|
|
1653
|
+
detail="Streaming support for tool calls is available only when using "
|
|
1654
|
+
"Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
|
|
1633
1655
|
)
|
|
1634
1656
|
|
|
1635
1657
|
if body.stream:
|
|
@@ -768,6 +768,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
768
768
|
response_format: str = "mp3",
|
|
769
769
|
speed: float = 1.0,
|
|
770
770
|
stream: bool = False,
|
|
771
|
+
prompt_speech: Optional[bytes] = None,
|
|
772
|
+
**kwargs,
|
|
771
773
|
):
|
|
772
774
|
"""
|
|
773
775
|
Generates audio from the input text.
|
|
@@ -799,8 +801,21 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
799
801
|
"response_format": response_format,
|
|
800
802
|
"speed": speed,
|
|
801
803
|
"stream": stream,
|
|
804
|
+
"kwargs": json.dumps(kwargs),
|
|
802
805
|
}
|
|
803
|
-
|
|
806
|
+
if prompt_speech:
|
|
807
|
+
files: List[Any] = []
|
|
808
|
+
files.append(
|
|
809
|
+
(
|
|
810
|
+
"prompt_speech",
|
|
811
|
+
("prompt_speech", prompt_speech, "application/octet-stream"),
|
|
812
|
+
)
|
|
813
|
+
)
|
|
814
|
+
response = requests.post(
|
|
815
|
+
url, data=params, files=files, headers=self.auth_headers
|
|
816
|
+
)
|
|
817
|
+
else:
|
|
818
|
+
response = requests.post(url, json=params, headers=self.auth_headers)
|
|
804
819
|
if response.status_code != 200:
|
|
805
820
|
raise RuntimeError(
|
|
806
821
|
f"Failed to speech the text, detail: {_get_error_string(response)}"
|
|
@@ -428,7 +428,7 @@ class GradioInterface:
|
|
|
428
428
|
}
|
|
429
429
|
|
|
430
430
|
hist.append(response_content)
|
|
431
|
-
return {
|
|
431
|
+
return { # type: ignore
|
|
432
432
|
textbox: response_content,
|
|
433
433
|
history: hist,
|
|
434
434
|
}
|
|
@@ -467,7 +467,7 @@ class GradioInterface:
|
|
|
467
467
|
}
|
|
468
468
|
|
|
469
469
|
hist.append(response_content)
|
|
470
|
-
return {
|
|
470
|
+
return { # type: ignore
|
|
471
471
|
textbox: response_content,
|
|
472
472
|
history: hist,
|
|
473
473
|
}
|
xinference/core/model.py
CHANGED
|
@@ -646,7 +646,10 @@ class ModelActor(xo.StatelessActor):
|
|
|
646
646
|
f"Model {self._model.model_spec} is not for creating translations."
|
|
647
647
|
)
|
|
648
648
|
|
|
649
|
-
@log_async(
|
|
649
|
+
@log_async(
|
|
650
|
+
logger=logger,
|
|
651
|
+
args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
|
|
652
|
+
)
|
|
650
653
|
@request_limit
|
|
651
654
|
@xo.generator
|
|
652
655
|
async def speech(
|
|
@@ -656,6 +659,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
656
659
|
response_format: str = "mp3",
|
|
657
660
|
speed: float = 1.0,
|
|
658
661
|
stream: bool = False,
|
|
662
|
+
**kwargs,
|
|
659
663
|
):
|
|
660
664
|
if hasattr(self._model, "speech"):
|
|
661
665
|
return await self._call_wrapper_binary(
|
|
@@ -665,6 +669,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
665
669
|
response_format,
|
|
666
670
|
speed,
|
|
667
671
|
stream,
|
|
672
|
+
**kwargs,
|
|
668
673
|
)
|
|
669
674
|
raise AttributeError(
|
|
670
675
|
f"Model {self._model.model_spec} is not for creating speech."
|
|
@@ -735,7 +740,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
735
740
|
**kwargs,
|
|
736
741
|
):
|
|
737
742
|
if hasattr(self._model, "inpainting"):
|
|
738
|
-
return await self.
|
|
743
|
+
return await self._call_wrapper_json(
|
|
739
744
|
self._model.inpainting,
|
|
740
745
|
image,
|
|
741
746
|
mask_image,
|
|
@@ -758,7 +763,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
758
763
|
**kwargs,
|
|
759
764
|
):
|
|
760
765
|
if hasattr(self._model, "infer"):
|
|
761
|
-
return await self.
|
|
766
|
+
return await self._call_wrapper_json(
|
|
762
767
|
self._model.infer,
|
|
763
768
|
**kwargs,
|
|
764
769
|
)
|
xinference/core/scheduler.py
CHANGED
|
@@ -81,7 +81,7 @@ class InferenceRequest:
|
|
|
81
81
|
self.future_or_queue = future_or_queue
|
|
82
82
|
# Record error message when this request has error.
|
|
83
83
|
# Must set stopped=True when this field is set.
|
|
84
|
-
self.error_msg: Optional[str] = None
|
|
84
|
+
self.error_msg: Optional[str] = None # type: ignore
|
|
85
85
|
# For compatibility. Record some extra parameters for some special cases.
|
|
86
86
|
self.extra_kwargs = {}
|
|
87
87
|
|
|
@@ -295,11 +295,11 @@ class SchedulerActor(xo.StatelessActor):
|
|
|
295
295
|
|
|
296
296
|
def __init__(self):
|
|
297
297
|
super().__init__()
|
|
298
|
-
self._waiting_queue: deque[InferenceRequest] = deque()
|
|
299
|
-
self._running_queue: deque[InferenceRequest] = deque()
|
|
298
|
+
self._waiting_queue: deque[InferenceRequest] = deque() # type: ignore
|
|
299
|
+
self._running_queue: deque[InferenceRequest] = deque() # type: ignore
|
|
300
300
|
self._model = None
|
|
301
301
|
self._id_to_req = {}
|
|
302
|
-
self._abort_req_ids: Set[str] = set()
|
|
302
|
+
self._abort_req_ids: Set[str] = set() # type: ignore
|
|
303
303
|
self._isolation = None
|
|
304
304
|
|
|
305
305
|
async def __post_create__(self):
|
xinference/model/audio/core.py
CHANGED
|
@@ -20,6 +20,7 @@ from ...constants import XINFERENCE_CACHE_DIR
|
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
|
+
from .cosyvoice import CosyVoiceModel
|
|
23
24
|
from .whisper import WhisperModel
|
|
24
25
|
|
|
25
26
|
MAX_ATTEMPTS = 3
|
|
@@ -150,14 +151,16 @@ def create_audio_model_instance(
|
|
|
150
151
|
model_name: str,
|
|
151
152
|
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
152
153
|
**kwargs,
|
|
153
|
-
) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
|
|
154
|
+
) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
|
|
154
155
|
model_spec = match_audio(model_name, download_hub)
|
|
155
156
|
model_path = cache(model_spec)
|
|
156
|
-
model: Union[WhisperModel, ChatTTSModel]
|
|
157
|
+
model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
|
|
157
158
|
if model_spec.model_family == "whisper":
|
|
158
159
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
159
160
|
elif model_spec.model_family == "ChatTTS":
|
|
160
161
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
162
|
+
elif model_spec.model_family == "CosyVoice":
|
|
163
|
+
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
|
|
161
164
|
else:
|
|
162
165
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
163
166
|
model_description = AudioModelDescription(
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import io
|
|
15
|
+
import logging
|
|
16
|
+
from io import BytesIO
|
|
17
|
+
from typing import TYPE_CHECKING, Optional
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .core import AudioModelFamilyV1
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CosyVoiceModel:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model_uid: str,
|
|
29
|
+
model_path: str,
|
|
30
|
+
model_spec: "AudioModelFamilyV1",
|
|
31
|
+
device: Optional[str] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
self._model_uid = model_uid
|
|
35
|
+
self._model_path = model_path
|
|
36
|
+
self._model_spec = model_spec
|
|
37
|
+
self._device = device
|
|
38
|
+
self._model = None
|
|
39
|
+
self._kwargs = kwargs
|
|
40
|
+
|
|
41
|
+
def load(self):
|
|
42
|
+
import os
|
|
43
|
+
import sys
|
|
44
|
+
|
|
45
|
+
# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
|
|
46
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
|
|
47
|
+
|
|
48
|
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
49
|
+
|
|
50
|
+
self._model = CosyVoice(self._model_path)
|
|
51
|
+
|
|
52
|
+
def speech(
|
|
53
|
+
self,
|
|
54
|
+
input: str,
|
|
55
|
+
voice: str,
|
|
56
|
+
response_format: str = "mp3",
|
|
57
|
+
speed: float = 1.0,
|
|
58
|
+
stream: bool = False,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
if stream:
|
|
62
|
+
raise Exception("CosyVoiceModel does not support stream.")
|
|
63
|
+
|
|
64
|
+
import torchaudio
|
|
65
|
+
from cosyvoice.utils.file_utils import load_wav
|
|
66
|
+
|
|
67
|
+
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
68
|
+
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
69
|
+
instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
|
|
70
|
+
|
|
71
|
+
if "SFT" in self._model_spec.model_name:
|
|
72
|
+
# inference_sft
|
|
73
|
+
assert (
|
|
74
|
+
prompt_speech is None
|
|
75
|
+
), "CosyVoice SFT model does not support prompt_speech"
|
|
76
|
+
assert (
|
|
77
|
+
prompt_text is None
|
|
78
|
+
), "CosyVoice SFT model does not support prompt_text"
|
|
79
|
+
assert (
|
|
80
|
+
instruct_text is None
|
|
81
|
+
), "CosyVoice SFT model does not support instruct_text"
|
|
82
|
+
elif "Instruct" in self._model_spec.model_name:
|
|
83
|
+
# inference_instruct
|
|
84
|
+
assert (
|
|
85
|
+
prompt_speech is None
|
|
86
|
+
), "CosyVoice Instruct model does not support prompt_speech"
|
|
87
|
+
assert (
|
|
88
|
+
prompt_text is None
|
|
89
|
+
), "CosyVoice Instruct model does not support prompt_text"
|
|
90
|
+
assert (
|
|
91
|
+
instruct_text is not None
|
|
92
|
+
), "CosyVoice Instruct model expect a instruct_text"
|
|
93
|
+
else:
|
|
94
|
+
# inference_zero_shot
|
|
95
|
+
# inference_cross_lingual
|
|
96
|
+
assert prompt_speech is not None, "CosyVoice model expect a prompt_speech"
|
|
97
|
+
assert (
|
|
98
|
+
instruct_text is None
|
|
99
|
+
), "CosyVoice model does not support instruct_text"
|
|
100
|
+
|
|
101
|
+
assert self._model is not None
|
|
102
|
+
if prompt_speech:
|
|
103
|
+
assert not voice, "voice can't be set with prompt speech."
|
|
104
|
+
with io.BytesIO(prompt_speech) as prompt_speech_io:
|
|
105
|
+
prompt_speech_16k = load_wav(prompt_speech_io, 16000)
|
|
106
|
+
if prompt_text:
|
|
107
|
+
logger.info("CosyVoice inference_zero_shot")
|
|
108
|
+
output = self._model.inference_zero_shot(
|
|
109
|
+
input, prompt_text, prompt_speech_16k
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
logger.info("CosyVoice inference_cross_lingual")
|
|
113
|
+
output = self._model.inference_cross_lingual(
|
|
114
|
+
input, prompt_speech_16k
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
available_speakers = self._model.list_avaliable_spks()
|
|
118
|
+
if not voice:
|
|
119
|
+
voice = available_speakers[0]
|
|
120
|
+
else:
|
|
121
|
+
assert (
|
|
122
|
+
voice in available_speakers
|
|
123
|
+
), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
|
|
124
|
+
if instruct_text:
|
|
125
|
+
logger.info("CosyVoice inference_instruct")
|
|
126
|
+
output = self._model.inference_instruct(
|
|
127
|
+
input, voice, instruct_text=instruct_text
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
logger.info("CosyVoice inference_sft")
|
|
131
|
+
output = self._model.inference_sft(input, voice)
|
|
132
|
+
|
|
133
|
+
# Save the generated audio
|
|
134
|
+
with BytesIO() as out:
|
|
135
|
+
torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
|
|
136
|
+
return out.getvalue()
|
|
@@ -102,5 +102,29 @@
|
|
|
102
102
|
"model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
|
|
103
103
|
"ability": "text-to-audio",
|
|
104
104
|
"multilingual": true
|
|
105
|
+
},
|
|
106
|
+
{
|
|
107
|
+
"model_name": "CosyVoice-300M",
|
|
108
|
+
"model_family": "CosyVoice",
|
|
109
|
+
"model_id": "model-scope/CosyVoice-300M",
|
|
110
|
+
"model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
|
|
111
|
+
"ability": "audio-to-audio",
|
|
112
|
+
"multilingual": true
|
|
113
|
+
},
|
|
114
|
+
{
|
|
115
|
+
"model_name": "CosyVoice-300M-SFT",
|
|
116
|
+
"model_family": "CosyVoice",
|
|
117
|
+
"model_id": "model-scope/CosyVoice-300M-SFT",
|
|
118
|
+
"model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
|
|
119
|
+
"ability": "text-to-audio",
|
|
120
|
+
"multilingual": true
|
|
121
|
+
},
|
|
122
|
+
{
|
|
123
|
+
"model_name": "CosyVoice-300M-Instruct",
|
|
124
|
+
"model_family": "CosyVoice",
|
|
125
|
+
"model_id": "model-scope/CosyVoice-300M-Instruct",
|
|
126
|
+
"model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
|
|
127
|
+
"ability": "text-to-audio",
|
|
128
|
+
"multilingual": true
|
|
105
129
|
}
|
|
106
130
|
]
|
|
@@ -16,5 +16,32 @@
|
|
|
16
16
|
"model_revision": "master",
|
|
17
17
|
"ability": "text-to-audio",
|
|
18
18
|
"multilingual": true
|
|
19
|
+
},
|
|
20
|
+
{
|
|
21
|
+
"model_name": "CosyVoice-300M",
|
|
22
|
+
"model_family": "CosyVoice",
|
|
23
|
+
"model_hub": "modelscope",
|
|
24
|
+
"model_id": "iic/CosyVoice-300M",
|
|
25
|
+
"model_revision": "master",
|
|
26
|
+
"ability": "audio-to-audio",
|
|
27
|
+
"multilingual": true
|
|
28
|
+
},
|
|
29
|
+
{
|
|
30
|
+
"model_name": "CosyVoice-300M-SFT",
|
|
31
|
+
"model_family": "CosyVoice",
|
|
32
|
+
"model_hub": "modelscope",
|
|
33
|
+
"model_id": "iic/CosyVoice-300M-SFT",
|
|
34
|
+
"model_revision": "master",
|
|
35
|
+
"ability": "text-to-audio",
|
|
36
|
+
"multilingual": true
|
|
37
|
+
},
|
|
38
|
+
{
|
|
39
|
+
"model_name": "CosyVoice-300M-Instruct",
|
|
40
|
+
"model_family": "CosyVoice",
|
|
41
|
+
"model_hub": "modelscope",
|
|
42
|
+
"model_id": "iic/CosyVoice-300M-Instruct",
|
|
43
|
+
"model_revision": "master",
|
|
44
|
+
"ability": "text-to-audio",
|
|
45
|
+
"multilingual": true
|
|
19
46
|
}
|
|
20
47
|
]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
from io import BytesIO
|
|
17
|
+
|
|
18
|
+
import PIL.Image
|
|
19
|
+
import PIL.ImageOps
|
|
20
|
+
|
|
21
|
+
from ....types import Image
|
|
22
|
+
from ..core import FlexibleModel, FlexibleModelSpec
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ImageRemoveBackgroundModel(FlexibleModel):
|
|
26
|
+
def infer(self, **kwargs):
|
|
27
|
+
invert = kwargs.get("invert", False)
|
|
28
|
+
b64_image: str = kwargs.get("image") # type: ignore
|
|
29
|
+
only_mask = kwargs.pop("only_mask", True)
|
|
30
|
+
image_format = kwargs.pop("image_format", "PNG")
|
|
31
|
+
if not b64_image:
|
|
32
|
+
raise ValueError("No image found to remove background")
|
|
33
|
+
image = base64.b64decode(b64_image)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from rembg import remove
|
|
37
|
+
except ImportError:
|
|
38
|
+
error_message = "Failed to import module 'rembg'"
|
|
39
|
+
installation_guide = [
|
|
40
|
+
"Please make sure 'rembg' is installed. ",
|
|
41
|
+
"You can install it by visiting the installation section of the git repo:\n",
|
|
42
|
+
"https://github.com/danielgatis/rembg?tab=readme-ov-file#installation",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
46
|
+
|
|
47
|
+
im = PIL.Image.open(BytesIO(image))
|
|
48
|
+
om = remove(im, only_mask=only_mask, **kwargs)
|
|
49
|
+
if invert:
|
|
50
|
+
om = PIL.ImageOps.invert(om)
|
|
51
|
+
|
|
52
|
+
buffered = BytesIO()
|
|
53
|
+
om.save(buffered, format=image_format)
|
|
54
|
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
55
|
+
return Image(url=None, b64_json=img_str)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
|
|
59
|
+
task = kwargs.get("task")
|
|
60
|
+
device = kwargs.get("device")
|
|
61
|
+
|
|
62
|
+
if task == "remove_background":
|
|
63
|
+
return ImageRemoveBackgroundModel(
|
|
64
|
+
model_uid=model_uid,
|
|
65
|
+
model_path=model_spec.model_uri, # type: ignore
|
|
66
|
+
device=device,
|
|
67
|
+
config=kwargs,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"Unknown Task for image processing: {task}")
|
|
@@ -106,5 +106,12 @@
|
|
|
106
106
|
"model_id": "stabilityai/stable-diffusion-2-inpainting",
|
|
107
107
|
"model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
|
|
108
108
|
"ability": "inpainting"
|
|
109
|
+
},
|
|
110
|
+
{
|
|
111
|
+
"model_name": "stable-diffusion-xl-inpainting",
|
|
112
|
+
"model_family": "stable_diffusion",
|
|
113
|
+
"model_id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
|
114
|
+
"model_revision": "115134f363124c53c7d878647567d04daf26e41e",
|
|
115
|
+
"ability": "inpainting"
|
|
109
116
|
}
|
|
110
117
|
]
|
|
@@ -94,7 +94,12 @@ class DiffusionModel:
|
|
|
94
94
|
self._model_path,
|
|
95
95
|
**self._kwargs,
|
|
96
96
|
)
|
|
97
|
-
self.
|
|
97
|
+
if self._kwargs.get("cpu_offload", False):
|
|
98
|
+
logger.debug("CPU offloading model")
|
|
99
|
+
self._model.enable_model_cpu_offload()
|
|
100
|
+
else:
|
|
101
|
+
logger.debug("Loading model to available device")
|
|
102
|
+
self._model = move_model_to_available_device(self._model)
|
|
98
103
|
# Recommended if your computer has < 64 GB of RAM
|
|
99
104
|
self._model.enable_attention_slicing()
|
|
100
105
|
self._apply_lora()
|