xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
xinference/model/llm/__init__.py
CHANGED
|
@@ -45,7 +45,6 @@ from .llm_family import (
|
|
|
45
45
|
LLMFamilyV1,
|
|
46
46
|
LLMSpecV1,
|
|
47
47
|
MLXLLMSpecV1,
|
|
48
|
-
PromptStyleV1,
|
|
49
48
|
PytorchLLMSpecV1,
|
|
50
49
|
get_cache_status,
|
|
51
50
|
get_user_defined_llm_families,
|
|
@@ -137,13 +136,18 @@ def _install():
|
|
|
137
136
|
from .transformers.cogvlm2 import CogVLM2Model
|
|
138
137
|
from .transformers.cogvlm2_video import CogVLM2VideoModel
|
|
139
138
|
from .transformers.core import PytorchChatModel, PytorchModel
|
|
139
|
+
from .transformers.deepseek_v2 import (
|
|
140
|
+
DeepSeekV2PytorchChatModel,
|
|
141
|
+
DeepSeekV2PytorchModel,
|
|
142
|
+
)
|
|
140
143
|
from .transformers.deepseek_vl import DeepSeekVLChatModel
|
|
141
144
|
from .transformers.glm4v import Glm4VModel
|
|
142
145
|
from .transformers.intern_vl import InternVLChatModel
|
|
143
146
|
from .transformers.internlm2 import Internlm2PytorchChatModel
|
|
144
|
-
from .transformers.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
|
|
145
147
|
from .transformers.minicpmv25 import MiniCPMV25Model
|
|
146
148
|
from .transformers.minicpmv26 import MiniCPMV26Model
|
|
149
|
+
from .transformers.qwen2_audio import Qwen2AudioChatModel
|
|
150
|
+
from .transformers.qwen2_vl import Qwen2VLChatModel
|
|
147
151
|
from .transformers.qwen_vl import QwenVLChatModel
|
|
148
152
|
from .transformers.yi_vl import YiVLChatModel
|
|
149
153
|
from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel
|
|
@@ -170,11 +174,11 @@ def _install():
|
|
|
170
174
|
TRANSFORMERS_CLASSES.extend(
|
|
171
175
|
[
|
|
172
176
|
ChatglmPytorchChatModel,
|
|
173
|
-
LlamaPytorchModel,
|
|
174
|
-
LlamaPytorchChatModel,
|
|
175
177
|
PytorchChatModel,
|
|
176
178
|
Internlm2PytorchChatModel,
|
|
177
179
|
QwenVLChatModel,
|
|
180
|
+
Qwen2VLChatModel,
|
|
181
|
+
Qwen2AudioChatModel,
|
|
178
182
|
YiVLChatModel,
|
|
179
183
|
DeepSeekVLChatModel,
|
|
180
184
|
InternVLChatModel,
|
|
@@ -184,6 +188,8 @@ def _install():
|
|
|
184
188
|
MiniCPMV25Model,
|
|
185
189
|
MiniCPMV26Model,
|
|
186
190
|
Glm4VModel,
|
|
191
|
+
DeepSeekV2PytorchModel,
|
|
192
|
+
DeepSeekV2PytorchChatModel,
|
|
187
193
|
]
|
|
188
194
|
)
|
|
189
195
|
if OmniLMMModel: # type: ignore
|
|
@@ -204,13 +210,17 @@ def _install():
|
|
|
204
210
|
model_spec = LLMFamilyV1.parse_obj(json_obj)
|
|
205
211
|
BUILTIN_LLM_FAMILIES.append(model_spec)
|
|
206
212
|
|
|
207
|
-
# register
|
|
213
|
+
# register chat_template
|
|
208
214
|
if "chat" in model_spec.model_ability and isinstance(
|
|
209
|
-
model_spec.
|
|
215
|
+
model_spec.chat_template, str
|
|
210
216
|
):
|
|
211
217
|
# note that the key is the model name,
|
|
212
218
|
# since there are multiple representations of the same prompt style name in json.
|
|
213
|
-
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] =
|
|
219
|
+
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
|
|
220
|
+
"chat_template": model_spec.chat_template,
|
|
221
|
+
"stop_token_ids": model_spec.stop_token_ids,
|
|
222
|
+
"stop": model_spec.stop,
|
|
223
|
+
}
|
|
214
224
|
# register model family
|
|
215
225
|
if "chat" in model_spec.model_ability:
|
|
216
226
|
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
|
|
@@ -230,10 +240,14 @@ def _install():
|
|
|
230
240
|
# if duplicated with huggingface json, keep it as the huggingface style
|
|
231
241
|
if (
|
|
232
242
|
"chat" in model_spec.model_ability
|
|
233
|
-
and isinstance(model_spec.
|
|
243
|
+
and isinstance(model_spec.chat_template, str)
|
|
234
244
|
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
|
|
235
245
|
):
|
|
236
|
-
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] =
|
|
246
|
+
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
|
|
247
|
+
"chat_template": model_spec.chat_template,
|
|
248
|
+
"stop_token_ids": model_spec.stop_token_ids,
|
|
249
|
+
"stop": model_spec.stop,
|
|
250
|
+
}
|
|
237
251
|
# register model family
|
|
238
252
|
if "chat" in model_spec.model_ability:
|
|
239
253
|
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
|
|
@@ -253,10 +267,14 @@ def _install():
|
|
|
253
267
|
# if duplicated with huggingface json, keep it as the huggingface style
|
|
254
268
|
if (
|
|
255
269
|
"chat" in model_spec.model_ability
|
|
256
|
-
and isinstance(model_spec.
|
|
270
|
+
and isinstance(model_spec.chat_template, str)
|
|
257
271
|
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
|
|
258
272
|
):
|
|
259
|
-
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] =
|
|
273
|
+
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
|
|
274
|
+
"chat_template": model_spec.chat_template,
|
|
275
|
+
"stop_token_ids": model_spec.stop_token_ids,
|
|
276
|
+
"stop": model_spec.stop,
|
|
277
|
+
}
|
|
260
278
|
# register model family
|
|
261
279
|
if "chat" in model_spec.model_ability:
|
|
262
280
|
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
|
|
@@ -14,12 +14,11 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
import time
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
18
18
|
|
|
19
19
|
from ....types import (
|
|
20
20
|
ChatCompletion,
|
|
21
21
|
ChatCompletionChunk,
|
|
22
|
-
ChatCompletionMessage,
|
|
23
22
|
Completion,
|
|
24
23
|
CompletionChunk,
|
|
25
24
|
CompletionUsage,
|
|
@@ -181,13 +180,12 @@ class LlamaCppModel(LLM):
|
|
|
181
180
|
for index, _completion_chunk in enumerate(
|
|
182
181
|
self._llm(prompt=_prompt, **_generate_config)
|
|
183
182
|
):
|
|
183
|
+
_completion_chunk["model"] = self.model_uid
|
|
184
184
|
request_id = _completion_chunk["id"]
|
|
185
|
-
|
|
186
|
-
if choice["finish_reason"] is not None:
|
|
187
|
-
completion_tokens = index
|
|
185
|
+
completion_tokens = index + 1
|
|
188
186
|
total_tokens = prompt_tokens + completion_tokens
|
|
189
187
|
_completion_chunk["usage"] = CompletionUsage(
|
|
190
|
-
prompt_tokens=
|
|
188
|
+
prompt_tokens=prompt_tokens,
|
|
191
189
|
completion_tokens=completion_tokens,
|
|
192
190
|
total_tokens=total_tokens,
|
|
193
191
|
)
|
|
@@ -262,39 +260,26 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
|
|
|
262
260
|
self, generate_config: Optional[LlamaCppGenerateConfig]
|
|
263
261
|
) -> LlamaCppGenerateConfig:
|
|
264
262
|
generate_config = super()._sanitize_generate_config(generate_config)
|
|
265
|
-
if self.model_family.
|
|
266
|
-
generate_config["stop"] = self.model_family.
|
|
263
|
+
if self.model_family.stop and self.model_family.stop:
|
|
264
|
+
generate_config["stop"] = self.model_family.stop.copy()
|
|
267
265
|
return generate_config
|
|
268
266
|
|
|
269
267
|
def chat(
|
|
270
268
|
self,
|
|
271
|
-
|
|
272
|
-
system_prompt: Optional[str] = None,
|
|
273
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
269
|
+
messages: List[Dict],
|
|
274
270
|
generate_config: Optional[LlamaCppGenerateConfig] = None,
|
|
275
271
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
276
|
-
|
|
277
|
-
prompt_style = self.model_family.prompt_style.copy()
|
|
278
|
-
if system_prompt:
|
|
279
|
-
prompt_style.system_prompt = system_prompt
|
|
280
|
-
|
|
281
|
-
chat_history = chat_history or []
|
|
282
|
-
assert prompt_style is not None
|
|
272
|
+
model_family = self.model_family.model_family or self.model_family.model_name
|
|
283
273
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
284
|
-
|
|
274
|
+
full_context_kwargs = {}
|
|
275
|
+
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
276
|
+
full_context_kwargs["tools"] = tools
|
|
277
|
+
assert self.model_family.chat_template is not None
|
|
278
|
+
full_prompt = self.get_full_context(
|
|
279
|
+
messages, self.model_family.chat_template, **full_context_kwargs
|
|
280
|
+
)
|
|
285
281
|
|
|
286
282
|
generate_config = self._sanitize_generate_config(generate_config)
|
|
287
|
-
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
288
|
-
model_family = self.model_family.model_family or self.model_family.model_name
|
|
289
|
-
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
290
|
-
stop = generate_config.get("stop")
|
|
291
|
-
if isinstance(stop, str):
|
|
292
|
-
generate_config["stop"] = [stop, "Observation:"]
|
|
293
|
-
elif isinstance(stop, Iterable):
|
|
294
|
-
assert not isinstance(stop, str)
|
|
295
|
-
generate_config["stop"] = stop + ["Observation:"] # type: ignore
|
|
296
|
-
else:
|
|
297
|
-
generate_config["stop"] = "Observation:"
|
|
298
283
|
|
|
299
284
|
stream = generate_config.get("stream", False)
|
|
300
285
|
if stream:
|
|
@@ -305,7 +290,5 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
|
|
|
305
290
|
c = self.generate(full_prompt, generate_config)
|
|
306
291
|
assert not isinstance(c, Iterator)
|
|
307
292
|
if tools:
|
|
308
|
-
return self._tool_calls_completion(
|
|
309
|
-
self.model_family, self.model_uid, c, tools
|
|
310
|
-
)
|
|
293
|
+
return self._tool_calls_completion(self.model_family, self.model_uid, c)
|
|
311
294
|
return self._to_chat_completion(c)
|