xinference 0.14.3__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/worker.py +18 -9
- xinference/model/audio/chattts.py +4 -3
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/embedding/core.py +2 -0
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/stable_diffusion/core.py +21 -6
- xinference/model/llm/llm_family.py +5 -6
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/utils.py +3 -0
- xinference/model/llm/vllm/core.py +0 -33
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/tools/api.py +1 -1
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/METADATA +20 -12
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/RECORD +70 -28
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-08-
|
|
11
|
+
"date": "2024-08-30T18:54:16+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.14.
|
|
14
|
+
"full-revisionid": "f3d510eceffbbbc41ce065919fd2c48411017573",
|
|
15
|
+
"version": "0.14.4"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/core/worker.py
CHANGED
|
@@ -73,15 +73,15 @@ class WorkerActor(xo.StatelessActor):
|
|
|
73
73
|
self._supervisor_ref: Optional[xo.ActorRefType] = None
|
|
74
74
|
self._main_pool = main_pool
|
|
75
75
|
self._main_pool.recover_sub_pool = self.recover_sub_pool
|
|
76
|
-
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
77
|
-
|
|
78
|
-
|
|
76
|
+
self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
|
|
77
|
+
None
|
|
78
|
+
)
|
|
79
79
|
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
80
80
|
EventCollectorActor
|
|
81
81
|
] = None
|
|
82
|
-
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
|
|
83
|
+
None
|
|
84
|
+
)
|
|
85
85
|
|
|
86
86
|
# internal states.
|
|
87
87
|
# temporary placeholder during model launch process:
|
|
@@ -146,7 +146,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
146
146
|
else:
|
|
147
147
|
recover_count = self._model_uid_to_recover_count.get(model_uid)
|
|
148
148
|
try:
|
|
149
|
-
await self.terminate_model(model_uid)
|
|
149
|
+
await self.terminate_model(model_uid, is_model_die=True)
|
|
150
150
|
except Exception:
|
|
151
151
|
pass
|
|
152
152
|
if recover_count is not None:
|
|
@@ -664,6 +664,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
664
664
|
|
|
665
665
|
ret.sort(key=sort_helper)
|
|
666
666
|
return ret
|
|
667
|
+
elif model_type == "video":
|
|
668
|
+
return []
|
|
667
669
|
elif model_type == "rerank":
|
|
668
670
|
from ..model.rerank.custom import get_user_defined_reranks
|
|
669
671
|
|
|
@@ -703,6 +705,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
703
705
|
for f in get_user_defined_audios():
|
|
704
706
|
if f.model_name == model_name:
|
|
705
707
|
return f
|
|
708
|
+
elif model_type == "video":
|
|
709
|
+
return None
|
|
706
710
|
elif model_type == "rerank":
|
|
707
711
|
from ..model.rerank.custom import get_user_defined_reranks
|
|
708
712
|
|
|
@@ -914,7 +918,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
914
918
|
)
|
|
915
919
|
|
|
916
920
|
@log_async(logger=logger)
|
|
917
|
-
async def terminate_model(self, model_uid: str):
|
|
921
|
+
async def terminate_model(self, model_uid: str, is_model_die=False):
|
|
918
922
|
# Terminate model while its launching is not allow
|
|
919
923
|
if model_uid in self._model_uid_launching_guard:
|
|
920
924
|
raise ValueError(f"{model_uid} is launching")
|
|
@@ -963,11 +967,16 @@ class WorkerActor(xo.StatelessActor):
|
|
|
963
967
|
self._model_uid_to_recover_count.pop(model_uid, None)
|
|
964
968
|
self._model_uid_to_launch_args.pop(model_uid, None)
|
|
965
969
|
|
|
970
|
+
if is_model_die:
|
|
971
|
+
status = LaunchStatus.ERROR.name
|
|
972
|
+
else:
|
|
973
|
+
status = LaunchStatus.TERMINATED.name
|
|
974
|
+
|
|
966
975
|
if self._status_guard_ref is None:
|
|
967
976
|
_ = await self.get_supervisor_ref()
|
|
968
977
|
assert self._status_guard_ref is not None
|
|
969
978
|
await self._status_guard_ref.update_instance_info(
|
|
970
|
-
origin_uid, {"status":
|
|
979
|
+
origin_uid, {"status": status}
|
|
971
980
|
)
|
|
972
981
|
|
|
973
982
|
# Provide an interface for future version of supervisor to call
|
|
@@ -11,11 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
import base64
|
|
15
16
|
import logging
|
|
16
17
|
from io import BytesIO
|
|
17
18
|
from typing import TYPE_CHECKING, Optional
|
|
18
19
|
|
|
20
|
+
from ..utils import set_all_random_seed
|
|
21
|
+
|
|
19
22
|
if TYPE_CHECKING:
|
|
20
23
|
from .core import AudioModelFamilyV1
|
|
21
24
|
|
|
@@ -78,9 +81,7 @@ class ChatTTSModel:
|
|
|
78
81
|
if rnd_spk_emb is None:
|
|
79
82
|
seed = xxhash.xxh32_intdigest(voice)
|
|
80
83
|
|
|
81
|
-
|
|
82
|
-
np.random.seed(seed)
|
|
83
|
-
torch.cuda.manual_seed(seed)
|
|
84
|
+
set_all_random_seed(seed)
|
|
84
85
|
torch.backends.cudnn.deterministic = True
|
|
85
86
|
torch.backends.cudnn.benchmark = False
|
|
86
87
|
|
|
@@ -16,6 +16,8 @@ import logging
|
|
|
16
16
|
from io import BytesIO
|
|
17
17
|
from typing import TYPE_CHECKING, Optional
|
|
18
18
|
|
|
19
|
+
from ..utils import set_all_random_seed
|
|
20
|
+
|
|
19
21
|
if TYPE_CHECKING:
|
|
20
22
|
from .core import AudioModelFamilyV1
|
|
21
23
|
|
|
@@ -67,6 +69,7 @@ class CosyVoiceModel:
|
|
|
67
69
|
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
68
70
|
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
69
71
|
instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
|
|
72
|
+
seed: Optional[int] = kwargs.pop("seed", 0)
|
|
70
73
|
|
|
71
74
|
if "SFT" in self._model_spec.model_name:
|
|
72
75
|
# inference_sft
|
|
@@ -87,9 +90,6 @@ class CosyVoiceModel:
|
|
|
87
90
|
assert (
|
|
88
91
|
prompt_text is None
|
|
89
92
|
), "CosyVoice Instruct model does not support prompt_text"
|
|
90
|
-
assert (
|
|
91
|
-
instruct_text is not None
|
|
92
|
-
), "CosyVoice Instruct model expect a instruct_text"
|
|
93
93
|
else:
|
|
94
94
|
# inference_zero_shot
|
|
95
95
|
# inference_cross_lingual
|
|
@@ -99,6 +99,7 @@ class CosyVoiceModel:
|
|
|
99
99
|
), "CosyVoice model does not support instruct_text"
|
|
100
100
|
|
|
101
101
|
assert self._model is not None
|
|
102
|
+
set_all_random_seed(seed)
|
|
102
103
|
if prompt_speech:
|
|
103
104
|
assert not voice, "voice can't be set with prompt speech."
|
|
104
105
|
with io.BytesIO(prompt_speech) as prompt_speech_io:
|
xinference/model/audio/custom.py
CHANGED
|
@@ -88,6 +88,10 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
|
|
|
88
88
|
if not is_valid_model_name(model_spec.model_name):
|
|
89
89
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
90
90
|
|
|
91
|
+
model_uri = model_spec.model_uri
|
|
92
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
93
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
94
|
+
|
|
91
95
|
with UD_AUDIO_LOCK:
|
|
92
96
|
for model_name in (
|
|
93
97
|
list(BUILTIN_AUDIO_MODELS.keys())
|
|
@@ -102,11 +106,6 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
|
|
|
102
106
|
UD_AUDIOS.append(model_spec)
|
|
103
107
|
|
|
104
108
|
if persist:
|
|
105
|
-
# We only validate model URL when persist is True.
|
|
106
|
-
model_uri = model_spec.model_uri
|
|
107
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
108
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
109
|
-
|
|
110
109
|
persist_path = os.path.join(
|
|
111
110
|
XINFERENCE_MODEL_DIR, "audio", f"{model_spec.model_name}.json"
|
|
112
111
|
)
|
|
@@ -124,6 +124,7 @@ class EmbeddingModel:
|
|
|
124
124
|
model_path: str,
|
|
125
125
|
model_spec: EmbeddingModelSpec,
|
|
126
126
|
device: Optional[str] = None,
|
|
127
|
+
**kwargs,
|
|
127
128
|
):
|
|
128
129
|
self._model_uid = model_uid
|
|
129
130
|
self._model_path = model_path
|
|
@@ -131,6 +132,7 @@ class EmbeddingModel:
|
|
|
131
132
|
self._model = None
|
|
132
133
|
self._counter = 0
|
|
133
134
|
self._model_spec = model_spec
|
|
135
|
+
self._kwargs = kwargs
|
|
134
136
|
|
|
135
137
|
def load(self):
|
|
136
138
|
try:
|
|
@@ -47,6 +47,10 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
|
|
|
47
47
|
if not is_valid_model_name(model_spec.model_name):
|
|
48
48
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
49
49
|
|
|
50
|
+
model_uri = model_spec.model_uri
|
|
51
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
52
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
53
|
+
|
|
50
54
|
with UD_EMBEDDING_LOCK:
|
|
51
55
|
for model_name in (
|
|
52
56
|
list(BUILTIN_EMBEDDING_MODELS.keys())
|
|
@@ -61,11 +65,6 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
|
|
|
61
65
|
UD_EMBEDDINGS.append(model_spec)
|
|
62
66
|
|
|
63
67
|
if persist:
|
|
64
|
-
# We only validate model URL when persist is True.
|
|
65
|
-
model_uri = model_spec.model_uri
|
|
66
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
67
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
68
|
-
|
|
69
68
|
persist_path = os.path.join(
|
|
70
69
|
XINFERENCE_MODEL_DIR, "embedding", f"{model_spec.model_name}.json"
|
|
71
70
|
)
|
|
@@ -99,11 +99,15 @@ def get_flexible_model_descriptions():
|
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
def register_flexible_model(model_spec: FlexibleModelSpec, persist: bool):
|
|
102
|
-
from ..utils import is_valid_model_name
|
|
102
|
+
from ..utils import is_valid_model_name, is_valid_model_uri
|
|
103
103
|
|
|
104
104
|
if not is_valid_model_name(model_spec.model_name):
|
|
105
105
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
106
106
|
|
|
107
|
+
model_uri = model_spec.model_uri
|
|
108
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
109
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
110
|
+
|
|
107
111
|
if model_spec.launcher_args:
|
|
108
112
|
try:
|
|
109
113
|
model_spec.parser_args()
|
xinference/model/image/custom.py
CHANGED
|
@@ -47,6 +47,10 @@ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
|
|
|
47
47
|
if not is_valid_model_name(model_spec.model_name):
|
|
48
48
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
49
49
|
|
|
50
|
+
model_uri = model_spec.model_uri
|
|
51
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
52
|
+
raise ValueError(f"Invalid model URI {model_uri}")
|
|
53
|
+
|
|
50
54
|
with UD_IMAGE_LOCK:
|
|
51
55
|
for model_name in (
|
|
52
56
|
list(BUILTIN_IMAGE_MODELS.keys())
|
|
@@ -60,11 +64,6 @@ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
|
|
|
60
64
|
UD_IMAGES.append(model_spec)
|
|
61
65
|
|
|
62
66
|
if persist:
|
|
63
|
-
# We only validate model URL when persist is True.
|
|
64
|
-
model_uri = model_spec.model_uri
|
|
65
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
66
|
-
raise ValueError(f"Invalid model URI {model_uri}")
|
|
67
|
-
|
|
68
67
|
persist_path = os.path.join(
|
|
69
68
|
XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_name}.json"
|
|
70
69
|
)
|
|
@@ -257,15 +257,19 @@ class DiffusionModel:
|
|
|
257
257
|
self._i2i_model = model = AutoPipelineForImage2Image.from_pipe(
|
|
258
258
|
self._model
|
|
259
259
|
)
|
|
260
|
-
|
|
261
|
-
width, height = map(int, re.split(r"[^\d]+", size))
|
|
262
|
-
kwargs["width"] = width
|
|
263
|
-
kwargs["height"] = height
|
|
260
|
+
|
|
264
261
|
if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
|
|
265
262
|
# Model like SD3 image to image requires image's height and width is times of 16
|
|
266
263
|
# padding the image if specified
|
|
267
264
|
image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
|
|
268
265
|
|
|
266
|
+
if size:
|
|
267
|
+
width, height = map(int, re.split(r"[^\d]+", size))
|
|
268
|
+
if padding_image_to_multiple:
|
|
269
|
+
width, height = image.size
|
|
270
|
+
kwargs["width"] = width
|
|
271
|
+
kwargs["height"] = height
|
|
272
|
+
|
|
269
273
|
self._filter_kwargs(kwargs)
|
|
270
274
|
return self._call_model(
|
|
271
275
|
image=image,
|
|
@@ -279,8 +283,8 @@ class DiffusionModel:
|
|
|
279
283
|
|
|
280
284
|
def inpainting(
|
|
281
285
|
self,
|
|
282
|
-
image:
|
|
283
|
-
mask_image:
|
|
286
|
+
image: PIL.Image,
|
|
287
|
+
mask_image: PIL.Image,
|
|
284
288
|
prompt: Optional[Union[str, List[str]]] = None,
|
|
285
289
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
286
290
|
n: int = 1,
|
|
@@ -306,6 +310,17 @@ class DiffusionModel:
|
|
|
306
310
|
model = self._model
|
|
307
311
|
|
|
308
312
|
width, height = map(int, re.split(r"[^\d]+", size))
|
|
313
|
+
|
|
314
|
+
if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
|
|
315
|
+
# Model like SD3 inpainting requires image's height and width is times of 16
|
|
316
|
+
# padding the image if specified
|
|
317
|
+
image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
|
|
318
|
+
mask_image = self.pad_to_multiple(
|
|
319
|
+
mask_image, multiple=int(padding_image_to_multiple)
|
|
320
|
+
)
|
|
321
|
+
# calculate actual image size after padding
|
|
322
|
+
width, height = image.size
|
|
323
|
+
|
|
309
324
|
return self._call_model(
|
|
310
325
|
image=image,
|
|
311
326
|
mask_image=mask_image,
|
|
@@ -1004,6 +1004,11 @@ def register_llm(llm_family: LLMFamilyV1, persist: bool):
|
|
|
1004
1004
|
if not is_valid_model_name(llm_family.model_name):
|
|
1005
1005
|
raise ValueError(f"Invalid model name {llm_family.model_name}.")
|
|
1006
1006
|
|
|
1007
|
+
for spec in llm_family.model_specs:
|
|
1008
|
+
model_uri = spec.model_uri
|
|
1009
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
1010
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
1011
|
+
|
|
1007
1012
|
with UD_LLM_FAMILIES_LOCK:
|
|
1008
1013
|
for family in BUILTIN_LLM_FAMILIES + UD_LLM_FAMILIES:
|
|
1009
1014
|
if llm_family.model_name == family.model_name:
|
|
@@ -1015,12 +1020,6 @@ def register_llm(llm_family: LLMFamilyV1, persist: bool):
|
|
|
1015
1020
|
generate_engine_config_by_model_family(llm_family)
|
|
1016
1021
|
|
|
1017
1022
|
if persist:
|
|
1018
|
-
# We only validate model URL when persist is True.
|
|
1019
|
-
for spec in llm_family.model_specs:
|
|
1020
|
-
model_uri = spec.model_uri
|
|
1021
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
1022
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
1023
|
-
|
|
1024
1023
|
persist_path = os.path.join(
|
|
1025
1024
|
XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
|
|
1026
1025
|
)
|
|
@@ -113,6 +113,13 @@ class SGLANGModel(LLM):
|
|
|
113
113
|
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
114
114
|
|
|
115
115
|
self._model_config = self._sanitize_model_config(self._model_config)
|
|
116
|
+
|
|
117
|
+
# Fix: GH#2169
|
|
118
|
+
if sgl.__version__ >= "0.2.14":
|
|
119
|
+
self._model_config.setdefault("triton_attention_reduce_in_fp32", False)
|
|
120
|
+
else:
|
|
121
|
+
self._model_config.setdefault("attention_reduce_in_fp32", False)
|
|
122
|
+
|
|
116
123
|
logger.info(
|
|
117
124
|
f"Loading {self.model_uid} with following model config: {self._model_config}"
|
|
118
125
|
)
|
|
@@ -152,7 +159,6 @@ class SGLANGModel(LLM):
|
|
|
152
159
|
else:
|
|
153
160
|
model_config["mem_fraction_static"] = 0.88
|
|
154
161
|
model_config.setdefault("log_level", "info")
|
|
155
|
-
model_config.setdefault("attention_reduce_in_fp32", False)
|
|
156
162
|
|
|
157
163
|
return model_config
|
|
158
164
|
|
xinference/model/llm/utils.py
CHANGED
|
@@ -32,6 +32,7 @@ from ...types import (
|
|
|
32
32
|
Completion,
|
|
33
33
|
CompletionChunk,
|
|
34
34
|
)
|
|
35
|
+
from ..utils import ensure_cache_cleared
|
|
35
36
|
from .llm_family import (
|
|
36
37
|
LlamaCppLLMSpecV1,
|
|
37
38
|
LLMFamilyV1,
|
|
@@ -576,6 +577,7 @@ Begin!"""
|
|
|
576
577
|
return cast(ChatCompletionChunk, chat_chunk)
|
|
577
578
|
|
|
578
579
|
@classmethod
|
|
580
|
+
@ensure_cache_cleared
|
|
579
581
|
def _to_chat_completion_chunks(
|
|
580
582
|
cls,
|
|
581
583
|
chunks: Iterator[CompletionChunk],
|
|
@@ -608,6 +610,7 @@ Begin!"""
|
|
|
608
610
|
i += 1
|
|
609
611
|
|
|
610
612
|
@staticmethod
|
|
613
|
+
@ensure_cache_cleared
|
|
611
614
|
def _to_chat_completion(completion: Completion) -> ChatCompletion:
|
|
612
615
|
return {
|
|
613
616
|
"id": "chat" + completion["id"],
|
|
@@ -643,39 +643,6 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
643
643
|
|
|
644
644
|
|
|
645
645
|
class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
646
|
-
def load(self):
|
|
647
|
-
try:
|
|
648
|
-
import vllm
|
|
649
|
-
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
650
|
-
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
651
|
-
except ImportError:
|
|
652
|
-
error_message = "Failed to import module 'vllm'"
|
|
653
|
-
installation_guide = [
|
|
654
|
-
"Please make sure 'vllm' is installed. ",
|
|
655
|
-
"You can install it by `pip install vllm`\n",
|
|
656
|
-
]
|
|
657
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
658
|
-
|
|
659
|
-
if vllm.__version__ >= "0.3.1":
|
|
660
|
-
# from vllm v0.3.1, it uses cupy as NCCL backend
|
|
661
|
-
# in which cupy will fork a process
|
|
662
|
-
# only for xoscar >= 0.3.0, new process is allowed in subpool
|
|
663
|
-
# besides, xinference set start method as forkserver for unix
|
|
664
|
-
# we need to set it to fork to make cupy NCCL work
|
|
665
|
-
multiprocessing.set_start_method("fork", force=True)
|
|
666
|
-
|
|
667
|
-
self._model_config = self._sanitize_model_config(self._model_config)
|
|
668
|
-
|
|
669
|
-
logger.info(
|
|
670
|
-
f"Loading {self.model_uid} with following model config: {self._model_config}"
|
|
671
|
-
)
|
|
672
|
-
|
|
673
|
-
engine_args = AsyncEngineArgs(
|
|
674
|
-
model=self.model_path,
|
|
675
|
-
**self._model_config,
|
|
676
|
-
)
|
|
677
|
-
self._engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
678
|
-
|
|
679
646
|
@classmethod
|
|
680
647
|
def match(
|
|
681
648
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
@@ -48,6 +48,10 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
|
|
|
48
48
|
if not is_valid_model_name(model_spec.model_name):
|
|
49
49
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
50
50
|
|
|
51
|
+
model_uri = model_spec.model_uri
|
|
52
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
53
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
54
|
+
|
|
51
55
|
with UD_RERANK_LOCK:
|
|
52
56
|
for model_name in (
|
|
53
57
|
list(BUILTIN_RERANK_MODELS.keys())
|
|
@@ -62,11 +66,6 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
|
|
|
62
66
|
UD_RERANKS.append(model_spec)
|
|
63
67
|
|
|
64
68
|
if persist:
|
|
65
|
-
# We only validate model URL when persist is True.
|
|
66
|
-
model_uri = model_spec.model_uri
|
|
67
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
68
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
69
|
-
|
|
70
69
|
persist_path = os.path.join(
|
|
71
70
|
XINFERENCE_MODEL_DIR, "rerank", f"{model_spec.model_name}.json"
|
|
72
71
|
)
|
xinference/model/utils.py
CHANGED
|
@@ -11,17 +11,24 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import functools
|
|
16
|
+
import gc
|
|
17
|
+
import inspect
|
|
14
18
|
import json
|
|
15
19
|
import logging
|
|
16
20
|
import os
|
|
21
|
+
import random
|
|
17
22
|
from json import JSONDecodeError
|
|
18
23
|
from pathlib import Path
|
|
19
24
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
20
25
|
|
|
21
26
|
import huggingface_hub
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
22
29
|
|
|
23
30
|
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
|
|
24
|
-
from ..device_utils import get_available_device, is_device_available
|
|
31
|
+
from ..device_utils import empty_cache, get_available_device, is_device_available
|
|
25
32
|
from .core import CacheableModelSpec
|
|
26
33
|
|
|
27
34
|
logger = logging.getLogger(__name__)
|
|
@@ -348,3 +355,36 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
|
|
|
348
355
|
return int(model_size)
|
|
349
356
|
else:
|
|
350
357
|
return str(model_size)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def ensure_cache_cleared(func: Callable):
|
|
361
|
+
assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
|
362
|
+
func
|
|
363
|
+
)
|
|
364
|
+
if inspect.isgeneratorfunction(func):
|
|
365
|
+
|
|
366
|
+
@functools.wraps(func)
|
|
367
|
+
def inner(*args, **kwargs):
|
|
368
|
+
for obj in func(*args, **kwargs):
|
|
369
|
+
yield obj
|
|
370
|
+
gc.collect()
|
|
371
|
+
empty_cache()
|
|
372
|
+
|
|
373
|
+
else:
|
|
374
|
+
|
|
375
|
+
@functools.wraps(func)
|
|
376
|
+
def inner(*args, **kwargs):
|
|
377
|
+
try:
|
|
378
|
+
return func(*args, **kwargs)
|
|
379
|
+
finally:
|
|
380
|
+
gc.collect()
|
|
381
|
+
empty_cache()
|
|
382
|
+
|
|
383
|
+
return inner
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def set_all_random_seed(seed: int):
|
|
387
|
+
random.seed(seed)
|
|
388
|
+
np.random.seed(seed)
|
|
389
|
+
torch.manual_seed(seed)
|
|
390
|
+
torch.cuda.manual_seed_all(seed)
|
xinference/model/video/core.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Dict, List, Literal, Optional, Tuple
|
|
17
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
@@ -44,6 +44,8 @@ class VideoModelFamilyV1(CacheableModelSpec):
|
|
|
44
44
|
model_revision: str
|
|
45
45
|
model_hub: str = "huggingface"
|
|
46
46
|
model_ability: Optional[List[str]]
|
|
47
|
+
default_model_config: Optional[Dict[str, Any]]
|
|
48
|
+
default_generate_config: Optional[Dict[str, Any]]
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
class VideoModelDescription(ModelDescription):
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
import base64
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
-
import sys
|
|
19
18
|
import time
|
|
20
19
|
import uuid
|
|
21
20
|
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -24,10 +23,9 @@ from typing import TYPE_CHECKING, List, Union
|
|
|
24
23
|
|
|
25
24
|
import numpy as np
|
|
26
25
|
import PIL.Image
|
|
27
|
-
import torch
|
|
28
26
|
|
|
29
27
|
from ...constants import XINFERENCE_VIDEO_DIR
|
|
30
|
-
from ...device_utils import move_model_to_available_device
|
|
28
|
+
from ...device_utils import gpu_count, move_model_to_available_device
|
|
31
29
|
from ...types import Video, VideoList
|
|
32
30
|
|
|
33
31
|
if TYPE_CHECKING:
|
|
@@ -76,41 +74,58 @@ class DiffUsersVideoModel:
|
|
|
76
74
|
def load(self):
|
|
77
75
|
import torch
|
|
78
76
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
77
|
+
kwargs = self._model_spec.default_model_config.copy()
|
|
78
|
+
kwargs.update(self._kwargs)
|
|
79
|
+
|
|
80
|
+
scheduler_cls_name = kwargs.pop("scheduler", None)
|
|
81
|
+
|
|
82
|
+
torch_dtype = kwargs.get("torch_dtype")
|
|
85
83
|
if isinstance(torch_dtype, str):
|
|
86
|
-
|
|
84
|
+
kwargs["torch_dtype"] = getattr(torch, torch_dtype)
|
|
85
|
+
logger.debug("Loading video model with kwargs: %s", kwargs)
|
|
87
86
|
|
|
88
87
|
if self._model_spec.model_family == "CogVideoX":
|
|
88
|
+
import diffusers
|
|
89
89
|
from diffusers import CogVideoXPipeline
|
|
90
90
|
|
|
91
|
-
self._model = CogVideoXPipeline.from_pretrained(
|
|
92
|
-
self._model_path, **
|
|
91
|
+
pipeline = self._model = CogVideoXPipeline.from_pretrained(
|
|
92
|
+
self._model_path, **kwargs
|
|
93
93
|
)
|
|
94
94
|
else:
|
|
95
95
|
raise Exception(
|
|
96
96
|
f"Unsupported model family: {self._model_spec.model_family}"
|
|
97
97
|
)
|
|
98
98
|
|
|
99
|
-
if
|
|
99
|
+
if scheduler_cls_name:
|
|
100
|
+
logger.debug("Using scheduler: %s", scheduler_cls_name)
|
|
101
|
+
pipeline.scheduler = getattr(diffusers, scheduler_cls_name).from_config(
|
|
102
|
+
pipeline.scheduler.config, timestep_spacing="trailing"
|
|
103
|
+
)
|
|
104
|
+
if kwargs.get("compile_graph", False):
|
|
105
|
+
pipeline.transformer = torch.compile(
|
|
106
|
+
pipeline.transformer, mode="max-autotune", fullgraph=True
|
|
107
|
+
)
|
|
108
|
+
if kwargs.get("cpu_offload", False):
|
|
100
109
|
logger.debug("CPU offloading model")
|
|
101
|
-
|
|
102
|
-
|
|
110
|
+
pipeline.enable_model_cpu_offload()
|
|
111
|
+
if kwargs.get("sequential_cpu_offload", True):
|
|
112
|
+
pipeline.enable_sequential_cpu_offload()
|
|
113
|
+
pipeline.vae.enable_slicing()
|
|
114
|
+
pipeline.vae.enable_tiling()
|
|
115
|
+
elif not kwargs.get("device_map"):
|
|
103
116
|
logger.debug("Loading model to available device")
|
|
104
|
-
|
|
117
|
+
if gpu_count() > 1:
|
|
118
|
+
kwargs["device_map"] = "balanced"
|
|
119
|
+
else:
|
|
120
|
+
pipeline = move_model_to_available_device(self._model)
|
|
105
121
|
# Recommended if your computer has < 64 GB of RAM
|
|
106
|
-
|
|
122
|
+
pipeline.enable_attention_slicing()
|
|
107
123
|
|
|
108
124
|
def text_to_video(
|
|
109
125
|
self,
|
|
110
126
|
prompt: str,
|
|
111
127
|
n: int = 1,
|
|
112
128
|
num_inference_steps: int = 50,
|
|
113
|
-
guidance_scale: int = 6,
|
|
114
129
|
response_format: str = "b64_json",
|
|
115
130
|
**kwargs,
|
|
116
131
|
) -> VideoList:
|
|
@@ -121,31 +136,19 @@ class DiffUsersVideoModel:
|
|
|
121
136
|
# from diffusers.utils import export_to_video
|
|
122
137
|
from ...device_utils import empty_cache
|
|
123
138
|
|
|
139
|
+
assert self._model is not None
|
|
140
|
+
assert callable(self._model)
|
|
141
|
+
generate_kwargs = self._model_spec.default_generate_config.copy()
|
|
142
|
+
generate_kwargs.update(kwargs)
|
|
143
|
+
generate_kwargs["num_videos_per_prompt"] = n
|
|
124
144
|
logger.debug(
|
|
125
145
|
"diffusers text_to_video args: %s",
|
|
126
|
-
|
|
146
|
+
generate_kwargs,
|
|
127
147
|
)
|
|
128
|
-
assert self._model is not None
|
|
129
|
-
if self._kwargs.get("cpu_offload"):
|
|
130
|
-
# if enabled cpu offload,
|
|
131
|
-
# the model.device would be CPU
|
|
132
|
-
device = "cuda"
|
|
133
|
-
else:
|
|
134
|
-
device = self._model.device
|
|
135
|
-
prompt_embeds, _ = self._model.encode_prompt(
|
|
136
|
-
prompt=prompt,
|
|
137
|
-
do_classifier_free_guidance=True,
|
|
138
|
-
num_videos_per_prompt=n,
|
|
139
|
-
max_sequence_length=226,
|
|
140
|
-
device=device,
|
|
141
|
-
dtype=torch.float16,
|
|
142
|
-
)
|
|
143
|
-
assert callable(self._model)
|
|
144
148
|
output = self._model(
|
|
149
|
+
prompt=prompt,
|
|
145
150
|
num_inference_steps=num_inference_steps,
|
|
146
|
-
|
|
147
|
-
prompt_embeds=prompt_embeds,
|
|
148
|
-
**kwargs,
|
|
151
|
+
**generate_kwargs,
|
|
149
152
|
)
|
|
150
153
|
|
|
151
154
|
# clean cache
|