xinference 1.7.1.post1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client/restful/async_restful_client.py +8 -13
- xinference/client/restful/restful_client.py +6 -2
- xinference/core/chat_interface.py +6 -4
- xinference/core/media_interface.py +5 -0
- xinference/core/model.py +1 -5
- xinference/core/supervisor.py +117 -68
- xinference/core/worker.py +49 -37
- xinference/deploy/test/test_cmdline.py +2 -6
- xinference/model/audio/__init__.py +26 -23
- xinference/model/audio/chattts.py +3 -2
- xinference/model/audio/core.py +49 -98
- xinference/model/audio/cosyvoice.py +3 -2
- xinference/model/audio/custom.py +28 -73
- xinference/model/audio/f5tts.py +3 -2
- xinference/model/audio/f5tts_mlx.py +3 -2
- xinference/model/audio/fish_speech.py +3 -2
- xinference/model/audio/funasr.py +17 -4
- xinference/model/audio/kokoro.py +3 -2
- xinference/model/audio/megatts.py +3 -2
- xinference/model/audio/melotts.py +3 -2
- xinference/model/audio/model_spec.json +572 -171
- xinference/model/audio/utils.py +0 -6
- xinference/model/audio/whisper.py +3 -2
- xinference/model/audio/whisper_mlx.py +3 -2
- xinference/model/cache_manager.py +141 -0
- xinference/model/core.py +6 -49
- xinference/model/custom.py +174 -0
- xinference/model/embedding/__init__.py +67 -56
- xinference/model/embedding/cache_manager.py +35 -0
- xinference/model/embedding/core.py +104 -84
- xinference/model/embedding/custom.py +55 -78
- xinference/model/embedding/embed_family.py +80 -31
- xinference/model/embedding/flag/core.py +21 -5
- xinference/model/embedding/llama_cpp/__init__.py +0 -0
- xinference/model/embedding/llama_cpp/core.py +234 -0
- xinference/model/embedding/model_spec.json +968 -103
- xinference/model/embedding/sentence_transformers/core.py +30 -20
- xinference/model/embedding/vllm/core.py +11 -5
- xinference/model/flexible/__init__.py +8 -2
- xinference/model/flexible/core.py +26 -119
- xinference/model/flexible/custom.py +69 -0
- xinference/model/flexible/launchers/image_process_launcher.py +1 -0
- xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
- xinference/model/flexible/launchers/transformers_launcher.py +15 -3
- xinference/model/flexible/launchers/yolo_launcher.py +5 -1
- xinference/model/image/__init__.py +20 -20
- xinference/model/image/cache_manager.py +62 -0
- xinference/model/image/core.py +70 -182
- xinference/model/image/custom.py +28 -72
- xinference/model/image/model_spec.json +402 -119
- xinference/model/image/ocr/got_ocr2.py +3 -2
- xinference/model/image/stable_diffusion/core.py +22 -7
- xinference/model/image/stable_diffusion/mlx.py +6 -6
- xinference/model/image/utils.py +2 -2
- xinference/model/llm/__init__.py +71 -94
- xinference/model/llm/cache_manager.py +292 -0
- xinference/model/llm/core.py +37 -111
- xinference/model/llm/custom.py +88 -0
- xinference/model/llm/llama_cpp/core.py +5 -7
- xinference/model/llm/llm_family.json +16260 -8151
- xinference/model/llm/llm_family.py +138 -839
- xinference/model/llm/lmdeploy/core.py +5 -7
- xinference/model/llm/memory.py +3 -4
- xinference/model/llm/mlx/core.py +6 -8
- xinference/model/llm/reasoning_parser.py +3 -1
- xinference/model/llm/sglang/core.py +32 -14
- xinference/model/llm/transformers/chatglm.py +3 -7
- xinference/model/llm/transformers/core.py +49 -27
- xinference/model/llm/transformers/deepseek_v2.py +2 -2
- xinference/model/llm/transformers/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
- xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
- xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
- xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
- xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
- xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
- xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
- xinference/model/llm/transformers/opt.py +3 -7
- xinference/model/llm/utils.py +34 -49
- xinference/model/llm/vllm/core.py +77 -27
- xinference/model/llm/vllm/xavier/engine.py +5 -3
- xinference/model/llm/vllm/xavier/scheduler.py +10 -6
- xinference/model/llm/vllm/xavier/transfer.py +1 -1
- xinference/model/rerank/__init__.py +26 -25
- xinference/model/rerank/core.py +47 -87
- xinference/model/rerank/custom.py +25 -71
- xinference/model/rerank/model_spec.json +158 -33
- xinference/model/rerank/utils.py +2 -2
- xinference/model/utils.py +115 -54
- xinference/model/video/__init__.py +13 -17
- xinference/model/video/core.py +44 -102
- xinference/model/video/diffusers.py +4 -3
- xinference/model/video/model_spec.json +90 -21
- xinference/types.py +5 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
- xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
- xinference/web/ui/src/locales/en.json +0 -1
- xinference/web/ui/src/locales/ja.json +0 -1
- xinference/web/ui/src/locales/ko.json +0 -1
- xinference/web/ui/src/locales/zh.json +0 -1
- {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
- {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
- xinference/model/audio/model_spec_modelscope.json +0 -231
- xinference/model/embedding/model_spec_modelscope.json +0 -293
- xinference/model/embedding/utils.py +0 -18
- xinference/model/image/model_spec_modelscope.json +0 -375
- xinference/model/llm/llama_cpp/memory.py +0 -457
- xinference/model/llm/llm_family_csghub.json +0 -56
- xinference/model/llm/llm_family_modelscope.json +0 -8700
- xinference/model/llm/llm_family_openmind_hub.json +0 -1019
- xinference/model/rerank/model_spec_modelscope.json +0 -85
- xinference/model/video/model_spec_modelscope.json +0 -184
- xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
- xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
- /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
- {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
- {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
xinference/model/audio/utils.py
CHANGED
|
@@ -21,15 +21,9 @@ from collections.abc import Callable
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from .core import AudioModelFamilyV1
|
|
25
|
-
|
|
26
24
|
logger = logging.getLogger(__name__)
|
|
27
25
|
|
|
28
26
|
|
|
29
|
-
def get_model_version(audio_model: AudioModelFamilyV1) -> str:
|
|
30
|
-
return audio_model.model_name
|
|
31
|
-
|
|
32
|
-
|
|
33
27
|
def _extract_pcm_from_wav_bytes(wav_bytes):
|
|
34
28
|
with io.BytesIO(wav_bytes) as wav_io:
|
|
35
29
|
with wave.open(wav_io, "rb") as wav_file:
|
|
@@ -26,7 +26,7 @@ from ...device_utils import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from .core import
|
|
29
|
+
from .core import AudioModelFamilyV2
|
|
30
30
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
32
32
|
|
|
@@ -43,11 +43,12 @@ class WhisperModel:
|
|
|
43
43
|
self,
|
|
44
44
|
model_uid: str,
|
|
45
45
|
model_path: str,
|
|
46
|
-
model_spec: "
|
|
46
|
+
model_spec: "AudioModelFamilyV2",
|
|
47
47
|
device: Optional[str] = None,
|
|
48
48
|
max_new_tokens: Optional[int] = 128,
|
|
49
49
|
**kwargs,
|
|
50
50
|
):
|
|
51
|
+
self.model_family = model_spec
|
|
51
52
|
self._model_uid = model_uid
|
|
52
53
|
self._model_path = model_path
|
|
53
54
|
self._model_spec = model_spec
|
|
@@ -18,7 +18,7 @@ import tempfile
|
|
|
18
18
|
from typing import TYPE_CHECKING, List, Optional
|
|
19
19
|
|
|
20
20
|
if TYPE_CHECKING:
|
|
21
|
-
from .core import
|
|
21
|
+
from .core import AudioModelFamilyV2
|
|
22
22
|
|
|
23
23
|
logger = logging.getLogger(__name__)
|
|
24
24
|
|
|
@@ -28,10 +28,11 @@ class WhisperMLXModel:
|
|
|
28
28
|
self,
|
|
29
29
|
model_uid: str,
|
|
30
30
|
model_path: str,
|
|
31
|
-
model_spec: "
|
|
31
|
+
model_spec: "AudioModelFamilyV2",
|
|
32
32
|
device: Optional[str] = None,
|
|
33
33
|
**kwargs,
|
|
34
34
|
):
|
|
35
|
+
self.model_family = model_spec
|
|
35
36
|
self._model_uid = model_uid
|
|
36
37
|
self._model_path = model_path
|
|
37
38
|
self._model_spec = model_spec
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from .core import CacheableModelSpec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CacheManager:
|
|
13
|
+
def __init__(self, model_family: "CacheableModelSpec"):
|
|
14
|
+
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
|
|
15
|
+
|
|
16
|
+
self._model_family = model_family
|
|
17
|
+
self._v2_cache_dir_prefix = os.path.join(XINFERENCE_CACHE_DIR, "v2")
|
|
18
|
+
self._v2_custom_dir_prefix = os.path.join(XINFERENCE_MODEL_DIR, "v2")
|
|
19
|
+
os.makedirs(self._v2_cache_dir_prefix, exist_ok=True)
|
|
20
|
+
os.makedirs(self._v2_custom_dir_prefix, exist_ok=True)
|
|
21
|
+
self._cache_dir = os.path.join(
|
|
22
|
+
self._v2_cache_dir_prefix, self._model_family.model_name.replace(".", "_")
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def get_cache_dir(self):
|
|
26
|
+
return self._cache_dir
|
|
27
|
+
|
|
28
|
+
def get_cache_status(self):
|
|
29
|
+
cache_dir = self.get_cache_dir()
|
|
30
|
+
return os.path.exists(cache_dir)
|
|
31
|
+
|
|
32
|
+
def _cache_from_uri(self, model_spec: "CacheableModelSpec") -> str:
|
|
33
|
+
from .utils import parse_uri
|
|
34
|
+
|
|
35
|
+
cache_dir = self.get_cache_dir()
|
|
36
|
+
if os.path.exists(cache_dir):
|
|
37
|
+
logger.info("cache %s exists", cache_dir)
|
|
38
|
+
return cache_dir
|
|
39
|
+
|
|
40
|
+
assert model_spec.model_uri is not None
|
|
41
|
+
src_scheme, src_root = parse_uri(model_spec.model_uri)
|
|
42
|
+
if src_root.endswith("/"):
|
|
43
|
+
# remove trailing path separator.
|
|
44
|
+
src_root = src_root[:-1]
|
|
45
|
+
|
|
46
|
+
if src_scheme == "file":
|
|
47
|
+
if not os.path.isabs(src_root):
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Model URI cannot be a relative path: {model_spec.model_uri}"
|
|
50
|
+
)
|
|
51
|
+
os.symlink(src_root, cache_dir, target_is_directory=True)
|
|
52
|
+
return cache_dir
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(f"Unsupported URL scheme: {src_scheme}")
|
|
55
|
+
|
|
56
|
+
def _cache(self) -> str:
|
|
57
|
+
from .utils import IS_NEW_HUGGINGFACE_HUB, create_symlink, retry_download
|
|
58
|
+
|
|
59
|
+
if (
|
|
60
|
+
hasattr(self._model_family, "model_uri")
|
|
61
|
+
and getattr(self._model_family, "model_uri", None) is not None
|
|
62
|
+
):
|
|
63
|
+
logger.info(f"Model caching from URI: {self._model_family.model_uri}")
|
|
64
|
+
return self._cache_from_uri(model_spec=self._model_family)
|
|
65
|
+
|
|
66
|
+
cache_dir = self.get_cache_dir()
|
|
67
|
+
if self.get_cache_status():
|
|
68
|
+
return cache_dir
|
|
69
|
+
|
|
70
|
+
from_modelscope: bool = self._model_family.model_hub == "modelscope"
|
|
71
|
+
cache_config = (
|
|
72
|
+
self._model_family.cache_config.copy()
|
|
73
|
+
if self._model_family.cache_config
|
|
74
|
+
else {}
|
|
75
|
+
)
|
|
76
|
+
if from_modelscope:
|
|
77
|
+
from modelscope.hub.snapshot_download import (
|
|
78
|
+
snapshot_download as ms_download,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
download_dir = retry_download(
|
|
82
|
+
ms_download,
|
|
83
|
+
self._model_family.model_name,
|
|
84
|
+
None,
|
|
85
|
+
self._model_family.model_id,
|
|
86
|
+
revision=self._model_family.model_revision,
|
|
87
|
+
**cache_config,
|
|
88
|
+
)
|
|
89
|
+
create_symlink(download_dir, cache_dir)
|
|
90
|
+
else:
|
|
91
|
+
from huggingface_hub import snapshot_download as hf_download
|
|
92
|
+
|
|
93
|
+
use_symlinks = cache_config
|
|
94
|
+
if not IS_NEW_HUGGINGFACE_HUB:
|
|
95
|
+
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
|
|
96
|
+
download_dir = retry_download(
|
|
97
|
+
hf_download,
|
|
98
|
+
self._model_family.model_name,
|
|
99
|
+
None,
|
|
100
|
+
self._model_family.model_id,
|
|
101
|
+
revision=self._model_family.model_revision,
|
|
102
|
+
**use_symlinks,
|
|
103
|
+
)
|
|
104
|
+
if IS_NEW_HUGGINGFACE_HUB:
|
|
105
|
+
create_symlink(download_dir, cache_dir)
|
|
106
|
+
return cache_dir
|
|
107
|
+
|
|
108
|
+
def cache(self) -> str:
|
|
109
|
+
return self._cache()
|
|
110
|
+
|
|
111
|
+
def register_custom_model(self, model_type: str):
|
|
112
|
+
persist_path = os.path.join(
|
|
113
|
+
self._v2_custom_dir_prefix,
|
|
114
|
+
model_type,
|
|
115
|
+
f"{self._model_family.model_name}.json",
|
|
116
|
+
)
|
|
117
|
+
os.makedirs(os.path.dirname(persist_path), exist_ok=True)
|
|
118
|
+
with open(persist_path, mode="w") as fd:
|
|
119
|
+
fd.write(self._model_family.json())
|
|
120
|
+
|
|
121
|
+
def unregister_custom_model(self, model_type: str):
|
|
122
|
+
persist_path = os.path.join(
|
|
123
|
+
self._v2_custom_dir_prefix,
|
|
124
|
+
model_type,
|
|
125
|
+
f"{self._model_family.model_name}.json",
|
|
126
|
+
)
|
|
127
|
+
if os.path.exists(persist_path):
|
|
128
|
+
os.remove(persist_path)
|
|
129
|
+
|
|
130
|
+
cache_dir = self.get_cache_dir()
|
|
131
|
+
if self.get_cache_status():
|
|
132
|
+
logger.warning(
|
|
133
|
+
f"Remove the cache of user-defined model {self._model_family.model_name}. "
|
|
134
|
+
f"Cache directory: {cache_dir}"
|
|
135
|
+
)
|
|
136
|
+
if os.path.islink(cache_dir):
|
|
137
|
+
os.remove(cache_dir)
|
|
138
|
+
else:
|
|
139
|
+
logger.warning(
|
|
140
|
+
f"Cache directory is not a soft link, please remove it manually."
|
|
141
|
+
)
|
xinference/model/core.py
CHANGED
|
@@ -11,47 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from abc import ABC, abstractmethod
|
|
16
|
-
from typing import Any, List, Literal, Optional, Tuple, Union
|
|
14
|
+
from typing import Any, List, Literal, Optional, Union
|
|
17
15
|
|
|
18
16
|
from .._compat import BaseModel
|
|
19
17
|
from ..types import PeftModelConfig
|
|
20
18
|
|
|
21
19
|
|
|
22
|
-
class ModelDescription(ABC):
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
address: Optional[str],
|
|
26
|
-
devices: Optional[List[str]],
|
|
27
|
-
model_path: Optional[str] = None,
|
|
28
|
-
):
|
|
29
|
-
self.address = address
|
|
30
|
-
self.devices = devices
|
|
31
|
-
self._model_path = model_path
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
@abstractmethod
|
|
35
|
-
def spec(self):
|
|
36
|
-
pass
|
|
37
|
-
|
|
38
|
-
def to_dict(self):
|
|
39
|
-
"""
|
|
40
|
-
Return a dict to describe some information about model.
|
|
41
|
-
:return:
|
|
42
|
-
"""
|
|
43
|
-
raise NotImplementedError
|
|
44
|
-
|
|
45
|
-
@abstractmethod
|
|
46
|
-
def to_version_info(self):
|
|
47
|
-
"""
|
|
48
|
-
Return a dict to describe version info about a model instance
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
|
|
52
20
|
def create_model_instance(
|
|
53
|
-
subpool_addr: str,
|
|
54
|
-
devices: List[str],
|
|
55
21
|
model_uid: str,
|
|
56
22
|
model_type: str,
|
|
57
23
|
model_name: str,
|
|
@@ -65,7 +31,7 @@ def create_model_instance(
|
|
|
65
31
|
] = None,
|
|
66
32
|
model_path: Optional[str] = None,
|
|
67
33
|
**kwargs,
|
|
68
|
-
) ->
|
|
34
|
+
) -> Any:
|
|
69
35
|
from .audio.core import create_audio_model_instance
|
|
70
36
|
from .embedding.core import create_embedding_model_instance
|
|
71
37
|
from .flexible.core import create_flexible_model_instance
|
|
@@ -76,8 +42,6 @@ def create_model_instance(
|
|
|
76
42
|
|
|
77
43
|
if model_type == "LLM":
|
|
78
44
|
return create_llm_model_instance(
|
|
79
|
-
subpool_addr,
|
|
80
|
-
devices,
|
|
81
45
|
model_uid,
|
|
82
46
|
model_name,
|
|
83
47
|
model_engine,
|
|
@@ -93,11 +57,11 @@ def create_model_instance(
|
|
|
93
57
|
# embedding model doesn't accept trust_remote_code
|
|
94
58
|
kwargs.pop("trust_remote_code", None)
|
|
95
59
|
return create_embedding_model_instance(
|
|
96
|
-
subpool_addr,
|
|
97
|
-
devices,
|
|
98
60
|
model_uid,
|
|
99
61
|
model_name,
|
|
100
62
|
model_engine,
|
|
63
|
+
model_format,
|
|
64
|
+
quantization,
|
|
101
65
|
download_hub,
|
|
102
66
|
model_path,
|
|
103
67
|
**kwargs,
|
|
@@ -105,8 +69,6 @@ def create_model_instance(
|
|
|
105
69
|
elif model_type == "image":
|
|
106
70
|
kwargs.pop("trust_remote_code", None)
|
|
107
71
|
return create_image_model_instance(
|
|
108
|
-
subpool_addr,
|
|
109
|
-
devices,
|
|
110
72
|
model_uid,
|
|
111
73
|
model_name,
|
|
112
74
|
peft_model_config,
|
|
@@ -117,8 +79,6 @@ def create_model_instance(
|
|
|
117
79
|
elif model_type == "rerank":
|
|
118
80
|
kwargs.pop("trust_remote_code", None)
|
|
119
81
|
return create_rerank_model_instance(
|
|
120
|
-
subpool_addr,
|
|
121
|
-
devices,
|
|
122
82
|
model_uid,
|
|
123
83
|
model_name,
|
|
124
84
|
download_hub,
|
|
@@ -128,8 +88,6 @@ def create_model_instance(
|
|
|
128
88
|
elif model_type == "audio":
|
|
129
89
|
kwargs.pop("trust_remote_code", None)
|
|
130
90
|
return create_audio_model_instance(
|
|
131
|
-
subpool_addr,
|
|
132
|
-
devices,
|
|
133
91
|
model_uid,
|
|
134
92
|
model_name,
|
|
135
93
|
download_hub,
|
|
@@ -139,8 +97,6 @@ def create_model_instance(
|
|
|
139
97
|
elif model_type == "video":
|
|
140
98
|
kwargs.pop("trust_remote_code", None)
|
|
141
99
|
return create_video_model_instance(
|
|
142
|
-
subpool_addr,
|
|
143
|
-
devices,
|
|
144
100
|
model_uid,
|
|
145
101
|
model_name,
|
|
146
102
|
download_hub,
|
|
@@ -150,7 +106,7 @@ def create_model_instance(
|
|
|
150
106
|
elif model_type == "flexible":
|
|
151
107
|
kwargs.pop("trust_remote_code", None)
|
|
152
108
|
return create_flexible_model_instance(
|
|
153
|
-
|
|
109
|
+
model_uid, model_name, model_path, **kwargs
|
|
154
110
|
)
|
|
155
111
|
else:
|
|
156
112
|
raise ValueError(f"Unsupported model type: {model_type}.")
|
|
@@ -161,6 +117,7 @@ class CacheableModelSpec(BaseModel):
|
|
|
161
117
|
model_id: str
|
|
162
118
|
model_revision: Optional[str]
|
|
163
119
|
model_hub: str = "huggingface"
|
|
120
|
+
cache_config: Optional[dict]
|
|
164
121
|
|
|
165
122
|
|
|
166
123
|
class VirtualEnvSettings(BaseModel):
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import codecs
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import threading
|
|
20
|
+
import warnings
|
|
21
|
+
from typing import TYPE_CHECKING, Dict, List, Type
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from .core import CacheableModelSpec
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ModelRegistry:
|
|
30
|
+
model_type = "unknown"
|
|
31
|
+
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self.lock = threading.Lock()
|
|
34
|
+
self.models: List["CacheableModelSpec"] = []
|
|
35
|
+
self.builtin_models: List[str] = []
|
|
36
|
+
|
|
37
|
+
def find_model(self, model_name: str):
|
|
38
|
+
model_spec = None
|
|
39
|
+
for f in self.models:
|
|
40
|
+
if f.model_name == model_name:
|
|
41
|
+
model_spec = f
|
|
42
|
+
break
|
|
43
|
+
return model_spec
|
|
44
|
+
|
|
45
|
+
def get_custom_models(self):
|
|
46
|
+
with self.lock:
|
|
47
|
+
return self.models.copy()
|
|
48
|
+
|
|
49
|
+
def check_model_uri(self, model_spec: "CacheableModelSpec"):
|
|
50
|
+
from .utils import is_valid_model_uri
|
|
51
|
+
|
|
52
|
+
model_uri = model_spec.model_uri
|
|
53
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
54
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
55
|
+
|
|
56
|
+
def add_ud_model(self, model_spec):
|
|
57
|
+
self.models.append(model_spec)
|
|
58
|
+
|
|
59
|
+
def register(self, model_spec: "CacheableModelSpec", persist: bool):
|
|
60
|
+
from .cache_manager import CacheManager
|
|
61
|
+
from .utils import is_valid_model_name
|
|
62
|
+
|
|
63
|
+
if not is_valid_model_name(model_spec.model_name):
|
|
64
|
+
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
65
|
+
|
|
66
|
+
self.check_model_uri(model_spec)
|
|
67
|
+
|
|
68
|
+
with self.lock:
|
|
69
|
+
for model_name in self.builtin_models + [
|
|
70
|
+
spec.model_name for spec in self.models
|
|
71
|
+
]:
|
|
72
|
+
if model_spec.model_name == model_name:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Model name conflicts with existing model {model_spec.model_name}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self.add_ud_model(model_spec)
|
|
78
|
+
|
|
79
|
+
if persist:
|
|
80
|
+
cache_manager = CacheManager(model_spec)
|
|
81
|
+
cache_manager.register_custom_model(self.model_type)
|
|
82
|
+
|
|
83
|
+
def remove_ud_model(self, model_spec):
|
|
84
|
+
self.models.remove(model_spec)
|
|
85
|
+
|
|
86
|
+
def remove_ud_model_files(self, model_spec):
|
|
87
|
+
from .cache_manager import CacheManager
|
|
88
|
+
|
|
89
|
+
cache_manager = CacheManager(model_spec)
|
|
90
|
+
cache_manager.unregister_custom_model(self.model_type)
|
|
91
|
+
|
|
92
|
+
def unregister(
|
|
93
|
+
self, model_name: str, raise_error: bool = True, remove_file: bool = True
|
|
94
|
+
):
|
|
95
|
+
with self.lock:
|
|
96
|
+
model_spec = self.find_model(model_name)
|
|
97
|
+
if model_spec:
|
|
98
|
+
self.remove_ud_model(model_spec)
|
|
99
|
+
if remove_file:
|
|
100
|
+
self.remove_ud_model_files(model_spec)
|
|
101
|
+
else:
|
|
102
|
+
if raise_error:
|
|
103
|
+
raise ValueError(f"Model {model_name} not found")
|
|
104
|
+
else:
|
|
105
|
+
logger.warning(
|
|
106
|
+
f"Custom {self.model_type} model {model_name} not found"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RegistryManager:
|
|
111
|
+
_instances: Dict[str, ModelRegistry] = {}
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def get_registry(cls, model_type: str) -> ModelRegistry:
|
|
115
|
+
from .audio.custom import AudioModelRegistry
|
|
116
|
+
from .embedding.custom import EmbeddingModelRegistry
|
|
117
|
+
from .flexible.custom import FlexibleModelRegistry
|
|
118
|
+
from .image.custom import ImageModelRegistry
|
|
119
|
+
from .llm.custom import LLMModelRegistry
|
|
120
|
+
from .rerank.custom import RerankModelRegistry
|
|
121
|
+
|
|
122
|
+
if model_type not in cls._instances:
|
|
123
|
+
if model_type == "rerank":
|
|
124
|
+
cls._instances[model_type] = RerankModelRegistry()
|
|
125
|
+
elif model_type == "image":
|
|
126
|
+
cls._instances[model_type] = ImageModelRegistry()
|
|
127
|
+
elif model_type == "audio":
|
|
128
|
+
cls._instances[model_type] = AudioModelRegistry()
|
|
129
|
+
elif model_type == "llm":
|
|
130
|
+
cls._instances[model_type] = LLMModelRegistry()
|
|
131
|
+
elif model_type == "flexible":
|
|
132
|
+
cls._instances[model_type] = FlexibleModelRegistry()
|
|
133
|
+
elif model_type == "embedding":
|
|
134
|
+
cls._instances[model_type] = EmbeddingModelRegistry()
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError(f"Unknown model type: {model_type}")
|
|
137
|
+
return cls._instances[model_type]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def migrate_from_v1_to_v2(model_type: str, model_spec_cls: Type):
|
|
141
|
+
from ..constants import XINFERENCE_MODEL_DIR
|
|
142
|
+
|
|
143
|
+
v1_user_defined_model_dir = os.path.join(XINFERENCE_MODEL_DIR, model_type)
|
|
144
|
+
v2_user_defined_model_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", model_type)
|
|
145
|
+
if os.path.isdir(v1_user_defined_model_dir):
|
|
146
|
+
for f in os.listdir(v1_user_defined_model_dir):
|
|
147
|
+
if os.path.exists(os.path.join(v2_user_defined_model_dir, f)):
|
|
148
|
+
# skip if v2 has already
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
with codecs.open(
|
|
153
|
+
os.path.join(v1_user_defined_model_dir, f), encoding="utf-8"
|
|
154
|
+
) as fd:
|
|
155
|
+
v1_model_json = json.load(fd)
|
|
156
|
+
|
|
157
|
+
v1_model_json["version"] = 2
|
|
158
|
+
for spec in v1_model_json.get("model_specs", []):
|
|
159
|
+
if "quantizations" in spec:
|
|
160
|
+
# change quantizations to quantization
|
|
161
|
+
spec["quantization"] = spec["quantizations"][0]
|
|
162
|
+
|
|
163
|
+
user_defined_model_family = model_spec_cls(**v1_model_json)
|
|
164
|
+
registry = RegistryManager.get_registry(model_type)
|
|
165
|
+
# register custom model file to v2
|
|
166
|
+
registry.register(user_defined_model_family, persist=True)
|
|
167
|
+
# unregister since it will be registered by v2
|
|
168
|
+
registry.unregister(
|
|
169
|
+
user_defined_model_family.model_name, remove_file=False
|
|
170
|
+
)
|
|
171
|
+
except Exception as e:
|
|
172
|
+
warnings.warn(
|
|
173
|
+
f"Fail to migrate {v1_user_defined_model_dir}/{f}, error: {e}"
|
|
174
|
+
)
|