xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client/restful/async_restful_client.py +8 -13
- xinference/client/restful/restful_client.py +6 -2
- xinference/core/chat_interface.py +6 -4
- xinference/core/media_interface.py +5 -0
- xinference/core/model.py +1 -5
- xinference/core/supervisor.py +117 -68
- xinference/core/worker.py +49 -37
- xinference/deploy/test/test_cmdline.py +2 -6
- xinference/model/audio/__init__.py +26 -23
- xinference/model/audio/chattts.py +3 -2
- xinference/model/audio/core.py +49 -98
- xinference/model/audio/cosyvoice.py +3 -2
- xinference/model/audio/custom.py +28 -73
- xinference/model/audio/f5tts.py +3 -2
- xinference/model/audio/f5tts_mlx.py +3 -2
- xinference/model/audio/fish_speech.py +3 -2
- xinference/model/audio/funasr.py +17 -4
- xinference/model/audio/kokoro.py +3 -2
- xinference/model/audio/megatts.py +3 -2
- xinference/model/audio/melotts.py +3 -2
- xinference/model/audio/model_spec.json +572 -171
- xinference/model/audio/utils.py +0 -6
- xinference/model/audio/whisper.py +3 -2
- xinference/model/audio/whisper_mlx.py +3 -2
- xinference/model/cache_manager.py +141 -0
- xinference/model/core.py +6 -49
- xinference/model/custom.py +174 -0
- xinference/model/embedding/__init__.py +67 -56
- xinference/model/embedding/cache_manager.py +35 -0
- xinference/model/embedding/core.py +104 -84
- xinference/model/embedding/custom.py +55 -78
- xinference/model/embedding/embed_family.py +80 -31
- xinference/model/embedding/flag/core.py +21 -5
- xinference/model/embedding/llama_cpp/__init__.py +0 -0
- xinference/model/embedding/llama_cpp/core.py +234 -0
- xinference/model/embedding/model_spec.json +968 -103
- xinference/model/embedding/sentence_transformers/core.py +30 -20
- xinference/model/embedding/vllm/core.py +11 -5
- xinference/model/flexible/__init__.py +8 -2
- xinference/model/flexible/core.py +26 -119
- xinference/model/flexible/custom.py +69 -0
- xinference/model/flexible/launchers/image_process_launcher.py +1 -0
- xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
- xinference/model/flexible/launchers/transformers_launcher.py +15 -3
- xinference/model/flexible/launchers/yolo_launcher.py +5 -1
- xinference/model/image/__init__.py +20 -20
- xinference/model/image/cache_manager.py +62 -0
- xinference/model/image/core.py +70 -182
- xinference/model/image/custom.py +28 -72
- xinference/model/image/model_spec.json +402 -119
- xinference/model/image/ocr/got_ocr2.py +3 -2
- xinference/model/image/stable_diffusion/core.py +22 -7
- xinference/model/image/stable_diffusion/mlx.py +6 -6
- xinference/model/image/utils.py +2 -2
- xinference/model/llm/__init__.py +71 -94
- xinference/model/llm/cache_manager.py +292 -0
- xinference/model/llm/core.py +37 -111
- xinference/model/llm/custom.py +88 -0
- xinference/model/llm/llama_cpp/core.py +5 -7
- xinference/model/llm/llm_family.json +16260 -8151
- xinference/model/llm/llm_family.py +138 -839
- xinference/model/llm/lmdeploy/core.py +5 -7
- xinference/model/llm/memory.py +3 -4
- xinference/model/llm/mlx/core.py +6 -8
- xinference/model/llm/reasoning_parser.py +3 -1
- xinference/model/llm/sglang/core.py +32 -14
- xinference/model/llm/transformers/chatglm.py +3 -7
- xinference/model/llm/transformers/core.py +49 -27
- xinference/model/llm/transformers/deepseek_v2.py +2 -2
- xinference/model/llm/transformers/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
- xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
- xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
- xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
- xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
- xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
- xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
- xinference/model/llm/transformers/opt.py +3 -7
- xinference/model/llm/utils.py +34 -49
- xinference/model/llm/vllm/core.py +77 -27
- xinference/model/llm/vllm/xavier/engine.py +5 -3
- xinference/model/llm/vllm/xavier/scheduler.py +10 -6
- xinference/model/llm/vllm/xavier/transfer.py +1 -1
- xinference/model/rerank/__init__.py +26 -25
- xinference/model/rerank/core.py +47 -87
- xinference/model/rerank/custom.py +25 -71
- xinference/model/rerank/model_spec.json +158 -33
- xinference/model/rerank/utils.py +2 -2
- xinference/model/utils.py +115 -54
- xinference/model/video/__init__.py +13 -17
- xinference/model/video/core.py +44 -102
- xinference/model/video/diffusers.py +4 -3
- xinference/model/video/model_spec.json +90 -21
- xinference/types.py +5 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
- xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
- xinference/web/ui/src/locales/en.json +0 -1
- xinference/web/ui/src/locales/ja.json +0 -1
- xinference/web/ui/src/locales/ko.json +0 -1
- xinference/web/ui/src/locales/zh.json +0 -1
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
- xinference/model/audio/model_spec_modelscope.json +0 -231
- xinference/model/embedding/model_spec_modelscope.json +0 -293
- xinference/model/embedding/utils.py +0 -18
- xinference/model/image/model_spec_modelscope.json +0 -375
- xinference/model/llm/llama_cpp/memory.py +0 -457
- xinference/model/llm/llm_family_csghub.json +0 -56
- xinference/model/llm/llm_family_modelscope.json +0 -8700
- xinference/model/llm/llm_family_openmind_hub.json +0 -1019
- xinference/model/rerank/model_spec_modelscope.json +0 -85
- xinference/model/video/model_spec_modelscope.json +0 -184
- xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
- xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
- /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
|
@@ -14,20 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
import importlib.util
|
|
16
16
|
import logging
|
|
17
|
-
from collections import defaultdict
|
|
18
17
|
from typing import List, Optional, Union, no_type_check
|
|
19
18
|
|
|
20
19
|
import numpy as np
|
|
21
20
|
import torch
|
|
22
21
|
|
|
23
|
-
from ....
|
|
24
|
-
from
|
|
22
|
+
from ....device_utils import is_device_available
|
|
23
|
+
from ....types import Embedding, EmbeddingData, EmbeddingUsage
|
|
24
|
+
from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
|
|
25
25
|
|
|
26
26
|
logger = logging.getLogger(__name__)
|
|
27
|
-
|
|
28
|
-
# Used for check whether the model is cached.
|
|
29
|
-
# Init when registering all the builtin models.
|
|
30
|
-
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
31
27
|
SENTENCE_TRANSFORMER_MODEL_LIST: List[str] = []
|
|
32
28
|
|
|
33
29
|
|
|
@@ -76,8 +72,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
76
72
|
torch_dtype = torch.float32
|
|
77
73
|
|
|
78
74
|
if (
|
|
79
|
-
"gte" in self.
|
|
80
|
-
and "qwen2" in self.
|
|
75
|
+
"gte" in self.model_family.model_name.lower()
|
|
76
|
+
and "qwen2" in self.model_family.model_name.lower()
|
|
81
77
|
):
|
|
82
78
|
model_kwargs = {"device_map": "auto"}
|
|
83
79
|
if torch_dtype:
|
|
@@ -87,10 +83,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
87
83
|
device=self._device,
|
|
88
84
|
model_kwargs=model_kwargs,
|
|
89
85
|
)
|
|
90
|
-
elif "qwen3" in self.
|
|
86
|
+
elif "qwen3" in self.model_family.model_name.lower():
|
|
91
87
|
# qwen3 embedding
|
|
92
88
|
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
|
|
93
|
-
flash_attn_enabled = self._kwargs.get(
|
|
89
|
+
flash_attn_enabled = self._kwargs.get(
|
|
90
|
+
"enable_flash_attn", is_device_available("cuda")
|
|
91
|
+
)
|
|
94
92
|
model_kwargs = {"device_map": "auto"}
|
|
95
93
|
tokenizer_kwargs = {}
|
|
96
94
|
if flash_attn_installed and flash_attn_enabled:
|
|
@@ -119,7 +117,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
119
117
|
trust_remote_code=True,
|
|
120
118
|
)
|
|
121
119
|
|
|
122
|
-
|
|
120
|
+
if hasattr(self._model, "tokenizer"):
|
|
121
|
+
self._tokenizer = self._model.tokenizer
|
|
123
122
|
|
|
124
123
|
def create_embedding(
|
|
125
124
|
self,
|
|
@@ -227,8 +226,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
227
226
|
device = model._target_device
|
|
228
227
|
|
|
229
228
|
if (
|
|
230
|
-
"gte" in self.
|
|
231
|
-
and "qwen2" in self.
|
|
229
|
+
"gte" in self.model_family.model_name.lower()
|
|
230
|
+
and "qwen2" in self.model_family.model_name.lower()
|
|
232
231
|
):
|
|
233
232
|
model.to(device)
|
|
234
233
|
|
|
@@ -254,7 +253,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
254
253
|
features.update(extra_features)
|
|
255
254
|
# when batching, the attention mask 1 means there is a token
|
|
256
255
|
# thus we just sum up it to get the total number of tokens
|
|
257
|
-
if
|
|
256
|
+
if (
|
|
257
|
+
"clip" in self.model_family.model_name.lower()
|
|
258
|
+
or "jina-embeddings-v4" in self.model_family.model_name.lower()
|
|
259
|
+
):
|
|
258
260
|
if "input_ids" in features and hasattr(
|
|
259
261
|
features["input_ids"], "numel"
|
|
260
262
|
):
|
|
@@ -322,8 +324,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
322
324
|
|
|
323
325
|
# seems already support prompt in embedding model
|
|
324
326
|
if (
|
|
325
|
-
"gte" in self.
|
|
326
|
-
and "qwen2" in self.
|
|
327
|
+
"gte" in self.model_family.model_name.lower()
|
|
328
|
+
and "qwen2" in self.model_family.model_name.lower()
|
|
327
329
|
):
|
|
328
330
|
all_embeddings, all_token_nums = encode(
|
|
329
331
|
self._model,
|
|
@@ -332,7 +334,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
332
334
|
convert_to_numpy=False,
|
|
333
335
|
**kwargs,
|
|
334
336
|
)
|
|
335
|
-
elif
|
|
337
|
+
elif (
|
|
338
|
+
"clip" in self.model_family.model_name.lower()
|
|
339
|
+
or "jina-embeddings-v4" in self.model_family.model_name.lower()
|
|
340
|
+
):
|
|
336
341
|
import base64
|
|
337
342
|
import re
|
|
338
343
|
from io import BytesIO
|
|
@@ -409,6 +414,11 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
|
|
|
409
414
|
return importlib.util.find_spec("sentence_transformers") is not None
|
|
410
415
|
|
|
411
416
|
@classmethod
|
|
412
|
-
def match_json(
|
|
417
|
+
def match_json(
|
|
418
|
+
cls,
|
|
419
|
+
model_family: EmbeddingModelFamilyV2,
|
|
420
|
+
model_spec: EmbeddingSpecV1,
|
|
421
|
+
quantization: str,
|
|
422
|
+
) -> bool:
|
|
413
423
|
# As default embedding engine, sentence-transformer support all models
|
|
414
|
-
return
|
|
424
|
+
return model_spec.model_format in ["pytorch"]
|
|
@@ -17,7 +17,7 @@ import logging
|
|
|
17
17
|
from typing import List, Union
|
|
18
18
|
|
|
19
19
|
from ....types import Embedding, EmbeddingData, EmbeddingUsage
|
|
20
|
-
from ..core import EmbeddingModel,
|
|
20
|
+
from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
|
|
21
21
|
|
|
22
22
|
logger = logging.getLogger(__name__)
|
|
23
23
|
SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
|
|
@@ -88,8 +88,14 @@ class VLLMEmbeddingModel(EmbeddingModel):
|
|
|
88
88
|
return importlib.util.find_spec("vllm") is not None
|
|
89
89
|
|
|
90
90
|
@classmethod
|
|
91
|
-
def match_json(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
91
|
+
def match_json(
|
|
92
|
+
cls,
|
|
93
|
+
model_family: EmbeddingModelFamilyV2,
|
|
94
|
+
model_spec: EmbeddingSpecV1,
|
|
95
|
+
quantization: str,
|
|
96
|
+
) -> bool:
|
|
97
|
+
if model_spec.model_format in ["pytorch"]:
|
|
98
|
+
prefix = model_family.model_name.split("-", 1)[0]
|
|
99
|
+
if prefix in SUPPORTED_MODELS_PREFIXES:
|
|
100
|
+
return True
|
|
95
101
|
return False
|
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
14
|
import codecs
|
|
16
15
|
import json
|
|
17
16
|
import logging
|
|
@@ -25,6 +24,8 @@ from .core import (
|
|
|
25
24
|
FlexibleModelSpec,
|
|
26
25
|
generate_flexible_model_description,
|
|
27
26
|
get_flexible_model_descriptions,
|
|
27
|
+
)
|
|
28
|
+
from .custom import (
|
|
28
29
|
get_flexible_models,
|
|
29
30
|
register_flexible_model,
|
|
30
31
|
unregister_flexible_model,
|
|
@@ -34,7 +35,12 @@ logger = logging.getLogger(__name__)
|
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
def register_custom_model():
|
|
37
|
-
|
|
38
|
+
from ..custom import migrate_from_v1_to_v2
|
|
39
|
+
|
|
40
|
+
# migrate from v1 to v2 first
|
|
41
|
+
migrate_from_v1_to_v2("flexible", FlexibleModelSpec)
|
|
42
|
+
|
|
43
|
+
model_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "flexible")
|
|
38
44
|
if os.path.isdir(model_dir):
|
|
39
45
|
for f in os.listdir(model_dir):
|
|
40
46
|
try:
|
|
@@ -14,21 +14,19 @@
|
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
|
-
import os
|
|
18
17
|
from collections import defaultdict
|
|
19
|
-
from
|
|
20
|
-
from typing import Dict, List, Optional, Tuple
|
|
18
|
+
from typing import Dict, List, Optional
|
|
21
19
|
|
|
22
|
-
from ...
|
|
23
|
-
from ..core import CacheableModelSpec,
|
|
20
|
+
from ..._compat import Literal
|
|
21
|
+
from ..core import CacheableModelSpec, VirtualEnvSettings
|
|
22
|
+
from ..utils import ModelInstanceInfoMixin
|
|
24
23
|
from .utils import get_launcher
|
|
25
24
|
|
|
26
25
|
logger = logging.getLogger(__name__)
|
|
27
26
|
|
|
28
|
-
FLEXIBLE_MODEL_LOCK = Lock()
|
|
29
27
|
|
|
30
|
-
|
|
31
|
-
|
|
28
|
+
class FlexibleModelSpec(CacheableModelSpec, ModelInstanceInfoMixin):
|
|
29
|
+
version: Literal[1, 2] = 2
|
|
32
30
|
model_id: Optional[str] # type: ignore
|
|
33
31
|
model_description: Optional[str]
|
|
34
32
|
model_uri: Optional[str]
|
|
@@ -39,42 +37,26 @@ class FlexibleModelSpec(CacheableModelSpec):
|
|
|
39
37
|
def parser_args(self):
|
|
40
38
|
return json.loads(self.launcher_args)
|
|
41
39
|
|
|
40
|
+
class Config:
|
|
41
|
+
extra = "allow"
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
def __init__(
|
|
45
|
-
self,
|
|
46
|
-
address: Optional[str],
|
|
47
|
-
devices: Optional[List[str]],
|
|
48
|
-
model_spec: FlexibleModelSpec,
|
|
49
|
-
model_path: Optional[str] = None,
|
|
50
|
-
):
|
|
51
|
-
super().__init__(address, devices, model_path=model_path)
|
|
52
|
-
self._model_spec = model_spec
|
|
53
|
-
|
|
54
|
-
@property
|
|
55
|
-
def spec(self):
|
|
56
|
-
return self._model_spec
|
|
57
|
-
|
|
58
|
-
def to_dict(self):
|
|
43
|
+
def to_description(self):
|
|
59
44
|
return {
|
|
60
45
|
"model_type": "flexible",
|
|
61
|
-
"address": self
|
|
62
|
-
"accelerators": self
|
|
63
|
-
"model_name": self.
|
|
64
|
-
"launcher": self.
|
|
65
|
-
"launcher_args": self.
|
|
46
|
+
"address": getattr(self, "address", None),
|
|
47
|
+
"accelerators": getattr(self, "accelerators", None),
|
|
48
|
+
"model_name": self.model_name,
|
|
49
|
+
"launcher": self.launcher,
|
|
50
|
+
"launcher_args": self.launcher_args,
|
|
66
51
|
}
|
|
67
52
|
|
|
68
|
-
def get_model_version(self) -> str:
|
|
69
|
-
return f"{self._model_spec.model_name}"
|
|
70
|
-
|
|
71
53
|
def to_version_info(self):
|
|
72
54
|
return {
|
|
73
|
-
"model_version": self.
|
|
55
|
+
"model_version": self.model_name,
|
|
74
56
|
"cache_status": True,
|
|
75
|
-
"model_file_location": self.
|
|
76
|
-
"launcher": self.
|
|
77
|
-
"launcher_args": self.
|
|
57
|
+
"model_file_location": self.model_uri,
|
|
58
|
+
"launcher": self.launcher,
|
|
59
|
+
"launcher_args": self.launcher_args,
|
|
78
60
|
}
|
|
79
61
|
|
|
80
62
|
|
|
@@ -82,9 +64,7 @@ def generate_flexible_model_description(
|
|
|
82
64
|
model_spec: FlexibleModelSpec,
|
|
83
65
|
) -> Dict[str, List[Dict]]:
|
|
84
66
|
res = defaultdict(list)
|
|
85
|
-
res[model_spec.model_name].append(
|
|
86
|
-
FlexibleModelDescription(None, None, model_spec).to_version_info()
|
|
87
|
-
)
|
|
67
|
+
res[model_spec.model_name].append(model_spec.to_version_info())
|
|
88
68
|
return res
|
|
89
69
|
|
|
90
70
|
|
|
@@ -92,93 +72,22 @@ FLEXIBLE_MODELS: List[FlexibleModelSpec] = []
|
|
|
92
72
|
FLEXIBLE_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
93
73
|
|
|
94
74
|
|
|
95
|
-
def get_flexible_models():
|
|
96
|
-
with FLEXIBLE_MODEL_LOCK:
|
|
97
|
-
return FLEXIBLE_MODELS.copy()
|
|
98
|
-
|
|
99
|
-
|
|
100
75
|
def get_flexible_model_descriptions():
|
|
101
76
|
import copy
|
|
102
77
|
|
|
103
78
|
return copy.deepcopy(FLEXIBLE_MODEL_DESCRIPTIONS)
|
|
104
79
|
|
|
105
80
|
|
|
106
|
-
def register_flexible_model(model_spec: FlexibleModelSpec, persist: bool):
|
|
107
|
-
from ..utils import is_valid_model_name, is_valid_model_uri
|
|
108
|
-
|
|
109
|
-
if not is_valid_model_name(model_spec.model_name):
|
|
110
|
-
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
111
|
-
|
|
112
|
-
model_uri = model_spec.model_uri
|
|
113
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
114
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
115
|
-
|
|
116
|
-
if model_spec.launcher_args:
|
|
117
|
-
try:
|
|
118
|
-
model_spec.parser_args()
|
|
119
|
-
except Exception:
|
|
120
|
-
raise ValueError(f"Invalid model launcher args {model_spec.launcher_args}.")
|
|
121
|
-
|
|
122
|
-
with FLEXIBLE_MODEL_LOCK:
|
|
123
|
-
for model_name in [spec.model_name for spec in FLEXIBLE_MODELS]:
|
|
124
|
-
if model_spec.model_name == model_name:
|
|
125
|
-
raise ValueError(
|
|
126
|
-
f"Model name conflicts with existing model {model_spec.model_name}"
|
|
127
|
-
)
|
|
128
|
-
FLEXIBLE_MODELS.append(model_spec)
|
|
129
|
-
|
|
130
|
-
if persist:
|
|
131
|
-
persist_path = os.path.join(
|
|
132
|
-
XINFERENCE_MODEL_DIR, "flexible", f"{model_spec.model_name}.json"
|
|
133
|
-
)
|
|
134
|
-
os.makedirs(os.path.dirname(persist_path), exist_ok=True)
|
|
135
|
-
with open(persist_path, mode="w") as fd:
|
|
136
|
-
fd.write(model_spec.json())
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def unregister_flexible_model(model_name: str, raise_error: bool = True):
|
|
140
|
-
with FLEXIBLE_MODEL_LOCK:
|
|
141
|
-
model_spec = None
|
|
142
|
-
for i, f in enumerate(FLEXIBLE_MODELS):
|
|
143
|
-
if f.model_name == model_name:
|
|
144
|
-
model_spec = f
|
|
145
|
-
break
|
|
146
|
-
if model_spec:
|
|
147
|
-
FLEXIBLE_MODELS.remove(model_spec)
|
|
148
|
-
|
|
149
|
-
persist_path = os.path.join(
|
|
150
|
-
XINFERENCE_MODEL_DIR, "flexible", f"{model_spec.model_name}.json"
|
|
151
|
-
)
|
|
152
|
-
if os.path.exists(persist_path):
|
|
153
|
-
os.remove(persist_path)
|
|
154
|
-
|
|
155
|
-
cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
|
|
156
|
-
if os.path.exists(cache_dir):
|
|
157
|
-
logger.warning(
|
|
158
|
-
f"Remove the cache of user-defined model {model_spec.model_name}. "
|
|
159
|
-
f"Cache directory: {cache_dir}"
|
|
160
|
-
)
|
|
161
|
-
if os.path.islink(cache_dir):
|
|
162
|
-
os.remove(cache_dir)
|
|
163
|
-
else:
|
|
164
|
-
logger.warning(
|
|
165
|
-
f"Cache directory is not a soft link, please remove it manually."
|
|
166
|
-
)
|
|
167
|
-
else:
|
|
168
|
-
if raise_error:
|
|
169
|
-
raise ValueError(f"Model {model_name} not found")
|
|
170
|
-
else:
|
|
171
|
-
logger.warning(f"Model {model_name} not found")
|
|
172
|
-
|
|
173
|
-
|
|
174
81
|
class FlexibleModel:
|
|
175
82
|
def __init__(
|
|
176
83
|
self,
|
|
177
84
|
model_uid: str,
|
|
178
85
|
model_path: str,
|
|
86
|
+
model_family: FlexibleModelSpec,
|
|
179
87
|
device: Optional[str] = None,
|
|
180
88
|
config: Optional[Dict] = None,
|
|
181
89
|
):
|
|
90
|
+
self.model_family = model_family
|
|
182
91
|
self._model_uid = model_uid
|
|
183
92
|
self._model_path = model_path
|
|
184
93
|
self._device = device
|
|
@@ -213,19 +122,20 @@ class FlexibleModel:
|
|
|
213
122
|
|
|
214
123
|
|
|
215
124
|
def match_flexible_model(model_name):
|
|
125
|
+
from .custom import get_flexible_models
|
|
126
|
+
|
|
216
127
|
for model_spec in get_flexible_models():
|
|
217
128
|
if model_name == model_spec.model_name:
|
|
218
129
|
return model_spec
|
|
130
|
+
return None
|
|
219
131
|
|
|
220
132
|
|
|
221
133
|
def create_flexible_model_instance(
|
|
222
|
-
subpool_addr: str,
|
|
223
|
-
devices: List[str],
|
|
224
134
|
model_uid: str,
|
|
225
135
|
model_name: str,
|
|
226
136
|
model_path: Optional[str] = None,
|
|
227
137
|
**kwargs,
|
|
228
|
-
) ->
|
|
138
|
+
) -> FlexibleModel:
|
|
229
139
|
model_spec = match_flexible_model(model_name)
|
|
230
140
|
if not model_path:
|
|
231
141
|
model_path = model_spec.model_uri
|
|
@@ -237,7 +147,4 @@ def create_flexible_model_instance(
|
|
|
237
147
|
model_uid=model_uid, model_spec=model_spec, **kwargs
|
|
238
148
|
)
|
|
239
149
|
|
|
240
|
-
|
|
241
|
-
subpool_addr, devices, model_spec, model_path=model_path
|
|
242
|
-
)
|
|
243
|
-
return model, model_description
|
|
150
|
+
return model
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from ..custom import ModelRegistry
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from .core import FlexibleModelSpec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FlexibleModelRegistry(ModelRegistry):
|
|
10
|
+
model_type = "flexible"
|
|
11
|
+
|
|
12
|
+
def __init__(self):
|
|
13
|
+
from .core import FLEXIBLE_MODELS
|
|
14
|
+
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.models = FLEXIBLE_MODELS
|
|
17
|
+
self.builtin_models = []
|
|
18
|
+
|
|
19
|
+
def register(self, model_spec: "FlexibleModelSpec", persist: bool):
|
|
20
|
+
from ..cache_manager import CacheManager
|
|
21
|
+
from ..utils import is_valid_model_name, is_valid_model_uri
|
|
22
|
+
|
|
23
|
+
if not is_valid_model_name(model_spec.model_name):
|
|
24
|
+
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
25
|
+
|
|
26
|
+
model_uri = model_spec.model_uri
|
|
27
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
28
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
29
|
+
|
|
30
|
+
if model_spec.launcher_args:
|
|
31
|
+
try:
|
|
32
|
+
model_spec.parser_args()
|
|
33
|
+
except Exception:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Invalid model launcher args {model_spec.launcher_args}."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
with self.lock:
|
|
39
|
+
for model_name in [spec.model_name for spec in self.models]:
|
|
40
|
+
if model_spec.model_name == model_name:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Model name conflicts with existing model {model_spec.model_name}"
|
|
43
|
+
)
|
|
44
|
+
self.models.append(model_spec)
|
|
45
|
+
|
|
46
|
+
if persist:
|
|
47
|
+
cache_manager = CacheManager(model_spec)
|
|
48
|
+
cache_manager.register_custom_model(self.model_type)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_flexible_models():
|
|
52
|
+
from ..custom import RegistryManager
|
|
53
|
+
|
|
54
|
+
registry = RegistryManager.get_registry("flexible")
|
|
55
|
+
return registry.get_custom_models()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def register_flexible_model(model_spec: "FlexibleModelSpec", persist: bool):
|
|
59
|
+
from ..custom import RegistryManager
|
|
60
|
+
|
|
61
|
+
registry = RegistryManager.get_registry("flexible")
|
|
62
|
+
registry.register(model_spec, persist)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def unregister_flexible_model(model_name: str, raise_error: bool = True):
|
|
66
|
+
from ..custom import RegistryManager
|
|
67
|
+
|
|
68
|
+
registry = RegistryManager.get_registry("flexible")
|
|
69
|
+
registry.unregister(model_name, raise_error)
|
|
@@ -63,6 +63,7 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
|
|
|
63
63
|
return ImageRemoveBackgroundModel(
|
|
64
64
|
model_uid=model_uid,
|
|
65
65
|
model_path=model_spec.model_uri, # type: ignore
|
|
66
|
+
model_family=model_spec,
|
|
66
67
|
device=device,
|
|
67
68
|
config=kwargs,
|
|
68
69
|
)
|
|
@@ -43,5 +43,9 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
|
|
|
43
43
|
raise ValueError("model_path required")
|
|
44
44
|
|
|
45
45
|
return ModelScopePipelineModel(
|
|
46
|
-
model_uid=model_uid,
|
|
46
|
+
model_uid=model_uid,
|
|
47
|
+
model_path=model_path,
|
|
48
|
+
model_family=model_spec,
|
|
49
|
+
device=device,
|
|
50
|
+
config=kwargs,
|
|
47
51
|
)
|
|
@@ -51,13 +51,25 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
|
|
|
51
51
|
|
|
52
52
|
if task == "text-classification":
|
|
53
53
|
return TransformersTextClassificationModel(
|
|
54
|
-
model_uid=model_uid,
|
|
54
|
+
model_uid=model_uid,
|
|
55
|
+
model_path=model_path,
|
|
56
|
+
model_family=model_spec,
|
|
57
|
+
device=device,
|
|
58
|
+
config=kwargs,
|
|
55
59
|
)
|
|
56
60
|
elif task == "mock":
|
|
57
61
|
return MockModel(
|
|
58
|
-
model_uid=model_uid,
|
|
62
|
+
model_uid=model_uid,
|
|
63
|
+
model_path=model_path,
|
|
64
|
+
model_family=model_spec,
|
|
65
|
+
device=device,
|
|
66
|
+
config=kwargs,
|
|
59
67
|
)
|
|
60
68
|
else:
|
|
61
69
|
return AutoModel(
|
|
62
|
-
model_uid=model_uid,
|
|
70
|
+
model_uid=model_uid,
|
|
71
|
+
model_path=model_path,
|
|
72
|
+
model_family=model_spec,
|
|
73
|
+
device=device,
|
|
74
|
+
config=kwargs,
|
|
63
75
|
)
|
|
@@ -58,5 +58,9 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
|
|
|
58
58
|
raise ValueError("model_path required")
|
|
59
59
|
|
|
60
60
|
return UltralyticsModel(
|
|
61
|
-
model_uid=model_uid,
|
|
61
|
+
model_uid=model_uid,
|
|
62
|
+
model_path=model_path,
|
|
63
|
+
model_family=model_spec,
|
|
64
|
+
device=device,
|
|
65
|
+
config=kwargs,
|
|
62
66
|
)
|
|
@@ -16,20 +16,17 @@ import codecs
|
|
|
16
16
|
import json
|
|
17
17
|
import os
|
|
18
18
|
import warnings
|
|
19
|
-
from itertools import chain
|
|
20
19
|
|
|
20
|
+
from ..utils import flatten_model_src
|
|
21
21
|
from .core import (
|
|
22
22
|
BUILTIN_IMAGE_MODELS,
|
|
23
23
|
IMAGE_MODEL_DESCRIPTIONS,
|
|
24
|
-
|
|
25
|
-
MODELSCOPE_IMAGE_MODELS,
|
|
26
|
-
ImageModelFamilyV1,
|
|
24
|
+
ImageModelFamilyV2,
|
|
27
25
|
generate_image_description,
|
|
28
|
-
get_cache_status,
|
|
29
26
|
get_image_model_descriptions,
|
|
30
27
|
)
|
|
31
28
|
from .custom import (
|
|
32
|
-
|
|
29
|
+
CustomImageModelFamilyV2,
|
|
33
30
|
get_user_defined_images,
|
|
34
31
|
register_image,
|
|
35
32
|
unregister_image,
|
|
@@ -38,15 +35,19 @@ from .custom import (
|
|
|
38
35
|
|
|
39
36
|
def register_custom_model():
|
|
40
37
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
38
|
+
from ..custom import migrate_from_v1_to_v2
|
|
41
39
|
|
|
42
|
-
|
|
40
|
+
# migrate from v1 to v2 first
|
|
41
|
+
migrate_from_v1_to_v2("image", CustomImageModelFamilyV2)
|
|
42
|
+
|
|
43
|
+
user_defined_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "image")
|
|
43
44
|
if os.path.isdir(user_defined_image_dir):
|
|
44
45
|
for f in os.listdir(user_defined_image_dir):
|
|
45
46
|
try:
|
|
46
47
|
with codecs.open(
|
|
47
48
|
os.path.join(user_defined_image_dir, f), encoding="utf-8"
|
|
48
49
|
) as fd:
|
|
49
|
-
user_defined_image_family =
|
|
50
|
+
user_defined_image_family = CustomImageModelFamilyV2.parse_obj(
|
|
50
51
|
json.load(fd)
|
|
51
52
|
)
|
|
52
53
|
register_image(user_defined_image_family, persist=False)
|
|
@@ -56,12 +57,10 @@ def register_custom_model():
|
|
|
56
57
|
|
|
57
58
|
def _install():
|
|
58
59
|
load_model_family_from_json("model_spec.json", BUILTIN_IMAGE_MODELS)
|
|
59
|
-
load_model_family_from_json("model_spec_modelscope.json", MODELSCOPE_IMAGE_MODELS)
|
|
60
60
|
|
|
61
61
|
# register model description
|
|
62
|
-
for model_name,
|
|
63
|
-
|
|
64
|
-
):
|
|
62
|
+
for model_name, model_specs in BUILTIN_IMAGE_MODELS.items():
|
|
63
|
+
model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
|
|
65
64
|
IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(model_spec))
|
|
66
65
|
|
|
67
66
|
register_custom_model()
|
|
@@ -72,13 +71,14 @@ def _install():
|
|
|
72
71
|
|
|
73
72
|
def load_model_family_from_json(json_filename, target_families):
|
|
74
73
|
json_path = os.path.join(os.path.dirname(__file__), json_filename)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
74
|
+
flattened_model_specs = []
|
|
75
|
+
for spec in json.load(codecs.open(json_path, "r", encoding="utf-8")):
|
|
76
|
+
flattened_model_specs.extend(flatten_model_src(spec))
|
|
77
|
+
|
|
78
|
+
for spec in flattened_model_specs:
|
|
79
|
+
if spec["model_name"] not in target_families:
|
|
80
|
+
target_families[spec["model_name"]] = [ImageModelFamilyV2(**spec)]
|
|
81
|
+
else:
|
|
82
|
+
target_families[spec["model_name"]].append(ImageModelFamilyV2(**spec))
|
|
83
83
|
|
|
84
84
|
del json_path
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from ..cache_manager import CacheManager
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ImageCacheManager(CacheManager):
|
|
8
|
+
def cache_gguf(self, quantization: Optional[str] = None):
|
|
9
|
+
from ..utils import IS_NEW_HUGGINGFACE_HUB, retry_download, symlink_local_file
|
|
10
|
+
from .core import ImageModelFamilyV2
|
|
11
|
+
|
|
12
|
+
if not quantization:
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
assert isinstance(self._model_family, ImageModelFamilyV2)
|
|
16
|
+
cache_dir = self.get_cache_dir()
|
|
17
|
+
|
|
18
|
+
if not self._model_family.gguf_model_file_name_template:
|
|
19
|
+
raise NotImplementedError(
|
|
20
|
+
f"{self._model_family.model_name} does not support GGUF quantization"
|
|
21
|
+
)
|
|
22
|
+
if quantization not in (self._model_family.gguf_quantizations or []):
|
|
23
|
+
raise ValueError(
|
|
24
|
+
f"Cannot support quantization {quantization}, "
|
|
25
|
+
f"available quantizations: {self._model_family.gguf_quantizations}"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
filename = self._model_family.gguf_model_file_name_template.format(quantization=quantization) # type: ignore
|
|
29
|
+
full_path = os.path.join(cache_dir, filename)
|
|
30
|
+
|
|
31
|
+
if self._model_family.model_hub == "huggingface":
|
|
32
|
+
import huggingface_hub
|
|
33
|
+
|
|
34
|
+
use_symlinks = {}
|
|
35
|
+
if not IS_NEW_HUGGINGFACE_HUB:
|
|
36
|
+
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
|
|
37
|
+
download_file_path = retry_download(
|
|
38
|
+
huggingface_hub.hf_hub_download,
|
|
39
|
+
self._model_family.model_name,
|
|
40
|
+
None,
|
|
41
|
+
self._model_family.gguf_model_id,
|
|
42
|
+
filename=filename,
|
|
43
|
+
**use_symlinks,
|
|
44
|
+
)
|
|
45
|
+
if IS_NEW_HUGGINGFACE_HUB:
|
|
46
|
+
symlink_local_file(download_file_path, cache_dir, filename)
|
|
47
|
+
elif self._model_family.model_hub == "modelscope":
|
|
48
|
+
from modelscope.hub.file_download import model_file_download
|
|
49
|
+
|
|
50
|
+
download_file_path = retry_download(
|
|
51
|
+
model_file_download,
|
|
52
|
+
self._model_family.model_name,
|
|
53
|
+
None,
|
|
54
|
+
self._model_family.gguf_model_id,
|
|
55
|
+
filename,
|
|
56
|
+
revision=self._model_family.model_revision,
|
|
57
|
+
)
|
|
58
|
+
symlink_local_file(download_file_path, cache_dir, filename)
|
|
59
|
+
else:
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
return full_path
|