xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -14,30 +14,21 @@
14
14
 
15
15
  import collections.abc
16
16
  import logging
17
- import os
18
17
  import platform
19
18
  from collections import defaultdict
20
- from typing import Dict, List, Literal, Optional, Tuple, Union
19
+ from typing import Dict, List, Literal, Optional, Union
21
20
 
22
- from ...constants import XINFERENCE_CACHE_DIR
23
21
  from ...types import PeftModelConfig
24
- from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
25
- from ..utils import (
26
- IS_NEW_HUGGINGFACE_HUB,
27
- retry_download,
28
- symlink_local_file,
29
- valid_model_revision,
30
- )
22
+ from ..core import CacheableModelSpec, VirtualEnvSettings
23
+ from ..utils import ModelInstanceInfoMixin
31
24
  from .ocr.got_ocr2 import GotOCR2Model
32
25
  from .stable_diffusion.core import DiffusionModel
33
26
  from .stable_diffusion.mlx import MLXDiffusionModel
34
27
 
35
28
  logger = logging.getLogger(__name__)
36
29
 
37
- MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
38
30
  IMAGE_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
39
- BUILTIN_IMAGE_MODELS: Dict[str, "ImageModelFamilyV1"] = {}
40
- MODELSCOPE_IMAGE_MODELS: Dict[str, "ImageModelFamilyV1"] = {}
31
+ BUILTIN_IMAGE_MODELS: Dict[str, List["ImageModelFamilyV2"]] = {}
41
32
 
42
33
 
43
34
  def get_image_model_descriptions():
@@ -46,14 +37,15 @@ def get_image_model_descriptions():
46
37
  return copy.deepcopy(IMAGE_MODEL_DESCRIPTIONS)
47
38
 
48
39
 
49
- class ImageModelFamilyV1(CacheableModelSpec):
40
+ class ImageModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
41
+ version: Literal[2] = 2
50
42
  model_family: str
51
43
  model_name: str
52
44
  model_id: str
53
45
  model_revision: str
54
46
  model_hub: str = "huggingface"
55
47
  model_ability: Optional[List[str]]
56
- controlnet: Optional[List["ImageModelFamilyV1"]]
48
+ controlnet: Optional[List["ImageModelFamilyV2"]]
57
49
  default_model_config: Optional[dict] = {}
58
50
  default_generate_config: Optional[dict] = {}
59
51
  gguf_model_id: Optional[str]
@@ -61,65 +53,48 @@ class ImageModelFamilyV1(CacheableModelSpec):
61
53
  gguf_model_file_name_template: Optional[str]
62
54
  virtualenv: Optional[VirtualEnvSettings]
63
55
 
56
+ class Config:
57
+ extra = "allow"
64
58
 
65
- class ImageModelDescription(ModelDescription):
66
- def __init__(
67
- self,
68
- address: Optional[str],
69
- devices: Optional[List[str]],
70
- model_spec: ImageModelFamilyV1,
71
- model_path: Optional[str] = None,
72
- ):
73
- super().__init__(address, devices, model_path=model_path)
74
- self._model_spec = model_spec
75
-
76
- @property
77
- def spec(self):
78
- return self._model_spec
79
-
80
- def to_dict(self):
81
- if self._model_spec.controlnet is not None:
82
- controlnet = [cn.dict() for cn in self._model_spec.controlnet]
59
+ def to_description(self):
60
+ if self.controlnet is not None:
61
+ controlnet = [cn.dict() for cn in self.controlnet]
83
62
  else:
84
- controlnet = self._model_spec.controlnet
63
+ controlnet = self.controlnet
85
64
  return {
86
65
  "model_type": "image",
87
- "address": self.address,
88
- "accelerators": self.devices,
89
- "model_name": self._model_spec.model_name,
90
- "model_family": self._model_spec.model_family,
91
- "model_revision": self._model_spec.model_revision,
92
- "model_ability": self._model_spec.model_ability,
66
+ "address": getattr(self, "address", None),
67
+ "accelerators": getattr(self, "accelerators", None),
68
+ "model_name": self.model_name,
69
+ "model_family": self.model_family,
70
+ "model_revision": self.model_revision,
71
+ "model_ability": self.model_ability,
93
72
  "controlnet": controlnet,
94
73
  }
95
74
 
96
75
  def to_version_info(self):
76
+ from .cache_manager import ImageCacheManager
97
77
  from .utils import get_model_version
98
78
 
99
- if self._model_path is None:
100
- is_cached = get_cache_status(self._model_spec)
101
- file_location = get_cache_dir(self._model_spec)
102
- else:
103
- is_cached = True
104
- file_location = self._model_path
79
+ cache_manager = ImageCacheManager(self)
105
80
 
106
- if self._model_spec.controlnet is None:
81
+ if not self.controlnet:
107
82
  return [
108
83
  {
109
- "model_version": get_model_version(self._model_spec, None),
110
- "model_file_location": file_location,
111
- "cache_status": is_cached,
84
+ "model_version": get_model_version(self, None),
85
+ "model_file_location": cache_manager.get_cache_dir(),
86
+ "cache_status": cache_manager.get_cache_status(),
112
87
  "controlnet": "zoe-depth",
113
88
  }
114
89
  ]
115
90
  else:
116
91
  res = []
117
- for cn in self._model_spec.controlnet:
92
+ for cn in self.controlnet:
118
93
  res.append(
119
94
  {
120
- "model_version": get_model_version(self._model_spec, cn),
121
- "model_file_location": file_location,
122
- "cache_status": is_cached,
95
+ "model_version": get_model_version(self, cn),
96
+ "model_file_location": cache_manager.get_cache_dir(),
97
+ "cache_status": cache_manager.get_cache_status(),
123
98
  "controlnet": cn.model_name,
124
99
  }
125
100
  )
@@ -127,12 +102,10 @@ class ImageModelDescription(ModelDescription):
127
102
 
128
103
 
129
104
  def generate_image_description(
130
- image_model: ImageModelFamilyV1,
105
+ image_model: ImageModelFamilyV2,
131
106
  ) -> Dict[str, List[Dict]]:
132
107
  res = defaultdict(list)
133
- res[image_model.model_name].extend(
134
- ImageModelDescription(None, None, image_model).to_version_info()
135
- )
108
+ res[image_model.model_name].extend(image_model.to_version_info())
136
109
  return res
137
110
 
138
111
 
@@ -141,27 +114,35 @@ def match_diffusion(
141
114
  download_hub: Optional[
142
115
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
143
116
  ] = None,
144
- ) -> ImageModelFamilyV1:
117
+ ) -> ImageModelFamilyV2:
145
118
  from ..utils import download_from_modelscope
146
- from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
119
+ from . import BUILTIN_IMAGE_MODELS
147
120
  from .custom import get_user_defined_images
148
121
 
149
122
  for model_spec in get_user_defined_images():
150
123
  if model_spec.model_name == model_name:
151
124
  return model_spec
152
125
 
153
- if download_hub == "modelscope" and model_name in MODELSCOPE_IMAGE_MODELS:
154
- logger.debug(f"Image model {model_name} found in ModelScope.")
155
- return MODELSCOPE_IMAGE_MODELS[model_name]
156
- elif download_hub == "huggingface" and model_name in BUILTIN_IMAGE_MODELS:
157
- logger.debug(f"Image model {model_name} found in Huggingface.")
158
- return BUILTIN_IMAGE_MODELS[model_name]
159
- elif download_from_modelscope() and model_name in MODELSCOPE_IMAGE_MODELS:
160
- logger.debug(f"Image model {model_name} found in ModelScope.")
161
- return MODELSCOPE_IMAGE_MODELS[model_name]
162
- elif model_name in BUILTIN_IMAGE_MODELS:
163
- logger.debug(f"Image model {model_name} found in Huggingface.")
164
- return BUILTIN_IMAGE_MODELS[model_name]
126
+ if model_name in BUILTIN_IMAGE_MODELS:
127
+ if download_hub == "modelscope" or download_from_modelscope():
128
+ return (
129
+ [
130
+ x
131
+ for x in BUILTIN_IMAGE_MODELS[model_name]
132
+ if x.model_hub == "modelscope"
133
+ ]
134
+ + [
135
+ x
136
+ for x in BUILTIN_IMAGE_MODELS[model_name]
137
+ if x.model_hub == "huggingface"
138
+ ]
139
+ )[0]
140
+ else:
141
+ return [
142
+ x
143
+ for x in BUILTIN_IMAGE_MODELS[model_name]
144
+ if x.model_hub == "huggingface"
145
+ ][0]
165
146
  else:
166
147
  raise ValueError(
167
148
  f"Image model {model_name} not found, available"
@@ -169,117 +150,27 @@ def match_diffusion(
169
150
  )
170
151
 
171
152
 
172
- def cache(model_spec: ImageModelFamilyV1):
173
- from ..utils import cache
174
-
175
- return cache(model_spec, ImageModelDescription)
176
-
177
-
178
- def get_cache_dir(model_spec: ImageModelFamilyV1):
179
- return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
180
-
181
-
182
- def get_cache_status(
183
- model_spec: ImageModelFamilyV1,
184
- ) -> bool:
185
- cache_dir = get_cache_dir(model_spec)
186
- meta_path = os.path.join(cache_dir, "__valid_download")
187
-
188
- model_name = model_spec.model_name
189
- if model_name in BUILTIN_IMAGE_MODELS and model_name in MODELSCOPE_IMAGE_MODELS:
190
- hf_spec = BUILTIN_IMAGE_MODELS[model_name]
191
- ms_spec = MODELSCOPE_IMAGE_MODELS[model_name]
192
-
193
- return any(
194
- [
195
- valid_model_revision(meta_path, hf_spec.model_revision),
196
- valid_model_revision(meta_path, ms_spec.model_revision),
197
- ]
198
- )
199
- else: # Usually for UT
200
- return valid_model_revision(meta_path, model_spec.model_revision)
201
-
202
-
203
- def cache_gguf(spec: ImageModelFamilyV1, quantization: Optional[str] = None):
204
- if not quantization:
205
- return
206
-
207
- cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, spec.model_name))
208
- if not os.path.exists(cache_dir):
209
- os.makedirs(cache_dir, exist_ok=True)
210
-
211
- if not spec.gguf_model_file_name_template:
212
- raise NotImplementedError(
213
- f"{spec.model_name} does not support GGUF quantization"
214
- )
215
- if quantization not in (spec.gguf_quantizations or []):
216
- raise ValueError(
217
- f"Cannot support quantization {quantization}, "
218
- f"available quantizations: {spec.gguf_quantizations}"
219
- )
220
-
221
- filename = spec.gguf_model_file_name_template.format(quantization=quantization) # type: ignore
222
- full_path = os.path.join(cache_dir, filename)
223
-
224
- if spec.model_hub == "huggingface":
225
- import huggingface_hub
226
-
227
- use_symlinks = {}
228
- if not IS_NEW_HUGGINGFACE_HUB:
229
- use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
230
- download_file_path = retry_download(
231
- huggingface_hub.hf_hub_download,
232
- spec.model_name,
233
- None,
234
- spec.gguf_model_id,
235
- filename=filename,
236
- **use_symlinks,
237
- )
238
- if IS_NEW_HUGGINGFACE_HUB:
239
- symlink_local_file(download_file_path, cache_dir, filename)
240
- elif spec.model_hub == "modelscope":
241
- from modelscope.hub.file_download import model_file_download
242
-
243
- download_file_path = retry_download(
244
- model_file_download,
245
- spec.model_name,
246
- None,
247
- spec.gguf_model_id,
248
- filename,
249
- revision=spec.model_revision,
250
- )
251
- symlink_local_file(download_file_path, cache_dir, filename)
252
- else:
253
- raise NotImplementedError
254
-
255
- return full_path
256
-
257
-
258
153
  def create_ocr_model_instance(
259
- subpool_addr: str,
260
- devices: List[str],
261
154
  model_uid: str,
262
- model_spec: ImageModelFamilyV1,
155
+ model_spec: ImageModelFamilyV2,
263
156
  model_path: Optional[str] = None,
264
157
  **kwargs,
265
- ) -> Tuple[GotOCR2Model, ImageModelDescription]:
158
+ ) -> GotOCR2Model:
159
+ from .cache_manager import ImageCacheManager
160
+
266
161
  if not model_path:
267
- model_path = cache(model_spec)
162
+ cache_manager = ImageCacheManager(model_spec)
163
+ model_path = cache_manager.cache()
268
164
  model = GotOCR2Model(
269
165
  model_uid,
270
166
  model_path,
271
167
  model_spec=model_spec,
272
168
  **kwargs,
273
169
  )
274
- model_description = ImageModelDescription(
275
- subpool_addr, devices, model_spec, model_path=model_path
276
- )
277
- return model, model_description
170
+ return model
278
171
 
279
172
 
280
173
  def create_image_model_instance(
281
- subpool_addr: str,
282
- devices: List[str],
283
174
  model_uid: str,
284
175
  model_name: str,
285
176
  peft_model_config: Optional[PeftModelConfig] = None,
@@ -290,14 +181,12 @@ def create_image_model_instance(
290
181
  gguf_quantization: Optional[str] = None,
291
182
  gguf_model_path: Optional[str] = None,
292
183
  **kwargs,
293
- ) -> Tuple[
294
- Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
295
- ]:
184
+ ) -> Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model]:
185
+ from .cache_manager import ImageCacheManager
186
+
296
187
  model_spec = match_diffusion(model_name, download_hub)
297
188
  if model_spec.model_ability and "ocr" in model_spec.model_ability:
298
189
  return create_ocr_model_instance(
299
- subpool_addr=subpool_addr,
300
- devices=devices,
301
190
  model_uid=model_uid,
302
191
  model_name=model_name,
303
192
  model_spec=model_spec,
@@ -327,7 +216,8 @@ def create_image_model_instance(
327
216
  for name in controlnet:
328
217
  for cn_model_spec in model_spec.controlnet:
329
218
  if cn_model_spec.model_name == name:
330
- controlnet_model_path = cache(cn_model_spec)
219
+ cn_cache_manager = ImageCacheManager(cn_model_spec)
220
+ controlnet_model_path = cn_cache_manager.cache()
331
221
  controlnet_model_paths.append(controlnet_model_path)
332
222
  break
333
223
  else:
@@ -340,10 +230,11 @@ def create_image_model_instance(
340
230
  kwargs["controlnet"] = [
341
231
  (n, path) for n, path in zip(controlnet, controlnet_model_paths)
342
232
  ]
233
+ cache_manager = ImageCacheManager(model_spec)
343
234
  if not model_path:
344
- model_path = cache(model_spec)
235
+ model_path = cache_manager.cache()
345
236
  if not gguf_model_path and gguf_quantization:
346
- gguf_model_path = cache_gguf(model_spec, gguf_quantization)
237
+ gguf_model_path = cache_manager.cache_gguf(gguf_quantization)
347
238
  if peft_model_config is not None:
348
239
  lora_model = peft_model_config.peft_model
349
240
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -356,7 +247,7 @@ def create_image_model_instance(
356
247
  if (
357
248
  platform.system() == "Darwin"
358
249
  and "arm" in platform.machine().lower()
359
- and model_name in MLXDiffusionModel.supported_models
250
+ and MLXDiffusionModel.support_model(model_name)
360
251
  ):
361
252
  # Mac with M series silicon chips
362
253
  model_cls = MLXDiffusionModel
@@ -373,7 +264,4 @@ def create_image_model_instance(
373
264
  gguf_model_path=gguf_model_path,
374
265
  **kwargs,
375
266
  )
376
- model_description = ImageModelDescription(
377
- subpool_addr, devices, model_spec, model_path=model_path
378
- )
379
- return model, model_description
267
+ return model
@@ -11,98 +11,54 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import logging
16
- import os
17
- from threading import Lock
18
15
  from typing import List, Optional
19
16
 
20
- from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
21
- from .core import ImageModelFamilyV1
17
+ from ..._compat import Literal
18
+ from ..custom import ModelRegistry
19
+ from .core import ImageModelFamilyV2
22
20
 
23
21
  logger = logging.getLogger(__name__)
24
22
 
25
- UD_IMAGE_LOCK = Lock()
26
-
27
23
 
28
- class CustomImageModelFamilyV1(ImageModelFamilyV1):
24
+ class CustomImageModelFamilyV2(ImageModelFamilyV2):
25
+ version: Literal[2] = 2
29
26
  model_id: Optional[str] # type: ignore
30
27
  model_revision: Optional[str] # type: ignore
31
28
  model_uri: Optional[str]
32
- controlnet: Optional[List["CustomImageModelFamilyV1"]]
29
+ controlnet: Optional[List["CustomImageModelFamilyV2"]]
30
+
33
31
 
32
+ UD_IMAGES: List[CustomImageModelFamilyV2] = []
34
33
 
35
- UD_IMAGES: List[CustomImageModelFamilyV1] = []
36
34
 
35
+ class ImageModelRegistry(ModelRegistry):
36
+ model_type = "image"
37
37
 
38
- def get_user_defined_images() -> List[ImageModelFamilyV1]:
39
- with UD_IMAGE_LOCK:
40
- return UD_IMAGES.copy()
38
+ def __init__(self):
39
+ from .core import BUILTIN_IMAGE_MODELS
41
40
 
41
+ super().__init__()
42
+ self.models = UD_IMAGES
43
+ self.builtin_models = list(BUILTIN_IMAGE_MODELS.keys())
42
44
 
43
- def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
44
- from ..utils import is_valid_model_name, is_valid_model_uri
45
- from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
46
45
 
47
- if not is_valid_model_name(model_spec.model_name):
48
- raise ValueError(f"Invalid model name {model_spec.model_name}.")
46
+ def get_user_defined_images() -> List[ImageModelFamilyV2]:
47
+ from ..custom import RegistryManager
49
48
 
50
- model_uri = model_spec.model_uri
51
- if model_uri and not is_valid_model_uri(model_uri):
52
- raise ValueError(f"Invalid model URI {model_uri}")
49
+ registry = RegistryManager.get_registry("image")
50
+ return registry.get_custom_models()
53
51
 
54
- with UD_IMAGE_LOCK:
55
- for model_name in (
56
- list(BUILTIN_IMAGE_MODELS.keys())
57
- + list(MODELSCOPE_IMAGE_MODELS.keys())
58
- + [spec.model_name for spec in UD_IMAGES]
59
- ):
60
- if model_spec.model_name == model_name:
61
- raise ValueError(
62
- f"Model name conflicts with existing model {model_spec.model_name}"
63
- )
64
- UD_IMAGES.append(model_spec)
65
52
 
66
- if persist:
67
- persist_path = os.path.join(
68
- XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_name}.json"
69
- )
70
- os.makedirs(os.path.dirname(persist_path), exist_ok=True)
71
- with open(persist_path, "w") as f:
72
- f.write(model_spec.json())
53
+ def register_image(model_spec: CustomImageModelFamilyV2, persist: bool):
54
+ from ..custom import RegistryManager
55
+
56
+ registry = RegistryManager.get_registry("image")
57
+ registry.register(model_spec, persist)
73
58
 
74
59
 
75
60
  def unregister_image(model_name: str, raise_error: bool = True):
76
- with UD_IMAGE_LOCK:
77
- model_spec = None
78
- for i, f in enumerate(UD_IMAGES):
79
- if f.model_name == model_name:
80
- model_spec = f
81
- break
82
- if model_spec:
83
- UD_IMAGES.remove(model_spec)
84
-
85
- persist_path = os.path.join(
86
- XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_id}.json"
87
- )
88
-
89
- if os.path.exists(persist_path):
90
- os.remove(persist_path)
91
-
92
- cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
93
- if os.path.exists(cache_dir):
94
- logger.warning(
95
- f"Remove the cache of user-defined model {model_spec.model_name}. "
96
- f"Cache directory: {cache_dir}"
97
- )
98
- if os.path.islink(cache_dir):
99
- os.remove(cache_dir)
100
- else:
101
- logger.warning(
102
- f"Cache directory is not a soft link, please remove it manually."
103
- )
104
- else:
105
- if raise_error:
106
- raise ValueError(f"Model {model_name} not found.")
107
- else:
108
- logger.warning(f"Custom image model {model_name} not found.")
61
+ from ..custom import RegistryManager
62
+
63
+ registry = RegistryManager.get_registry("image")
64
+ registry.unregister(model_name, raise_error)