xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
|
@@ -78,7 +78,7 @@ class LlamaCppLLMSpecV2(BaseModel):
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
class PytorchLLMSpecV2(BaseModel):
|
|
81
|
-
model_format: Literal["pytorch", "gptq", "awq", "fp8"]
|
|
81
|
+
model_format: Literal["pytorch", "gptq", "awq", "fp8", "bnb"]
|
|
82
82
|
# Must in order that `str` first, then `int`
|
|
83
83
|
model_size_in_billions: Union[str, int]
|
|
84
84
|
quantization: str
|
|
@@ -39,6 +39,7 @@ from ..llm_family import CustomLLMFamilyV2
|
|
|
39
39
|
from ..utils import (
|
|
40
40
|
DEEPSEEK_TOOL_CALL_FAMILY,
|
|
41
41
|
QWEN_TOOL_CALL_FAMILY,
|
|
42
|
+
QWEN_TOOL_CALL_SYMBOLS,
|
|
42
43
|
ChatModelMixin,
|
|
43
44
|
generate_completion_chunk,
|
|
44
45
|
)
|
|
@@ -337,7 +338,7 @@ class SGLANGModel(LLM):
|
|
|
337
338
|
return False
|
|
338
339
|
if not cls._is_linux():
|
|
339
340
|
return False
|
|
340
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
341
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
|
|
341
342
|
return False
|
|
342
343
|
if llm_spec.model_format == "pytorch":
|
|
343
344
|
if quantization != "none" and not (quantization is None):
|
|
@@ -471,6 +472,7 @@ class SGLANGModel(LLM):
|
|
|
471
472
|
*,
|
|
472
473
|
image_data: Optional[Union[List[str], str]] = None,
|
|
473
474
|
generate_config: Optional[SGLANGGenerateConfig] = None,
|
|
475
|
+
tools: Optional[List[Dict]] = None,
|
|
474
476
|
request_id: Optional[str] = None,
|
|
475
477
|
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
476
478
|
sanitized_generate_config = self._sanitize_generate_config(generate_config)
|
|
@@ -501,6 +503,10 @@ class SGLANGModel(LLM):
|
|
|
501
503
|
|
|
502
504
|
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
|
503
505
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
506
|
+
complete_response = ""
|
|
507
|
+
match_tool_call_tmp_results: List[CompletionChunk] = []
|
|
508
|
+
is_match_tool_call = False
|
|
509
|
+
chunk = None
|
|
504
510
|
finish_reason = None
|
|
505
511
|
async for meta_info, out in self._stream_generate(
|
|
506
512
|
prompt, image_data, **sanitized_generate_config
|
|
@@ -508,6 +514,7 @@ class SGLANGModel(LLM):
|
|
|
508
514
|
chunk = self._convert_state_to_completion_chunk(
|
|
509
515
|
request_id, self.model_uid, output_text=out
|
|
510
516
|
)
|
|
517
|
+
complete_response += out
|
|
511
518
|
finish_reason = meta_info["finish_reason"]
|
|
512
519
|
prompt_tokens = meta_info["prompt_tokens"]
|
|
513
520
|
completion_tokens = meta_info["completion_tokens"]
|
|
@@ -517,6 +524,49 @@ class SGLANGModel(LLM):
|
|
|
517
524
|
completion_tokens=completion_tokens,
|
|
518
525
|
total_tokens=total_tokens,
|
|
519
526
|
)
|
|
527
|
+
if tools:
|
|
528
|
+
"""
|
|
529
|
+
The qwen2 tool call returns format like this:
|
|
530
|
+
<tool_call>
|
|
531
|
+
{...}
|
|
532
|
+
</tool_call>
|
|
533
|
+
Here is to match this.
|
|
534
|
+
"""
|
|
535
|
+
if (
|
|
536
|
+
len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)
|
|
537
|
+
) and (
|
|
538
|
+
not QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
|
|
539
|
+
):
|
|
540
|
+
for c in match_tool_call_tmp_results:
|
|
541
|
+
yield c
|
|
542
|
+
match_tool_call_tmp_results.clear()
|
|
543
|
+
yield chunk
|
|
544
|
+
elif (
|
|
545
|
+
len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)
|
|
546
|
+
) and (QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)):
|
|
547
|
+
match_tool_call_tmp_results.append(chunk)
|
|
548
|
+
else:
|
|
549
|
+
assert len(QWEN_TOOL_CALL_SYMBOLS[0]) <= len(
|
|
550
|
+
complete_response
|
|
551
|
+
)
|
|
552
|
+
if not is_match_tool_call and complete_response.startswith(
|
|
553
|
+
QWEN_TOOL_CALL_SYMBOLS[0]
|
|
554
|
+
):
|
|
555
|
+
is_match_tool_call = True
|
|
556
|
+
match_tool_call_tmp_results.clear()
|
|
557
|
+
|
|
558
|
+
if not is_match_tool_call:
|
|
559
|
+
for c in match_tool_call_tmp_results:
|
|
560
|
+
yield c
|
|
561
|
+
match_tool_call_tmp_results.clear()
|
|
562
|
+
yield chunk
|
|
563
|
+
else:
|
|
564
|
+
chunk["choices"][0]["text"] = complete_response
|
|
565
|
+
else:
|
|
566
|
+
yield chunk
|
|
567
|
+
|
|
568
|
+
if is_match_tool_call:
|
|
569
|
+
assert chunk is not None
|
|
520
570
|
yield chunk
|
|
521
571
|
|
|
522
572
|
finish_reason = (
|
|
@@ -561,7 +611,7 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
561
611
|
def match_json(
|
|
562
612
|
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
563
613
|
) -> bool:
|
|
564
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
614
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
|
|
565
615
|
return False
|
|
566
616
|
if llm_spec.model_format == "pytorch":
|
|
567
617
|
if quantization != "none" and not (quantization is None):
|
|
@@ -588,6 +638,57 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
588
638
|
generate_config.pop("chat_template_kwargs", None)
|
|
589
639
|
return generate_config
|
|
590
640
|
|
|
641
|
+
@staticmethod
|
|
642
|
+
def is_tool_call_chunk_start(chunk):
|
|
643
|
+
return chunk["choices"][0]["text"].startswith(QWEN_TOOL_CALL_SYMBOLS[0])
|
|
644
|
+
|
|
645
|
+
@staticmethod
|
|
646
|
+
def is_tool_call_chunk_end(chunk):
|
|
647
|
+
return chunk["choices"][0]["text"].endswith(QWEN_TOOL_CALL_SYMBOLS[1])
|
|
648
|
+
|
|
649
|
+
async def _async_to_tool_completion_chunks(
|
|
650
|
+
self,
|
|
651
|
+
chunks: AsyncGenerator[CompletionChunk, None],
|
|
652
|
+
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
|
653
|
+
i = 0
|
|
654
|
+
previous_texts = [""]
|
|
655
|
+
tool_call = False
|
|
656
|
+
tool_call_texts = [""]
|
|
657
|
+
if self.reasoning_parser:
|
|
658
|
+
chunks = self.reasoning_parser.prepare_reasoning_content_streaming(chunks)
|
|
659
|
+
async for chunk in chunks:
|
|
660
|
+
if i == 0:
|
|
661
|
+
for first_chunk in self._get_first_chat_completion_chunk(
|
|
662
|
+
chunk, self.reasoning_parser
|
|
663
|
+
):
|
|
664
|
+
yield first_chunk
|
|
665
|
+
# usage
|
|
666
|
+
choices = chunk.get("choices")
|
|
667
|
+
if not choices:
|
|
668
|
+
yield self._get_final_chat_completion_chunk(chunk)
|
|
669
|
+
else:
|
|
670
|
+
if self.is_tool_call_chunk_start(chunk):
|
|
671
|
+
tool_call = True
|
|
672
|
+
if tool_call:
|
|
673
|
+
tool_call_text = tool_call_texts[-1]
|
|
674
|
+
tool_call_text += chunk["choices"][0]["text"]
|
|
675
|
+
tool_call_texts.append(tool_call_text)
|
|
676
|
+
if self.is_tool_call_chunk_end(chunk):
|
|
677
|
+
yield self._post_process_completion_chunk(
|
|
678
|
+
self.model_family,
|
|
679
|
+
self.model_uid,
|
|
680
|
+
chunk,
|
|
681
|
+
reasoning_parser=self.reasoning_parser,
|
|
682
|
+
tool_call_text=tool_call_text,
|
|
683
|
+
)
|
|
684
|
+
tool_call = False
|
|
685
|
+
tool_call_texts = [""]
|
|
686
|
+
else:
|
|
687
|
+
yield self._to_chat_completion_chunk(
|
|
688
|
+
chunk, self.reasoning_parser, previous_texts
|
|
689
|
+
)
|
|
690
|
+
i += 1
|
|
691
|
+
|
|
591
692
|
async def async_chat(
|
|
592
693
|
self,
|
|
593
694
|
messages: List[Dict],
|
|
@@ -618,13 +719,15 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
|
|
|
618
719
|
generate_config = self._sanitize_chat_config(generate_config)
|
|
619
720
|
stream = generate_config.get("stream", None)
|
|
620
721
|
if stream:
|
|
621
|
-
agen = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
|
|
722
|
+
agen = await self.async_generate(full_prompt, generate_config=generate_config, tools=tools) # type: ignore
|
|
622
723
|
assert isinstance(agen, AsyncGenerator)
|
|
724
|
+
if tools:
|
|
725
|
+
return self._async_to_tool_completion_chunks(agen)
|
|
623
726
|
return self._async_to_chat_completion_chunks(
|
|
624
727
|
agen, self.reasoning_parser, chat_template_kwargs
|
|
625
728
|
)
|
|
626
729
|
else:
|
|
627
|
-
c = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
|
|
730
|
+
c = await self.async_generate(full_prompt, generate_config=generate_config, tools=tools) # type: ignore
|
|
628
731
|
assert not isinstance(c, AsyncGenerator)
|
|
629
732
|
if tools:
|
|
630
733
|
return self._post_process_completion(
|
|
@@ -642,7 +745,7 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
|
|
|
642
745
|
return False
|
|
643
746
|
if not cls._is_linux():
|
|
644
747
|
return False
|
|
645
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
748
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
|
|
646
749
|
return False
|
|
647
750
|
if llm_spec.model_format == "pytorch":
|
|
648
751
|
if quantization != "none" and not (quantization is None):
|
|
@@ -286,12 +286,18 @@ class PytorchModel(LLM):
|
|
|
286
286
|
|
|
287
287
|
kwargs = {}
|
|
288
288
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
289
|
+
torch_dtype = self._pytorch_model_config.get("torch_dtype")
|
|
290
|
+
if torch_dtype is not None:
|
|
291
|
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
|
292
|
+
torch_dtype = getattr(torch, torch_dtype)
|
|
293
|
+
kwargs["torch_dtype"] = torch_dtype
|
|
293
294
|
else:
|
|
294
|
-
|
|
295
|
+
dtype = get_device_preferred_dtype(self._device)
|
|
296
|
+
|
|
297
|
+
if dtype is not None:
|
|
298
|
+
kwargs["torch_dtype"] = dtype
|
|
299
|
+
else:
|
|
300
|
+
raise ValueError(f"Device {self._device} is not supported in temporary")
|
|
295
301
|
|
|
296
302
|
kwargs["revision"] = self._pytorch_model_config.get(
|
|
297
303
|
"revision", self.model_spec.model_revision
|
|
@@ -327,6 +333,8 @@ class PytorchModel(LLM):
|
|
|
327
333
|
reasoning_content, enable_thinking=enable_thinking
|
|
328
334
|
)
|
|
329
335
|
|
|
336
|
+
logger.debug("Loading Transformers model with kwargs: %s", kwargs)
|
|
337
|
+
|
|
330
338
|
if self._check_tensorizer_integrity():
|
|
331
339
|
self._model, self._tokenizer = self._load_tensorizer(**kwargs)
|
|
332
340
|
else:
|
|
@@ -488,7 +496,7 @@ class PytorchModel(LLM):
|
|
|
488
496
|
def match_json(
|
|
489
497
|
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
490
498
|
) -> bool:
|
|
491
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
499
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
492
500
|
return False
|
|
493
501
|
model_family = llm_family.model_family or llm_family.model_name
|
|
494
502
|
if model_family in NON_DEFAULT_MODEL_LIST:
|
|
@@ -539,15 +547,13 @@ class PytorchModel(LLM):
|
|
|
539
547
|
So we need pad `0` on the left again.
|
|
540
548
|
"""
|
|
541
549
|
data = []
|
|
550
|
+
max_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs) + 1
|
|
542
551
|
for r in reqs:
|
|
543
552
|
r.extra_kwargs["attention_mask_seq_len"] += 1
|
|
553
|
+
real_len = r.extra_kwargs["attention_mask_seq_len"]
|
|
554
|
+
pad_len = max_len - real_len
|
|
555
|
+
|
|
544
556
|
if self._tokenizer.padding_side == "left":
|
|
545
|
-
attention_mask_seq_len = r.extra_kwargs["attention_mask_seq_len"]
|
|
546
|
-
pad_len = seq_length - attention_mask_seq_len
|
|
547
|
-
assert pad_len >= 0, (
|
|
548
|
-
f"pad_len must be greater equal 0, got {pad_len} = "
|
|
549
|
-
f"seq_length({seq_length}) - attention_mask_seq_len({attention_mask_seq_len})"
|
|
550
|
-
)
|
|
551
557
|
x = torch.cat(
|
|
552
558
|
[
|
|
553
559
|
(
|
|
@@ -555,14 +561,10 @@ class PytorchModel(LLM):
|
|
|
555
561
|
if pad_len > 0
|
|
556
562
|
else torch.tensor([], dtype=torch.long)
|
|
557
563
|
),
|
|
558
|
-
torch.ones((
|
|
564
|
+
torch.ones((real_len,), dtype=torch.long),
|
|
559
565
|
]
|
|
560
566
|
)
|
|
561
567
|
else:
|
|
562
|
-
max_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs)
|
|
563
|
-
real_len = r.extra_kwargs["attention_mask_seq_len"]
|
|
564
|
-
pad_len = max_len - real_len
|
|
565
|
-
|
|
566
568
|
x = torch.cat(
|
|
567
569
|
[
|
|
568
570
|
torch.ones((real_len,), dtype=torch.long),
|
|
@@ -878,7 +880,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
878
880
|
def match_json(
|
|
879
881
|
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
880
882
|
) -> bool:
|
|
881
|
-
if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
883
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
882
884
|
return False
|
|
883
885
|
model_family = llm_family.model_family or llm_family.model_name
|
|
884
886
|
if model_family in NON_DEFAULT_MODEL_LIST:
|
|
@@ -28,7 +28,7 @@ class Gemma3TextChatModel(PytorchChatModel):
|
|
|
28
28
|
def match_json(
|
|
29
29
|
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
30
30
|
) -> bool:
|
|
31
|
-
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
31
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
32
32
|
return False
|
|
33
33
|
llm_family = model_family.model_family or model_family.model_name
|
|
34
34
|
if "gemma-3-1b-it".lower() in llm_family.lower():
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import inspect
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
from ....types import (
|
|
19
|
+
ChatCompletion,
|
|
20
|
+
ChatCompletionChunk,
|
|
21
|
+
PytorchGenerateConfig,
|
|
22
|
+
PytorchModelConfig,
|
|
23
|
+
)
|
|
24
|
+
from ..harmony import async_stream_harmony_chat_completion
|
|
25
|
+
from ..llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
|
|
26
|
+
from .core import PytorchChatModel, register_non_default_model
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@register_transformer
|
|
32
|
+
@register_non_default_model("gpt-oss")
|
|
33
|
+
class GPTOSSPytorchChatModel(PytorchChatModel):
|
|
34
|
+
def _sanitize_model_config(
|
|
35
|
+
self, pytorch_model_config: Optional[PytorchModelConfig]
|
|
36
|
+
) -> PytorchModelConfig:
|
|
37
|
+
config = super()._sanitize_model_config(pytorch_model_config)
|
|
38
|
+
config.setdefault("torch_dtype", "auto")
|
|
39
|
+
return config # type:ignore
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def match_json(
|
|
43
|
+
cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
|
|
44
|
+
) -> bool:
|
|
45
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
46
|
+
return False
|
|
47
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
48
|
+
if "gpt" not in model_family and "oss" not in model_family:
|
|
49
|
+
return False
|
|
50
|
+
if "chat" not in llm_family.model_ability:
|
|
51
|
+
return False
|
|
52
|
+
return True
|
|
53
|
+
|
|
54
|
+
async def chat( # type:ignore
|
|
55
|
+
self,
|
|
56
|
+
messages: List[Dict],
|
|
57
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
58
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
59
|
+
gen = super().chat(messages, generate_config=generate_config)
|
|
60
|
+
|
|
61
|
+
if inspect.iscoroutine(gen):
|
|
62
|
+
gen = await gen
|
|
63
|
+
|
|
64
|
+
if inspect.isasyncgen(gen):
|
|
65
|
+
# Streaming
|
|
66
|
+
async def stream_parser():
|
|
67
|
+
full_text = ""
|
|
68
|
+
full_reasoning = ""
|
|
69
|
+
|
|
70
|
+
async for parsed_chunk in async_stream_harmony_chat_completion(gen):
|
|
71
|
+
choices = parsed_chunk.get("choices")
|
|
72
|
+
if choices and len(choices) > 0:
|
|
73
|
+
delta = choices[0].get("delta", {})
|
|
74
|
+
if delta.get("content"):
|
|
75
|
+
full_text += delta["content"]
|
|
76
|
+
if delta.get("reasoning_content"):
|
|
77
|
+
full_reasoning += delta["reasoning_content"]
|
|
78
|
+
yield parsed_chunk
|
|
79
|
+
|
|
80
|
+
logger.debug(
|
|
81
|
+
"Chat finished, content: %r, reasoning: %r",
|
|
82
|
+
full_text,
|
|
83
|
+
full_reasoning,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return stream_parser()
|
|
87
|
+
|
|
88
|
+
else:
|
|
89
|
+
# Non-streaming sync - handle single result
|
|
90
|
+
async for parsed_completion in async_stream_harmony_chat_completion(gen): # type: ignore
|
|
91
|
+
return parsed_completion
|
|
@@ -21,9 +21,9 @@ from .....types import (
|
|
|
21
21
|
CompletionChunk,
|
|
22
22
|
PytorchGenerateConfig,
|
|
23
23
|
)
|
|
24
|
+
from ....utils import cache_clean
|
|
24
25
|
from ...utils import generate_chat_completion, generate_completion_chunk
|
|
25
26
|
from ..core import PytorchChatModel
|
|
26
|
-
from ..utils import cache_clean
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class PytorchMultiModalModel(PytorchChatModel):
|
|
@@ -31,7 +31,7 @@ class Gemma3ChatModel(PytorchMultiModalModel):
|
|
|
31
31
|
def match_json(
|
|
32
32
|
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
33
33
|
) -> bool:
|
|
34
|
-
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
34
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
35
35
|
return False
|
|
36
36
|
llm_family = model_family.model_family or model_family.model_name
|
|
37
37
|
if "gemma-3-it".lower() in llm_family.lower():
|
|
@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@register_transformer
|
|
31
|
-
@register_non_default_model("glm-4.1v-thinking")
|
|
31
|
+
@register_non_default_model("glm-4.1v-thinking", "glm-4.5v")
|
|
32
32
|
class Glm4_1VModel(PytorchMultiModalModel):
|
|
33
33
|
@classmethod
|
|
34
34
|
def match_json(
|
|
35
35
|
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
36
36
|
) -> bool:
|
|
37
37
|
family = model_family.model_family or model_family.model_name
|
|
38
|
-
if "glm-4.1v" in family.lower():
|
|
38
|
+
if "glm-4.1v" in family.lower() or "glm-4.5v" in family.lower():
|
|
39
39
|
return True
|
|
40
40
|
return False
|
|
41
41
|
|
|
@@ -37,7 +37,7 @@ class Ovis2ChatModel(PytorchMultiModalModel):
|
|
|
37
37
|
def match_json(
|
|
38
38
|
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
39
39
|
) -> bool:
|
|
40
|
-
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
40
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
41
41
|
return False
|
|
42
42
|
llm_family = model_family.model_family or model_family.model_name
|
|
43
43
|
if "ovis2".lower() in llm_family.lower():
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import base64
|
|
15
|
-
import importlib.util
|
|
16
15
|
import io
|
|
17
16
|
import logging
|
|
18
17
|
import time
|
|
@@ -20,13 +19,13 @@ import uuid
|
|
|
20
19
|
from threading import Thread
|
|
21
20
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
22
21
|
|
|
23
|
-
from .....model.utils import select_device
|
|
24
22
|
from .....types import (
|
|
25
23
|
ChatCompletion,
|
|
26
24
|
ChatCompletionAudio,
|
|
27
25
|
ChatCompletionChoice,
|
|
28
26
|
CompletionUsage,
|
|
29
27
|
)
|
|
28
|
+
from ....utils import is_flash_attn_available, select_device
|
|
30
29
|
from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
|
|
31
30
|
from ..core import PytorchGenerateConfig, register_non_default_model
|
|
32
31
|
from .core import PytorchMultiModalModel
|
|
@@ -46,7 +45,7 @@ class Qwen2_5OmniChatModel(PytorchMultiModalModel):
|
|
|
46
45
|
def match_json(
|
|
47
46
|
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
48
47
|
) -> bool:
|
|
49
|
-
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
48
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
50
49
|
return False
|
|
51
50
|
llm_family = model_family.model_family or model_family.model_name
|
|
52
51
|
if "qwen2.5-omni".lower() in llm_family.lower():
|
|
@@ -71,12 +70,12 @@ class Qwen2_5OmniChatModel(PytorchMultiModalModel):
|
|
|
71
70
|
|
|
72
71
|
# for multiple GPU, set back to auto to make multiple devices work
|
|
73
72
|
device = "auto" if self._device == "cuda" else self._device
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
if not flash_attn_installed
|
|
78
|
-
else {"attn_implementation": "flash_attention_2"}
|
|
73
|
+
kwargs = {}
|
|
74
|
+
enable_flash_attn = self._pytorch_model_config.get(
|
|
75
|
+
"enable_flash_attn", is_flash_attn_available()
|
|
79
76
|
)
|
|
77
|
+
if enable_flash_attn:
|
|
78
|
+
kwargs["attn_implementation"] = "flash_attention_2"
|
|
80
79
|
kwargs = self.apply_bnb_quantization(kwargs)
|
|
81
80
|
logger.debug("Loading model with extra kwargs: %s", kwargs)
|
|
82
81
|
|
|
@@ -11,15 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import importlib.util
|
|
15
14
|
import logging
|
|
16
15
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
17
16
|
|
|
18
17
|
from .....core.model import register_batching_multimodal_models
|
|
19
18
|
from .....device_utils import is_npu_available
|
|
20
|
-
from .....model.utils import select_device
|
|
21
19
|
from .....types import PytorchModelConfig
|
|
22
20
|
from ....scheduler.request import InferenceRequest
|
|
21
|
+
from ....utils import is_flash_attn_available, select_device
|
|
23
22
|
from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
|
|
24
23
|
from ..core import register_non_default_model
|
|
25
24
|
from .core import PytorchMultiModalModel
|
|
@@ -48,7 +47,7 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
|
|
|
48
47
|
def match_json(
|
|
49
48
|
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
|
|
50
49
|
) -> bool:
|
|
51
|
-
if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
|
|
50
|
+
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
|
|
52
51
|
return False
|
|
53
52
|
llm_family = model_family.model_family or model_family.model_name
|
|
54
53
|
if "qwen2-vl-instruct".lower() in llm_family.lower():
|
|
@@ -87,7 +86,6 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
|
|
|
87
86
|
Qwen2_5_VLForConditionalGeneration = None
|
|
88
87
|
|
|
89
88
|
kwargs = self.apply_bnb_quantization()
|
|
90
|
-
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
|
|
91
89
|
llm_family = self.model_family.model_family or self.model_family.model_name
|
|
92
90
|
model_cls = (
|
|
93
91
|
Qwen2_5_VLForConditionalGeneration
|
|
@@ -97,12 +95,17 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
|
|
|
97
95
|
if model_cls is None:
|
|
98
96
|
raise ImportError("`transformers` version is too old, please upgrade it")
|
|
99
97
|
device = "auto" if self._device == "cuda" else self._device
|
|
100
|
-
|
|
98
|
+
|
|
99
|
+
enable_flash_attn = self._pytorch_model_config.get(
|
|
100
|
+
"enable_flash_attn", is_flash_attn_available()
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if enable_flash_attn:
|
|
101
104
|
self._model = model_cls.from_pretrained(
|
|
102
105
|
self.model_path,
|
|
103
106
|
torch_dtype="bfloat16",
|
|
104
|
-
device_map=device,
|
|
105
107
|
attn_implementation="flash_attention_2",
|
|
108
|
+
device_map=device,
|
|
106
109
|
trust_remote_code=True,
|
|
107
110
|
**kwargs,
|
|
108
111
|
).eval()
|
|
@@ -12,8 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
import functools
|
|
15
|
+
|
|
17
16
|
import logging
|
|
18
17
|
import os
|
|
19
18
|
import time
|
|
@@ -495,34 +494,3 @@ def batch_inference_one_step(
|
|
|
495
494
|
for r in req_list:
|
|
496
495
|
r.stopped = True
|
|
497
496
|
r.error_msg = str(e)
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
def cache_clean(fn):
|
|
501
|
-
@functools.wraps(fn)
|
|
502
|
-
async def _async_wrapper(self, *args, **kwargs):
|
|
503
|
-
import gc
|
|
504
|
-
|
|
505
|
-
from ....device_utils import empty_cache
|
|
506
|
-
|
|
507
|
-
result = await fn(self, *args, **kwargs)
|
|
508
|
-
|
|
509
|
-
gc.collect()
|
|
510
|
-
empty_cache()
|
|
511
|
-
return result
|
|
512
|
-
|
|
513
|
-
@functools.wraps(fn)
|
|
514
|
-
def _wrapper(self, *args, **kwargs):
|
|
515
|
-
import gc
|
|
516
|
-
|
|
517
|
-
from ....device_utils import empty_cache
|
|
518
|
-
|
|
519
|
-
result = fn(self, *args, **kwargs)
|
|
520
|
-
|
|
521
|
-
gc.collect()
|
|
522
|
-
empty_cache()
|
|
523
|
-
return result
|
|
524
|
-
|
|
525
|
-
if asyncio.iscoroutinefunction(fn):
|
|
526
|
-
return _async_wrapper
|
|
527
|
-
else:
|
|
528
|
-
return _wrapper
|