xinference 1.9.0__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (92) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +415 -1
  3. xinference/constants.py +2 -0
  4. xinference/core/model.py +3 -4
  5. xinference/core/supervisor.py +29 -1
  6. xinference/core/worker.py +4 -1
  7. xinference/deploy/cmdline.py +2 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/model/audio/core.py +5 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +64 -20
  14. xinference/model/embedding/flag/core.py +5 -0
  15. xinference/model/embedding/llama_cpp/core.py +22 -19
  16. xinference/model/embedding/sentence_transformers/core.py +19 -4
  17. xinference/model/embedding/vllm/core.py +40 -8
  18. xinference/model/image/cache_manager.py +56 -0
  19. xinference/model/image/core.py +9 -0
  20. xinference/model/image/model_spec.json +116 -9
  21. xinference/model/image/stable_diffusion/core.py +141 -31
  22. xinference/model/llm/core.py +10 -0
  23. xinference/model/llm/llama_cpp/core.py +42 -40
  24. xinference/model/llm/llm_family.json +435 -23
  25. xinference/model/llm/llm_family.py +1 -0
  26. xinference/model/llm/mlx/core.py +52 -33
  27. xinference/model/llm/sglang/core.py +2 -44
  28. xinference/model/llm/tool_parsers/__init__.py +58 -0
  29. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  30. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +128 -0
  31. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  32. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  33. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  34. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  35. xinference/model/llm/transformers/core.py +6 -12
  36. xinference/model/llm/utils.py +128 -46
  37. xinference/model/llm/vllm/core.py +8 -61
  38. xinference/model/rerank/core.py +3 -0
  39. xinference/model/rerank/sentence_transformers/core.py +1 -1
  40. xinference/model/rerank/vllm/core.py +56 -6
  41. xinference/model/utils.py +1 -2
  42. xinference/model/video/model_spec.json +95 -1
  43. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  44. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  45. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  46. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  47. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  48. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  49. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  50. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  51. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  52. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  53. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  54. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  55. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  56. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  57. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  58. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  59. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  60. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  61. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  62. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  63. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  64. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  65. xinference/types.py +105 -2
  66. xinference/ui/gradio/chat_interface.py +2 -0
  67. xinference/ui/gradio/media_interface.py +353 -7
  68. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  69. xinference/ui/web/ui/build/index.html +1 -1
  70. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  71. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  72. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  73. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  74. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  75. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  76. xinference/ui/web/ui/src/locales/en.json +2 -0
  77. xinference/ui/web/ui/src/locales/ja.json +2 -0
  78. xinference/ui/web/ui/src/locales/ko.json +2 -0
  79. xinference/ui/web/ui/src/locales/zh.json +2 -0
  80. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/METADATA +16 -12
  81. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/RECORD +86 -77
  82. xinference/ui/web/ui/build/static/js/main.4918643a.js +0 -3
  83. xinference/ui/web/ui/build/static/js/main.4918643a.js.map +0 -1
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +0 -1
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +0 -1
  88. /xinference/ui/web/ui/build/static/js/{main.4918643a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  89. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/WHEEL +0 -0
  90. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/entry_points.txt +0 -0
  91. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/licenses/LICENSE +0 -0
  92. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,10 @@ class ImageModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
51
51
  gguf_model_id: Optional[str]
52
52
  gguf_quantizations: Optional[List[str]]
53
53
  gguf_model_file_name_template: Optional[str]
54
+ lightning_model_id: Optional[str]
55
+ lightning_versions: Optional[List[str]]
56
+ lightning_model_file_name_template: Optional[str]
57
+
54
58
  virtualenv: Optional[VirtualEnvSettings]
55
59
 
56
60
  class Config:
@@ -180,6 +184,8 @@ def create_image_model_instance(
180
184
  model_path: Optional[str] = None,
181
185
  gguf_quantization: Optional[str] = None,
182
186
  gguf_model_path: Optional[str] = None,
187
+ lightning_version: Optional[str] = None,
188
+ lightning_model_path: Optional[str] = None,
183
189
  **kwargs,
184
190
  ) -> Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model]:
185
191
  from .cache_manager import ImageCacheManager
@@ -235,6 +241,8 @@ def create_image_model_instance(
235
241
  model_path = cache_manager.cache()
236
242
  if not gguf_model_path and gguf_quantization:
237
243
  gguf_model_path = cache_manager.cache_gguf(gguf_quantization)
244
+ if not lightning_model_path and lightning_version:
245
+ lightning_model_path = cache_manager.cache_lightning(lightning_version)
238
246
  if peft_model_config is not None:
239
247
  lora_model = peft_model_config.peft_model
240
248
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -262,6 +270,7 @@ def create_image_model_instance(
262
270
  lora_fuse_kwargs=lora_fuse_kwargs,
263
271
  model_spec=model_spec,
264
272
  gguf_model_path=gguf_model_path,
273
+ lightning_model_path=lightning_model_path,
265
274
  **kwargs,
266
275
  )
267
276
  return model
@@ -169,7 +169,9 @@
169
169
  },
170
170
  "virtualenv": {
171
171
  "packages": [
172
- "git+https://github.com/huggingface/diffusers.git",
172
+ "diffusers==0.35.1",
173
+ "peft>=0.17.0",
174
+ "#system_torch#",
173
175
  "#system_numpy#"
174
176
  ],
175
177
  "no_build_isolation": true
@@ -180,7 +182,9 @@
180
182
  "model_name": "Qwen-Image",
181
183
  "model_family": "stable_diffusion",
182
184
  "model_ability": [
183
- "text2image"
185
+ "text2image",
186
+ "image2image",
187
+ "inpainting"
184
188
  ],
185
189
  "model_src": {
186
190
  "huggingface": {
@@ -202,7 +206,16 @@
202
206
  "Q6_K",
203
207
  "Q8_0"
204
208
  ],
205
- "gguf_model_file_name_template": "qwen-image-{quantization}.gguf"
209
+ "gguf_model_file_name_template": "qwen-image-{quantization}.gguf",
210
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
211
+ "lightning_versions": [
212
+ "4steps-V1.0-bf16",
213
+ "4steps-V1.0",
214
+ "8steps-V1.0",
215
+ "8steps-V1.1-bf16",
216
+ "8steps-V1.1"
217
+ ],
218
+ "lightning_model_file_name_template": "Qwen-Image-Lightning-{lightning_version}.safetensors"
206
219
  },
207
220
  "modelscope": {
208
221
  "model_id": "Qwen/Qwen-Image",
@@ -223,7 +236,102 @@
223
236
  "Q6_K",
224
237
  "Q8_0"
225
238
  ],
226
- "gguf_model_file_name_template": "qwen-image-{quantization}.gguf"
239
+ "gguf_model_file_name_template": "qwen-image-{quantization}.gguf",
240
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
241
+ "lightning_versions": [
242
+ "4steps-V1.0-bf16",
243
+ "4steps-V1.0",
244
+ "8steps-V1.0",
245
+ "8steps-V1.1-bf16",
246
+ "8steps-V1.1"
247
+ ],
248
+ "lightning_model_file_name_template": "Qwen-Image-Lightning-{lightning_version}.safetensors"
249
+ }
250
+ },
251
+ "default_model_config": {
252
+ "quantize": true,
253
+ "quantize_text_encoder": "text_encoder",
254
+ "torch_dtype": "bfloat16"
255
+ },
256
+ "default_generate_config": {
257
+ "guidance_scale": 1.0,
258
+ "true_cfg_scale": 1.0
259
+ },
260
+ "virtualenv": {
261
+ "packages": [
262
+ "diffusers==0.35.1",
263
+ "peft>=0.17.0",
264
+ "#system_torch#",
265
+ "#system_numpy#"
266
+ ],
267
+ "no_build_isolation": true
268
+ }
269
+ },
270
+ {
271
+ "version": 2,
272
+ "model_name": "Qwen-Image-Edit",
273
+ "model_family": "stable_diffusion",
274
+ "model_ability": [
275
+ "image2image"
276
+ ],
277
+ "model_src": {
278
+ "huggingface": {
279
+ "model_id": "Qwen/Qwen-Image-Edit",
280
+ "model_revision": "0b71959872ea3bf4d106c578b7c480ebb133dba7",
281
+ "gguf_model_id": "QuantStack/Qwen-Image-Edit-GGUF",
282
+ "gguf_quantizations": [
283
+ "Q2_K",
284
+ "Q3_K_M",
285
+ "Q3_K_S",
286
+ "Q4_0",
287
+ "Q4_1",
288
+ "Q4_K_M",
289
+ "Q4_K_S",
290
+ "Q5_0",
291
+ "Q5_1",
292
+ "Q5_K_M",
293
+ "Q5_K_S",
294
+ "Q6_K",
295
+ "Q8_0"
296
+ ],
297
+ "gguf_model_file_name_template": "Qwen_Image_Edit-{quantization}.gguf",
298
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
299
+ "lightning_versions": [
300
+ "4steps-V1.0-bf16",
301
+ "4steps-V1.0",
302
+ "8steps-V1.0-bf16",
303
+ "8steps-V1.0"
304
+ ],
305
+ "lightning_model_file_name_template": "Qwen-Image-Edit-Lightning-{lightning_version}.safetensors"
306
+ },
307
+ "modelscope": {
308
+ "model_id": "Qwen/Qwen-Image-Edit",
309
+ "model_revision": "master",
310
+ "gguf_model_id": "QuantStack/Qwen-Image-Edit-GGUF",
311
+ "gguf_quantizations": [
312
+ "Q2_K",
313
+ "Q3_K_M",
314
+ "Q3_K_S",
315
+ "Q4_0",
316
+ "Q4_1",
317
+ "Q4_K_M",
318
+ "Q4_K_S",
319
+ "Q5_0",
320
+ "Q5_1",
321
+ "Q5_K_M",
322
+ "Q5_K_S",
323
+ "Q6_K",
324
+ "Q8_0"
325
+ ],
326
+ "gguf_model_file_name_template": "Qwen_Image_Edit-{quantization}.gguf",
327
+ "lightning_model_id": "lightx2v/Qwen-Image-Lightning",
328
+ "lightning_versions": [
329
+ "4steps-V1.0-bf16",
330
+ "4steps-V1.0",
331
+ "8steps-V1.0-bf16",
332
+ "8steps-V1.0"
333
+ ],
334
+ "lightning_model_file_name_template": "Qwen-Image-Edit-Lightning-{lightning_version}.safetensors"
227
335
  }
228
336
  },
229
337
  "default_model_config": {
@@ -232,11 +340,11 @@
232
340
  "torch_dtype": "bfloat16"
233
341
  },
234
342
  "default_generate_config": {
235
- "guidance_scale": 1.0
343
+ "true_cfg_scale": 4.0
236
344
  },
237
345
  "virtualenv": {
238
346
  "packages": [
239
- "git+https://github.com/huggingface/diffusers.git",
347
+ "diffusers==0.35.1",
240
348
  "peft>=0.17.0",
241
349
  "#system_torch#",
242
350
  "#system_numpy#"
@@ -716,13 +824,12 @@
716
824
  "deepspeed==0.12.3",
717
825
  "peft==0.4.0",
718
826
  "tiktoken==0.6.0",
719
- "bitsandbytes==0.41.0",
720
- "scikit-learn==1.2.2",
721
827
  "sentencepiece==0.1.99",
722
828
  "einops==0.6.1",
723
829
  "einops-exts==0.0.4",
724
830
  "timm==0.6.13",
725
- "numpy==1.26.4"
831
+ "#system_numpy#",
832
+ "#system_torch#"
726
833
  ]
727
834
  },
728
835
  "model_src": {
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import asyncio
15
16
  import contextlib
16
17
  import gc
17
18
  import importlib
@@ -19,6 +20,7 @@ import inspect
19
20
  import itertools
20
21
  import json
21
22
  import logging
23
+ import math
22
24
  import os
23
25
  import re
24
26
  import sys
@@ -30,7 +32,11 @@ import PIL.Image
30
32
  import torch
31
33
  from PIL import ImageOps
32
34
 
33
- from ....device_utils import get_available_device, move_model_to_available_device
35
+ from ....device_utils import (
36
+ get_available_device,
37
+ gpu_count,
38
+ move_model_to_available_device,
39
+ )
34
40
  from ....types import LoRA
35
41
  from ..sdapi import SDAPIDiffusionModelMixin
36
42
  from ..utils import handle_image_result
@@ -89,6 +95,7 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
89
95
  lora_fuse_kwargs: Optional[Dict] = None,
90
96
  model_spec: Optional["ImageModelFamilyV2"] = None,
91
97
  gguf_model_path: Optional[str] = None,
98
+ lightning_model_path: Optional[str] = None,
92
99
  **kwargs,
93
100
  ):
94
101
  self.model_family = model_spec
@@ -115,6 +122,8 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
115
122
  self._kwargs = kwargs
116
123
  # gguf
117
124
  self._gguf_model_path = gguf_model_path
125
+ # lightning
126
+ self._lightning_model_path = lightning_model_path
118
127
 
119
128
  @property
120
129
  def model_ability(self):
@@ -171,7 +180,32 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
171
180
  )
172
181
  model = model_type.from_pipe(self._model, controlnet=controlnet)
173
182
  else:
174
- model = model_type.from_pipe(self._model)
183
+ try:
184
+ from diffusers import (
185
+ QwenImageImg2ImgPipeline,
186
+ QwenImageInpaintPipeline,
187
+ QwenImagePipeline,
188
+ )
189
+ except ImportError:
190
+ QwenImagePipeline = None
191
+ QwenImageImg2ImgPipeline = None
192
+ QwenImageInpaintPipeline = None
193
+
194
+ if QwenImagePipeline is not None and isinstance(
195
+ self._model, QwenImagePipeline
196
+ ):
197
+ # special process for Qwen-image
198
+ if ability == "image2image":
199
+ model = QwenImageImg2ImgPipeline.from_pipe(
200
+ self._model, torch_dtype=None
201
+ )
202
+ else:
203
+ assert ability == "inpainting"
204
+ model = QwenImageInpaintPipeline.from_pipe(
205
+ self._model, torch_dtype=None
206
+ )
207
+ else:
208
+ model = model_type.from_pipe(self._model)
175
209
  self._load_to_device(model)
176
210
 
177
211
  self._ability_to_models[ability, controlnet_name] = model
@@ -237,35 +271,42 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
237
271
  else:
238
272
  self._quantize_transformer()
239
273
 
274
+ if (device_count := gpu_count()) > 1 and "device_map" not in self._kwargs:
275
+ logger.debug(
276
+ "Device count (%d) > 1, force to set device_map=balanced", device_count
277
+ )
278
+ self._kwargs["device_map"] = "balanced"
279
+
240
280
  logger.debug(
241
281
  "Loading model from %s, kwargs: %s", self._model_path, self._kwargs
242
282
  )
243
- try:
244
- self._model = AutoPipelineModel.from_pretrained(
245
- self._model_path,
246
- **self._kwargs,
247
- )
248
- except ValueError:
249
- if "kontext" in self._model_spec.model_name.lower():
250
- # TODO: remove this branch when auto pipeline supports
251
- # flux.1-kontext-dev
252
- from diffusers import FluxKontextPipeline
253
-
254
- self._model = FluxKontextPipeline.from_pretrained(
255
- self._model_path, **self._kwargs
283
+ with self._process_lightning(self._kwargs):
284
+ try:
285
+ self._model = AutoPipelineModel.from_pretrained(
286
+ self._model_path,
287
+ **self._kwargs,
256
288
  )
257
- elif "qwen" in self._model_spec.model_name.lower():
258
- # TODO: remove this branch when auto pipeline supports
259
- # Qwen-Image
260
- from diffusers import DiffusionPipeline
261
-
262
- self._model = DiffusionPipeline.from_pretrained(
263
- self._model_path, **self._kwargs
264
- )
265
- else:
266
- raise
267
- self._load_to_device(self._model)
268
- self._apply_lora()
289
+ except ValueError:
290
+ if "kontext" in self._model_spec.model_name.lower():
291
+ # TODO: remove this branch when auto pipeline supports
292
+ # flux.1-kontext-dev
293
+ from diffusers import FluxKontextPipeline
294
+
295
+ self._model = FluxKontextPipeline.from_pretrained(
296
+ self._model_path, **self._kwargs
297
+ )
298
+ elif "qwen" in self._model_spec.model_name.lower():
299
+ # TODO: remove this branch when auto pipeline supports
300
+ # Qwen-Image
301
+ from diffusers import DiffusionPipeline
302
+
303
+ self._model = DiffusionPipeline.from_pretrained(
304
+ self._model_path, **self._kwargs
305
+ )
306
+ else:
307
+ raise
308
+ self._load_to_device(self._model)
309
+ self._apply_lora()
269
310
 
270
311
  if self._kwargs.get("deepcache", False):
271
312
  try:
@@ -440,6 +481,44 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
440
481
  config=os.path.join(self._model_path, "transformer"),
441
482
  )
442
483
 
484
+ @contextlib.contextmanager
485
+ def _process_lightning(self, kwargs):
486
+ lightning_model_path = self._lightning_model_path
487
+ if not lightning_model_path:
488
+ yield
489
+ return
490
+
491
+ from diffusers import FlowMatchEulerDiscreteScheduler
492
+
493
+ if "qwen" in self._model_spec.model_name.lower():
494
+ scheduler_config = {
495
+ "base_image_seq_len": 256,
496
+ "base_shift": math.log(3), # We use shift=3 in distillation
497
+ "invert_sigmas": False,
498
+ "max_image_seq_len": 8192,
499
+ "max_shift": math.log(3), # We use shift=3 in distillation
500
+ "num_train_timesteps": 1000,
501
+ "shift": 1.0,
502
+ "shift_terminal": None, # set shift_terminal to None
503
+ "stochastic_sampling": False,
504
+ "time_shift_type": "exponential",
505
+ "use_beta_sigmas": False,
506
+ "use_dynamic_shifting": True,
507
+ "use_exponential_sigmas": False,
508
+ "use_karras_sigmas": False,
509
+ }
510
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
511
+ kwargs["scheduler"] = scheduler
512
+
513
+ yield
514
+
515
+ model = self._model
516
+ logger.debug("Loading lightning lora: %s", self._lightning_model_path)
517
+ model.load_lora_weights(self._lightning_model_path)
518
+ else:
519
+ logger.debug("No lightning applied")
520
+ yield
521
+
443
522
  def _load_to_device(self, model):
444
523
  if self._kwargs.get("cpu_offload", False):
445
524
  logger.debug("CPU offloading model")
@@ -687,7 +766,6 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
687
766
  await self._image_batch_scheduler.add_request(
688
767
  prompt, future, n, size, response_format, **kwargs
689
768
  )
690
- import asyncio
691
769
 
692
770
  fut = asyncio.wrap_future(future)
693
771
  return await fut
@@ -702,6 +780,18 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
702
780
  if self._image_batch_scheduler and not self._image_batch_scheduler._running:
703
781
  await self._image_batch_scheduler.start()
704
782
 
783
+ def _gen_config_for_lightning(self, kwargs):
784
+ if (
785
+ not kwargs.get("num_inference_steps")
786
+ and self._lightning_model_path is not None
787
+ ):
788
+ is_4_steps = "4steps" in self._lightning_model_path
789
+ if is_4_steps:
790
+ kwargs["num_inference_steps"] = 4
791
+ else:
792
+ assert "8steps" in self._lightning_model_path
793
+ kwargs["num_inference_steps"] = 8
794
+
705
795
  async def _direct_text_to_image(
706
796
  self,
707
797
  prompt: str,
@@ -714,14 +804,28 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
714
804
  generate_kwargs = self._model_spec.default_generate_config.copy() # type: ignore
715
805
  generate_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
716
806
  generate_kwargs["width"], generate_kwargs["height"] = width, height
807
+ self._gen_config_for_lightning(generate_kwargs)
717
808
 
718
- return self._call_model(
719
- prompt=prompt,
720
- num_images_per_prompt=n,
809
+ return await asyncio.to_thread(
810
+ self._call_model,
811
+ prompt=prompt, # type: ignore
812
+ num_images_per_prompt=n, # type: ignore
721
813
  response_format=response_format,
722
814
  **generate_kwargs,
723
815
  )
724
816
 
817
+ async def abort_request(self, request_id: str) -> str:
818
+ """Abort a running request."""
819
+ from ....model.scheduler.core import AbortRequestMessage
820
+
821
+ # Check if we have a cancel callback for this request
822
+ if hasattr(self, "_cancel_callbacks") and request_id in self._cancel_callbacks:
823
+ cancel_callback = self._cancel_callbacks.pop(request_id)
824
+ cancel_callback()
825
+ return AbortRequestMessage.DONE.name
826
+
827
+ return AbortRequestMessage.NO_OP.name
828
+
725
829
  @staticmethod
726
830
  def pad_to_multiple(image, multiple=8):
727
831
  x, y = image.size
@@ -769,6 +873,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
769
873
  if allow_width_height:
770
874
  kwargs["width"], kwargs["height"] = image.size
771
875
 
876
+ # generate config for lightning
877
+ self._gen_config_for_lightning(kwargs)
878
+
772
879
  return self._call_model(
773
880
  image=image,
774
881
  prompt=prompt,
@@ -819,6 +926,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
819
926
  # calculate actual image size after padding
820
927
  kwargs["width"], kwargs["height"] = image.size
821
928
 
929
+ # generate config for lightning
930
+ self._gen_config_for_lightning(kwargs)
931
+
822
932
  return self._call_model(
823
933
  image=image,
824
934
  mask_image=mask_image,
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
27
27
  from ...core.utils import parse_replica_model_uid
28
28
  from ...types import PeftModelConfig
29
29
  from .reasoning_parser import ReasoningParser
30
+ from .tool_parsers import TOOL_PARSERS
30
31
 
31
32
  if TYPE_CHECKING:
32
33
  from .llm_family import LLMFamilyV2, LLMSpecV1
@@ -59,6 +60,7 @@ class LLM(abc.ABC):
59
60
  self.quantization = model_family.model_specs[0].quantization
60
61
  self.model_path = model_path
61
62
  self.reasoning_parser = None
63
+ self.tool_parser = None
62
64
  if args:
63
65
  raise ValueError(f"Unrecognized positional arguments: {args}")
64
66
  if kwargs:
@@ -171,6 +173,14 @@ class LLM(abc.ABC):
171
173
  enable_thinking=enable_thinking,
172
174
  )
173
175
 
176
+ def prepare_parse_tool_calls(self):
177
+ if self.model_family.tool_parser is None:
178
+ return
179
+ if self.model_family.tool_parser not in TOOL_PARSERS:
180
+ return
181
+ tool_parser = TOOL_PARSERS[self.model_family.tool_parser]
182
+ self.tool_parser = tool_parser()
183
+
174
184
 
175
185
  # Context variable for passing per-request chat context (e.g., chat_template_kwargs).
176
186
  # This variable should be set at the beginning of each chat or stream_chat call.
@@ -19,11 +19,11 @@ import pprint
19
19
  import queue
20
20
  from typing import Iterator, List, Optional, Union
21
21
 
22
- import orjson
22
+ from packaging import version
23
23
 
24
24
  from ....constants import XINFERENCE_MAX_TOKENS
25
25
  from ....types import ChatCompletion, ChatCompletionChunk, Completion, CompletionChunk
26
- from ..core import LLM
26
+ from ..core import LLM, chat_context_var
27
27
  from ..llm_family import LLMFamilyV2, LLMSpecV1
28
28
  from ..utils import ChatModelMixin
29
29
 
@@ -98,10 +98,19 @@ class XllamaCppModel(LLM, ChatModelMixin):
98
98
  from xllamacpp import (
99
99
  CommonParams,
100
100
  Server,
101
+ __version__,
101
102
  estimate_gpu_layers,
102
103
  get_device_info,
103
104
  ggml_backend_dev_type,
104
105
  )
106
+
107
+ try:
108
+ if version.parse(__version__) < version.parse("0.2.0"):
109
+ raise RuntimeError(
110
+ "Please update xllamacpp to >= 0.2.0 by `pip install -U xllamacpp`"
111
+ )
112
+ except version.InvalidVersion:
113
+ pass # If the version parse failed, we just skip the version check.
105
114
  except ImportError:
106
115
  error_message = "Failed to import module 'xllamacpp'"
107
116
  installation_guide = ["Please make sure 'xllamacpp' is installed. "]
@@ -113,6 +122,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
113
122
  self.prepare_parse_reasoning_content(
114
123
  reasoning_content, enable_thinking=enable_thinking
115
124
  )
125
+ self.prepare_parse_tool_calls()
116
126
 
117
127
  if os.path.isfile(self.model_path):
118
128
  # mostly passed from --model_path
@@ -160,6 +170,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
160
170
  params.mmproj.path = mmproj
161
171
  if self.model_family.chat_template:
162
172
  params.chat_template = self.model_family.chat_template
173
+ params.use_jinja = True
163
174
  # This is the default value, could be overwritten by _llamacpp_model_config
164
175
  params.n_parallel = min(8, os.cpu_count() or 1)
165
176
  for k, v in self._llamacpp_model_config.items():
@@ -208,7 +219,8 @@ class XllamaCppModel(LLM, ChatModelMixin):
208
219
  )
209
220
  logger.info("Estimate num gpu layers: %s", estimate)
210
221
  if estimate.tensor_split:
211
- params.tensor_split = estimate.tensor_split
222
+ for i in range(len(estimate.tensor_split)):
223
+ params.tensor_split[i] = estimate.tensor_split[i]
212
224
  else:
213
225
  params.n_gpu_layers = estimate.layers
214
226
  except Exception as e:
@@ -242,28 +254,18 @@ class XllamaCppModel(LLM, ChatModelMixin):
242
254
  {
243
255
  "prompt": prompt,
244
256
  "stream": stream,
257
+ "model": self.model_uid,
245
258
  }
246
259
  )
247
- prompt_json = orjson.dumps(data)
248
-
249
- def _error_callback(err):
250
- try:
251
- msg = orjson.loads(err)
252
- q.put(_Error(msg))
253
- except Exception as e:
254
- q.put(_Error(str(e)))
260
+ try:
255
261
 
256
- def _ok_callback(ok):
257
- try:
258
- res = orjson.loads(ok)
259
- res["model"] = self.model_uid
260
- q.put(res)
261
- except Exception as e:
262
- logger.exception("handle_completions callback failed: %s", e)
263
- q.put(_Error(str(e)))
262
+ def _callback(res):
263
+ if res.get("code"):
264
+ q.put(_Error(res))
265
+ else:
266
+ q.put(res)
264
267
 
265
- try:
266
- self._llm.handle_completions(prompt_json, _error_callback, _ok_callback)
268
+ self._llm.handle_completions(data, _callback)
267
269
  except Exception as ex:
268
270
  logger.exception("handle_completions failed: %s", ex)
269
271
  q.put(_Error(str(ex)))
@@ -296,6 +298,15 @@ class XllamaCppModel(LLM, ChatModelMixin):
296
298
  if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
297
299
  generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
298
300
  stream = generate_config.get("stream", False)
301
+
302
+ chat_template_kwargs = (
303
+ self._get_chat_template_kwargs_from_generate_config(
304
+ generate_config, self.reasoning_parser
305
+ )
306
+ or {}
307
+ )
308
+ chat_context_var.set(chat_template_kwargs)
309
+
299
310
  tools = generate_config.pop("tools", []) if generate_config else None
300
311
  q: queue.Queue = queue.Queue()
301
312
 
@@ -310,30 +321,21 @@ class XllamaCppModel(LLM, ChatModelMixin):
310
321
  "messages": messages,
311
322
  "stream": stream,
312
323
  "tools": tools,
324
+ "model": self.model_uid,
313
325
  }
314
326
  )
315
- prompt_json = orjson.dumps(data)
327
+ if chat_template_kwargs:
328
+ data["chat_template_kwargs"] = chat_template_kwargs
316
329
 
317
- def _error_callback(err):
318
- try:
319
- msg = orjson.loads(err)
320
- q.put(_Error(msg))
321
- except Exception as e:
322
- q.put(_Error(str(e)))
330
+ try:
323
331
 
324
- def _ok_callback(ok):
325
- try:
326
- res = orjson.loads(ok)
327
- res["model"] = self.model_uid
328
- q.put(res)
329
- except Exception as e:
330
- logger.exception("handle_chat_completions callback failed: %s", e)
331
- q.put(_Error(str(e)))
332
+ def _callback(res):
333
+ if res.get("code"):
334
+ q.put(_Error(res))
335
+ else:
336
+ q.put(res)
332
337
 
333
- try:
334
- self._llm.handle_chat_completions(
335
- prompt_json, _error_callback, _ok_callback
336
- )
338
+ self._llm.handle_chat_completions(data, _callback)
337
339
  except Exception as ex:
338
340
  logger.exception("handle_chat_completions failed: %s", ex)
339
341
  q.put(_Error(str(ex)))