xinference 1.7.1.post1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -14,20 +14,16 @@
14
14
 
15
15
  import importlib.util
16
16
  import logging
17
- from collections import defaultdict
18
17
  from typing import List, Optional, Union, no_type_check
19
18
 
20
19
  import numpy as np
21
20
  import torch
22
21
 
23
- from ....types import Dict, Embedding, EmbeddingData, EmbeddingUsage
24
- from ..core import EmbeddingModel, EmbeddingModelSpec
22
+ from ....device_utils import is_device_available
23
+ from ....types import Embedding, EmbeddingData, EmbeddingUsage
24
+ from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
25
25
 
26
26
  logger = logging.getLogger(__name__)
27
-
28
- # Used for check whether the model is cached.
29
- # Init when registering all the builtin models.
30
- MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
31
27
  SENTENCE_TRANSFORMER_MODEL_LIST: List[str] = []
32
28
 
33
29
 
@@ -76,8 +72,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
76
72
  torch_dtype = torch.float32
77
73
 
78
74
  if (
79
- "gte" in self._model_spec.model_name.lower()
80
- and "qwen2" in self._model_spec.model_name.lower()
75
+ "gte" in self.model_family.model_name.lower()
76
+ and "qwen2" in self.model_family.model_name.lower()
81
77
  ):
82
78
  model_kwargs = {"device_map": "auto"}
83
79
  if torch_dtype:
@@ -87,10 +83,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
87
83
  device=self._device,
88
84
  model_kwargs=model_kwargs,
89
85
  )
90
- elif "qwen3" in self._model_spec.model_name.lower():
86
+ elif "qwen3" in self.model_family.model_name.lower():
91
87
  # qwen3 embedding
92
88
  flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
93
- flash_attn_enabled = self._kwargs.get("enable_flash_attn", True)
89
+ flash_attn_enabled = self._kwargs.get(
90
+ "enable_flash_attn", is_device_available("cuda")
91
+ )
94
92
  model_kwargs = {"device_map": "auto"}
95
93
  tokenizer_kwargs = {}
96
94
  if flash_attn_installed and flash_attn_enabled:
@@ -119,7 +117,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
119
117
  trust_remote_code=True,
120
118
  )
121
119
 
122
- self._tokenizer = self._model.tokenizer
120
+ if hasattr(self._model, "tokenizer"):
121
+ self._tokenizer = self._model.tokenizer
123
122
 
124
123
  def create_embedding(
125
124
  self,
@@ -227,8 +226,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
227
226
  device = model._target_device
228
227
 
229
228
  if (
230
- "gte" in self._model_spec.model_name.lower()
231
- and "qwen2" in self._model_spec.model_name.lower()
229
+ "gte" in self.model_family.model_name.lower()
230
+ and "qwen2" in self.model_family.model_name.lower()
232
231
  ):
233
232
  model.to(device)
234
233
 
@@ -254,7 +253,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
254
253
  features.update(extra_features)
255
254
  # when batching, the attention mask 1 means there is a token
256
255
  # thus we just sum up it to get the total number of tokens
257
- if "clip" in self._model_spec.model_name.lower():
256
+ if (
257
+ "clip" in self.model_family.model_name.lower()
258
+ or "jina-embeddings-v4" in self.model_family.model_name.lower()
259
+ ):
258
260
  if "input_ids" in features and hasattr(
259
261
  features["input_ids"], "numel"
260
262
  ):
@@ -322,8 +324,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
322
324
 
323
325
  # seems already support prompt in embedding model
324
326
  if (
325
- "gte" in self._model_spec.model_name.lower()
326
- and "qwen2" in self._model_spec.model_name.lower()
327
+ "gte" in self.model_family.model_name.lower()
328
+ and "qwen2" in self.model_family.model_name.lower()
327
329
  ):
328
330
  all_embeddings, all_token_nums = encode(
329
331
  self._model,
@@ -332,7 +334,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
332
334
  convert_to_numpy=False,
333
335
  **kwargs,
334
336
  )
335
- elif "clip" in self._model_spec.model_name.lower():
337
+ elif (
338
+ "clip" in self.model_family.model_name.lower()
339
+ or "jina-embeddings-v4" in self.model_family.model_name.lower()
340
+ ):
336
341
  import base64
337
342
  import re
338
343
  from io import BytesIO
@@ -409,6 +414,11 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
409
414
  return importlib.util.find_spec("sentence_transformers") is not None
410
415
 
411
416
  @classmethod
412
- def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
417
+ def match_json(
418
+ cls,
419
+ model_family: EmbeddingModelFamilyV2,
420
+ model_spec: EmbeddingSpecV1,
421
+ quantization: str,
422
+ ) -> bool:
413
423
  # As default embedding engine, sentence-transformer support all models
414
- return True
424
+ return model_spec.model_format in ["pytorch"]
@@ -17,7 +17,7 @@ import logging
17
17
  from typing import List, Union
18
18
 
19
19
  from ....types import Embedding, EmbeddingData, EmbeddingUsage
20
- from ..core import EmbeddingModel, EmbeddingModelSpec
20
+ from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
  SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
@@ -88,8 +88,14 @@ class VLLMEmbeddingModel(EmbeddingModel):
88
88
  return importlib.util.find_spec("vllm") is not None
89
89
 
90
90
  @classmethod
91
- def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
92
- prefix = model_spec.model_name.split("-", 1)[0]
93
- if prefix in SUPPORTED_MODELS_PREFIXES:
94
- return True
91
+ def match_json(
92
+ cls,
93
+ model_family: EmbeddingModelFamilyV2,
94
+ model_spec: EmbeddingSpecV1,
95
+ quantization: str,
96
+ ) -> bool:
97
+ if model_spec.model_format in ["pytorch"]:
98
+ prefix = model_family.model_name.split("-", 1)[0]
99
+ if prefix in SUPPORTED_MODELS_PREFIXES:
100
+ return True
95
101
  return False
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import codecs
16
15
  import json
17
16
  import logging
@@ -25,6 +24,8 @@ from .core import (
25
24
  FlexibleModelSpec,
26
25
  generate_flexible_model_description,
27
26
  get_flexible_model_descriptions,
27
+ )
28
+ from .custom import (
28
29
  get_flexible_models,
29
30
  register_flexible_model,
30
31
  unregister_flexible_model,
@@ -34,7 +35,12 @@ logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
37
  def register_custom_model():
37
- model_dir = os.path.join(XINFERENCE_MODEL_DIR, "flexible")
38
+ from ..custom import migrate_from_v1_to_v2
39
+
40
+ # migrate from v1 to v2 first
41
+ migrate_from_v1_to_v2("flexible", FlexibleModelSpec)
42
+
43
+ model_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "flexible")
38
44
  if os.path.isdir(model_dir):
39
45
  for f in os.listdir(model_dir):
40
46
  try:
@@ -14,21 +14,19 @@
14
14
 
15
15
  import json
16
16
  import logging
17
- import os
18
17
  from collections import defaultdict
19
- from threading import Lock
20
- from typing import Dict, List, Optional, Tuple
18
+ from typing import Dict, List, Optional
21
19
 
22
- from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
23
- from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
20
+ from ..._compat import Literal
21
+ from ..core import CacheableModelSpec, VirtualEnvSettings
22
+ from ..utils import ModelInstanceInfoMixin
24
23
  from .utils import get_launcher
25
24
 
26
25
  logger = logging.getLogger(__name__)
27
26
 
28
- FLEXIBLE_MODEL_LOCK = Lock()
29
27
 
30
-
31
- class FlexibleModelSpec(CacheableModelSpec):
28
+ class FlexibleModelSpec(CacheableModelSpec, ModelInstanceInfoMixin):
29
+ version: Literal[1, 2] = 2
32
30
  model_id: Optional[str] # type: ignore
33
31
  model_description: Optional[str]
34
32
  model_uri: Optional[str]
@@ -39,42 +37,26 @@ class FlexibleModelSpec(CacheableModelSpec):
39
37
  def parser_args(self):
40
38
  return json.loads(self.launcher_args)
41
39
 
40
+ class Config:
41
+ extra = "allow"
42
42
 
43
- class FlexibleModelDescription(ModelDescription):
44
- def __init__(
45
- self,
46
- address: Optional[str],
47
- devices: Optional[List[str]],
48
- model_spec: FlexibleModelSpec,
49
- model_path: Optional[str] = None,
50
- ):
51
- super().__init__(address, devices, model_path=model_path)
52
- self._model_spec = model_spec
53
-
54
- @property
55
- def spec(self):
56
- return self._model_spec
57
-
58
- def to_dict(self):
43
+ def to_description(self):
59
44
  return {
60
45
  "model_type": "flexible",
61
- "address": self.address,
62
- "accelerators": self.devices,
63
- "model_name": self._model_spec.model_name,
64
- "launcher": self._model_spec.launcher,
65
- "launcher_args": self._model_spec.launcher_args,
46
+ "address": getattr(self, "address", None),
47
+ "accelerators": getattr(self, "accelerators", None),
48
+ "model_name": self.model_name,
49
+ "launcher": self.launcher,
50
+ "launcher_args": self.launcher_args,
66
51
  }
67
52
 
68
- def get_model_version(self) -> str:
69
- return f"{self._model_spec.model_name}"
70
-
71
53
  def to_version_info(self):
72
54
  return {
73
- "model_version": self.get_model_version(),
55
+ "model_version": self.model_name,
74
56
  "cache_status": True,
75
- "model_file_location": self._model_spec.model_uri,
76
- "launcher": self._model_spec.launcher,
77
- "launcher_args": self._model_spec.launcher_args,
57
+ "model_file_location": self.model_uri,
58
+ "launcher": self.launcher,
59
+ "launcher_args": self.launcher_args,
78
60
  }
79
61
 
80
62
 
@@ -82,9 +64,7 @@ def generate_flexible_model_description(
82
64
  model_spec: FlexibleModelSpec,
83
65
  ) -> Dict[str, List[Dict]]:
84
66
  res = defaultdict(list)
85
- res[model_spec.model_name].append(
86
- FlexibleModelDescription(None, None, model_spec).to_version_info()
87
- )
67
+ res[model_spec.model_name].append(model_spec.to_version_info())
88
68
  return res
89
69
 
90
70
 
@@ -92,93 +72,22 @@ FLEXIBLE_MODELS: List[FlexibleModelSpec] = []
92
72
  FLEXIBLE_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
93
73
 
94
74
 
95
- def get_flexible_models():
96
- with FLEXIBLE_MODEL_LOCK:
97
- return FLEXIBLE_MODELS.copy()
98
-
99
-
100
75
  def get_flexible_model_descriptions():
101
76
  import copy
102
77
 
103
78
  return copy.deepcopy(FLEXIBLE_MODEL_DESCRIPTIONS)
104
79
 
105
80
 
106
- def register_flexible_model(model_spec: FlexibleModelSpec, persist: bool):
107
- from ..utils import is_valid_model_name, is_valid_model_uri
108
-
109
- if not is_valid_model_name(model_spec.model_name):
110
- raise ValueError(f"Invalid model name {model_spec.model_name}.")
111
-
112
- model_uri = model_spec.model_uri
113
- if model_uri and not is_valid_model_uri(model_uri):
114
- raise ValueError(f"Invalid model URI {model_uri}.")
115
-
116
- if model_spec.launcher_args:
117
- try:
118
- model_spec.parser_args()
119
- except Exception:
120
- raise ValueError(f"Invalid model launcher args {model_spec.launcher_args}.")
121
-
122
- with FLEXIBLE_MODEL_LOCK:
123
- for model_name in [spec.model_name for spec in FLEXIBLE_MODELS]:
124
- if model_spec.model_name == model_name:
125
- raise ValueError(
126
- f"Model name conflicts with existing model {model_spec.model_name}"
127
- )
128
- FLEXIBLE_MODELS.append(model_spec)
129
-
130
- if persist:
131
- persist_path = os.path.join(
132
- XINFERENCE_MODEL_DIR, "flexible", f"{model_spec.model_name}.json"
133
- )
134
- os.makedirs(os.path.dirname(persist_path), exist_ok=True)
135
- with open(persist_path, mode="w") as fd:
136
- fd.write(model_spec.json())
137
-
138
-
139
- def unregister_flexible_model(model_name: str, raise_error: bool = True):
140
- with FLEXIBLE_MODEL_LOCK:
141
- model_spec = None
142
- for i, f in enumerate(FLEXIBLE_MODELS):
143
- if f.model_name == model_name:
144
- model_spec = f
145
- break
146
- if model_spec:
147
- FLEXIBLE_MODELS.remove(model_spec)
148
-
149
- persist_path = os.path.join(
150
- XINFERENCE_MODEL_DIR, "flexible", f"{model_spec.model_name}.json"
151
- )
152
- if os.path.exists(persist_path):
153
- os.remove(persist_path)
154
-
155
- cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
156
- if os.path.exists(cache_dir):
157
- logger.warning(
158
- f"Remove the cache of user-defined model {model_spec.model_name}. "
159
- f"Cache directory: {cache_dir}"
160
- )
161
- if os.path.islink(cache_dir):
162
- os.remove(cache_dir)
163
- else:
164
- logger.warning(
165
- f"Cache directory is not a soft link, please remove it manually."
166
- )
167
- else:
168
- if raise_error:
169
- raise ValueError(f"Model {model_name} not found")
170
- else:
171
- logger.warning(f"Model {model_name} not found")
172
-
173
-
174
81
  class FlexibleModel:
175
82
  def __init__(
176
83
  self,
177
84
  model_uid: str,
178
85
  model_path: str,
86
+ model_family: FlexibleModelSpec,
179
87
  device: Optional[str] = None,
180
88
  config: Optional[Dict] = None,
181
89
  ):
90
+ self.model_family = model_family
182
91
  self._model_uid = model_uid
183
92
  self._model_path = model_path
184
93
  self._device = device
@@ -213,19 +122,20 @@ class FlexibleModel:
213
122
 
214
123
 
215
124
  def match_flexible_model(model_name):
125
+ from .custom import get_flexible_models
126
+
216
127
  for model_spec in get_flexible_models():
217
128
  if model_name == model_spec.model_name:
218
129
  return model_spec
130
+ return None
219
131
 
220
132
 
221
133
  def create_flexible_model_instance(
222
- subpool_addr: str,
223
- devices: List[str],
224
134
  model_uid: str,
225
135
  model_name: str,
226
136
  model_path: Optional[str] = None,
227
137
  **kwargs,
228
- ) -> Tuple[FlexibleModel, FlexibleModelDescription]:
138
+ ) -> FlexibleModel:
229
139
  model_spec = match_flexible_model(model_name)
230
140
  if not model_path:
231
141
  model_path = model_spec.model_uri
@@ -237,7 +147,4 @@ def create_flexible_model_instance(
237
147
  model_uid=model_uid, model_spec=model_spec, **kwargs
238
148
  )
239
149
 
240
- model_description = FlexibleModelDescription(
241
- subpool_addr, devices, model_spec, model_path=model_path
242
- )
243
- return model, model_description
150
+ return model
@@ -0,0 +1,69 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ..custom import ModelRegistry
4
+
5
+ if TYPE_CHECKING:
6
+ from .core import FlexibleModelSpec
7
+
8
+
9
+ class FlexibleModelRegistry(ModelRegistry):
10
+ model_type = "flexible"
11
+
12
+ def __init__(self):
13
+ from .core import FLEXIBLE_MODELS
14
+
15
+ super().__init__()
16
+ self.models = FLEXIBLE_MODELS
17
+ self.builtin_models = []
18
+
19
+ def register(self, model_spec: "FlexibleModelSpec", persist: bool):
20
+ from ..cache_manager import CacheManager
21
+ from ..utils import is_valid_model_name, is_valid_model_uri
22
+
23
+ if not is_valid_model_name(model_spec.model_name):
24
+ raise ValueError(f"Invalid model name {model_spec.model_name}.")
25
+
26
+ model_uri = model_spec.model_uri
27
+ if model_uri and not is_valid_model_uri(model_uri):
28
+ raise ValueError(f"Invalid model URI {model_uri}.")
29
+
30
+ if model_spec.launcher_args:
31
+ try:
32
+ model_spec.parser_args()
33
+ except Exception:
34
+ raise ValueError(
35
+ f"Invalid model launcher args {model_spec.launcher_args}."
36
+ )
37
+
38
+ with self.lock:
39
+ for model_name in [spec.model_name for spec in self.models]:
40
+ if model_spec.model_name == model_name:
41
+ raise ValueError(
42
+ f"Model name conflicts with existing model {model_spec.model_name}"
43
+ )
44
+ self.models.append(model_spec)
45
+
46
+ if persist:
47
+ cache_manager = CacheManager(model_spec)
48
+ cache_manager.register_custom_model(self.model_type)
49
+
50
+
51
+ def get_flexible_models():
52
+ from ..custom import RegistryManager
53
+
54
+ registry = RegistryManager.get_registry("flexible")
55
+ return registry.get_custom_models()
56
+
57
+
58
+ def register_flexible_model(model_spec: "FlexibleModelSpec", persist: bool):
59
+ from ..custom import RegistryManager
60
+
61
+ registry = RegistryManager.get_registry("flexible")
62
+ registry.register(model_spec, persist)
63
+
64
+
65
+ def unregister_flexible_model(model_name: str, raise_error: bool = True):
66
+ from ..custom import RegistryManager
67
+
68
+ registry = RegistryManager.get_registry("flexible")
69
+ registry.unregister(model_name, raise_error)
@@ -63,6 +63,7 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
63
63
  return ImageRemoveBackgroundModel(
64
64
  model_uid=model_uid,
65
65
  model_path=model_spec.model_uri, # type: ignore
66
+ model_family=model_spec,
66
67
  device=device,
67
68
  config=kwargs,
68
69
  )
@@ -43,5 +43,9 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
43
43
  raise ValueError("model_path required")
44
44
 
45
45
  return ModelScopePipelineModel(
46
- model_uid=model_uid, model_path=model_path, device=device, config=kwargs
46
+ model_uid=model_uid,
47
+ model_path=model_path,
48
+ model_family=model_spec,
49
+ device=device,
50
+ config=kwargs,
47
51
  )
@@ -51,13 +51,25 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
51
51
 
52
52
  if task == "text-classification":
53
53
  return TransformersTextClassificationModel(
54
- model_uid=model_uid, model_path=model_path, device=device, config=kwargs
54
+ model_uid=model_uid,
55
+ model_path=model_path,
56
+ model_family=model_spec,
57
+ device=device,
58
+ config=kwargs,
55
59
  )
56
60
  elif task == "mock":
57
61
  return MockModel(
58
- model_uid=model_uid, model_path=model_path, device=device, config=kwargs
62
+ model_uid=model_uid,
63
+ model_path=model_path,
64
+ model_family=model_spec,
65
+ device=device,
66
+ config=kwargs,
59
67
  )
60
68
  else:
61
69
  return AutoModel(
62
- model_uid=model_uid, model_path=model_path, device=device, config=kwargs
70
+ model_uid=model_uid,
71
+ model_path=model_path,
72
+ model_family=model_spec,
73
+ device=device,
74
+ config=kwargs,
63
75
  )
@@ -58,5 +58,9 @@ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> Flexibl
58
58
  raise ValueError("model_path required")
59
59
 
60
60
  return UltralyticsModel(
61
- model_uid=model_uid, model_path=model_path, device=device, config=kwargs
61
+ model_uid=model_uid,
62
+ model_path=model_path,
63
+ model_family=model_spec,
64
+ device=device,
65
+ config=kwargs,
62
66
  )
@@ -16,20 +16,17 @@ import codecs
16
16
  import json
17
17
  import os
18
18
  import warnings
19
- from itertools import chain
20
19
 
20
+ from ..utils import flatten_model_src
21
21
  from .core import (
22
22
  BUILTIN_IMAGE_MODELS,
23
23
  IMAGE_MODEL_DESCRIPTIONS,
24
- MODEL_NAME_TO_REVISION,
25
- MODELSCOPE_IMAGE_MODELS,
26
- ImageModelFamilyV1,
24
+ ImageModelFamilyV2,
27
25
  generate_image_description,
28
- get_cache_status,
29
26
  get_image_model_descriptions,
30
27
  )
31
28
  from .custom import (
32
- CustomImageModelFamilyV1,
29
+ CustomImageModelFamilyV2,
33
30
  get_user_defined_images,
34
31
  register_image,
35
32
  unregister_image,
@@ -38,15 +35,19 @@ from .custom import (
38
35
 
39
36
  def register_custom_model():
40
37
  from ...constants import XINFERENCE_MODEL_DIR
38
+ from ..custom import migrate_from_v1_to_v2
41
39
 
42
- user_defined_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "image")
40
+ # migrate from v1 to v2 first
41
+ migrate_from_v1_to_v2("image", CustomImageModelFamilyV2)
42
+
43
+ user_defined_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "image")
43
44
  if os.path.isdir(user_defined_image_dir):
44
45
  for f in os.listdir(user_defined_image_dir):
45
46
  try:
46
47
  with codecs.open(
47
48
  os.path.join(user_defined_image_dir, f), encoding="utf-8"
48
49
  ) as fd:
49
- user_defined_image_family = CustomImageModelFamilyV1.parse_obj(
50
+ user_defined_image_family = CustomImageModelFamilyV2.parse_obj(
50
51
  json.load(fd)
51
52
  )
52
53
  register_image(user_defined_image_family, persist=False)
@@ -56,12 +57,10 @@ def register_custom_model():
56
57
 
57
58
  def _install():
58
59
  load_model_family_from_json("model_spec.json", BUILTIN_IMAGE_MODELS)
59
- load_model_family_from_json("model_spec_modelscope.json", MODELSCOPE_IMAGE_MODELS)
60
60
 
61
61
  # register model description
62
- for model_name, model_spec in chain(
63
- MODELSCOPE_IMAGE_MODELS.items(), BUILTIN_IMAGE_MODELS.items()
64
- ):
62
+ for model_name, model_specs in BUILTIN_IMAGE_MODELS.items():
63
+ model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
65
64
  IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(model_spec))
66
65
 
67
66
  register_custom_model()
@@ -72,13 +71,14 @@ def _install():
72
71
 
73
72
  def load_model_family_from_json(json_filename, target_families):
74
73
  json_path = os.path.join(os.path.dirname(__file__), json_filename)
75
- target_families.update(
76
- dict(
77
- (spec["model_name"], ImageModelFamilyV1(**spec))
78
- for spec in json.load(codecs.open(json_path, "r", encoding="utf-8"))
79
- )
80
- )
81
- for model_name, model_spec in target_families.items():
82
- MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
74
+ flattened_model_specs = []
75
+ for spec in json.load(codecs.open(json_path, "r", encoding="utf-8")):
76
+ flattened_model_specs.extend(flatten_model_src(spec))
77
+
78
+ for spec in flattened_model_specs:
79
+ if spec["model_name"] not in target_families:
80
+ target_families[spec["model_name"]] = [ImageModelFamilyV2(**spec)]
81
+ else:
82
+ target_families[spec["model_name"]].append(ImageModelFamilyV2(**spec))
83
83
 
84
84
  del json_path
@@ -0,0 +1,62 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ from ..cache_manager import CacheManager
5
+
6
+
7
+ class ImageCacheManager(CacheManager):
8
+ def cache_gguf(self, quantization: Optional[str] = None):
9
+ from ..utils import IS_NEW_HUGGINGFACE_HUB, retry_download, symlink_local_file
10
+ from .core import ImageModelFamilyV2
11
+
12
+ if not quantization:
13
+ return None
14
+
15
+ assert isinstance(self._model_family, ImageModelFamilyV2)
16
+ cache_dir = self.get_cache_dir()
17
+
18
+ if not self._model_family.gguf_model_file_name_template:
19
+ raise NotImplementedError(
20
+ f"{self._model_family.model_name} does not support GGUF quantization"
21
+ )
22
+ if quantization not in (self._model_family.gguf_quantizations or []):
23
+ raise ValueError(
24
+ f"Cannot support quantization {quantization}, "
25
+ f"available quantizations: {self._model_family.gguf_quantizations}"
26
+ )
27
+
28
+ filename = self._model_family.gguf_model_file_name_template.format(quantization=quantization) # type: ignore
29
+ full_path = os.path.join(cache_dir, filename)
30
+
31
+ if self._model_family.model_hub == "huggingface":
32
+ import huggingface_hub
33
+
34
+ use_symlinks = {}
35
+ if not IS_NEW_HUGGINGFACE_HUB:
36
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
37
+ download_file_path = retry_download(
38
+ huggingface_hub.hf_hub_download,
39
+ self._model_family.model_name,
40
+ None,
41
+ self._model_family.gguf_model_id,
42
+ filename=filename,
43
+ **use_symlinks,
44
+ )
45
+ if IS_NEW_HUGGINGFACE_HUB:
46
+ symlink_local_file(download_file_path, cache_dir, filename)
47
+ elif self._model_family.model_hub == "modelscope":
48
+ from modelscope.hub.file_download import model_file_download
49
+
50
+ download_file_path = retry_download(
51
+ model_file_download,
52
+ self._model_family.model_name,
53
+ None,
54
+ self._model_family.gguf_model_id,
55
+ filename,
56
+ revision=self._model_family.model_revision,
57
+ )
58
+ symlink_local_file(download_file_path, cache_dir, filename)
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ return full_path