xinference 1.7.1.post1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,15 @@ import os
18
18
  import warnings
19
19
  from typing import Any, Dict, List
20
20
 
21
+ from ..utils import flatten_quantizations
21
22
  from .core import (
22
23
  EMBEDDING_MODEL_DESCRIPTIONS,
23
- MODEL_NAME_TO_REVISION,
24
- EmbeddingModelSpec,
24
+ EmbeddingModelFamilyV2,
25
25
  generate_embedding_description,
26
- get_cache_status,
27
26
  get_embedding_model_descriptions,
28
27
  )
29
28
  from .custom import (
30
- CustomEmbeddingModelSpec,
29
+ CustomEmbeddingModelFamilyV2,
31
30
  get_user_defined_embeddings,
32
31
  register_embedding,
33
32
  unregister_embedding,
@@ -36,7 +35,7 @@ from .embed_family import (
36
35
  BUILTIN_EMBEDDING_MODELS,
37
36
  EMBEDDING_ENGINES,
38
37
  FLAG_EMBEDDER_CLASSES,
39
- MODELSCOPE_EMBEDDING_MODELS,
38
+ LLAMA_CPP_CLASSES,
40
39
  SENTENCE_TRANSFORMER_CLASSES,
41
40
  SUPPORTED_ENGINES,
42
41
  VLLM_CLASSES,
@@ -45,15 +44,19 @@ from .embed_family import (
45
44
 
46
45
  def register_custom_model():
47
46
  from ...constants import XINFERENCE_MODEL_DIR
47
+ from ..custom import migrate_from_v1_to_v2
48
48
 
49
- user_defined_embedding_dir = os.path.join(XINFERENCE_MODEL_DIR, "embedding")
49
+ # migrate from v1 to v2 first
50
+ migrate_from_v1_to_v2("embedding", CustomEmbeddingModelFamilyV2)
51
+
52
+ user_defined_embedding_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "embedding")
50
53
  if os.path.isdir(user_defined_embedding_dir):
51
54
  for f in os.listdir(user_defined_embedding_dir):
52
55
  try:
53
56
  with codecs.open(
54
57
  os.path.join(user_defined_embedding_dir, f), encoding="utf-8"
55
58
  ) as fd:
56
- user_defined_llm_family = CustomEmbeddingModelSpec.parse_obj(
59
+ user_defined_llm_family = CustomEmbeddingModelFamilyV2.parse_obj(
57
60
  json.load(fd)
58
61
  )
59
62
  register_embedding(user_defined_llm_family, persist=False)
@@ -61,80 +64,89 @@ def register_custom_model():
61
64
  warnings.warn(f"{user_defined_embedding_dir}/{f} has error, {e}")
62
65
 
63
66
 
64
- def generate_engine_config_by_model_name(model_spec: "EmbeddingModelSpec"):
65
- model_name = model_spec.model_name
67
+ def check_format_with_engine(model_format, engine):
68
+ if model_format in ["ggufv2"] and engine not in ["llama.cpp"]:
69
+ return False
70
+ if model_format not in ["ggufv2"] and engine == "llama.cpp":
71
+ return False
72
+ return True
73
+
74
+
75
+ def generate_engine_config_by_model_name(model_family: "EmbeddingModelFamilyV2"):
76
+ model_name = model_family.model_name
66
77
  engines: Dict[str, List[Dict[str, Any]]] = EMBEDDING_ENGINES.get(
67
78
  model_name, {}
68
79
  ) # structure for engine query
69
- for engine in SUPPORTED_ENGINES:
70
- CLASSES = SUPPORTED_ENGINES[engine]
71
- for cls in CLASSES:
72
- # Every engine needs to implement match method
73
- if cls.match(model_spec):
74
- # we only match the first class for an engine
75
- engines[engine] = [
76
- {
77
- "model_name": model_name,
78
- "embedding_class": cls,
79
- }
80
- ]
81
- break
80
+ for spec in [x for x in model_family.model_specs if x.model_hub == "huggingface"]:
81
+ model_format = spec.model_format
82
+ quantization = spec.quantization
83
+ for engine in SUPPORTED_ENGINES:
84
+ if not check_format_with_engine(model_format, engine):
85
+ continue
86
+ CLASSES = SUPPORTED_ENGINES[engine]
87
+ for cls in CLASSES:
88
+ # Every engine needs to implement match method
89
+ if cls.match(model_family, spec, quantization):
90
+ # we only match the first class for an engine
91
+ if engine not in engines:
92
+ engines[engine] = [
93
+ {
94
+ "model_name": model_name,
95
+ "model_format": model_format,
96
+ "quantization": quantization,
97
+ "embedding_class": cls,
98
+ }
99
+ ]
100
+ else:
101
+ engines[engine].append(
102
+ {
103
+ "model_name": model_name,
104
+ "model_format": model_format,
105
+ "quantization": quantization,
106
+ "embedding_class": cls,
107
+ }
108
+ )
109
+ break
82
110
  EMBEDDING_ENGINES[model_name] = engines
83
111
 
84
112
 
85
113
  # will be called in xinference/model/__init__.py
86
114
  def _install():
87
115
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
88
- _model_spec_modelscope_json = os.path.join(
89
- os.path.dirname(__file__), "model_spec_modelscope.json"
90
- )
91
- ################### HuggingFace Model List Info Init ###################
92
- BUILTIN_EMBEDDING_MODELS.update(
93
- dict(
94
- (spec["model_name"], EmbeddingModelSpec(**spec))
95
- for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
116
+
117
+ for json_obj in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")):
118
+ flattened = []
119
+ for spec in json_obj["model_specs"]:
120
+ flattened.extend(flatten_quantizations(spec))
121
+ json_obj["model_specs"] = flattened
122
+ BUILTIN_EMBEDDING_MODELS[json_obj["model_name"]] = EmbeddingModelFamilyV2(
123
+ **json_obj
96
124
  )
97
- )
125
+
98
126
  for model_name, model_spec in BUILTIN_EMBEDDING_MODELS.items():
99
- MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
100
-
101
- ################### ModelScope Model List Info Init ###################
102
- MODELSCOPE_EMBEDDING_MODELS.update(
103
- dict(
104
- (spec["model_name"], EmbeddingModelSpec(**spec))
105
- for spec in json.load(
106
- codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
127
+ if model_spec.model_name not in EMBEDDING_MODEL_DESCRIPTIONS:
128
+ EMBEDDING_MODEL_DESCRIPTIONS.update(
129
+ generate_embedding_description(model_spec)
107
130
  )
108
- )
109
- )
110
- for model_name, model_spec in MODELSCOPE_EMBEDDING_MODELS.items():
111
- MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
112
-
113
- # TODO: consider support more download hub in future...
114
- # register model description after recording model revision
115
- for model_spec_info in [BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS]:
116
- for model_name, model_spec in model_spec_info.items():
117
- if model_spec.model_name not in EMBEDDING_MODEL_DESCRIPTIONS:
118
- EMBEDDING_MODEL_DESCRIPTIONS.update(
119
- generate_embedding_description(model_spec)
120
- )
121
131
 
122
132
  from .flag.core import FlagEmbeddingModel
133
+ from .llama_cpp.core import XllamaCppEmbeddingModel
123
134
  from .sentence_transformers.core import SentenceTransformerEmbeddingModel
124
135
  from .vllm.core import VLLMEmbeddingModel
125
136
 
126
137
  SENTENCE_TRANSFORMER_CLASSES.extend([SentenceTransformerEmbeddingModel])
127
138
  FLAG_EMBEDDER_CLASSES.extend([FlagEmbeddingModel])
128
139
  VLLM_CLASSES.extend([VLLMEmbeddingModel])
140
+ LLAMA_CPP_CLASSES.extend([XllamaCppEmbeddingModel])
129
141
 
130
142
  SUPPORTED_ENGINES["sentence_transformers"] = SENTENCE_TRANSFORMER_CLASSES
131
143
  SUPPORTED_ENGINES["flag"] = FLAG_EMBEDDER_CLASSES
132
144
  SUPPORTED_ENGINES["vllm"] = VLLM_CLASSES
145
+ SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CPP_CLASSES
133
146
 
134
147
  # Init embedding engine
135
- for model_infos in [BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS]:
136
- for model_spec in model_infos.values():
137
- generate_engine_config_by_model_name(model_spec)
148
+ for model_spec in BUILTIN_EMBEDDING_MODELS.values():
149
+ generate_engine_config_by_model_name(model_spec)
138
150
 
139
151
  register_custom_model()
140
152
 
@@ -145,4 +157,3 @@ def _install():
145
157
  )
146
158
 
147
159
  del _model_spec_json
148
- del _model_spec_modelscope_json
@@ -0,0 +1,35 @@
1
+ import os
2
+ from typing import TYPE_CHECKING
3
+
4
+ from ..cache_manager import CacheManager
5
+
6
+ if TYPE_CHECKING:
7
+ from .core import EmbeddingModelFamilyV2
8
+
9
+
10
+ class EmbeddingCacheManager(CacheManager):
11
+ def __init__(self, model_family: "EmbeddingModelFamilyV2"):
12
+ from ..llm.cache_manager import LLMCacheManager
13
+
14
+ super().__init__(model_family)
15
+ # Composition design mode for avoiding duplicate code
16
+ self.cache_helper = LLMCacheManager(model_family)
17
+
18
+ spec = self._model_family.model_specs[0]
19
+ model_dir_name = (
20
+ f"{self._model_family.model_name}-{spec.model_format}-{spec.quantization}"
21
+ )
22
+ self._cache_dir = os.path.join(self._v2_cache_dir_prefix, model_dir_name)
23
+ self.cache_helper._cache_dir = self._cache_dir
24
+
25
+ def cache(self) -> str:
26
+ spec = self._model_family.model_specs[0]
27
+ if spec.model_uri is not None:
28
+ return self.cache_helper.cache_uri()
29
+ else:
30
+ if spec.model_hub == "huggingface":
31
+ return self.cache_helper.cache_from_huggingface()
32
+ elif spec.model_hub == "modelscope":
33
+ return self.cache_helper.cache_from_modelscope()
34
+ else:
35
+ raise ValueError(f"Unknown model hub: {spec.model_hub}")
@@ -12,23 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import abc
15
16
  import gc
16
17
  import logging
17
18
  import os
19
+ from abc import abstractmethod
18
20
  from collections import defaultdict
19
- from typing import Dict, List, Literal, Optional, Tuple, Union
21
+ from typing import Annotated, Dict, List, Literal, Optional, Union
20
22
 
21
- from ..._compat import ROOT_KEY, ErrorWrapper, ValidationError
23
+ from ..._compat import ROOT_KEY, BaseModel, ErrorWrapper, Field, ValidationError
22
24
  from ...device_utils import empty_cache
23
- from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
24
- from ..utils import get_cache_dir, is_model_cached
25
+ from ..core import VirtualEnvSettings
26
+ from ..utils import ModelInstanceInfoMixin
25
27
  from .embed_family import match_embedding
26
28
 
27
29
  logger = logging.getLogger(__name__)
28
30
 
29
31
  # Used for check whether the model is cached.
30
32
  # Init when registering all the builtin models.
31
- MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
32
33
  EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
33
34
  EMBEDDING_EMPTY_CACHE_COUNT = int(
34
35
  os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
@@ -46,96 +47,100 @@ def get_embedding_model_descriptions():
46
47
  return copy.deepcopy(EMBEDDING_MODEL_DESCRIPTIONS)
47
48
 
48
49
 
50
+ class TransformersEmbeddingSpecV1(BaseModel):
51
+ model_format: Literal["pytorch"]
52
+ model_hub: str = "huggingface"
53
+ model_id: Optional[str]
54
+ model_uri: Optional[str]
55
+ model_revision: Optional[str]
56
+ quantization: str
57
+
58
+
59
+ class LlamaCppEmbeddingSpecV1(BaseModel):
60
+ model_format: Literal["ggufv2"]
61
+ model_hub: str = "huggingface"
62
+ model_id: Optional[str]
63
+ model_uri: Optional[str]
64
+ model_revision: Optional[str]
65
+ quantization: str
66
+ model_file_name_template: str
67
+ model_file_name_split_template: Optional[str]
68
+ quantization_parts: Optional[Dict[str, List[str]]]
69
+
70
+
71
+ EmbeddingSpecV1 = Annotated[
72
+ Union[TransformersEmbeddingSpecV1, LlamaCppEmbeddingSpecV1],
73
+ Field(discriminator="model_format"),
74
+ ]
75
+
76
+
49
77
  # this class define the basic info of embedding model
50
- class EmbeddingModelSpec(CacheableModelSpec):
78
+ class EmbeddingModelFamilyV2(BaseModel, ModelInstanceInfoMixin):
79
+ version: Literal[2]
51
80
  model_name: str
52
81
  dimensions: int
53
82
  max_tokens: int
54
83
  language: List[str]
55
- model_id: str
56
- model_revision: Optional[str]
57
- model_hub: str = "huggingface"
84
+ model_specs: List["EmbeddingSpecV1"]
85
+ cache_config: Optional[dict]
58
86
  virtualenv: Optional[VirtualEnvSettings]
59
87
 
88
+ class Config:
89
+ extra = "allow"
60
90
 
61
- class EmbeddingModelDescription(ModelDescription):
62
- def __init__(
63
- self,
64
- address: Optional[str],
65
- devices: Optional[List[str]],
66
- model_spec: EmbeddingModelSpec,
67
- model_path: Optional[str] = None,
68
- ):
69
- super().__init__(address, devices, model_path=model_path)
70
- self._model_spec = model_spec
71
-
72
- @property
73
- def spec(self):
74
- return self._model_spec
75
-
76
- def to_dict(self):
91
+ def to_description(self):
92
+ spec = self.model_specs[0]
77
93
  return {
78
94
  "model_type": "embedding",
79
- "address": self.address,
80
- "accelerators": self.devices,
81
- "model_name": self._model_spec.model_name,
82
- "dimensions": self._model_spec.dimensions,
83
- "max_tokens": self._model_spec.max_tokens,
84
- "language": self._model_spec.language,
85
- "model_revision": self._model_spec.model_revision,
95
+ "address": getattr(self, "address", None),
96
+ "accelerators": getattr(self, "accelerators", None),
97
+ "model_name": self.model_name,
98
+ "dimensions": self.dimensions,
99
+ "max_tokens": self.max_tokens,
100
+ "language": self.language,
101
+ "model_hub": spec.model_hub,
102
+ "model_revision": spec.model_revision,
103
+ "quantization": spec.quantization,
86
104
  }
87
105
 
88
106
  def to_version_info(self):
89
- from .utils import get_model_version
107
+ from .cache_manager import EmbeddingCacheManager
90
108
 
91
- if self._model_path is None:
92
- is_cached = get_cache_status(self._model_spec)
93
- file_location = get_cache_dir(self._model_spec)
94
- else:
95
- is_cached = True
96
- file_location = self._model_path
109
+ cache_manager = EmbeddingCacheManager(self)
97
110
 
98
111
  return {
99
- "model_version": get_model_version(self._model_spec),
100
- "model_file_location": file_location,
101
- "cache_status": is_cached,
102
- "dimensions": self._model_spec.dimensions,
103
- "max_tokens": self._model_spec.max_tokens,
112
+ "model_version": get_model_version(self),
113
+ "model_file_location": cache_manager.get_cache_dir(),
114
+ "cache_status": cache_manager.get_cache_status(),
115
+ "dimensions": self.dimensions,
116
+ "max_tokens": self.max_tokens,
104
117
  }
105
118
 
106
119
 
120
+ def get_model_version(embedding_model: EmbeddingModelFamilyV2) -> str:
121
+ spec = embedding_model.model_specs[0]
122
+ return f"{embedding_model.model_name}--{embedding_model.max_tokens}--{embedding_model.dimensions}--{spec.model_format}--{spec.quantization}"
123
+
124
+
107
125
  def generate_embedding_description(
108
- model_spec: EmbeddingModelSpec,
126
+ model_family: EmbeddingModelFamilyV2,
109
127
  ) -> Dict[str, List[Dict]]:
110
128
  res = defaultdict(list)
111
- res[model_spec.model_name].append(
112
- EmbeddingModelDescription(None, None, model_spec).to_version_info()
113
- )
129
+ specs = [x for x in model_family.model_specs if x.model_hub == "huggingface"]
130
+ for spec in specs:
131
+ family = model_family.copy()
132
+ family.model_specs = [spec]
133
+ res[model_family.model_name].append(family.to_version_info())
114
134
  return res
115
135
 
116
136
 
117
- def cache(model_spec: EmbeddingModelSpec):
118
- from ..utils import cache
119
-
120
- return cache(model_spec, EmbeddingModelDescription)
121
-
122
-
123
- def get_cache_status(
124
- model_spec: EmbeddingModelSpec,
125
- ) -> bool:
126
- return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
127
-
128
-
129
- import abc
130
- from abc import abstractmethod
131
-
132
-
133
137
  class EmbeddingModel(abc.ABC):
134
138
  def __init__(
135
139
  self,
136
140
  model_uid: str,
137
141
  model_path: str,
138
- model_spec: EmbeddingModelSpec,
142
+ model_family: EmbeddingModelFamilyV2,
143
+ quantization: Optional[str] = None,
139
144
  device: Optional[str] = None,
140
145
  **kwargs,
141
146
  ):
@@ -145,8 +150,10 @@ class EmbeddingModel(abc.ABC):
145
150
  self._model = None
146
151
  self._tokenizer = None
147
152
  self._counter = 0
148
- self._model_spec = model_spec
149
- self._model_name = self._model_spec.model_name
153
+ self.model_family = model_family
154
+ self._model_spec = model_family.model_specs[0]
155
+ self._quantization = quantization
156
+ self._model_name = self.model_family.model_name
150
157
  self._kwargs = kwargs
151
158
 
152
159
  @classmethod
@@ -156,17 +163,27 @@ class EmbeddingModel(abc.ABC):
156
163
 
157
164
  @classmethod
158
165
  @abstractmethod
159
- def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
166
+ def match_json(
167
+ cls,
168
+ model_family: EmbeddingModelFamilyV2,
169
+ model_spec: EmbeddingSpecV1,
170
+ quantization: str,
171
+ ) -> bool:
160
172
  pass
161
173
 
162
174
  @classmethod
163
- def match(cls, model_spec: EmbeddingModelSpec):
175
+ def match(
176
+ cls,
177
+ model_family: EmbeddingModelFamilyV2,
178
+ model_spec: EmbeddingSpecV1,
179
+ quantization: str,
180
+ ):
164
181
  """
165
182
  Return if the model_spec can be matched.
166
183
  """
167
184
  if not cls.check_lib():
168
185
  return False
169
- return cls.match_json(model_spec)
186
+ return cls.match_json(model_family, model_spec, quantization)
170
187
 
171
188
  @abstractmethod
172
189
  def load(self):
@@ -290,36 +307,39 @@ class EmbeddingModel(abc.ABC):
290
307
 
291
308
 
292
309
  def create_embedding_model_instance(
293
- subpool_addr: str,
294
- devices: Optional[List[str]],
295
310
  model_uid: str,
296
311
  model_name: str,
297
312
  model_engine: Optional[str],
313
+ model_format: Optional[str] = None,
314
+ quantization: Optional[str] = None,
298
315
  download_hub: Optional[
299
316
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
300
317
  ] = None,
301
318
  model_path: Optional[str] = None,
302
319
  **kwargs,
303
- ) -> Tuple[EmbeddingModel, EmbeddingModelDescription]:
304
- model_spec = match_embedding(model_name, download_hub)
320
+ ) -> EmbeddingModel:
321
+ from .cache_manager import EmbeddingCacheManager
322
+
323
+ model_family = match_embedding(model_name, model_format, quantization, download_hub)
305
324
  if model_path is None:
306
- model_path = cache(model_spec)
325
+ cache_manager = EmbeddingCacheManager(model_family)
326
+ model_path = cache_manager.cache()
307
327
 
308
328
  if model_engine is None:
309
- # unlike LLM and for compatibility
329
+ # unlike LLM and for compatibility,
310
330
  # we use sentence_transformers as the default engine for all models
311
331
  model_engine = "sentence_transformers"
312
332
 
313
333
  from .embed_family import check_engine_by_model_name_and_engine
314
334
 
315
335
  embedding_cls = check_engine_by_model_name_and_engine(
316
- model_name,
317
- model_engine,
336
+ model_engine, model_name, model_format, quantization
318
337
  )
319
- devices = devices or ["cpu"]
320
- # model class should be one of flag, fastembed, sentence_transformers
321
- model = embedding_cls(model_uid, model_path, model_spec, **kwargs)
322
- model_description = EmbeddingModelDescription(
323
- subpool_addr, devices, model_spec, model_path=model_path
338
+ model = embedding_cls(
339
+ model_uid,
340
+ model_path,
341
+ model_family,
342
+ quantization,
343
+ **kwargs,
324
344
  )
325
- return model, model_description
345
+ return model
@@ -11,103 +11,80 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import logging
15
- import os
16
- from threading import Lock
17
- from typing import List, Optional
16
+ from typing import List
18
17
 
19
- from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
20
- from .core import EmbeddingModelSpec
18
+ from ..._compat import Literal
19
+ from ..custom import ModelRegistry
20
+ from .core import EmbeddingModelFamilyV2
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
24
 
25
- UD_EMBEDDING_LOCK = Lock()
25
+ class CustomEmbeddingModelFamilyV2(EmbeddingModelFamilyV2):
26
+ version: Literal[2] = 2
26
27
 
27
28
 
28
- class CustomEmbeddingModelSpec(EmbeddingModelSpec):
29
- model_id: Optional[str] # type: ignore
30
- model_revision: Optional[str] # type: ignore
31
- model_uri: Optional[str]
29
+ UD_EMBEDDINGS: List[CustomEmbeddingModelFamilyV2] = []
32
30
 
33
31
 
34
- UD_EMBEDDINGS: List[CustomEmbeddingModelSpec] = []
32
+ class EmbeddingModelRegistry(ModelRegistry):
33
+ model_type = "embedding"
35
34
 
35
+ def __init__(self):
36
+ from .embed_family import BUILTIN_EMBEDDING_MODELS
36
37
 
37
- def get_user_defined_embeddings() -> List[EmbeddingModelSpec]:
38
- with UD_EMBEDDING_LOCK:
39
- return UD_EMBEDDINGS.copy()
38
+ super().__init__()
39
+ self.models = UD_EMBEDDINGS
40
+ self.builtin_models = list(BUILTIN_EMBEDDING_MODELS.keys())
40
41
 
42
+ def add_ud_model(self, model_spec):
43
+ from . import generate_engine_config_by_model_name
41
44
 
42
- def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
43
- from ...constants import XINFERENCE_MODEL_DIR
44
- from ..utils import is_valid_model_name, is_valid_model_uri
45
- from . import (
46
- BUILTIN_EMBEDDING_MODELS,
47
- MODELSCOPE_EMBEDDING_MODELS,
48
- generate_engine_config_by_model_name,
49
- )
45
+ UD_EMBEDDINGS.append(model_spec)
46
+ generate_engine_config_by_model_name(model_spec)
50
47
 
51
- if not is_valid_model_name(model_spec.model_name):
52
- raise ValueError(f"Invalid model name {model_spec.model_name}.")
48
+ def check_model_uri(self, model_family: "EmbeddingModelFamilyV2"):
49
+ from ..utils import is_valid_model_uri
53
50
 
54
- model_uri = model_spec.model_uri
55
- if model_uri and not is_valid_model_uri(model_uri):
56
- raise ValueError(f"Invalid model URI {model_uri}.")
51
+ for spec in model_family.model_specs:
52
+ model_uri = spec.model_uri
53
+ if model_uri and not is_valid_model_uri(model_uri):
54
+ raise ValueError(f"Invalid model URI {model_uri}.")
57
55
 
58
- with UD_EMBEDDING_LOCK:
59
- for model_name in (
60
- list(BUILTIN_EMBEDDING_MODELS.keys())
61
- + list(MODELSCOPE_EMBEDDING_MODELS.keys())
62
- + [spec.model_name for spec in UD_EMBEDDINGS]
63
- ):
64
- if model_spec.model_name == model_name:
65
- raise ValueError(
66
- f"Model name conflicts with existing model {model_spec.model_name}"
67
- )
56
+ def remove_ud_model(self, model_family: "CustomEmbeddingModelFamilyV2"):
57
+ from .embed_family import EMBEDDING_ENGINES
68
58
 
69
- UD_EMBEDDINGS.append(model_spec)
70
- generate_engine_config_by_model_name(model_spec)
59
+ UD_EMBEDDINGS.remove(model_family)
60
+ del EMBEDDING_ENGINES[model_family.model_name]
61
+
62
+ def remove_ud_model_files(self, model_family: "CustomEmbeddingModelFamilyV2"):
63
+ from .cache_manager import EmbeddingCacheManager
64
+
65
+ _model_family = model_family.copy()
66
+ for spec in model_family.model_specs:
67
+ _model_family.model_specs = [spec]
68
+ cache_manager = EmbeddingCacheManager(_model_family)
69
+ cache_manager.unregister_custom_model(self.model_type)
71
70
 
72
- if persist:
73
- persist_path = os.path.join(
74
- XINFERENCE_MODEL_DIR, "embedding", f"{model_spec.model_name}.json"
75
- )
76
- os.makedirs(os.path.dirname(persist_path), exist_ok=True)
77
- with open(persist_path, mode="w") as fd:
78
- fd.write(model_spec.json())
71
+
72
+ def get_user_defined_embeddings() -> List[EmbeddingModelFamilyV2]:
73
+ from ..custom import RegistryManager
74
+
75
+ registry = RegistryManager.get_registry("embedding")
76
+ return registry.get_custom_models()
77
+
78
+
79
+ def register_embedding(model_family: CustomEmbeddingModelFamilyV2, persist: bool):
80
+ from ..custom import RegistryManager
81
+
82
+ registry = RegistryManager.get_registry("embedding")
83
+ registry.register(model_family, persist)
79
84
 
80
85
 
81
86
  def unregister_embedding(model_name: str, raise_error: bool = True):
82
- with UD_EMBEDDING_LOCK:
83
- model_spec = None
84
- for i, f in enumerate(UD_EMBEDDINGS):
85
- if f.model_name == model_name:
86
- model_spec = f
87
- break
88
- if model_spec:
89
- UD_EMBEDDINGS.remove(model_spec)
90
-
91
- persist_path = os.path.join(
92
- XINFERENCE_MODEL_DIR, "embedding", f"{model_spec.model_name}.json"
93
- )
94
- if os.path.exists(persist_path):
95
- os.remove(persist_path)
96
-
97
- cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
98
- if os.path.exists(cache_dir):
99
- logger.warning(
100
- f"Remove the cache of user-defined model {model_spec.model_name}. "
101
- f"Cache directory: {cache_dir}"
102
- )
103
- if os.path.islink(cache_dir):
104
- os.remove(cache_dir)
105
- else:
106
- logger.warning(
107
- f"Cache directory is not a soft link, please remove it manually."
108
- )
109
- else:
110
- if raise_error:
111
- raise ValueError(f"Model {model_name} not found")
112
- else:
113
- logger.warning(f"Custom embedding model {model_name} not found")
87
+ from ..custom import RegistryManager
88
+
89
+ registry = RegistryManager.get_registry("embedding")
90
+ registry.unregister(model_name, raise_error)