xinference 1.7.1.post1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,9 @@ from collections.abc import Callable
21
21
  import numpy as np
22
22
  import torch
23
23
 
24
- from .core import AudioModelFamilyV1
25
-
26
24
  logger = logging.getLogger(__name__)
27
25
 
28
26
 
29
- def get_model_version(audio_model: AudioModelFamilyV1) -> str:
30
- return audio_model.model_name
31
-
32
-
33
27
  def _extract_pcm_from_wav_bytes(wav_bytes):
34
28
  with io.BytesIO(wav_bytes) as wav_io:
35
29
  with wave.open(wav_io, "rb") as wav_file:
@@ -26,7 +26,7 @@ from ...device_utils import (
26
26
  )
27
27
 
28
28
  if TYPE_CHECKING:
29
- from .core import AudioModelFamilyV1
29
+ from .core import AudioModelFamilyV2
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
@@ -43,11 +43,12 @@ class WhisperModel:
43
43
  self,
44
44
  model_uid: str,
45
45
  model_path: str,
46
- model_spec: "AudioModelFamilyV1",
46
+ model_spec: "AudioModelFamilyV2",
47
47
  device: Optional[str] = None,
48
48
  max_new_tokens: Optional[int] = 128,
49
49
  **kwargs,
50
50
  ):
51
+ self.model_family = model_spec
51
52
  self._model_uid = model_uid
52
53
  self._model_path = model_path
53
54
  self._model_spec = model_spec
@@ -18,7 +18,7 @@ import tempfile
18
18
  from typing import TYPE_CHECKING, List, Optional
19
19
 
20
20
  if TYPE_CHECKING:
21
- from .core import AudioModelFamilyV1
21
+ from .core import AudioModelFamilyV2
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
@@ -28,10 +28,11 @@ class WhisperMLXModel:
28
28
  self,
29
29
  model_uid: str,
30
30
  model_path: str,
31
- model_spec: "AudioModelFamilyV1",
31
+ model_spec: "AudioModelFamilyV2",
32
32
  device: Optional[str] = None,
33
33
  **kwargs,
34
34
  ):
35
+ self.model_family = model_spec
35
36
  self._model_uid = model_uid
36
37
  self._model_path = model_path
37
38
  self._model_spec = model_spec
@@ -0,0 +1,141 @@
1
+ import logging
2
+ import os
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from .core import CacheableModelSpec
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class CacheManager:
13
+ def __init__(self, model_family: "CacheableModelSpec"):
14
+ from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
15
+
16
+ self._model_family = model_family
17
+ self._v2_cache_dir_prefix = os.path.join(XINFERENCE_CACHE_DIR, "v2")
18
+ self._v2_custom_dir_prefix = os.path.join(XINFERENCE_MODEL_DIR, "v2")
19
+ os.makedirs(self._v2_cache_dir_prefix, exist_ok=True)
20
+ os.makedirs(self._v2_custom_dir_prefix, exist_ok=True)
21
+ self._cache_dir = os.path.join(
22
+ self._v2_cache_dir_prefix, self._model_family.model_name.replace(".", "_")
23
+ )
24
+
25
+ def get_cache_dir(self):
26
+ return self._cache_dir
27
+
28
+ def get_cache_status(self):
29
+ cache_dir = self.get_cache_dir()
30
+ return os.path.exists(cache_dir)
31
+
32
+ def _cache_from_uri(self, model_spec: "CacheableModelSpec") -> str:
33
+ from .utils import parse_uri
34
+
35
+ cache_dir = self.get_cache_dir()
36
+ if os.path.exists(cache_dir):
37
+ logger.info("cache %s exists", cache_dir)
38
+ return cache_dir
39
+
40
+ assert model_spec.model_uri is not None
41
+ src_scheme, src_root = parse_uri(model_spec.model_uri)
42
+ if src_root.endswith("/"):
43
+ # remove trailing path separator.
44
+ src_root = src_root[:-1]
45
+
46
+ if src_scheme == "file":
47
+ if not os.path.isabs(src_root):
48
+ raise ValueError(
49
+ f"Model URI cannot be a relative path: {model_spec.model_uri}"
50
+ )
51
+ os.symlink(src_root, cache_dir, target_is_directory=True)
52
+ return cache_dir
53
+ else:
54
+ raise ValueError(f"Unsupported URL scheme: {src_scheme}")
55
+
56
+ def _cache(self) -> str:
57
+ from .utils import IS_NEW_HUGGINGFACE_HUB, create_symlink, retry_download
58
+
59
+ if (
60
+ hasattr(self._model_family, "model_uri")
61
+ and getattr(self._model_family, "model_uri", None) is not None
62
+ ):
63
+ logger.info(f"Model caching from URI: {self._model_family.model_uri}")
64
+ return self._cache_from_uri(model_spec=self._model_family)
65
+
66
+ cache_dir = self.get_cache_dir()
67
+ if self.get_cache_status():
68
+ return cache_dir
69
+
70
+ from_modelscope: bool = self._model_family.model_hub == "modelscope"
71
+ cache_config = (
72
+ self._model_family.cache_config.copy()
73
+ if self._model_family.cache_config
74
+ else {}
75
+ )
76
+ if from_modelscope:
77
+ from modelscope.hub.snapshot_download import (
78
+ snapshot_download as ms_download,
79
+ )
80
+
81
+ download_dir = retry_download(
82
+ ms_download,
83
+ self._model_family.model_name,
84
+ None,
85
+ self._model_family.model_id,
86
+ revision=self._model_family.model_revision,
87
+ **cache_config,
88
+ )
89
+ create_symlink(download_dir, cache_dir)
90
+ else:
91
+ from huggingface_hub import snapshot_download as hf_download
92
+
93
+ use_symlinks = cache_config
94
+ if not IS_NEW_HUGGINGFACE_HUB:
95
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
96
+ download_dir = retry_download(
97
+ hf_download,
98
+ self._model_family.model_name,
99
+ None,
100
+ self._model_family.model_id,
101
+ revision=self._model_family.model_revision,
102
+ **use_symlinks,
103
+ )
104
+ if IS_NEW_HUGGINGFACE_HUB:
105
+ create_symlink(download_dir, cache_dir)
106
+ return cache_dir
107
+
108
+ def cache(self) -> str:
109
+ return self._cache()
110
+
111
+ def register_custom_model(self, model_type: str):
112
+ persist_path = os.path.join(
113
+ self._v2_custom_dir_prefix,
114
+ model_type,
115
+ f"{self._model_family.model_name}.json",
116
+ )
117
+ os.makedirs(os.path.dirname(persist_path), exist_ok=True)
118
+ with open(persist_path, mode="w") as fd:
119
+ fd.write(self._model_family.json())
120
+
121
+ def unregister_custom_model(self, model_type: str):
122
+ persist_path = os.path.join(
123
+ self._v2_custom_dir_prefix,
124
+ model_type,
125
+ f"{self._model_family.model_name}.json",
126
+ )
127
+ if os.path.exists(persist_path):
128
+ os.remove(persist_path)
129
+
130
+ cache_dir = self.get_cache_dir()
131
+ if self.get_cache_status():
132
+ logger.warning(
133
+ f"Remove the cache of user-defined model {self._model_family.model_name}. "
134
+ f"Cache directory: {cache_dir}"
135
+ )
136
+ if os.path.islink(cache_dir):
137
+ os.remove(cache_dir)
138
+ else:
139
+ logger.warning(
140
+ f"Cache directory is not a soft link, please remove it manually."
141
+ )
xinference/model/core.py CHANGED
@@ -11,47 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- from abc import ABC, abstractmethod
16
- from typing import Any, List, Literal, Optional, Tuple, Union
14
+ from typing import Any, List, Literal, Optional, Union
17
15
 
18
16
  from .._compat import BaseModel
19
17
  from ..types import PeftModelConfig
20
18
 
21
19
 
22
- class ModelDescription(ABC):
23
- def __init__(
24
- self,
25
- address: Optional[str],
26
- devices: Optional[List[str]],
27
- model_path: Optional[str] = None,
28
- ):
29
- self.address = address
30
- self.devices = devices
31
- self._model_path = model_path
32
-
33
- @property
34
- @abstractmethod
35
- def spec(self):
36
- pass
37
-
38
- def to_dict(self):
39
- """
40
- Return a dict to describe some information about model.
41
- :return:
42
- """
43
- raise NotImplementedError
44
-
45
- @abstractmethod
46
- def to_version_info(self):
47
- """
48
- Return a dict to describe version info about a model instance
49
- """
50
-
51
-
52
20
  def create_model_instance(
53
- subpool_addr: str,
54
- devices: List[str],
55
21
  model_uid: str,
56
22
  model_type: str,
57
23
  model_name: str,
@@ -65,7 +31,7 @@ def create_model_instance(
65
31
  ] = None,
66
32
  model_path: Optional[str] = None,
67
33
  **kwargs,
68
- ) -> Tuple[Any, ModelDescription]:
34
+ ) -> Any:
69
35
  from .audio.core import create_audio_model_instance
70
36
  from .embedding.core import create_embedding_model_instance
71
37
  from .flexible.core import create_flexible_model_instance
@@ -76,8 +42,6 @@ def create_model_instance(
76
42
 
77
43
  if model_type == "LLM":
78
44
  return create_llm_model_instance(
79
- subpool_addr,
80
- devices,
81
45
  model_uid,
82
46
  model_name,
83
47
  model_engine,
@@ -93,11 +57,11 @@ def create_model_instance(
93
57
  # embedding model doesn't accept trust_remote_code
94
58
  kwargs.pop("trust_remote_code", None)
95
59
  return create_embedding_model_instance(
96
- subpool_addr,
97
- devices,
98
60
  model_uid,
99
61
  model_name,
100
62
  model_engine,
63
+ model_format,
64
+ quantization,
101
65
  download_hub,
102
66
  model_path,
103
67
  **kwargs,
@@ -105,8 +69,6 @@ def create_model_instance(
105
69
  elif model_type == "image":
106
70
  kwargs.pop("trust_remote_code", None)
107
71
  return create_image_model_instance(
108
- subpool_addr,
109
- devices,
110
72
  model_uid,
111
73
  model_name,
112
74
  peft_model_config,
@@ -117,8 +79,6 @@ def create_model_instance(
117
79
  elif model_type == "rerank":
118
80
  kwargs.pop("trust_remote_code", None)
119
81
  return create_rerank_model_instance(
120
- subpool_addr,
121
- devices,
122
82
  model_uid,
123
83
  model_name,
124
84
  download_hub,
@@ -128,8 +88,6 @@ def create_model_instance(
128
88
  elif model_type == "audio":
129
89
  kwargs.pop("trust_remote_code", None)
130
90
  return create_audio_model_instance(
131
- subpool_addr,
132
- devices,
133
91
  model_uid,
134
92
  model_name,
135
93
  download_hub,
@@ -139,8 +97,6 @@ def create_model_instance(
139
97
  elif model_type == "video":
140
98
  kwargs.pop("trust_remote_code", None)
141
99
  return create_video_model_instance(
142
- subpool_addr,
143
- devices,
144
100
  model_uid,
145
101
  model_name,
146
102
  download_hub,
@@ -150,7 +106,7 @@ def create_model_instance(
150
106
  elif model_type == "flexible":
151
107
  kwargs.pop("trust_remote_code", None)
152
108
  return create_flexible_model_instance(
153
- subpool_addr, devices, model_uid, model_name, model_path, **kwargs
109
+ model_uid, model_name, model_path, **kwargs
154
110
  )
155
111
  else:
156
112
  raise ValueError(f"Unsupported model type: {model_type}.")
@@ -161,6 +117,7 @@ class CacheableModelSpec(BaseModel):
161
117
  model_id: str
162
118
  model_revision: Optional[str]
163
119
  model_hub: str = "huggingface"
120
+ cache_config: Optional[dict]
164
121
 
165
122
 
166
123
  class VirtualEnvSettings(BaseModel):
@@ -0,0 +1,174 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import codecs
16
+ import json
17
+ import logging
18
+ import os
19
+ import threading
20
+ import warnings
21
+ from typing import TYPE_CHECKING, Dict, List, Type
22
+
23
+ if TYPE_CHECKING:
24
+ from .core import CacheableModelSpec
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class ModelRegistry:
30
+ model_type = "unknown"
31
+
32
+ def __init__(self) -> None:
33
+ self.lock = threading.Lock()
34
+ self.models: List["CacheableModelSpec"] = []
35
+ self.builtin_models: List[str] = []
36
+
37
+ def find_model(self, model_name: str):
38
+ model_spec = None
39
+ for f in self.models:
40
+ if f.model_name == model_name:
41
+ model_spec = f
42
+ break
43
+ return model_spec
44
+
45
+ def get_custom_models(self):
46
+ with self.lock:
47
+ return self.models.copy()
48
+
49
+ def check_model_uri(self, model_spec: "CacheableModelSpec"):
50
+ from .utils import is_valid_model_uri
51
+
52
+ model_uri = model_spec.model_uri
53
+ if model_uri and not is_valid_model_uri(model_uri):
54
+ raise ValueError(f"Invalid model URI {model_uri}.")
55
+
56
+ def add_ud_model(self, model_spec):
57
+ self.models.append(model_spec)
58
+
59
+ def register(self, model_spec: "CacheableModelSpec", persist: bool):
60
+ from .cache_manager import CacheManager
61
+ from .utils import is_valid_model_name
62
+
63
+ if not is_valid_model_name(model_spec.model_name):
64
+ raise ValueError(f"Invalid model name {model_spec.model_name}.")
65
+
66
+ self.check_model_uri(model_spec)
67
+
68
+ with self.lock:
69
+ for model_name in self.builtin_models + [
70
+ spec.model_name for spec in self.models
71
+ ]:
72
+ if model_spec.model_name == model_name:
73
+ raise ValueError(
74
+ f"Model name conflicts with existing model {model_spec.model_name}"
75
+ )
76
+
77
+ self.add_ud_model(model_spec)
78
+
79
+ if persist:
80
+ cache_manager = CacheManager(model_spec)
81
+ cache_manager.register_custom_model(self.model_type)
82
+
83
+ def remove_ud_model(self, model_spec):
84
+ self.models.remove(model_spec)
85
+
86
+ def remove_ud_model_files(self, model_spec):
87
+ from .cache_manager import CacheManager
88
+
89
+ cache_manager = CacheManager(model_spec)
90
+ cache_manager.unregister_custom_model(self.model_type)
91
+
92
+ def unregister(
93
+ self, model_name: str, raise_error: bool = True, remove_file: bool = True
94
+ ):
95
+ with self.lock:
96
+ model_spec = self.find_model(model_name)
97
+ if model_spec:
98
+ self.remove_ud_model(model_spec)
99
+ if remove_file:
100
+ self.remove_ud_model_files(model_spec)
101
+ else:
102
+ if raise_error:
103
+ raise ValueError(f"Model {model_name} not found")
104
+ else:
105
+ logger.warning(
106
+ f"Custom {self.model_type} model {model_name} not found"
107
+ )
108
+
109
+
110
+ class RegistryManager:
111
+ _instances: Dict[str, ModelRegistry] = {}
112
+
113
+ @classmethod
114
+ def get_registry(cls, model_type: str) -> ModelRegistry:
115
+ from .audio.custom import AudioModelRegistry
116
+ from .embedding.custom import EmbeddingModelRegistry
117
+ from .flexible.custom import FlexibleModelRegistry
118
+ from .image.custom import ImageModelRegistry
119
+ from .llm.custom import LLMModelRegistry
120
+ from .rerank.custom import RerankModelRegistry
121
+
122
+ if model_type not in cls._instances:
123
+ if model_type == "rerank":
124
+ cls._instances[model_type] = RerankModelRegistry()
125
+ elif model_type == "image":
126
+ cls._instances[model_type] = ImageModelRegistry()
127
+ elif model_type == "audio":
128
+ cls._instances[model_type] = AudioModelRegistry()
129
+ elif model_type == "llm":
130
+ cls._instances[model_type] = LLMModelRegistry()
131
+ elif model_type == "flexible":
132
+ cls._instances[model_type] = FlexibleModelRegistry()
133
+ elif model_type == "embedding":
134
+ cls._instances[model_type] = EmbeddingModelRegistry()
135
+ else:
136
+ raise ValueError(f"Unknown model type: {model_type}")
137
+ return cls._instances[model_type]
138
+
139
+
140
+ def migrate_from_v1_to_v2(model_type: str, model_spec_cls: Type):
141
+ from ..constants import XINFERENCE_MODEL_DIR
142
+
143
+ v1_user_defined_model_dir = os.path.join(XINFERENCE_MODEL_DIR, model_type)
144
+ v2_user_defined_model_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", model_type)
145
+ if os.path.isdir(v1_user_defined_model_dir):
146
+ for f in os.listdir(v1_user_defined_model_dir):
147
+ if os.path.exists(os.path.join(v2_user_defined_model_dir, f)):
148
+ # skip if v2 has already
149
+ continue
150
+
151
+ try:
152
+ with codecs.open(
153
+ os.path.join(v1_user_defined_model_dir, f), encoding="utf-8"
154
+ ) as fd:
155
+ v1_model_json = json.load(fd)
156
+
157
+ v1_model_json["version"] = 2
158
+ for spec in v1_model_json.get("model_specs", []):
159
+ if "quantizations" in spec:
160
+ # change quantizations to quantization
161
+ spec["quantization"] = spec["quantizations"][0]
162
+
163
+ user_defined_model_family = model_spec_cls(**v1_model_json)
164
+ registry = RegistryManager.get_registry(model_type)
165
+ # register custom model file to v2
166
+ registry.register(user_defined_model_family, persist=True)
167
+ # unregister since it will be registered by v2
168
+ registry.unregister(
169
+ user_defined_model_family.model_name, remove_file=False
170
+ )
171
+ except Exception as e:
172
+ warnings.warn(
173
+ f"Fail to migrate {v1_user_defined_model_dir}/{f}, error: {e}"
174
+ )