xinference 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (70) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +123 -3
  3. xinference/client/restful/restful_client.py +131 -2
  4. xinference/core/model.py +93 -24
  5. xinference/core/supervisor.py +132 -15
  6. xinference/core/worker.py +165 -8
  7. xinference/deploy/cmdline.py +5 -0
  8. xinference/model/audio/chattts.py +46 -14
  9. xinference/model/audio/core.py +23 -15
  10. xinference/model/core.py +12 -3
  11. xinference/model/embedding/core.py +25 -16
  12. xinference/model/flexible/__init__.py +40 -0
  13. xinference/model/flexible/core.py +228 -0
  14. xinference/model/flexible/launchers/__init__.py +15 -0
  15. xinference/model/flexible/launchers/transformers_launcher.py +63 -0
  16. xinference/model/flexible/utils.py +33 -0
  17. xinference/model/image/core.py +21 -14
  18. xinference/model/image/custom.py +1 -1
  19. xinference/model/image/model_spec.json +14 -0
  20. xinference/model/image/stable_diffusion/core.py +43 -6
  21. xinference/model/llm/__init__.py +0 -2
  22. xinference/model/llm/core.py +3 -2
  23. xinference/model/llm/ggml/llamacpp.py +1 -10
  24. xinference/model/llm/llm_family.json +292 -36
  25. xinference/model/llm/llm_family.py +97 -52
  26. xinference/model/llm/llm_family_modelscope.json +220 -27
  27. xinference/model/llm/pytorch/core.py +0 -80
  28. xinference/model/llm/sglang/core.py +7 -2
  29. xinference/model/llm/utils.py +4 -2
  30. xinference/model/llm/vllm/core.py +3 -0
  31. xinference/model/rerank/core.py +24 -25
  32. xinference/types.py +0 -1
  33. xinference/web/ui/build/asset-manifest.json +3 -3
  34. xinference/web/ui/build/index.html +1 -1
  35. xinference/web/ui/build/static/js/{main.0fb6f3ab.js → main.95c1d652.js} +3 -3
  36. xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
  43. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/METADATA +9 -11
  44. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/RECORD +49 -58
  45. xinference/model/llm/ggml/chatglm.py +0 -457
  46. xinference/thirdparty/ChatTTS/__init__.py +0 -1
  47. xinference/thirdparty/ChatTTS/core.py +0 -200
  48. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  49. xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
  50. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  51. xinference/thirdparty/ChatTTS/infer/api.py +0 -125
  52. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  53. xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
  54. xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
  55. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  56. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
  57. xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
  58. xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
  59. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +0 -1
  66. /xinference/web/ui/build/static/js/{main.0fb6f3ab.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
  67. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/LICENSE +0 -0
  68. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/WHEEL +0 -0
  69. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/entry_points.txt +0 -0
  70. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@
14
14
  import logging
15
15
  import os
16
16
  from collections import defaultdict
17
- from typing import Dict, List, Optional, Tuple, Union
17
+ from typing import Dict, List, Literal, Optional, Tuple, Union
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
@@ -94,7 +94,10 @@ def generate_audio_description(
94
94
  return res
95
95
 
96
96
 
97
- def match_audio(model_name: str) -> AudioModelFamilyV1:
97
+ def match_audio(
98
+ model_name: str,
99
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
100
+ ) -> AudioModelFamilyV1:
98
101
  from ..utils import download_from_modelscope
99
102
  from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
100
103
  from .custom import get_user_defined_audios
@@ -103,17 +106,17 @@ def match_audio(model_name: str) -> AudioModelFamilyV1:
103
106
  if model_spec.model_name == model_name:
104
107
  return model_spec
105
108
 
106
- if download_from_modelscope():
107
- if model_name in MODELSCOPE_AUDIO_MODELS:
108
- logger.debug(f"Audio model {model_name} found in ModelScope.")
109
- return MODELSCOPE_AUDIO_MODELS[model_name]
110
- else:
111
- logger.debug(
112
- f"Audio model {model_name} not found in ModelScope, "
113
- f"now try to load it via builtin way."
114
- )
115
-
116
- if model_name in BUILTIN_AUDIO_MODELS:
109
+ if download_hub == "huggingface" and model_name in BUILTIN_AUDIO_MODELS:
110
+ logger.debug(f"Audio model {model_name} found in huggingface.")
111
+ return BUILTIN_AUDIO_MODELS[model_name]
112
+ elif download_hub == "modelscope" and model_name in MODELSCOPE_AUDIO_MODELS:
113
+ logger.debug(f"Audio model {model_name} found in ModelScope.")
114
+ return MODELSCOPE_AUDIO_MODELS[model_name]
115
+ elif download_from_modelscope() and model_name in MODELSCOPE_AUDIO_MODELS:
116
+ logger.debug(f"Audio model {model_name} found in ModelScope.")
117
+ return MODELSCOPE_AUDIO_MODELS[model_name]
118
+ elif model_name in BUILTIN_AUDIO_MODELS:
119
+ logger.debug(f"Audio model {model_name} found in huggingface.")
117
120
  return BUILTIN_AUDIO_MODELS[model_name]
118
121
  else:
119
122
  raise ValueError(
@@ -141,9 +144,14 @@ def get_cache_status(
141
144
 
142
145
 
143
146
  def create_audio_model_instance(
144
- subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
147
+ subpool_addr: str,
148
+ devices: List[str],
149
+ model_uid: str,
150
+ model_name: str,
151
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
152
+ **kwargs,
145
153
  ) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
146
- model_spec = match_audio(model_name)
154
+ model_spec = match_audio(model_name, download_hub)
147
155
  model_path = cache(model_spec)
148
156
  model: Union[WhisperModel, ChatTTSModel]
149
157
  if model_spec.model_family == "whisper":
xinference/model/core.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any, List, Optional, Tuple, Union
16
+ from typing import Any, List, Literal, Optional, Tuple, Union
17
17
 
18
18
  from .._compat import BaseModel
19
19
  from ..types import PeftModelConfig
@@ -55,10 +55,12 @@ def create_model_instance(
55
55
  model_size_in_billions: Optional[Union[int, str]] = None,
56
56
  quantization: Optional[str] = None,
57
57
  peft_model_config: Optional[PeftModelConfig] = None,
58
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
58
59
  **kwargs,
59
60
  ) -> Tuple[Any, ModelDescription]:
60
61
  from .audio.core import create_audio_model_instance
61
62
  from .embedding.core import create_embedding_model_instance
63
+ from .flexible.core import create_flexible_model_instance
62
64
  from .image.core import create_image_model_instance
63
65
  from .llm.core import create_llm_model_instance
64
66
  from .rerank.core import create_rerank_model_instance
@@ -74,13 +76,14 @@ def create_model_instance(
74
76
  model_size_in_billions,
75
77
  quantization,
76
78
  peft_model_config,
79
+ download_hub,
77
80
  **kwargs,
78
81
  )
79
82
  elif model_type == "embedding":
80
83
  # embedding model doesn't accept trust_remote_code
81
84
  kwargs.pop("trust_remote_code", None)
82
85
  return create_embedding_model_instance(
83
- subpool_addr, devices, model_uid, model_name, **kwargs
86
+ subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
84
87
  )
85
88
  elif model_type == "image":
86
89
  kwargs.pop("trust_remote_code", None)
@@ -90,16 +93,22 @@ def create_model_instance(
90
93
  model_uid,
91
94
  model_name,
92
95
  peft_model_config,
96
+ download_hub,
93
97
  **kwargs,
94
98
  )
95
99
  elif model_type == "rerank":
96
100
  kwargs.pop("trust_remote_code", None)
97
101
  return create_rerank_model_instance(
98
- subpool_addr, devices, model_uid, model_name, **kwargs
102
+ subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
99
103
  )
100
104
  elif model_type == "audio":
101
105
  kwargs.pop("trust_remote_code", None)
102
106
  return create_audio_model_instance(
107
+ subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
108
+ )
109
+ elif model_type == "flexible":
110
+ kwargs.pop("trust_remote_code", None)
111
+ return create_flexible_model_instance(
103
112
  subpool_addr, devices, model_uid, model_name, **kwargs
104
113
  )
105
114
  else:
@@ -16,7 +16,7 @@ import gc
16
16
  import logging
17
17
  import os
18
18
  from collections import defaultdict
19
- from typing import Dict, List, Optional, Tuple, Union, no_type_check
19
+ from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check
20
20
 
21
21
  import numpy as np
22
22
 
@@ -305,7 +305,10 @@ class EmbeddingModel:
305
305
  )
306
306
 
307
307
 
308
- def match_embedding(model_name: str) -> EmbeddingModelSpec:
308
+ def match_embedding(
309
+ model_name: str,
310
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
311
+ ) -> EmbeddingModelSpec:
309
312
  from ..utils import download_from_modelscope
310
313
  from . import BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS
311
314
  from .custom import get_user_defined_embeddings
@@ -315,29 +318,35 @@ def match_embedding(model_name: str) -> EmbeddingModelSpec:
315
318
  if model_name == model_spec.model_name:
316
319
  return model_spec
317
320
 
318
- if download_from_modelscope():
319
- if model_name in MODELSCOPE_EMBEDDING_MODELS:
320
- logger.debug(f"Embedding model {model_name} found in ModelScope.")
321
- return MODELSCOPE_EMBEDDING_MODELS[model_name]
322
- else:
323
- logger.debug(
324
- f"Embedding model {model_name} not found in ModelScope, "
325
- f"now try to load it via builtin way."
326
- )
327
-
328
- if model_name in BUILTIN_EMBEDDING_MODELS:
321
+ if download_hub == "modelscope" and model_name in MODELSCOPE_EMBEDDING_MODELS:
322
+ logger.debug(f"Embedding model {model_name} found in ModelScope.")
323
+ return MODELSCOPE_EMBEDDING_MODELS[model_name]
324
+ elif download_hub == "huggingface" and model_name in BUILTIN_EMBEDDING_MODELS:
325
+ logger.debug(f"Embedding model {model_name} found in Huggingface.")
326
+ return BUILTIN_EMBEDDING_MODELS[model_name]
327
+ elif download_from_modelscope() and model_name in MODELSCOPE_EMBEDDING_MODELS:
328
+ logger.debug(f"Embedding model {model_name} found in ModelScope.")
329
+ return MODELSCOPE_EMBEDDING_MODELS[model_name]
330
+ elif model_name in BUILTIN_EMBEDDING_MODELS:
331
+ logger.debug(f"Embedding model {model_name} found in Huggingface.")
329
332
  return BUILTIN_EMBEDDING_MODELS[model_name]
330
333
  else:
331
334
  raise ValueError(
332
335
  f"Embedding model {model_name} not found, available"
333
- f"model list: {BUILTIN_EMBEDDING_MODELS.keys()}"
336
+ f"Huggingface: {BUILTIN_EMBEDDING_MODELS.keys()}"
337
+ f"ModelScope: {MODELSCOPE_EMBEDDING_MODELS.keys()}"
334
338
  )
335
339
 
336
340
 
337
341
  def create_embedding_model_instance(
338
- subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
342
+ subpool_addr: str,
343
+ devices: List[str],
344
+ model_uid: str,
345
+ model_name: str,
346
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
347
+ **kwargs,
339
348
  ) -> Tuple[EmbeddingModel, EmbeddingModelDescription]:
340
- model_spec = match_embedding(model_name)
349
+ model_spec = match_embedding(model_name, download_hub)
341
350
  model_path = cache(model_spec)
342
351
  model = EmbeddingModel(model_uid, model_path, **kwargs)
343
352
  model_description = EmbeddingModelDescription(
@@ -0,0 +1,40 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import codecs
16
+ import json
17
+ import os
18
+
19
+ from ...constants import XINFERENCE_MODEL_DIR
20
+ from .core import (
21
+ FLEXIBLE_MODEL_DESCRIPTIONS,
22
+ FlexibleModel,
23
+ FlexibleModelSpec,
24
+ generate_flexible_model_description,
25
+ get_flexible_model_descriptions,
26
+ get_flexible_models,
27
+ register_flexible_model,
28
+ unregister_flexible_model,
29
+ )
30
+
31
+ model_dir = os.path.join(XINFERENCE_MODEL_DIR, "flexible")
32
+ if os.path.isdir(model_dir):
33
+ for f in os.listdir(model_dir):
34
+ with codecs.open(os.path.join(model_dir, f), encoding="utf-8") as fd:
35
+ model_spec = FlexibleModelSpec.parse_obj(json.load(fd))
36
+ register_flexible_model(model_spec, persist=False)
37
+
38
+ # register model description
39
+ for model in get_flexible_models():
40
+ FLEXIBLE_MODEL_DESCRIPTIONS.update(generate_flexible_model_description(model))
@@ -0,0 +1,228 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ from collections import defaultdict
19
+ from threading import Lock
20
+ from typing import Dict, List, Optional, Tuple
21
+
22
+ from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
23
+ from ..core import CacheableModelSpec, ModelDescription
24
+ from .utils import get_launcher
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ FLEXIBLE_MODEL_LOCK = Lock()
29
+
30
+
31
+ class FlexibleModelSpec(CacheableModelSpec):
32
+ model_id: Optional[str] # type: ignore
33
+ model_description: Optional[str]
34
+ model_uri: Optional[str]
35
+ launcher: str
36
+ launcher_args: Optional[str]
37
+
38
+ def parser_args(self):
39
+ return json.loads(self.launcher_args)
40
+
41
+
42
+ class FlexibleModelDescription(ModelDescription):
43
+ def __init__(
44
+ self,
45
+ address: Optional[str],
46
+ devices: Optional[List[str]],
47
+ model_spec: FlexibleModelSpec,
48
+ model_path: Optional[str] = None,
49
+ ):
50
+ super().__init__(address, devices, model_path=model_path)
51
+ self._model_spec = model_spec
52
+
53
+ def to_dict(self):
54
+ return {
55
+ "model_type": "flexible",
56
+ "address": self.address,
57
+ "accelerators": self.devices,
58
+ "model_name": self._model_spec.model_name,
59
+ "launcher": self._model_spec.launcher,
60
+ "launcher_args": self._model_spec.launcher_args,
61
+ }
62
+
63
+ def get_model_version(self) -> str:
64
+ return f"{self._model_spec.model_name}"
65
+
66
+ def to_version_info(self):
67
+ return {
68
+ "model_version": self.get_model_version(),
69
+ "cache_status": True,
70
+ "model_file_location": self._model_spec.model_uri,
71
+ "launcher": self._model_spec.launcher,
72
+ "launcher_args": self._model_spec.launcher_args,
73
+ }
74
+
75
+
76
+ def generate_flexible_model_description(
77
+ model_spec: FlexibleModelSpec,
78
+ ) -> Dict[str, List[Dict]]:
79
+ res = defaultdict(list)
80
+ res[model_spec.model_name].append(
81
+ FlexibleModelDescription(None, None, model_spec).to_version_info()
82
+ )
83
+ return res
84
+
85
+
86
+ FLEXIBLE_MODELS: List[FlexibleModelSpec] = []
87
+ FLEXIBLE_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
88
+
89
+
90
+ def get_flexible_models():
91
+ with FLEXIBLE_MODEL_LOCK:
92
+ return FLEXIBLE_MODELS.copy()
93
+
94
+
95
+ def get_flexible_model_descriptions():
96
+ import copy
97
+
98
+ return copy.deepcopy(FLEXIBLE_MODEL_DESCRIPTIONS)
99
+
100
+
101
+ def register_flexible_model(model_spec: FlexibleModelSpec, persist: bool):
102
+ from ..utils import is_valid_model_name
103
+
104
+ if not is_valid_model_name(model_spec.model_name):
105
+ raise ValueError(f"Invalid model name {model_spec.model_name}.")
106
+
107
+ if model_spec.launcher_args:
108
+ try:
109
+ model_spec.parser_args()
110
+ except Exception:
111
+ raise ValueError(f"Invalid model launcher args {model_spec.launcher_args}.")
112
+
113
+ with FLEXIBLE_MODEL_LOCK:
114
+ for model_name in [spec.model_name for spec in FLEXIBLE_MODELS]:
115
+ if model_spec.model_name == model_name:
116
+ raise ValueError(
117
+ f"Model name conflicts with existing model {model_spec.model_name}"
118
+ )
119
+ FLEXIBLE_MODELS.append(model_spec)
120
+
121
+ if persist:
122
+ persist_path = os.path.join(
123
+ XINFERENCE_MODEL_DIR, "flexible", f"{model_spec.model_name}.json"
124
+ )
125
+ os.makedirs(os.path.dirname(persist_path), exist_ok=True)
126
+ with open(persist_path, mode="w") as fd:
127
+ fd.write(model_spec.json())
128
+
129
+
130
+ def unregister_flexible_model(model_name: str, raise_error: bool = True):
131
+ with FLEXIBLE_MODEL_LOCK:
132
+ model_spec = None
133
+ for i, f in enumerate(FLEXIBLE_MODELS):
134
+ if f.model_name == model_name:
135
+ model_spec = f
136
+ break
137
+ if model_spec:
138
+ FLEXIBLE_MODELS.remove(model_spec)
139
+
140
+ persist_path = os.path.join(
141
+ XINFERENCE_MODEL_DIR, "flexible", f"{model_spec.model_name}.json"
142
+ )
143
+ if os.path.exists(persist_path):
144
+ os.remove(persist_path)
145
+
146
+ cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
147
+ if os.path.exists(cache_dir):
148
+ logger.warning(
149
+ f"Remove the cache of user-defined model {model_spec.model_name}. "
150
+ f"Cache directory: {cache_dir}"
151
+ )
152
+ if os.path.islink(cache_dir):
153
+ os.remove(cache_dir)
154
+ else:
155
+ logger.warning(
156
+ f"Cache directory is not a soft link, please remove it manually."
157
+ )
158
+ else:
159
+ if raise_error:
160
+ raise ValueError(f"Model {model_name} not found")
161
+ else:
162
+ logger.warning(f"Model {model_name} not found")
163
+
164
+
165
+ class FlexibleModel:
166
+ def __init__(
167
+ self,
168
+ model_uid: str,
169
+ model_path: str,
170
+ device: Optional[str] = None,
171
+ config: Optional[Dict] = None,
172
+ ):
173
+ self._model_uid = model_uid
174
+ self._model_path = model_path
175
+ self._device = device
176
+ self._config = config
177
+
178
+ def load(self):
179
+ """
180
+ Load the model.
181
+ """
182
+
183
+ def infer(self, **kwargs):
184
+ """
185
+ Call model to inference.
186
+ """
187
+ raise NotImplementedError("infer method not implemented.")
188
+
189
+ @property
190
+ def model_uid(self):
191
+ return self._model_uid
192
+
193
+ @property
194
+ def model_path(self):
195
+ return self._model_path
196
+
197
+ @property
198
+ def device(self):
199
+ return self._device
200
+
201
+ @property
202
+ def config(self):
203
+ return self._config
204
+
205
+
206
+ def match_flexible_model(model_name):
207
+ for model_spec in get_flexible_models():
208
+ if model_name == model_spec.model_name:
209
+ return model_spec
210
+
211
+
212
+ def create_flexible_model_instance(
213
+ subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
214
+ ) -> Tuple[FlexibleModel, FlexibleModelDescription]:
215
+ model_spec = match_flexible_model(model_name)
216
+ model_path = model_spec.model_uri
217
+ launcher_name = model_spec.launcher
218
+ launcher_args = model_spec.parser_args()
219
+ kwargs.update(launcher_args)
220
+
221
+ model = get_launcher(launcher_name)(
222
+ model_uid=model_uid, model_spec=model_spec, **kwargs
223
+ )
224
+
225
+ model_description = FlexibleModelDescription(
226
+ subpool_addr, devices, model_spec, model_path=model_path
227
+ )
228
+ return model, model_description
@@ -0,0 +1,15 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .transformers_launcher import launcher as transformers
@@ -0,0 +1,63 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers import pipeline
16
+
17
+ from ..core import FlexibleModel, FlexibleModelSpec
18
+
19
+
20
+ class MockModel(FlexibleModel):
21
+ def infer(self, **kwargs):
22
+ return kwargs
23
+
24
+
25
+ class AutoModel(FlexibleModel):
26
+ def load(self):
27
+ config = self.config or {}
28
+ self._pipeline = pipeline(model=self.model_path, device=self.device, **config)
29
+
30
+ def infer(self, **kwargs):
31
+ return self._pipeline(**kwargs)
32
+
33
+
34
+ class TransformersTextClassificationModel(FlexibleModel):
35
+ def load(self):
36
+ config = self.config or {}
37
+
38
+ self._pipeline = pipeline(model=self._model_path, device=self._device, **config)
39
+
40
+ def infer(self, **kwargs):
41
+ return self._pipeline(**kwargs)
42
+
43
+
44
+ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
45
+ task = kwargs.get("task")
46
+ device = kwargs.get("device")
47
+
48
+ model_path = model_spec.model_uri
49
+ if model_path is None:
50
+ raise ValueError("model_path required")
51
+
52
+ if task == "text-classification":
53
+ return TransformersTextClassificationModel(
54
+ model_uid=model_uid, model_path=model_path, device=device, config=kwargs
55
+ )
56
+ elif task == "mock":
57
+ return MockModel(
58
+ model_uid=model_uid, model_path=model_path, device=device, config=kwargs
59
+ )
60
+ else:
61
+ return AutoModel(
62
+ model_uid=model_uid, model_path=model_path, device=device, config=kwargs
63
+ )
@@ -0,0 +1,33 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+
17
+
18
+ def get_launcher(launcher_name: str):
19
+ try:
20
+ i = launcher_name.rfind(".")
21
+ if i != -1:
22
+ module = importlib.import_module(launcher_name[:i])
23
+ fn = getattr(module, launcher_name[i + 1 :])
24
+ else:
25
+ importlib.import_module(launcher_name)
26
+ fn = locals().get(launcher_name)
27
+
28
+ if fn is None:
29
+ raise ValueError(f"Launcher {launcher_name} not found.")
30
+
31
+ return fn
32
+ except ImportError as e:
33
+ raise ImportError(f"Failed to import {launcher_name}: {e}")
@@ -15,7 +15,7 @@ import collections.abc
15
15
  import logging
16
16
  import os
17
17
  from collections import defaultdict
18
- from typing import Dict, List, Optional, Tuple
18
+ from typing import Dict, List, Literal, Optional, Tuple
19
19
 
20
20
  from ...constants import XINFERENCE_CACHE_DIR
21
21
  from ...types import PeftModelConfig
@@ -45,6 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
45
  model_id: str
46
46
  model_revision: str
47
47
  model_hub: str = "huggingface"
48
+ ability: Optional[str]
48
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
49
50
 
50
51
 
@@ -71,6 +72,7 @@ class ImageModelDescription(ModelDescription):
71
72
  "model_name": self._model_spec.model_name,
72
73
  "model_family": self._model_spec.model_family,
73
74
  "model_revision": self._model_spec.model_revision,
75
+ "ability": self._model_spec.ability,
74
76
  "controlnet": controlnet,
75
77
  }
76
78
 
@@ -117,7 +119,10 @@ def generate_image_description(
117
119
  return res
118
120
 
119
121
 
120
- def match_diffusion(model_name: str) -> ImageModelFamilyV1:
122
+ def match_diffusion(
123
+ model_name: str,
124
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
125
+ ) -> ImageModelFamilyV1:
121
126
  from ..utils import download_from_modelscope
122
127
  from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
123
128
  from .custom import get_user_defined_images
@@ -126,17 +131,17 @@ def match_diffusion(model_name: str) -> ImageModelFamilyV1:
126
131
  if model_spec.model_name == model_name:
127
132
  return model_spec
128
133
 
129
- if download_from_modelscope():
130
- if model_name in MODELSCOPE_IMAGE_MODELS:
131
- logger.debug(f"Image model {model_name} found in ModelScope.")
132
- return MODELSCOPE_IMAGE_MODELS[model_name]
133
- else:
134
- logger.debug(
135
- f"Image model {model_name} not found in ModelScope, "
136
- f"now try to load it via builtin way."
137
- )
138
-
139
- if model_name in BUILTIN_IMAGE_MODELS:
134
+ if download_hub == "modelscope" and model_name in MODELSCOPE_IMAGE_MODELS:
135
+ logger.debug(f"Image model {model_name} found in ModelScope.")
136
+ return MODELSCOPE_IMAGE_MODELS[model_name]
137
+ elif download_hub == "huggingface" and model_name in BUILTIN_IMAGE_MODELS:
138
+ logger.debug(f"Image model {model_name} found in Huggingface.")
139
+ return BUILTIN_IMAGE_MODELS[model_name]
140
+ elif download_from_modelscope() and model_name in MODELSCOPE_IMAGE_MODELS:
141
+ logger.debug(f"Image model {model_name} found in ModelScope.")
142
+ return MODELSCOPE_IMAGE_MODELS[model_name]
143
+ elif model_name in BUILTIN_IMAGE_MODELS:
144
+ logger.debug(f"Image model {model_name} found in Huggingface.")
140
145
  return BUILTIN_IMAGE_MODELS[model_name]
141
146
  else:
142
147
  raise ValueError(
@@ -183,9 +188,10 @@ def create_image_model_instance(
183
188
  model_uid: str,
184
189
  model_name: str,
185
190
  peft_model_config: Optional[PeftModelConfig] = None,
191
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
186
192
  **kwargs,
187
193
  ) -> Tuple[DiffusionModel, ImageModelDescription]:
188
- model_spec = match_diffusion(model_name)
194
+ model_spec = match_diffusion(model_name, download_hub)
189
195
  controlnet = kwargs.get("controlnet")
190
196
  # Handle controlnet
191
197
  if controlnet is not None:
@@ -230,6 +236,7 @@ def create_image_model_instance(
230
236
  lora_model_paths=lora_model,
231
237
  lora_load_kwargs=lora_load_kwargs,
232
238
  lora_fuse_kwargs=lora_fuse_kwargs,
239
+ ability=model_spec.ability,
233
240
  **kwargs,
234
241
  )
235
242
  model_description = ImageModelDescription(
@@ -66,7 +66,7 @@ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
66
66
  raise ValueError(f"Invalid model URI {model_uri}")
67
67
 
68
68
  persist_path = os.path.join(
69
- XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_id}.json"
69
+ XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_name}.json"
70
70
  )
71
71
  os.makedirs(os.path.dirname(persist_path), exist_ok=True)
72
72
  with open(persist_path, "w") as f: