xinference 1.7.1.post1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.post1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from threading import Thread
17
+ from typing import Any, Dict, Iterator, List, Tuple
18
+
19
+ import torch
20
+
21
+ from .....model.utils import select_device
22
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
23
+ from ...utils import _decode_image
24
+ from ..core import register_non_default_model
25
+ from .core import PytorchMultiModalModel
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @register_transformer
31
+ @register_non_default_model("glm-4.1v-thinking")
32
+ class Glm4_1VModel(PytorchMultiModalModel):
33
+ @classmethod
34
+ def match_json(
35
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
36
+ ) -> bool:
37
+ family = model_family.model_family or model_family.model_name
38
+ if "glm-4.1v" in family.lower():
39
+ return True
40
+ return False
41
+
42
+ def decide_device(self):
43
+ device = self._pytorch_model_config.get("device", "auto")
44
+ self._device = select_device(device)
45
+
46
+ def load_processor(self):
47
+ from transformers import AutoProcessor
48
+
49
+ self._processor = AutoProcessor.from_pretrained(self.model_path, use_fast=True)
50
+ self._tokenizer = self._processor.tokenizer
51
+
52
+ def load_multimodal_model(self):
53
+ from transformers import Glm4vForConditionalGeneration
54
+
55
+ kwargs = {"device_map": "auto"}
56
+ kwargs = self.apply_bnb_quantization(kwargs)
57
+
58
+ model = Glm4vForConditionalGeneration.from_pretrained(
59
+ self.model_path,
60
+ torch_dtype=torch.bfloat16,
61
+ **kwargs,
62
+ )
63
+ self._model = model.eval()
64
+ self._device = self._model.device
65
+
66
+ @staticmethod
67
+ def _get_processed_msgs(messages: List[Dict]) -> List[Dict]:
68
+ res = []
69
+ for message in messages:
70
+ role = message["role"]
71
+ content = message["content"]
72
+ if isinstance(content, str):
73
+ res.append({"role": role, "content": content})
74
+ else:
75
+ texts = []
76
+ image_urls = []
77
+ for c in content:
78
+ c_type = c.get("type")
79
+ if c_type == "text":
80
+ texts.append(c["text"])
81
+ else:
82
+ assert (
83
+ c_type == "image_url"
84
+ ), "Please follow the image input of the OpenAI API."
85
+ image_urls.append(c["image_url"]["url"])
86
+ if len(image_urls) > 1:
87
+ raise RuntimeError("Only one image per message is supported")
88
+ image_futures = []
89
+ with ThreadPoolExecutor() as executor:
90
+ for image_url in image_urls:
91
+ fut = executor.submit(_decode_image, image_url)
92
+ image_futures.append(fut)
93
+ images = [fut.result() for fut in image_futures]
94
+ assert len(images) <= 1
95
+ text = " ".join(texts)
96
+ if images:
97
+ content = [
98
+ {"type": "image", "image": images[0]},
99
+ {"type": "text", "text": text},
100
+ ]
101
+ res.append({"role": role, "content": content})
102
+ else:
103
+ res.append(
104
+ {"role": role, "content": {"type": "text", "text": text}}
105
+ )
106
+ return res
107
+
108
+ def build_inputs_from_messages(
109
+ self,
110
+ messages: List[Dict],
111
+ generate_config: Dict,
112
+ ):
113
+ msgs = self._get_processed_msgs(messages)
114
+ inputs = self._processor.apply_chat_template(
115
+ msgs,
116
+ add_generation_prompt=True,
117
+ tokenize=True,
118
+ return_tensors="pt",
119
+ return_dict=True,
120
+ ) # chat mode
121
+ inputs = inputs.to(self._model.device)
122
+ return inputs
123
+
124
+ def get_stop_strs(self) -> List[str]:
125
+ return ["<|endoftext|>"]
126
+
127
+ def get_builtin_stop_token_ids(self) -> Tuple:
128
+ from transformers import AutoConfig
129
+
130
+ return tuple(AutoConfig.from_pretrained(self.model_path).eos_token_id)
131
+
132
+ def build_generate_kwargs(
133
+ self,
134
+ generate_config: Dict,
135
+ ) -> Dict[str, Any]:
136
+ return dict(
137
+ do_sample=True,
138
+ top_p=generate_config.get("top_p", 1e-5),
139
+ repetition_penalty=generate_config.get("repetition_penalty", 1.1),
140
+ top_k=generate_config.get("top_k", 2),
141
+ max_new_tokens=generate_config.get("max_tokens", 512),
142
+ )
143
+
144
+ def build_streaming_iter(
145
+ self,
146
+ messages: List[Dict],
147
+ generate_config: Dict,
148
+ ) -> Tuple[Iterator, int]:
149
+ from transformers import TextIteratorStreamer
150
+
151
+ generate_kwargs = self.build_generate_kwargs(generate_config)
152
+ inputs = self.build_inputs_from_messages(messages, generate_config)
153
+ streamer = TextIteratorStreamer(
154
+ tokenizer=self._tokenizer,
155
+ timeout=60,
156
+ skip_prompt=True,
157
+ skip_special_tokens=False,
158
+ )
159
+ kwargs = {
160
+ **inputs,
161
+ **generate_kwargs,
162
+ "streamer": streamer,
163
+ }
164
+ logger.debug("Generate with kwargs: %s", generate_kwargs)
165
+ t = Thread(target=self._model.generate, kwargs=kwargs)
166
+ t.start()
167
+ return streamer, len(inputs.input_ids[0])
@@ -22,7 +22,7 @@ import torch
22
22
  from .....core.model import register_batching_multimodal_models
23
23
  from .....core.scheduler import InferenceRequest
24
24
  from .....model.utils import select_device
25
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
25
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
26
26
  from ...utils import _decode_image
27
27
  from ..core import register_non_default_model
28
28
  from ..utils import get_max_src_len
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
37
37
  class Glm4VModel(PytorchMultiModalModel):
38
38
  @classmethod
39
39
  def match_json(
40
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
40
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
41
41
  ) -> bool:
42
42
  family = model_family.model_family or model_family.model_name
43
43
  if "glm-4v" in family.lower():
@@ -19,7 +19,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
 
22
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
22
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
23
23
  from ...utils import _decode_image, parse_messages
24
24
  from ..core import register_non_default_model
25
25
  from .core import PytorchMultiModalModel
@@ -35,7 +35,7 @@ class InternVLChatModel(PytorchMultiModalModel):
35
35
 
36
36
  @classmethod
37
37
  def match_json(
38
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
38
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
39
39
  ) -> bool:
40
40
  family = model_family.model_family or model_family.model_name
41
41
  if "internvl3" in family.lower():
@@ -22,7 +22,7 @@ from .....core.model import register_batching_multimodal_models
22
22
  from .....core.scheduler import InferenceRequest
23
23
  from .....model.utils import select_device
24
24
  from .....types import PytorchModelConfig
25
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
25
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
26
26
  from ...utils import _decode_image, parse_messages
27
27
  from ..core import register_non_default_model
28
28
  from .core import PytorchMultiModalModel
@@ -33,10 +33,10 @@ logger = logging.getLogger(__name__)
33
33
  @register_batching_multimodal_models("MiniCPM-V-2.6")
34
34
  @register_transformer
35
35
  @register_non_default_model("MiniCPM-V-2.6")
36
- class Glm4VModel(PytorchMultiModalModel):
36
+ class MiniCPMV26Model(PytorchMultiModalModel):
37
37
  @classmethod
38
38
  def match_json(
39
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
39
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
40
40
  ) -> bool:
41
41
  family = model_family.model_family or model_family.model_name
42
42
  if "MiniCPM-V-2.6".lower() in family.lower():
@@ -18,7 +18,7 @@ from typing import Any, Dict, Iterator, List, Tuple
18
18
  import torch
19
19
  from PIL import Image
20
20
 
21
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
21
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
22
22
  from ..core import register_non_default_model
23
23
  from .core import PytorchMultiModalModel
24
24
 
@@ -35,7 +35,7 @@ class Ovis2ChatModel(PytorchMultiModalModel):
35
35
 
36
36
  @classmethod
37
37
  def match_json(
38
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
38
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
39
39
  ) -> bool:
40
40
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
41
41
  return False
@@ -27,7 +27,7 @@ from .....types import (
27
27
  ChatCompletionChoice,
28
28
  CompletionUsage,
29
29
  )
30
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
30
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
31
31
  from ..core import PytorchGenerateConfig, register_non_default_model
32
32
  from .core import PytorchMultiModalModel
33
33
 
@@ -44,7 +44,7 @@ class Qwen2_5OmniChatModel(PytorchMultiModalModel):
44
44
 
45
45
  @classmethod
46
46
  def match_json(
47
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
48
48
  ) -> bool:
49
49
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
50
50
  return False
@@ -20,7 +20,7 @@ from urllib.request import urlopen
20
20
  import numpy as np
21
21
 
22
22
  from .....model.utils import select_device
23
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
23
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
24
24
  from ..core import register_non_default_model
25
25
  from .core import PytorchMultiModalModel
26
26
 
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
32
32
  class Qwen2AudioChatModel(PytorchMultiModalModel):
33
33
  @classmethod
34
34
  def match_json(
35
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
35
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
36
36
  ) -> bool:
37
37
  llm_family = model_family.model_family or model_family.model_name
38
38
  if "qwen2-audio".lower() in llm_family.lower():
@@ -20,7 +20,7 @@ from .....core.scheduler import InferenceRequest
20
20
  from .....device_utils import is_npu_available
21
21
  from .....model.utils import select_device
22
22
  from .....types import PytorchModelConfig
23
- from ...llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
23
+ from ...llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
24
24
  from ..core import register_non_default_model
25
25
  from .core import PytorchMultiModalModel
26
26
 
@@ -46,7 +46,7 @@ class Qwen2VLChatModel(PytorchMultiModalModel):
46
46
 
47
47
  @classmethod
48
48
  def match_json(
49
- cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
49
+ cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
50
50
  ) -> bool:
51
51
  if model_spec.model_format not in ["pytorch", "gptq", "awq"]:
52
52
  return False
@@ -16,7 +16,7 @@ from typing import List, Optional
16
16
 
17
17
  from ....core.scheduler import InferenceRequest
18
18
  from ....types import LoRA
19
- from ..llm_family import LLMFamilyV1, LLMSpecV1, register_transformer
19
+ from ..llm_family import LLMFamilyV2, LLMSpecV1, register_transformer
20
20
  from .core import PytorchModel, PytorchModelConfig, register_non_default_model
21
21
 
22
22
 
@@ -26,9 +26,7 @@ class OptPytorchModel(PytorchModel):
26
26
  def __init__(
27
27
  self,
28
28
  model_uid: str,
29
- model_family: "LLMFamilyV1",
30
- model_spec: "LLMSpecV1",
31
- quantization: str,
29
+ model_family: "LLMFamilyV2",
32
30
  model_path: str,
33
31
  pytorch_model_config: Optional[PytorchModelConfig] = None,
34
32
  peft_model: Optional[List[LoRA]] = None,
@@ -36,8 +34,6 @@ class OptPytorchModel(PytorchModel):
36
34
  super().__init__(
37
35
  model_uid,
38
36
  model_family,
39
- model_spec,
40
- quantization,
41
37
  model_path,
42
38
  pytorch_model_config=pytorch_model_config,
43
39
  peft_model=peft_model,
@@ -45,7 +41,7 @@ class OptPytorchModel(PytorchModel):
45
41
 
46
42
  @classmethod
47
43
  def match_json(
48
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
44
+ cls, llm_family: "LLMFamilyV2", llm_spec: "LLMSpecV1", quantization: str
49
45
  ) -> bool:
50
46
  if llm_spec.model_format != "pytorch":
51
47
  return False
@@ -16,7 +16,6 @@ import base64
16
16
  import functools
17
17
  import json
18
18
  import logging
19
- import os
20
19
  import re
21
20
  import time
22
21
  import typing
@@ -50,13 +49,7 @@ from ...types import (
50
49
  CompletionChunk,
51
50
  CompletionUsage,
52
51
  )
53
- from .llm_family import (
54
- LlamaCppLLMSpecV1,
55
- LLMFamilyV1,
56
- LLMSpecV1,
57
- _get_cache_dir,
58
- get_cache_status,
59
- )
52
+ from .core import chat_context_var
60
53
  from .reasoning_parser import ReasoningParser
61
54
 
62
55
  logger = logging.getLogger(__name__)
@@ -319,9 +312,7 @@ class ChatModelMixin:
319
312
  for i, choice in enumerate(choices): # type: ignore
320
313
  delta = ChatCompletionChunkDelta()
321
314
  if "text" in choice and choice["finish_reason"] is None:
322
- if not reasoning_parser or not reasoning_parser.check_content_parser():
323
- delta["content"] = choice["text"]
324
- else:
315
+ if reasoning_parser and reasoning_parser.check_content_parser():
325
316
  assert previous_texts is not None
326
317
  current_text = previous_texts[-1] + choice["text"]
327
318
  delta = reasoning_parser.extract_reasoning_content_streaming(
@@ -330,6 +321,8 @@ class ChatModelMixin:
330
321
  delta_text=choice["text"],
331
322
  )
332
323
  previous_texts[-1] = current_text
324
+ else:
325
+ delta["content"] = choice["text"]
333
326
  elif "text" in choice and choice["finish_reason"] is not None:
334
327
  delta["content"] = choice["text"]
335
328
  if reasoning_parser and reasoning_parser.check_content_parser():
@@ -463,12 +456,19 @@ class ChatModelMixin:
463
456
  cls,
464
457
  chunks: AsyncGenerator[CompletionChunk, None],
465
458
  reasoning_parser: Optional[ReasoningParser] = None,
459
+ ctx: Optional[Dict[str, Any]] = None,
466
460
  ) -> AsyncGenerator[ChatCompletionChunk, None]:
461
+ def set_context():
462
+ if ctx:
463
+ chat_context_var.set(ctx)
464
+
467
465
  previous_texts = [""]
468
466
  # Process chunks
469
467
  if reasoning_parser:
468
+ set_context()
470
469
  chunks = reasoning_parser.prepare_reasoning_content_streaming(chunks)
471
470
  async for chunk in chunks:
471
+ set_context()
472
472
  choices = chunk.get("choices")
473
473
  if not choices:
474
474
  # usage
@@ -560,23 +560,33 @@ class ChatModelMixin:
560
560
  def split_into_blocks(text: str) -> list[str]:
561
561
  # Match blocks starting with <think> or <tool_call> and ending with </think> or </tool_call>
562
562
  pattern = r"(<(think|tool_call)>.*?</\2>)"
563
- blocks = re.findall(pattern, text, re.DOTALL)
564
- return [match[0] for match in blocks]
563
+ parts = []
564
+ last_end = 0
565
+ # Find all label blocks and record their positions
566
+ for m in re.finditer(pattern, text, re.DOTALL):
567
+ # Text before adding tags
568
+ if m.start() > last_end:
569
+ parts.append(text[last_end : m.start()])
570
+ # Add label block
571
+ parts.append(m.group(0))
572
+ last_end = m.end()
573
+ # Text after adding the last tag
574
+ if last_end < len(text):
575
+ parts.append(text[last_end:])
576
+ return parts
565
577
 
566
578
  contents = split_into_blocks(text)
567
579
  results: List[Tuple] = []
568
580
  for content in contents:
569
- content = content.strip()
570
- if content:
581
+ if content.strip():
571
582
  pos1 = content.find(QWEN_TOOL_CALL_SYMBOLS[0])
572
583
  if pos1 != -1:
573
584
  content = content[pos1 + len(QWEN_TOOL_CALL_SYMBOLS[0]) :]
574
585
  pos2 = content.find(QWEN_TOOL_CALL_SYMBOLS[1])
575
586
  if pos2 != -1:
576
587
  content = content[:pos2]
577
- content = content.strip()
578
588
  try:
579
- res = json.loads(content)
589
+ res = json.loads(content, strict=False)
580
590
  results.append((None, res["name"], res["arguments"]))
581
591
  except Exception as e:
582
592
  logger.error(
@@ -724,7 +734,7 @@ class ChatModelMixin:
724
734
  failed_contents.append(content)
725
735
  finish_reason = "tool_calls" if tool_calls else "stop"
726
736
 
727
- content = ". ".join(failed_contents) if failed_contents else None
737
+ content = "".join(failed_contents) if failed_contents else None
728
738
 
729
739
  # fix: qwen tool_call content field return null
730
740
  family = model_family.model_family or model_family.model_name
@@ -802,7 +812,7 @@ class ChatModelMixin:
802
812
  failed_contents.append(content)
803
813
  finish_reason = "tool_calls" if tool_calls else "stop"
804
814
 
805
- content = ". ".join(failed_contents) if failed_contents else None
815
+ content = "".join(failed_contents) if failed_contents else None
806
816
 
807
817
  # fix: qwen tool_call content field return null
808
818
  family = model_family.model_family or model_family.model_name
@@ -880,38 +890,13 @@ class ChatModelMixin:
880
890
  return transformed_messages
881
891
 
882
892
 
883
- def get_file_location(
884
- llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
885
- ) -> Tuple[str, bool]:
886
- cache_dir = _get_cache_dir(
887
- llm_family, spec, quantization, create_if_not_exist=False
888
- )
889
- cache_status = get_cache_status(llm_family, spec, quantization)
890
- if isinstance(cache_status, list):
891
- is_cached = None
892
- for q, cs in zip(spec.quantizations, cache_status):
893
- if q == quantization:
894
- is_cached = cs
895
- break
896
- else:
897
- is_cached = cache_status
898
- assert isinstance(is_cached, bool)
899
-
900
- if spec.model_format in ["pytorch", "gptq", "awq", "fp8", "mlx"]:
901
- return cache_dir, is_cached
902
- elif spec.model_format in ["ggufv2"]:
903
- assert isinstance(spec, LlamaCppLLMSpecV1)
904
- filename = spec.model_file_name_template.format(quantization=quantization)
905
- model_path = os.path.join(cache_dir, filename)
906
- return model_path, is_cached
907
- else:
908
- raise ValueError(f"Not supported model format {spec.model_format}")
909
-
910
-
911
893
  def get_model_version(
912
- llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
894
+ model_name: str,
895
+ model_format: str,
896
+ model_size_in_billions: Union[str, int],
897
+ quantization: str,
913
898
  ) -> str:
914
- return f"{llm_family.model_name}--{llm_spec.model_size_in_billions}B--{llm_spec.model_format}--{quantization}"
899
+ return f"{model_name}--{model_size_in_billions}B--{model_format}--{quantization}"
915
900
 
916
901
 
917
902
  def _decode_image(_url):