xinference 1.3.0.post2__py3-none-any.whl → 1.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (53) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +1 -0
  3. xinference/conftest.py +7 -0
  4. xinference/core/chat_interface.py +39 -24
  5. xinference/core/model.py +3 -1
  6. xinference/core/scheduler.py +3 -0
  7. xinference/core/worker.py +1 -1
  8. xinference/model/embedding/core.py +12 -5
  9. xinference/model/llm/__init__.py +2 -1
  10. xinference/model/llm/core.py +10 -0
  11. xinference/model/llm/llama_cpp/core.py +266 -3
  12. xinference/model/llm/llm_family.json +390 -17
  13. xinference/model/llm/llm_family_modelscope.json +348 -29
  14. xinference/model/llm/mlx/core.py +15 -4
  15. xinference/model/llm/{reasoning_parsers/deepseek_r1_reasoning_parser.py → reasoning_parser.py} +9 -13
  16. xinference/model/llm/sglang/core.py +7 -2
  17. xinference/model/llm/transformers/chatglm.py +4 -4
  18. xinference/model/llm/transformers/core.py +22 -5
  19. xinference/model/llm/transformers/intern_vl.py +2 -1
  20. xinference/model/llm/transformers/utils.py +1 -1
  21. xinference/model/llm/utils.py +134 -60
  22. xinference/model/llm/vllm/core.py +31 -42
  23. xinference/types.py +4 -0
  24. xinference/web/ui/build/asset-manifest.json +3 -3
  25. xinference/web/ui/build/index.html +1 -1
  26. xinference/web/ui/build/static/js/main.55b70cb7.js +3 -0
  27. xinference/web/ui/build/static/js/main.55b70cb7.js.map +1 -0
  28. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +1 -0
  29. xinference/web/ui/node_modules/.cache/babel-loader/2deac8d5636974533e3714f34e94fc754f9153a07c6ee11e72846cb8eae47e4b.json +1 -0
  30. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +1 -0
  31. xinference/web/ui/node_modules/.cache/babel-loader/87a9b13f2466f375ae5c6e7c08b279cc38351d29710d7f7626bbb07a85262b79.json +1 -0
  32. xinference/web/ui/node_modules/.cache/babel-loader/e23d476fcbf6fd69c8986bf82133d257d28aa8fc9a5cab231d81c1c75c58cd99.json +1 -0
  33. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/e7a8c37fda8725cab69c7ef8c627060bd7fc806adc67e00fe628ba148cb86d7f.json +1 -0
  35. xinference/web/ui/src/locales/en.json +9 -1
  36. xinference/web/ui/src/locales/zh.json +9 -1
  37. {xinference-1.3.0.post2.dist-info → xinference-1.3.1.post1.dist-info}/METADATA +9 -5
  38. {xinference-1.3.0.post2.dist-info → xinference-1.3.1.post1.dist-info}/RECORD +43 -44
  39. xinference/model/llm/reasoning_parsers/__init__.py +0 -13
  40. xinference/model/llm/reasoning_parsers/abs_reasoning_parsers.py +0 -98
  41. xinference/web/ui/build/static/js/main.ad42919c.js +0 -3
  42. xinference/web/ui/build/static/js/main.ad42919c.js.map +0 -1
  43. xinference/web/ui/node_modules/.cache/babel-loader/074a42304bbbaa79e1bfc3b28502457a390df55708de9006f4cc8e35c60aea87.json +0 -1
  44. xinference/web/ui/node_modules/.cache/babel-loader/279ace390216236a82b3d8995c78eca4d637ac9a523e9f521a2d9c76607a43d7.json +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/630a7bd592596cc6e291fc32238ce7c08238038a64ed8ccee0eb0c13c9902910.json +0 -1
  46. xinference/web/ui/node_modules/.cache/babel-loader/914c33e91c1012e3bcd3e96f3a25884cbef148290632d0266dab972b8cc1e95f.json +0 -1
  47. xinference/web/ui/node_modules/.cache/babel-loader/b7939cd3a48adf12fccfdd0803019b5cc235ff7de3a297dae70ce635e0eea13e.json +0 -1
  48. xinference/web/ui/node_modules/.cache/babel-loader/fecf076bcd198a458c2a6ab0e85e40dc1c99994c353164e79c469be162cb74c9.json +0 -1
  49. /xinference/web/ui/build/static/js/{main.ad42919c.js.LICENSE.txt → main.55b70cb7.js.LICENSE.txt} +0 -0
  50. {xinference-1.3.0.post2.dist-info → xinference-1.3.1.post1.dist-info}/LICENSE +0 -0
  51. {xinference-1.3.0.post2.dist-info → xinference-1.3.1.post1.dist-info}/WHEEL +0 -0
  52. {xinference-1.3.0.post2.dist-info → xinference-1.3.1.post1.dist-info}/entry_points.txt +0 -0
  53. {xinference-1.3.0.post2.dist-info → xinference-1.3.1.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-02-22T23:10:02+0800",
11
+ "date": "2025-03-11T12:00:36+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "378a47adad8506a13105b063322ccd7a04f7ea5d",
15
- "version": "1.3.0.post2"
14
+ "full-revisionid": "2ef99fbb5450a76a6ba07a909f58b8c2e4c22a28",
15
+ "version": "1.3.1.post1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -1330,6 +1330,7 @@ class RESTfulAPI(CancelMixin):
1330
1330
  raise HTTPException(status_code=500, detail=str(e))
1331
1331
 
1332
1332
  try:
1333
+ kwargs["model_uid"] = model_uid
1333
1334
  embedding = await model.create_embedding(body.input, **kwargs)
1334
1335
  return Response(embedding, media_type="application/json")
1335
1336
  except Exception as e:
xinference/conftest.py CHANGED
@@ -304,3 +304,10 @@ def setup_with_auth():
304
304
  os.remove(auth_file)
305
305
  except:
306
306
  pass
307
+
308
+
309
+ @pytest.fixture
310
+ def set_use_xllamacpp():
311
+ os.environ["USE_XLLAMACPP"] = "1"
312
+ yield
313
+ del os.environ["USE_XLLAMACPP"]
@@ -113,6 +113,7 @@ class GradioInterface:
113
113
  max_tokens: int,
114
114
  temperature: float,
115
115
  lora_name: str,
116
+ stream: bool,
116
117
  ) -> Generator:
117
118
  from ..client import RESTfulClient
118
119
 
@@ -123,29 +124,40 @@ class GradioInterface:
123
124
  messages = to_chat(flatten(history))
124
125
  messages.append(dict(role="user", content=message))
125
126
 
126
- response_content = ""
127
- for chunk in model.chat(
128
- messages,
129
- generate_config={
130
- "max_tokens": int(max_tokens),
131
- "temperature": temperature,
132
- "stream": True,
133
- "lora_name": lora_name,
134
- },
135
- ):
136
- assert isinstance(chunk, dict)
137
- delta = chunk["choices"][0]["delta"]
138
- if "content" not in delta:
139
- continue
140
- else:
141
- # some model like deepseek-r1-distill-qwen
142
- # will generate <think>...</think> ...
143
- # in gradio, no output will be rendered,
144
- # thus escape html tags in advance
145
- response_content += html.escape(delta["content"])
146
- yield response_content
147
-
148
- yield response_content
127
+ if stream:
128
+ response_content = ""
129
+ for chunk in model.chat(
130
+ messages,
131
+ generate_config={
132
+ "max_tokens": int(max_tokens),
133
+ "temperature": temperature,
134
+ "stream": True,
135
+ "lora_name": lora_name,
136
+ },
137
+ ):
138
+ assert isinstance(chunk, dict)
139
+ delta = chunk["choices"][0]["delta"]
140
+ if "content" not in delta:
141
+ continue
142
+ else:
143
+ # some model like deepseek-r1-distill-qwen
144
+ # will generate <think>...</think> ...
145
+ # in gradio, no output will be rendered,
146
+ # thus escape html tags in advance
147
+ response_content += html.escape(delta["content"])
148
+ yield response_content
149
+
150
+ yield response_content
151
+ else:
152
+ result = model.chat(
153
+ messages,
154
+ generate_config={
155
+ "max_tokens": int(max_tokens),
156
+ "temperature": temperature,
157
+ "lora_name": lora_name,
158
+ },
159
+ )
160
+ yield html.escape(result["choices"][0]["message"]["content"]) # type: ignore
149
161
 
150
162
  return gr.ChatInterface(
151
163
  fn=generate_wrapper,
@@ -153,7 +165,9 @@ class GradioInterface:
153
165
  gr.Slider(
154
166
  minimum=1,
155
167
  maximum=self.context_length,
156
- value=512,
168
+ value=512
169
+ if "reasoning" not in self.model_ability
170
+ else self.context_length // 2,
157
171
  step=1,
158
172
  label="Max Tokens",
159
173
  ),
@@ -161,6 +175,7 @@ class GradioInterface:
161
175
  minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
162
176
  ),
163
177
  gr.Text(label="LoRA Name"),
178
+ gr.Checkbox(label="Stream", value=True),
164
179
  ],
165
180
  title=f"🚀 Xinference Chat Bot : {self.model_name} 🚀",
166
181
  css="""
xinference/core/model.py CHANGED
@@ -231,6 +231,7 @@ class ModelActor(xo.StatelessActor, CancelMixin):
231
231
  driver_info: Optional[dict] = None, # for model across workers
232
232
  ):
233
233
  super().__init__()
234
+ from ..model.llm.llama_cpp.core import XllamaCppModel
234
235
  from ..model.llm.lmdeploy.core import LMDeployModel
235
236
  from ..model.llm.sglang.core import SGLANGModel
236
237
  from ..model.llm.transformers.core import PytorchModel
@@ -251,7 +252,8 @@ class ModelActor(xo.StatelessActor, CancelMixin):
251
252
  self._lock = (
252
253
  None
253
254
  if isinstance(
254
- self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel)
255
+ self._model,
256
+ (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel, XllamaCppModel),
255
257
  )
256
258
  else asyncio.locks.Lock()
257
259
  )
@@ -97,6 +97,9 @@ class InferenceRequest:
97
97
  # check the integrity of args passed upstream
98
98
  self._check_args()
99
99
 
100
+ # for reasoning_content using
101
+ self.previous_texts = [""]
102
+
100
103
  def _check_args(self):
101
104
  assert len(self._inference_args) == 1
102
105
  # generate config
xinference/core/worker.py CHANGED
@@ -1002,7 +1002,7 @@ class WorkerActor(xo.StatelessActor):
1002
1002
  )
1003
1003
  try:
1004
1004
  subpool_address = self._model_uid_to_addr[model_uid]
1005
- await self._main_pool.remove_sub_pool(subpool_address)
1005
+ await self._main_pool.remove_sub_pool(subpool_address, force=True)
1006
1006
  except Exception as e:
1007
1007
  logger.debug(
1008
1008
  "Remove sub pool failed, model uid: %s, error: %s", model_uid, e
@@ -268,7 +268,7 @@ class EmbeddingModel:
268
268
  **kwargs,
269
269
  ):
270
270
  sentences = self._fix_langchain_openai_inputs(sentences)
271
-
271
+ model_uid = kwargs.pop("model_uid", None)
272
272
  from sentence_transformers import SentenceTransformer
273
273
 
274
274
  kwargs.setdefault("normalize_embeddings", True)
@@ -546,8 +546,14 @@ class EmbeddingModel:
546
546
  # when batching, the attention mask 1 means there is a token
547
547
  # thus we just sum up it to get the total number of tokens
548
548
  if "clip" in self._model_spec.model_name.lower():
549
- all_token_nums += features["input_ids"].numel()
550
- all_token_nums += features["pixel_values"].numel()
549
+ if "input_ids" in features and hasattr(
550
+ features["input_ids"], "numel"
551
+ ):
552
+ all_token_nums += features["input_ids"].numel()
553
+ if "pixel_values" in features and hasattr(
554
+ features["pixel_values"], "numel"
555
+ ):
556
+ all_token_nums += features["pixel_values"].numel()
551
557
  else:
552
558
  all_token_nums += features["attention_mask"].sum().item()
553
559
 
@@ -657,7 +663,7 @@ class EmbeddingModel:
657
663
  self._model,
658
664
  objs,
659
665
  convert_to_numpy=False,
660
- **self._kwargs,
666
+ **kwargs,
661
667
  )
662
668
  else:
663
669
  all_embeddings, all_token_nums = encode(
@@ -693,7 +699,8 @@ class EmbeddingModel:
693
699
  if not is_bge_m3_flag_model and not kwargs.get("return_sparse")
694
700
  else "dict"
695
701
  ),
696
- model=self._model_uid,
702
+ model=model_uid, # type: ignore
703
+ model_replica=self._model_uid,
697
704
  data=embedding_list,
698
705
  usage=usage,
699
706
  )
@@ -129,7 +129,7 @@ def register_custom_model():
129
129
 
130
130
 
131
131
  def _install():
132
- from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel
132
+ from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel, XllamaCppModel
133
133
  from .lmdeploy.core import LMDeployChatModel, LMDeployModel
134
134
  from .mlx.core import MLXChatModel, MLXModel, MLXVisionModel
135
135
  from .sglang.core import SGLANGChatModel, SGLANGModel
@@ -169,6 +169,7 @@ def _install():
169
169
  [
170
170
  LlamaCppChatModel,
171
171
  LlamaCppModel,
172
+ XllamaCppModel,
172
173
  ]
173
174
  )
174
175
  SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel])
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
25
25
  from ...core.utils import parse_replica_model_uid
26
26
  from ...types import PeftModelConfig
27
27
  from ..core import ModelDescription
28
+ from .reasoning_parser import ReasoningParser
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from .llm_family import LLMFamilyV1, LLMSpecV1
@@ -57,6 +58,7 @@ class LLM(abc.ABC):
57
58
  self.model_spec = model_spec
58
59
  self.quantization = quantization
59
60
  self.model_path = model_path
61
+ self.reasoning_parser = None
60
62
  if args:
61
63
  raise ValueError(f"Unrecognized positional arguments: {args}")
62
64
  if kwargs:
@@ -117,6 +119,14 @@ class LLM(abc.ABC):
117
119
  ) -> bool:
118
120
  raise NotImplementedError
119
121
 
122
+ def prepare_parse_reasoning_content(self, reasoning_content):
123
+ # Initialize reasoning parser if model has reasoning ability
124
+ if "reasoning" in self.model_family.model_ability and reasoning_content:
125
+ self.reasoning_parser = ReasoningParser(
126
+ self.model_family.reasoning_start_tag,
127
+ self.model_family.reasoning_end_tag,
128
+ )
129
+
120
130
 
121
131
  class LLMDescription(ModelDescription):
122
132
  def __init__(
@@ -11,11 +11,15 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import concurrent.futures
14
15
  import logging
15
16
  import os
17
+ import queue
16
18
  import time
17
19
  from typing import Dict, Iterator, List, Optional, Union
18
20
 
21
+ import orjson
22
+
19
23
  from ....types import (
20
24
  ChatCompletion,
21
25
  ChatCompletionChunk,
@@ -32,6 +36,254 @@ from ..utils import DEEPSEEK_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY, ChatModelM
32
36
 
33
37
  logger = logging.getLogger(__name__)
34
38
 
39
+ USE_XLLAMACPP = bool(int(os.environ.get("USE_XLLAMACPP", 0)))
40
+
41
+
42
+ class _Sentinel:
43
+ pass
44
+
45
+
46
+ class XllamaCppModel(LLM, ChatModelMixin):
47
+ def __init__(
48
+ self,
49
+ model_uid: str,
50
+ model_family: "LLMFamilyV1",
51
+ model_spec: "LLMSpecV1",
52
+ quantization: str,
53
+ model_path: str,
54
+ llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
55
+ ):
56
+ super().__init__(model_uid, model_family, model_spec, quantization, model_path)
57
+
58
+ self._llamacpp_model_config: LlamaCppModelConfig = self._sanitize_model_config(
59
+ llamacpp_model_config
60
+ )
61
+ self._llm = None
62
+ self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
63
+
64
+ def _sanitize_model_config(
65
+ self, llamacpp_model_config: Optional[LlamaCppModelConfig]
66
+ ) -> LlamaCppModelConfig:
67
+ if llamacpp_model_config is None:
68
+ llamacpp_model_config = LlamaCppModelConfig()
69
+
70
+ if self.model_family.context_length:
71
+ llamacpp_model_config.setdefault("n_ctx", self.model_family.context_length)
72
+ llamacpp_model_config.setdefault("use_mmap", False)
73
+ llamacpp_model_config.setdefault("use_mlock", True)
74
+
75
+ if (
76
+ "llama-2" in self.model_family.model_name
77
+ and self.model_spec.model_size_in_billions == 70
78
+ ):
79
+ llamacpp_model_config["use_mlock"] = False
80
+ llamacpp_model_config["n_gqa"] = 8
81
+
82
+ if self._is_darwin_and_apple_silicon():
83
+ llamacpp_model_config.setdefault("n_gpu_layers", -1)
84
+ elif self._is_linux():
85
+ llamacpp_model_config.setdefault("n_gpu_layers", -1)
86
+ llamacpp_model_config.setdefault("reasoning_content", False)
87
+
88
+ return llamacpp_model_config
89
+
90
+ def _sanitize_generate_config(
91
+ self, generate_config: Optional[LlamaCppGenerateConfig]
92
+ ) -> LlamaCppGenerateConfig:
93
+ if generate_config is None:
94
+ generate_config = LlamaCppGenerateConfig(
95
+ **CreateCompletionLlamaCpp().dict()
96
+ )
97
+ else:
98
+ from llama_cpp import LlamaGrammar
99
+
100
+ grammar = generate_config.get("grammar")
101
+ if grammar is not None and not isinstance(grammar, LlamaGrammar):
102
+ generate_config["grammar"] = LlamaGrammar.from_string(
103
+ generate_config["grammar"]
104
+ )
105
+ # Validate generate_config and fill default values to the generate config.
106
+ generate_config = LlamaCppGenerateConfig(
107
+ **CreateCompletionLlamaCpp(**generate_config).dict()
108
+ )
109
+ # Currently, llama.cpp does not support lora
110
+ generate_config.pop("lora_name", None) # type: ignore
111
+ return generate_config
112
+
113
+ @classmethod
114
+ def match(
115
+ cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
116
+ ) -> bool:
117
+ if llm_spec.model_format not in ["ggufv2"]:
118
+ return False
119
+ if (
120
+ "chat" not in llm_family.model_ability
121
+ and "generate" not in llm_family.model_ability
122
+ ):
123
+ return False
124
+ return True
125
+
126
+ def load(self):
127
+ try:
128
+ from xllamacpp import CommonParams, Server
129
+ except ImportError:
130
+ error_message = "Failed to import module 'xllamacpp'"
131
+ installation_guide = ["Please make sure 'xllamacpp' is installed. "]
132
+
133
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
134
+
135
+ reasoning_content = self._llamacpp_model_config.pop("reasoning_content")
136
+ self.prepare_parse_reasoning_content(reasoning_content)
137
+
138
+ if os.path.isfile(self.model_path):
139
+ # mostly passed from --model_path
140
+ model_path = os.path.realpath(self.model_path)
141
+ else:
142
+ # handle legacy cache.
143
+ model_path = os.path.realpath(
144
+ os.path.join(
145
+ self.model_path,
146
+ self.model_spec.model_file_name_template.format(
147
+ quantization=self.quantization
148
+ ),
149
+ )
150
+ )
151
+ legacy_model_file_path = os.path.join(self.model_path, "model.bin")
152
+ if os.path.exists(legacy_model_file_path):
153
+ model_path = legacy_model_file_path
154
+
155
+ try:
156
+ params = CommonParams()
157
+ params.model = model_path
158
+ if self.model_family.chat_template:
159
+ params.chat_template = self.model_family.chat_template
160
+ # This is the default value, could be overwritten by _llamacpp_model_config
161
+ params.n_parallel = os.cpu_count()
162
+ for k, v in self._llamacpp_model_config.items():
163
+ try:
164
+ setattr(params, k, v)
165
+ except Exception as e:
166
+ logger.error("Failed to set the param %s = %s, error: %s", k, v, e)
167
+ n_threads = self._llamacpp_model_config.get("n_threads", os.cpu_count())
168
+ params.cpuparams.n_threads = n_threads
169
+ params.cpuparams_batch.n_threads = n_threads
170
+ if params.n_gpu_layers == -1:
171
+ # Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
172
+ # 0x7FFFFFFF is INT32 max, will be auto set to all layers
173
+ params.n_gpu_layers = 0x7FFFFFFF
174
+ self._llm = Server(params)
175
+ self._executor = concurrent.futures.ThreadPoolExecutor(
176
+ max_workers=max(10, n_threads)
177
+ )
178
+ except AssertionError:
179
+ raise RuntimeError(f"Load model {self.model_family.model_name} failed")
180
+
181
+ def generate(
182
+ self, prompt: str, generate_config: Optional[LlamaCppGenerateConfig] = None
183
+ ) -> Union[Completion, Iterator[CompletionChunk]]:
184
+ generate_config = self._sanitize_generate_config(generate_config)
185
+ stream = generate_config.get("stream", False)
186
+ q: queue.Queue = queue.Queue()
187
+
188
+ def _handle_completion():
189
+ # TODO(fyrestone): Replace the LlamaCppGenerateConfig with OpenAI params.
190
+ data = generate_config
191
+ data.pop("stopping_criteria", None)
192
+ data.pop("logits_processor", None)
193
+ data.pop("suffix", None)
194
+ data.pop("best_of", None)
195
+ data.update(
196
+ {
197
+ "prompt": prompt,
198
+ "stream": stream,
199
+ }
200
+ )
201
+ prompt_json = orjson.dumps(data)
202
+
203
+ def _res_callback(ok):
204
+ try:
205
+ res = orjson.loads(ok)
206
+ res["model"] = self.model_uid
207
+ q.put(res)
208
+ except Exception as e:
209
+ logger.exception("handle_completions callback failed: %s", e)
210
+
211
+ try:
212
+ self._llm.handle_completions(prompt_json, _res_callback, _res_callback)
213
+ except Exception as ex:
214
+ logger.exception("handle_completions failed: %s", ex)
215
+ q.put(_Sentinel)
216
+
217
+ assert self._executor
218
+ self._executor.submit(_handle_completion)
219
+
220
+ if stream:
221
+
222
+ def _to_iterator():
223
+ while (r := q.get()) is not _Sentinel:
224
+ yield r
225
+
226
+ return _to_iterator()
227
+ else:
228
+ return q.get()
229
+
230
+ def chat(
231
+ self,
232
+ messages: List[Dict],
233
+ generate_config: Optional[LlamaCppGenerateConfig] = None,
234
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
235
+ generate_config = self._sanitize_generate_config(generate_config)
236
+ stream = generate_config.get("stream", False)
237
+ tools = generate_config.pop("tools", []) if generate_config else None
238
+ q: queue.Queue = queue.Queue()
239
+
240
+ def _handle_chat_completion():
241
+ # TODO(fyrestone): Replace the LlamaCppGenerateConfig with OpenAI params.
242
+ data = generate_config
243
+ data.pop("stopping_criteria", None)
244
+ data.pop("logits_processor", None)
245
+ data.pop("suffix", None)
246
+ data.pop("best_of", None)
247
+ data.update(
248
+ {
249
+ "messages": messages,
250
+ "stream": stream,
251
+ "tools": tools,
252
+ }
253
+ )
254
+ prompt_json = orjson.dumps(data)
255
+
256
+ def _res_callback(ok):
257
+ try:
258
+ res = orjson.loads(ok)
259
+ res["model"] = self.model_uid
260
+ q.put(res)
261
+ except Exception as e:
262
+ logger.exception("handle_chat_completions callback failed: %s", e)
263
+
264
+ try:
265
+ self._llm.handle_chat_completions(
266
+ prompt_json, _res_callback, _res_callback
267
+ )
268
+ except Exception as ex:
269
+ logger.exception("handle_chat_completions failed: %s", ex)
270
+ q.put(_Sentinel)
271
+
272
+ assert self._executor
273
+ self._executor.submit(_handle_chat_completion)
274
+
275
+ if stream:
276
+
277
+ def _to_iterator():
278
+ while (r := q.get()) is not _Sentinel:
279
+ yield r
280
+
281
+ return self._to_chat_completion_chunks(
282
+ _to_iterator(), self.reasoning_parser
283
+ )
284
+ else:
285
+ return self._to_chat_completion(q.get(), self.reasoning_parser)
286
+
35
287
 
36
288
  class LlamaCppModel(LLM):
37
289
  def __init__(
@@ -76,6 +328,7 @@ class LlamaCppModel(LLM):
76
328
  llamacpp_model_config.setdefault("n_gpu_layers", -1)
77
329
  elif self._is_linux() and self._can_apply_cublas():
78
330
  llamacpp_model_config.setdefault("n_gpu_layers", -1)
331
+ llamacpp_model_config.setdefault("reasoning_content", False)
79
332
 
80
333
  return llamacpp_model_config
81
334
 
@@ -123,6 +376,9 @@ class LlamaCppModel(LLM):
123
376
 
124
377
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
125
378
 
379
+ reasoning_content = self._llamacpp_model_config.pop("reasoning_content")
380
+ self.prepare_parse_reasoning_content(reasoning_content)
381
+
126
382
  if os.path.isfile(self.model_path):
127
383
  # mostly passed from --model_path
128
384
  model_path = os.path.realpath(self.model_path)
@@ -292,10 +548,17 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
292
548
  if stream:
293
549
  it = self.generate(full_prompt, generate_config)
294
550
  assert isinstance(it, Iterator)
295
- return self._to_chat_completion_chunks(it)
551
+ return self._to_chat_completion_chunks(it, self.reasoning_parser)
296
552
  else:
297
553
  c = self.generate(full_prompt, generate_config)
298
554
  assert not isinstance(c, Iterator)
299
555
  if tools:
300
- return self._tool_calls_completion(self.model_family, self.model_uid, c)
301
- return self._to_chat_completion(c)
556
+ return self._post_process_completion(
557
+ self.model_family, self.model_uid, c, self.reasoning_parser
558
+ )
559
+ return self._to_chat_completion(c, self.reasoning_parser)
560
+
561
+
562
+ if USE_XLLAMACPP:
563
+ LlamaCppModel = XllamaCppModel # type: ignore # noqa: F811
564
+ LlamaCppChatModel = XllamaCppModel # type: ignore # noqa: F811