xinference 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (103) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +30 -5
  4. xinference/client/restful/restful_client.py +18 -3
  5. xinference/constants.py +0 -4
  6. xinference/core/chat_interface.py +2 -2
  7. xinference/core/image_interface.py +6 -3
  8. xinference/core/model.py +9 -4
  9. xinference/core/scheduler.py +4 -4
  10. xinference/core/supervisor.py +2 -0
  11. xinference/core/worker.py +7 -0
  12. xinference/deploy/utils.py +6 -0
  13. xinference/model/audio/core.py +9 -4
  14. xinference/model/audio/cosyvoice.py +136 -0
  15. xinference/model/audio/model_spec.json +24 -0
  16. xinference/model/audio/model_spec_modelscope.json +27 -0
  17. xinference/model/core.py +25 -4
  18. xinference/model/embedding/core.py +88 -13
  19. xinference/model/embedding/model_spec.json +8 -0
  20. xinference/model/embedding/model_spec_modelscope.json +8 -0
  21. xinference/model/flexible/core.py +8 -2
  22. xinference/model/flexible/launchers/__init__.py +1 -0
  23. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  24. xinference/model/image/core.py +8 -5
  25. xinference/model/image/model_spec.json +36 -5
  26. xinference/model/image/model_spec_modelscope.json +21 -3
  27. xinference/model/image/stable_diffusion/core.py +36 -28
  28. xinference/model/llm/core.py +6 -4
  29. xinference/model/llm/ggml/llamacpp.py +7 -5
  30. xinference/model/llm/llm_family.json +802 -82
  31. xinference/model/llm/llm_family.py +6 -6
  32. xinference/model/llm/llm_family_csghub.json +39 -0
  33. xinference/model/llm/llm_family_modelscope.json +295 -47
  34. xinference/model/llm/mlx/core.py +7 -0
  35. xinference/model/llm/pytorch/chatglm.py +246 -5
  36. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  37. xinference/model/llm/pytorch/deepseek_vl.py +2 -1
  38. xinference/model/llm/pytorch/falcon.py +2 -1
  39. xinference/model/llm/pytorch/llama_2.py +4 -2
  40. xinference/model/llm/pytorch/omnilmm.py +2 -1
  41. xinference/model/llm/pytorch/qwen_vl.py +2 -1
  42. xinference/model/llm/pytorch/vicuna.py +2 -1
  43. xinference/model/llm/pytorch/yi_vl.py +2 -1
  44. xinference/model/llm/sglang/core.py +12 -6
  45. xinference/model/llm/utils.py +78 -1
  46. xinference/model/llm/vllm/core.py +9 -5
  47. xinference/model/rerank/core.py +4 -3
  48. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  50. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  51. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  52. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  53. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  54. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  55. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  56. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  57. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  58. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  59. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  61. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  63. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  64. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  65. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  66. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  67. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  68. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  69. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  70. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  71. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  72. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  73. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  74. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  77. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  78. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  79. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  80. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  81. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  82. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  83. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  84. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  85. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  86. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  87. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  88. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  89. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  90. xinference/web/ui/build/asset-manifest.json +3 -3
  91. xinference/web/ui/build/index.html +1 -1
  92. xinference/web/ui/build/static/js/{main.95c1d652.js → main.af906659.js} +3 -3
  93. xinference/web/ui/build/static/js/main.af906659.js.map +1 -0
  94. xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +1 -0
  95. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/METADATA +39 -11
  96. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/RECORD +101 -57
  97. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  99. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.af906659.js.LICENSE.txt} +0 -0
  100. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/LICENSE +0 -0
  101. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/WHEEL +0 -0
  102. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/entry_points.txt +0 -0
  103. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/top_level.txt +0 -0
@@ -118,12 +118,19 @@ def get_cache_status(
118
118
 
119
119
 
120
120
  class EmbeddingModel:
121
- def __init__(self, model_uid: str, model_path: str, device: Optional[str] = None):
121
+ def __init__(
122
+ self,
123
+ model_uid: str,
124
+ model_path: str,
125
+ model_spec: EmbeddingModelSpec,
126
+ device: Optional[str] = None,
127
+ ):
122
128
  self._model_uid = model_uid
123
129
  self._model_path = model_path
124
130
  self._device = device
125
131
  self._model = None
126
132
  self._counter = 0
133
+ self._model_spec = model_spec
127
134
 
128
135
  def load(self):
129
136
  try:
@@ -134,12 +141,26 @@ class EmbeddingModel:
134
141
  "Please make sure 'sentence-transformers' is installed. ",
135
142
  "You can install it by `pip install sentence-transformers`\n",
136
143
  ]
137
-
138
144
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
145
+
146
+ class XSentenceTransformer(SentenceTransformer):
147
+ def to(self, *args, **kwargs):
148
+ pass
149
+
139
150
  from ..utils import patch_trust_remote_code
140
151
 
141
152
  patch_trust_remote_code()
142
- self._model = SentenceTransformer(self._model_path, device=self._device)
153
+ if (
154
+ "gte-Qwen2" in self._model_spec.model_id
155
+ or "gte-Qwen2" in self._model_spec.model_name
156
+ ):
157
+ self._model = XSentenceTransformer(
158
+ self._model_path,
159
+ device=self._device,
160
+ model_kwargs={"device_map": "auto"},
161
+ )
162
+ else:
163
+ self._model = SentenceTransformer(self._model_path, device=self._device)
143
164
 
144
165
  def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
145
166
  self._counter += 1
@@ -156,6 +177,8 @@ class EmbeddingModel:
156
177
  def encode(
157
178
  model: SentenceTransformer,
158
179
  sentences: Union[str, List[str]],
180
+ prompt_name: Optional[str] = None,
181
+ prompt: Optional[str] = None,
159
182
  batch_size: int = 32,
160
183
  show_progress_bar: bool = None,
161
184
  output_value: str = "sentence_embedding",
@@ -204,10 +227,43 @@ class EmbeddingModel:
204
227
  sentences = [sentences]
205
228
  input_was_string = True
206
229
 
230
+ if prompt is None:
231
+ if prompt_name is not None:
232
+ try:
233
+ prompt = model.prompts[prompt_name]
234
+ except KeyError:
235
+ raise ValueError(
236
+ f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
237
+ )
238
+ elif model.default_prompt_name is not None:
239
+ prompt = model.prompts.get(model.default_prompt_name, None)
240
+ else:
241
+ if prompt_name is not None:
242
+ logger.warning(
243
+ "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
244
+ "Ignoring the `prompt_name` in favor of `prompt`."
245
+ )
246
+
247
+ extra_features = {}
248
+ if prompt is not None:
249
+ sentences = [prompt + sentence for sentence in sentences]
250
+
251
+ # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
252
+ # Tracking the prompt length allow us to remove the prompt during pooling
253
+ tokenized_prompt = model.tokenize([prompt])
254
+ if "input_ids" in tokenized_prompt:
255
+ extra_features["prompt_length"] = (
256
+ tokenized_prompt["input_ids"].shape[-1] - 1
257
+ )
258
+
207
259
  if device is None:
208
260
  device = model._target_device
209
261
 
210
- model.to(device)
262
+ if (
263
+ "gte-Qwen2" not in self._model_spec.model_id
264
+ and "gte-Qwen2" not in self._model_spec.model_name
265
+ ):
266
+ model.to(device)
211
267
 
212
268
  all_embeddings = []
213
269
  all_token_nums = 0
@@ -228,6 +284,7 @@ class EmbeddingModel:
228
284
  ]
229
285
  features = model.tokenize(sentences_batch)
230
286
  features = batch_to_device(features, device)
287
+ features.update(extra_features)
231
288
  all_token_nums += sum([len(f) for f in features])
232
289
 
233
290
  with torch.no_grad():
@@ -272,7 +329,10 @@ class EmbeddingModel:
272
329
  ]
273
330
 
274
331
  if convert_to_tensor:
275
- all_embeddings = torch.stack(all_embeddings)
332
+ if len(all_embeddings):
333
+ all_embeddings = torch.stack(all_embeddings)
334
+ else:
335
+ all_embeddings = torch.Tensor()
276
336
  elif convert_to_numpy:
277
337
  all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
278
338
 
@@ -281,12 +341,24 @@ class EmbeddingModel:
281
341
 
282
342
  return all_embeddings, all_token_nums
283
343
 
284
- all_embeddings, all_token_nums = encode(
285
- self._model,
286
- sentences,
287
- convert_to_numpy=False,
288
- **kwargs,
289
- )
344
+ if (
345
+ "gte-Qwen2" in self._model_spec.model_id
346
+ or "gte-Qwen2" in self._model_spec.model_name
347
+ ):
348
+ all_embeddings, all_token_nums = encode(
349
+ self._model,
350
+ sentences,
351
+ prompt_name="query",
352
+ convert_to_numpy=False,
353
+ **kwargs,
354
+ )
355
+ else:
356
+ all_embeddings, all_token_nums = encode(
357
+ self._model,
358
+ sentences,
359
+ convert_to_numpy=False,
360
+ **kwargs,
361
+ )
290
362
  if isinstance(sentences, str):
291
363
  all_embeddings = [all_embeddings]
292
364
  embedding_list = []
@@ -344,11 +416,14 @@ def create_embedding_model_instance(
344
416
  model_uid: str,
345
417
  model_name: str,
346
418
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
419
+ model_path: Optional[str] = None,
347
420
  **kwargs,
348
421
  ) -> Tuple[EmbeddingModel, EmbeddingModelDescription]:
349
422
  model_spec = match_embedding(model_name, download_hub)
350
- model_path = cache(model_spec)
351
- model = EmbeddingModel(model_uid, model_path, **kwargs)
423
+ if model_path is None:
424
+ model_path = cache(model_spec)
425
+
426
+ model = EmbeddingModel(model_uid, model_path, model_spec, **kwargs)
352
427
  model_description = EmbeddingModelDescription(
353
428
  subpool_addr, devices, model_spec, model_path=model_path
354
429
  )
@@ -230,5 +230,13 @@
230
230
  "language": ["zh", "en"],
231
231
  "model_id": "moka-ai/m3e-large",
232
232
  "model_revision": "12900375086c37ba5d83d1e417b21dc7d1d1f388"
233
+ },
234
+ {
235
+ "model_name": "gte-Qwen2",
236
+ "dimensions": 3584,
237
+ "max_tokens": 32000,
238
+ "language": ["zh", "en"],
239
+ "model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
240
+ "model_revision": "e26182b2122f4435e8b3ebecbf363990f409b45b"
233
241
  }
234
242
  ]
@@ -232,5 +232,13 @@
232
232
  "language": ["zh", "en"],
233
233
  "model_id": "AI-ModelScope/m3e-large",
234
234
  "model_hub": "modelscope"
235
+ },
236
+ {
237
+ "model_name": "gte-Qwen2",
238
+ "dimensions": 4096,
239
+ "max_tokens": 32000,
240
+ "language": ["zh", "en"],
241
+ "model_id": "iic/gte_Qwen2-7B-instruct",
242
+ "model_hub": "modelscope"
235
243
  }
236
244
  ]
@@ -210,10 +210,16 @@ def match_flexible_model(model_name):
210
210
 
211
211
 
212
212
  def create_flexible_model_instance(
213
- subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
213
+ subpool_addr: str,
214
+ devices: List[str],
215
+ model_uid: str,
216
+ model_name: str,
217
+ model_path: Optional[str] = None,
218
+ **kwargs,
214
219
  ) -> Tuple[FlexibleModel, FlexibleModelDescription]:
215
220
  model_spec = match_flexible_model(model_name)
216
- model_path = model_spec.model_uri
221
+ if not model_path:
222
+ model_path = model_spec.model_uri
217
223
  launcher_name = model_spec.launcher
218
224
  launcher_args = model_spec.parser_args()
219
225
  kwargs.update(launcher_args)
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .image_process_launcher import launcher as image_process
15
16
  from .transformers_launcher import launcher as transformers
@@ -0,0 +1,70 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ from io import BytesIO
17
+
18
+ import PIL.Image
19
+ import PIL.ImageOps
20
+
21
+ from ....types import Image
22
+ from ..core import FlexibleModel, FlexibleModelSpec
23
+
24
+
25
+ class ImageRemoveBackgroundModel(FlexibleModel):
26
+ def infer(self, **kwargs):
27
+ invert = kwargs.get("invert", False)
28
+ b64_image: str = kwargs.get("image") # type: ignore
29
+ only_mask = kwargs.pop("only_mask", True)
30
+ image_format = kwargs.pop("image_format", "PNG")
31
+ if not b64_image:
32
+ raise ValueError("No image found to remove background")
33
+ image = base64.b64decode(b64_image)
34
+
35
+ try:
36
+ from rembg import remove
37
+ except ImportError:
38
+ error_message = "Failed to import module 'rembg'"
39
+ installation_guide = [
40
+ "Please make sure 'rembg' is installed. ",
41
+ "You can install it by visiting the installation section of the git repo:\n",
42
+ "https://github.com/danielgatis/rembg?tab=readme-ov-file#installation",
43
+ ]
44
+
45
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
46
+
47
+ im = PIL.Image.open(BytesIO(image))
48
+ om = remove(im, only_mask=only_mask, **kwargs)
49
+ if invert:
50
+ om = PIL.ImageOps.invert(om)
51
+
52
+ buffered = BytesIO()
53
+ om.save(buffered, format=image_format)
54
+ img_str = base64.b64encode(buffered.getvalue()).decode()
55
+ return Image(url=None, b64_json=img_str)
56
+
57
+
58
+ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
59
+ task = kwargs.get("task")
60
+ device = kwargs.get("device")
61
+
62
+ if task == "remove_background":
63
+ return ImageRemoveBackgroundModel(
64
+ model_uid=model_uid,
65
+ model_path=model_spec.model_uri, # type: ignore
66
+ device=device,
67
+ config=kwargs,
68
+ )
69
+ else:
70
+ raise ValueError(f"Unknown Task for image processing: {task}")
@@ -45,7 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
45
  model_id: str
46
46
  model_revision: str
47
47
  model_hub: str = "huggingface"
48
- ability: Optional[str]
48
+ abilities: Optional[List[str]]
49
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
50
50
 
51
51
 
@@ -72,7 +72,7 @@ class ImageModelDescription(ModelDescription):
72
72
  "model_name": self._model_spec.model_name,
73
73
  "model_family": self._model_spec.model_family,
74
74
  "model_revision": self._model_spec.model_revision,
75
- "ability": self._model_spec.ability,
75
+ "abilities": self._model_spec.abilities,
76
76
  "controlnet": controlnet,
77
77
  }
78
78
 
@@ -189,6 +189,7 @@ def create_image_model_instance(
189
189
  model_name: str,
190
190
  peft_model_config: Optional[PeftModelConfig] = None,
191
191
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
192
+ model_path: Optional[str] = None,
192
193
  **kwargs,
193
194
  ) -> Tuple[DiffusionModel, ImageModelDescription]:
194
195
  model_spec = match_diffusion(model_name, download_hub)
@@ -209,7 +210,8 @@ def create_image_model_instance(
209
210
  for name in controlnet:
210
211
  for cn_model_spec in model_spec.controlnet:
211
212
  if cn_model_spec.model_name == name:
212
- model_path = cache(cn_model_spec)
213
+ if not model_path:
214
+ model_path = cache(cn_model_spec)
213
215
  controlnet_model_paths.append(model_path)
214
216
  break
215
217
  else:
@@ -220,7 +222,8 @@ def create_image_model_instance(
220
222
  kwargs["controlnet"] = controlnet_model_paths[0]
221
223
  else:
222
224
  kwargs["controlnet"] = controlnet_model_paths
223
- model_path = cache(model_spec)
225
+ if not model_path:
226
+ model_path = cache(model_spec)
224
227
  if peft_model_config is not None:
225
228
  lora_model = peft_model_config.peft_model
226
229
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -236,7 +239,7 @@ def create_image_model_instance(
236
239
  lora_model_paths=lora_model,
237
240
  lora_load_kwargs=lora_load_kwargs,
238
241
  lora_fuse_kwargs=lora_fuse_kwargs,
239
- ability=model_spec.ability,
242
+ abilities=model_spec.abilities,
240
243
  **kwargs,
241
244
  )
242
245
  model_description = ImageModelDescription(
@@ -3,25 +3,39 @@
3
3
  "model_name": "sd3-medium",
4
4
  "model_family": "stable_diffusion",
5
5
  "model_id": "stabilityai/stable-diffusion-3-medium-diffusers",
6
- "model_revision": "ea42f8cef0f178587cf766dc8129abd379c90671"
6
+ "model_revision": "ea42f8cef0f178587cf766dc8129abd379c90671",
7
+ "abilities": [
8
+ "text2iamge",
9
+ "image2image"
10
+ ]
7
11
  },
8
12
  {
9
13
  "model_name": "sd-turbo",
10
14
  "model_family": "stable_diffusion",
11
15
  "model_id": "stabilityai/sd-turbo",
12
- "model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c"
16
+ "model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c",
17
+ "abilities": [
18
+ "text2iamge"
19
+ ]
13
20
  },
14
21
  {
15
22
  "model_name": "sdxl-turbo",
16
23
  "model_family": "stable_diffusion",
17
24
  "model_id": "stabilityai/sdxl-turbo",
18
- "model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b"
25
+ "model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b",
26
+ "abilities": [
27
+ "text2iamge"
28
+ ]
19
29
  },
20
30
  {
21
31
  "model_name": "stable-diffusion-v1.5",
22
32
  "model_family": "stable_diffusion",
23
33
  "model_id": "runwayml/stable-diffusion-v1-5",
24
34
  "model_revision": "1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9",
35
+ "abilities": [
36
+ "text2iamge",
37
+ "image2image"
38
+ ],
25
39
  "controlnet": [
26
40
  {
27
41
  "model_name":"canny",
@@ -72,6 +86,10 @@
72
86
  "model_family": "stable_diffusion",
73
87
  "model_id": "stabilityai/stable-diffusion-xl-base-1.0",
74
88
  "model_revision": "f898a3e026e802f68796b95e9702464bac78d76f",
89
+ "abilities": [
90
+ "text2iamge",
91
+ "image2image"
92
+ ],
75
93
  "controlnet": [
76
94
  {
77
95
  "model_name":"canny",
@@ -98,13 +116,26 @@
98
116
  "model_family": "stable_diffusion",
99
117
  "model_id": "runwayml/stable-diffusion-inpainting",
100
118
  "model_revision": "51388a731f57604945fddd703ecb5c50e8e7b49d",
101
- "ability": "inpainting"
119
+ "abilities": [
120
+ "inpainting"
121
+ ]
102
122
  },
103
123
  {
104
124
  "model_name": "stable-diffusion-2-inpainting",
105
125
  "model_family": "stable_diffusion",
106
126
  "model_id": "stabilityai/stable-diffusion-2-inpainting",
107
127
  "model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
108
- "ability": "inpainting"
128
+ "abilities": [
129
+ "inpainting"
130
+ ]
131
+ },
132
+ {
133
+ "model_name": "stable-diffusion-xl-inpainting",
134
+ "model_family": "stable_diffusion",
135
+ "model_id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
136
+ "model_revision": "115134f363124c53c7d878647567d04daf26e41e",
137
+ "abilities": [
138
+ "inpainting"
139
+ ]
109
140
  }
110
141
  ]
@@ -4,21 +4,31 @@
4
4
  "model_family": "stable_diffusion",
5
5
  "model_hub": "modelscope",
6
6
  "model_id": "AI-ModelScope/stable-diffusion-3-medium-diffusers",
7
- "model_revision": "master"
7
+ "model_revision": "master",
8
+ "abilities": [
9
+ "text2iamge",
10
+ "image2image"
11
+ ]
8
12
  },
9
13
  {
10
14
  "model_name": "sd-turbo",
11
15
  "model_family": "stable_diffusion",
12
16
  "model_hub": "modelscope",
13
17
  "model_id": "AI-ModelScope/sd-turbo",
14
- "model_revision": "master"
18
+ "model_revision": "master",
19
+ "abilities": [
20
+ "text2iamge"
21
+ ]
15
22
  },
16
23
  {
17
24
  "model_name": "sdxl-turbo",
18
25
  "model_family": "stable_diffusion",
19
26
  "model_hub": "modelscope",
20
27
  "model_id": "AI-ModelScope/sdxl-turbo",
21
- "model_revision": "master"
28
+ "model_revision": "master",
29
+ "abilities": [
30
+ "text2iamge"
31
+ ]
22
32
  },
23
33
  {
24
34
  "model_name": "stable-diffusion-v1.5",
@@ -26,6 +36,10 @@
26
36
  "model_hub": "modelscope",
27
37
  "model_id": "AI-ModelScope/stable-diffusion-v1-5",
28
38
  "model_revision": "master",
39
+ "abilities": [
40
+ "text2iamge",
41
+ "image2image"
42
+ ],
29
43
  "controlnet": [
30
44
  {
31
45
  "model_name":"canny",
@@ -77,6 +91,10 @@
77
91
  "model_hub": "modelscope",
78
92
  "model_id": "AI-ModelScope/stable-diffusion-xl-base-1.0",
79
93
  "model_revision": "master",
94
+ "abilities": [
95
+ "text2iamge",
96
+ "image2image"
97
+ ],
80
98
  "controlnet": [
81
99
  {
82
100
  "model_name":"canny",
@@ -35,22 +35,23 @@ class DiffusionModel:
35
35
  def __init__(
36
36
  self,
37
37
  model_uid: str,
38
- model_path: str,
38
+ model_path: Optional[str] = None,
39
39
  device: Optional[str] = None,
40
40
  lora_model: Optional[List[LoRA]] = None,
41
41
  lora_load_kwargs: Optional[Dict] = None,
42
42
  lora_fuse_kwargs: Optional[Dict] = None,
43
- ability: Optional[str] = None,
43
+ abilities: Optional[List[str]] = None,
44
44
  **kwargs,
45
45
  ):
46
46
  self._model_uid = model_uid
47
47
  self._model_path = model_path
48
48
  self._device = device
49
49
  self._model = None
50
+ self._i2i_model = None # image to image model
50
51
  self._lora_model = lora_model
51
52
  self._lora_load_kwargs = lora_load_kwargs or {}
52
53
  self._lora_fuse_kwargs = lora_fuse_kwargs or {}
53
- self._ability = ability
54
+ self._abilities = abilities
54
55
  self._kwargs = kwargs
55
56
 
56
57
  def _apply_lora(self):
@@ -69,12 +70,12 @@ class DiffusionModel:
69
70
  def load(self):
70
71
  import torch
71
72
 
72
- if self._ability in [None, "text2image", "image2image"]:
73
+ if "text2image" in self._abilities or "image2image" in self._abilities:
73
74
  from diffusers import AutoPipelineForText2Image as AutoPipelineModel
74
- elif self._ability == "inpainting":
75
+ elif "inpainting" in self._abilities:
75
76
  from diffusers import AutoPipelineForInpainting as AutoPipelineModel
76
77
  else:
77
- raise ValueError(f"Unknown ability: {self._ability}")
78
+ raise ValueError(f"Unknown ability: {self._abilities}")
78
79
 
79
80
  controlnet = self._kwargs.get("controlnet")
80
81
  if controlnet is not None:
@@ -94,35 +95,29 @@ class DiffusionModel:
94
95
  self._model_path,
95
96
  **self._kwargs,
96
97
  )
97
- self._model = move_model_to_available_device(self._model)
98
+ if self._kwargs.get("cpu_offload", False):
99
+ logger.debug("CPU offloading model")
100
+ self._model.enable_model_cpu_offload()
101
+ else:
102
+ logger.debug("Loading model to available device")
103
+ self._model = move_model_to_available_device(self._model)
98
104
  # Recommended if your computer has < 64 GB of RAM
99
105
  self._model.enable_attention_slicing()
100
106
  self._apply_lora()
101
107
 
102
108
  def _call_model(
103
109
  self,
104
- height: int,
105
- width: int,
106
- num_images_per_prompt: int,
107
110
  response_format: str,
111
+ model=None,
108
112
  **kwargs,
109
113
  ):
110
114
  logger.debug(
111
115
  "stable diffusion args: %s",
112
- dict(
113
- kwargs,
114
- height=height,
115
- width=width,
116
- num_images_per_prompt=num_images_per_prompt,
117
- ),
116
+ kwargs,
118
117
  )
119
- assert callable(self._model)
120
- images = self._model(
121
- height=height,
122
- width=width,
123
- num_images_per_prompt=num_images_per_prompt,
124
- **kwargs,
125
- ).images
118
+ model = model if model is not None else self._model
119
+ assert callable(model)
120
+ images = model(**kwargs).images
126
121
  if response_format == "url":
127
122
  os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True)
128
123
  image_list = []
@@ -140,7 +135,7 @@ class DiffusionModel:
140
135
  return base64.b64encode(buffered.getvalue()).decode()
141
136
 
142
137
  with ThreadPoolExecutor() as executor:
143
- results = list(map(partial(executor.submit, _gen_base64_image), images))
138
+ results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
144
139
  image_list = [Image(url=None, b64_json=s.result()) for s in results]
145
140
  return ImageList(created=int(time.time()), data=image_list)
146
141
  else:
@@ -172,19 +167,32 @@ class DiffusionModel:
172
167
  prompt: Optional[Union[str, List[str]]] = None,
173
168
  negative_prompt: Optional[Union[str, List[str]]] = None,
174
169
  n: int = 1,
175
- size: str = "1024*1024",
170
+ size: Optional[str] = None,
176
171
  response_format: str = "url",
177
172
  **kwargs,
178
173
  ):
179
- width, height = map(int, re.split(r"[^\d]+", size))
174
+ if "controlnet" in self._kwargs:
175
+ model = self._model
176
+ else:
177
+ if self._i2i_model is not None:
178
+ model = self._i2i_model
179
+ else:
180
+ from diffusers import AutoPipelineForImage2Image
181
+
182
+ self._i2i_model = model = AutoPipelineForImage2Image.from_pipe(
183
+ self._model
184
+ )
185
+ if size:
186
+ width, height = map(int, re.split(r"[^\d]+", size))
187
+ kwargs["width"] = width
188
+ kwargs["height"] = height
180
189
  return self._call_model(
181
190
  image=image,
182
191
  prompt=prompt,
183
192
  negative_prompt=negative_prompt,
184
- height=height,
185
- width=width,
186
193
  num_images_per_prompt=n,
187
194
  response_format=response_format,
195
+ model=model,
188
196
  **kwargs,
189
197
  )
190
198
 
@@ -194,6 +194,7 @@ def create_llm_model_instance(
194
194
  quantization: Optional[str] = None,
195
195
  peft_model_config: Optional[PeftModelConfig] = None,
196
196
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
197
+ model_path: Optional[str] = None,
197
198
  **kwargs,
198
199
  ) -> Tuple[LLM, LLMDescription]:
199
200
  from .llm_family import cache, check_engine_by_spec_parameters, match_llm
@@ -221,7 +222,8 @@ def create_llm_model_instance(
221
222
  )
222
223
  logger.debug(f"Launching {model_uid} with {llm_cls.__name__}")
223
224
 
224
- save_path = cache(llm_family, llm_spec, quantization)
225
+ if not model_path:
226
+ model_path = cache(llm_family, llm_spec, quantization)
225
227
 
226
228
  peft_model = peft_model_config.peft_model if peft_model_config else None
227
229
  if peft_model is not None:
@@ -231,7 +233,7 @@ def create_llm_model_instance(
231
233
  llm_family,
232
234
  llm_spec,
233
235
  quantization,
234
- save_path,
236
+ model_path,
235
237
  kwargs,
236
238
  peft_model,
237
239
  )
@@ -241,11 +243,11 @@ def create_llm_model_instance(
241
243
  f"Load this without lora."
242
244
  )
243
245
  model = llm_cls(
244
- model_uid, llm_family, llm_spec, quantization, save_path, kwargs
246
+ model_uid, llm_family, llm_spec, quantization, model_path, kwargs
245
247
  )
246
248
  else:
247
249
  model = llm_cls(
248
- model_uid, llm_family, llm_spec, quantization, save_path, kwargs
250
+ model_uid, llm_family, llm_spec, quantization, model_path, kwargs
249
251
  )
250
252
  return model, LLMDescription(
251
253
  subpool_addr, devices, llm_family, llm_spec, quantization