xinference 1.5.0.post2__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (89) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +10 -3
  3. xinference/constants.py +5 -1
  4. xinference/core/supervisor.py +1 -1
  5. xinference/core/utils.py +1 -1
  6. xinference/core/worker.py +2 -2
  7. xinference/deploy/cmdline.py +17 -0
  8. xinference/model/audio/core.py +1 -1
  9. xinference/model/audio/model_spec.json +43 -43
  10. xinference/model/audio/model_spec_modelscope.json +13 -13
  11. xinference/model/llm/__init__.py +3 -5
  12. xinference/model/llm/core.py +14 -0
  13. xinference/model/llm/llama_cpp/core.py +15 -4
  14. xinference/model/llm/llm_family.json +3251 -4304
  15. xinference/model/llm/llm_family.py +62 -6
  16. xinference/model/llm/llm_family_csghub.json +0 -32
  17. xinference/model/llm/llm_family_modelscope.json +1161 -1789
  18. xinference/model/llm/llm_family_openmind_hub.json +19 -325
  19. xinference/model/llm/lmdeploy/core.py +7 -2
  20. xinference/model/llm/mlx/core.py +19 -6
  21. xinference/model/llm/sglang/core.py +25 -10
  22. xinference/model/llm/transformers/chatglm.py +8 -1
  23. xinference/model/llm/transformers/cogagent.py +10 -12
  24. xinference/model/llm/transformers/cogvlm2.py +6 -3
  25. xinference/model/llm/transformers/cogvlm2_video.py +3 -6
  26. xinference/model/llm/transformers/core.py +50 -58
  27. xinference/model/llm/transformers/deepseek_v2.py +4 -2
  28. xinference/model/llm/transformers/deepseek_vl.py +10 -4
  29. xinference/model/llm/transformers/deepseek_vl2.py +9 -4
  30. xinference/model/llm/transformers/gemma3.py +4 -5
  31. xinference/model/llm/transformers/glm4v.py +2 -20
  32. xinference/model/llm/transformers/glm_edge_v.py +3 -20
  33. xinference/model/llm/transformers/intern_vl.py +3 -6
  34. xinference/model/llm/transformers/internlm2.py +1 -1
  35. xinference/model/llm/transformers/minicpmv25.py +4 -2
  36. xinference/model/llm/transformers/minicpmv26.py +5 -3
  37. xinference/model/llm/transformers/omnilmm.py +1 -1
  38. xinference/model/llm/transformers/opt.py +1 -1
  39. xinference/model/llm/transformers/ovis2.py +302 -0
  40. xinference/model/llm/transformers/qwen-omni.py +2 -1
  41. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  42. xinference/model/llm/transformers/qwen2_vl.py +5 -1
  43. xinference/model/llm/transformers/qwen_vl.py +5 -2
  44. xinference/model/llm/utils.py +28 -0
  45. xinference/model/llm/vllm/core.py +73 -9
  46. xinference/model/llm/vllm/distributed_executor.py +8 -7
  47. xinference/model/llm/vllm/xavier/allocator.py +1 -1
  48. xinference/model/llm/vllm/xavier/block_manager.py +1 -1
  49. xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
  50. xinference/model/llm/vllm/xavier/executor.py +1 -1
  51. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -1
  52. xinference/model/video/diffusers.py +30 -3
  53. xinference/model/video/model_spec.json +46 -0
  54. xinference/model/video/model_spec_modelscope.json +48 -0
  55. xinference/types.py +2 -0
  56. xinference/web/ui/build/asset-manifest.json +6 -6
  57. xinference/web/ui/build/index.html +1 -1
  58. xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
  59. xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
  60. xinference/web/ui/build/static/js/main.91e77b5c.js +3 -0
  61. xinference/web/ui/build/static/js/main.91e77b5c.js.map +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
  69. xinference/web/ui/src/locales/en.json +1 -0
  70. xinference/web/ui/src/locales/zh.json +1 -0
  71. {xinference-1.5.0.post2.dist-info → xinference-1.5.1.dist-info}/METADATA +1 -1
  72. {xinference-1.5.0.post2.dist-info → xinference-1.5.1.dist-info}/RECORD +77 -78
  73. {xinference-1.5.0.post2.dist-info → xinference-1.5.1.dist-info}/WHEEL +1 -1
  74. xinference/model/llm/transformers/compression.py +0 -258
  75. xinference/model/llm/transformers/yi_vl.py +0 -239
  76. xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
  77. xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
  78. xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
  86. /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.91e77b5c.js.LICENSE.txt} +0 -0
  87. {xinference-1.5.0.post2.dist-info → xinference-1.5.1.dist-info}/entry_points.txt +0 -0
  88. {xinference-1.5.0.post2.dist-info → xinference-1.5.1.dist-info}/licenses/LICENSE +0 -0
  89. {xinference-1.5.0.post2.dist-info → xinference-1.5.1.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import importlib.util
15
15
  import logging
16
16
  import platform
17
17
  import sys
@@ -172,7 +172,11 @@ class MLXModel(LLM):
172
172
  self._model, self._tokenizer = self._load_model(**kwargs)
173
173
 
174
174
  @classmethod
175
- def match(
175
+ def check_lib(cls) -> bool:
176
+ return importlib.util.find_spec("mlx_lm") is not None
177
+
178
+ @classmethod
179
+ def match_json(
176
180
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
177
181
  ) -> bool:
178
182
  if llm_spec.model_format not in ["mlx"]:
@@ -423,7 +427,7 @@ class MLXChatModel(MLXModel, ChatModelMixin):
423
427
  return generate_config
424
428
 
425
429
  @classmethod
426
- def match(
430
+ def match_json(
427
431
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
428
432
  ) -> bool:
429
433
  if llm_spec.model_format not in ["mlx"]:
@@ -445,7 +449,9 @@ class MLXChatModel(MLXModel, ChatModelMixin):
445
449
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
446
450
  model_family = self.model_family.model_family or self.model_family.model_name
447
451
  tools = generate_config.pop("tools", []) if generate_config else None
448
- full_context_kwargs = {}
452
+ full_context_kwargs = (
453
+ self._get_chat_template_kwargs_from_generate_config(generate_config) or {} # type: ignore
454
+ )
449
455
  if tools:
450
456
  if (
451
457
  model_family in QWEN_TOOL_CALL_FAMILY
@@ -476,7 +482,11 @@ class MLXChatModel(MLXModel, ChatModelMixin):
476
482
 
477
483
  class MLXVisionModel(MLXModel, ChatModelMixin):
478
484
  @classmethod
479
- def match(
485
+ def check_lib(cls) -> bool:
486
+ return importlib.util.find_spec("mlx_vlm") is not None
487
+
488
+ @classmethod
489
+ def match_json(
480
490
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
481
491
  ) -> bool:
482
492
  if llm_spec.model_format not in ["mlx"]:
@@ -623,7 +633,10 @@ class MLXVisionModel(MLXModel, ChatModelMixin):
623
633
  if "internvl2" not in model_family.lower():
624
634
  from qwen_vl_utils import process_vision_info
625
635
 
626
- full_context_kwargs = {}
636
+ full_context_kwargs = (
637
+ self._get_chat_template_kwargs_from_generate_config(generate_config) # type: ignore
638
+ or {}
639
+ )
627
640
  if tools and model_family in QWEN_TOOL_CALL_FAMILY:
628
641
  full_context_kwargs["tools"] = tools
629
642
  assert self.model_family.chat_template is not None
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import importlib.util
15
15
  import json
16
16
  import logging
17
17
  import sys
@@ -107,6 +107,7 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
107
107
  "deepseek-r1-distill-llama",
108
108
  "deepseek-v3",
109
109
  "deepseek-r1",
110
+ "qwen3",
110
111
  ]
111
112
  SGLANG_SUPPORTED_VISION_MODEL_LIST = [
112
113
  "qwen2.5-vl-instruct",
@@ -297,7 +298,11 @@ class SGLANGModel(LLM):
297
298
  return generate_config
298
299
 
299
300
  @classmethod
300
- def match(
301
+ def check_lib(cls) -> bool:
302
+ return importlib.util.find_spec("sglang") is not None
303
+
304
+ @classmethod
305
+ def match_json(
301
306
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
302
307
  ) -> bool:
303
308
  if not cls._has_cuda_device():
@@ -435,6 +440,7 @@ class SGLANGModel(LLM):
435
440
  async def async_generate(
436
441
  self,
437
442
  prompt: str,
443
+ *,
438
444
  image_data: Optional[Union[List[str], str]] = None,
439
445
  generate_config: Optional[SGLANGGenerateConfig] = None,
440
446
  request_id: Optional[str] = None,
@@ -524,7 +530,7 @@ class SGLANGModel(LLM):
524
530
 
525
531
  class SGLANGChatModel(SGLANGModel, ChatModelMixin):
526
532
  @classmethod
527
- def match(
533
+ def match_json(
528
534
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
529
535
  ) -> bool:
530
536
  if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
@@ -551,6 +557,7 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
551
557
  if self.model_family.stop:
552
558
  if (not generate_config.get("stop")) and self.model_family.stop:
553
559
  generate_config["stop"] = self.model_family.stop.copy()
560
+ generate_config.pop("chat_template_kwargs", None)
554
561
  return generate_config
555
562
 
556
563
  async def async_chat(
@@ -560,23 +567,28 @@ class SGLANGChatModel(SGLANGModel, ChatModelMixin):
560
567
  request_id: Optional[str] = None,
561
568
  ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
562
569
  assert self.model_family.chat_template is not None
563
- full_prompt = self.get_full_context(messages, self.model_family.chat_template)
570
+ full_context_kwargs = (
571
+ self._get_chat_template_kwargs_from_generate_config(generate_config) or {}
572
+ )
573
+ full_prompt = self.get_full_context(
574
+ messages, self.model_family.chat_template, **full_context_kwargs
575
+ )
564
576
 
565
577
  generate_config = self._sanitize_chat_config(generate_config)
566
578
  stream = generate_config.get("stream", None)
567
579
  if stream:
568
- agen = await self.async_generate(full_prompt, generate_config) # type: ignore
580
+ agen = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
569
581
  assert isinstance(agen, AsyncGenerator)
570
582
  return self._async_to_chat_completion_chunks(agen, self.reasoning_parser)
571
583
  else:
572
- c = await self.async_generate(full_prompt, generate_config) # type: ignore
584
+ c = await self.async_generate(full_prompt, generate_config=generate_config) # type: ignore
573
585
  assert not isinstance(c, AsyncGenerator)
574
586
  return self._to_chat_completion(c, self.reasoning_parser)
575
587
 
576
588
 
577
589
  class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
578
590
  @classmethod
579
- def match(
591
+ def match_json(
580
592
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
581
593
  ) -> bool:
582
594
  if not cls._has_cuda_device():
@@ -627,7 +639,10 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
627
639
  self.model_family.chat_template if self.model_family.chat_template else ""
628
640
  )
629
641
 
630
- prompt = self.get_full_context(messages, chat_template)
642
+ full_context_kwargs = (
643
+ self._get_chat_template_kwargs_from_generate_config(generate_config) or {}
644
+ )
645
+ prompt = self.get_full_context(messages, chat_template, **full_context_kwargs)
631
646
  images, video_inputs = process_vision_info(messages)
632
647
  if video_inputs:
633
648
  raise ValueError("Not support video input now.")
@@ -650,10 +665,10 @@ class SGLANGVisionModel(SGLANGModel, ChatModelMixin):
650
665
  generate_config = self._sanitize_chat_config(generate_config)
651
666
  stream = generate_config.get("stream", None)
652
667
  if stream:
653
- agen = await self.async_generate(prompt, base64_images, generate_config) # type: ignore
668
+ agen = await self.async_generate(prompt, image_data=base64_images, generate_config=generate_config) # type: ignore
654
669
  assert isinstance(agen, AsyncGenerator)
655
670
  return self._async_to_chat_completion_chunks(agen, self.reasoning_parser)
656
671
  else:
657
- c = await self.async_generate(prompt, base64_images, generate_config) # type: ignore
672
+ c = await self.async_generate(prompt, image_data=base64_images, generate_config=generate_config) # type: ignore
658
673
  assert not isinstance(c, AsyncGenerator)
659
674
  return self._to_chat_completion(c, self.reasoning_parser)
@@ -84,7 +84,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
84
84
  return model, tokenizer
85
85
 
86
86
  @classmethod
87
- def match(
87
+ def match_json(
88
88
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
89
89
  ) -> bool:
90
90
  if llm_spec.model_format != "pytorch":
@@ -462,6 +462,12 @@ class ChatglmPytorchChatModel(PytorchChatModel):
462
462
  tools = list(tools) if tools is not None else None
463
463
  tool_choice = r.generate_config.get("tool_choice", "none")
464
464
 
465
+ full_context_kwargs = (
466
+ self._get_chat_template_kwargs_from_generate_config(
467
+ r.generate_config
468
+ )
469
+ or {}
470
+ )
465
471
  r.prompt = self._process_messages(
466
472
  r.prompt, tools=tools, tool_choice=tool_choice
467
473
  )
@@ -469,6 +475,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
469
475
  r.prompt,
470
476
  self.model_family.chat_template, # type: ignore
471
477
  tokenizer=self._tokenizer,
478
+ **full_context_kwargs,
472
479
  )
473
480
  if tools:
474
481
  r.tools = tools
@@ -46,8 +46,8 @@ class CogAgentChatModel(PytorchChatModel):
46
46
  self._device = None
47
47
  self._tokenizer = None
48
48
  self._model = None
49
- self._platform: Literal["Mac", "WIN", "Mobile"] | None = "Mac"
50
- self._format: Literal[
49
+ self._platform: Literal["Mac", "WIN", "Mobile"] | None = "Mac" # type: ignore
50
+ self._format: Literal[ # type: ignore
51
51
  "(Answer in Action-Operation-Sensitive format.)",
52
52
  "(Answer in Status-Plan-Action-Operation format.)",
53
53
  "(Answer in Status-Action-Operation-Sensitive format.)",
@@ -56,7 +56,7 @@ class CogAgentChatModel(PytorchChatModel):
56
56
  ] | None = "(Answer in Action-Operation-Sensitive format.)"
57
57
 
58
58
  @classmethod
59
- def match(
59
+ def match_json(
60
60
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
61
  ) -> bool:
62
62
  family = model_family.model_family or model_family.model_name
@@ -64,8 +64,8 @@ class CogAgentChatModel(PytorchChatModel):
64
64
  return True
65
65
  return False
66
66
 
67
- def load(self, **kwargs):
68
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
67
+ def load(self):
68
+ from transformers import AutoModelForCausalLM, AutoTokenizer
69
69
 
70
70
  device = self._pytorch_model_config.get("device", "auto")
71
71
  self._device = select_device(device)
@@ -73,19 +73,14 @@ class CogAgentChatModel(PytorchChatModel):
73
73
  self._tokenizer = AutoTokenizer.from_pretrained(
74
74
  self.model_path, trust_remote_code=True
75
75
  )
76
- if self.quantization == "4-bit":
77
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
78
- elif self.quantization == "8-bit":
79
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
80
- else:
81
- quantization_config = None
76
+ kwargs = self.apply_bnb_quantization()
82
77
 
83
78
  self._model = AutoModelForCausalLM.from_pretrained(
84
79
  self.model_path,
85
80
  torch_dtype=torch.bfloat16,
86
81
  trust_remote_code=True,
87
82
  device_map=self._device,
88
- quantization_config=quantization_config,
83
+ **kwargs,
89
84
  ).eval()
90
85
 
91
86
  def _message_content_to_cogagent(self, content):
@@ -211,6 +206,9 @@ class CogAgentChatModel(PytorchChatModel):
211
206
  "return_tensors": "pt",
212
207
  "return_dict": True,
213
208
  }
209
+ full_context_kwargs.update(
210
+ self._get_chat_template_kwargs_from_generate_config(generate_config) or {} # type: ignore
211
+ )
214
212
  assert self.model_family.chat_template is not None
215
213
  inputs = self.get_full_context(
216
214
  [{"role": "user", "image": image, "content": query}],
@@ -64,7 +64,7 @@ class CogVLM2Model(PytorchChatModel):
64
64
  self._model = None
65
65
 
66
66
  @classmethod
67
- def match(
67
+ def match_json(
68
68
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
69
69
  ) -> bool:
70
70
  family = model_family.model_family or model_family.model_name
@@ -72,7 +72,7 @@ class CogVLM2Model(PytorchChatModel):
72
72
  return True
73
73
  return False
74
74
 
75
- def load(self, **kwargs):
75
+ def load(self):
76
76
  from transformers import AutoModelForCausalLM, AutoTokenizer
77
77
  from transformers.generation import GenerationConfig
78
78
 
@@ -88,6 +88,8 @@ class CogVLM2Model(PytorchChatModel):
88
88
  self._model, self._tokenizer = self._load_tensorizer()
89
89
  return
90
90
 
91
+ kwargs = self.apply_bnb_quantization()
92
+
91
93
  self._tokenizer = AutoTokenizer.from_pretrained(
92
94
  self.model_path,
93
95
  trust_remote_code=True,
@@ -99,6 +101,7 @@ class CogVLM2Model(PytorchChatModel):
99
101
  trust_remote_code=True,
100
102
  low_cpu_mem_usage=True,
101
103
  device_map="auto",
104
+ **kwargs
102
105
  ).eval()
103
106
 
104
107
  # Specify hyperparameters for generation
@@ -313,7 +316,7 @@ class CogVLM2Model(PytorchChatModel):
313
316
  def get_dtype(self):
314
317
  return self._torch_type
315
318
 
316
- def _get_full_prompt(self, messages: List[Dict], tools):
319
+ def _get_full_prompt(self, messages: List[Dict], tools): # type: ignore
317
320
  prompt, system_prompt, chat_history = parse_messages(messages)
318
321
  system_prompt = system_prompt or ""
319
322
  query, image, history = self.get_query_and_history(
@@ -63,7 +63,7 @@ class CogVLM2VideoModel(PytorchChatModel):
63
63
  self._model = None
64
64
 
65
65
  @classmethod
66
- def match(
66
+ def match_json(
67
67
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
68
68
  ) -> bool:
69
69
  family = model_family.model_family or model_family.model_name
@@ -71,7 +71,7 @@ class CogVLM2VideoModel(PytorchChatModel):
71
71
  return True
72
72
  return False
73
73
 
74
- def load(self, **kwargs):
74
+ def load(self):
75
75
  from transformers import AutoModelForCausalLM, AutoTokenizer
76
76
  from transformers.generation import GenerationConfig
77
77
 
@@ -87,10 +87,7 @@ class CogVLM2VideoModel(PytorchChatModel):
87
87
  self._model, self._tokenizer = self._load_tensorizer()
88
88
  return
89
89
 
90
- if "8-bit" in self.quantization.lower():
91
- kwargs["load_in_8bit"] = True
92
- elif "4-bit" in self.quantization.lower():
93
- kwargs["load_in_4bit"] = True
90
+ kwargs = self.apply_bnb_quantization()
94
91
 
95
92
  self._tokenizer = AutoTokenizer.from_pretrained(
96
93
  self.model_path,
@@ -11,12 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import importlib.util
15
15
  import json
16
16
  import logging
17
17
  import os
18
18
  from functools import lru_cache
19
- from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
20
20
 
21
21
  import torch
22
22
 
@@ -53,11 +53,8 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
53
53
  "opt",
54
54
  "glm4-chat",
55
55
  "glm4-chat-1m",
56
- "internlm2-chat",
57
- "internlm2.5-chat",
58
56
  "qwen-vl-chat",
59
57
  "OmniLMM",
60
- "yi-vl-chat",
61
58
  "deepseek-vl-chat",
62
59
  "cogvlm2",
63
60
  "cogvlm2-video-llama3-chat",
@@ -75,6 +72,7 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
75
72
  "cogagent",
76
73
  "gemma-3-1b-it",
77
74
  "gemma-3-it",
75
+ "Ovis2",
78
76
  "deepseek-vl2",
79
77
  ]
80
78
 
@@ -142,6 +140,7 @@ class PytorchModel(LLM):
142
140
  pytorch_model_config.setdefault("max_num_seqs", 16)
143
141
  pytorch_model_config.setdefault("enable_tensorizer", False)
144
142
  pytorch_model_config.setdefault("reasoning_content", False)
143
+ pytorch_model_config.setdefault("quantization_config", {})
145
144
  return pytorch_model_config
146
145
 
147
146
  def _sanitize_generate_config(
@@ -264,16 +263,39 @@ class PytorchModel(LLM):
264
263
  f"PEFT adaptor '{peft_model.lora_name}' successfully loaded for model '{self.model_uid}'."
265
264
  )
266
265
 
267
- def load(self):
268
- try:
269
- import torch
270
- except ImportError:
271
- raise ImportError(
272
- f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
266
+ def apply_bnb_quantization(
267
+ self, kwargs: Optional[Dict[str, Any]] = None
268
+ ) -> Dict[str, Any]:
269
+ model_format = self.model_spec.model_format
270
+ _kwargs = kwargs if kwargs is not None else {}
271
+ if model_format == "pytorch":
272
+ quantization_config = self._pytorch_model_config.get(
273
+ "quantization_config", {}
273
274
  )
274
- from .compression import load_compress_model
275
+ if quantization_config:
276
+ # If `load_in_4bit` is enabled, apply default quantization presets.
277
+ if quantization_config.get("load_in_4bit", False):
278
+ quantization_config.setdefault(
279
+ "bnb_4bit_compute_dtype", torch.float16
280
+ )
281
+ quantization_config.setdefault("bnb_4bit_use_double_quant", True)
282
+ quantization_config.setdefault(
283
+ "llm_int8_skip_modules",
284
+ [
285
+ "lm_head",
286
+ "encoder",
287
+ "EncDecAttention",
288
+ ],
289
+ )
290
+
291
+ from transformers import BitsAndBytesConfig
292
+
293
+ _kwargs["quantization_config"] = BitsAndBytesConfig(
294
+ **quantization_config
295
+ )
296
+ return _kwargs
275
297
 
276
- quantization = self.quantization
298
+ def load(self):
277
299
  num_gpus = gpu_count()
278
300
  device = self._pytorch_model_config.get("device", "auto")
279
301
  self._pytorch_model_config["device"] = select_device(device)
@@ -294,7 +316,6 @@ class PytorchModel(LLM):
294
316
  kwargs["trust_remote_code"] = self._pytorch_model_config.get(
295
317
  "trust_remote_code"
296
318
  )
297
- model_format = self.model_spec.model_format
298
319
 
299
320
  is_device_map_auto = False
300
321
 
@@ -310,45 +331,8 @@ class PytorchModel(LLM):
310
331
  }
311
332
  kwargs["max_memory"] = max_memory
312
333
 
313
- if quantization != "none" and model_format == "pytorch":
314
- if self._device == "cuda" and self._is_linux():
315
- kwargs["device_map"] = "auto"
316
- is_device_map_auto = True
317
- if quantization == "4-bit":
318
- kwargs["load_in_4bit"] = True
319
- kwargs["bnb_4bit_compute_dtype"] = torch.float16
320
- kwargs["bnb_4bit_use_double_quant"] = True
321
- kwargs["llm_int8_skip_modules"] = [
322
- "lm_head",
323
- "encoder",
324
- "EncDecAttention",
325
- ]
326
- elif quantization == "8-bit":
327
- kwargs["load_in_8bit"] = True
328
- else:
329
- raise ValueError(
330
- f"Quantization {quantization} is not supported in temporary"
331
- )
332
- else:
333
- if num_gpus != 1 and self._device == "cuda":
334
- raise ValueError(f"Quantization is not supported for multi-gpu")
335
- elif quantization != "8-bit":
336
- raise ValueError(
337
- f"Only 8-bit quantization is supported if it is not linux system or cuda device"
338
- )
339
- else:
340
- (
341
- self._model,
342
- self._tokenizer,
343
- ) = load_compress_model(
344
- model_path=self.model_path,
345
- device=self._device,
346
- torch_dtype=kwargs["torch_dtype"],
347
- use_fast=self._use_fast_tokenizer,
348
- revision=kwargs["revision"],
349
- )
350
- logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
351
- return
334
+ # handle bnb quantization
335
+ kwargs = self.apply_bnb_quantization(kwargs)
352
336
 
353
337
  if num_gpus > 0 and is_hf_accelerate_supported(self._device):
354
338
  kwargs.update({"device_map": "auto"})
@@ -372,7 +356,11 @@ class PytorchModel(LLM):
372
356
  logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
373
357
 
374
358
  @classmethod
375
- def match(
359
+ def check_lib(cls) -> bool:
360
+ return importlib.util.find_spec("transformers") is not None
361
+
362
+ @classmethod
363
+ def match_json(
376
364
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
377
365
  ) -> bool:
378
366
  if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -689,7 +677,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
689
677
  return generate_config
690
678
 
691
679
  @classmethod
692
- def match(
680
+ def match_json(
693
681
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
694
682
  ) -> bool:
695
683
  if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -711,9 +699,11 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
711
699
  def load(self):
712
700
  super().load()
713
701
 
714
- def _get_full_prompt(self, messages: List[Dict], tools):
702
+ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict):
715
703
  model_family = self.model_family.model_family or self.model_family.model_name
716
- full_context_kwargs = {}
704
+ full_context_kwargs = (
705
+ self._get_chat_template_kwargs_from_generate_config(generate_config) or {}
706
+ )
717
707
  if (
718
708
  tools
719
709
  and model_family in QWEN_TOOL_CALL_FAMILY
@@ -736,7 +726,9 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
736
726
  try:
737
727
  if not r.stopped and r.is_prefill:
738
728
  tools = r.generate_config.get("tools", None)
739
- r.full_prompt = self._get_full_prompt(r.prompt, tools)
729
+ r.full_prompt = self._get_full_prompt(
730
+ r.prompt, tools, r.generate_config
731
+ )
740
732
  if tools:
741
733
  r.tools = tools
742
734
  except Exception as e:
@@ -48,13 +48,14 @@ class DeepSeekV2PytorchModel(PytorchModel):
48
48
  torch_dtype=torch.bfloat16,
49
49
  trust_remote_code=True,
50
50
  device_map="auto",
51
+ **kwargs,
51
52
  )
52
53
  model.generation_config = GenerationConfig.from_pretrained(self.model_path)
53
54
  model.generation_config.pad_token_id = model.generation_config.eos_token_id
54
55
  return model, tokenizer
55
56
 
56
57
  @classmethod
57
- def match(
58
+ def match_json(
58
59
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
59
60
  ) -> bool:
60
61
  if llm_spec.model_format != "pytorch":
@@ -95,13 +96,14 @@ class DeepSeekV2PytorchChatModel(PytorchChatModel):
95
96
  torch_dtype=torch.bfloat16,
96
97
  trust_remote_code=True,
97
98
  device_map="auto",
99
+ **kwargs,
98
100
  )
99
101
  model.generation_config = GenerationConfig.from_pretrained(self.model_path)
100
102
  model.generation_config.pad_token_id = model.generation_config.eos_token_id
101
103
  return model, tokenizer
102
104
 
103
105
  @classmethod
104
- def match(
106
+ def match_json(
105
107
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
106
108
  ) -> bool:
107
109
  if llm_spec.model_format != "pytorch":
@@ -42,11 +42,11 @@ class DeepSeekVLChatModel(PytorchChatModel):
42
42
  self._type = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  llm_family = model_family.model_family or model_family.model_name
49
- if "deepseek-vl" == llm_family.lower():
49
+ if "deepseek-vl-chat" == llm_family.lower():
50
50
  return True
51
51
  return False
52
52
 
@@ -62,6 +62,8 @@ class DeepSeekVLChatModel(PytorchChatModel):
62
62
  self._device = select_device(self._device)
63
63
  self._type = torch.float16 if self._device == "mps" else torch.bfloat16
64
64
 
65
+ kwargs = self.apply_bnb_quantization()
66
+
65
67
  # specify the path to the model
66
68
  self._vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained( # type: ignore
67
69
  self.model_path
@@ -69,9 +71,13 @@ class DeepSeekVLChatModel(PytorchChatModel):
69
71
  self._tokenizer = self._vl_chat_processor.tokenizer
70
72
 
71
73
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( # type: ignore
72
- self.model_path, trust_remote_code=True, device_map=self._device
74
+ self.model_path,
75
+ trust_remote_code=True,
76
+ device_map=self._device,
77
+ torch_dtype=self._type,
78
+ **kwargs,
73
79
  )
74
- self._model = vl_gpt.to(self._type).eval()
80
+ self._model = vl_gpt.eval()
75
81
 
76
82
  @staticmethod
77
83
  def _message_content_to_deepseek(content) -> Tuple[str, List[str]]:
@@ -42,7 +42,7 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
42
42
  self._type = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  llm_family = model_family.model_family or model_family.model_name
@@ -60,7 +60,8 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
60
60
 
61
61
  self._device = self._pytorch_model_config.get("device", "auto")
62
62
  self._device = select_device(self._device)
63
- self._type = torch.float16 if self._device == "mps" else torch.bfloat16
63
+ self._type = torch.bfloat16
64
+ kwargs = self.apply_bnb_quantization()
64
65
 
65
66
  # specify the path to the model
66
67
  self._vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained( # type: ignore
@@ -69,9 +70,13 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
69
70
  self._tokenizer = self._vl_chat_processor.tokenizer
70
71
 
71
72
  vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( # type: ignore
72
- self.model_path, trust_remote_code=True, device_map=self._device
73
+ self.model_path,
74
+ trust_remote_code=True,
75
+ device_map=self._device,
76
+ torch_dtype=self._type,
77
+ **kwargs,
73
78
  )
74
- self._model = vl_gpt.to(torch.bfloat16).cuda().eval()
79
+ self._model = vl_gpt.cuda().eval()
75
80
 
76
81
  @staticmethod
77
82
  def _message_content_to_deepseek(content) -> Tuple[str, List[str]]:
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
36
36
 
37
37
  class Gemma3TextChatModel(PytorchChatModel):
38
38
  @classmethod
39
- def match(
39
+ def match_json(
40
40
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
41
41
  ) -> bool:
42
42
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -56,7 +56,7 @@ class Gemma3ChatModel(PytorchChatModel):
56
56
  self._processor = None
57
57
 
58
58
  @classmethod
59
- def match(
59
+ def match_json(
60
60
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
61
  ) -> bool:
62
62
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -85,6 +85,7 @@ class Gemma3ChatModel(PytorchChatModel):
85
85
  device = "auto" if device == "cuda" else device
86
86
  min_pixels = self._pytorch_model_config.get("min_pixels")
87
87
  max_pixels = self._pytorch_model_config.get("max_pixels")
88
+ kwargs = self.apply_bnb_quantization()
88
89
  self._processor = AutoProcessor.from_pretrained(
89
90
  self.model_path,
90
91
  min_pixels=min_pixels,
@@ -92,9 +93,7 @@ class Gemma3ChatModel(PytorchChatModel):
92
93
  )
93
94
  self._tokenizer = self._processor.tokenizer
94
95
  self._model = Gemma3ForConditionalGeneration.from_pretrained(
95
- self.model_path,
96
- device_map="auto",
97
- torch_dtype="bfloat16",
96
+ self.model_path, device_map="auto", torch_dtype="bfloat16", **kwargs
98
97
  )
99
98
 
100
99
  @cache_clean