xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client/restful/async_restful_client.py +8 -13
- xinference/client/restful/restful_client.py +6 -2
- xinference/core/chat_interface.py +6 -4
- xinference/core/media_interface.py +5 -0
- xinference/core/model.py +1 -5
- xinference/core/supervisor.py +117 -68
- xinference/core/worker.py +49 -37
- xinference/deploy/test/test_cmdline.py +2 -6
- xinference/model/audio/__init__.py +26 -23
- xinference/model/audio/chattts.py +3 -2
- xinference/model/audio/core.py +49 -98
- xinference/model/audio/cosyvoice.py +3 -2
- xinference/model/audio/custom.py +28 -73
- xinference/model/audio/f5tts.py +3 -2
- xinference/model/audio/f5tts_mlx.py +3 -2
- xinference/model/audio/fish_speech.py +3 -2
- xinference/model/audio/funasr.py +17 -4
- xinference/model/audio/kokoro.py +3 -2
- xinference/model/audio/megatts.py +3 -2
- xinference/model/audio/melotts.py +3 -2
- xinference/model/audio/model_spec.json +572 -171
- xinference/model/audio/utils.py +0 -6
- xinference/model/audio/whisper.py +3 -2
- xinference/model/audio/whisper_mlx.py +3 -2
- xinference/model/cache_manager.py +141 -0
- xinference/model/core.py +6 -49
- xinference/model/custom.py +174 -0
- xinference/model/embedding/__init__.py +67 -56
- xinference/model/embedding/cache_manager.py +35 -0
- xinference/model/embedding/core.py +104 -84
- xinference/model/embedding/custom.py +55 -78
- xinference/model/embedding/embed_family.py +80 -31
- xinference/model/embedding/flag/core.py +21 -5
- xinference/model/embedding/llama_cpp/__init__.py +0 -0
- xinference/model/embedding/llama_cpp/core.py +234 -0
- xinference/model/embedding/model_spec.json +968 -103
- xinference/model/embedding/sentence_transformers/core.py +30 -20
- xinference/model/embedding/vllm/core.py +11 -5
- xinference/model/flexible/__init__.py +8 -2
- xinference/model/flexible/core.py +26 -119
- xinference/model/flexible/custom.py +69 -0
- xinference/model/flexible/launchers/image_process_launcher.py +1 -0
- xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
- xinference/model/flexible/launchers/transformers_launcher.py +15 -3
- xinference/model/flexible/launchers/yolo_launcher.py +5 -1
- xinference/model/image/__init__.py +20 -20
- xinference/model/image/cache_manager.py +62 -0
- xinference/model/image/core.py +70 -182
- xinference/model/image/custom.py +28 -72
- xinference/model/image/model_spec.json +402 -119
- xinference/model/image/ocr/got_ocr2.py +3 -2
- xinference/model/image/stable_diffusion/core.py +22 -7
- xinference/model/image/stable_diffusion/mlx.py +6 -6
- xinference/model/image/utils.py +2 -2
- xinference/model/llm/__init__.py +71 -94
- xinference/model/llm/cache_manager.py +292 -0
- xinference/model/llm/core.py +37 -111
- xinference/model/llm/custom.py +88 -0
- xinference/model/llm/llama_cpp/core.py +5 -7
- xinference/model/llm/llm_family.json +16260 -8151
- xinference/model/llm/llm_family.py +138 -839
- xinference/model/llm/lmdeploy/core.py +5 -7
- xinference/model/llm/memory.py +3 -4
- xinference/model/llm/mlx/core.py +6 -8
- xinference/model/llm/reasoning_parser.py +3 -1
- xinference/model/llm/sglang/core.py +32 -14
- xinference/model/llm/transformers/chatglm.py +3 -7
- xinference/model/llm/transformers/core.py +49 -27
- xinference/model/llm/transformers/deepseek_v2.py +2 -2
- xinference/model/llm/transformers/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
- xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
- xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
- xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
- xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
- xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
- xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
- xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
- xinference/model/llm/transformers/opt.py +3 -7
- xinference/model/llm/utils.py +34 -49
- xinference/model/llm/vllm/core.py +77 -27
- xinference/model/llm/vllm/xavier/engine.py +5 -3
- xinference/model/llm/vllm/xavier/scheduler.py +10 -6
- xinference/model/llm/vllm/xavier/transfer.py +1 -1
- xinference/model/rerank/__init__.py +26 -25
- xinference/model/rerank/core.py +47 -87
- xinference/model/rerank/custom.py +25 -71
- xinference/model/rerank/model_spec.json +158 -33
- xinference/model/rerank/utils.py +2 -2
- xinference/model/utils.py +115 -54
- xinference/model/video/__init__.py +13 -17
- xinference/model/video/core.py +44 -102
- xinference/model/video/diffusers.py +4 -3
- xinference/model/video/model_spec.json +90 -21
- xinference/types.py +5 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
- xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
- xinference/web/ui/src/locales/en.json +0 -1
- xinference/web/ui/src/locales/ja.json +0 -1
- xinference/web/ui/src/locales/ko.json +0 -1
- xinference/web/ui/src/locales/zh.json +0 -1
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
- xinference/model/audio/model_spec_modelscope.json +0 -231
- xinference/model/embedding/model_spec_modelscope.json +0 -293
- xinference/model/embedding/utils.py +0 -18
- xinference/model/image/model_spec_modelscope.json +0 -375
- xinference/model/llm/llama_cpp/memory.py +0 -457
- xinference/model/llm/llm_family_csghub.json +0 -56
- xinference/model/llm/llm_family_modelscope.json +0 -8700
- xinference/model/llm/llm_family_openmind_hub.json +0 -1019
- xinference/model/rerank/model_spec_modelscope.json +0 -85
- xinference/model/video/model_spec_modelscope.json +0 -184
- xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
- xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
- /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
|
@@ -18,16 +18,15 @@ import os
|
|
|
18
18
|
import warnings
|
|
19
19
|
from typing import Any, Dict, List
|
|
20
20
|
|
|
21
|
+
from ..utils import flatten_quantizations
|
|
21
22
|
from .core import (
|
|
22
23
|
EMBEDDING_MODEL_DESCRIPTIONS,
|
|
23
|
-
|
|
24
|
-
EmbeddingModelSpec,
|
|
24
|
+
EmbeddingModelFamilyV2,
|
|
25
25
|
generate_embedding_description,
|
|
26
|
-
get_cache_status,
|
|
27
26
|
get_embedding_model_descriptions,
|
|
28
27
|
)
|
|
29
28
|
from .custom import (
|
|
30
|
-
|
|
29
|
+
CustomEmbeddingModelFamilyV2,
|
|
31
30
|
get_user_defined_embeddings,
|
|
32
31
|
register_embedding,
|
|
33
32
|
unregister_embedding,
|
|
@@ -36,7 +35,7 @@ from .embed_family import (
|
|
|
36
35
|
BUILTIN_EMBEDDING_MODELS,
|
|
37
36
|
EMBEDDING_ENGINES,
|
|
38
37
|
FLAG_EMBEDDER_CLASSES,
|
|
39
|
-
|
|
38
|
+
LLAMA_CPP_CLASSES,
|
|
40
39
|
SENTENCE_TRANSFORMER_CLASSES,
|
|
41
40
|
SUPPORTED_ENGINES,
|
|
42
41
|
VLLM_CLASSES,
|
|
@@ -45,15 +44,19 @@ from .embed_family import (
|
|
|
45
44
|
|
|
46
45
|
def register_custom_model():
|
|
47
46
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
47
|
+
from ..custom import migrate_from_v1_to_v2
|
|
48
48
|
|
|
49
|
-
|
|
49
|
+
# migrate from v1 to v2 first
|
|
50
|
+
migrate_from_v1_to_v2("embedding", CustomEmbeddingModelFamilyV2)
|
|
51
|
+
|
|
52
|
+
user_defined_embedding_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "embedding")
|
|
50
53
|
if os.path.isdir(user_defined_embedding_dir):
|
|
51
54
|
for f in os.listdir(user_defined_embedding_dir):
|
|
52
55
|
try:
|
|
53
56
|
with codecs.open(
|
|
54
57
|
os.path.join(user_defined_embedding_dir, f), encoding="utf-8"
|
|
55
58
|
) as fd:
|
|
56
|
-
user_defined_llm_family =
|
|
59
|
+
user_defined_llm_family = CustomEmbeddingModelFamilyV2.parse_obj(
|
|
57
60
|
json.load(fd)
|
|
58
61
|
)
|
|
59
62
|
register_embedding(user_defined_llm_family, persist=False)
|
|
@@ -61,80 +64,89 @@ def register_custom_model():
|
|
|
61
64
|
warnings.warn(f"{user_defined_embedding_dir}/{f} has error, {e}")
|
|
62
65
|
|
|
63
66
|
|
|
64
|
-
def
|
|
65
|
-
|
|
67
|
+
def check_format_with_engine(model_format, engine):
|
|
68
|
+
if model_format in ["ggufv2"] and engine not in ["llama.cpp"]:
|
|
69
|
+
return False
|
|
70
|
+
if model_format not in ["ggufv2"] and engine == "llama.cpp":
|
|
71
|
+
return False
|
|
72
|
+
return True
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def generate_engine_config_by_model_name(model_family: "EmbeddingModelFamilyV2"):
|
|
76
|
+
model_name = model_family.model_name
|
|
66
77
|
engines: Dict[str, List[Dict[str, Any]]] = EMBEDDING_ENGINES.get(
|
|
67
78
|
model_name, {}
|
|
68
79
|
) # structure for engine query
|
|
69
|
-
for
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
80
|
+
for spec in [x for x in model_family.model_specs if x.model_hub == "huggingface"]:
|
|
81
|
+
model_format = spec.model_format
|
|
82
|
+
quantization = spec.quantization
|
|
83
|
+
for engine in SUPPORTED_ENGINES:
|
|
84
|
+
if not check_format_with_engine(model_format, engine):
|
|
85
|
+
continue
|
|
86
|
+
CLASSES = SUPPORTED_ENGINES[engine]
|
|
87
|
+
for cls in CLASSES:
|
|
88
|
+
# Every engine needs to implement match method
|
|
89
|
+
if cls.match(model_family, spec, quantization):
|
|
90
|
+
# we only match the first class for an engine
|
|
91
|
+
if engine not in engines:
|
|
92
|
+
engines[engine] = [
|
|
93
|
+
{
|
|
94
|
+
"model_name": model_name,
|
|
95
|
+
"model_format": model_format,
|
|
96
|
+
"quantization": quantization,
|
|
97
|
+
"embedding_class": cls,
|
|
98
|
+
}
|
|
99
|
+
]
|
|
100
|
+
else:
|
|
101
|
+
engines[engine].append(
|
|
102
|
+
{
|
|
103
|
+
"model_name": model_name,
|
|
104
|
+
"model_format": model_format,
|
|
105
|
+
"quantization": quantization,
|
|
106
|
+
"embedding_class": cls,
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
break
|
|
82
110
|
EMBEDDING_ENGINES[model_name] = engines
|
|
83
111
|
|
|
84
112
|
|
|
85
113
|
# will be called in xinference/model/__init__.py
|
|
86
114
|
def _install():
|
|
87
115
|
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
116
|
+
|
|
117
|
+
for json_obj in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
|
|
118
|
+
flattened = []
|
|
119
|
+
for spec in json_obj["model_specs"]:
|
|
120
|
+
flattened.extend(flatten_quantizations(spec))
|
|
121
|
+
json_obj["model_specs"] = flattened
|
|
122
|
+
BUILTIN_EMBEDDING_MODELS[json_obj["model_name"]] = EmbeddingModelFamilyV2(
|
|
123
|
+
**json_obj
|
|
96
124
|
)
|
|
97
|
-
|
|
125
|
+
|
|
98
126
|
for model_name, model_spec in BUILTIN_EMBEDDING_MODELS.items():
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
MODELSCOPE_EMBEDDING_MODELS.update(
|
|
103
|
-
dict(
|
|
104
|
-
(spec["model_name"], EmbeddingModelSpec(**spec))
|
|
105
|
-
for spec in json.load(
|
|
106
|
-
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
|
|
127
|
+
if model_spec.model_name not in EMBEDDING_MODEL_DESCRIPTIONS:
|
|
128
|
+
EMBEDDING_MODEL_DESCRIPTIONS.update(
|
|
129
|
+
generate_embedding_description(model_spec)
|
|
107
130
|
)
|
|
108
|
-
)
|
|
109
|
-
)
|
|
110
|
-
for model_name, model_spec in MODELSCOPE_EMBEDDING_MODELS.items():
|
|
111
|
-
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
112
|
-
|
|
113
|
-
# TODO: consider support more download hub in future...
|
|
114
|
-
# register model description after recording model revision
|
|
115
|
-
for model_spec_info in [BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS]:
|
|
116
|
-
for model_name, model_spec in model_spec_info.items():
|
|
117
|
-
if model_spec.model_name not in EMBEDDING_MODEL_DESCRIPTIONS:
|
|
118
|
-
EMBEDDING_MODEL_DESCRIPTIONS.update(
|
|
119
|
-
generate_embedding_description(model_spec)
|
|
120
|
-
)
|
|
121
131
|
|
|
122
132
|
from .flag.core import FlagEmbeddingModel
|
|
133
|
+
from .llama_cpp.core import XllamaCppEmbeddingModel
|
|
123
134
|
from .sentence_transformers.core import SentenceTransformerEmbeddingModel
|
|
124
135
|
from .vllm.core import VLLMEmbeddingModel
|
|
125
136
|
|
|
126
137
|
SENTENCE_TRANSFORMER_CLASSES.extend([SentenceTransformerEmbeddingModel])
|
|
127
138
|
FLAG_EMBEDDER_CLASSES.extend([FlagEmbeddingModel])
|
|
128
139
|
VLLM_CLASSES.extend([VLLMEmbeddingModel])
|
|
140
|
+
LLAMA_CPP_CLASSES.extend([XllamaCppEmbeddingModel])
|
|
129
141
|
|
|
130
142
|
SUPPORTED_ENGINES["sentence_transformers"] = SENTENCE_TRANSFORMER_CLASSES
|
|
131
143
|
SUPPORTED_ENGINES["flag"] = FLAG_EMBEDDER_CLASSES
|
|
132
144
|
SUPPORTED_ENGINES["vllm"] = VLLM_CLASSES
|
|
145
|
+
SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CPP_CLASSES
|
|
133
146
|
|
|
134
147
|
# Init embedding engine
|
|
135
|
-
for
|
|
136
|
-
|
|
137
|
-
generate_engine_config_by_model_name(model_spec)
|
|
148
|
+
for model_spec in BUILTIN_EMBEDDING_MODELS.values():
|
|
149
|
+
generate_engine_config_by_model_name(model_spec)
|
|
138
150
|
|
|
139
151
|
register_custom_model()
|
|
140
152
|
|
|
@@ -145,4 +157,3 @@ def _install():
|
|
|
145
157
|
)
|
|
146
158
|
|
|
147
159
|
del _model_spec_json
|
|
148
|
-
del _model_spec_modelscope_json
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from ..cache_manager import CacheManager
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .core import EmbeddingModelFamilyV2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EmbeddingCacheManager(CacheManager):
|
|
11
|
+
def __init__(self, model_family: "EmbeddingModelFamilyV2"):
|
|
12
|
+
from ..llm.cache_manager import LLMCacheManager
|
|
13
|
+
|
|
14
|
+
super().__init__(model_family)
|
|
15
|
+
# Composition design mode for avoiding duplicate code
|
|
16
|
+
self.cache_helper = LLMCacheManager(model_family)
|
|
17
|
+
|
|
18
|
+
spec = self._model_family.model_specs[0]
|
|
19
|
+
model_dir_name = (
|
|
20
|
+
f"{self._model_family.model_name}-{spec.model_format}-{spec.quantization}"
|
|
21
|
+
)
|
|
22
|
+
self._cache_dir = os.path.join(self._v2_cache_dir_prefix, model_dir_name)
|
|
23
|
+
self.cache_helper._cache_dir = self._cache_dir
|
|
24
|
+
|
|
25
|
+
def cache(self) -> str:
|
|
26
|
+
spec = self._model_family.model_specs[0]
|
|
27
|
+
if spec.model_uri is not None:
|
|
28
|
+
return self.cache_helper.cache_uri()
|
|
29
|
+
else:
|
|
30
|
+
if spec.model_hub == "huggingface":
|
|
31
|
+
return self.cache_helper.cache_from_huggingface()
|
|
32
|
+
elif spec.model_hub == "modelscope":
|
|
33
|
+
return self.cache_helper.cache_from_modelscope()
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Unknown model hub: {spec.model_hub}")
|
|
@@ -12,23 +12,24 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import abc
|
|
15
16
|
import gc
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
19
|
+
from abc import abstractmethod
|
|
18
20
|
from collections import defaultdict
|
|
19
|
-
from typing import Dict, List, Literal, Optional,
|
|
21
|
+
from typing import Annotated, Dict, List, Literal, Optional, Union
|
|
20
22
|
|
|
21
|
-
from ..._compat import ROOT_KEY, ErrorWrapper, ValidationError
|
|
23
|
+
from ..._compat import ROOT_KEY, BaseModel, ErrorWrapper, Field, ValidationError
|
|
22
24
|
from ...device_utils import empty_cache
|
|
23
|
-
from ..core import
|
|
24
|
-
from ..utils import
|
|
25
|
+
from ..core import VirtualEnvSettings
|
|
26
|
+
from ..utils import ModelInstanceInfoMixin
|
|
25
27
|
from .embed_family import match_embedding
|
|
26
28
|
|
|
27
29
|
logger = logging.getLogger(__name__)
|
|
28
30
|
|
|
29
31
|
# Used for check whether the model is cached.
|
|
30
32
|
# Init when registering all the builtin models.
|
|
31
|
-
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
32
33
|
EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
33
34
|
EMBEDDING_EMPTY_CACHE_COUNT = int(
|
|
34
35
|
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
|
|
@@ -46,96 +47,100 @@ def get_embedding_model_descriptions():
|
|
|
46
47
|
return copy.deepcopy(EMBEDDING_MODEL_DESCRIPTIONS)
|
|
47
48
|
|
|
48
49
|
|
|
50
|
+
class TransformersEmbeddingSpecV1(BaseModel):
|
|
51
|
+
model_format: Literal["pytorch"]
|
|
52
|
+
model_hub: str = "huggingface"
|
|
53
|
+
model_id: Optional[str]
|
|
54
|
+
model_uri: Optional[str]
|
|
55
|
+
model_revision: Optional[str]
|
|
56
|
+
quantization: str
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class LlamaCppEmbeddingSpecV1(BaseModel):
|
|
60
|
+
model_format: Literal["ggufv2"]
|
|
61
|
+
model_hub: str = "huggingface"
|
|
62
|
+
model_id: Optional[str]
|
|
63
|
+
model_uri: Optional[str]
|
|
64
|
+
model_revision: Optional[str]
|
|
65
|
+
quantization: str
|
|
66
|
+
model_file_name_template: str
|
|
67
|
+
model_file_name_split_template: Optional[str]
|
|
68
|
+
quantization_parts: Optional[Dict[str, List[str]]]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
EmbeddingSpecV1 = Annotated[
|
|
72
|
+
Union[TransformersEmbeddingSpecV1, LlamaCppEmbeddingSpecV1],
|
|
73
|
+
Field(discriminator="model_format"),
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
|
|
49
77
|
# this class define the basic info of embedding model
|
|
50
|
-
class
|
|
78
|
+
class EmbeddingModelFamilyV2(BaseModel, ModelInstanceInfoMixin):
|
|
79
|
+
version: Literal[2]
|
|
51
80
|
model_name: str
|
|
52
81
|
dimensions: int
|
|
53
82
|
max_tokens: int
|
|
54
83
|
language: List[str]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
model_hub: str = "huggingface"
|
|
84
|
+
model_specs: List["EmbeddingSpecV1"]
|
|
85
|
+
cache_config: Optional[dict]
|
|
58
86
|
virtualenv: Optional[VirtualEnvSettings]
|
|
59
87
|
|
|
88
|
+
class Config:
|
|
89
|
+
extra = "allow"
|
|
60
90
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
self,
|
|
64
|
-
address: Optional[str],
|
|
65
|
-
devices: Optional[List[str]],
|
|
66
|
-
model_spec: EmbeddingModelSpec,
|
|
67
|
-
model_path: Optional[str] = None,
|
|
68
|
-
):
|
|
69
|
-
super().__init__(address, devices, model_path=model_path)
|
|
70
|
-
self._model_spec = model_spec
|
|
71
|
-
|
|
72
|
-
@property
|
|
73
|
-
def spec(self):
|
|
74
|
-
return self._model_spec
|
|
75
|
-
|
|
76
|
-
def to_dict(self):
|
|
91
|
+
def to_description(self):
|
|
92
|
+
spec = self.model_specs[0]
|
|
77
93
|
return {
|
|
78
94
|
"model_type": "embedding",
|
|
79
|
-
"address": self
|
|
80
|
-
"accelerators": self
|
|
81
|
-
"model_name": self.
|
|
82
|
-
"dimensions": self.
|
|
83
|
-
"max_tokens": self.
|
|
84
|
-
"language": self.
|
|
85
|
-
"
|
|
95
|
+
"address": getattr(self, "address", None),
|
|
96
|
+
"accelerators": getattr(self, "accelerators", None),
|
|
97
|
+
"model_name": self.model_name,
|
|
98
|
+
"dimensions": self.dimensions,
|
|
99
|
+
"max_tokens": self.max_tokens,
|
|
100
|
+
"language": self.language,
|
|
101
|
+
"model_hub": spec.model_hub,
|
|
102
|
+
"model_revision": spec.model_revision,
|
|
103
|
+
"quantization": spec.quantization,
|
|
86
104
|
}
|
|
87
105
|
|
|
88
106
|
def to_version_info(self):
|
|
89
|
-
from .
|
|
107
|
+
from .cache_manager import EmbeddingCacheManager
|
|
90
108
|
|
|
91
|
-
|
|
92
|
-
is_cached = get_cache_status(self._model_spec)
|
|
93
|
-
file_location = get_cache_dir(self._model_spec)
|
|
94
|
-
else:
|
|
95
|
-
is_cached = True
|
|
96
|
-
file_location = self._model_path
|
|
109
|
+
cache_manager = EmbeddingCacheManager(self)
|
|
97
110
|
|
|
98
111
|
return {
|
|
99
|
-
"model_version": get_model_version(self
|
|
100
|
-
"model_file_location":
|
|
101
|
-
"cache_status":
|
|
102
|
-
"dimensions": self.
|
|
103
|
-
"max_tokens": self.
|
|
112
|
+
"model_version": get_model_version(self),
|
|
113
|
+
"model_file_location": cache_manager.get_cache_dir(),
|
|
114
|
+
"cache_status": cache_manager.get_cache_status(),
|
|
115
|
+
"dimensions": self.dimensions,
|
|
116
|
+
"max_tokens": self.max_tokens,
|
|
104
117
|
}
|
|
105
118
|
|
|
106
119
|
|
|
120
|
+
def get_model_version(embedding_model: EmbeddingModelFamilyV2) -> str:
|
|
121
|
+
spec = embedding_model.model_specs[0]
|
|
122
|
+
return f"{embedding_model.model_name}--{embedding_model.max_tokens}--{embedding_model.dimensions}--{spec.model_format}--{spec.quantization}"
|
|
123
|
+
|
|
124
|
+
|
|
107
125
|
def generate_embedding_description(
|
|
108
|
-
|
|
126
|
+
model_family: EmbeddingModelFamilyV2,
|
|
109
127
|
) -> Dict[str, List[Dict]]:
|
|
110
128
|
res = defaultdict(list)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
129
|
+
specs = [x for x in model_family.model_specs if x.model_hub == "huggingface"]
|
|
130
|
+
for spec in specs:
|
|
131
|
+
family = model_family.copy()
|
|
132
|
+
family.model_specs = [spec]
|
|
133
|
+
res[model_family.model_name].append(family.to_version_info())
|
|
114
134
|
return res
|
|
115
135
|
|
|
116
136
|
|
|
117
|
-
def cache(model_spec: EmbeddingModelSpec):
|
|
118
|
-
from ..utils import cache
|
|
119
|
-
|
|
120
|
-
return cache(model_spec, EmbeddingModelDescription)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def get_cache_status(
|
|
124
|
-
model_spec: EmbeddingModelSpec,
|
|
125
|
-
) -> bool:
|
|
126
|
-
return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
import abc
|
|
130
|
-
from abc import abstractmethod
|
|
131
|
-
|
|
132
|
-
|
|
133
137
|
class EmbeddingModel(abc.ABC):
|
|
134
138
|
def __init__(
|
|
135
139
|
self,
|
|
136
140
|
model_uid: str,
|
|
137
141
|
model_path: str,
|
|
138
|
-
|
|
142
|
+
model_family: EmbeddingModelFamilyV2,
|
|
143
|
+
quantization: Optional[str] = None,
|
|
139
144
|
device: Optional[str] = None,
|
|
140
145
|
**kwargs,
|
|
141
146
|
):
|
|
@@ -145,8 +150,10 @@ class EmbeddingModel(abc.ABC):
|
|
|
145
150
|
self._model = None
|
|
146
151
|
self._tokenizer = None
|
|
147
152
|
self._counter = 0
|
|
148
|
-
self.
|
|
149
|
-
self.
|
|
153
|
+
self.model_family = model_family
|
|
154
|
+
self._model_spec = model_family.model_specs[0]
|
|
155
|
+
self._quantization = quantization
|
|
156
|
+
self._model_name = self.model_family.model_name
|
|
150
157
|
self._kwargs = kwargs
|
|
151
158
|
|
|
152
159
|
@classmethod
|
|
@@ -156,17 +163,27 @@ class EmbeddingModel(abc.ABC):
|
|
|
156
163
|
|
|
157
164
|
@classmethod
|
|
158
165
|
@abstractmethod
|
|
159
|
-
def match_json(
|
|
166
|
+
def match_json(
|
|
167
|
+
cls,
|
|
168
|
+
model_family: EmbeddingModelFamilyV2,
|
|
169
|
+
model_spec: EmbeddingSpecV1,
|
|
170
|
+
quantization: str,
|
|
171
|
+
) -> bool:
|
|
160
172
|
pass
|
|
161
173
|
|
|
162
174
|
@classmethod
|
|
163
|
-
def match(
|
|
175
|
+
def match(
|
|
176
|
+
cls,
|
|
177
|
+
model_family: EmbeddingModelFamilyV2,
|
|
178
|
+
model_spec: EmbeddingSpecV1,
|
|
179
|
+
quantization: str,
|
|
180
|
+
):
|
|
164
181
|
"""
|
|
165
182
|
Return if the model_spec can be matched.
|
|
166
183
|
"""
|
|
167
184
|
if not cls.check_lib():
|
|
168
185
|
return False
|
|
169
|
-
return cls.match_json(model_spec)
|
|
186
|
+
return cls.match_json(model_family, model_spec, quantization)
|
|
170
187
|
|
|
171
188
|
@abstractmethod
|
|
172
189
|
def load(self):
|
|
@@ -290,36 +307,39 @@ class EmbeddingModel(abc.ABC):
|
|
|
290
307
|
|
|
291
308
|
|
|
292
309
|
def create_embedding_model_instance(
|
|
293
|
-
subpool_addr: str,
|
|
294
|
-
devices: Optional[List[str]],
|
|
295
310
|
model_uid: str,
|
|
296
311
|
model_name: str,
|
|
297
312
|
model_engine: Optional[str],
|
|
313
|
+
model_format: Optional[str] = None,
|
|
314
|
+
quantization: Optional[str] = None,
|
|
298
315
|
download_hub: Optional[
|
|
299
316
|
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
300
317
|
] = None,
|
|
301
318
|
model_path: Optional[str] = None,
|
|
302
319
|
**kwargs,
|
|
303
|
-
) ->
|
|
304
|
-
|
|
320
|
+
) -> EmbeddingModel:
|
|
321
|
+
from .cache_manager import EmbeddingCacheManager
|
|
322
|
+
|
|
323
|
+
model_family = match_embedding(model_name, model_format, quantization, download_hub)
|
|
305
324
|
if model_path is None:
|
|
306
|
-
|
|
325
|
+
cache_manager = EmbeddingCacheManager(model_family)
|
|
326
|
+
model_path = cache_manager.cache()
|
|
307
327
|
|
|
308
328
|
if model_engine is None:
|
|
309
|
-
# unlike LLM and for compatibility
|
|
329
|
+
# unlike LLM and for compatibility,
|
|
310
330
|
# we use sentence_transformers as the default engine for all models
|
|
311
331
|
model_engine = "sentence_transformers"
|
|
312
332
|
|
|
313
333
|
from .embed_family import check_engine_by_model_name_and_engine
|
|
314
334
|
|
|
315
335
|
embedding_cls = check_engine_by_model_name_and_engine(
|
|
316
|
-
model_name,
|
|
317
|
-
model_engine,
|
|
336
|
+
model_engine, model_name, model_format, quantization
|
|
318
337
|
)
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
338
|
+
model = embedding_cls(
|
|
339
|
+
model_uid,
|
|
340
|
+
model_path,
|
|
341
|
+
model_family,
|
|
342
|
+
quantization,
|
|
343
|
+
**kwargs,
|
|
324
344
|
)
|
|
325
|
-
return model
|
|
345
|
+
return model
|
|
@@ -11,103 +11,80 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
import logging
|
|
15
|
-
import
|
|
16
|
-
from threading import Lock
|
|
17
|
-
from typing import List, Optional
|
|
16
|
+
from typing import List
|
|
18
17
|
|
|
19
|
-
from ...
|
|
20
|
-
from
|
|
18
|
+
from ..._compat import Literal
|
|
19
|
+
from ..custom import ModelRegistry
|
|
20
|
+
from .core import EmbeddingModelFamilyV2
|
|
21
21
|
|
|
22
22
|
logger = logging.getLogger(__name__)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
class CustomEmbeddingModelFamilyV2(EmbeddingModelFamilyV2):
|
|
26
|
+
version: Literal[2] = 2
|
|
26
27
|
|
|
27
28
|
|
|
28
|
-
|
|
29
|
-
model_id: Optional[str] # type: ignore
|
|
30
|
-
model_revision: Optional[str] # type: ignore
|
|
31
|
-
model_uri: Optional[str]
|
|
29
|
+
UD_EMBEDDINGS: List[CustomEmbeddingModelFamilyV2] = []
|
|
32
30
|
|
|
33
31
|
|
|
34
|
-
|
|
32
|
+
class EmbeddingModelRegistry(ModelRegistry):
|
|
33
|
+
model_type = "embedding"
|
|
35
34
|
|
|
35
|
+
def __init__(self):
|
|
36
|
+
from .embed_family import BUILTIN_EMBEDDING_MODELS
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.models = UD_EMBEDDINGS
|
|
40
|
+
self.builtin_models = list(BUILTIN_EMBEDDING_MODELS.keys())
|
|
40
41
|
|
|
42
|
+
def add_ud_model(self, model_spec):
|
|
43
|
+
from . import generate_engine_config_by_model_name
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
from ..utils import is_valid_model_name, is_valid_model_uri
|
|
45
|
-
from . import (
|
|
46
|
-
BUILTIN_EMBEDDING_MODELS,
|
|
47
|
-
MODELSCOPE_EMBEDDING_MODELS,
|
|
48
|
-
generate_engine_config_by_model_name,
|
|
49
|
-
)
|
|
45
|
+
UD_EMBEDDINGS.append(model_spec)
|
|
46
|
+
generate_engine_config_by_model_name(model_spec)
|
|
50
47
|
|
|
51
|
-
|
|
52
|
-
|
|
48
|
+
def check_model_uri(self, model_family: "EmbeddingModelFamilyV2"):
|
|
49
|
+
from ..utils import is_valid_model_uri
|
|
53
50
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
51
|
+
for spec in model_family.model_specs:
|
|
52
|
+
model_uri = spec.model_uri
|
|
53
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
54
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
57
55
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
list(BUILTIN_EMBEDDING_MODELS.keys())
|
|
61
|
-
+ list(MODELSCOPE_EMBEDDING_MODELS.keys())
|
|
62
|
-
+ [spec.model_name for spec in UD_EMBEDDINGS]
|
|
63
|
-
):
|
|
64
|
-
if model_spec.model_name == model_name:
|
|
65
|
-
raise ValueError(
|
|
66
|
-
f"Model name conflicts with existing model {model_spec.model_name}"
|
|
67
|
-
)
|
|
56
|
+
def remove_ud_model(self, model_family: "CustomEmbeddingModelFamilyV2"):
|
|
57
|
+
from .embed_family import EMBEDDING_ENGINES
|
|
68
58
|
|
|
69
|
-
UD_EMBEDDINGS.
|
|
70
|
-
|
|
59
|
+
UD_EMBEDDINGS.remove(model_family)
|
|
60
|
+
del EMBEDDING_ENGINES[model_family.model_name]
|
|
61
|
+
|
|
62
|
+
def remove_ud_model_files(self, model_family: "CustomEmbeddingModelFamilyV2"):
|
|
63
|
+
from .cache_manager import EmbeddingCacheManager
|
|
64
|
+
|
|
65
|
+
_model_family = model_family.copy()
|
|
66
|
+
for spec in model_family.model_specs:
|
|
67
|
+
_model_family.model_specs = [spec]
|
|
68
|
+
cache_manager = EmbeddingCacheManager(_model_family)
|
|
69
|
+
cache_manager.unregister_custom_model(self.model_type)
|
|
71
70
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
71
|
+
|
|
72
|
+
def get_user_defined_embeddings() -> List[EmbeddingModelFamilyV2]:
|
|
73
|
+
from ..custom import RegistryManager
|
|
74
|
+
|
|
75
|
+
registry = RegistryManager.get_registry("embedding")
|
|
76
|
+
return registry.get_custom_models()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def register_embedding(model_family: CustomEmbeddingModelFamilyV2, persist: bool):
|
|
80
|
+
from ..custom import RegistryManager
|
|
81
|
+
|
|
82
|
+
registry = RegistryManager.get_registry("embedding")
|
|
83
|
+
registry.register(model_family, persist)
|
|
79
84
|
|
|
80
85
|
|
|
81
86
|
def unregister_embedding(model_name: str, raise_error: bool = True):
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
model_spec = f
|
|
87
|
-
break
|
|
88
|
-
if model_spec:
|
|
89
|
-
UD_EMBEDDINGS.remove(model_spec)
|
|
90
|
-
|
|
91
|
-
persist_path = os.path.join(
|
|
92
|
-
XINFERENCE_MODEL_DIR, "embedding", f"{model_spec.model_name}.json"
|
|
93
|
-
)
|
|
94
|
-
if os.path.exists(persist_path):
|
|
95
|
-
os.remove(persist_path)
|
|
96
|
-
|
|
97
|
-
cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
|
|
98
|
-
if os.path.exists(cache_dir):
|
|
99
|
-
logger.warning(
|
|
100
|
-
f"Remove the cache of user-defined model {model_spec.model_name}. "
|
|
101
|
-
f"Cache directory: {cache_dir}"
|
|
102
|
-
)
|
|
103
|
-
if os.path.islink(cache_dir):
|
|
104
|
-
os.remove(cache_dir)
|
|
105
|
-
else:
|
|
106
|
-
logger.warning(
|
|
107
|
-
f"Cache directory is not a soft link, please remove it manually."
|
|
108
|
-
)
|
|
109
|
-
else:
|
|
110
|
-
if raise_error:
|
|
111
|
-
raise ValueError(f"Model {model_name} not found")
|
|
112
|
-
else:
|
|
113
|
-
logger.warning(f"Custom embedding model {model_name} not found")
|
|
87
|
+
from ..custom import RegistryManager
|
|
88
|
+
|
|
89
|
+
registry = RegistryManager.get_registry("embedding")
|
|
90
|
+
registry.unregister(model_name, raise_error)
|