xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -78,7 +78,7 @@ class LlamaCppLLMSpecV2(BaseModel):
78
78
 
79
79
 
80
80
  class PytorchLLMSpecV2(BaseModel):
81
- model_format: Literal["pytorch", "gptq", "awq", "fp8"]
81
+ model_format: Literal["pytorch", "gptq", "awq", "fp8", "bnb"]
82
82
  # Must in order that `str` first, then `int`
83
83
  model_size_in_billions: Union[str, int]
84
84
  quantization: str
@@ -39,6 +39,7 @@ from ..llm_family import CustomLLMFamilyV2
39
39
  from ..utils import (
40
40
  DEEPSEEK_TOOL_CALL_FAMILY,
41
41
  QWEN_TOOL_CALL_FAMILY,
42
+ QWEN_TOOL_CALL_SYMBOLS,
42
43
  ChatModelMixin,
43
44
  generate_completion_chunk,
44
45
  )
@@ -337,7 +338,7 @@ class SGLANGModel(LLM):
337
338
  return False
338
339
  if not cls._is_linux():
339
340
  return False
340
- if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
341
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
341
342
  return False
342
343
  if llm_spec.model_format == "pytorch":
343
344
  if quantization != "none" and not (quantization is None):
@@ -471,6 +472,7 @@ class SGLANGModel(LLM):
471
472
  *,
472
473
  image_data: Optional[Union[List[str], str]] = None,
473
474
  generate_config: Optional[SGLANGGenerateConfig] = None,
475
+ tools: Optional[List[Dict]] = None,
474
476
  request_id: Optional[str] = None,
475
477
  ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
476
478
  sanitized_generate_config = self._sanitize_generate_config(generate_config)
@@ -501,6 +503,10 @@ class SGLANGModel(LLM):
501
503
 
502
504
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
503
505
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
506
+ complete_response = ""
507
+ match_tool_call_tmp_results: List[CompletionChunk] = []
508
+ is_match_tool_call = False
509
+ chunk = None
504
510
  finish_reason = None
505
511
  async for meta_info, out in self._stream_generate(
506
512
  prompt, image_data, **sanitized_generate_config
@@ -508,6 +514,7 @@ class SGLANGModel(LLM):
508
514
  chunk = self._convert_state_to_completion_chunk(
509
515
  request_id, self.model_uid, output_text=out
510
516
  )
517
+ complete_response += out
511
518
  finish_reason = meta_info["finish_reason"]
512
519
  prompt_tokens = meta_info["prompt_tokens"]
513
520
  completion_tokens = meta_info["completion_tokens"]
@@ -517,6 +524,49 @@ class SGLANGModel(LLM):
517
524
  completion_tokens=completion_tokens,
518
525
  total_tokens=total_tokens,
519
526
  )
527
+ if tools:
528
+ """
529
+ The qwen2 tool call returns format like this:
530
+ <tool_call>
531
+ {...}
532
+ </tool_call>
533
+ Here is to match this.
534
+ """
535
+ if (
536
+ len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)
537
+ ) and (
538
+ not QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)
539
+ ):
540
+ for c in match_tool_call_tmp_results:
541
+ yield c
542
+ match_tool_call_tmp_results.clear()
543
+ yield chunk
544
+ elif (
545
+ len(QWEN_TOOL_CALL_SYMBOLS[0]) > len(complete_response)
546
+ ) and (QWEN_TOOL_CALL_SYMBOLS[0].startswith(complete_response)):
547
+ match_tool_call_tmp_results.append(chunk)
548
+ else:
549
+ assert len(QWEN_TOOL_CALL_SYMBOLS[0]) <= len(
550
+ complete_response
551
+ )
552
+ if not is_match_tool_call and complete_response.startswith(
553
+ QWEN_TOOL_CALL_SYMBOLS[0]
554
+ ):
555
+ is_match_tool_call = True
556
+ match_tool_call_tmp_results.clear()
557
+
558
+ if not is_match_tool_call:
559
+ for c in match_tool_call_tmp_results:
560
+ yield c
561
+ match_tool_call_tmp_results.clear()
562
+ yield chunk
563
+ else:
564
+ chunk["choices"][0]["text"] = complete_response
565
+ else:
566
+ yield chunk
567
+
568
+ if is_match_tool_call:
569
+ assert chunk is not None
520
570
  yield chunk
521
571
 
522
572
  finish_reason = (
@@ -561,7 +611,7 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
561
611
  def match_json(
562
612
  cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
563
613
  ) -> bool:
564
- if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
614
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
565
615
  return False
566
616
  if llm_spec.model_format == "pytorch":
567
617
  if quantization != "none" and not (quantization is None):
@@ -588,6 +638,57 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
588
638
  generate_config.pop("chat_template_kwargs", None)
589
639
  return generate_config
590
640
 
641
+ @staticmethod
642
+ def is_tool_call_chunk_start(chunk):
643
+ return chunk["choices"][0]["text"].startswith(QWEN_TOOL_CALL_SYMBOLS[0])
644
+
645
+ @staticmethod
646
+ def is_tool_call_chunk_end(chunk):
647
+ return chunk["choices"][0]["text"].endswith(QWEN_TOOL_CALL_SYMBOLS[1])
648
+
649
+ async def _async_to_tool_completion_chunks(
650
+ self,
651
+ chunks: AsyncGenerator[CompletionChunk, None],
652
+ ) -> AsyncGenerator[ChatCompletionChunk, None]:
653
+ i = 0
654
+ previous_texts = [""]
655
+ tool_call = False
656
+ tool_call_texts = [""]
657
+ if self.reasoning_parser:
658
+ chunks = self.reasoning_parser.prepare_reasoning_content_streaming(chunks)
659
+ async for chunk in chunks:
660
+ if i == 0:
661
+ for first_chunk in self._get_first_chat_completion_chunk(
662
+ chunk, self.reasoning_parser
663
+ ):
664
+ yield first_chunk
665
+ # usage
666
+ choices = chunk.get("choices")
667
+ if not choices:
668
+ yield self._get_final_chat_completion_chunk(chunk)
669
+ else:
670
+ if self.is_tool_call_chunk_start(chunk):
671
+ tool_call = True
672
+ if tool_call:
673
+ tool_call_text = tool_call_texts[-1]
674
+ tool_call_text += chunk["choices"][0]["text"]
675
+ tool_call_texts.append(tool_call_text)
676
+ if self.is_tool_call_chunk_end(chunk):
677
+ yield self._post_process_completion_chunk(
678
+ self.model_family,
679
+ self.model_uid,
680
+ chunk,
681
+ reasoning_parser=self.reasoning_parser,
682
+ tool_call_text=tool_call_text,
683
+ )
684
+ tool_call = False
685
+ tool_call_texts = [""]
686
+ else:
687
+ yield self._to_chat_completion_chunk(
688
+ chunk, self.reasoning_parser, previous_texts
689
+ )
690
+ i += 1
691
+
591
692
  async def async_chat(
592
693
  self,
593
694
  messages: List[Dict],
@@ -618,13 +719,15 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
618
719
  generate_config = self._sanitize_chat_config(generate_config)
619
720
  stream = generate_config.get("stream", None)
620
721
  if stream:
621
- agen = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
722
+ agen = await self.async_generate(full_prompt, generate_config=generate_config, tools=tools) # type: ignore
622
723
  assert isinstance(agen, AsyncGenerator)
724
+ if tools:
725
+ return self._async_to_tool_completion_chunks(agen)
623
726
  return self._async_to_chat_completion_chunks(
624
727
  agen, self.reasoning_parser, chat_template_kwargs
625
728
  )
626
729
  else:
627
- c = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
730
+ c = await self.async_generate(full_prompt, generate_config=generate_config, tools=tools) # type: ignore
628
731
  assert not isinstance(c, AsyncGenerator)
629
732
  if tools:
630
733
  return self._post_process_completion(
@@ -642,7 +745,7 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
642
745
  return False
643
746
  if not cls._is_linux():
644
747
  return False
645
- if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
748
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8", "bnb"]:
646
749
  return False
647
750
  if llm_spec.model_format == "pytorch":
648
751
  if quantization != "none" and not (quantization is None):
@@ -286,12 +286,18 @@ class PytorchModel(LLM):
286
286
 
287
287
  kwargs = {}
288
288
 
289
- dtype = get_device_preferred_dtype(self._device)
290
-
291
- if dtype is not None:
292
- kwargs["torch_dtype"] = dtype
289
+ torch_dtype = self._pytorch_model_config.get("torch_dtype")
290
+ if torch_dtype is not None:
291
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
292
+ torch_dtype = getattr(torch, torch_dtype)
293
+ kwargs["torch_dtype"] = torch_dtype
293
294
  else:
294
- raise ValueError(f"Device {self._device} is not supported in temporary")
295
+ dtype = get_device_preferred_dtype(self._device)
296
+
297
+ if dtype is not None:
298
+ kwargs["torch_dtype"] = dtype
299
+ else:
300
+ raise ValueError(f"Device {self._device} is not supported in temporary")
295
301
 
296
302
  kwargs["revision"] = self._pytorch_model_config.get(
297
303
  "revision", self.model_spec.model_revision
@@ -327,6 +333,8 @@ class PytorchModel(LLM):
327
333
  reasoning_content, enable_thinking=enable_thinking
328
334
  )
329
335
 
336
+ logger.debug("Loading Transformers model with kwargs: %s", kwargs)
337
+
330
338
  if self._check_tensorizer_integrity():
331
339
  self._model, self._tokenizer = self._load_tensorizer(**kwargs)
332
340
  else:
@@ -488,7 +496,7 @@ class PytorchModel(LLM):
488
496
  def match_json(
489
497
  cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
490
498
  ) -> bool:
491
- if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
499
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
492
500
  return False
493
501
  model_family = llm_family.model_family or llm_family.model_name
494
502
  if model_family in NON_DEFAULT_MODEL_LIST:
@@ -539,15 +547,13 @@ class PytorchModel(LLM):
539
547
  So we need pad `0` on the left again.
540
548
  """
541
549
  data = []
550
+ max_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs) + 1
542
551
  for r in reqs:
543
552
  r.extra_kwargs["attention_mask_seq_len"] += 1
553
+ real_len = r.extra_kwargs["attention_mask_seq_len"]
554
+ pad_len = max_len - real_len
555
+
544
556
  if self._tokenizer.padding_side == "left":
545
- attention_mask_seq_len = r.extra_kwargs["attention_mask_seq_len"]
546
- pad_len = seq_length - attention_mask_seq_len
547
- assert pad_len >= 0, (
548
- f"pad_len must be greater equal 0, got {pad_len} = "
549
- f"seq_length({seq_length}) - attention_mask_seq_len({attention_mask_seq_len})"
550
- )
551
557
  x = torch.cat(
552
558
  [
553
559
  (
@@ -555,14 +561,10 @@ class PytorchModel(LLM):
555
561
  if pad_len > 0
556
562
  else torch.tensor([], dtype=torch.long)
557
563
  ),
558
- torch.ones((attention_mask_seq_len,), dtype=torch.long),
564
+ torch.ones((real_len,), dtype=torch.long),
559
565
  ]
560
566
  )
561
567
  else:
562
- max_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs)
563
- real_len = r.extra_kwargs["attention_mask_seq_len"]
564
- pad_len = max_len - real_len
565
-
566
568
  x = torch.cat(
567
569
  [
568
570
  torch.ones((real_len,), dtype=torch.long),
@@ -878,7 +880,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
878
880
  def match_json(
879
881
  cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
880
882
  ) -> bool:
881
- if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
883
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
882
884
  return False
883
885
  model_family = llm_family.model_family or llm_family.model_name
884
886
  if model_family in NON_DEFAULT_MODEL_LIST:
@@ -28,7 +28,7 @@ class Gemma3TextChatModel(PytorchChatModel):
28
28
  def match_json(
29
29
  cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
30
30
  ) -> bool:
31
- if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
31
+ if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
32
32
  return False
33
33
  llm_family = model_family.model_family or model_family.model_name
34
34
  if "gemma-3-1b-it".lower() in llm_family.lower():
@@ -0,0 +1,91 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import logging
16
+ from typing import Dict, Iterator, List, Optional, Union
17
+
18
+ from ....types import (
19
+ ChatCompletion,
20
+ ChatCompletionChunk,
21
+ PytorchGenerateConfig,
22
+ PytorchModelConfig,
23
+ )
24
+ from ..harmony import async_stream_harmony_chat_completion
25
+ from ..llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
26
+ from .core import PytorchChatModel, register_non_default_model
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @register_transformer
32
+ @register_non_default_model("gpt-oss")
33
+ class GPTOSSPytorchChatModel(PytorchChatModel):
34
+ def _sanitize_model_config(
35
+ self, pytorch_model_config: Optional[PytorchModelConfig]
36
+ ) -> PytorchModelConfig:
37
+ config = super()._sanitize_model_config(pytorch_model_config)
38
+ config.setdefault("torch_dtype", "auto")
39
+ return config # type:ignore
40
+
41
+ @classmethod
42
+ def match_json(
43
+ cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
44
+ ) -> bool:
45
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
46
+ return False
47
+ model_family = llm_family.model_family or llm_family.model_name
48
+ if "gpt" not in model_family and "oss" not in model_family:
49
+ return False
50
+ if "chat" not in llm_family.model_ability:
51
+ return False
52
+ return True
53
+
54
+ async def chat( # type:ignore
55
+ self,
56
+ messages: List[Dict],
57
+ generate_config: Optional[PytorchGenerateConfig] = None,
58
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
59
+ gen = super().chat(messages, generate_config=generate_config)
60
+
61
+ if inspect.iscoroutine(gen):
62
+ gen = await gen
63
+
64
+ if inspect.isasyncgen(gen):
65
+ # Streaming
66
+ async def stream_parser():
67
+ full_text = ""
68
+ full_reasoning = ""
69
+
70
+ async for parsed_chunk in async_stream_harmony_chat_completion(gen):
71
+ choices = parsed_chunk.get("choices")
72
+ if choices and len(choices) > 0:
73
+ delta = choices[0].get("delta", {})
74
+ if delta.get("content"):
75
+ full_text += delta["content"]
76
+ if delta.get("reasoning_content"):
77
+ full_reasoning += delta["reasoning_content"]
78
+ yield parsed_chunk
79
+
80
+ logger.debug(
81
+ "Chat finished, content: %r, reasoning: %r",
82
+ full_text,
83
+ full_reasoning,
84
+ )
85
+
86
+ return stream_parser()
87
+
88
+ else:
89
+ # Non-streaming sync - handle single result
90
+ async for parsed_completion in async_stream_harmony_chat_completion(gen): # type: ignore
91
+ return parsed_completion
@@ -21,9 +21,9 @@ from .....types import (
21
21
  CompletionChunk,
22
22
  PytorchGenerateConfig,
23
23
  )
24
+ from ....utils import cache_clean
24
25
  from ...utils import generate_chat_completion, generate_completion_chunk
25
26
  from ..core import PytorchChatModel
26
- from ..utils import cache_clean
27
27
 
28
28
 
29
29
  class PytorchMultiModalModel(PytorchChatModel):
@@ -31,7 +31,7 @@ class Gemma3ChatModel(PytorchMultiModalModel):
31
31
  def match_json(
32
32
  cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
33
33
  ) -> bool:
34
- if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
34
+ if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
35
35
  return False
36
36
  llm_family = model_family.model_family or model_family.model_name
37
37
  if "gemma-3-it".lower() in llm_family.lower():
@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
30
  @register_transformer
31
- @register_non_default_model("glm-4.1v-thinking")
31
+ @register_non_default_model("glm-4.1v-thinking", "glm-4.5v")
32
32
  class Glm4_1VModel(PytorchMultiModalModel):
33
33
  @classmethod
34
34
  def match_json(
35
35
  cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
36
36
  ) -> bool:
37
37
  family = model_family.model_family or model_family.model_name
38
- if "glm-4.1v" in family.lower():
38
+ if "glm-4.1v" in family.lower() or "glm-4.5v" in family.lower():
39
39
  return True
40
40
  return False
41
41
 
@@ -37,7 +37,7 @@ class Ovis2ChatModel(PytorchMultiModalModel):
37
37
  def match_json(
38
38
  cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
39
39
  ) -> bool:
40
- if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
40
+ if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
41
41
  return False
42
42
  llm_family = model_family.model_family or model_family.model_name
43
43
  if "ovis2".lower() in llm_family.lower():
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import base64
15
- import importlib.util
16
15
  import io
17
16
  import logging
18
17
  import time
@@ -20,13 +19,13 @@ import uuid
20
19
  from threading import Thread
21
20
  from typing import Any, Dict, Iterator, List, Optional, Tuple
22
21
 
23
- from .....model.utils import select_device
24
22
  from .....types import (
25
23
  ChatCompletion,
26
24
  ChatCompletionAudio,
27
25
  ChatCompletionChoice,
28
26
  CompletionUsage,
29
27
  )
28
+ from ....utils import is_flash_attn_available, select_device
30
29
  from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
31
30
  from ..core import PytorchGenerateConfig, register_non_default_model
32
31
  from .core import PytorchMultiModalModel
@@ -46,7 +45,7 @@ class Qwen2_5OmniChatModel(PytorchMultiModalModel):
46
45
  def match_json(
47
46
  cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
48
47
  ) -> bool:
49
- if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
48
+ if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
50
49
  return False
51
50
  llm_family = model_family.model_family or model_family.model_name
52
51
  if "qwen2.5-omni".lower() in llm_family.lower():
@@ -71,12 +70,12 @@ class Qwen2_5OmniChatModel(PytorchMultiModalModel):
71
70
 
72
71
  # for multiple GPU, set back to auto to make multiple devices work
73
72
  device = "auto" if self._device == "cuda" else self._device
74
- flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
75
- kwargs = (
76
- {}
77
- if not flash_attn_installed
78
- else {"attn_implementation": "flash_attention_2"}
73
+ kwargs = {}
74
+ enable_flash_attn = self._pytorch_model_config.get(
75
+ "enable_flash_attn", is_flash_attn_available()
79
76
  )
77
+ if enable_flash_attn:
78
+ kwargs["attn_implementation"] = "flash_attention_2"
80
79
  kwargs = self.apply_bnb_quantization(kwargs)
81
80
  logger.debug("Loading model with extra kwargs: %s", kwargs)
82
81
 
@@ -11,15 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import importlib.util
15
14
  import logging
16
15
  from typing import Any, Dict, Iterator, List, Optional, Tuple
17
16
 
18
17
  from .....core.model import register_batching_multimodal_models
19
18
  from .....device_utils import is_npu_available
20
- from .....model.utils import select_device
21
19
  from .....types import PytorchModelConfig
22
20
  from ....scheduler.request import InferenceRequest
21
+ from ....utils import is_flash_attn_available, select_device
23
22
  from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
24
23
  from ..core import register_non_default_model
25
24
  from .core import PytorchMultiModalModel
@@ -48,7 +47,7 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
48
47
  def match_json(
49
48
  cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
50
49
  ) -> bool:
51
- if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
50
+ if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
52
51
  return False
53
52
  llm_family = model_family.model_family or model_family.model_name
54
53
  if "qwen2-vl-instruct".lower() in llm_family.lower():
@@ -87,7 +86,6 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
87
86
  Qwen2_5_VLForConditionalGeneration = None
88
87
 
89
88
  kwargs = self.apply_bnb_quantization()
90
- flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
91
89
  llm_family = self.model_family.model_family or self.model_family.model_name
92
90
  model_cls = (
93
91
  Qwen2_5_VLForConditionalGeneration
@@ -97,12 +95,17 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
97
95
  if model_cls is None:
98
96
  raise ImportError("`transformers` version is too old, please upgrade it")
99
97
  device = "auto" if self._device == "cuda" else self._device
100
- if flash_attn_installed:
98
+
99
+ enable_flash_attn = self._pytorch_model_config.get(
100
+ "enable_flash_attn", is_flash_attn_available()
101
+ )
102
+
103
+ if enable_flash_attn:
101
104
  self._model = model_cls.from_pretrained(
102
105
  self.model_path,
103
106
  torch_dtype="bfloat16",
104
- device_map=device,
105
107
  attn_implementation="flash_attention_2",
108
+ device_map=device,
106
109
  trust_remote_code=True,
107
110
  **kwargs,
108
111
  ).eval()
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import asyncio
16
- import functools
15
+
17
16
  import logging
18
17
  import os
19
18
  import time
@@ -495,34 +494,3 @@ def batch_inference_one_step(
495
494
  for r in req_list:
496
495
  r.stopped = True
497
496
  r.error_msg = str(e)
498
-
499
-
500
- def cache_clean(fn):
501
- @functools.wraps(fn)
502
- async def _async_wrapper(self, *args, **kwargs):
503
- import gc
504
-
505
- from ....device_utils import empty_cache
506
-
507
- result = await fn(self, *args, **kwargs)
508
-
509
- gc.collect()
510
- empty_cache()
511
- return result
512
-
513
- @functools.wraps(fn)
514
- def _wrapper(self, *args, **kwargs):
515
- import gc
516
-
517
- from ....device_utils import empty_cache
518
-
519
- result = fn(self, *args, **kwargs)
520
-
521
- gc.collect()
522
- empty_cache()
523
- return result
524
-
525
- if asyncio.iscoroutinefunction(fn):
526
- return _async_wrapper
527
- else:
528
- return _wrapper