xinference 1.5.1__py3-none-any.whl → 1.6.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/METADATA +59 -39
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/top_level.txt +0 -0
@@ -128,8 +128,38 @@ def register_custom_model():
128
128
  warnings.warn(f"{user_defined_llm_dir}/{f} has error, {e}")
129
129
 
130
130
 
131
+ def load_model_family_from_json(json_filename, target_families):
132
+ json_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), json_filename)
133
+ for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
134
+ model_spec = LLMFamilyV1.parse_obj(json_obj)
135
+ target_families.append(model_spec)
136
+
137
+ # register chat_template
138
+ if (
139
+ "chat" in model_spec.model_ability
140
+ and isinstance(model_spec.chat_template, str)
141
+ and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
142
+ ):
143
+ # note that the key is the model name,
144
+ # since there are multiple representations of the same prompt style name in json.
145
+ if model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE:
146
+ BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
147
+ "chat_template": model_spec.chat_template,
148
+ "stop_token_ids": model_spec.stop_token_ids,
149
+ "stop": model_spec.stop,
150
+ }
151
+
152
+ # register model family
153
+ if "chat" in model_spec.model_ability:
154
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
155
+ else:
156
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
157
+ if "tools" in model_spec.model_ability:
158
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
159
+
160
+
131
161
  def _install():
132
- from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel, XllamaCppModel
162
+ from .llama_cpp.core import XllamaCppModel
133
163
  from .lmdeploy.core import LMDeployChatModel, LMDeployModel
134
164
  from .mlx.core import MLXChatModel, MLXModel, MLXVisionModel
135
165
  from .sglang.core import SGLANGChatModel, SGLANGModel, SGLANGVisionModel
@@ -166,8 +196,6 @@ def _install():
166
196
  # register llm classes.
167
197
  LLAMA_CLASSES.extend(
168
198
  [
169
- LlamaCppChatModel,
170
- LlamaCppModel,
171
199
  XllamaCppModel,
172
200
  ]
173
201
  )
@@ -210,115 +238,14 @@ def _install():
210
238
  SUPPORTED_ENGINES["MLX"] = MLX_CLASSES
211
239
  SUPPORTED_ENGINES["LMDEPLOY"] = LMDEPLOY_CLASSES
212
240
 
213
- json_path = os.path.join(
214
- os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
241
+ load_model_family_from_json("llm_family.json", BUILTIN_LLM_FAMILIES)
242
+ load_model_family_from_json(
243
+ "llm_family_modelscope.json", BUILTIN_MODELSCOPE_LLM_FAMILIES
215
244
  )
216
- for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
217
- model_spec = LLMFamilyV1.parse_obj(json_obj)
218
- BUILTIN_LLM_FAMILIES.append(model_spec)
219
-
220
- # register chat_template
221
- if "chat" in model_spec.model_ability and isinstance(
222
- model_spec.chat_template, str
223
- ):
224
- # note that the key is the model name,
225
- # since there are multiple representations of the same prompt style name in json.
226
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
227
- "chat_template": model_spec.chat_template,
228
- "stop_token_ids": model_spec.stop_token_ids,
229
- "stop": model_spec.stop,
230
- }
231
- # register model family
232
- if "chat" in model_spec.model_ability:
233
- BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
234
- else:
235
- BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
236
- if "tools" in model_spec.model_ability:
237
- BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
238
-
239
- modelscope_json_path = os.path.join(
240
- os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
245
+ load_model_family_from_json(
246
+ "llm_family_openmind_hub.json", BUILTIN_OPENMIND_HUB_LLM_FAMILIES
241
247
  )
242
- for json_obj in json.load(codecs.open(modelscope_json_path, "r", encoding="utf-8")):
243
- model_spec = LLMFamilyV1.parse_obj(json_obj)
244
- BUILTIN_MODELSCOPE_LLM_FAMILIES.append(model_spec)
245
-
246
- # register prompt style, in case that we have something missed
247
- # if duplicated with huggingface json, keep it as the huggingface style
248
- if (
249
- "chat" in model_spec.model_ability
250
- and isinstance(model_spec.chat_template, str)
251
- and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
252
- ):
253
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
254
- "chat_template": model_spec.chat_template,
255
- "stop_token_ids": model_spec.stop_token_ids,
256
- "stop": model_spec.stop,
257
- }
258
- # register model family
259
- if "chat" in model_spec.model_ability:
260
- BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
261
- else:
262
- BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
263
- if "tools" in model_spec.model_ability:
264
- BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
265
-
266
- openmind_hub_json_path = os.path.join(
267
- os.path.dirname(os.path.abspath(__file__)), "llm_family_openmind_hub.json"
268
- )
269
- for json_obj in json.load(
270
- codecs.open(openmind_hub_json_path, "r", encoding="utf-8")
271
- ):
272
- model_spec = LLMFamilyV1.parse_obj(json_obj)
273
- BUILTIN_OPENMIND_HUB_LLM_FAMILIES.append(model_spec)
274
-
275
- # register prompt style, in case that we have something missed
276
- # if duplicated with huggingface json, keep it as the huggingface style
277
-
278
- if (
279
- "chat" in model_spec.model_ability
280
- and isinstance(model_spec.chat_template, str)
281
- and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
282
- ):
283
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
284
- "chat_template": model_spec.chat_template,
285
- "stop_token_ids": model_spec.stop_token_ids,
286
- "stop": model_spec.stop,
287
- }
288
- # register model family
289
- if "chat" in model_spec.model_ability:
290
- BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
291
- else:
292
- BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
293
- if "tools" in model_spec.model_ability:
294
- BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
295
-
296
- csghub_json_path = os.path.join(
297
- os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json"
298
- )
299
- for json_obj in json.load(codecs.open(csghub_json_path, "r", encoding="utf-8")):
300
- model_spec = LLMFamilyV1.parse_obj(json_obj)
301
- BUILTIN_CSGHUB_LLM_FAMILIES.append(model_spec)
302
-
303
- # register prompt style, in case that we have something missed
304
- # if duplicated with huggingface json, keep it as the huggingface style
305
- if (
306
- "chat" in model_spec.model_ability
307
- and isinstance(model_spec.chat_template, str)
308
- and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
309
- ):
310
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
311
- "chat_template": model_spec.chat_template,
312
- "stop_token_ids": model_spec.stop_token_ids,
313
- "stop": model_spec.stop,
314
- }
315
- # register model family
316
- if "chat" in model_spec.model_ability:
317
- BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
318
- else:
319
- BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
320
- if "tools" in model_spec.model_ability:
321
- BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
248
+ load_model_family_from_json("llm_family_csghub.json", BUILTIN_CSGHUB_LLM_FAMILIES)
322
249
 
323
250
  for llm_specs in [
324
251
  BUILTIN_LLM_FAMILIES,
@@ -17,6 +17,7 @@ import inspect
17
17
  import logging
18
18
  import os
19
19
  import platform
20
+ import warnings
20
21
  from abc import abstractmethod
21
22
  from collections import defaultdict
22
23
  from functools import lru_cache
@@ -134,13 +135,21 @@ class LLM(abc.ABC):
134
135
  ) -> bool:
135
136
  raise NotImplementedError
136
137
 
137
- def prepare_parse_reasoning_content(self, reasoning_content):
138
- # Initialize reasoning parser if model has reasoning ability
139
- if "reasoning" in self.model_family.model_ability and reasoning_content:
140
- self.reasoning_parser = ReasoningParser(
141
- self.model_family.reasoning_start_tag,
142
- self.model_family.reasoning_end_tag,
138
+ def prepare_parse_reasoning_content(
139
+ self, reasoning_content: bool, enable_thinking: bool = True
140
+ ):
141
+ if "hybrid" not in self.model_family.model_ability and not enable_thinking:
142
+ enable_thinking = True
143
+ warnings.warn(
144
+ "enable_thinking cannot be disabled for non hybrid model, will be ignored"
143
145
  )
146
+ # Initialize reasoning parser if model has reasoning ability
147
+ self.reasoning_parser = ReasoningParser( # type: ignore
148
+ reasoning_content,
149
+ self.model_family.reasoning_start_tag, # type: ignore
150
+ self.model_family.reasoning_end_tag, # type: ignore
151
+ enable_thinking=enable_thinking,
152
+ )
144
153
 
145
154
 
146
155
  class LLMDescription(ModelDescription):
@@ -16,29 +16,17 @@ import importlib.util
16
16
  import logging
17
17
  import os
18
18
  import queue
19
- import time
20
- from typing import Dict, Iterator, List, Optional, Union
19
+ from typing import Iterator, List, Optional, Union
21
20
 
22
21
  import orjson
23
22
 
24
- from ....types import (
25
- ChatCompletion,
26
- ChatCompletionChunk,
27
- Completion,
28
- CompletionChunk,
29
- CompletionUsage,
30
- CreateCompletionLlamaCpp,
31
- LlamaCppGenerateConfig,
32
- LlamaCppModelConfig,
33
- )
23
+ from ....types import ChatCompletion, ChatCompletionChunk, Completion, CompletionChunk
34
24
  from ..core import LLM
35
25
  from ..llm_family import LLMFamilyV1, LLMSpecV1
36
- from ..utils import DEEPSEEK_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY, ChatModelMixin
26
+ from ..utils import ChatModelMixin
37
27
 
38
28
  logger = logging.getLogger(__name__)
39
29
 
40
- USE_XLLAMACPP = bool(int(os.environ.get("USE_XLLAMACPP", 1)))
41
-
42
30
 
43
31
  class _Done:
44
32
  pass
@@ -57,21 +45,16 @@ class XllamaCppModel(LLM, ChatModelMixin):
57
45
  model_spec: "LLMSpecV1",
58
46
  quantization: str,
59
47
  model_path: str,
60
- llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
48
+ llamacpp_model_config: Optional[dict] = None,
61
49
  ):
62
50
  super().__init__(model_uid, model_family, model_spec, quantization, model_path)
63
-
64
- self._llamacpp_model_config: LlamaCppModelConfig = self._sanitize_model_config(
65
- llamacpp_model_config
66
- )
51
+ self._llamacpp_model_config = self._sanitize_model_config(llamacpp_model_config)
67
52
  self._llm = None
68
53
  self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
69
54
 
70
- def _sanitize_model_config(
71
- self, llamacpp_model_config: Optional[LlamaCppModelConfig]
72
- ) -> LlamaCppModelConfig:
55
+ def _sanitize_model_config(self, llamacpp_model_config: Optional[dict]) -> dict:
73
56
  if llamacpp_model_config is None:
74
- llamacpp_model_config = LlamaCppModelConfig()
57
+ llamacpp_model_config = {}
75
58
 
76
59
  if self.model_family.context_length:
77
60
  llamacpp_model_config.setdefault("n_ctx", self.model_family.context_length)
@@ -93,29 +76,6 @@ class XllamaCppModel(LLM, ChatModelMixin):
93
76
 
94
77
  return llamacpp_model_config
95
78
 
96
- def _sanitize_generate_config(
97
- self, generate_config: Optional[LlamaCppGenerateConfig]
98
- ) -> LlamaCppGenerateConfig:
99
- if generate_config is None:
100
- generate_config = LlamaCppGenerateConfig(
101
- **CreateCompletionLlamaCpp().dict()
102
- )
103
- else:
104
- from llama_cpp import LlamaGrammar
105
-
106
- grammar = generate_config.get("grammar")
107
- if grammar is not None and not isinstance(grammar, LlamaGrammar):
108
- generate_config["grammar"] = LlamaGrammar.from_string(
109
- generate_config["grammar"]
110
- )
111
- # Validate generate_config and fill default values to the generate config.
112
- generate_config = LlamaCppGenerateConfig(
113
- **CreateCompletionLlamaCpp(**generate_config).dict()
114
- )
115
- # Currently, llama.cpp does not support lora
116
- generate_config.pop("lora_name", None) # type: ignore
117
- return generate_config
118
-
119
79
  @classmethod
120
80
  def check_lib(cls) -> bool:
121
81
  return importlib.util.find_spec("xllamacpp") is not None
@@ -143,7 +103,10 @@ class XllamaCppModel(LLM, ChatModelMixin):
143
103
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
144
104
 
145
105
  reasoning_content = self._llamacpp_model_config.pop("reasoning_content")
146
- self.prepare_parse_reasoning_content(reasoning_content)
106
+ enable_thinking = self._llamacpp_model_config.pop("enable_thinking", True)
107
+ self.prepare_parse_reasoning_content(
108
+ reasoning_content, enable_thinking=enable_thinking
109
+ )
147
110
 
148
111
  if os.path.isfile(self.model_path):
149
112
  # mostly passed from --model_path
@@ -152,7 +115,7 @@ class XllamaCppModel(LLM, ChatModelMixin):
152
115
  # handle legacy cache.
153
116
  if (
154
117
  self.model_spec.model_file_name_split_template
155
- and self.model_spec.quantization_parts
118
+ and self.quantization in self.model_spec.quantization_parts
156
119
  ):
157
120
  part = self.model_spec.quantization_parts[self.quantization]
158
121
  model_path = os.path.join(
@@ -185,7 +148,14 @@ class XllamaCppModel(LLM, ChatModelMixin):
185
148
  params.n_parallel = os.cpu_count()
186
149
  for k, v in self._llamacpp_model_config.items():
187
150
  try:
188
- setattr(params, k, v)
151
+ if "." in k:
152
+ parts = k.split(".")
153
+ sub_param = params
154
+ for p in parts[:-1]:
155
+ sub_param = getattr(sub_param, p)
156
+ setattr(sub_param, parts[-1], v)
157
+ else:
158
+ setattr(params, k, v)
189
159
  except Exception as e:
190
160
  logger.error("Failed to set the param %s = %s, error: %s", k, v, e)
191
161
  n_threads = self._llamacpp_model_config.get("n_threads", os.cpu_count())
@@ -203,14 +173,13 @@ class XllamaCppModel(LLM, ChatModelMixin):
203
173
  raise RuntimeError(f"Load model {self.model_family.model_name} failed")
204
174
 
205
175
  def generate(
206
- self, prompt: str, generate_config: Optional[LlamaCppGenerateConfig] = None
176
+ self, prompt: str, generate_config: Optional[dict] = None
207
177
  ) -> Union[Completion, Iterator[CompletionChunk]]:
208
- generate_config = self._sanitize_generate_config(generate_config)
178
+ generate_config = generate_config or {}
209
179
  stream = generate_config.get("stream", False)
210
180
  q: queue.Queue = queue.Queue()
211
181
 
212
182
  def _handle_completion():
213
- # TODO(fyrestone): Replace the LlamaCppGenerateConfig with OpenAI params.
214
183
  data = generate_config
215
184
  data.pop("stopping_criteria", None)
216
185
  data.pop("logits_processor", None)
@@ -265,16 +234,15 @@ class XllamaCppModel(LLM, ChatModelMixin):
265
234
 
266
235
  def chat(
267
236
  self,
268
- messages: List[Dict],
269
- generate_config: Optional[LlamaCppGenerateConfig] = None,
237
+ messages: List[dict],
238
+ generate_config: Optional[dict] = None,
270
239
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
271
- generate_config = self._sanitize_generate_config(generate_config)
240
+ generate_config = generate_config or {}
272
241
  stream = generate_config.get("stream", False)
273
242
  tools = generate_config.pop("tools", []) if generate_config else None
274
243
  q: queue.Queue = queue.Queue()
275
244
 
276
245
  def _handle_chat_completion():
277
- # TODO(fyrestone): Replace the LlamaCppGenerateConfig with OpenAI params.
278
246
  data = generate_config
279
247
  data.pop("stopping_criteria", None)
280
248
  data.pop("logits_processor", None)
@@ -336,299 +304,3 @@ class XllamaCppModel(LLM, ChatModelMixin):
336
304
  if type(r) is _Error:
337
305
  raise Exception("Got error in chat: %s", r.msg)
338
306
  return self._to_chat_completion(r, self.reasoning_parser)
339
-
340
-
341
- class LlamaCppModel(LLM):
342
- def __init__(
343
- self,
344
- model_uid: str,
345
- model_family: "LLMFamilyV1",
346
- model_spec: "LLMSpecV1",
347
- quantization: str,
348
- model_path: str,
349
- llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
350
- ):
351
- super().__init__(model_uid, model_family, model_spec, quantization, model_path)
352
-
353
- self._llamacpp_model_config: LlamaCppModelConfig = self._sanitize_model_config(
354
- llamacpp_model_config
355
- )
356
- self._llm = None
357
-
358
- def _can_apply_cublas(self):
359
- # TODO: figure out the quantizations supported.
360
- return True
361
-
362
- def _sanitize_model_config(
363
- self, llamacpp_model_config: Optional[LlamaCppModelConfig]
364
- ) -> LlamaCppModelConfig:
365
- if llamacpp_model_config is None:
366
- llamacpp_model_config = LlamaCppModelConfig()
367
-
368
- if self.model_family.context_length:
369
- llamacpp_model_config.setdefault("n_ctx", self.model_family.context_length)
370
- llamacpp_model_config.setdefault("use_mmap", False)
371
- llamacpp_model_config.setdefault("use_mlock", True)
372
-
373
- if (
374
- "llama-2" in self.model_family.model_name
375
- and self.model_spec.model_size_in_billions == 70
376
- ):
377
- llamacpp_model_config["use_mlock"] = False
378
- llamacpp_model_config["n_gqa"] = 8
379
-
380
- if self._is_darwin_and_apple_silicon():
381
- llamacpp_model_config.setdefault("n_gpu_layers", -1)
382
- elif self._is_linux() and self._can_apply_cublas():
383
- llamacpp_model_config.setdefault("n_gpu_layers", -1)
384
- llamacpp_model_config.setdefault("reasoning_content", False)
385
-
386
- return llamacpp_model_config
387
-
388
- def _sanitize_generate_config(
389
- self, generate_config: Optional[LlamaCppGenerateConfig]
390
- ) -> LlamaCppGenerateConfig:
391
- if generate_config is None:
392
- generate_config = LlamaCppGenerateConfig(
393
- **CreateCompletionLlamaCpp().dict()
394
- )
395
- else:
396
- from llama_cpp import LlamaGrammar
397
-
398
- grammar = generate_config.get("grammar")
399
- if grammar is not None and not isinstance(grammar, LlamaGrammar):
400
- generate_config["grammar"] = LlamaGrammar.from_string(
401
- generate_config["grammar"]
402
- )
403
- # Validate generate_config and fill default values to the generate config.
404
- generate_config = LlamaCppGenerateConfig(
405
- **CreateCompletionLlamaCpp(**generate_config).dict()
406
- )
407
- # Currently, llama.cpp does not support lora
408
- generate_config.pop("lora_name", None) # type: ignore
409
- return generate_config
410
-
411
- def load(self):
412
- try:
413
- import llama_cpp
414
- from llama_cpp import Llama
415
-
416
- if llama_cpp.__version__ < "0.2.0":
417
- raise ValueError(
418
- "The llama_cpp version must be greater than 0.2.0. "
419
- "Please upgrade your version via `pip install -U llama_cpp` or refer to "
420
- "https://github.com/abetlen/llama-cpp-python#installation-with-openblas--cublas--clblast--metal."
421
- )
422
- except ImportError:
423
- error_message = "Failed to import module 'llama_cpp'"
424
- installation_guide = [
425
- "Please make sure 'llama_cpp' is installed. ",
426
- "You can install it by visiting the installation section of the git repo:\n",
427
- "https://github.com/abetlen/llama-cpp-python#installation-with-openblas--cublas--clblast--metal",
428
- ]
429
-
430
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
431
-
432
- reasoning_content = self._llamacpp_model_config.pop("reasoning_content")
433
- self.prepare_parse_reasoning_content(reasoning_content)
434
-
435
- if os.path.isfile(self.model_path):
436
- # mostly passed from --model_path
437
- model_path = self.model_path
438
- else:
439
- # handle legacy cache.
440
- if (
441
- self.model_spec.model_file_name_split_template
442
- and self.model_spec.quantization_parts
443
- ):
444
- part = self.model_spec.quantization_parts[self.quantization]
445
- model_path = os.path.join(
446
- self.model_path,
447
- self.model_spec.model_file_name_split_template.format(
448
- quantization=self.quantization, part=part[0]
449
- ),
450
- )
451
- else:
452
- model_path = os.path.join(
453
- self.model_path,
454
- self.model_spec.model_file_name_template.format(
455
- quantization=self.quantization
456
- ),
457
- )
458
- legacy_model_file_path = os.path.join(self.model_path, "model.bin")
459
- if os.path.exists(legacy_model_file_path):
460
- model_path = legacy_model_file_path
461
-
462
- try:
463
- self._llm = Llama(
464
- model_path=model_path,
465
- verbose=True,
466
- **self._llamacpp_model_config,
467
- )
468
- except AssertionError:
469
- raise RuntimeError(f"Load model {self.model_family.model_name} failed")
470
-
471
- @classmethod
472
- def check_lib(cls) -> bool:
473
- return importlib.util.find_spec("llama_cpp") is not None
474
-
475
- @classmethod
476
- def match_json(
477
- cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
478
- ) -> bool:
479
- if llm_spec.model_format not in ["ggufv2"]:
480
- return False
481
- if "qwen" in llm_family.model_name:
482
- return False
483
- if "generate" not in llm_family.model_ability:
484
- return False
485
- return True
486
-
487
- def generate(
488
- self, prompt: str, generate_config: Optional[LlamaCppGenerateConfig] = None
489
- ) -> Union[Completion, Iterator[CompletionChunk]]:
490
- def generator_wrapper(
491
- _prompt: str,
492
- _generate_config: LlamaCppGenerateConfig,
493
- ) -> Iterator[CompletionChunk]:
494
- assert self._llm is not None
495
- prompt_token_ids: List[int] = (
496
- (
497
- self._llm.tokenize(prompt.encode("utf-8"), special=True)
498
- if prompt != ""
499
- else [self._llm.token_bos()]
500
- )
501
- if isinstance(prompt, str)
502
- else prompt
503
- )
504
- prompt_tokens = len(prompt_token_ids)
505
- completion_tokens, total_tokens = 0, 0
506
- request_id = 0
507
- for index, _completion_chunk in enumerate(
508
- self._llm(prompt=_prompt, **_generate_config)
509
- ):
510
- _completion_chunk["model"] = self.model_uid
511
- request_id = _completion_chunk["id"]
512
- completion_tokens = index + 1
513
- total_tokens = prompt_tokens + completion_tokens
514
- _completion_chunk["usage"] = CompletionUsage(
515
- prompt_tokens=prompt_tokens,
516
- completion_tokens=completion_tokens,
517
- total_tokens=total_tokens,
518
- )
519
- yield _completion_chunk
520
- if include_usage:
521
- chunk = CompletionChunk(
522
- id=request_id,
523
- object="text_completion",
524
- created=int(time.time()),
525
- model=self.model_uid,
526
- choices=[],
527
- )
528
- chunk["usage"] = CompletionUsage(
529
- prompt_tokens=prompt_tokens,
530
- completion_tokens=completion_tokens,
531
- total_tokens=total_tokens,
532
- )
533
- yield chunk
534
-
535
- logger.debug(
536
- "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
537
- )
538
-
539
- generate_config = self._sanitize_generate_config(generate_config)
540
- stream = generate_config.get("stream", False)
541
- stream_options = generate_config.pop("stream_options", None)
542
- include_usage = (
543
- stream_options["include_usage"]
544
- if isinstance(stream_options, dict)
545
- else False
546
- )
547
-
548
- if not stream:
549
- assert self._llm is not None
550
- completion = self._llm(prompt=prompt, **generate_config)
551
-
552
- return completion
553
- else:
554
- return generator_wrapper(prompt, generate_config)
555
-
556
-
557
- class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
558
- def __init__(
559
- self,
560
- model_uid: str,
561
- model_family: "LLMFamilyV1",
562
- model_spec: "LLMSpecV1",
563
- quantization: str,
564
- model_path: str,
565
- llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
566
- ):
567
- super().__init__(
568
- model_uid,
569
- model_family,
570
- model_spec,
571
- quantization,
572
- model_path,
573
- llamacpp_model_config,
574
- )
575
-
576
- @classmethod
577
- def match_json(
578
- cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
579
- ) -> bool:
580
- if llm_spec.model_format not in ["ggufv2"]:
581
- return False
582
- if "chat" not in llm_family.model_ability:
583
- return False
584
- return True
585
-
586
- def _sanitize_generate_config(
587
- self, generate_config: Optional[LlamaCppGenerateConfig]
588
- ) -> LlamaCppGenerateConfig:
589
- generate_config = super()._sanitize_generate_config(generate_config)
590
- if self.model_family.stop and self.model_family.stop:
591
- generate_config["stop"] = self.model_family.stop.copy()
592
- return generate_config
593
-
594
- def chat(
595
- self,
596
- messages: List[Dict],
597
- generate_config: Optional[LlamaCppGenerateConfig] = None,
598
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
599
- model_family = self.model_family.model_family or self.model_family.model_name
600
- tools = generate_config.pop("tools", []) if generate_config else None
601
- full_context_kwargs = (
602
- self._get_chat_template_kwargs_from_generate_config(generate_config) or {} # type: ignore
603
- )
604
- if tools:
605
- if (
606
- model_family in QWEN_TOOL_CALL_FAMILY
607
- or model_family in DEEPSEEK_TOOL_CALL_FAMILY
608
- ):
609
- full_context_kwargs["tools"] = tools
610
- assert self.model_family.chat_template is not None
611
- full_prompt = self.get_full_context(
612
- messages, self.model_family.chat_template, **full_context_kwargs
613
- )
614
-
615
- generate_config = self._sanitize_generate_config(generate_config)
616
-
617
- stream = generate_config.get("stream", False)
618
- if stream:
619
- it = self.generate(full_prompt, generate_config)
620
- assert isinstance(it, Iterator)
621
- return self._to_chat_completion_chunks(it, self.reasoning_parser)
622
- else:
623
- c = self.generate(full_prompt, generate_config)
624
- assert not isinstance(c, Iterator)
625
- if tools:
626
- return self._post_process_completion(
627
- self.model_family, self.model_uid, c, self.reasoning_parser
628
- )
629
- return self._to_chat_completion(c, self.reasoning_parser)
630
-
631
-
632
- if USE_XLLAMACPP:
633
- LlamaCppModel = XllamaCppModel # type: ignore # noqa: F811
634
- LlamaCppChatModel = XllamaCppModel # type: ignore # noqa: F811