xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -60,3 +60,59 @@ class ImageCacheManager(CacheManager):
60
60
  raise NotImplementedError
61
61
 
62
62
  return full_path
63
+
64
+ def cache_lightning(self, lightning_version: Optional[str] = None):
65
+ from ..utils import IS_NEW_HUGGINGFACE_HUB, retry_download, symlink_local_file
66
+ from .core import ImageModelFamilyV2
67
+
68
+ if not lightning_version:
69
+ return None
70
+
71
+ assert isinstance(self._model_family, ImageModelFamilyV2)
72
+ cache_dir = self.get_cache_dir()
73
+
74
+ if not self._model_family.lightning_model_file_name_template:
75
+ raise NotImplementedError(
76
+ f"{self._model_family.model_name} does not support lightning"
77
+ )
78
+ if lightning_version not in (self._model_family.lightning_versions or []):
79
+ raise ValueError(
80
+ f"Cannot support lightning version {lightning_version}, "
81
+ f"available lightning version: {self._model_family.lightning_versions}"
82
+ )
83
+
84
+ filename = self._model_family.lightning_model_file_name_template.format(lightning_version=lightning_version) # type: ignore
85
+ full_path = os.path.join(cache_dir, filename)
86
+
87
+ if self._model_family.model_hub == "huggingface":
88
+ import huggingface_hub
89
+
90
+ use_symlinks = {}
91
+ if not IS_NEW_HUGGINGFACE_HUB:
92
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
93
+ download_file_path = retry_download(
94
+ huggingface_hub.hf_hub_download,
95
+ self._model_family.model_name,
96
+ None,
97
+ self._model_family.lightning_model_id,
98
+ filename=filename,
99
+ **use_symlinks,
100
+ )
101
+ if IS_NEW_HUGGINGFACE_HUB:
102
+ symlink_local_file(download_file_path, cache_dir, filename)
103
+ elif self._model_family.model_hub == "modelscope":
104
+ from modelscope.hub.file_download import model_file_download
105
+
106
+ download_file_path = retry_download(
107
+ model_file_download,
108
+ self._model_family.model_name,
109
+ None,
110
+ self._model_family.lightning_model_id,
111
+ filename,
112
+ revision=self._model_family.model_revision,
113
+ )
114
+ symlink_local_file(download_file_path, cache_dir, filename)
115
+ else:
116
+ raise NotImplementedError
117
+
118
+ return full_path
@@ -51,6 +51,10 @@ class ImageModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
51
51
  gguf_model_id: Optional[str]
52
52
  gguf_quantizations: Optional[List[str]]
53
53
  gguf_model_file_name_template: Optional[str]
54
+ lightning_model_id: Optional[str]
55
+ lightning_versions: Optional[List[str]]
56
+ lightning_model_file_name_template: Optional[str]
57
+
54
58
  virtualenv: Optional[VirtualEnvSettings]
55
59
 
56
60
  class Config:
@@ -180,6 +184,8 @@ def create_image_model_instance(
180
184
  model_path: Optional[str] = None,
181
185
  gguf_quantization: Optional[str] = None,
182
186
  gguf_model_path: Optional[str] = None,
187
+ lightning_version: Optional[str] = None,
188
+ lightning_model_path: Optional[str] = None,
183
189
  **kwargs,
184
190
  ) -> Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model]:
185
191
  from .cache_manager import ImageCacheManager
@@ -235,6 +241,8 @@ def create_image_model_instance(
235
241
  model_path = cache_manager.cache()
236
242
  if not gguf_model_path and gguf_quantization:
237
243
  gguf_model_path = cache_manager.cache_gguf(gguf_quantization)
244
+ if not lightning_model_path and lightning_version:
245
+ lightning_model_path = cache_manager.cache_lightning(lightning_version)
238
246
  if peft_model_config is not None:
239
247
  lora_model = peft_model_config.peft_model
240
248
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -262,6 +270,7 @@ def create_image_model_instance(
262
270
  lora_fuse_kwargs=lora_fuse_kwargs,
263
271
  model_spec=model_spec,
264
272
  gguf_model_path=gguf_model_path,
273
+ lightning_model_path=lightning_model_path,
265
274
  **kwargs,
266
275
  )
267
276
  return model
@@ -169,7 +169,184 @@
169
169
  },
170
170
  "virtualenv": {
171
171
  "packages": [
172
- "git+https://github.com/huggingface/diffusers.git",
172
+ "diffusers==0.35.1",
173
+ "peft>=0.17.0",
174
+ "#system_torch#",
175
+ "#system_numpy#"
176
+ ],
177
+ "no_build_isolation": true
178
+ }
179
+ },
180
+ {
181
+ "version": 2,
182
+ "model_name": "Qwen-Image",
183
+ "model_family": "stable_diffusion",
184
+ "model_ability": [
185
+ "text2image",
186
+ "image2image",
187
+ "inpainting"
188
+ ],
189
+ "model_src": {
190
+ "huggingface": {
191
+ "model_id": "Qwen/Qwen-Image",
192
+ "model_revision": "4516c4d3058302ff35cd86c62ffa645d039fefad",
193
+ "gguf_model_id": "city96/Qwen-Image-gguf",
194
+ "gguf_quantizations": [
195
+ "F16",
196
+ "Q3_K_M",
197
+ "Q3_K_S",
198
+ "Q4_0",
199
+ "Q4_1",
200
+ "Q4_K_M",
201
+ "Q4_K_S",
202
+ "Q5_0",
203
+ "Q5_1",
204
+ "Q5_K_M",
205
+ "Q5_K_S",
206
+ "Q6_K",
207
+ "Q8_0"
208
+ ],
209
+ "gguf_model_file_name_template": "qwen-image-{quantization}.gguf",
210
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
211
+ "lightning_versions": [
212
+ "4steps-V1.0-bf16",
213
+ "4steps-V1.0",
214
+ "8steps-V1.0",
215
+ "8steps-V1.1-bf16",
216
+ "8steps-V1.1"
217
+ ],
218
+ "lightning_model_file_name_template": "Qwen-Image-Lightning-{lightning_version}.safetensors"
219
+ },
220
+ "modelscope": {
221
+ "model_id": "Qwen/Qwen-Image",
222
+ "model_revision": "master",
223
+ "gguf_model_id": "city96/Qwen-Image-gguf",
224
+ "gguf_quantizations": [
225
+ "F16",
226
+ "Q3_K_M",
227
+ "Q3_K_S",
228
+ "Q4_0",
229
+ "Q4_1",
230
+ "Q4_K_M",
231
+ "Q4_K_S",
232
+ "Q5_0",
233
+ "Q5_1",
234
+ "Q5_K_M",
235
+ "Q5_K_S",
236
+ "Q6_K",
237
+ "Q8_0"
238
+ ],
239
+ "gguf_model_file_name_template": "qwen-image-{quantization}.gguf",
240
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
241
+ "lightning_versions": [
242
+ "4steps-V1.0-bf16",
243
+ "4steps-V1.0",
244
+ "8steps-V1.0",
245
+ "8steps-V1.1-bf16",
246
+ "8steps-V1.1"
247
+ ],
248
+ "lightning_model_file_name_template": "Qwen-Image-Lightning-{lightning_version}.safetensors"
249
+ }
250
+ },
251
+ "default_model_config": {
252
+ "quantize": true,
253
+ "quantize_text_encoder": "text_encoder",
254
+ "torch_dtype": "bfloat16"
255
+ },
256
+ "default_generate_config": {
257
+ "guidance_scale": 1.0,
258
+ "true_cfg_scale": 1.0
259
+ },
260
+ "virtualenv": {
261
+ "packages": [
262
+ "diffusers==0.35.1",
263
+ "peft>=0.17.0",
264
+ "#system_torch#",
265
+ "#system_numpy#"
266
+ ],
267
+ "no_build_isolation": true
268
+ }
269
+ },
270
+ {
271
+ "version": 2,
272
+ "model_name": "Qwen-Image-Edit",
273
+ "model_family": "stable_diffusion",
274
+ "model_ability": [
275
+ "image2image"
276
+ ],
277
+ "model_src": {
278
+ "huggingface": {
279
+ "model_id": "Qwen/Qwen-Image-Edit",
280
+ "model_revision": "0b71959872ea3bf4d106c578b7c480ebb133dba7",
281
+ "gguf_model_id": "QuantStack/Qwen-Image-Edit-GGUF",
282
+ "gguf_quantizations": [
283
+ "Q2_K",
284
+ "Q3_K_M",
285
+ "Q3_K_S",
286
+ "Q4_0",
287
+ "Q4_1",
288
+ "Q4_K_M",
289
+ "Q4_K_S",
290
+ "Q5_0",
291
+ "Q5_1",
292
+ "Q5_K_M",
293
+ "Q5_K_S",
294
+ "Q6_K",
295
+ "Q8_0"
296
+ ],
297
+ "gguf_model_file_name_template": "Qwen_Image_Edit-{quantization}.gguf",
298
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
299
+ "lightning_versions": [
300
+ "4steps-V1.0-bf16",
301
+ "4steps-V1.0",
302
+ "8steps-V1.0-bf16",
303
+ "8steps-V1.0"
304
+ ],
305
+ "lightning_model_file_name_template": "Qwen-Image-Edit-Lightning-{lightning_version}.safetensors"
306
+ },
307
+ "modelscope": {
308
+ "model_id": "Qwen/Qwen-Image-Edit",
309
+ "model_revision": "master",
310
+ "gguf_model_id": "QuantStack/Qwen-Image-Edit-GGUF",
311
+ "gguf_quantizations": [
312
+ "Q2_K",
313
+ "Q3_K_M",
314
+ "Q3_K_S",
315
+ "Q4_0",
316
+ "Q4_1",
317
+ "Q4_K_M",
318
+ "Q4_K_S",
319
+ "Q5_0",
320
+ "Q5_1",
321
+ "Q5_K_M",
322
+ "Q5_K_S",
323
+ "Q6_K",
324
+ "Q8_0"
325
+ ],
326
+ "gguf_model_file_name_template": "Qwen_Image_Edit-{quantization}.gguf",
327
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
328
+ "lightning_versions": [
329
+ "4steps-V1.0-bf16",
330
+ "4steps-V1.0",
331
+ "8steps-V1.0-bf16",
332
+ "8steps-V1.0"
333
+ ],
334
+ "lightning_model_file_name_template": "Qwen-Image-Edit-Lightning-{lightning_version}.safetensors"
335
+ }
336
+ },
337
+ "default_model_config": {
338
+ "quantize": true,
339
+ "quantize_text_encoder": "text_encoder",
340
+ "torch_dtype": "bfloat16"
341
+ },
342
+ "default_generate_config": {
343
+ "true_cfg_scale": 4.0
344
+ },
345
+ "virtualenv": {
346
+ "packages": [
347
+ "diffusers==0.35.1",
348
+ "peft>=0.17.0",
349
+ "#system_torch#",
173
350
  "#system_numpy#"
174
351
  ],
175
352
  "no_build_isolation": true
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import asyncio
15
16
  import contextlib
16
17
  import gc
17
18
  import importlib
@@ -19,6 +20,7 @@ import inspect
19
20
  import itertools
20
21
  import json
21
22
  import logging
23
+ import math
22
24
  import os
23
25
  import re
24
26
  import sys
@@ -30,7 +32,11 @@ import PIL.Image
30
32
  import torch
31
33
  from PIL import ImageOps
32
34
 
33
- from ....device_utils import get_available_device, move_model_to_available_device
35
+ from ....device_utils import (
36
+ get_available_device,
37
+ gpu_count,
38
+ move_model_to_available_device,
39
+ )
34
40
  from ....types import LoRA
35
41
  from ..sdapi import SDAPIDiffusionModelMixin
36
42
  from ..utils import handle_image_result
@@ -89,6 +95,7 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
89
95
  lora_fuse_kwargs: Optional[Dict] = None,
90
96
  model_spec: Optional["ImageModelFamilyV2"] = None,
91
97
  gguf_model_path: Optional[str] = None,
98
+ lightning_model_path: Optional[str] = None,
92
99
  **kwargs,
93
100
  ):
94
101
  self.model_family = model_spec
@@ -115,6 +122,8 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
115
122
  self._kwargs = kwargs
116
123
  # gguf
117
124
  self._gguf_model_path = gguf_model_path
125
+ # lightning
126
+ self._lightning_model_path = lightning_model_path
118
127
 
119
128
  @property
120
129
  def model_ability(self):
@@ -171,7 +180,32 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
171
180
  )
172
181
  model = model_type.from_pipe(self._model, controlnet=controlnet)
173
182
  else:
174
- model = model_type.from_pipe(self._model)
183
+ try:
184
+ from diffusers import (
185
+ QwenImageImg2ImgPipeline,
186
+ QwenImageInpaintPipeline,
187
+ QwenImagePipeline,
188
+ )
189
+ except ImportError:
190
+ QwenImagePipeline = None
191
+ QwenImageImg2ImgPipeline = None
192
+ QwenImageInpaintPipeline = None
193
+
194
+ if QwenImagePipeline is not None and isinstance(
195
+ self._model, QwenImagePipeline
196
+ ):
197
+ # special process for Qwen-image
198
+ if ability == "image2image":
199
+ model = QwenImageImg2ImgPipeline.from_pipe(
200
+ self._model, torch_dtype=None
201
+ )
202
+ else:
203
+ assert ability == "inpainting"
204
+ model = QwenImageInpaintPipeline.from_pipe(
205
+ self._model, torch_dtype=None
206
+ )
207
+ else:
208
+ model = model_type.from_pipe(self._model)
175
209
  self._load_to_device(model)
176
210
 
177
211
  self._ability_to_models[ability, controlnet_name] = model
@@ -237,27 +271,42 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
237
271
  else:
238
272
  self._quantize_transformer()
239
273
 
274
+ if (device_count := gpu_count()) > 1 and "device_map" not in self._kwargs:
275
+ logger.debug(
276
+ "Device count (%d) > 1, force to set device_map=balanced", device_count
277
+ )
278
+ self._kwargs["device_map"] = "balanced"
279
+
240
280
  logger.debug(
241
281
  "Loading model from %s, kwargs: %s", self._model_path, self._kwargs
242
282
  )
243
- try:
244
- self._model = AutoPipelineModel.from_pretrained(
245
- self._model_path,
246
- **self._kwargs,
247
- )
248
- except ValueError:
249
- if "kontext" in self._model_spec.model_name.lower():
250
- # TODO: remove this branch when auto pipeline supports
251
- # flux.1-kontext-dev
252
- from diffusers import FluxKontextPipeline
253
-
254
- self._model = FluxKontextPipeline.from_pretrained(
255
- self._model_path, **self._kwargs
283
+ with self._process_lightning(self._kwargs):
284
+ try:
285
+ self._model = AutoPipelineModel.from_pretrained(
286
+ self._model_path,
287
+ **self._kwargs,
256
288
  )
257
- else:
258
- raise
259
- self._load_to_device(self._model)
260
- self._apply_lora()
289
+ except ValueError:
290
+ if "kontext" in self._model_spec.model_name.lower():
291
+ # TODO: remove this branch when auto pipeline supports
292
+ # flux.1-kontext-dev
293
+ from diffusers import FluxKontextPipeline
294
+
295
+ self._model = FluxKontextPipeline.from_pretrained(
296
+ self._model_path, **self._kwargs
297
+ )
298
+ elif "qwen" in self._model_spec.model_name.lower():
299
+ # TODO: remove this branch when auto pipeline supports
300
+ # Qwen-Image
301
+ from diffusers import DiffusionPipeline
302
+
303
+ self._model = DiffusionPipeline.from_pretrained(
304
+ self._model_path, **self._kwargs
305
+ )
306
+ else:
307
+ raise
308
+ self._load_to_device(self._model)
309
+ self._apply_lora()
261
310
 
262
311
  if self._kwargs.get("deepcache", False):
263
312
  try:
@@ -348,11 +397,19 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
348
397
  return
349
398
 
350
399
  if not quantize_text_encoder:
400
+ logger.debug("No text encoder quantization")
351
401
  return
352
402
 
353
403
  quantization_method = self._kwargs.pop("text_encoder_quantize_method", "bnb")
354
404
  quantization = self._kwargs.pop("text_encoder_quantization", "8-bit")
355
405
 
406
+ logger.debug(
407
+ "Quantize text encoder %s with method %s, quantization %s",
408
+ quantize_text_encoder,
409
+ quantization_method,
410
+ quantization,
411
+ )
412
+
356
413
  torch_dtype = self._torch_dtype
357
414
  for text_encoder_name in quantize_text_encoder.split(","):
358
415
  quantization_kwargs: Dict[str, Any] = {}
@@ -389,8 +446,13 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
389
446
 
390
447
  if not quantization:
391
448
  # skip if no quantization specified
449
+ logger.debug("No transformer quantization")
392
450
  return
393
451
 
452
+ logger.debug(
453
+ "Quantize transformer with %s, quantization %s", method, quantization
454
+ )
455
+
394
456
  torch_dtype = self._torch_dtype
395
457
  transformer_cls = self._get_layer_cls("transformer")
396
458
  quantization_config = self._get_quantize_config(
@@ -409,6 +471,7 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
409
471
 
410
472
  # GGUF transformer
411
473
  torch_dtype = self._torch_dtype
474
+ logger.debug("Quantize transformer with gguf file %s", self._gguf_model_path)
412
475
  self._kwargs["transformer"] = self._get_layer_cls(
413
476
  "transformer"
414
477
  ).from_single_file(
@@ -418,6 +481,44 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
418
481
  config=os.path.join(self._model_path, "transformer"),
419
482
  )
420
483
 
484
+ @contextlib.contextmanager
485
+ def _process_lightning(self, kwargs):
486
+ lightning_model_path = self._lightning_model_path
487
+ if not lightning_model_path:
488
+ yield
489
+ return
490
+
491
+ from diffusers import FlowMatchEulerDiscreteScheduler
492
+
493
+ if "qwen" in self._model_spec.model_name.lower():
494
+ scheduler_config = {
495
+ "base_image_seq_len": 256,
496
+ "base_shift": math.log(3), # We use shift=3 in distillation
497
+ "invert_sigmas": False,
498
+ "max_image_seq_len": 8192,
499
+ "max_shift": math.log(3), # We use shift=3 in distillation
500
+ "num_train_timesteps": 1000,
501
+ "shift": 1.0,
502
+ "shift_terminal": None, # set shift_terminal to None
503
+ "stochastic_sampling": False,
504
+ "time_shift_type": "exponential",
505
+ "use_beta_sigmas": False,
506
+ "use_dynamic_shifting": True,
507
+ "use_exponential_sigmas": False,
508
+ "use_karras_sigmas": False,
509
+ }
510
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
511
+ kwargs["scheduler"] = scheduler
512
+
513
+ yield
514
+
515
+ model = self._model
516
+ logger.debug("Loading lightning lora: %s", self._lightning_model_path)
517
+ model.load_lora_weights(self._lightning_model_path)
518
+ else:
519
+ logger.debug("No lightning applied")
520
+ yield
521
+
421
522
  def _load_to_device(self, model):
422
523
  if self._kwargs.get("cpu_offload", False):
423
524
  logger.debug("CPU offloading model")
@@ -665,7 +766,6 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
665
766
  await self._image_batch_scheduler.add_request(
666
767
  prompt, future, n, size, response_format, **kwargs
667
768
  )
668
- import asyncio
669
769
 
670
770
  fut = asyncio.wrap_future(future)
671
771
  return await fut
@@ -680,6 +780,18 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
680
780
  if self._image_batch_scheduler and not self._image_batch_scheduler._running:
681
781
  await self._image_batch_scheduler.start()
682
782
 
783
+ def _gen_config_for_lightning(self, kwargs):
784
+ if (
785
+ not kwargs.get("num_inference_steps")
786
+ and self._lightning_model_path is not None
787
+ ):
788
+ is_4_steps = "4steps" in self._lightning_model_path
789
+ if is_4_steps:
790
+ kwargs["num_inference_steps"] = 4
791
+ else:
792
+ assert "8steps" in self._lightning_model_path
793
+ kwargs["num_inference_steps"] = 8
794
+
683
795
  async def _direct_text_to_image(
684
796
  self,
685
797
  prompt: str,
@@ -692,14 +804,28 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
692
804
  generate_kwargs = self._model_spec.default_generate_config.copy() # type: ignore
693
805
  generate_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
694
806
  generate_kwargs["width"], generate_kwargs["height"] = width, height
807
+ self._gen_config_for_lightning(generate_kwargs)
695
808
 
696
- return self._call_model(
697
- prompt=prompt,
698
- num_images_per_prompt=n,
809
+ return await asyncio.to_thread(
810
+ self._call_model,
811
+ prompt=prompt, # type: ignore
812
+ num_images_per_prompt=n, # type: ignore
699
813
  response_format=response_format,
700
814
  **generate_kwargs,
701
815
  )
702
816
 
817
+ async def abort_request(self, request_id: str) -> str:
818
+ """Abort a running request."""
819
+ from ....model.scheduler.core import AbortRequestMessage
820
+
821
+ # Check if we have a cancel callback for this request
822
+ if hasattr(self, "_cancel_callbacks") and request_id in self._cancel_callbacks:
823
+ cancel_callback = self._cancel_callbacks.pop(request_id)
824
+ cancel_callback()
825
+ return AbortRequestMessage.DONE.name
826
+
827
+ return AbortRequestMessage.NO_OP.name
828
+
703
829
  @staticmethod
704
830
  def pad_to_multiple(image, multiple=8):
705
831
  x, y = image.size
@@ -747,6 +873,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
747
873
  if allow_width_height:
748
874
  kwargs["width"], kwargs["height"] = image.size
749
875
 
876
+ # generate config for lightning
877
+ self._gen_config_for_lightning(kwargs)
878
+
750
879
  return self._call_model(
751
880
  image=image,
752
881
  prompt=prompt,
@@ -797,6 +926,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
797
926
  # calculate actual image size after padding
798
927
  kwargs["width"], kwargs["height"] = image.size
799
928
 
929
+ # generate config for lightning
930
+ self._gen_config_for_lightning(kwargs)
931
+
800
932
  return self._call_model(
801
933
  image=image,
802
934
  mask_image=mask_image,
@@ -1,3 +1,17 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import logging
2
16
  import os
3
17
  from typing import TYPE_CHECKING, Optional
@@ -81,7 +95,7 @@ class LLMCacheManager(CacheManager):
81
95
  if not IS_NEW_HUGGINGFACE_HUB:
82
96
  use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
83
97
 
84
- if self._model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
98
+ if self._model_format in ["pytorch", "gptq", "awq", "fp8", "bnb", "mlx"]:
85
99
  download_dir = retry_download(
86
100
  huggingface_hub.snapshot_download,
87
101
  self._model_name,
@@ -144,7 +158,7 @@ class LLMCacheManager(CacheManager):
144
158
  if self.get_cache_status():
145
159
  return cache_dir
146
160
 
147
- if self._model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
161
+ if self._model_format in ["pytorch", "gptq", "awq", "bnb", "fp8", "bnb", "mlx"]:
148
162
  download_dir = retry_download(
149
163
  snapshot_download,
150
164
  self._model_name,
@@ -234,7 +248,7 @@ class LLMCacheManager(CacheManager):
234
248
  if self.get_cache_status():
235
249
  return cache_dir
236
250
 
237
- if self._model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
251
+ if self._model_format in ["pytorch", "gptq", "awq", "fp8", "bnb", "mlx"]:
238
252
  download_dir = retry_download(
239
253
  snapshot_download,
240
254
  self._model_name,