xinference 1.3.1.post1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (75) hide show
  1. xinference/_compat.py +1 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +4 -0
  4. xinference/core/chat_interface.py +1 -1
  5. xinference/core/model.py +23 -3
  6. xinference/core/supervisor.py +6 -0
  7. xinference/core/worker.py +54 -11
  8. xinference/model/llm/__init__.py +7 -2
  9. xinference/model/llm/core.py +1 -0
  10. xinference/model/llm/llama_cpp/core.py +50 -15
  11. xinference/model/llm/llm_family.json +388 -13
  12. xinference/model/llm/llm_family_modelscope.json +373 -14
  13. xinference/model/llm/mlx/core.py +15 -11
  14. xinference/model/llm/reasoning_parser.py +17 -9
  15. xinference/model/llm/sglang/core.py +112 -12
  16. xinference/model/llm/transformers/core.py +4 -2
  17. xinference/model/llm/transformers/deepseek_vl.py +1 -1
  18. xinference/model/llm/transformers/deepseek_vl2.py +287 -0
  19. xinference/model/llm/transformers/gemma3.py +185 -0
  20. xinference/model/llm/transformers/intern_vl.py +0 -2
  21. xinference/model/llm/utils.py +62 -42
  22. xinference/model/llm/vllm/core.py +157 -11
  23. xinference/model/llm/vllm/distributed_executor.py +314 -0
  24. xinference/model/rerank/core.py +16 -11
  25. xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
  26. xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
  27. xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
  28. xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
  29. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
  30. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
  31. xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
  32. xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
  33. xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
  34. xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
  35. xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
  36. xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
  37. xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
  38. xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
  39. xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
  40. xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
  41. xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
  42. xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
  43. xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
  44. xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
  45. xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
  46. xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
  47. xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
  48. xinference/types.py +2 -2
  49. xinference/web/ui/build/asset-manifest.json +6 -6
  50. xinference/web/ui/build/index.html +1 -1
  51. xinference/web/ui/build/static/css/main.b494ae7e.css +2 -0
  52. xinference/web/ui/build/static/css/main.b494ae7e.css.map +1 -0
  53. xinference/web/ui/build/static/js/main.5ca4eea1.js +3 -0
  54. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/cc97b49285d7717c63374766c789141a4329a04582ab32756d7e0e614d4c5c7f.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +1 -0
  59. xinference/web/ui/src/locales/en.json +2 -2
  60. xinference/web/ui/src/locales/zh.json +1 -1
  61. {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/METADATA +4 -4
  62. {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/RECORD +67 -41
  63. xinference/web/ui/build/static/css/main.f8177338.css +0 -2
  64. xinference/web/ui/build/static/css/main.f8177338.css.map +0 -1
  65. xinference/web/ui/build/static/js/main.55b70cb7.js +0 -3
  66. xinference/web/ui/build/static/js/main.55b70cb7.js.map +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/2deac8d5636974533e3714f34e94fc754f9153a07c6ee11e72846cb8eae47e4b.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/e23d476fcbf6fd69c8986bf82133d257d28aa8fc9a5cab231d81c1c75c58cd99.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/e7a8c37fda8725cab69c7ef8c627060bd7fc806adc67e00fe628ba148cb86d7f.json +0 -1
  71. /xinference/web/ui/build/static/js/{main.55b70cb7.js.LICENSE.txt → main.5ca4eea1.js.LICENSE.txt} +0 -0
  72. {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/LICENSE +0 -0
  73. {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/WHEEL +0 -0
  74. {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/entry_points.txt +0 -0
  75. {xinference-1.3.1.post1.dist-info → xinference-1.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,185 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+ import sys
18
+ import uuid
19
+ from typing import Iterator, List, Optional, Union
20
+
21
+ from ....model.utils import select_device
22
+ from ....types import (
23
+ ChatCompletion,
24
+ ChatCompletionChunk,
25
+ ChatCompletionMessage,
26
+ CompletionChunk,
27
+ )
28
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
29
+ from ..utils import generate_chat_completion, generate_completion_chunk
30
+ from .core import PytorchChatModel, PytorchGenerateConfig
31
+ from .utils import cache_clean
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class Gemma3TextChatModel(PytorchChatModel):
37
+ @classmethod
38
+ def match(
39
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
40
+ ) -> bool:
41
+ if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
42
+ return False
43
+ llm_family = model_family.model_family or model_family.model_name
44
+ if "gemma-3-1b-it".lower() in llm_family.lower():
45
+ return True
46
+ return False
47
+
48
+
49
+ class Gemma3ChatModel(PytorchChatModel):
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ self._tokenizer = None
53
+ self._model = None
54
+ self._device = None
55
+ self._processor = None
56
+
57
+ @classmethod
58
+ def match(
59
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
60
+ ) -> bool:
61
+ if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
62
+ return False
63
+ llm_family = model_family.model_family or model_family.model_name
64
+ if "gemma-3-it".lower() in llm_family.lower():
65
+ return True
66
+ return False
67
+
68
+ def load(self):
69
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
70
+
71
+ device = self._pytorch_model_config.get("device", "auto")
72
+ device = select_device(device)
73
+ self._device = device
74
+ # for multiple GPU, set back to auto to make multiple devices work
75
+ device = "auto" if device == "cuda" else device
76
+
77
+ self._processor = AutoProcessor.from_pretrained(self.model_path)
78
+ self._tokenizer = self._processor.tokenizer
79
+ self._model = Gemma3ForConditionalGeneration.from_pretrained(
80
+ self.model_path,
81
+ device_map="auto",
82
+ torch_dtype="bfloat16",
83
+ )
84
+
85
+ @cache_clean
86
+ def chat(
87
+ self,
88
+ messages: List[ChatCompletionMessage], # type: ignore
89
+ generate_config: Optional[PytorchGenerateConfig] = None,
90
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
91
+ messages = self._transform_messages(messages)
92
+
93
+ generate_config = generate_config if generate_config else {}
94
+
95
+ stream = generate_config.get("stream", False) if generate_config else False
96
+
97
+ if stream:
98
+ it = self._generate_stream(messages, generate_config)
99
+ return self._to_chat_completion_chunks(it)
100
+ else:
101
+ c = self._generate(messages, generate_config)
102
+ return c
103
+
104
+ def _generate(
105
+ self, messages: List, config: PytorchGenerateConfig = {}
106
+ ) -> ChatCompletion:
107
+ inputs = self._processor.apply_chat_template(
108
+ messages,
109
+ add_generation_prompt=True,
110
+ tokenize=True,
111
+ return_dict=True,
112
+ return_tensors="pt",
113
+ ).to(self._device)
114
+ input_len = inputs["input_ids"].shape[-1]
115
+
116
+ generation = self._model.generate(**inputs, do_sample=False)
117
+ generation = generation[0][input_len:]
118
+
119
+ decoded = self._processor.decode(generation, skip_special_tokens=True)
120
+ return generate_chat_completion(self.model_uid, decoded)
121
+
122
+ def _generate_stream(
123
+ self, messages: List, config: PytorchGenerateConfig = {}
124
+ ) -> Iterator[CompletionChunk]:
125
+ from threading import Thread
126
+
127
+ from transformers import TextIteratorStreamer
128
+
129
+ inputs = self._processor.apply_chat_template(
130
+ messages,
131
+ add_generation_prompt=True,
132
+ tokenize=True,
133
+ return_dict=True,
134
+ return_tensors="pt",
135
+ ).to(self._device)
136
+
137
+ tokenizer = self._tokenizer
138
+ streamer = TextIteratorStreamer(
139
+ tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
140
+ )
141
+
142
+ gen_kwargs = {"streamer": streamer, **inputs}
143
+ error = None
144
+
145
+ def model_generate():
146
+ try:
147
+ return self._model.generate(**gen_kwargs)
148
+ except Exception:
149
+ nonlocal error
150
+ error = sys.exc_info()
151
+ streamer.end()
152
+ raise
153
+
154
+ thread = Thread(target=model_generate)
155
+ thread.start()
156
+
157
+ completion_id = str(uuid.uuid1())
158
+ for new_text in streamer:
159
+ yield generate_completion_chunk(
160
+ chunk_text=new_text,
161
+ finish_reason=None,
162
+ chunk_id=completion_id,
163
+ model_uid=self.model_uid,
164
+ prompt_tokens=-1,
165
+ completion_tokens=-1,
166
+ total_tokens=-1,
167
+ has_choice=True,
168
+ has_content=True,
169
+ )
170
+
171
+ if error:
172
+ _, err, tb = error # type: ignore
173
+ raise err.with_traceback(tb)
174
+
175
+ yield generate_completion_chunk(
176
+ chunk_text=None,
177
+ finish_reason="stop",
178
+ chunk_id=completion_id,
179
+ model_uid=self.model_uid,
180
+ prompt_tokens=-1,
181
+ completion_tokens=-1,
182
+ total_tokens=-1,
183
+ has_choice=True,
184
+ has_content=False,
185
+ )
@@ -245,8 +245,6 @@ class InternVLChatModel(PytorchChatModel):
245
245
  family = model_family.model_family or model_family.model_name
246
246
  if "internvl" not in family.lower():
247
247
  return False
248
- if "pytorch" not in model_spec.model_format:
249
- return False
250
248
  return True
251
249
 
252
250
  def _get_model_class(self):
@@ -79,8 +79,7 @@ LLAMA3_TOOL_CALL_FAMILY = [
79
79
  ]
80
80
 
81
81
  DEEPSEEK_TOOL_CALL_FAMILY = [
82
- "deepseek-r1-distill-qwen",
83
- "deepseek-r1-distill-llama",
82
+ "deepseek-v3",
84
83
  ]
85
84
 
86
85
  TOOL_CALL_FAMILY = (
@@ -256,19 +255,26 @@ class ChatModelMixin:
256
255
  and choices
257
256
  and "delta" in choices[0]
258
257
  ):
259
- if reasoning_parser is not None:
260
- # process parsing reasoning content
261
- assert previous_texts is not None
258
+ if choices[0]["finish_reason"] is None:
259
+ if reasoning_parser is not None:
260
+ # process parsing reasoning content
261
+ assert previous_texts is not None
262
+ delta = choices[0]["delta"] # type: ignore
263
+ if text := delta.get("content"):
264
+ current_text = previous_texts[-1] + text
265
+ delta = reasoning_parser.extract_reasoning_content_streaming(
266
+ previous_text=previous_texts[-1],
267
+ current_text=current_text,
268
+ delta_text=text,
269
+ )
270
+ previous_texts[-1] = current_text
271
+ choices[0]["delta"] = delta # type: ignore
272
+ elif choices[0]["finish_reason"] is not None:
262
273
  delta = choices[0]["delta"] # type: ignore
263
- if text := delta.get("content"):
264
- current_text = previous_texts[-1] + text
265
- delta = reasoning_parser.extract_reasoning_content_streaming(
266
- previous_text=previous_texts[-1],
267
- current_text=current_text,
268
- delta_text=text,
269
- )
270
- previous_texts[-1] = current_text
271
- choices[0]["delta"] = delta # type: ignore
274
+ if "content" not in delta:
275
+ delta["content"] = "" # type: ignore
276
+ if reasoning_parser is not None:
277
+ delta["reasoning_content"] = None # type: ignore
272
278
  # Already a ChatCompletionChunk, we don't need to convert chunk.
273
279
  return cast(ChatCompletionChunk, chunk)
274
280
 
@@ -287,7 +293,11 @@ class ChatModelMixin:
287
293
  delta_text=choice["text"],
288
294
  )
289
295
  previous_texts[-1] = current_text
290
- if "tool_calls" in choice:
296
+ elif "text" in choice and choice["finish_reason"] is not None:
297
+ delta["content"] = choice["text"]
298
+ if reasoning_parser is not None:
299
+ delta["reasoning_content"] = None
300
+ elif "tool_calls" in choice:
291
301
  delta["tool_calls"] = choice["tool_calls"]
292
302
  choices_list.append(
293
303
  {
@@ -296,12 +306,19 @@ class ChatModelMixin:
296
306
  "finish_reason": choice["finish_reason"],
297
307
  }
298
308
  )
309
+ assert choices is not None
310
+ usage = (
311
+ chunk["usage"]
312
+ if choices[0]["finish_reason"] is not None and reasoning_parser is not None
313
+ else None
314
+ )
299
315
  chat_chunk = {
300
316
  "id": "chat" + chunk["id"],
301
317
  "model": chunk["model"],
302
318
  "created": chunk["created"],
303
319
  "object": "chat.completion.chunk",
304
320
  "choices": choices_list,
321
+ "usage": usage,
305
322
  }
306
323
  return cast(ChatCompletionChunk, chat_chunk)
307
324
 
@@ -313,12 +330,9 @@ class ChatModelMixin:
313
330
  ) -> ChatCompletionChunk:
314
331
  choices_list = []
315
332
  for i, choice in enumerate(chunk["choices"]):
316
- delta = {
317
- "role": "assistant",
318
- }
319
- if reasoning_parser is None:
320
- delta["content"] = ""
321
- else:
333
+ delta = ChatCompletionChunkDelta(role="assistant", content="")
334
+ if reasoning_parser is not None:
335
+ delta["content"] = None
322
336
  delta["reasoning_content"] = ""
323
337
  choices_list.append(
324
338
  {
@@ -359,9 +373,7 @@ class ChatModelMixin:
359
373
  reasoning_parse: Optional[ReasoningParser] = None,
360
374
  ) -> Iterator[ChatCompletionChunk]:
361
375
  previous_texts = [""]
362
- for i, chunk in enumerate(chunks):
363
- if i == 0:
364
- yield cls._get_first_chat_completion_chunk(chunk, reasoning_parse)
376
+ for _, chunk in enumerate(chunks):
365
377
  # usage
366
378
  choices = chunk.get("choices")
367
379
  if not choices:
@@ -407,14 +419,10 @@ class ChatModelMixin:
407
419
  chunks: AsyncGenerator[CompletionChunk, None],
408
420
  reasoning_parser: Optional[ReasoningParser] = None,
409
421
  ) -> AsyncGenerator[ChatCompletionChunk, None]:
410
- i = 0
411
422
  previous_texts = [""]
412
423
  async for chunk in chunks:
413
- if i == 0:
414
- chat_chunk = cls._get_first_chat_completion_chunk(
415
- chunk, reasoning_parser
416
- )
417
- elif not chunk.get("choices"):
424
+ choices = chunk.get("choices")
425
+ if not choices:
418
426
  # usage
419
427
  chat_chunk = cls._get_final_chat_completion_chunk(chunk)
420
428
  else:
@@ -422,7 +430,6 @@ class ChatModelMixin:
422
430
  chunk, reasoning_parser, previous_texts
423
431
  )
424
432
  yield chat_chunk
425
- i += 1
426
433
 
427
434
  @staticmethod
428
435
  def _to_chat_completion(
@@ -533,7 +540,7 @@ class ChatModelMixin:
533
540
  @classmethod
534
541
  def _eval_deepseek_chat_arguments(cls, c) -> List[Tuple]:
535
542
  """
536
- Parses tool calls from deepseek-r1 format and removes duplicates.
543
+ Parses tool calls from deepseek-v3 format and removes duplicates.
537
544
 
538
545
  Returns:
539
546
  List[Tuple[Optional[str], Optional[str], Optional[dict]]]
@@ -541,20 +548,24 @@ class ChatModelMixin:
541
548
  - (content, None, None) if parsing failed (content is raw JSON text).
542
549
 
543
550
  Example input:
544
- <|tool▁call|>get_current_weather
545
551
  ```json
546
- {"location": "tokyo", "unit": "fahrenheit"}
552
+ {
553
+ "name": "get_weather_and_time",
554
+ "parameters": {
555
+ "location": "Hangzhou"
556
+ }
557
+ }
547
558
  ```
548
559
 
549
560
  Output:
550
561
  [
551
- (None, "get_current_weather", {"location": "tokyo", "unit": "fahrenheit"})
562
+ (None, "get_current_weather", {"location": "Hangzhou"})
552
563
  ]
553
564
  """
554
565
 
555
566
  text = c["choices"][0]["text"]
556
567
 
557
- pattern = r"<|tool▁call|>(\w+)\s*```json\s*(.*?)\s*```"
568
+ pattern = r"\s*```json\s*(.*?)\s*```"
558
569
  matches = re.findall(pattern, text, re.DOTALL)
559
570
 
560
571
  if not matches:
@@ -563,22 +574,31 @@ class ChatModelMixin:
563
574
  tool_calls = set() # Used for deduplication
564
575
  results = []
565
576
 
566
- for function_name, args_json in matches:
577
+ for raw_json in matches:
578
+ func_and_args = None
567
579
  try:
568
- arguments = json.loads(args_json)
580
+ func_and_args = json.loads(raw_json)
569
581
  # Convert dictionary to frozenset for deduplication
570
- arguments_hashable = frozenset(arguments.items())
571
- tool_call_tuple = (None, function_name, arguments)
582
+ arguments_hashable = frozenset(func_and_args["parameters"])
583
+ tool_call_tuple = (
584
+ None,
585
+ func_and_args["name"],
586
+ func_and_args["parameters"],
587
+ )
572
588
  except json.JSONDecodeError:
573
589
  tool_call_tuple = (
574
- args_json,
590
+ raw_json,
575
591
  None,
576
592
  None,
577
593
  ) # If parsing fails, treat as raw content
578
594
  arguments_hashable = None # No need for hashing
579
595
 
580
596
  # Avoid duplicate entries
581
- dedup_key = (function_name, arguments_hashable)
597
+ dedup_key = (
598
+ (func_and_args["name"], arguments_hashable)
599
+ if func_and_args is not None
600
+ else (raw_json)
601
+ )
582
602
  if dedup_key not in tool_calls:
583
603
  tool_calls.add(dedup_key)
584
604
  results.append(tool_call_tuple)
@@ -13,12 +13,16 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ import itertools
16
17
  import json
17
18
  import logging
18
19
  import multiprocessing
19
20
  import os
21
+ import sys
22
+ import threading
20
23
  import time
21
24
  import uuid
25
+ from functools import partial
22
26
  from typing import (
23
27
  TYPE_CHECKING,
24
28
  Any,
@@ -27,10 +31,13 @@ from typing import (
27
31
  List,
28
32
  Optional,
29
33
  Tuple,
34
+ Type,
30
35
  TypedDict,
31
36
  Union,
32
37
  )
33
38
 
39
+ import xoscar as xo
40
+
34
41
  from ....types import (
35
42
  ChatCompletion,
36
43
  ChatCompletionChunk,
@@ -73,6 +80,7 @@ class VLLMModelConfig(TypedDict, total=False):
73
80
  guided_decoding_backend: Optional[str]
74
81
  scheduling_policy: Optional[str]
75
82
  reasoning_content: bool
83
+ model_quantization: Optional[str]
76
84
 
77
85
 
78
86
  class VLLMGenerateConfig(TypedDict, total=False):
@@ -161,6 +169,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
161
169
  VLLM_SUPPORTED_CHAT_MODELS.append("QwQ-32B")
162
170
  VLLM_SUPPORTED_CHAT_MODELS.append("marco-o1")
163
171
  VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-r1-distill-qwen")
172
+ VLLM_SUPPORTED_CHAT_MODELS.append("fin-r1")
164
173
 
165
174
  if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
166
175
  VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
@@ -216,6 +225,10 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.7.2":
216
225
  if VLLM_INSTALLED and vllm.__version__ >= "0.7.3":
217
226
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-instruct-1m")
218
227
 
228
+ if VLLM_INSTALLED and vllm.__version__ >= "0.8.0":
229
+ VLLM_SUPPORTED_CHAT_MODELS.append("gemma-3-1b-it")
230
+ VLLM_SUPPORTED_VISION_MODEL_LIST.append("gemma-3-it")
231
+
219
232
 
220
233
  class VLLMModel(LLM):
221
234
  def __init__(
@@ -244,15 +257,59 @@ class VLLMModel(LLM):
244
257
  self.lora_modules = peft_model
245
258
  self.lora_requests: List[LoRARequest] = []
246
259
  self._xavier_config = None
260
+ # distributed inference
261
+ self._device_count = None
262
+ self._address = model_config.pop("address", None) # type: ignore
263
+ self._n_worker = model_config.pop("n_worker", 1) # type: ignore
264
+ self._shard = model_config.pop("shard", 0) # type: ignore
265
+ self._driver_info = model_config.pop("driver_info", None) # type: ignore
266
+ self._loading_thread: Optional[threading.Thread] = None
267
+ self._loading_error = None
268
+ # variables used for distributed inference and multiple GPUs
269
+ self._pool_addresses = None
270
+ self._worker_addresses: Optional[Dict[int, List[str]]] = None
271
+ self._all_worker_ready: Optional[threading.Event] = None
272
+ # used to call async
273
+ self._loop = None
247
274
 
248
275
  def set_xavier_config(self, value: Optional[Dict]):
249
276
  self._xavier_config = value # type: ignore
250
277
 
278
+ def set_worker_addresses(self, shard: int, worker_addresses: List[str]):
279
+ assert self._worker_addresses is not None
280
+ self._worker_addresses[shard] = worker_addresses
281
+ if (
282
+ self._all_worker_ready is not None
283
+ and len(self._worker_addresses) == self._n_worker
284
+ ):
285
+ self._all_worker_ready.set()
286
+
287
+ @property
288
+ def driver_info(self) -> Optional[dict]:
289
+ return self._driver_info
290
+
291
+ @property
292
+ def need_create_pools(self):
293
+ return True
294
+
295
+ def set_pool_addresses(self, pool_addresses: List[str]):
296
+ self._pool_addresses = pool_addresses # type: ignore
297
+
298
+ def get_pool_addresses(self) -> Optional[List[str]]:
299
+ return self._pool_addresses
300
+
301
+ def set_loop(self, loop: asyncio.AbstractEventLoop):
302
+ # loop will be passed into XinferenceDistributedExecutor,
303
+ # to call aynsc method with asyncio.run_coroutine_threadsafe
304
+ self._loop = loop # type: ignore
305
+
251
306
  def load(self):
252
307
  try:
253
308
  import vllm
309
+ from vllm.config import VllmConfig
254
310
  from vllm.engine.arg_utils import AsyncEngineArgs
255
311
  from vllm.engine.async_llm_engine import AsyncLLMEngine
312
+ from vllm.executor.executor_base import ExecutorBase
256
313
  from vllm.lora.request import LoRARequest
257
314
  except ImportError:
258
315
  error_message = "Failed to import module 'vllm'"
@@ -271,6 +328,7 @@ class VLLMModel(LLM):
271
328
  # we need to set it to fork to make cupy NCCL work
272
329
  multiprocessing.set_start_method("fork", force=True)
273
330
 
331
+ self._device_count = self._get_cuda_count()
274
332
  self._model_config = self._sanitize_model_config(self._model_config)
275
333
  reasoning_content = self._model_config.pop("reasoning_content")
276
334
 
@@ -316,6 +374,83 @@ class VLLMModel(LLM):
316
374
  self._engine = XavierEngine.from_engine_args(
317
375
  engine_args, xavier_config=self._xavier_config
318
376
  )
377
+ elif self._n_worker > 1 or (
378
+ self._device_count > 1 and vllm.__version__ >= "0.7.0"
379
+ ):
380
+ from .distributed_executor import XinferenceDistributedExecutor
381
+
382
+ # model across multiple workers or GPUs
383
+ engine_args = AsyncEngineArgs(
384
+ model=self.model_path,
385
+ enable_lora=enable_lora,
386
+ max_loras=max_loras,
387
+ **self._model_config,
388
+ )
389
+
390
+ assert self._loop is not None
391
+ self._worker_addresses = {}
392
+
393
+ def _load():
394
+ try:
395
+ assert self._pool_addresses
396
+
397
+ if self._shard > 0:
398
+ assert self._driver_info
399
+ address = self._driver_info["address"]
400
+
401
+ coro = xo.actor_ref(address, self.raw_model_uid)
402
+ model_ref = asyncio.run_coroutine_threadsafe(
403
+ coro, self._loop
404
+ ).result()
405
+ coro = model_ref.set_worker_addresses(
406
+ self._shard, self._pool_addresses
407
+ )
408
+ asyncio.run_coroutine_threadsafe(coro, self._loop).result()
409
+ else:
410
+ self.set_worker_addresses(0, self._pool_addresses)
411
+ self._driver_info = {"address": self._address}
412
+
413
+ if self._n_worker > 1:
414
+ self._all_worker_ready = threading.Event()
415
+ # if model across workers, wait for other workers ready
416
+ self._all_worker_ready.wait()
417
+
418
+ # gather all worker addresses
419
+ worker_addresses = list(
420
+ itertools.chain(
421
+ *[
422
+ self._worker_addresses[shard]
423
+ for shard in range(self._n_worker)
424
+ ]
425
+ )
426
+ )
427
+ assert worker_addresses
428
+ loop = self._loop
429
+
430
+ class XinferenceAsyncLLMEngine(AsyncLLMEngine):
431
+ @classmethod
432
+ def _get_executor_cls(
433
+ cls, engine_config: VllmConfig
434
+ ) -> Type[ExecutorBase]:
435
+ return partial( # type: ignore
436
+ XinferenceDistributedExecutor,
437
+ pool_addresses=worker_addresses,
438
+ n_worker=self._n_worker,
439
+ loop=loop,
440
+ )
441
+
442
+ self._engine = XinferenceAsyncLLMEngine.from_engine_args(
443
+ engine_args
444
+ )
445
+ except:
446
+ logger.exception("Creating vllm engine failed")
447
+ self._loading_error = sys.exc_info()
448
+
449
+ self._loading_thread = threading.Thread(target=_load)
450
+ self._loading_thread.start()
451
+ # wait some time for init finish
452
+ if self._shard == 0:
453
+ self._loading_thread.join(1)
319
454
  else:
320
455
  engine_args = AsyncEngineArgs(
321
456
  model=self.model_path,
@@ -328,7 +463,14 @@ class VLLMModel(LLM):
328
463
  self._check_health_task = None
329
464
  if hasattr(self._engine, "check_health"):
330
465
  # vLLM introduced `check_health` since v0.4.1
331
- self._check_health_task = asyncio.create_task(self._check_healthy())
466
+ self._check_health_task = self._loop.create_task(self._check_healthy())
467
+
468
+ def wait_for_load(self):
469
+ if self._loading_thread:
470
+ self._loading_thread.join()
471
+ if self._loading_error:
472
+ _, err, tb = self._loading_error
473
+ raise err.with_traceback(tb)
332
474
 
333
475
  def stop(self):
334
476
  # though the vLLM engine will shutdown when deleted,
@@ -337,9 +479,10 @@ class VLLMModel(LLM):
337
479
  logger.info("Stopping vLLM engine")
338
480
  if self._check_health_task:
339
481
  self._check_health_task.cancel()
340
- if model_executor := getattr(self._engine.engine, "model_executor", None):
341
- model_executor.shutdown()
342
- self._engine = None
482
+ if self._engine:
483
+ if model_executor := getattr(self._engine.engine, "model_executor", None):
484
+ model_executor.shutdown()
485
+ self._engine = None
343
486
 
344
487
  async def init_xavier(self):
345
488
  await self._engine.init_xavier()
@@ -370,16 +513,18 @@ class VLLMModel(LLM):
370
513
  if model_config is None:
371
514
  model_config = VLLMModelConfig()
372
515
 
373
- cuda_count = self._get_cuda_count()
374
-
375
516
  model_config.setdefault("tokenizer_mode", "auto")
376
517
  model_config.setdefault("trust_remote_code", True)
377
- model_config.setdefault("tensor_parallel_size", cuda_count)
518
+ model_config.setdefault("tensor_parallel_size", self._device_count) # type: ignore
519
+ model_config.setdefault("pipeline_parallel_size", self._n_worker) # type: ignore
378
520
  model_config.setdefault("block_size", 16)
379
521
  model_config.setdefault("swap_space", 4)
380
522
  model_config.setdefault("gpu_memory_utilization", 0.90)
381
523
  model_config.setdefault("max_num_seqs", 256)
382
- model_config.setdefault("quantization", None)
524
+ if "model_quantization" in model_config:
525
+ model_config["quantization"] = model_config.pop("model_quantization")
526
+ else:
527
+ model_config.setdefault("quantization", None)
383
528
  model_config.setdefault("max_model_len", None)
384
529
  model_config.setdefault("guided_decoding_backend", "outlines")
385
530
  model_config.setdefault("reasoning_content", False)
@@ -840,10 +985,11 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
840
985
  model_family = self.model_family.model_family or self.model_family.model_name
841
986
  full_context_kwargs = {}
842
987
  if tools:
843
- if model_family in QWEN_TOOL_CALL_FAMILY:
988
+ if (
989
+ model_family in QWEN_TOOL_CALL_FAMILY
990
+ or model_family in DEEPSEEK_TOOL_CALL_FAMILY
991
+ ):
844
992
  full_context_kwargs["tools"] = tools
845
- elif model_family in DEEPSEEK_TOOL_CALL_FAMILY:
846
- self._tools_to_messages_for_deepseek(messages, tools)
847
993
  assert self.model_family.chat_template is not None
848
994
  full_prompt = self.get_full_context(
849
995
  messages, self.model_family.chat_template, **full_context_kwargs