xinference 0.13.3__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (48) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -1
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/constants.py +0 -4
  5. xinference/core/image_interface.py +6 -3
  6. xinference/core/model.py +1 -1
  7. xinference/core/supervisor.py +2 -0
  8. xinference/core/worker.py +7 -0
  9. xinference/deploy/utils.py +6 -0
  10. xinference/model/audio/core.py +4 -2
  11. xinference/model/core.py +25 -4
  12. xinference/model/embedding/core.py +88 -13
  13. xinference/model/embedding/model_spec.json +8 -0
  14. xinference/model/embedding/model_spec_modelscope.json +8 -0
  15. xinference/model/flexible/core.py +8 -2
  16. xinference/model/image/core.py +8 -5
  17. xinference/model/image/model_spec.json +30 -6
  18. xinference/model/image/model_spec_modelscope.json +21 -3
  19. xinference/model/image/stable_diffusion/core.py +30 -27
  20. xinference/model/llm/core.py +6 -4
  21. xinference/model/llm/ggml/llamacpp.py +7 -5
  22. xinference/model/llm/llm_family.py +6 -6
  23. xinference/model/llm/mlx/core.py +7 -0
  24. xinference/model/llm/pytorch/chatglm.py +4 -1
  25. xinference/model/llm/pytorch/deepseek_vl.py +2 -1
  26. xinference/model/llm/pytorch/falcon.py +2 -1
  27. xinference/model/llm/pytorch/llama_2.py +4 -2
  28. xinference/model/llm/pytorch/omnilmm.py +2 -1
  29. xinference/model/llm/pytorch/qwen_vl.py +2 -1
  30. xinference/model/llm/pytorch/vicuna.py +2 -1
  31. xinference/model/llm/pytorch/yi_vl.py +2 -1
  32. xinference/model/llm/sglang/core.py +12 -6
  33. xinference/model/llm/vllm/core.py +1 -5
  34. xinference/model/rerank/core.py +4 -3
  35. xinference/web/ui/build/asset-manifest.json +3 -3
  36. xinference/web/ui/build/index.html +1 -1
  37. xinference/web/ui/build/static/js/{main.2ef0cfaf.js → main.af906659.js} +3 -3
  38. xinference/web/ui/build/static/js/main.af906659.js.map +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +1 -0
  40. {xinference-0.13.3.dist-info → xinference-0.13.4.dist-info}/METADATA +24 -4
  41. {xinference-0.13.3.dist-info → xinference-0.13.4.dist-info}/RECORD +46 -46
  42. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +0 -1
  43. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +0 -1
  44. /xinference/web/ui/build/static/js/{main.2ef0cfaf.js.LICENSE.txt → main.af906659.js.LICENSE.txt} +0 -0
  45. {xinference-0.13.3.dist-info → xinference-0.13.4.dist-info}/LICENSE +0 -0
  46. {xinference-0.13.3.dist-info → xinference-0.13.4.dist-info}/WHEEL +0 -0
  47. {xinference-0.13.3.dist-info → xinference-0.13.4.dist-info}/entry_points.txt +0 -0
  48. {xinference-0.13.3.dist-info → xinference-0.13.4.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-26T18:42:50+0800",
11
+ "date": "2024-08-02T16:08:07+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "aa51ff22dbfb5644554436270deaf57a7ebaf066",
15
- "version": "0.13.3"
14
+ "full-revisionid": "dd85cfe015c9cd2d8110c79213640aa0e21f3a6a",
15
+ "version": "0.13.4"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -797,6 +797,7 @@ class RESTfulAPI:
797
797
  worker_ip = payload.get("worker_ip", None)
798
798
  gpu_idx = payload.get("gpu_idx", None)
799
799
  download_hub = payload.get("download_hub", None)
800
+ model_path = payload.get("model_path", None)
800
801
 
801
802
  exclude_keys = {
802
803
  "model_uid",
@@ -813,6 +814,7 @@ class RESTfulAPI:
813
814
  "worker_ip",
814
815
  "gpu_idx",
815
816
  "download_hub",
817
+ "model_path",
816
818
  }
817
819
 
818
820
  kwargs = {
@@ -861,6 +863,7 @@ class RESTfulAPI:
861
863
  worker_ip=worker_ip,
862
864
  gpu_idx=gpu_idx,
863
865
  download_hub=download_hub,
866
+ model_path=model_path,
864
867
  **kwargs,
865
868
  )
866
869
  except ValueError as ve:
@@ -1407,7 +1410,7 @@ class RESTfulAPI:
1407
1410
  negative_prompt: Optional[Union[str, List[str]]] = Form(None),
1408
1411
  n: Optional[int] = Form(1),
1409
1412
  response_format: Optional[str] = Form("url"),
1410
- size: Optional[str] = Form("1024*1024"),
1413
+ size: Optional[str] = Form(None),
1411
1414
  kwargs: Optional[str] = Form(None),
1412
1415
  ) -> Response:
1413
1416
  model_uid = model
@@ -234,9 +234,9 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
234
234
  self,
235
235
  image: Union[str, bytes],
236
236
  prompt: str,
237
- negative_prompt: str,
237
+ negative_prompt: Optional[str] = None,
238
238
  n: int = 1,
239
- size: str = "1024*1024",
239
+ size: Optional[str] = None,
240
240
  response_format: str = "url",
241
241
  **kwargs,
242
242
  ) -> "ImageList":
xinference/constants.py CHANGED
@@ -26,8 +26,6 @@ XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
26
26
  XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
27
27
  XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
28
28
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
29
- XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
30
- XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
31
29
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
32
30
  XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
33
31
 
@@ -72,8 +70,6 @@ XINFERENCE_HEALTH_CHECK_TIMEOUT = int(
72
70
  XINFERENCE_DISABLE_HEALTH_CHECK = bool(
73
71
  int(os.environ.get(XINFERENCE_ENV_DISABLE_HEALTH_CHECK, 0))
74
72
  )
75
- XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
76
- XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG, 0)))
77
73
  XINFERENCE_DISABLE_METRICS = bool(
78
74
  int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
79
75
  )
@@ -153,7 +153,10 @@ class ImageInterface:
153
153
  model = client.get_model(self.model_uid)
154
154
  assert isinstance(model, RESTfulImageModelHandle)
155
155
 
156
- size = f"{int(size_width)}*{int(size_height)}"
156
+ if size_width > 0 and size_height > 0:
157
+ size = f"{int(size_width)}*{int(size_height)}"
158
+ else:
159
+ size = None
157
160
 
158
161
  bio = io.BytesIO()
159
162
  image.save(bio, format="png")
@@ -195,8 +198,8 @@ class ImageInterface:
195
198
 
196
199
  with gr.Row():
197
200
  n = gr.Number(label="Number of image", value=1)
198
- size_width = gr.Number(label="Width", value=512)
199
- size_height = gr.Number(label="Height", value=512)
201
+ size_width = gr.Number(label="Width", value=-1)
202
+ size_height = gr.Number(label="Height", value=-1)
200
203
 
201
204
  with gr.Row():
202
205
  with gr.Column(scale=1):
xinference/core/model.py CHANGED
@@ -706,7 +706,7 @@ class ModelActor(xo.StatelessActor):
706
706
  prompt: str,
707
707
  negative_prompt: str,
708
708
  n: int = 1,
709
- size: str = "1024*1024",
709
+ size: Optional[str] = None,
710
710
  response_format: str = "url",
711
711
  *args,
712
712
  **kwargs,
@@ -859,6 +859,7 @@ class SupervisorActor(xo.StatelessActor):
859
859
  worker_ip: Optional[str] = None,
860
860
  gpu_idx: Optional[Union[int, List[int]]] = None,
861
861
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
862
+ model_path: Optional[str] = None,
862
863
  **kwargs,
863
864
  ) -> str:
864
865
  # search in worker first
@@ -942,6 +943,7 @@ class SupervisorActor(xo.StatelessActor):
942
943
  peft_model_config=peft_model_config,
943
944
  gpu_idx=replica_gpu_idx,
944
945
  download_hub=download_hub,
946
+ model_path=model_path,
945
947
  **kwargs,
946
948
  )
947
949
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
xinference/core/worker.py CHANGED
@@ -743,6 +743,7 @@ class WorkerActor(xo.StatelessActor):
743
743
  request_limits: Optional[int] = None,
744
744
  gpu_idx: Optional[Union[int, List[int]]] = None,
745
745
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
746
+ model_path: Optional[str] = None,
746
747
  **kwargs,
747
748
  ):
748
749
  # !!! Note that The following code must be placed at the very beginning of this function,
@@ -799,6 +800,11 @@ class WorkerActor(xo.StatelessActor):
799
800
  raise ValueError(
800
801
  f"PEFT adaptors can only be applied to pytorch-like models"
801
802
  )
803
+ if model_path is not None:
804
+ if not os.path.exists(model_path):
805
+ raise ValueError(
806
+ f"Invalid input. `model_path`: {model_path} File or directory does not exist."
807
+ )
802
808
 
803
809
  assert model_uid not in self._model_uid_to_model
804
810
  self._check_model_is_valid(model_name, model_format)
@@ -826,6 +832,7 @@ class WorkerActor(xo.StatelessActor):
826
832
  quantization,
827
833
  peft_model_config,
828
834
  download_hub,
835
+ model_path,
829
836
  **kwargs,
830
837
  )
831
838
  await self.update_cache_status(model_name, model_description)
@@ -27,6 +27,9 @@ if TYPE_CHECKING:
27
27
 
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
+ # mainly for k8s
31
+ XINFERENCE_POD_NAME_ENV_KEY = "XINFERENCE_POD_NAME"
32
+
30
33
 
31
34
  class LoggerNameFilter(logging.Filter):
32
35
  def filter(self, record):
@@ -40,6 +43,9 @@ def get_log_file(sub_dir: str):
40
43
  """
41
44
  sub_dir should contain a timestamp.
42
45
  """
46
+ pod_name = os.environ.get(XINFERENCE_POD_NAME_ENV_KEY, None)
47
+ if pod_name is not None:
48
+ sub_dir = sub_dir + "_" + pod_name
43
49
  log_dir = os.path.join(XINFERENCE_LOG_DIR, sub_dir)
44
50
  # Here should be creating a new directory each time, so `exist_ok=False`
45
51
  os.makedirs(log_dir, exist_ok=False)
@@ -150,10 +150,12 @@ def create_audio_model_instance(
150
150
  model_uid: str,
151
151
  model_name: str,
152
152
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
153
+ model_path: Optional[str] = None,
153
154
  **kwargs,
154
155
  ) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
155
156
  model_spec = match_audio(model_name, download_hub)
156
- model_path = cache(model_spec)
157
+ if model_path is None:
158
+ model_path = cache(model_spec)
157
159
  model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
158
160
  if model_spec.model_family == "whisper":
159
161
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
@@ -164,6 +166,6 @@ def create_audio_model_instance(
164
166
  else:
165
167
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
166
168
  model_description = AudioModelDescription(
167
- subpool_addr, devices, model_spec, model_path=model_path
169
+ subpool_addr, devices, model_spec, model_path
168
170
  )
169
171
  return model, model_description
xinference/model/core.py CHANGED
@@ -56,6 +56,7 @@ def create_model_instance(
56
56
  quantization: Optional[str] = None,
57
57
  peft_model_config: Optional[PeftModelConfig] = None,
58
58
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
59
+ model_path: Optional[str] = None,
59
60
  **kwargs,
60
61
  ) -> Tuple[Any, ModelDescription]:
61
62
  from .audio.core import create_audio_model_instance
@@ -77,13 +78,20 @@ def create_model_instance(
77
78
  quantization,
78
79
  peft_model_config,
79
80
  download_hub,
81
+ model_path,
80
82
  **kwargs,
81
83
  )
82
84
  elif model_type == "embedding":
83
85
  # embedding model doesn't accept trust_remote_code
84
86
  kwargs.pop("trust_remote_code", None)
85
87
  return create_embedding_model_instance(
86
- subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
88
+ subpool_addr,
89
+ devices,
90
+ model_uid,
91
+ model_name,
92
+ download_hub,
93
+ model_path,
94
+ **kwargs,
87
95
  )
88
96
  elif model_type == "image":
89
97
  kwargs.pop("trust_remote_code", None)
@@ -94,22 +102,35 @@ def create_model_instance(
94
102
  model_name,
95
103
  peft_model_config,
96
104
  download_hub,
105
+ model_path,
97
106
  **kwargs,
98
107
  )
99
108
  elif model_type == "rerank":
100
109
  kwargs.pop("trust_remote_code", None)
101
110
  return create_rerank_model_instance(
102
- subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
111
+ subpool_addr,
112
+ devices,
113
+ model_uid,
114
+ model_name,
115
+ download_hub,
116
+ model_path,
117
+ **kwargs,
103
118
  )
104
119
  elif model_type == "audio":
105
120
  kwargs.pop("trust_remote_code", None)
106
121
  return create_audio_model_instance(
107
- subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
122
+ subpool_addr,
123
+ devices,
124
+ model_uid,
125
+ model_name,
126
+ download_hub,
127
+ model_path,
128
+ **kwargs,
108
129
  )
109
130
  elif model_type == "flexible":
110
131
  kwargs.pop("trust_remote_code", None)
111
132
  return create_flexible_model_instance(
112
- subpool_addr, devices, model_uid, model_name, **kwargs
133
+ subpool_addr, devices, model_uid, model_name, model_path, **kwargs
113
134
  )
114
135
  else:
115
136
  raise ValueError(f"Unsupported model type: {model_type}.")
@@ -118,12 +118,19 @@ def get_cache_status(
118
118
 
119
119
 
120
120
  class EmbeddingModel:
121
- def __init__(self, model_uid: str, model_path: str, device: Optional[str] = None):
121
+ def __init__(
122
+ self,
123
+ model_uid: str,
124
+ model_path: str,
125
+ model_spec: EmbeddingModelSpec,
126
+ device: Optional[str] = None,
127
+ ):
122
128
  self._model_uid = model_uid
123
129
  self._model_path = model_path
124
130
  self._device = device
125
131
  self._model = None
126
132
  self._counter = 0
133
+ self._model_spec = model_spec
127
134
 
128
135
  def load(self):
129
136
  try:
@@ -134,12 +141,26 @@ class EmbeddingModel:
134
141
  "Please make sure 'sentence-transformers' is installed. ",
135
142
  "You can install it by `pip install sentence-transformers`\n",
136
143
  ]
137
-
138
144
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
145
+
146
+ class XSentenceTransformer(SentenceTransformer):
147
+ def to(self, *args, **kwargs):
148
+ pass
149
+
139
150
  from ..utils import patch_trust_remote_code
140
151
 
141
152
  patch_trust_remote_code()
142
- self._model = SentenceTransformer(self._model_path, device=self._device)
153
+ if (
154
+ "gte-Qwen2" in self._model_spec.model_id
155
+ or "gte-Qwen2" in self._model_spec.model_name
156
+ ):
157
+ self._model = XSentenceTransformer(
158
+ self._model_path,
159
+ device=self._device,
160
+ model_kwargs={"device_map": "auto"},
161
+ )
162
+ else:
163
+ self._model = SentenceTransformer(self._model_path, device=self._device)
143
164
 
144
165
  def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
145
166
  self._counter += 1
@@ -156,6 +177,8 @@ class EmbeddingModel:
156
177
  def encode(
157
178
  model: SentenceTransformer,
158
179
  sentences: Union[str, List[str]],
180
+ prompt_name: Optional[str] = None,
181
+ prompt: Optional[str] = None,
159
182
  batch_size: int = 32,
160
183
  show_progress_bar: bool = None,
161
184
  output_value: str = "sentence_embedding",
@@ -204,10 +227,43 @@ class EmbeddingModel:
204
227
  sentences = [sentences]
205
228
  input_was_string = True
206
229
 
230
+ if prompt is None:
231
+ if prompt_name is not None:
232
+ try:
233
+ prompt = model.prompts[prompt_name]
234
+ except KeyError:
235
+ raise ValueError(
236
+ f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
237
+ )
238
+ elif model.default_prompt_name is not None:
239
+ prompt = model.prompts.get(model.default_prompt_name, None)
240
+ else:
241
+ if prompt_name is not None:
242
+ logger.warning(
243
+ "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
244
+ "Ignoring the `prompt_name` in favor of `prompt`."
245
+ )
246
+
247
+ extra_features = {}
248
+ if prompt is not None:
249
+ sentences = [prompt + sentence for sentence in sentences]
250
+
251
+ # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
252
+ # Tracking the prompt length allow us to remove the prompt during pooling
253
+ tokenized_prompt = model.tokenize([prompt])
254
+ if "input_ids" in tokenized_prompt:
255
+ extra_features["prompt_length"] = (
256
+ tokenized_prompt["input_ids"].shape[-1] - 1
257
+ )
258
+
207
259
  if device is None:
208
260
  device = model._target_device
209
261
 
210
- model.to(device)
262
+ if (
263
+ "gte-Qwen2" not in self._model_spec.model_id
264
+ and "gte-Qwen2" not in self._model_spec.model_name
265
+ ):
266
+ model.to(device)
211
267
 
212
268
  all_embeddings = []
213
269
  all_token_nums = 0
@@ -228,6 +284,7 @@ class EmbeddingModel:
228
284
  ]
229
285
  features = model.tokenize(sentences_batch)
230
286
  features = batch_to_device(features, device)
287
+ features.update(extra_features)
231
288
  all_token_nums += sum([len(f) for f in features])
232
289
 
233
290
  with torch.no_grad():
@@ -272,7 +329,10 @@ class EmbeddingModel:
272
329
  ]
273
330
 
274
331
  if convert_to_tensor:
275
- all_embeddings = torch.stack(all_embeddings)
332
+ if len(all_embeddings):
333
+ all_embeddings = torch.stack(all_embeddings)
334
+ else:
335
+ all_embeddings = torch.Tensor()
276
336
  elif convert_to_numpy:
277
337
  all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
278
338
 
@@ -281,12 +341,24 @@ class EmbeddingModel:
281
341
 
282
342
  return all_embeddings, all_token_nums
283
343
 
284
- all_embeddings, all_token_nums = encode(
285
- self._model,
286
- sentences,
287
- convert_to_numpy=False,
288
- **kwargs,
289
- )
344
+ if (
345
+ "gte-Qwen2" in self._model_spec.model_id
346
+ or "gte-Qwen2" in self._model_spec.model_name
347
+ ):
348
+ all_embeddings, all_token_nums = encode(
349
+ self._model,
350
+ sentences,
351
+ prompt_name="query",
352
+ convert_to_numpy=False,
353
+ **kwargs,
354
+ )
355
+ else:
356
+ all_embeddings, all_token_nums = encode(
357
+ self._model,
358
+ sentences,
359
+ convert_to_numpy=False,
360
+ **kwargs,
361
+ )
290
362
  if isinstance(sentences, str):
291
363
  all_embeddings = [all_embeddings]
292
364
  embedding_list = []
@@ -344,11 +416,14 @@ def create_embedding_model_instance(
344
416
  model_uid: str,
345
417
  model_name: str,
346
418
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
419
+ model_path: Optional[str] = None,
347
420
  **kwargs,
348
421
  ) -> Tuple[EmbeddingModel, EmbeddingModelDescription]:
349
422
  model_spec = match_embedding(model_name, download_hub)
350
- model_path = cache(model_spec)
351
- model = EmbeddingModel(model_uid, model_path, **kwargs)
423
+ if model_path is None:
424
+ model_path = cache(model_spec)
425
+
426
+ model = EmbeddingModel(model_uid, model_path, model_spec, **kwargs)
352
427
  model_description = EmbeddingModelDescription(
353
428
  subpool_addr, devices, model_spec, model_path=model_path
354
429
  )
@@ -230,5 +230,13 @@
230
230
  "language": ["zh", "en"],
231
231
  "model_id": "moka-ai/m3e-large",
232
232
  "model_revision": "12900375086c37ba5d83d1e417b21dc7d1d1f388"
233
+ },
234
+ {
235
+ "model_name": "gte-Qwen2",
236
+ "dimensions": 3584,
237
+ "max_tokens": 32000,
238
+ "language": ["zh", "en"],
239
+ "model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
240
+ "model_revision": "e26182b2122f4435e8b3ebecbf363990f409b45b"
233
241
  }
234
242
  ]
@@ -232,5 +232,13 @@
232
232
  "language": ["zh", "en"],
233
233
  "model_id": "AI-ModelScope/m3e-large",
234
234
  "model_hub": "modelscope"
235
+ },
236
+ {
237
+ "model_name": "gte-Qwen2",
238
+ "dimensions": 4096,
239
+ "max_tokens": 32000,
240
+ "language": ["zh", "en"],
241
+ "model_id": "iic/gte_Qwen2-7B-instruct",
242
+ "model_hub": "modelscope"
235
243
  }
236
244
  ]
@@ -210,10 +210,16 @@ def match_flexible_model(model_name):
210
210
 
211
211
 
212
212
  def create_flexible_model_instance(
213
- subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
213
+ subpool_addr: str,
214
+ devices: List[str],
215
+ model_uid: str,
216
+ model_name: str,
217
+ model_path: Optional[str] = None,
218
+ **kwargs,
214
219
  ) -> Tuple[FlexibleModel, FlexibleModelDescription]:
215
220
  model_spec = match_flexible_model(model_name)
216
- model_path = model_spec.model_uri
221
+ if not model_path:
222
+ model_path = model_spec.model_uri
217
223
  launcher_name = model_spec.launcher
218
224
  launcher_args = model_spec.parser_args()
219
225
  kwargs.update(launcher_args)
@@ -45,7 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
45
  model_id: str
46
46
  model_revision: str
47
47
  model_hub: str = "huggingface"
48
- ability: Optional[str]
48
+ abilities: Optional[List[str]]
49
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
50
50
 
51
51
 
@@ -72,7 +72,7 @@ class ImageModelDescription(ModelDescription):
72
72
  "model_name": self._model_spec.model_name,
73
73
  "model_family": self._model_spec.model_family,
74
74
  "model_revision": self._model_spec.model_revision,
75
- "ability": self._model_spec.ability,
75
+ "abilities": self._model_spec.abilities,
76
76
  "controlnet": controlnet,
77
77
  }
78
78
 
@@ -189,6 +189,7 @@ def create_image_model_instance(
189
189
  model_name: str,
190
190
  peft_model_config: Optional[PeftModelConfig] = None,
191
191
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
192
+ model_path: Optional[str] = None,
192
193
  **kwargs,
193
194
  ) -> Tuple[DiffusionModel, ImageModelDescription]:
194
195
  model_spec = match_diffusion(model_name, download_hub)
@@ -209,7 +210,8 @@ def create_image_model_instance(
209
210
  for name in controlnet:
210
211
  for cn_model_spec in model_spec.controlnet:
211
212
  if cn_model_spec.model_name == name:
212
- model_path = cache(cn_model_spec)
213
+ if not model_path:
214
+ model_path = cache(cn_model_spec)
213
215
  controlnet_model_paths.append(model_path)
214
216
  break
215
217
  else:
@@ -220,7 +222,8 @@ def create_image_model_instance(
220
222
  kwargs["controlnet"] = controlnet_model_paths[0]
221
223
  else:
222
224
  kwargs["controlnet"] = controlnet_model_paths
223
- model_path = cache(model_spec)
225
+ if not model_path:
226
+ model_path = cache(model_spec)
224
227
  if peft_model_config is not None:
225
228
  lora_model = peft_model_config.peft_model
226
229
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -236,7 +239,7 @@ def create_image_model_instance(
236
239
  lora_model_paths=lora_model,
237
240
  lora_load_kwargs=lora_load_kwargs,
238
241
  lora_fuse_kwargs=lora_fuse_kwargs,
239
- ability=model_spec.ability,
242
+ abilities=model_spec.abilities,
240
243
  **kwargs,
241
244
  )
242
245
  model_description = ImageModelDescription(
@@ -3,25 +3,39 @@
3
3
  "model_name": "sd3-medium",
4
4
  "model_family": "stable_diffusion",
5
5
  "model_id": "stabilityai/stable-diffusion-3-medium-diffusers",
6
- "model_revision": "ea42f8cef0f178587cf766dc8129abd379c90671"
6
+ "model_revision": "ea42f8cef0f178587cf766dc8129abd379c90671",
7
+ "abilities": [
8
+ "text2iamge",
9
+ "image2image"
10
+ ]
7
11
  },
8
12
  {
9
13
  "model_name": "sd-turbo",
10
14
  "model_family": "stable_diffusion",
11
15
  "model_id": "stabilityai/sd-turbo",
12
- "model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c"
16
+ "model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c",
17
+ "abilities": [
18
+ "text2iamge"
19
+ ]
13
20
  },
14
21
  {
15
22
  "model_name": "sdxl-turbo",
16
23
  "model_family": "stable_diffusion",
17
24
  "model_id": "stabilityai/sdxl-turbo",
18
- "model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b"
25
+ "model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b",
26
+ "abilities": [
27
+ "text2iamge"
28
+ ]
19
29
  },
20
30
  {
21
31
  "model_name": "stable-diffusion-v1.5",
22
32
  "model_family": "stable_diffusion",
23
33
  "model_id": "runwayml/stable-diffusion-v1-5",
24
34
  "model_revision": "1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9",
35
+ "abilities": [
36
+ "text2iamge",
37
+ "image2image"
38
+ ],
25
39
  "controlnet": [
26
40
  {
27
41
  "model_name":"canny",
@@ -72,6 +86,10 @@
72
86
  "model_family": "stable_diffusion",
73
87
  "model_id": "stabilityai/stable-diffusion-xl-base-1.0",
74
88
  "model_revision": "f898a3e026e802f68796b95e9702464bac78d76f",
89
+ "abilities": [
90
+ "text2iamge",
91
+ "image2image"
92
+ ],
75
93
  "controlnet": [
76
94
  {
77
95
  "model_name":"canny",
@@ -98,20 +116,26 @@
98
116
  "model_family": "stable_diffusion",
99
117
  "model_id": "runwayml/stable-diffusion-inpainting",
100
118
  "model_revision": "51388a731f57604945fddd703ecb5c50e8e7b49d",
101
- "ability": "inpainting"
119
+ "abilities": [
120
+ "inpainting"
121
+ ]
102
122
  },
103
123
  {
104
124
  "model_name": "stable-diffusion-2-inpainting",
105
125
  "model_family": "stable_diffusion",
106
126
  "model_id": "stabilityai/stable-diffusion-2-inpainting",
107
127
  "model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
108
- "ability": "inpainting"
128
+ "abilities": [
129
+ "inpainting"
130
+ ]
109
131
  },
110
132
  {
111
133
  "model_name": "stable-diffusion-xl-inpainting",
112
134
  "model_family": "stable_diffusion",
113
135
  "model_id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
114
136
  "model_revision": "115134f363124c53c7d878647567d04daf26e41e",
115
- "ability": "inpainting"
137
+ "abilities": [
138
+ "inpainting"
139
+ ]
116
140
  }
117
141
  ]