xinference 0.15.0__py3-none-any.whl → 0.15.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (84) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +204 -1
  3. xinference/client/restful/restful_client.py +4 -2
  4. xinference/core/image_interface.py +28 -0
  5. xinference/core/model.py +30 -2
  6. xinference/core/supervisor.py +6 -0
  7. xinference/model/audio/cosyvoice.py +3 -3
  8. xinference/model/audio/fish_speech.py +9 -9
  9. xinference/model/audio/model_spec.json +9 -9
  10. xinference/model/audio/whisper.py +4 -1
  11. xinference/model/image/core.py +2 -1
  12. xinference/model/image/model_spec.json +16 -4
  13. xinference/model/image/model_spec_modelscope.json +16 -4
  14. xinference/model/image/sdapi.py +136 -0
  15. xinference/model/image/stable_diffusion/core.py +163 -24
  16. xinference/model/llm/__init__.py +9 -1
  17. xinference/model/llm/llm_family.json +1241 -0
  18. xinference/model/llm/llm_family.py +3 -1
  19. xinference/model/llm/llm_family_modelscope.json +1301 -3
  20. xinference/model/llm/sglang/core.py +7 -0
  21. xinference/model/llm/transformers/chatglm.py +1 -1
  22. xinference/model/llm/transformers/core.py +6 -0
  23. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  24. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  25. xinference/model/llm/transformers/qwen2_vl.py +31 -5
  26. xinference/model/llm/utils.py +104 -84
  27. xinference/model/llm/vllm/core.py +13 -0
  28. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +2 -3
  29. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +1 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  34. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  35. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  37. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  38. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  39. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  40. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  41. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  42. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  43. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  44. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  45. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  46. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  47. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  48. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  49. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  50. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  51. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  52. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  53. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  54. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  55. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  56. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  57. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  58. xinference/types.py +7 -4
  59. xinference/web/ui/build/asset-manifest.json +6 -6
  60. xinference/web/ui/build/index.html +1 -1
  61. xinference/web/ui/build/static/css/{main.632e9148.css → main.5061c4c3.css} +2 -2
  62. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  63. xinference/web/ui/build/static/js/{main.9cfafbd6.js → main.29578905.js} +3 -3
  64. xinference/web/ui/build/static/js/main.29578905.js.map +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  67. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/METADATA +13 -7
  68. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/RECORD +73 -75
  69. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  73. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  74. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  75. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  76. xinference/web/ui/build/static/css/main.632e9148.css.map +0 -1
  77. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +0 -1
  80. /xinference/web/ui/build/static/js/{main.9cfafbd6.js.LICENSE.txt → main.29578905.js.LICENSE.txt} +0 -0
  81. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/LICENSE +0 -0
  82. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/WHEEL +0 -0
  83. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/entry_points.txt +0 -0
  84. {xinference-0.15.0.dist-info → xinference-0.15.2.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,9 @@
6
6
  "model_id": "AI-ModelScope/FLUX.1-schnell",
7
7
  "model_revision": "master",
8
8
  "model_ability": [
9
- "text2image"
9
+ "text2image",
10
+ "image2image",
11
+ "inpainting"
10
12
  ]
11
13
  },
12
14
  {
@@ -16,7 +18,9 @@
16
18
  "model_id": "AI-ModelScope/FLUX.1-dev",
17
19
  "model_revision": "master",
18
20
  "model_ability": [
19
- "text2image"
21
+ "text2image",
22
+ "image2image",
23
+ "inpainting"
20
24
  ]
21
25
  },
22
26
  {
@@ -39,7 +43,11 @@
39
43
  "model_revision": "master",
40
44
  "model_ability": [
41
45
  "text2image"
42
- ]
46
+ ],
47
+ "default_generate_config": {
48
+ "guidance_scale": 0.0,
49
+ "num_inference_steps": 1
50
+ }
43
51
  },
44
52
  {
45
53
  "model_name": "sdxl-turbo",
@@ -49,7 +57,11 @@
49
57
  "model_revision": "master",
50
58
  "model_ability": [
51
59
  "text2image"
52
- ]
60
+ ],
61
+ "default_generate_config": {
62
+ "guidance_scale": 0.0,
63
+ "num_inference_steps": 1
64
+ }
53
65
  },
54
66
  {
55
67
  "model_name": "stable-diffusion-v1.5",
@@ -0,0 +1,136 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import io
16
+ import warnings
17
+
18
+ from PIL import Image
19
+
20
+
21
+ class SDAPIToDiffusersConverter:
22
+ txt2img_identical_args = {
23
+ "prompt",
24
+ "negative_prompt",
25
+ "seed",
26
+ "width",
27
+ "height",
28
+ "sampler_name",
29
+ }
30
+ txt2img_arg_mapping = {
31
+ "steps": "num_inference_steps",
32
+ "cfg_scale": "guidance_scale",
33
+ # "denoising_strength": "strength",
34
+ }
35
+ img2img_identical_args = {
36
+ "prompt",
37
+ "negative_prompt",
38
+ "seed",
39
+ "width",
40
+ "height",
41
+ "sampler_name",
42
+ }
43
+ img2img_arg_mapping = {
44
+ "init_images": "image",
45
+ "steps": "num_inference_steps",
46
+ "cfg_scale": "guidance_scale",
47
+ "denoising_strength": "strength",
48
+ }
49
+
50
+ @staticmethod
51
+ def convert_to_diffusers(sd_type: str, params: dict) -> dict:
52
+ diffusers_params = {}
53
+
54
+ identical_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_identical_args")
55
+ mapping_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_arg_mapping")
56
+ for param, value in params.items():
57
+ if param in identical_args:
58
+ diffusers_params[param] = value
59
+ elif param in mapping_args:
60
+ diffusers_params[mapping_args[param]] = value
61
+ else:
62
+ raise ValueError(f"Unknown arg: {param}")
63
+
64
+ return diffusers_params
65
+
66
+ @staticmethod
67
+ def get_available_args(sd_type: str) -> set:
68
+ identical_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_identical_args")
69
+ mapping_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_arg_mapping")
70
+ return identical_args.union(mapping_args)
71
+
72
+
73
+ class SDAPIDiffusionModelMixin:
74
+ @staticmethod
75
+ def _check_kwargs(sd_type: str, kwargs: dict):
76
+ available_args = SDAPIToDiffusersConverter.get_available_args(sd_type)
77
+ unknown_args = []
78
+ available_kwargs = {}
79
+ for arg, value in kwargs.items():
80
+ if arg in available_args:
81
+ available_kwargs[arg] = value
82
+ else:
83
+ unknown_args.append(arg)
84
+ if unknown_args:
85
+ warnings.warn(
86
+ f"Some args are not supported for now and will be ignored: {unknown_args}"
87
+ )
88
+
89
+ converted_kwargs = SDAPIToDiffusersConverter.convert_to_diffusers(
90
+ sd_type, available_kwargs
91
+ )
92
+
93
+ width, height = converted_kwargs.pop("width", None), converted_kwargs.pop(
94
+ "height", None
95
+ )
96
+ if width and height:
97
+ converted_kwargs["size"] = f"{width}*{height}"
98
+
99
+ return converted_kwargs
100
+
101
+ def txt2img(self, **kwargs):
102
+ converted_kwargs = self._check_kwargs("txt2img", kwargs)
103
+ result = self.text_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
104
+
105
+ # convert to SD API result
106
+ return {
107
+ "images": [r["b64_json"] for r in result["data"]],
108
+ "info": {"created": result["created"]},
109
+ "parameters": {},
110
+ }
111
+
112
+ @staticmethod
113
+ def _decode_b64_img(img_str: str) -> Image:
114
+ # img_str in a format: "data:image/png;base64," + raw_b64_img(image)
115
+ f, data = img_str.split(",", 1)
116
+ f, encode_type = f.split(";", 1)
117
+ assert encode_type == "base64"
118
+ f = f.split("/", 1)[1]
119
+ b = base64.b64decode(data)
120
+ return Image.open(io.BytesIO(b), formats=[f])
121
+
122
+ def img2img(self, **kwargs):
123
+ init_images = kwargs.pop("init_images", [])
124
+ kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images]
125
+ clip_skip = kwargs.get("override_settings", {}).get("clip_skip")
126
+ converted_kwargs = self._check_kwargs("img2img", kwargs)
127
+ if clip_skip:
128
+ converted_kwargs["clip_skip"] = clip_skip
129
+ result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
130
+
131
+ # convert to SD API result
132
+ return {
133
+ "images": [r["b64_json"] for r in result["data"]],
134
+ "info": {"created": result["created"]},
135
+ "parameters": {},
136
+ }
@@ -13,28 +13,72 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import base64
16
+ import contextlib
17
+ import inspect
16
18
  import logging
17
19
  import os
18
20
  import re
19
21
  import sys
20
22
  import time
21
23
  import uuid
24
+ import warnings
22
25
  from concurrent.futures import ThreadPoolExecutor
23
26
  from functools import partial
24
27
  from io import BytesIO
25
- from typing import Dict, List, Optional, Union
28
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
29
 
27
30
  import PIL.Image
31
+ import torch
28
32
  from PIL import ImageOps
29
33
 
30
34
  from ....constants import XINFERENCE_IMAGE_DIR
31
- from ....device_utils import move_model_to_available_device
35
+ from ....device_utils import get_available_device, move_model_to_available_device
32
36
  from ....types import Image, ImageList, LoRA
37
+ from ..sdapi import SDAPIDiffusionModelMixin
33
38
 
34
- logger = logging.getLogger(__name__)
39
+ if TYPE_CHECKING:
40
+ from ..core import ImageModelFamilyV1
35
41
 
42
+ logger = logging.getLogger(__name__)
36
43
 
37
- class DiffusionModel:
44
+ SAMPLING_METHODS = [
45
+ "default",
46
+ "DPM++ 2M",
47
+ "DPM++ 2M Karras",
48
+ "DPM++ 2M SDE",
49
+ "DPM++ 2M SDE Karras",
50
+ "DPM++ SDE",
51
+ "DPM++ SDE Karras",
52
+ "DPM2",
53
+ "DPM2 Karras",
54
+ "DPM2 a",
55
+ "DPM2 a Karras",
56
+ "Euler",
57
+ "Euler a",
58
+ "Heun",
59
+ "LMS",
60
+ "LMS Karras",
61
+ ]
62
+
63
+
64
+ def model_accept_param(params: Union[str, List[str]], model: Any) -> bool:
65
+ params = [params] if isinstance(params, str) else params
66
+ # model is diffusers Pipeline
67
+ parameters = inspect.signature(model.__call__).parameters # type: ignore
68
+ allow_params = False
69
+ for param in parameters.values():
70
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
71
+ # the __call__ can accept **kwargs,
72
+ # we treat it as it can accept any parameters
73
+ allow_params = True
74
+ break
75
+ if not allow_params:
76
+ if all(param in parameters for param in params):
77
+ allow_params = True
78
+ return allow_params
79
+
80
+
81
+ class DiffusionModel(SDAPIDiffusionModelMixin):
38
82
  def __init__(
39
83
  self,
40
84
  model_uid: str,
@@ -43,7 +87,7 @@ class DiffusionModel:
43
87
  lora_model: Optional[List[LoRA]] = None,
44
88
  lora_load_kwargs: Optional[Dict] = None,
45
89
  lora_fuse_kwargs: Optional[Dict] = None,
46
- abilities: Optional[List[str]] = None,
90
+ model_spec: Optional["ImageModelFamilyV1"] = None,
47
91
  **kwargs,
48
92
  ):
49
93
  self._model_uid = model_uid
@@ -59,7 +103,8 @@ class DiffusionModel:
59
103
  self._lora_model = lora_model
60
104
  self._lora_load_kwargs = lora_load_kwargs or {}
61
105
  self._lora_fuse_kwargs = lora_fuse_kwargs or {}
62
- self._abilities = abilities or []
106
+ self._model_spec = model_spec
107
+ self._abilities = model_spec.model_ability or [] # type: ignore
63
108
  self._kwargs = kwargs
64
109
 
65
110
  @property
@@ -80,8 +125,6 @@ class DiffusionModel:
80
125
  logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
81
126
 
82
127
  def load(self):
83
- import torch
84
-
85
128
  if "text2image" in self._abilities or "image2image" in self._abilities:
86
129
  from diffusers import AutoPipelineForText2Image as AutoPipelineModel
87
130
  elif "inpainting" in self._abilities:
@@ -143,7 +186,9 @@ class DiffusionModel:
143
186
  self._kwargs[text_encoder_name] = text_encoder
144
187
  self._kwargs["device_map"] = "balanced"
145
188
 
146
- logger.debug("Loading model %s", AutoPipelineModel)
189
+ logger.debug(
190
+ "Loading model from %s, kwargs: %s", self._model_path, self._kwargs
191
+ )
147
192
  self._model = AutoPipelineModel.from_pretrained(
148
193
  self._model_path,
149
194
  **self._kwargs,
@@ -158,6 +203,89 @@ class DiffusionModel:
158
203
  self._model.enable_attention_slicing()
159
204
  self._apply_lora()
160
205
 
206
+ @staticmethod
207
+ def _get_scheduler(model: Any, sampler_name: str):
208
+ if not sampler_name or sampler_name == "default":
209
+ return
210
+
211
+ assert model is not None
212
+
213
+ import diffusers
214
+
215
+ # see https://github.com/huggingface/diffusers/issues/4167
216
+ # to get A1111 <> Diffusers Scheduler mapping
217
+ if sampler_name == "DPM++ 2M":
218
+ return diffusers.DPMSolverMultistepScheduler.from_config(
219
+ model.scheduler.config
220
+ )
221
+ elif sampler_name == "DPM++ 2M Karras":
222
+ return diffusers.DPMSolverMultistepScheduler.from_config(
223
+ model.scheduler.config, use_karras_sigmas=True
224
+ )
225
+ elif sampler_name == "DPM++ 2M SDE":
226
+ return diffusers.DPMSolverMultistepScheduler.from_config(
227
+ model.scheduler.config, algorithm_type="sde-dpmsolver++"
228
+ )
229
+ elif sampler_name == "DPM++ 2M SDE Karras":
230
+ return diffusers.DPMSolverMultistepScheduler.from_config(
231
+ model.scheduler.config,
232
+ algorithm_type="sde-dpmsolver++",
233
+ use_karras_sigmas=True,
234
+ )
235
+ elif sampler_name == "DPM++ SDE":
236
+ return diffusers.DPMSolverSinglestepScheduler.from_config(
237
+ model.scheduler.config
238
+ )
239
+ elif sampler_name == "DPM++ SDE Karras":
240
+ return diffusers.DPMSolverSinglestepScheduler.from_config(
241
+ model.scheduler.config, use_karras_sigmas=True
242
+ )
243
+ elif sampler_name == "DPM2":
244
+ return diffusers.KDPM2DiscreteScheduler.from_config(model.scheduler.config)
245
+ elif sampler_name == "DPM2 Karras":
246
+ return diffusers.KDPM2DiscreteScheduler.from_config(
247
+ model.scheduler.config, use_karras_sigmas=True
248
+ )
249
+ elif sampler_name == "DPM2 a":
250
+ return diffusers.KDPM2AncestralDiscreteScheduler.from_config(
251
+ model.scheduler.config
252
+ )
253
+ elif sampler_name == "DPM2 a Karras":
254
+ return diffusers.KDPM2AncestralDiscreteScheduler.from_config(
255
+ model.scheduler.config, use_karras_sigmas=True
256
+ )
257
+ elif sampler_name == "Euler":
258
+ return diffusers.EulerDiscreteScheduler.from_config(model.scheduler.config)
259
+ elif sampler_name == "Euler a":
260
+ return diffusers.EulerAncestralDiscreteScheduler.from_config(
261
+ model.scheduler.config
262
+ )
263
+ elif sampler_name == "Heun":
264
+ return diffusers.HeunDiscreteScheduler.from_config(model.scheduler.config)
265
+ elif sampler_name == "LMS":
266
+ return diffusers.LMSDiscreteScheduler.from_config(model.scheduler.config)
267
+ elif sampler_name == "LMS Karras":
268
+ return diffusers.LMSDiscreteScheduler.from_config(
269
+ model.scheduler.config, use_karras_sigmas=True
270
+ )
271
+ else:
272
+ raise ValueError(f"Unknown sampler: {sampler_name}")
273
+
274
+ @staticmethod
275
+ @contextlib.contextmanager
276
+ def _reset_when_done(model: Any, sampler_name: str):
277
+ assert model is not None
278
+ scheduler = DiffusionModel._get_scheduler(model, sampler_name)
279
+ if scheduler:
280
+ default_scheduler = model.scheduler
281
+ model.scheduler = scheduler
282
+ try:
283
+ yield
284
+ finally:
285
+ model.scheduler = default_scheduler
286
+ else:
287
+ yield
288
+
161
289
  def _call_model(
162
290
  self,
163
291
  response_format: str,
@@ -168,16 +296,20 @@ class DiffusionModel:
168
296
 
169
297
  from ....device_utils import empty_cache
170
298
 
171
- logger.debug(
172
- "stable diffusion args: %s",
173
- kwargs,
174
- )
299
+ model = model if model is not None else self._model
175
300
  is_padded = kwargs.pop("is_padded", None)
176
301
  origin_size = kwargs.pop("origin_size", None)
177
-
178
- model = model if model is not None else self._model
302
+ seed = kwargs.pop("seed", None)
303
+ if seed is not None:
304
+ kwargs["generator"] = generator = torch.Generator(device=get_available_device()) # type: ignore
305
+ if seed != -1:
306
+ kwargs["generator"] = generator.manual_seed(seed)
307
+ sampler_name = kwargs.pop("sampler_name", None)
179
308
  assert callable(model)
180
- images = model(**kwargs).images
309
+ with self._reset_when_done(model, sampler_name):
310
+ logger.debug("stable diffusion args: %s, model: %s", kwargs, model)
311
+ self._filter_kwargs(model, kwargs)
312
+ images = model(**kwargs).images
181
313
 
182
314
  # revert padding if padded
183
315
  if is_padded and origin_size:
@@ -215,11 +347,17 @@ class DiffusionModel:
215
347
  raise ValueError(f"Unsupported response format: {response_format}")
216
348
 
217
349
  @classmethod
218
- def _filter_kwargs(cls, kwargs: dict):
350
+ def _filter_kwargs(cls, model, kwargs: dict):
219
351
  for arg in ["negative_prompt", "num_inference_steps"]:
220
352
  if not kwargs.get(arg):
221
353
  kwargs.pop(arg, None)
222
354
 
355
+ for key in list(kwargs):
356
+ allow_key = model_accept_param(key, model)
357
+ if not allow_key:
358
+ warnings.warn(f"{type(model)} cannot accept `{key}`, will ignore it")
359
+ kwargs.pop(key)
360
+
223
361
  def text_to_image(
224
362
  self,
225
363
  prompt: str,
@@ -231,14 +369,15 @@ class DiffusionModel:
231
369
  # References:
232
370
  # https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet_sdxl
233
371
  width, height = map(int, re.split(r"[^\d]+", size))
234
- self._filter_kwargs(kwargs)
372
+ generate_kwargs = self._model_spec.default_generate_config.copy() # type: ignore
373
+ generate_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
235
374
  return self._call_model(
236
375
  prompt=prompt,
237
376
  height=height,
238
377
  width=width,
239
378
  num_images_per_prompt=n,
240
379
  response_format=response_format,
241
- **kwargs,
380
+ **generate_kwargs,
242
381
  )
243
382
 
244
383
  @staticmethod
@@ -253,7 +392,6 @@ class DiffusionModel:
253
392
  self,
254
393
  image: PIL.Image,
255
394
  prompt: Optional[Union[str, List[str]]] = None,
256
- negative_prompt: Optional[Union[str, List[str]]] = None,
257
395
  n: int = 1,
258
396
  size: Optional[str] = None,
259
397
  response_format: str = "url",
@@ -287,12 +425,15 @@ class DiffusionModel:
287
425
  width, height = image.size
288
426
  kwargs["width"] = width
289
427
  kwargs["height"] = height
428
+ else:
429
+ # SD3 image2image cannot accept width and height
430
+ allow_width_height = model_accept_param(["width", "height"], model)
431
+ if allow_width_height:
432
+ kwargs["width"], kwargs["height"] = image.size
290
433
 
291
- self._filter_kwargs(kwargs)
292
434
  return self._call_model(
293
435
  image=image,
294
436
  prompt=prompt,
295
- negative_prompt=negative_prompt,
296
437
  num_images_per_prompt=n,
297
438
  response_format=response_format,
298
439
  model=model,
@@ -304,7 +445,6 @@ class DiffusionModel:
304
445
  image: PIL.Image,
305
446
  mask_image: PIL.Image,
306
447
  prompt: Optional[Union[str, List[str]]] = None,
307
- negative_prompt: Optional[Union[str, List[str]]] = None,
308
448
  n: int = 1,
309
449
  size: str = "1024*1024",
310
450
  response_format: str = "url",
@@ -346,7 +486,6 @@ class DiffusionModel:
346
486
  image=image,
347
487
  mask_image=mask_image,
348
488
  prompt=prompt,
349
- negative_prompt=negative_prompt,
350
489
  height=height,
351
490
  width=width,
352
491
  num_images_per_prompt=n,
@@ -121,7 +121,7 @@ def register_custom_model():
121
121
  with codecs.open(
122
122
  os.path.join(user_defined_llm_dir, f), encoding="utf-8"
123
123
  ) as fd:
124
- user_defined_llm_family = CustomLLMFamilyV1.parse_obj(json.load(fd))
124
+ user_defined_llm_family = CustomLLMFamilyV1.parse_raw(fd.read())
125
125
  register_llm(user_defined_llm_family, persist=False)
126
126
  except Exception as e:
127
127
  warnings.warn(f"{user_defined_llm_dir}/{f} has error, {e}")
@@ -136,12 +136,17 @@ def _install():
136
136
  from .transformers.cogvlm2 import CogVLM2Model
137
137
  from .transformers.cogvlm2_video import CogVLM2VideoModel
138
138
  from .transformers.core import PytorchChatModel, PytorchModel
139
+ from .transformers.deepseek_v2 import (
140
+ DeepSeekV2PytorchChatModel,
141
+ DeepSeekV2PytorchModel,
142
+ )
139
143
  from .transformers.deepseek_vl import DeepSeekVLChatModel
140
144
  from .transformers.glm4v import Glm4VModel
141
145
  from .transformers.intern_vl import InternVLChatModel
142
146
  from .transformers.internlm2 import Internlm2PytorchChatModel
143
147
  from .transformers.minicpmv25 import MiniCPMV25Model
144
148
  from .transformers.minicpmv26 import MiniCPMV26Model
149
+ from .transformers.qwen2_audio import Qwen2AudioChatModel
145
150
  from .transformers.qwen2_vl import Qwen2VLChatModel
146
151
  from .transformers.qwen_vl import QwenVLChatModel
147
152
  from .transformers.yi_vl import YiVLChatModel
@@ -173,6 +178,7 @@ def _install():
173
178
  Internlm2PytorchChatModel,
174
179
  QwenVLChatModel,
175
180
  Qwen2VLChatModel,
181
+ Qwen2AudioChatModel,
176
182
  YiVLChatModel,
177
183
  DeepSeekVLChatModel,
178
184
  InternVLChatModel,
@@ -182,6 +188,8 @@ def _install():
182
188
  MiniCPMV25Model,
183
189
  MiniCPMV26Model,
184
190
  Glm4VModel,
191
+ DeepSeekV2PytorchModel,
192
+ DeepSeekV2PytorchChatModel,
185
193
  ]
186
194
  )
187
195
  if OmniLMMModel: # type: ignore