xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +107 -11
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/constants.py +5 -1
  5. xinference/core/media_interface.py +758 -0
  6. xinference/core/model.py +49 -9
  7. xinference/core/supervisor.py +1 -1
  8. xinference/core/utils.py +1 -1
  9. xinference/core/worker.py +33 -39
  10. xinference/deploy/cmdline.py +17 -0
  11. xinference/deploy/utils.py +0 -3
  12. xinference/model/audio/__init__.py +16 -27
  13. xinference/model/audio/core.py +2 -1
  14. xinference/model/audio/cosyvoice.py +4 -2
  15. xinference/model/audio/model_spec.json +63 -46
  16. xinference/model/audio/model_spec_modelscope.json +31 -14
  17. xinference/model/embedding/__init__.py +16 -24
  18. xinference/model/image/__init__.py +15 -25
  19. xinference/model/llm/__init__.py +40 -115
  20. xinference/model/llm/core.py +29 -6
  21. xinference/model/llm/llama_cpp/core.py +30 -347
  22. xinference/model/llm/llm_family.json +1674 -2203
  23. xinference/model/llm/llm_family.py +71 -7
  24. xinference/model/llm/llm_family_csghub.json +0 -32
  25. xinference/model/llm/llm_family_modelscope.json +1838 -2016
  26. xinference/model/llm/llm_family_openmind_hub.json +19 -325
  27. xinference/model/llm/lmdeploy/core.py +7 -2
  28. xinference/model/llm/mlx/core.py +23 -7
  29. xinference/model/llm/reasoning_parser.py +281 -5
  30. xinference/model/llm/sglang/core.py +39 -11
  31. xinference/model/llm/transformers/chatglm.py +9 -2
  32. xinference/model/llm/transformers/cogagent.py +10 -12
  33. xinference/model/llm/transformers/cogvlm2.py +6 -3
  34. xinference/model/llm/transformers/cogvlm2_video.py +3 -6
  35. xinference/model/llm/transformers/core.py +58 -60
  36. xinference/model/llm/transformers/deepseek_v2.py +4 -2
  37. xinference/model/llm/transformers/deepseek_vl.py +10 -4
  38. xinference/model/llm/transformers/deepseek_vl2.py +9 -4
  39. xinference/model/llm/transformers/gemma3.py +4 -5
  40. xinference/model/llm/transformers/glm4v.py +3 -21
  41. xinference/model/llm/transformers/glm_edge_v.py +3 -20
  42. xinference/model/llm/transformers/intern_vl.py +3 -6
  43. xinference/model/llm/transformers/internlm2.py +1 -1
  44. xinference/model/llm/transformers/minicpmv25.py +4 -2
  45. xinference/model/llm/transformers/minicpmv26.py +5 -3
  46. xinference/model/llm/transformers/omnilmm.py +1 -1
  47. xinference/model/llm/transformers/opt.py +1 -1
  48. xinference/model/llm/transformers/ovis2.py +302 -0
  49. xinference/model/llm/transformers/qwen-omni.py +8 -1
  50. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  51. xinference/model/llm/transformers/qwen2_vl.py +5 -1
  52. xinference/model/llm/transformers/qwen_vl.py +5 -2
  53. xinference/model/llm/utils.py +96 -45
  54. xinference/model/llm/vllm/core.py +108 -24
  55. xinference/model/llm/vllm/distributed_executor.py +8 -7
  56. xinference/model/llm/vllm/xavier/allocator.py +1 -1
  57. xinference/model/llm/vllm/xavier/block_manager.py +1 -1
  58. xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
  59. xinference/model/llm/vllm/xavier/executor.py +1 -1
  60. xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
  61. xinference/model/rerank/__init__.py +13 -24
  62. xinference/model/video/__init__.py +15 -25
  63. xinference/model/video/core.py +3 -3
  64. xinference/model/video/diffusers.py +157 -13
  65. xinference/model/video/model_spec.json +100 -0
  66. xinference/model/video/model_spec_modelscope.json +104 -0
  67. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  68. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  69. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  70. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  71. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  74. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  75. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  76. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  77. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  78. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  79. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  80. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  81. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  84. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  85. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  86. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  87. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  88. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  89. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  90. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  91. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  92. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  93. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  94. xinference/types.py +2 -71
  95. xinference/web/ui/build/asset-manifest.json +6 -6
  96. xinference/web/ui/build/index.html +1 -1
  97. xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
  98. xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
  99. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  100. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  101. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  102. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  103. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  109. xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
  112. xinference/web/ui/src/locales/en.json +7 -4
  113. xinference/web/ui/src/locales/zh.json +7 -4
  114. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
  115. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
  116. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
  117. xinference/core/image_interface.py +0 -377
  118. xinference/model/llm/transformers/compression.py +0 -258
  119. xinference/model/llm/transformers/yi_vl.py +0 -239
  120. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  121. xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
  122. xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
  123. xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
  124. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  133. xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
  134. /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  135. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
  136. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
  137. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
@@ -63,7 +63,7 @@ class CogVLM2VideoModel(PytorchChatModel):
63
63
  self._model = None
64
64
 
65
65
  @classmethod
66
- def match(
66
+ def match_json(
67
67
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
68
68
  ) -> bool:
69
69
  family = model_family.model_family or model_family.model_name
@@ -71,7 +71,7 @@ class CogVLM2VideoModel(PytorchChatModel):
71
71
  return True
72
72
  return False
73
73
 
74
- def load(self, **kwargs):
74
+ def load(self):
75
75
  from transformers import AutoModelForCausalLM, AutoTokenizer
76
76
  from transformers.generation import GenerationConfig
77
77
 
@@ -87,10 +87,7 @@ class CogVLM2VideoModel(PytorchChatModel):
87
87
  self._model, self._tokenizer = self._load_tensorizer()
88
88
  return
89
89
 
90
- if "8-bit" in self.quantization.lower():
91
- kwargs["load_in_8bit"] = True
92
- elif "4-bit" in self.quantization.lower():
93
- kwargs["load_in_4bit"] = True
90
+ kwargs = self.apply_bnb_quantization()
94
91
 
95
92
  self._tokenizer = AutoTokenizer.from_pretrained(
96
93
  self.model_path,
@@ -11,12 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import importlib.util
15
15
  import json
16
16
  import logging
17
17
  import os
18
18
  from functools import lru_cache
19
- from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
20
20
 
21
21
  import torch
22
22
 
@@ -53,11 +53,8 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
53
53
  "opt",
54
54
  "glm4-chat",
55
55
  "glm4-chat-1m",
56
- "internlm2-chat",
57
- "internlm2.5-chat",
58
56
  "qwen-vl-chat",
59
57
  "OmniLMM",
60
- "yi-vl-chat",
61
58
  "deepseek-vl-chat",
62
59
  "cogvlm2",
63
60
  "cogvlm2-video-llama3-chat",
@@ -75,6 +72,7 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
75
72
  "cogagent",
76
73
  "gemma-3-1b-it",
77
74
  "gemma-3-it",
75
+ "Ovis2",
78
76
  "deepseek-vl2",
79
77
  ]
80
78
 
@@ -142,6 +140,7 @@ class PytorchModel(LLM):
142
140
  pytorch_model_config.setdefault("max_num_seqs", 16)
143
141
  pytorch_model_config.setdefault("enable_tensorizer", False)
144
142
  pytorch_model_config.setdefault("reasoning_content", False)
143
+ pytorch_model_config.setdefault("quantization_config", {})
145
144
  return pytorch_model_config
146
145
 
147
146
  def _sanitize_generate_config(
@@ -264,16 +263,39 @@ class PytorchModel(LLM):
264
263
  f"PEFT adaptor '{peft_model.lora_name}' successfully loaded for model '{self.model_uid}'."
265
264
  )
266
265
 
267
- def load(self):
268
- try:
269
- import torch
270
- except ImportError:
271
- raise ImportError(
272
- f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
266
+ def apply_bnb_quantization(
267
+ self, kwargs: Optional[Dict[str, Any]] = None
268
+ ) -> Dict[str, Any]:
269
+ model_format = self.model_spec.model_format
270
+ _kwargs = kwargs if kwargs is not None else {}
271
+ if model_format == "pytorch":
272
+ quantization_config = self._pytorch_model_config.get(
273
+ "quantization_config", {}
273
274
  )
274
- from .compression import load_compress_model
275
+ if quantization_config:
276
+ # If `load_in_4bit` is enabled, apply default quantization presets.
277
+ if quantization_config.get("load_in_4bit", False):
278
+ quantization_config.setdefault(
279
+ "bnb_4bit_compute_dtype", torch.float16
280
+ )
281
+ quantization_config.setdefault("bnb_4bit_use_double_quant", True)
282
+ quantization_config.setdefault(
283
+ "llm_int8_skip_modules",
284
+ [
285
+ "lm_head",
286
+ "encoder",
287
+ "EncDecAttention",
288
+ ],
289
+ )
275
290
 
276
- quantization = self.quantization
291
+ from transformers import BitsAndBytesConfig
292
+
293
+ _kwargs["quantization_config"] = BitsAndBytesConfig(
294
+ **quantization_config
295
+ )
296
+ return _kwargs
297
+
298
+ def load(self):
277
299
  num_gpus = gpu_count()
278
300
  device = self._pytorch_model_config.get("device", "auto")
279
301
  self._pytorch_model_config["device"] = select_device(device)
@@ -294,7 +316,6 @@ class PytorchModel(LLM):
294
316
  kwargs["trust_remote_code"] = self._pytorch_model_config.get(
295
317
  "trust_remote_code"
296
318
  )
297
- model_format = self.model_spec.model_format
298
319
 
299
320
  is_device_map_auto = False
300
321
 
@@ -310,52 +331,18 @@ class PytorchModel(LLM):
310
331
  }
311
332
  kwargs["max_memory"] = max_memory
312
333
 
313
- if quantization != "none" and model_format == "pytorch":
314
- if self._device == "cuda" and self._is_linux():
315
- kwargs["device_map"] = "auto"
316
- is_device_map_auto = True
317
- if quantization == "4-bit":
318
- kwargs["load_in_4bit"] = True
319
- kwargs["bnb_4bit_compute_dtype"] = torch.float16
320
- kwargs["bnb_4bit_use_double_quant"] = True
321
- kwargs["llm_int8_skip_modules"] = [
322
- "lm_head",
323
- "encoder",
324
- "EncDecAttention",
325
- ]
326
- elif quantization == "8-bit":
327
- kwargs["load_in_8bit"] = True
328
- else:
329
- raise ValueError(
330
- f"Quantization {quantization} is not supported in temporary"
331
- )
332
- else:
333
- if num_gpus != 1 and self._device == "cuda":
334
- raise ValueError(f"Quantization is not supported for multi-gpu")
335
- elif quantization != "8-bit":
336
- raise ValueError(
337
- f"Only 8-bit quantization is supported if it is not linux system or cuda device"
338
- )
339
- else:
340
- (
341
- self._model,
342
- self._tokenizer,
343
- ) = load_compress_model(
344
- model_path=self.model_path,
345
- device=self._device,
346
- torch_dtype=kwargs["torch_dtype"],
347
- use_fast=self._use_fast_tokenizer,
348
- revision=kwargs["revision"],
349
- )
350
- logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
351
- return
334
+ # handle bnb quantization
335
+ kwargs = self.apply_bnb_quantization(kwargs)
352
336
 
353
337
  if num_gpus > 0 and is_hf_accelerate_supported(self._device):
354
338
  kwargs.update({"device_map": "auto"})
355
339
  is_device_map_auto = True
356
340
 
357
341
  reasoning_content = self._pytorch_model_config.pop("reasoning_content")
358
- self.prepare_parse_reasoning_content(reasoning_content)
342
+ enable_thinking = self._pytorch_model_config.pop("enable_thinking", False)
343
+ self.prepare_parse_reasoning_content(
344
+ reasoning_content, enable_thinking=enable_thinking
345
+ )
359
346
 
360
347
  if self._check_tensorizer_integrity():
361
348
  self._model, self._tokenizer = self._load_tensorizer(**kwargs)
@@ -372,7 +359,11 @@ class PytorchModel(LLM):
372
359
  logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
373
360
 
374
361
  @classmethod
375
- def match(
362
+ def check_lib(cls) -> bool:
363
+ return importlib.util.find_spec("transformers") is not None
364
+
365
+ @classmethod
366
+ def match_json(
376
367
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
377
368
  ) -> bool:
378
369
  if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -689,7 +680,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
689
680
  return generate_config
690
681
 
691
682
  @classmethod
692
- def match(
683
+ def match_json(
693
684
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
694
685
  ) -> bool:
695
686
  if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -711,9 +702,14 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
711
702
  def load(self):
712
703
  super().load()
713
704
 
714
- def _get_full_prompt(self, messages: List[Dict], tools):
705
+ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict):
715
706
  model_family = self.model_family.model_family or self.model_family.model_name
716
- full_context_kwargs = {}
707
+ full_context_kwargs = (
708
+ self._get_chat_template_kwargs_from_generate_config(
709
+ generate_config, self.reasoning_parser
710
+ )
711
+ or {}
712
+ )
717
713
  if (
718
714
  tools
719
715
  and model_family in QWEN_TOOL_CALL_FAMILY
@@ -736,7 +732,9 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
736
732
  try:
737
733
  if not r.stopped and r.is_prefill:
738
734
  tools = r.generate_config.get("tools", None)
739
- r.full_prompt = self._get_full_prompt(r.prompt, tools)
735
+ r.full_prompt = self._get_full_prompt(
736
+ r.prompt, tools, r.generate_config
737
+ )
740
738
  if tools:
741
739
  r.tools = tools
742
740
  except Exception as e:
@@ -761,7 +759,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
761
759
  results = []
762
760
  for i, c in enumerate(req.completion):
763
761
  if c == "<bos_stream>":
764
- results.append(
762
+ results.extend(
765
763
  self._get_first_chat_completion_chunk(
766
764
  req.completion[i + 1], self.reasoning_parser
767
765
  )
@@ -48,13 +48,14 @@ class DeepSeekV2PytorchModel(PytorchModel):
48
48
  torch_dtype=torch.bfloat16,
49
49
  trust_remote_code=True,
50
50
  device_map="auto",
51
+ **kwargs,
51
52
  )
52
53
  model.generation_config = GenerationConfig.from_pretrained(self.model_path)
53
54
  model.generation_config.pad_token_id = model.generation_config.eos_token_id
54
55
  return model, tokenizer
55
56
 
56
57
  @classmethod
57
- def match(
58
+ def match_json(
58
59
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
59
60
  ) -> bool:
60
61
  if llm_spec.model_format != "pytorch":
@@ -95,13 +96,14 @@ class DeepSeekV2PytorchChatModel(PytorchChatModel):
95
96
  torch_dtype=torch.bfloat16,
96
97
  trust_remote_code=True,
97
98
  device_map="auto",
99
+ **kwargs,
98
100
  )
99
101
  model.generation_config = GenerationConfig.from_pretrained(self.model_path)
100
102
  model.generation_config.pad_token_id = model.generation_config.eos_token_id
101
103
  return model, tokenizer
102
104
 
103
105
  @classmethod
104
- def match(
106
+ def match_json(
105
107
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
106
108
  ) -> bool:
107
109
  if llm_spec.model_format != "pytorch":
@@ -42,11 +42,11 @@ class DeepSeekVLChatModel(PytorchChatModel):
42
42
  self._type = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  llm_family = model_family.model_family or model_family.model_name
49
- if "deepseek-vl" == llm_family.lower():
49
+ if "deepseek-vl-chat" == llm_family.lower():
50
50
  return True
51
51
  return False
52
52
 
@@ -62,6 +62,8 @@ class DeepSeekVLChatModel(PytorchChatModel):
62
62
  self._device = select_device(self._device)
63
63
  self._type = torch.float16 if self._device == "mps" else torch.bfloat16
64
64
 
65
+ kwargs = self.apply_bnb_quantization()
66
+
65
67
  # specify the path to the model
66
68
  self._vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained( # type: ignore
67
69
  self.model_path
@@ -69,9 +71,13 @@ class DeepSeekVLChatModel(PytorchChatModel):
69
71
  self._tokenizer = self._vl_chat_processor.tokenizer
70
72
 
71
73
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( # type: ignore
72
- self.model_path, trust_remote_code=True, device_map=self._device
74
+ self.model_path,
75
+ trust_remote_code=True,
76
+ device_map=self._device,
77
+ torch_dtype=self._type,
78
+ **kwargs,
73
79
  )
74
- self._model = vl_gpt.to(self._type).eval()
80
+ self._model = vl_gpt.eval()
75
81
 
76
82
  @staticmethod
77
83
  def _message_content_to_deepseek(content) -> Tuple[str, List[str]]:
@@ -42,7 +42,7 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
42
42
  self._type = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  llm_family = model_family.model_family or model_family.model_name
@@ -60,7 +60,8 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
60
60
 
61
61
  self._device = self._pytorch_model_config.get("device", "auto")
62
62
  self._device = select_device(self._device)
63
- self._type = torch.float16 if self._device == "mps" else torch.bfloat16
63
+ self._type = torch.bfloat16
64
+ kwargs = self.apply_bnb_quantization()
64
65
 
65
66
  # specify the path to the model
66
67
  self._vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained( # type: ignore
@@ -69,9 +70,13 @@ class DeepSeekVL2ChatModel(PytorchChatModel):
69
70
  self._tokenizer = self._vl_chat_processor.tokenizer
70
71
 
71
72
  vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained( # type: ignore
72
- self.model_path, trust_remote_code=True, device_map=self._device
73
+ self.model_path,
74
+ trust_remote_code=True,
75
+ device_map=self._device,
76
+ torch_dtype=self._type,
77
+ **kwargs,
73
78
  )
74
- self._model = vl_gpt.to(torch.bfloat16).cuda().eval()
79
+ self._model = vl_gpt.cuda().eval()
75
80
 
76
81
  @staticmethod
77
82
  def _message_content_to_deepseek(content) -> Tuple[str, List[str]]:
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
36
36
 
37
37
  class Gemma3TextChatModel(PytorchChatModel):
38
38
  @classmethod
39
- def match(
39
+ def match_json(
40
40
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
41
41
  ) -> bool:
42
42
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -56,7 +56,7 @@ class Gemma3ChatModel(PytorchChatModel):
56
56
  self._processor = None
57
57
 
58
58
  @classmethod
59
- def match(
59
+ def match_json(
60
60
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
61
  ) -> bool:
62
62
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
@@ -85,6 +85,7 @@ class Gemma3ChatModel(PytorchChatModel):
85
85
  device = "auto" if device == "cuda" else device
86
86
  min_pixels = self._pytorch_model_config.get("min_pixels")
87
87
  max_pixels = self._pytorch_model_config.get("max_pixels")
88
+ kwargs = self.apply_bnb_quantization()
88
89
  self._processor = AutoProcessor.from_pretrained(
89
90
  self.model_path,
90
91
  min_pixels=min_pixels,
@@ -92,9 +93,7 @@ class Gemma3ChatModel(PytorchChatModel):
92
93
  )
93
94
  self._tokenizer = self._processor.tokenizer
94
95
  self._model = Gemma3ForConditionalGeneration.from_pretrained(
95
- self.model_path,
96
- device_map="auto",
97
- torch_dtype="bfloat16",
96
+ self.model_path, device_map="auto", torch_dtype="bfloat16", **kwargs
98
97
  )
99
98
 
100
99
  @cache_clean
@@ -39,7 +39,7 @@ class Glm4VModel(PytorchChatModel):
39
39
  self._model = None
40
40
 
41
41
  @classmethod
42
- def match(
42
+ def match_json(
43
43
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
44
44
  ) -> bool:
45
45
  family = model_family.model_family or model_family.model_name
@@ -54,25 +54,7 @@ class Glm4VModel(PytorchChatModel):
54
54
  self._device = select_device(device)
55
55
 
56
56
  kwargs = {"device_map": self._device}
57
- quantization = self.quantization
58
-
59
- # referenced from PytorchModel.load
60
- if quantization != "none":
61
- if self._device == "cuda" and self._is_linux():
62
- kwargs["device_map"] = "auto"
63
- if quantization == "4-bit":
64
- kwargs["load_in_4bit"] = True
65
- elif quantization == "8-bit":
66
- kwargs["load_in_8bit"] = True
67
- else:
68
- raise ValueError(
69
- f"Quantization {quantization} is not supported in temporary"
70
- )
71
- else:
72
- if quantization != "8-bit":
73
- raise ValueError(
74
- f"Only 8-bit quantization is supported if it is not linux system or cuda device"
75
- )
57
+ kwargs = self.apply_bnb_quantization(kwargs)
76
58
 
77
59
  if self._check_tensorizer_integrity():
78
60
  self._model, self._tokenizer = self._load_tensorizer()
@@ -214,7 +196,7 @@ class Glm4VModel(PytorchChatModel):
214
196
  has_content=False,
215
197
  )
216
198
 
217
- def _get_full_prompt(self, messages, tools):
199
+ def _get_full_prompt(self, messages, tools, generate_config: dict):
218
200
  msgs = self._get_processed_msgs(messages)
219
201
  inputs = self._tokenizer.apply_chat_template(
220
202
  msgs,
@@ -42,7 +42,7 @@ class GlmEdgeVModel(PytorchChatModel):
42
42
  self._processor = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  family = model_family.model_family or model_family.model_name
@@ -57,25 +57,7 @@ class GlmEdgeVModel(PytorchChatModel):
57
57
  self._device = select_device(device)
58
58
 
59
59
  kwargs = {"device_map": self._device}
60
- quantization = self.quantization
61
-
62
- # referenced from PytorchModel.load
63
- if quantization != "none":
64
- if self._device == "cuda" and self._is_linux():
65
- kwargs["device_map"] = "auto"
66
- if quantization == "4-bit":
67
- kwargs["load_in_4bit"] = True
68
- elif quantization == "8-bit":
69
- kwargs["load_in_8bit"] = True
70
- else:
71
- raise ValueError(
72
- f"Quantization {quantization} is not supported in temporary"
73
- )
74
- else:
75
- if quantization != "8-bit":
76
- raise ValueError(
77
- f"Only 8-bit quantization is supported if it is not linux system or cuda device"
78
- )
60
+ kwargs = self.apply_bnb_quantization(kwargs)
79
61
 
80
62
  processor = AutoImageProcessor.from_pretrained(
81
63
  self.model_path, trust_remote_code=True
@@ -87,6 +69,7 @@ class GlmEdgeVModel(PytorchChatModel):
87
69
  trust_remote_code=True,
88
70
  torch_dtype=torch.bfloat16,
89
71
  device_map="auto",
72
+ **kwargs
90
73
  )
91
74
 
92
75
  self._model = model
@@ -243,7 +243,7 @@ class InternVLChatModel(PytorchChatModel):
243
243
  self._model = None
244
244
 
245
245
  @classmethod
246
- def match(
246
+ def match_json(
247
247
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
248
248
  ) -> bool:
249
249
  family = model_family.model_family or model_family.model_name
@@ -311,7 +311,7 @@ class InternVLChatModel(PytorchChatModel):
311
311
  device_map[f"language_model.model.layers.{num_layers - 1}"] = 0
312
312
  return device_map
313
313
 
314
- def load(self, **kwargs):
314
+ def load(self):
315
315
  from transformers import AutoModel, AutoTokenizer
316
316
 
317
317
  if self._check_tensorizer_integrity():
@@ -329,10 +329,7 @@ class InternVLChatModel(PytorchChatModel):
329
329
  if device is not None:
330
330
  kwargs["device_map"] = device
331
331
 
332
- if "8-bit" in self.quantization.lower():
333
- kwargs["load_in_8bit"] = True
334
- elif "4-bit" in self.quantization.lower():
335
- kwargs["load_in_4bit"] = True
332
+ kwargs = self.apply_bnb_quantization(kwargs)
336
333
 
337
334
  self._model = AutoModel.from_pretrained(self.model_path, **kwargs).eval()
338
335
 
@@ -71,7 +71,7 @@ class Internlm2PytorchChatModel(PytorchChatModel):
71
71
  return model, tokenizer
72
72
 
73
73
  @classmethod
74
- def match(
74
+ def match_json(
75
75
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
76
76
  ) -> bool:
77
77
  model_family = llm_family.model_family or llm_family.model_name
@@ -42,7 +42,7 @@ class MiniCPMV25Model(PytorchChatModel):
42
42
  self._model = None
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  family = model_family.model_family or model_family.model_name
@@ -55,7 +55,7 @@ class MiniCPMV25Model(PytorchChatModel):
55
55
 
56
56
  return AutoModel
57
57
 
58
- def load(self, **kwargs):
58
+ def load(self):
59
59
  from transformers import AutoModel, AutoTokenizer
60
60
  from transformers.generation import GenerationConfig
61
61
 
@@ -76,11 +76,13 @@ class MiniCPMV25Model(PytorchChatModel):
76
76
  if "int4" in self.model_path:
77
77
  model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
78
78
  else:
79
+ kwargs = self.apply_bnb_quantization()
79
80
  model = AutoModel.from_pretrained(
80
81
  self.model_path,
81
82
  trust_remote_code=True,
82
83
  torch_dtype=torch.float16,
83
84
  device_map=self._device,
85
+ **kwargs
84
86
  )
85
87
  tokenizer = AutoTokenizer.from_pretrained(
86
88
  self.model_path, trust_remote_code=True
@@ -49,7 +49,7 @@ class MiniCPMV26Model(PytorchChatModel):
49
49
  self._processor = None
50
50
 
51
51
  @classmethod
52
- def match(
52
+ def match_json(
53
53
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
54
54
  ) -> bool:
55
55
  family = model_family.model_family or model_family.model_name
@@ -71,7 +71,7 @@ class MiniCPMV26Model(PytorchChatModel):
71
71
 
72
72
  return AutoModel
73
73
 
74
- def load(self, **kwargs):
74
+ def load(self):
75
75
  from transformers import AutoModel, AutoProcessor, AutoTokenizer
76
76
  from transformers.generation import GenerationConfig
77
77
 
@@ -96,11 +96,13 @@ class MiniCPMV26Model(PytorchChatModel):
96
96
  if "int4" in self.model_path:
97
97
  model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
98
98
  else:
99
+ kwargs = self.apply_bnb_quantization()
99
100
  model = AutoModel.from_pretrained(
100
101
  self.model_path,
101
102
  trust_remote_code=True,
102
103
  torch_dtype=torch.float16,
103
104
  device_map=self._device,
105
+ **kwargs,
104
106
  )
105
107
  tokenizer = AutoTokenizer.from_pretrained(
106
108
  self.model_path, trust_remote_code=True
@@ -322,7 +324,7 @@ class MiniCPMV26Model(PytorchChatModel):
322
324
  "input_image": images,
323
325
  }
324
326
 
325
- def _get_full_prompt(self, messages: List[Dict], tools):
327
+ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict): # type: ignore
326
328
  msgs, video_existed = self._convert_to_specific_style(messages)
327
329
  if video_existed:
328
330
  raise RuntimeError(
@@ -35,7 +35,7 @@ class OmniLMMModel(PytorchChatModel):
35
35
  self._model = None
36
36
 
37
37
  @classmethod
38
- def match(
38
+ def match_json(
39
39
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
40
40
  ) -> bool:
41
41
  llm_family = model_family.model_family or model_family.model_name
@@ -42,7 +42,7 @@ class OptPytorchModel(PytorchModel):
42
42
  )
43
43
 
44
44
  @classmethod
45
- def match(
45
+ def match_json(
46
46
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
48
  if llm_spec.model_format != "pytorch":