xinference 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/restful_client.py +1 -1
  3. xinference/conftest.py +0 -7
  4. xinference/core/media_interface.py +9 -8
  5. xinference/core/model.py +13 -6
  6. xinference/core/scheduler.py +1 -10
  7. xinference/core/worker.py +0 -10
  8. xinference/model/audio/model_spec.json +53 -1
  9. xinference/model/audio/model_spec_modelscope.json +57 -1
  10. xinference/model/embedding/core.py +19 -11
  11. xinference/model/image/model_spec.json +10 -1
  12. xinference/model/image/model_spec_modelscope.json +20 -0
  13. xinference/model/llm/__init__.py +6 -54
  14. xinference/model/llm/core.py +19 -5
  15. xinference/model/llm/llama_cpp/core.py +59 -3
  16. xinference/model/llm/llama_cpp/memory.py +455 -0
  17. xinference/model/llm/llm_family.json +185 -397
  18. xinference/model/llm/llm_family.py +88 -16
  19. xinference/model/llm/llm_family_modelscope.json +199 -421
  20. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  21. xinference/model/llm/sglang/core.py +4 -0
  22. xinference/model/llm/transformers/__init__.py +27 -6
  23. xinference/model/llm/transformers/chatglm.py +4 -2
  24. xinference/model/llm/transformers/core.py +49 -28
  25. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  26. xinference/model/llm/transformers/gemma3.py +119 -164
  27. xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
  28. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  29. xinference/model/llm/transformers/multimodal/core.py +205 -0
  30. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  31. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  32. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  33. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  34. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  35. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  36. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  37. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  38. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  39. xinference/model/llm/transformers/opt.py +4 -2
  40. xinference/model/llm/transformers/utils.py +6 -37
  41. xinference/model/llm/vllm/core.py +4 -0
  42. xinference/model/rerank/core.py +7 -1
  43. xinference/model/rerank/utils.py +17 -0
  44. xinference/web/ui/build/asset-manifest.json +3 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
  47. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
  52. xinference/web/ui/src/locales/en.json +3 -1
  53. xinference/web/ui/src/locales/zh.json +3 -1
  54. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/METADATA +16 -14
  55. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
  56. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
  57. xinference/model/llm/transformers/cogvlm2.py +0 -442
  58. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  59. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  60. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  61. xinference/model/llm/transformers/intern_vl.py +0 -526
  62. xinference/model/llm/transformers/internlm2.py +0 -94
  63. xinference/model/llm/transformers/minicpmv25.py +0 -193
  64. xinference/model/llm/transformers/omnilmm.py +0 -132
  65. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  66. xinference/model/llm/transformers/qwen_vl.py +0 -360
  67. xinference/thirdparty/omnilmm/LICENSE +0 -201
  68. xinference/thirdparty/omnilmm/__init__.py +0 -0
  69. xinference/thirdparty/omnilmm/chat.py +0 -218
  70. xinference/thirdparty/omnilmm/constants.py +0 -4
  71. xinference/thirdparty/omnilmm/conversation.py +0 -332
  72. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  73. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  74. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  75. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  76. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  77. xinference/thirdparty/omnilmm/utils.py +0 -134
  78. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  79. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  84. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
  85. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
  86. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
  87. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
@@ -785,40 +785,6 @@
785
785
  "</s>"
786
786
  ]
787
787
  },
788
- {
789
- "version": 1,
790
- "context_length": 8192,
791
- "model_name": "cogvlm2",
792
- "model_lang": [
793
- "en",
794
- "zh"
795
- ],
796
- "model_ability": [
797
- "chat",
798
- "vision"
799
- ],
800
- "model_description": "CogVLM2 have achieved good results in many lists compared to the previous generation of CogVLM open source models. Its excellent performance can compete with some non-open source models.",
801
- "model_specs": [
802
- {
803
- "model_format": "pytorch",
804
- "model_size_in_billions": 20,
805
- "quantizations": [
806
- "none"
807
- ],
808
- "model_id": "AI-Research/cogvlm2-llama3-chinese-chat-19b",
809
- "model_hub": "openmind_hub"
810
- }
811
- ],
812
- "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ '<|end_of_text|>' }}{% endif %}",
813
- "stop_token_ids": [
814
- 128001,
815
- 128009
816
- ],
817
- "stop": [
818
- "<|end_of_text|>",
819
- "<|eot_id|>"
820
- ]
821
- },
822
788
  {
823
789
  "version": 1,
824
790
  "context_length": 8192,
@@ -107,7 +107,11 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
107
107
  "deepseek-r1-distill-qwen",
108
108
  "deepseek-r1-distill-llama",
109
109
  "deepseek-v3",
110
+ "deepseek-v3-0324",
110
111
  "deepseek-r1",
112
+ "deepseek-r1-0528",
113
+ "deepseek-r1-0528-qwen3",
114
+ "deepseek-prover-v2",
111
115
  "DianJin-R1",
112
116
  "qwen3",
113
117
  "HuatuoGPT-o1-Qwen2.5",
@@ -16,12 +16,33 @@
16
16
  import importlib
17
17
  import os
18
18
  import pkgutil
19
+ from typing import Dict
19
20
 
20
- # Get the path of the current package
21
+
22
+ def import_submodules(package_path: str, package_name: str, globals_dict: Dict) -> None:
23
+ """
24
+ Recursively import all classes in submodules and subpackages
25
+ """
26
+ for _, module_name, is_pkg in pkgutil.iter_modules([package_path]):
27
+ full_module_name = f"{package_name}.{module_name}"
28
+
29
+ if module_name.startswith(
30
+ ("_", "test_")
31
+ ): # Skip the modules which start with "_" or "test_"
32
+ continue
33
+
34
+ module = importlib.import_module(full_module_name)
35
+ globals_dict[module_name] = module
36
+
37
+ # If it's a pkg, recursive processing
38
+ if is_pkg:
39
+ subpackage_path = os.path.join(package_path, module_name)
40
+ import_submodules(subpackage_path, full_module_name, globals_dict)
41
+
42
+
43
+ # Get the path and name of the current package
21
44
  __path__ = [os.path.dirname(os.path.abspath(__file__))]
45
+ __package__ = __name__
22
46
 
23
- # Automatically import all modules under the current package
24
- for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
25
- if not module_name.startswith("_"): # Skip modules starting with underscore
26
- module = importlib.import_module(f"{__name__}.{module_name}")
27
- globals()[module_name] = module
47
+ # Automatic import of all sub-modules and sub-packages
48
+ import_submodules(__path__[0], __package__, globals())
@@ -22,17 +22,19 @@ import torch
22
22
 
23
23
  from ....core.scheduler import InferenceRequest
24
24
  from ....types import ChatCompletion, ChatCompletionChunk, LoRA, PytorchGenerateConfig
25
- from ..llm_family import LLMFamilyV1, LLMSpecV1
25
+ from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
26
26
  from ..utils import (
27
27
  GLM4_TOOL_CALL_FAMILY,
28
28
  generate_chat_completion,
29
29
  generate_completion_chunk,
30
30
  )
31
- from .core import PytorchChatModel, PytorchModelConfig
31
+ from .core import PytorchChatModel, PytorchModelConfig, register_non_default_model
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
+ @register_transformer
37
+ @register_non_default_model("glm4-chat", "glm4-chat-1m")
36
38
  class ChatglmPytorchChatModel(PytorchChatModel):
37
39
  def __init__(
38
40
  self,
@@ -16,7 +16,7 @@ import json
16
16
  import logging
17
17
  import os
18
18
  from functools import lru_cache
19
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
20
20
 
21
21
  import torch
22
22
 
@@ -45,36 +45,17 @@ from ..utils import (
45
45
  QWEN_TOOL_CALL_FAMILY,
46
46
  ChatModelMixin,
47
47
  )
48
- from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
48
+ from .utils import (
49
+ _get_pad_param,
50
+ get_context_length,
51
+ get_max_src_len,
52
+ pad_prefill_tokens,
53
+ )
49
54
 
50
55
  logger = logging.getLogger(__name__)
51
56
 
52
- NON_DEFAULT_MODEL_LIST: List[str] = [
53
- "opt",
54
- "glm4-chat",
55
- "glm4-chat-1m",
56
- "qwen-vl-chat",
57
- "OmniLMM",
58
- "deepseek-vl-chat",
59
- "cogvlm2",
60
- "cogvlm2-video-llama3-chat",
61
- "MiniCPM-Llama3-V-2_5",
62
- "MiniCPM-V-2.6",
63
- "glm-4v",
64
- "qwen2-audio",
65
- "qwen2-audio-instruct",
66
- "deepseek-v2",
67
- "deepseek-v2-chat",
68
- "deepseek-v2.5",
69
- "deepseek-v2-chat-0628",
70
- "glm-edge-v",
71
- "QvQ-72B-Preview",
72
- "cogagent",
73
- "gemma-3-1b-it",
74
- "gemma-3-it",
75
- "Ovis2",
76
- "deepseek-vl2",
77
- ]
57
+ # !!!!! Do not add model_name to this list, use `register_non_default_model` below instead!
58
+ NON_DEFAULT_MODEL_LIST: List[str] = []
78
59
 
79
60
 
80
61
  # Define the decorator to support multiple names registration
@@ -551,6 +532,36 @@ class PytorchModel(LLM):
551
532
  def prepare_sanitize_generate_config(self, req: InferenceRequest):
552
533
  return self._sanitize_generate_config(req.generate_config)
553
534
 
535
+ def merge_kv_cache(self, past_cache, new_cache):
536
+ from torch.nn.functional import pad
537
+ from transformers import DynamicCache
538
+
539
+ _, seq_len_idx = self.get_batch_size_and_seq_len_indexes_from_kv()
540
+ past_seq_len = past_cache[0][0].shape[seq_len_idx]
541
+ new_seq_len = new_cache[0][0].shape[seq_len_idx]
542
+ if past_seq_len != new_seq_len:
543
+ padding_target = new_cache if past_seq_len > new_seq_len else past_cache
544
+ padding_len = abs(past_seq_len - new_seq_len)
545
+ pad_param = _get_pad_param(seq_len_idx, padding_len)
546
+ for idx in range(len(padding_target)):
547
+ k = padding_target.key_cache[idx]
548
+ v = padding_target.value_cache[idx]
549
+ _k = pad(k, pad_param)
550
+ _v = pad(v, pad_param)
551
+ padding_target.key_cache[idx] = _k
552
+ padding_target.value_cache[idx] = _v
553
+
554
+ ret_kv = DynamicCache()
555
+ for idx in range(len(past_cache)):
556
+ k1, k2 = new_cache.key_cache[idx], past_cache.key_cache[idx]
557
+ v1, v2 = new_cache.value_cache[idx], past_cache.value_cache[idx]
558
+ ret_kv.update(
559
+ torch.cat((k1, k2), 0).contiguous(),
560
+ torch.cat((v1, v2), 0).contiguous(),
561
+ idx,
562
+ )
563
+ return ret_kv
564
+
554
565
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
555
566
  # check some parameters
556
567
  for r in req_list:
@@ -642,6 +653,16 @@ class PytorchModel(LLM):
642
653
  )
643
654
  self.handle_batch_inference_results(req_list)
644
655
 
656
+ def build_reduced_kv_cache(self, cache, skipped_indexes: Set[int]):
657
+ batch_size = cache.key_cache[0].shape[0]
658
+ batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
659
+ for idx in range(len(cache)):
660
+ cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::].contiguous()
661
+ cache.value_cache[idx] = cache.value_cache[idx][
662
+ batch_slices, ::
663
+ ].contiguous()
664
+ return cache
665
+
645
666
 
646
667
  class PytorchChatModel(PytorchModel, ChatModelMixin):
647
668
  def __init__(
@@ -15,59 +15,16 @@ import logging
15
15
 
16
16
  import torch
17
17
 
18
- from ..llm_family import LLMFamilyV1, LLMSpecV1
19
- from .core import PytorchChatModel, PytorchModel
18
+ from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
19
+ from .core import PytorchChatModel, register_non_default_model
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- class DeepSeekV2PytorchModel(PytorchModel):
25
- def _load_model(self, **kwargs):
26
- try:
27
- from transformers import (
28
- AutoModelForCausalLM,
29
- AutoTokenizer,
30
- GenerationConfig,
31
- )
32
- except ImportError:
33
- error_message = "Failed to import module 'transformers'"
34
- installation_guide = [
35
- "Please make sure 'transformers' is installed. ",
36
- "You can install it by `pip install transformers`\n",
37
- ]
38
-
39
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
40
-
41
- tokenizer = AutoTokenizer.from_pretrained(
42
- self.model_path,
43
- trust_remote_code=kwargs["trust_remote_code"],
44
- )
45
- model = AutoModelForCausalLM.from_pretrained(
46
- self.model_path,
47
- attn_implementation="eager",
48
- torch_dtype=torch.bfloat16,
49
- trust_remote_code=True,
50
- device_map="auto",
51
- **kwargs,
52
- )
53
- model.generation_config = GenerationConfig.from_pretrained(self.model_path)
54
- model.generation_config.pad_token_id = model.generation_config.eos_token_id
55
- return model, tokenizer
56
-
57
- @classmethod
58
- def match_json(
59
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
60
- ) -> bool:
61
- if llm_spec.model_format != "pytorch":
62
- return False
63
- model_family = llm_family.model_family or llm_family.model_name
64
- if "deepseek-v2" not in model_family:
65
- return False
66
- if "generate" not in llm_family.model_ability:
67
- return False
68
- return True
69
-
70
-
24
+ @register_transformer
25
+ @register_non_default_model(
26
+ "deepseek-v2-chat", "deepseek-v2.5", "deepseek-v2-chat-0628"
27
+ )
71
28
  class DeepSeekV2PytorchChatModel(PytorchChatModel):
72
29
  def _load_model(self, **kwargs):
73
30
  try:
@@ -11,29 +11,18 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
-
16
14
  import logging
17
- import sys
18
- import uuid
19
- from typing import Iterator, List, Optional, Union
20
-
21
- from ....model.utils import select_device
22
- from ....types import (
23
- ChatCompletion,
24
- ChatCompletionChunk,
25
- ChatCompletionMessage,
26
- CompletionChunk,
27
- PytorchModelConfig,
28
- )
29
- from ..llm_family import LLMFamilyV1, LLMSpecV1
30
- from ..utils import generate_chat_completion, generate_completion_chunk
31
- from .core import PytorchChatModel, PytorchGenerateConfig
32
- from .utils import cache_clean
15
+ from typing import Dict, List, Set
16
+
17
+ from ....core.scheduler import InferenceRequest
18
+ from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
19
+ from .core import PytorchChatModel, register_non_default_model
33
20
 
34
21
  logger = logging.getLogger(__name__)
35
22
 
36
23
 
24
+ @register_transformer
25
+ @register_non_default_model("gemma-3-1b-it")
37
26
  class Gemma3TextChatModel(PytorchChatModel):
38
27
  @classmethod
39
28
  def match_json(
@@ -46,163 +35,129 @@ class Gemma3TextChatModel(PytorchChatModel):
46
35
  return True
47
36
  return False
48
37
 
38
+ def _load_model(self, **kwargs):
39
+ import torch
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer
49
41
 
50
- class Gemma3ChatModel(PytorchChatModel):
51
- def __init__(self, *args, **kwargs):
52
- super().__init__(*args, **kwargs)
53
- self._tokenizer = None
54
- self._model = None
55
- self._device = None
56
- self._processor = None
57
-
58
- @classmethod
59
- def match_json(
60
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
61
- ) -> bool:
62
- if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
63
- return False
64
- llm_family = model_family.model_family or model_family.model_name
65
- if "gemma-3-it".lower() in llm_family.lower():
66
- return True
67
- return False
68
-
69
- def _sanitize_model_config(
70
- self, pytorch_model_config: Optional[PytorchModelConfig]
71
- ) -> PytorchModelConfig:
72
- pytorch_model_config = super()._sanitize_model_config(pytorch_model_config)
73
- assert pytorch_model_config is not None
74
- pytorch_model_config.setdefault("min_pixels", 256 * 28 * 28)
75
- pytorch_model_config.setdefault("max_pixels", 1280 * 28 * 28)
76
- return pytorch_model_config
77
-
78
- def load(self):
79
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration
80
-
81
- device = self._pytorch_model_config.get("device", "auto")
82
- device = select_device(device)
83
- self._device = device
84
- # for multiple GPU, set back to auto to make multiple devices work
85
- device = "auto" if device == "cuda" else device
86
- min_pixels = self._pytorch_model_config.get("min_pixels")
87
- max_pixels = self._pytorch_model_config.get("max_pixels")
88
- kwargs = self.apply_bnb_quantization()
89
- self._processor = AutoProcessor.from_pretrained(
42
+ tokenizer = AutoTokenizer.from_pretrained(
90
43
  self.model_path,
91
- min_pixels=min_pixels,
92
- max_pixels=max_pixels,
44
+ trust_remote_code=kwargs["trust_remote_code"],
45
+ revision=kwargs["revision"],
93
46
  )
94
- self._tokenizer = self._processor.tokenizer
95
- self._model = Gemma3ForConditionalGeneration.from_pretrained(
96
- self.model_path, device_map="auto", torch_dtype="bfloat16", **kwargs
97
- )
98
-
99
- @cache_clean
100
- def chat(
101
- self,
102
- messages: List[ChatCompletionMessage], # type: ignore
103
- generate_config: Optional[PytorchGenerateConfig] = None,
104
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
105
- messages = self._transform_messages(messages)
106
-
107
- generate_config = generate_config if generate_config else {}
108
-
109
- stream = generate_config.get("stream", False) if generate_config else False
110
-
111
- if stream:
112
- it = self._generate_stream(messages, generate_config)
113
- return self._to_chat_completion_chunks(it)
114
- else:
115
- c = self._generate(messages, generate_config)
116
- return c
117
-
118
- def _generate(
119
- self, messages: List, config: PytorchGenerateConfig = {}
120
- ) -> ChatCompletion:
121
- inputs = self._processor.apply_chat_template(
122
- messages,
123
- add_generation_prompt=True,
124
- tokenize=True,
125
- return_dict=True,
126
- return_tensors="pt",
127
- ).to(self._device)
128
- input_len = inputs["input_ids"].shape[-1]
129
-
130
- generation = self._model.generate(
131
- **inputs,
132
- do_sample=False,
133
- max_new_tokens=config.get("max_tokens", 512),
134
- temperature=config.get("temperature", 1),
47
+ kwargs["torch_dtype"] = torch.bfloat16
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ self.model_path,
50
+ **kwargs,
135
51
  )
136
- generation = generation[0][input_len:]
137
-
138
- decoded = self._processor.decode(generation, skip_special_tokens=True)
139
- return generate_chat_completion(self.model_uid, decoded)
52
+ self._device = model.device
53
+ return model, tokenizer
140
54
 
141
- def _generate_stream(
142
- self, messages: List, config: PytorchGenerateConfig = {}
143
- ) -> Iterator[CompletionChunk]:
144
- from threading import Thread
55
+ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict):
56
+ return messages
145
57
 
146
- from transformers import TextIteratorStreamer
58
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
59
+ """
60
+ Note that it is important to prepare `past_key_values` for gemma3 prefill phase
61
+ """
62
+ from transformers import HybridCache
147
63
 
148
- inputs = self._processor.apply_chat_template(
149
- messages,
150
- add_generation_prompt=True,
64
+ inputs = self._tokenizer.apply_chat_template(
65
+ prompts,
151
66
  tokenize=True,
152
- return_dict=True,
67
+ add_generation_prompt=True,
153
68
  return_tensors="pt",
69
+ return_dict=True,
70
+ padding=True,
154
71
  ).to(self._device)
155
72
 
156
- tokenizer = self._tokenizer
157
- streamer = TextIteratorStreamer(
158
- tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
73
+ for i, r in enumerate(req_list):
74
+ r.prompt_tokens = inputs["input_ids"][i].tolist()
75
+
76
+ batch_size = len(prompts)
77
+ max_cache_len = self.get_context_len()
78
+ kv = HybridCache(
79
+ self._model.config,
80
+ max_batch_size=batch_size,
81
+ max_cache_len=max_cache_len,
82
+ dtype=self._model.dtype,
83
+ device=self._device,
84
+ )
85
+ return {**inputs, "past_key_values": kv}
86
+
87
+ def merge_kv_cache(self, past_cache, new_cache):
88
+ """
89
+ Note that: DO NOT use the `update` func of `HybridCache`, that is unrelated to KV cache merging.
90
+ """
91
+ import torch
92
+ from transformers import HybridCache
93
+
94
+ max_cache_len = new_cache.max_cache_len
95
+ batch_size = past_cache.max_batch_size + new_cache.max_batch_size
96
+
97
+ kv_batch = HybridCache(
98
+ self._model.config,
99
+ max_batch_size=batch_size,
100
+ max_cache_len=max_cache_len,
101
+ dtype=self._model.dtype,
102
+ device=self._device,
159
103
  )
160
104
 
161
- gen_kwargs = {"streamer": streamer, **inputs}
162
- error = None
163
-
164
- def model_generate():
165
- try:
166
- return self._model.generate(
167
- **gen_kwargs,
168
- max_new_tokens=config.get("max_tokens", 512),
169
- temperature=config.get("temperature", 1),
170
- )
171
- except Exception:
172
- nonlocal error
173
- error = sys.exc_info()
174
- streamer.end()
175
- raise
176
-
177
- thread = Thread(target=model_generate)
178
- thread.start()
179
-
180
- completion_id = str(uuid.uuid1())
181
- for new_text in streamer:
182
- yield generate_completion_chunk(
183
- chunk_text=new_text,
184
- finish_reason=None,
185
- chunk_id=completion_id,
186
- model_uid=self.model_uid,
187
- prompt_tokens=-1,
188
- completion_tokens=-1,
189
- total_tokens=-1,
190
- has_choice=True,
191
- has_content=True,
192
- )
193
-
194
- if error:
195
- _, err, tb = error # type: ignore
196
- raise err.with_traceback(tb)
197
-
198
- yield generate_completion_chunk(
199
- chunk_text=None,
200
- finish_reason="stop",
201
- chunk_id=completion_id,
202
- model_uid=self.model_uid,
203
- prompt_tokens=-1,
204
- completion_tokens=-1,
205
- total_tokens=-1,
206
- has_choice=True,
207
- has_content=False,
105
+ new_ks = [
106
+ torch.cat([nk, pk], dim=0).contiguous()
107
+ for nk, pk in zip(new_cache.key_cache, past_cache.key_cache)
108
+ ]
109
+ new_vs = [
110
+ torch.cat([nv, pv], dim=0).contiguous()
111
+ for nv, pv in zip(new_cache.value_cache, past_cache.value_cache)
112
+ ]
113
+
114
+ kv_batch.key_cache.clear()
115
+ kv_batch.value_cache.clear()
116
+ kv_batch.key_cache.extend(new_ks)
117
+ kv_batch.value_cache.extend(new_vs)
118
+
119
+ return kv_batch
120
+
121
+ def build_decode_attention_mask(
122
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
123
+ ):
124
+ """
125
+ In Gemma3's inference script, attention_mask is handled internally for decode phase.
126
+ """
127
+ return None
128
+
129
+ def build_decode_position_ids(
130
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
131
+ ):
132
+ """
133
+ In Gemma3's inference script, position_ids is handled internally for decode phase.
134
+ """
135
+ return None
136
+
137
+ def build_reduced_kv_cache(self, cache, skipped_indexes: Set[int]):
138
+ from transformers import HybridCache
139
+
140
+ batch_slices = [
141
+ num for num in range(cache.max_batch_size) if num not in skipped_indexes
142
+ ]
143
+ batch_size = len(batch_slices)
144
+
145
+ kv_batch = HybridCache(
146
+ self._model.config,
147
+ max_batch_size=batch_size,
148
+ max_cache_len=cache.max_cache_len,
149
+ dtype=self._model.dtype,
150
+ device=self._device,
208
151
  )
152
+
153
+ ks = cache.key_cache
154
+ vs = cache.value_cache
155
+
156
+ new_ks = [_k[batch_slices, ::].contiguous() for _k in ks]
157
+ new_vs = [_v[batch_slices, ::].contiguous() for _v in vs]
158
+ kv_batch.key_cache.clear()
159
+ kv_batch.value_cache.clear()
160
+ kv_batch.key_cache.extend(new_ks)
161
+ kv_batch.value_cache.extend(new_vs)
162
+
163
+ return kv_batch
@@ -1,4 +1,4 @@
1
- # Copyright 2022-2023 XProbe Inc.
1
+ # Copyright 2022-2025 XProbe Inc.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.