xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +99 -5
- xinference/client/restful/restful_client.py +98 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +85 -26
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/chattts.py +40 -8
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/core.py +3 -0
- xinference/model/image/model_spec.json +21 -0
- xinference/model/image/stable_diffusion/core.py +49 -7
- xinference/model/llm/llm_family.json +1065 -106
- xinference/model/llm/llm_family.py +26 -6
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +460 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/sglang/core.py +7 -2
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +11 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
xinference/model/audio/core.py
CHANGED
|
@@ -20,6 +20,7 @@ from ...constants import XINFERENCE_CACHE_DIR
|
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
|
+
from .cosyvoice import CosyVoiceModel
|
|
23
24
|
from .whisper import WhisperModel
|
|
24
25
|
|
|
25
26
|
MAX_ATTEMPTS = 3
|
|
@@ -150,14 +151,16 @@ def create_audio_model_instance(
|
|
|
150
151
|
model_name: str,
|
|
151
152
|
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
152
153
|
**kwargs,
|
|
153
|
-
) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
|
|
154
|
+
) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
|
|
154
155
|
model_spec = match_audio(model_name, download_hub)
|
|
155
156
|
model_path = cache(model_spec)
|
|
156
|
-
model: Union[WhisperModel, ChatTTSModel]
|
|
157
|
+
model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
|
|
157
158
|
if model_spec.model_family == "whisper":
|
|
158
159
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
159
160
|
elif model_spec.model_family == "ChatTTS":
|
|
160
161
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
162
|
+
elif model_spec.model_family == "CosyVoice":
|
|
163
|
+
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
|
|
161
164
|
else:
|
|
162
165
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
163
166
|
model_description = AudioModelDescription(
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import io
|
|
15
|
+
import logging
|
|
16
|
+
from io import BytesIO
|
|
17
|
+
from typing import TYPE_CHECKING, Optional
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .core import AudioModelFamilyV1
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CosyVoiceModel:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model_uid: str,
|
|
29
|
+
model_path: str,
|
|
30
|
+
model_spec: "AudioModelFamilyV1",
|
|
31
|
+
device: Optional[str] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
self._model_uid = model_uid
|
|
35
|
+
self._model_path = model_path
|
|
36
|
+
self._model_spec = model_spec
|
|
37
|
+
self._device = device
|
|
38
|
+
self._model = None
|
|
39
|
+
self._kwargs = kwargs
|
|
40
|
+
|
|
41
|
+
def load(self):
|
|
42
|
+
import os
|
|
43
|
+
import sys
|
|
44
|
+
|
|
45
|
+
# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
|
|
46
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
|
|
47
|
+
|
|
48
|
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
49
|
+
|
|
50
|
+
self._model = CosyVoice(self._model_path)
|
|
51
|
+
|
|
52
|
+
def speech(
|
|
53
|
+
self,
|
|
54
|
+
input: str,
|
|
55
|
+
voice: str,
|
|
56
|
+
response_format: str = "mp3",
|
|
57
|
+
speed: float = 1.0,
|
|
58
|
+
stream: bool = False,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
if stream:
|
|
62
|
+
raise Exception("CosyVoiceModel does not support stream.")
|
|
63
|
+
|
|
64
|
+
import torchaudio
|
|
65
|
+
from cosyvoice.utils.file_utils import load_wav
|
|
66
|
+
|
|
67
|
+
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
68
|
+
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
69
|
+
instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
|
|
70
|
+
|
|
71
|
+
if "SFT" in self._model_spec.model_name:
|
|
72
|
+
# inference_sft
|
|
73
|
+
assert (
|
|
74
|
+
prompt_speech is None
|
|
75
|
+
), "CosyVoice SFT model does not support prompt_speech"
|
|
76
|
+
assert (
|
|
77
|
+
prompt_text is None
|
|
78
|
+
), "CosyVoice SFT model does not support prompt_text"
|
|
79
|
+
assert (
|
|
80
|
+
instruct_text is None
|
|
81
|
+
), "CosyVoice SFT model does not support instruct_text"
|
|
82
|
+
elif "Instruct" in self._model_spec.model_name:
|
|
83
|
+
# inference_instruct
|
|
84
|
+
assert (
|
|
85
|
+
prompt_speech is None
|
|
86
|
+
), "CosyVoice Instruct model does not support prompt_speech"
|
|
87
|
+
assert (
|
|
88
|
+
prompt_text is None
|
|
89
|
+
), "CosyVoice Instruct model does not support prompt_text"
|
|
90
|
+
assert (
|
|
91
|
+
instruct_text is not None
|
|
92
|
+
), "CosyVoice Instruct model expect a instruct_text"
|
|
93
|
+
else:
|
|
94
|
+
# inference_zero_shot
|
|
95
|
+
# inference_cross_lingual
|
|
96
|
+
assert prompt_speech is not None, "CosyVoice model expect a prompt_speech"
|
|
97
|
+
assert (
|
|
98
|
+
instruct_text is None
|
|
99
|
+
), "CosyVoice model does not support instruct_text"
|
|
100
|
+
|
|
101
|
+
assert self._model is not None
|
|
102
|
+
if prompt_speech:
|
|
103
|
+
assert not voice, "voice can't be set with prompt speech."
|
|
104
|
+
with io.BytesIO(prompt_speech) as prompt_speech_io:
|
|
105
|
+
prompt_speech_16k = load_wav(prompt_speech_io, 16000)
|
|
106
|
+
if prompt_text:
|
|
107
|
+
logger.info("CosyVoice inference_zero_shot")
|
|
108
|
+
output = self._model.inference_zero_shot(
|
|
109
|
+
input, prompt_text, prompt_speech_16k
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
logger.info("CosyVoice inference_cross_lingual")
|
|
113
|
+
output = self._model.inference_cross_lingual(
|
|
114
|
+
input, prompt_speech_16k
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
available_speakers = self._model.list_avaliable_spks()
|
|
118
|
+
if not voice:
|
|
119
|
+
voice = available_speakers[0]
|
|
120
|
+
else:
|
|
121
|
+
assert (
|
|
122
|
+
voice in available_speakers
|
|
123
|
+
), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
|
|
124
|
+
if instruct_text:
|
|
125
|
+
logger.info("CosyVoice inference_instruct")
|
|
126
|
+
output = self._model.inference_instruct(
|
|
127
|
+
input, voice, instruct_text=instruct_text
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
logger.info("CosyVoice inference_sft")
|
|
131
|
+
output = self._model.inference_sft(input, voice)
|
|
132
|
+
|
|
133
|
+
# Save the generated audio
|
|
134
|
+
with BytesIO() as out:
|
|
135
|
+
torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
|
|
136
|
+
return out.getvalue()
|
|
@@ -102,5 +102,29 @@
|
|
|
102
102
|
"model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
|
|
103
103
|
"ability": "text-to-audio",
|
|
104
104
|
"multilingual": true
|
|
105
|
+
},
|
|
106
|
+
{
|
|
107
|
+
"model_name": "CosyVoice-300M",
|
|
108
|
+
"model_family": "CosyVoice",
|
|
109
|
+
"model_id": "model-scope/CosyVoice-300M",
|
|
110
|
+
"model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
|
|
111
|
+
"ability": "audio-to-audio",
|
|
112
|
+
"multilingual": true
|
|
113
|
+
},
|
|
114
|
+
{
|
|
115
|
+
"model_name": "CosyVoice-300M-SFT",
|
|
116
|
+
"model_family": "CosyVoice",
|
|
117
|
+
"model_id": "model-scope/CosyVoice-300M-SFT",
|
|
118
|
+
"model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
|
|
119
|
+
"ability": "text-to-audio",
|
|
120
|
+
"multilingual": true
|
|
121
|
+
},
|
|
122
|
+
{
|
|
123
|
+
"model_name": "CosyVoice-300M-Instruct",
|
|
124
|
+
"model_family": "CosyVoice",
|
|
125
|
+
"model_id": "model-scope/CosyVoice-300M-Instruct",
|
|
126
|
+
"model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
|
|
127
|
+
"ability": "text-to-audio",
|
|
128
|
+
"multilingual": true
|
|
105
129
|
}
|
|
106
130
|
]
|
|
@@ -16,5 +16,32 @@
|
|
|
16
16
|
"model_revision": "master",
|
|
17
17
|
"ability": "text-to-audio",
|
|
18
18
|
"multilingual": true
|
|
19
|
+
},
|
|
20
|
+
{
|
|
21
|
+
"model_name": "CosyVoice-300M",
|
|
22
|
+
"model_family": "CosyVoice",
|
|
23
|
+
"model_hub": "modelscope",
|
|
24
|
+
"model_id": "iic/CosyVoice-300M",
|
|
25
|
+
"model_revision": "master",
|
|
26
|
+
"ability": "audio-to-audio",
|
|
27
|
+
"multilingual": true
|
|
28
|
+
},
|
|
29
|
+
{
|
|
30
|
+
"model_name": "CosyVoice-300M-SFT",
|
|
31
|
+
"model_family": "CosyVoice",
|
|
32
|
+
"model_hub": "modelscope",
|
|
33
|
+
"model_id": "iic/CosyVoice-300M-SFT",
|
|
34
|
+
"model_revision": "master",
|
|
35
|
+
"ability": "text-to-audio",
|
|
36
|
+
"multilingual": true
|
|
37
|
+
},
|
|
38
|
+
{
|
|
39
|
+
"model_name": "CosyVoice-300M-Instruct",
|
|
40
|
+
"model_family": "CosyVoice",
|
|
41
|
+
"model_hub": "modelscope",
|
|
42
|
+
"model_id": "iic/CosyVoice-300M-Instruct",
|
|
43
|
+
"model_revision": "master",
|
|
44
|
+
"ability": "text-to-audio",
|
|
45
|
+
"multilingual": true
|
|
19
46
|
}
|
|
20
47
|
]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
from io import BytesIO
|
|
17
|
+
|
|
18
|
+
import PIL.Image
|
|
19
|
+
import PIL.ImageOps
|
|
20
|
+
|
|
21
|
+
from ....types import Image
|
|
22
|
+
from ..core import FlexibleModel, FlexibleModelSpec
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ImageRemoveBackgroundModel(FlexibleModel):
|
|
26
|
+
def infer(self, **kwargs):
|
|
27
|
+
invert = kwargs.get("invert", False)
|
|
28
|
+
b64_image: str = kwargs.get("image") # type: ignore
|
|
29
|
+
only_mask = kwargs.pop("only_mask", True)
|
|
30
|
+
image_format = kwargs.pop("image_format", "PNG")
|
|
31
|
+
if not b64_image:
|
|
32
|
+
raise ValueError("No image found to remove background")
|
|
33
|
+
image = base64.b64decode(b64_image)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from rembg import remove
|
|
37
|
+
except ImportError:
|
|
38
|
+
error_message = "Failed to import module 'rembg'"
|
|
39
|
+
installation_guide = [
|
|
40
|
+
"Please make sure 'rembg' is installed. ",
|
|
41
|
+
"You can install it by visiting the installation section of the git repo:\n",
|
|
42
|
+
"https://github.com/danielgatis/rembg?tab=readme-ov-file#installation",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
46
|
+
|
|
47
|
+
im = PIL.Image.open(BytesIO(image))
|
|
48
|
+
om = remove(im, only_mask=only_mask, **kwargs)
|
|
49
|
+
if invert:
|
|
50
|
+
om = PIL.ImageOps.invert(om)
|
|
51
|
+
|
|
52
|
+
buffered = BytesIO()
|
|
53
|
+
om.save(buffered, format=image_format)
|
|
54
|
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
55
|
+
return Image(url=None, b64_json=img_str)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
|
|
59
|
+
task = kwargs.get("task")
|
|
60
|
+
device = kwargs.get("device")
|
|
61
|
+
|
|
62
|
+
if task == "remove_background":
|
|
63
|
+
return ImageRemoveBackgroundModel(
|
|
64
|
+
model_uid=model_uid,
|
|
65
|
+
model_path=model_spec.model_uri, # type: ignore
|
|
66
|
+
device=device,
|
|
67
|
+
config=kwargs,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"Unknown Task for image processing: {task}")
|
xinference/model/image/core.py
CHANGED
|
@@ -45,6 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
|
|
|
45
45
|
model_id: str
|
|
46
46
|
model_revision: str
|
|
47
47
|
model_hub: str = "huggingface"
|
|
48
|
+
ability: Optional[str]
|
|
48
49
|
controlnet: Optional[List["ImageModelFamilyV1"]]
|
|
49
50
|
|
|
50
51
|
|
|
@@ -71,6 +72,7 @@ class ImageModelDescription(ModelDescription):
|
|
|
71
72
|
"model_name": self._model_spec.model_name,
|
|
72
73
|
"model_family": self._model_spec.model_family,
|
|
73
74
|
"model_revision": self._model_spec.model_revision,
|
|
75
|
+
"ability": self._model_spec.ability,
|
|
74
76
|
"controlnet": controlnet,
|
|
75
77
|
}
|
|
76
78
|
|
|
@@ -234,6 +236,7 @@ def create_image_model_instance(
|
|
|
234
236
|
lora_model_paths=lora_model,
|
|
235
237
|
lora_load_kwargs=lora_load_kwargs,
|
|
236
238
|
lora_fuse_kwargs=lora_fuse_kwargs,
|
|
239
|
+
ability=model_spec.ability,
|
|
237
240
|
**kwargs,
|
|
238
241
|
)
|
|
239
242
|
model_description = ImageModelDescription(
|
|
@@ -92,5 +92,26 @@
|
|
|
92
92
|
"model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
|
|
93
93
|
}
|
|
94
94
|
]
|
|
95
|
+
},
|
|
96
|
+
{
|
|
97
|
+
"model_name": "stable-diffusion-inpainting",
|
|
98
|
+
"model_family": "stable_diffusion",
|
|
99
|
+
"model_id": "runwayml/stable-diffusion-inpainting",
|
|
100
|
+
"model_revision": "51388a731f57604945fddd703ecb5c50e8e7b49d",
|
|
101
|
+
"ability": "inpainting"
|
|
102
|
+
},
|
|
103
|
+
{
|
|
104
|
+
"model_name": "stable-diffusion-2-inpainting",
|
|
105
|
+
"model_family": "stable_diffusion",
|
|
106
|
+
"model_id": "stabilityai/stable-diffusion-2-inpainting",
|
|
107
|
+
"model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
|
|
108
|
+
"ability": "inpainting"
|
|
109
|
+
},
|
|
110
|
+
{
|
|
111
|
+
"model_name": "stable-diffusion-xl-inpainting",
|
|
112
|
+
"model_family": "stable_diffusion",
|
|
113
|
+
"model_id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
|
114
|
+
"model_revision": "115134f363124c53c7d878647567d04daf26e41e",
|
|
115
|
+
"ability": "inpainting"
|
|
95
116
|
}
|
|
96
117
|
]
|
|
@@ -16,6 +16,7 @@ import base64
|
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
18
|
import re
|
|
19
|
+
import sys
|
|
19
20
|
import time
|
|
20
21
|
import uuid
|
|
21
22
|
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -39,6 +40,7 @@ class DiffusionModel:
|
|
|
39
40
|
lora_model: Optional[List[LoRA]] = None,
|
|
40
41
|
lora_load_kwargs: Optional[Dict] = None,
|
|
41
42
|
lora_fuse_kwargs: Optional[Dict] = None,
|
|
43
|
+
ability: Optional[str] = None,
|
|
42
44
|
**kwargs,
|
|
43
45
|
):
|
|
44
46
|
self._model_uid = model_uid
|
|
@@ -48,6 +50,7 @@ class DiffusionModel:
|
|
|
48
50
|
self._lora_model = lora_model
|
|
49
51
|
self._lora_load_kwargs = lora_load_kwargs or {}
|
|
50
52
|
self._lora_fuse_kwargs = lora_fuse_kwargs or {}
|
|
53
|
+
self._ability = ability
|
|
51
54
|
self._kwargs = kwargs
|
|
52
55
|
|
|
53
56
|
def _apply_lora(self):
|
|
@@ -64,8 +67,14 @@ class DiffusionModel:
|
|
|
64
67
|
logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
|
|
65
68
|
|
|
66
69
|
def load(self):
|
|
67
|
-
|
|
68
|
-
|
|
70
|
+
import torch
|
|
71
|
+
|
|
72
|
+
if self._ability in [None, "text2image", "image2image"]:
|
|
73
|
+
from diffusers import AutoPipelineForText2Image as AutoPipelineModel
|
|
74
|
+
elif self._ability == "inpainting":
|
|
75
|
+
from diffusers import AutoPipelineForInpainting as AutoPipelineModel
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unknown ability: {self._ability}")
|
|
69
78
|
|
|
70
79
|
controlnet = self._kwargs.get("controlnet")
|
|
71
80
|
if controlnet is not None:
|
|
@@ -74,14 +83,23 @@ class DiffusionModel:
|
|
|
74
83
|
logger.debug("Loading controlnet %s", controlnet)
|
|
75
84
|
self._kwargs["controlnet"] = ControlNetModel.from_pretrained(controlnet)
|
|
76
85
|
|
|
77
|
-
|
|
86
|
+
torch_dtype = self._kwargs.get("torch_dtype")
|
|
87
|
+
if sys.platform != "darwin" and torch_dtype is None:
|
|
88
|
+
# The following params crashes on Mac M2
|
|
89
|
+
self._kwargs["torch_dtype"] = torch.float16
|
|
90
|
+
self._kwargs["use_safetensors"] = True
|
|
91
|
+
|
|
92
|
+
logger.debug("Loading model %s", AutoPipelineModel)
|
|
93
|
+
self._model = AutoPipelineModel.from_pretrained(
|
|
78
94
|
self._model_path,
|
|
79
95
|
**self._kwargs,
|
|
80
|
-
# The following params crashes on Mac M2
|
|
81
|
-
# torch_dtype=torch.float16,
|
|
82
|
-
# use_safetensors=True,
|
|
83
96
|
)
|
|
84
|
-
self.
|
|
97
|
+
if self._kwargs.get("cpu_offload", False):
|
|
98
|
+
logger.debug("CPU offloading model")
|
|
99
|
+
self._model.enable_model_cpu_offload()
|
|
100
|
+
else:
|
|
101
|
+
logger.debug("Loading model to available device")
|
|
102
|
+
self._model = move_model_to_available_device(self._model)
|
|
85
103
|
# Recommended if your computer has < 64 GB of RAM
|
|
86
104
|
self._model.enable_attention_slicing()
|
|
87
105
|
self._apply_lora()
|
|
@@ -174,3 +192,27 @@ class DiffusionModel:
|
|
|
174
192
|
response_format=response_format,
|
|
175
193
|
**kwargs,
|
|
176
194
|
)
|
|
195
|
+
|
|
196
|
+
def inpainting(
|
|
197
|
+
self,
|
|
198
|
+
image: bytes,
|
|
199
|
+
mask_image: bytes,
|
|
200
|
+
prompt: Optional[Union[str, List[str]]] = None,
|
|
201
|
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
202
|
+
n: int = 1,
|
|
203
|
+
size: str = "1024*1024",
|
|
204
|
+
response_format: str = "url",
|
|
205
|
+
**kwargs,
|
|
206
|
+
):
|
|
207
|
+
width, height = map(int, re.split(r"[^\d]+", size))
|
|
208
|
+
return self._call_model(
|
|
209
|
+
image=image,
|
|
210
|
+
mask_image=mask_image,
|
|
211
|
+
prompt=prompt,
|
|
212
|
+
negative_prompt=negative_prompt,
|
|
213
|
+
height=height,
|
|
214
|
+
width=width,
|
|
215
|
+
num_images_per_prompt=n,
|
|
216
|
+
response_format=response_format,
|
|
217
|
+
**kwargs,
|
|
218
|
+
)
|