xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,7 @@
14
14
 
15
15
  import logging
16
16
  import os
17
- from threading import Lock
18
- from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
17
+ from typing import Any, Dict, List, Optional, Set, Type, Union
19
18
 
20
19
  from typing_extensions import Annotated, Literal
21
20
 
@@ -30,24 +29,14 @@ from ..._compat import (
30
29
  load_str_bytes,
31
30
  validator,
32
31
  )
33
- from ...constants import (
34
- XINFERENCE_CACHE_DIR,
35
- XINFERENCE_CSG_ENDPOINT,
36
- XINFERENCE_ENV_CSG_TOKEN,
37
- XINFERENCE_MODEL_DIR,
38
- )
32
+ from ...constants import XINFERENCE_CACHE_DIR
39
33
  from ..core import VirtualEnvSettings
40
34
  from ..utils import (
41
- IS_NEW_HUGGINGFACE_HUB,
42
- create_symlink,
35
+ ModelInstanceInfoMixin,
43
36
  download_from_csghub,
44
37
  download_from_modelscope,
45
38
  download_from_openmind_hub,
46
- is_valid_model_uri,
47
- parse_uri,
48
39
  retry_download,
49
- symlink_local_file,
50
- valid_model_revision,
51
40
  )
52
41
  from . import LLM
53
42
 
@@ -60,11 +49,11 @@ BUILTIN_LLM_MODEL_GENERATE_FAMILIES: Set[str] = set()
60
49
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES: Set[str] = set()
61
50
 
62
51
 
63
- class LlamaCppLLMSpecV1(BaseModel):
52
+ class LlamaCppLLMSpecV2(BaseModel):
64
53
  model_format: Literal["ggufv2"]
65
54
  # Must in order that `str` first, then `int`
66
55
  model_size_in_billions: Union[str, int]
67
- quantizations: List[str]
56
+ quantization: str
68
57
  multimodal_projectors: Optional[List[str]]
69
58
  model_id: Optional[str]
70
59
  model_file_name_template: str
@@ -88,11 +77,11 @@ class LlamaCppLLMSpecV1(BaseModel):
88
77
  return v
89
78
 
90
79
 
91
- class PytorchLLMSpecV1(BaseModel):
80
+ class PytorchLLMSpecV2(BaseModel):
92
81
  model_format: Literal["pytorch", "gptq", "awq", "fp8"]
93
82
  # Must in order that `str` first, then `int`
94
83
  model_size_in_billions: Union[str, int]
95
- quantizations: List[str]
84
+ quantization: str
96
85
  model_id: Optional[str]
97
86
  model_hub: str = "huggingface"
98
87
  model_uri: Optional[str]
@@ -112,11 +101,11 @@ class PytorchLLMSpecV1(BaseModel):
112
101
  return v
113
102
 
114
103
 
115
- class MLXLLMSpecV1(BaseModel):
104
+ class MLXLLMSpecV2(BaseModel):
116
105
  model_format: Literal["mlx"]
117
106
  # Must in order that `str` first, then `int`
118
107
  model_size_in_billions: Union[str, int]
119
- quantizations: List[str]
108
+ quantization: str
120
109
  model_id: Optional[str]
121
110
  model_hub: str = "huggingface"
122
111
  model_uri: Optional[str]
@@ -136,8 +125,8 @@ class MLXLLMSpecV1(BaseModel):
136
125
  return v
137
126
 
138
127
 
139
- class LLMFamilyV1(BaseModel):
140
- version: Literal[1]
128
+ class LLMFamilyV2(BaseModel, ModelInstanceInfoMixin):
129
+ version: Literal[2]
141
130
  context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH
142
131
  model_name: str
143
132
  model_lang: List[str]
@@ -163,10 +152,61 @@ class LLMFamilyV1(BaseModel):
163
152
  stop: Optional[List[str]]
164
153
  reasoning_start_tag: Optional[str]
165
154
  reasoning_end_tag: Optional[str]
155
+ cache_config: Optional[dict]
166
156
  virtualenv: Optional[VirtualEnvSettings]
167
157
 
158
+ class Config:
159
+ extra = "allow"
160
+
161
+ def to_description(self):
162
+ spec = self.model_specs[0]
163
+ return {
164
+ "model_type": "LLM",
165
+ "address": getattr(self, "address", None),
166
+ "accelerators": getattr(self, "accelerators", None),
167
+ "model_name": self.model_name,
168
+ "model_lang": self.model_lang,
169
+ "model_ability": self.model_ability,
170
+ "model_description": self.model_description,
171
+ "model_format": spec.model_format,
172
+ "model_size_in_billions": spec.model_size_in_billions,
173
+ "model_family": self.model_family or self.model_name,
174
+ "quantization": spec.quantization,
175
+ "multimodal_projector": getattr(self, "multimodal_projector", None),
176
+ "model_hub": spec.model_hub,
177
+ "revision": spec.model_revision,
178
+ "context_length": self.context_length,
179
+ }
180
+
181
+ def to_version_info(self):
182
+ """
183
+ Entering this function means it is already bound to a model instance,
184
+ so there is only one spec.
185
+ """
186
+ from .cache_manager import LLMCacheManager
187
+ from .utils import get_model_version
188
+
189
+ spec = self.model_specs[0]
190
+ multimodal_projector = getattr(self, "multimodal_projector", None)
191
+ cache_manager = LLMCacheManager(self, multimodal_projector)
192
+
193
+ return {
194
+ "model_version": get_model_version(
195
+ self.model_name,
196
+ spec.model_format,
197
+ spec.model_size_in_billions,
198
+ spec.quantization,
199
+ ),
200
+ "model_file_location": cache_manager.get_cache_dir(),
201
+ "cache_status": cache_manager.get_cache_status(),
202
+ "quantization": spec.quantization,
203
+ "multimodal_projector": multimodal_projector,
204
+ "model_format": spec.model_format,
205
+ "model_size_in_billions": spec.model_size_in_billions,
206
+ }
168
207
 
169
- class CustomLLMFamilyV1(LLMFamilyV1):
208
+
209
+ class CustomLLMFamilyV2(LLMFamilyV2):
170
210
  @classmethod
171
211
  def parse_raw(
172
212
  cls: Any,
@@ -176,7 +216,7 @@ class CustomLLMFamilyV1(LLMFamilyV1):
176
216
  encoding: str = "utf8",
177
217
  proto: Protocol = None,
178
218
  allow_pickle: bool = False,
179
- ) -> LLMFamilyV1:
219
+ ) -> LLMFamilyV2:
180
220
  # See source code of BaseModel.parse_raw
181
221
  try:
182
222
  obj = load_str_bytes(
@@ -189,7 +229,7 @@ class CustomLLMFamilyV1(LLMFamilyV1):
189
229
  )
190
230
  except (ValueError, TypeError, UnicodeDecodeError) as e:
191
231
  raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls)
192
- llm_spec: CustomLLMFamilyV1 = cls.parse_obj(obj)
232
+ llm_spec: CustomLLMFamilyV2 = cls.parse_obj(obj)
193
233
  vision_model_names: Set[str] = {
194
234
  family.model_name
195
235
  for family in BUILTIN_LLM_FAMILIES
@@ -255,39 +295,27 @@ class CustomLLMFamilyV1(LLMFamilyV1):
255
295
 
256
296
 
257
297
  LLMSpecV1 = Annotated[
258
- Union[LlamaCppLLMSpecV1, PytorchLLMSpecV1, MLXLLMSpecV1],
298
+ Union[LlamaCppLLMSpecV2, PytorchLLMSpecV2, MLXLLMSpecV2],
259
299
  Field(discriminator="model_format"),
260
300
  ]
261
301
 
262
- LLMFamilyV1.update_forward_refs()
263
- CustomLLMFamilyV1.update_forward_refs()
302
+ LLMFamilyV2.update_forward_refs()
303
+ CustomLLMFamilyV2.update_forward_refs()
264
304
 
265
305
 
266
306
  LLAMA_CLASSES: List[Type[LLM]] = []
267
307
 
268
- BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
269
- BUILTIN_MODELSCOPE_LLM_FAMILIES: List["LLMFamilyV1"] = []
270
- BUILTIN_OPENMIND_HUB_LLM_FAMILIES: List["LLMFamilyV1"] = []
271
- BUILTIN_CSGHUB_LLM_FAMILIES: List["LLMFamilyV1"] = []
308
+ BUILTIN_LLM_FAMILIES: List["LLMFamilyV2"] = []
272
309
 
273
310
  SGLANG_CLASSES: List[Type[LLM]] = []
274
311
  TRANSFORMERS_CLASSES: List[Type[LLM]] = []
275
-
276
- UD_LLM_FAMILIES: List["LLMFamilyV1"] = []
277
-
278
- UD_LLM_FAMILIES_LOCK = Lock()
279
-
280
312
  VLLM_CLASSES: List[Type[LLM]] = []
281
-
282
313
  MLX_CLASSES: List[Type[LLM]] = []
283
-
284
314
  LMDEPLOY_CLASSES: List[Type[LLM]] = []
285
315
 
286
316
  LLM_ENGINES: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
287
317
  SUPPORTED_ENGINES: Dict[str, List[Type[LLM]]] = {}
288
318
 
289
- LLM_LAUNCH_VERSIONS: Dict[str, List[str]] = {}
290
-
291
319
 
292
320
  # Add decorator definition
293
321
  def register_transformer(cls):
@@ -308,107 +336,16 @@ def register_transformer(cls):
308
336
  return cls
309
337
 
310
338
 
311
- def download_from_self_hosted_storage() -> bool:
312
- from ...constants import XINFERENCE_ENV_MODEL_SRC
313
-
314
- return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "xorbits"
315
-
316
-
317
- def get_legacy_cache_path(
318
- model_name: str,
319
- model_format: str,
320
- model_size_in_billions: Optional[Union[str, int]] = None,
321
- quantization: Optional[str] = None,
322
- ) -> str:
323
- full_name = f"{model_name}-{model_format}-{model_size_in_billions}b-{quantization}"
324
- return os.path.join(XINFERENCE_CACHE_DIR, full_name, "model.bin")
325
-
326
-
327
- def cache(
328
- llm_family: LLMFamilyV1,
329
- llm_spec: "LLMSpecV1",
330
- quantization: Optional[str] = None,
331
- multimodal_projector: Optional[str] = None,
332
- ) -> str:
333
- legacy_cache_path = get_legacy_cache_path(
334
- llm_family.model_name,
335
- llm_spec.model_format,
336
- llm_spec.model_size_in_billions,
337
- quantization,
338
- )
339
- if os.path.exists(legacy_cache_path):
340
- logger.info("Legacy cache path exists: %s", legacy_cache_path)
341
- return os.path.dirname(legacy_cache_path)
342
- else:
343
- if llm_spec.model_uri is not None:
344
- logger.info(f"Caching from URI: {llm_spec.model_uri}")
345
- return cache_from_uri(llm_family, llm_spec)
346
- else:
347
- if llm_spec.model_hub == "huggingface":
348
- logger.info(f"Caching from Hugging Face: {llm_spec.model_id}")
349
- return cache_from_huggingface(
350
- llm_family, llm_spec, quantization, multimodal_projector
351
- )
352
- elif llm_spec.model_hub == "modelscope":
353
- logger.info(f"Caching from Modelscope: {llm_spec.model_id}")
354
- return cache_from_modelscope(
355
- llm_family, llm_spec, quantization, multimodal_projector
356
- )
357
- elif llm_spec.model_hub == "openmind_hub":
358
- logger.info(f"Caching from openmind_hub: {llm_spec.model_id}")
359
- return cache_from_openmind_hub(
360
- llm_family, llm_spec, quantization, multimodal_projector
361
- )
362
- elif llm_spec.model_hub == "csghub":
363
- logger.info(f"Caching from CSGHub: {llm_spec.model_id}")
364
- return cache_from_csghub(
365
- llm_family, llm_spec, quantization, multimodal_projector
366
- )
367
- else:
368
- raise ValueError(f"Unknown model hub: {llm_spec.model_hub}")
369
-
370
-
371
- def cache_from_uri(
372
- llm_family: LLMFamilyV1,
373
- llm_spec: "LLMSpecV1",
374
- ) -> str:
375
- cache_dir_name = (
376
- f"{llm_family.model_name}-{llm_spec.model_format}"
377
- f"-{llm_spec.model_size_in_billions}b"
378
- )
379
- cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name))
380
-
381
- assert llm_spec.model_uri is not None
382
- src_scheme, src_root = parse_uri(llm_spec.model_uri)
383
- if src_root.endswith("/"):
384
- # remove trailing path separator.
385
- src_root = src_root[:-1]
386
-
387
- if src_scheme == "file":
388
- if not os.path.isabs(src_root):
389
- raise ValueError(
390
- f"Model URI cannot be a relative path: {llm_spec.model_uri}"
391
- )
392
- os.makedirs(XINFERENCE_CACHE_DIR, exist_ok=True)
393
- if os.path.exists(cache_dir):
394
- logger.info(f"Cache {cache_dir} exists")
395
- return cache_dir
396
- else:
397
- os.symlink(src_root, cache_dir, target_is_directory=True)
398
- return cache_dir
399
- else:
400
- raise ValueError(f"Unsupported URL scheme: {src_scheme}")
401
-
402
-
403
339
  def cache_model_tokenizer_and_config(
404
- llm_family: LLMFamilyV1,
405
- llm_spec: "LLMSpecV1",
340
+ llm_family: LLMFamilyV2,
406
341
  ) -> str:
407
342
  """
408
343
  Download model config.json and tokenizers only
409
344
  """
345
+ llm_spec = llm_family.model_specs[0]
410
346
  cache_dir = _get_cache_dir_for_model_mem(llm_family, llm_spec, "tokenizer_config")
411
347
  os.makedirs(cache_dir, exist_ok=True)
348
+ patterns = ["tokenizer*", "config.json", "configuration*", "tokenization*"]
412
349
  if llm_spec.model_hub == "huggingface":
413
350
  from huggingface_hub import snapshot_download
414
351
 
@@ -421,7 +358,7 @@ def cache_model_tokenizer_and_config(
421
358
  },
422
359
  llm_spec.model_id,
423
360
  revision=llm_spec.model_revision,
424
- allow_patterns=["tokenizer*", "config.json"],
361
+ allow_patterns=patterns,
425
362
  local_dir=cache_dir,
426
363
  )
427
364
  elif llm_spec.model_hub == "modelscope":
@@ -436,7 +373,7 @@ def cache_model_tokenizer_and_config(
436
373
  },
437
374
  llm_spec.model_id,
438
375
  revision=llm_spec.model_revision,
439
- allow_patterns=["tokenizer*", "config.json"],
376
+ allow_patterns=patterns,
440
377
  local_dir=cache_dir,
441
378
  )
442
379
  else:
@@ -447,13 +384,11 @@ def cache_model_tokenizer_and_config(
447
384
  return download_dir
448
385
 
449
386
 
450
- def cache_model_config(
451
- llm_family: LLMFamilyV1,
452
- llm_spec: "LLMSpecV1",
453
- ):
387
+ def cache_model_config(llm_family: LLMFamilyV2):
454
388
  """Download model config.json into cache_dir,
455
389
  returns local filepath
456
390
  """
391
+ llm_spec = llm_family.model_specs[0]
457
392
  cache_dir = _get_cache_dir_for_model_mem(llm_family, llm_spec, "model_mem")
458
393
  config_file = os.path.join(cache_dir, "config.json")
459
394
  if not os.path.islink(config_file) and not os.path.exists(config_file):
@@ -475,7 +410,7 @@ def cache_model_config(
475
410
 
476
411
 
477
412
  def _get_cache_dir_for_model_mem(
478
- llm_family: LLMFamilyV1,
413
+ llm_family: LLMFamilyV2,
479
414
  llm_spec: "LLMSpecV1",
480
415
  category: str,
481
416
  create_if_not_exist=True,
@@ -486,597 +421,18 @@ def _get_cache_dir_for_model_mem(
486
421
  e.g. for cal-model-mem, (might called from supervisor / cli)
487
422
  Temporary use separate dir from worker's cache_dir, due to issue of different style of symlink.
488
423
  """
489
- quant_suffix = ""
490
- for q in llm_spec.quantizations:
491
- if llm_spec.model_id and q in llm_spec.model_id:
492
- quant_suffix = q
493
- break
494
424
  cache_dir_name = (
495
425
  f"{llm_family.model_name}-{llm_spec.model_format}"
496
- f"-{llm_spec.model_size_in_billions}b"
426
+ f"-{llm_spec.model_size_in_billions}b-{llm_spec.quantization}"
497
427
  )
498
- if quant_suffix:
499
- cache_dir_name += f"-{quant_suffix}"
500
428
  cache_dir = os.path.realpath(
501
- os.path.join(XINFERENCE_CACHE_DIR, category, cache_dir_name)
429
+ os.path.join(XINFERENCE_CACHE_DIR, "v2", category, cache_dir_name)
502
430
  )
503
431
  if create_if_not_exist and not os.path.exists(cache_dir):
504
432
  os.makedirs(cache_dir, exist_ok=True)
505
433
  return cache_dir
506
434
 
507
435
 
508
- def _get_cache_dir(
509
- llm_family: LLMFamilyV1,
510
- llm_spec: "LLMSpecV1",
511
- quantization: Optional[str] = None,
512
- create_if_not_exist=True,
513
- ):
514
- # If the model id contains quantization, then we should give each
515
- # quantization a dedicated cache dir.
516
- quant_suffix = ""
517
- if llm_spec.model_id and "{" in llm_spec.model_id and quantization is not None:
518
- quant_suffix = quantization
519
- else:
520
- for q in llm_spec.quantizations:
521
- if llm_spec.model_id and q in llm_spec.model_id:
522
- quant_suffix = q
523
- break
524
-
525
- # some model name includes ".", e.g. qwen1.5-chat
526
- # if the model does not require trust_remote_code, it's OK
527
- # because no need to import modeling_xxx.py from the path
528
- # but when the model need to trust_remote_code,
529
- # e.g. internlm2.5-chat, the import will fail,
530
- # but before the model may have been downloaded,
531
- # thus we check it first, if exist, return it,
532
- # otherwise, we replace the "." with "_" in model name
533
- old_cache_dir_name = (
534
- f"{llm_family.model_name}-{llm_spec.model_format}"
535
- f"-{llm_spec.model_size_in_billions}b"
536
- )
537
- if quant_suffix:
538
- old_cache_dir_name += f"-{quant_suffix}"
539
- old_cache_dir = os.path.realpath(
540
- os.path.join(XINFERENCE_CACHE_DIR, old_cache_dir_name)
541
- )
542
- if os.path.exists(old_cache_dir):
543
- return old_cache_dir
544
- else:
545
- cache_dir_name = (
546
- f"{llm_family.model_name.replace('.', '_')}-{llm_spec.model_format}"
547
- f"-{llm_spec.model_size_in_billions}b"
548
- )
549
- if quant_suffix:
550
- cache_dir_name += f"-{quant_suffix}"
551
- cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name))
552
- if create_if_not_exist and not os.path.exists(cache_dir):
553
- os.makedirs(cache_dir, exist_ok=True)
554
- return cache_dir
555
-
556
-
557
- def _get_meta_path(
558
- cache_dir: str,
559
- model_format: str,
560
- model_hub: str,
561
- quantization: Optional[str] = None,
562
- multimodal_projector: Optional[str] = None,
563
- ):
564
- if model_format == "pytorch":
565
- if model_hub == "huggingface":
566
- return os.path.join(cache_dir, "__valid_download")
567
- else:
568
- return os.path.join(cache_dir, f"__valid_download_{model_hub}")
569
- elif model_format == "ggufv2":
570
- assert quantization is not None
571
- if multimodal_projector is None:
572
- # Compatible with old cache file to avoid re-download model.
573
- if model_hub == "huggingface":
574
- return os.path.join(cache_dir, f"__valid_download_{quantization}")
575
- else:
576
- return os.path.join(
577
- cache_dir, f"__valid_download_{model_hub}_{quantization}"
578
- )
579
- else:
580
- if model_hub == "huggingface":
581
- return os.path.join(
582
- cache_dir, f"__valid_download_{quantization}_{multimodal_projector}"
583
- )
584
- else:
585
- return os.path.join(
586
- cache_dir,
587
- f"__valid_download_{model_hub}_{quantization}_{multimodal_projector}",
588
- )
589
- elif model_format in ["gptq", "awq", "fp8", "mlx"]:
590
- assert quantization is not None
591
- if model_hub == "huggingface":
592
- return os.path.join(cache_dir, f"__valid_download_{quantization}")
593
- else:
594
- return os.path.join(
595
- cache_dir, f"__valid_download_{model_hub}_{quantization}"
596
- )
597
- else:
598
- raise ValueError(f"Unsupported format: {model_format}")
599
-
600
-
601
- def _skip_download(
602
- cache_dir: str,
603
- model_format: str,
604
- model_hub: str,
605
- model_revision: Optional[str],
606
- quantization: Optional[str] = None,
607
- multimodal_projector: Optional[str] = None,
608
- ) -> bool:
609
- if model_format in ["pytorch", "mindspore"]:
610
- model_hub_to_meta_path = {
611
- "huggingface": _get_meta_path(
612
- cache_dir, model_format, "huggingface", quantization
613
- ),
614
- "modelscope": _get_meta_path(
615
- cache_dir, model_format, "modelscope", quantization
616
- ),
617
- "openmind_hub": _get_meta_path(
618
- cache_dir, model_format, "openmind_hub", quantization
619
- ),
620
- "csghub": _get_meta_path(cache_dir, model_format, "csghub", quantization),
621
- }
622
- if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision):
623
- logger.info(f"Cache {cache_dir} exists")
624
- return True
625
- else:
626
- for hub, meta_path in model_hub_to_meta_path.items():
627
- if hub != model_hub and os.path.exists(meta_path):
628
- # PyTorch models from modelscope can also be loaded by transformers.
629
- logger.warning(f"Cache {cache_dir} exists, but it was from {hub}")
630
- return True
631
- return False
632
- elif model_format == "ggufv2":
633
- assert quantization is not None
634
- return os.path.exists(
635
- _get_meta_path(
636
- cache_dir, model_format, model_hub, quantization, multimodal_projector
637
- )
638
- )
639
- elif model_format in ["gptq", "awq", "fp8", "mlx"]:
640
- assert quantization is not None
641
- return os.path.exists(
642
- _get_meta_path(cache_dir, model_format, model_hub, quantization)
643
- )
644
- else:
645
- raise ValueError(f"Unsupported format: {model_format}")
646
-
647
-
648
- def _generate_meta_file(
649
- meta_path: str,
650
- llm_family: "LLMFamilyV1",
651
- llm_spec: "LLMSpecV1",
652
- quantization: Optional[str] = None,
653
- multimodal_projector: Optional[str] = None,
654
- ):
655
- assert not valid_model_revision(
656
- meta_path, llm_spec.model_revision
657
- ), f"meta file {meta_path} should not be valid"
658
- with open(meta_path, "w") as f:
659
- import json
660
-
661
- from .core import LLMDescription
662
-
663
- desc = LLMDescription(
664
- None, None, llm_family, llm_spec, quantization, multimodal_projector
665
- )
666
- json.dump(desc.to_dict(), f)
667
-
668
-
669
- def _generate_model_file_names(
670
- llm_spec: "LLMSpecV1",
671
- quantization: Optional[str] = None,
672
- multimodal_projector: Optional[str] = None,
673
- ) -> Tuple[List[str], str, bool]:
674
- file_names = []
675
- final_file_name = llm_spec.model_file_name_template.format(
676
- quantization=quantization
677
- )
678
- need_merge = False
679
-
680
- if (
681
- llm_spec.quantization_parts is None
682
- or quantization not in llm_spec.quantization_parts
683
- ):
684
- file_names.append(final_file_name)
685
- elif quantization is not None and quantization in llm_spec.quantization_parts:
686
- parts = llm_spec.quantization_parts[quantization]
687
- need_merge = True
688
-
689
- logger.info(
690
- f"Model {llm_spec.model_id} {llm_spec.model_format} {quantization} has {len(parts)} parts."
691
- )
692
-
693
- if llm_spec.model_file_name_split_template is None:
694
- raise ValueError(
695
- f"No model_file_name_split_template for model spec {llm_spec.model_id}"
696
- )
697
-
698
- for part in parts:
699
- file_name = llm_spec.model_file_name_split_template.format(
700
- quantization=quantization, part=part
701
- )
702
- file_names.append(file_name)
703
- if multimodal_projector:
704
- file_names.append(multimodal_projector)
705
-
706
- return file_names, final_file_name, need_merge
707
-
708
-
709
- def _merge_cached_files(
710
- cache_dir: str, input_file_names: List[str], output_file_name: str
711
- ):
712
- # now llama.cpp can find the gguf parts automatically
713
- # we only need to provide the first part
714
- # thus we create the symlink to the first part
715
- symlink_local_file(
716
- os.path.join(cache_dir, input_file_names[0]), cache_dir, output_file_name
717
- )
718
-
719
- logger.info(f"Merge complete.")
720
-
721
-
722
- def cache_from_csghub(
723
- llm_family: LLMFamilyV1,
724
- llm_spec: "LLMSpecV1",
725
- quantization: Optional[str] = None,
726
- multimodal_projector: Optional[str] = None,
727
- ) -> str:
728
- """
729
- Cache model from CSGHub. Return the cache directory.
730
- """
731
- from pycsghub.file_download import file_download
732
- from pycsghub.snapshot_download import snapshot_download
733
-
734
- cache_dir = _get_cache_dir(llm_family, llm_spec)
735
-
736
- if _skip_download(
737
- cache_dir,
738
- llm_spec.model_format,
739
- llm_spec.model_hub,
740
- llm_spec.model_revision,
741
- quantization,
742
- multimodal_projector,
743
- ):
744
- return cache_dir
745
-
746
- if llm_spec.model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
747
- download_dir = retry_download(
748
- snapshot_download,
749
- llm_family.model_name,
750
- {
751
- "model_size": llm_spec.model_size_in_billions,
752
- "model_format": llm_spec.model_format,
753
- },
754
- llm_spec.model_id,
755
- endpoint=XINFERENCE_CSG_ENDPOINT,
756
- token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
757
- )
758
- create_symlink(download_dir, cache_dir)
759
-
760
- elif llm_spec.model_format in ["ggufv2"]:
761
- file_names, final_file_name, need_merge = _generate_model_file_names(
762
- llm_spec, quantization, multimodal_projector
763
- )
764
-
765
- for filename in file_names:
766
- download_path = retry_download(
767
- file_download,
768
- llm_family.model_name,
769
- {
770
- "model_size": llm_spec.model_size_in_billions,
771
- "model_format": llm_spec.model_format,
772
- },
773
- llm_spec.model_id,
774
- file_name=filename,
775
- endpoint=XINFERENCE_CSG_ENDPOINT,
776
- token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
777
- )
778
- symlink_local_file(download_path, cache_dir, filename)
779
-
780
- if need_merge:
781
- _merge_cached_files(cache_dir, file_names, final_file_name)
782
- else:
783
- raise ValueError(f"Unsupported format: {llm_spec.model_format}")
784
-
785
- meta_path = _get_meta_path(
786
- cache_dir,
787
- llm_spec.model_format,
788
- llm_spec.model_hub,
789
- quantization,
790
- multimodal_projector,
791
- )
792
- _generate_meta_file(
793
- meta_path, llm_family, llm_spec, quantization, multimodal_projector
794
- )
795
-
796
- return cache_dir
797
-
798
-
799
- def cache_from_modelscope(
800
- llm_family: LLMFamilyV1,
801
- llm_spec: "LLMSpecV1",
802
- quantization: Optional[str] = None,
803
- multimodal_projector: Optional[str] = None,
804
- ) -> str:
805
- """
806
- Cache model from Modelscope. Return the cache directory.
807
- """
808
- from modelscope.hub.file_download import model_file_download
809
- from modelscope.hub.snapshot_download import snapshot_download
810
-
811
- cache_dir = _get_cache_dir(llm_family, llm_spec)
812
- if _skip_download(
813
- cache_dir,
814
- llm_spec.model_format,
815
- llm_spec.model_hub,
816
- llm_spec.model_revision,
817
- quantization,
818
- multimodal_projector,
819
- ):
820
- return cache_dir
821
-
822
- if llm_spec.model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
823
- download_dir = retry_download(
824
- snapshot_download,
825
- llm_family.model_name,
826
- {
827
- "model_size": llm_spec.model_size_in_billions,
828
- "model_format": llm_spec.model_format,
829
- },
830
- llm_spec.model_id,
831
- revision=llm_spec.model_revision,
832
- )
833
- create_symlink(download_dir, cache_dir)
834
-
835
- elif llm_spec.model_format in ["ggufv2"]:
836
- file_names, final_file_name, need_merge = _generate_model_file_names(
837
- llm_spec, quantization, multimodal_projector
838
- )
839
-
840
- for filename in file_names:
841
- download_path = retry_download(
842
- model_file_download,
843
- llm_family.model_name,
844
- {
845
- "model_size": llm_spec.model_size_in_billions,
846
- "model_format": llm_spec.model_format,
847
- },
848
- llm_spec.model_id,
849
- filename,
850
- revision=llm_spec.model_revision,
851
- )
852
- symlink_local_file(download_path, cache_dir, filename)
853
-
854
- if need_merge:
855
- _merge_cached_files(cache_dir, file_names, final_file_name)
856
- else:
857
- raise ValueError(f"Unsupported format: {llm_spec.model_format}")
858
-
859
- meta_path = _get_meta_path(
860
- cache_dir,
861
- llm_spec.model_format,
862
- llm_spec.model_hub,
863
- quantization,
864
- multimodal_projector,
865
- )
866
- _generate_meta_file(meta_path, llm_family, llm_spec, quantization)
867
-
868
- return cache_dir
869
-
870
-
871
- def cache_from_openmind_hub(
872
- llm_family: LLMFamilyV1,
873
- llm_spec: "LLMSpecV1",
874
- quantization: Optional[str] = None,
875
- multimodal_projector: Optional[str] = None,
876
- ) -> str:
877
- """
878
- Cache model from openmind_hub. Return the cache directory.
879
- """
880
- from openmind_hub import snapshot_download
881
-
882
- cache_dir = _get_cache_dir(llm_family, llm_spec)
883
- if _skip_download(
884
- cache_dir,
885
- llm_spec.model_format,
886
- llm_spec.model_hub,
887
- llm_spec.model_revision,
888
- quantization,
889
- multimodal_projector,
890
- ):
891
- return cache_dir
892
-
893
- if llm_spec.model_format in ["pytorch", "mindspore"]:
894
- download_dir = retry_download(
895
- snapshot_download,
896
- llm_family.model_name,
897
- {
898
- "model_size": llm_spec.model_size_in_billions,
899
- "model_format": llm_spec.model_format,
900
- },
901
- llm_spec.model_id,
902
- revision=llm_spec.model_revision,
903
- )
904
- create_symlink(download_dir, cache_dir)
905
-
906
- else:
907
- raise ValueError(f"Unsupported format: {llm_spec.model_format}")
908
-
909
- meta_path = _get_meta_path(
910
- cache_dir,
911
- llm_spec.model_format,
912
- llm_spec.model_hub,
913
- quantization,
914
- multimodal_projector,
915
- )
916
- _generate_meta_file(meta_path, llm_family, llm_spec, quantization)
917
-
918
- return cache_dir
919
-
920
-
921
- def cache_from_huggingface(
922
- llm_family: LLMFamilyV1,
923
- llm_spec: "LLMSpecV1",
924
- quantization: Optional[str] = None,
925
- multimodal_projector: Optional[str] = None,
926
- ) -> str:
927
- """
928
- Cache model from Hugging Face. Return the cache directory.
929
- """
930
- import huggingface_hub
931
-
932
- cache_dir = _get_cache_dir(llm_family, llm_spec)
933
- if _skip_download(
934
- cache_dir,
935
- llm_spec.model_format,
936
- llm_spec.model_hub,
937
- llm_spec.model_revision,
938
- quantization,
939
- multimodal_projector,
940
- ):
941
- return cache_dir
942
-
943
- use_symlinks = {}
944
- if not IS_NEW_HUGGINGFACE_HUB:
945
- use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
946
-
947
- if llm_spec.model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
948
- assert isinstance(llm_spec, (PytorchLLMSpecV1, MLXLLMSpecV1))
949
- download_dir = retry_download(
950
- huggingface_hub.snapshot_download,
951
- llm_family.model_name,
952
- {
953
- "model_size": llm_spec.model_size_in_billions,
954
- "model_format": llm_spec.model_format,
955
- },
956
- llm_spec.model_id,
957
- revision=llm_spec.model_revision,
958
- **use_symlinks,
959
- )
960
- if IS_NEW_HUGGINGFACE_HUB:
961
- create_symlink(download_dir, cache_dir)
962
-
963
- elif llm_spec.model_format in ["ggufv2"]:
964
- assert isinstance(llm_spec, LlamaCppLLMSpecV1)
965
- file_names, final_file_name, need_merge = _generate_model_file_names(
966
- llm_spec, quantization, multimodal_projector
967
- )
968
-
969
- for file_name in file_names:
970
- download_file_path = retry_download(
971
- huggingface_hub.hf_hub_download,
972
- llm_family.model_name,
973
- {
974
- "model_size": llm_spec.model_size_in_billions,
975
- "model_format": llm_spec.model_format,
976
- },
977
- llm_spec.model_id,
978
- revision=llm_spec.model_revision,
979
- filename=file_name,
980
- **use_symlinks,
981
- )
982
- if IS_NEW_HUGGINGFACE_HUB:
983
- symlink_local_file(download_file_path, cache_dir, file_name)
984
-
985
- if need_merge:
986
- _merge_cached_files(cache_dir, file_names, final_file_name)
987
- else:
988
- raise ValueError(f"Unsupported model format: {llm_spec.model_format}")
989
-
990
- meta_path = _get_meta_path(
991
- cache_dir,
992
- llm_spec.model_format,
993
- llm_spec.model_hub,
994
- quantization,
995
- multimodal_projector,
996
- )
997
- _generate_meta_file(meta_path, llm_family, llm_spec, quantization)
998
-
999
- return cache_dir
1000
-
1001
-
1002
- def _check_revision(
1003
- llm_family: LLMFamilyV1,
1004
- llm_spec: "LLMSpecV1",
1005
- builtin: list,
1006
- meta_path: str,
1007
- quantization: Optional[str] = None,
1008
- ) -> bool:
1009
- for family in builtin:
1010
- if llm_family.model_name == family.model_name:
1011
- specs = family.model_specs
1012
- for spec in specs:
1013
- if (
1014
- spec.model_format == "pytorch"
1015
- and spec.model_size_in_billions == llm_spec.model_size_in_billions
1016
- and (quantization is None or quantization in spec.quantizations)
1017
- ):
1018
- return valid_model_revision(meta_path, spec.model_revision)
1019
- return False
1020
-
1021
-
1022
- def get_cache_status(
1023
- llm_family: LLMFamilyV1, llm_spec: "LLMSpecV1", quantization: Optional[str] = None
1024
- ) -> Union[bool, List[bool]]:
1025
- """
1026
- Checks if a model's cache status is available based on the model format and quantization.
1027
- Supports different directories and model formats.
1028
- """
1029
-
1030
- def check_file_status(meta_path: str) -> bool:
1031
- return os.path.exists(meta_path)
1032
-
1033
- def check_revision_status(
1034
- meta_path: str, families: list, quantization: Optional[str] = None
1035
- ) -> bool:
1036
- return _check_revision(llm_family, llm_spec, families, meta_path, quantization)
1037
-
1038
- def handle_quantization(q: Union[str, None]) -> bool:
1039
- specific_cache_dir = _get_cache_dir(
1040
- llm_family, llm_spec, q, create_if_not_exist=False
1041
- )
1042
- meta_paths = {
1043
- "huggingface": _get_meta_path(
1044
- specific_cache_dir, llm_spec.model_format, "huggingface", q
1045
- ),
1046
- "modelscope": _get_meta_path(
1047
- specific_cache_dir, llm_spec.model_format, "modelscope", q
1048
- ),
1049
- }
1050
- if llm_spec.model_format == "pytorch":
1051
- return check_revision_status(
1052
- meta_paths["huggingface"], BUILTIN_LLM_FAMILIES, q
1053
- ) or check_revision_status(
1054
- meta_paths["modelscope"], BUILTIN_MODELSCOPE_LLM_FAMILIES, q
1055
- )
1056
- else:
1057
- return check_file_status(meta_paths["huggingface"]) or check_file_status(
1058
- meta_paths["modelscope"]
1059
- )
1060
-
1061
- if llm_spec.model_id and "{" in llm_spec.model_id:
1062
- return (
1063
- [handle_quantization(q) for q in llm_spec.quantizations]
1064
- if quantization is None
1065
- else handle_quantization(quantization)
1066
- )
1067
- else:
1068
- return (
1069
- [handle_quantization(q) for q in llm_spec.quantizations]
1070
- if llm_spec.model_format != "pytorch"
1071
- else handle_quantization(None)
1072
- )
1073
-
1074
-
1075
- def get_user_defined_llm_families():
1076
- with UD_LLM_FAMILIES_LOCK:
1077
- return UD_LLM_FAMILIES.copy()
1078
-
1079
-
1080
436
  def match_model_size(
1081
437
  model_size: Union[int, str], spec_model_size: Union[int, str]
1082
438
  ) -> bool:
@@ -1097,7 +453,7 @@ def match_model_size(
1097
453
 
1098
454
 
1099
455
  def convert_model_size_to_float(
1100
- model_size_in_billions: Union[float, int, str]
456
+ model_size_in_billions: Union[float, int, str],
1101
457
  ) -> float:
1102
458
  if isinstance(model_size_in_billions, str):
1103
459
  if "_" in model_size_in_billions:
@@ -1118,55 +474,68 @@ def match_llm(
1118
474
  download_hub: Optional[
1119
475
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
1120
476
  ] = None,
1121
- ) -> Optional[Tuple[LLMFamilyV1, LLMSpecV1, str]]:
477
+ ) -> Optional[LLMFamilyV2]:
1122
478
  """
1123
479
  Find an LLM family, spec, and quantization that satisfy given criteria.
1124
480
  """
481
+ from .custom import get_user_defined_llm_families
482
+
1125
483
  user_defined_llm_families = get_user_defined_llm_families()
1126
484
 
1127
- def _match_quantization(q: Union[str, None], quantizations: List[str]):
485
+ def _match_quantization(q: Union[str, None], quant: str):
1128
486
  # Currently, the quantization name could include both uppercase and lowercase letters,
1129
487
  # so it is necessary to ensure that the case sensitivity does not
1130
488
  # affect the matching results.
1131
- if q is None:
1132
- return q
1133
- for quant in quantizations:
1134
- if q.lower() == quant.lower():
1135
- return quant
489
+ if q is None or q.lower() != quant.lower():
490
+ return None
491
+ return quant
1136
492
 
1137
- def _apply_format_to_model_id(spec: LLMSpecV1, q: str) -> LLMSpecV1:
493
+ def _apply_format_to_model_id(_spec: "LLMSpecV1", q: str) -> "LLMSpecV1":
1138
494
  # Different quantized versions of some models use different model ids,
1139
495
  # Here we check the `{}` in the model id to format the id.
1140
- if spec.model_id and "{" in spec.model_id:
1141
- spec.model_id = spec.model_id.format(quantization=q)
1142
- return spec
496
+ if _spec.model_id and "{" in _spec.model_id:
497
+ _spec.model_id = _spec.model_id.format(quantization=q)
498
+ return _spec
499
+
500
+ def _get_model_specs(
501
+ _model_specs: List["LLMSpecV1"], hub: str
502
+ ) -> List["LLMSpecV1"]:
503
+ return [x for x in _model_specs if x.model_hub == hub]
1143
504
 
1144
505
  # priority: download_hub > download_from_modelscope() and download_from_csghub()
1145
506
  # set base model
1146
- base_families = BUILTIN_LLM_FAMILIES + user_defined_llm_families
1147
- hub_families_map = {
1148
- "modelscope": BUILTIN_MODELSCOPE_LLM_FAMILIES,
1149
- "openmind_hub": BUILTIN_OPENMIND_HUB_LLM_FAMILIES,
1150
- "csghub": BUILTIN_CSGHUB_LLM_FAMILIES,
1151
- }
1152
- if download_hub == "huggingface":
1153
- all_families = base_families
1154
- elif download_hub in hub_families_map:
1155
- all_families = hub_families_map[download_hub] + base_families
1156
- elif download_from_modelscope():
1157
- all_families = BUILTIN_MODELSCOPE_LLM_FAMILIES + base_families
1158
- elif download_from_openmind_hub():
1159
- all_families = BUILTIN_OPENMIND_HUB_LLM_FAMILIES + base_families
1160
- elif download_from_csghub():
1161
- all_families = BUILTIN_CSGHUB_LLM_FAMILIES + base_families
1162
- else:
1163
- all_families = base_families
507
+ families = BUILTIN_LLM_FAMILIES + user_defined_llm_families
1164
508
 
1165
- for family in all_families:
509
+ for family in families:
1166
510
  if model_name != family.model_name:
1167
511
  continue
1168
- for spec in family.model_specs:
1169
- matched_quantization = _match_quantization(quantization, spec.quantizations)
512
+
513
+ # prepare possible quantization matching options
514
+ if download_hub is not None:
515
+ if download_hub == "huggingface":
516
+ model_specs = _get_model_specs(family.model_specs, download_hub)
517
+ else:
518
+ model_specs = _get_model_specs(
519
+ family.model_specs, download_hub
520
+ ) + _get_model_specs(family.model_specs, "huggingface")
521
+ else:
522
+ if download_from_modelscope():
523
+ model_specs = _get_model_specs(
524
+ family.model_specs, "modelscope"
525
+ ) + _get_model_specs(family.model_specs, "huggingface")
526
+ elif download_from_openmind_hub():
527
+ model_specs = _get_model_specs(
528
+ family.model_specs, "openmind_hub"
529
+ ) + _get_model_specs(family.model_specs, "huggingface")
530
+ elif download_from_csghub():
531
+ model_specs = _get_model_specs(
532
+ family.model_specs, "csghub"
533
+ ) + _get_model_specs(family.model_specs, "huggingface")
534
+ else:
535
+ model_specs = _get_model_specs(family.model_specs, "huggingface")
536
+
537
+ for spec in model_specs:
538
+ # check model_format and model_size_in_billions
1170
539
  if (
1171
540
  model_format
1172
541
  and model_format != spec.model_format
@@ -1174,97 +543,27 @@ def match_llm(
1174
543
  and not match_model_size(
1175
544
  model_size_in_billions, spec.model_size_in_billions
1176
545
  )
1177
- or quantization
1178
- and matched_quantization is None
1179
546
  ):
1180
547
  continue
1181
- # Copy spec to avoid _apply_format_to_model_id modify the original spec.
1182
- spec = spec.copy()
548
+
549
+ # Check quantization
550
+ matched_quantization = _match_quantization(quantization, spec.quantization)
551
+ if quantization and matched_quantization is None:
552
+ continue
553
+ _llm_family = family.copy()
1183
554
  if quantization:
1184
- return (
1185
- family,
1186
- _apply_format_to_model_id(spec, matched_quantization),
1187
- matched_quantization,
1188
- )
555
+ _llm_family.model_specs = [
556
+ _apply_format_to_model_id(spec, matched_quantization)
557
+ ]
558
+ return _llm_family
1189
559
  else:
1190
560
  # TODO: If user does not specify quantization, just use the first one
1191
- _q = "none" if spec.model_format == "pytorch" else spec.quantizations[0]
1192
- return family, _apply_format_to_model_id(spec, _q), _q
561
+ _q = "none" if spec.model_format == "pytorch" else spec.quantization
562
+ _llm_family.model_specs = [_apply_format_to_model_id(spec, _q)]
563
+ return _llm_family
1193
564
  return None
1194
565
 
1195
566
 
1196
- def register_llm(llm_family: LLMFamilyV1, persist: bool):
1197
- from ..utils import is_valid_model_name
1198
- from . import generate_engine_config_by_model_family
1199
-
1200
- if not is_valid_model_name(llm_family.model_name):
1201
- raise ValueError(f"Invalid model name {llm_family.model_name}.")
1202
-
1203
- for spec in llm_family.model_specs:
1204
- model_uri = spec.model_uri
1205
- if model_uri and not is_valid_model_uri(model_uri):
1206
- raise ValueError(f"Invalid model URI {model_uri}.")
1207
-
1208
- with UD_LLM_FAMILIES_LOCK:
1209
- for family in BUILTIN_LLM_FAMILIES + UD_LLM_FAMILIES:
1210
- if llm_family.model_name == family.model_name:
1211
- raise ValueError(
1212
- f"Model name conflicts with existing model {family.model_name}"
1213
- )
1214
-
1215
- UD_LLM_FAMILIES.append(llm_family)
1216
- generate_engine_config_by_model_family(llm_family)
1217
-
1218
- if persist:
1219
- persist_path = os.path.join(
1220
- XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
1221
- )
1222
- os.makedirs(os.path.dirname(persist_path), exist_ok=True)
1223
- with open(persist_path, mode="w") as fd:
1224
- fd.write(llm_family.json())
1225
-
1226
-
1227
- def unregister_llm(model_name: str, raise_error: bool = True):
1228
- with UD_LLM_FAMILIES_LOCK:
1229
- llm_family = None
1230
- for i, f in enumerate(UD_LLM_FAMILIES):
1231
- if f.model_name == model_name:
1232
- llm_family = f
1233
- break
1234
- if llm_family:
1235
- UD_LLM_FAMILIES.remove(llm_family)
1236
- del LLM_ENGINES[model_name]
1237
-
1238
- persist_path = os.path.join(
1239
- XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
1240
- )
1241
- if os.path.exists(persist_path):
1242
- os.remove(persist_path)
1243
-
1244
- llm_spec = llm_family.model_specs[0]
1245
- cache_dir_name = (
1246
- f"{llm_family.model_name}-{llm_spec.model_format}"
1247
- f"-{llm_spec.model_size_in_billions}b"
1248
- )
1249
- cache_dir = os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name)
1250
- if os.path.exists(cache_dir):
1251
- logger.warning(
1252
- f"Remove the cache of user-defined model {llm_family.model_name}. "
1253
- f"Cache directory: {cache_dir}"
1254
- )
1255
- if os.path.islink(cache_dir):
1256
- os.remove(cache_dir)
1257
- else:
1258
- logger.warning(
1259
- f"Cache directory is not a soft link, please remove it manually."
1260
- )
1261
- else:
1262
- if raise_error:
1263
- raise ValueError(f"Model {model_name} not found")
1264
- else:
1265
- logger.warning(f"Custom model {model_name} not found")
1266
-
1267
-
1268
567
  def check_engine_by_spec_parameters(
1269
568
  model_engine: str,
1270
569
  model_name: str,