xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +108 -14
- xinference/client/restful/restful_client.py +78 -5
- xinference/constants.py +1 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/event.py +5 -6
- xinference/core/model.py +59 -42
- xinference/core/scheduler.py +46 -18
- xinference/core/supervisor.py +73 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/core.py +12 -1
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +34 -2
- xinference/model/llm/llm_family.json +8 -2
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +8 -2
- xinference/model/llm/pytorch/chatglm.py +41 -12
- xinference/model/llm/pytorch/core.py +128 -88
- xinference/model/llm/pytorch/glm4v.py +24 -3
- xinference/model/llm/pytorch/internlm2.py +15 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +69 -189
- xinference/model/llm/utils.py +27 -14
- xinference/model/llm/vllm/core.py +10 -4
- xinference/model/rerank/core.py +35 -6
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +28 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
- xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
|
@@ -32,10 +32,15 @@ from ..._compat import (
|
|
|
32
32
|
load_str_bytes,
|
|
33
33
|
validator,
|
|
34
34
|
)
|
|
35
|
-
from ...constants import
|
|
35
|
+
from ...constants import (
|
|
36
|
+
XINFERENCE_CACHE_DIR,
|
|
37
|
+
XINFERENCE_ENV_CSG_TOKEN,
|
|
38
|
+
XINFERENCE_MODEL_DIR,
|
|
39
|
+
)
|
|
36
40
|
from ..utils import (
|
|
37
41
|
IS_NEW_HUGGINGFACE_HUB,
|
|
38
42
|
create_symlink,
|
|
43
|
+
download_from_csghub,
|
|
39
44
|
download_from_modelscope,
|
|
40
45
|
is_valid_model_uri,
|
|
41
46
|
parse_uri,
|
|
@@ -232,6 +237,7 @@ LLAMA_CLASSES: List[Type[LLM]] = []
|
|
|
232
237
|
|
|
233
238
|
BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
234
239
|
BUILTIN_MODELSCOPE_LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
240
|
+
BUILTIN_CSGHUB_LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
235
241
|
|
|
236
242
|
SGLANG_CLASSES: List[Type[LLM]] = []
|
|
237
243
|
TRANSFORMERS_CLASSES: List[Type[LLM]] = []
|
|
@@ -292,6 +298,9 @@ def cache(
|
|
|
292
298
|
elif llm_spec.model_hub == "modelscope":
|
|
293
299
|
logger.info(f"Caching from Modelscope: {llm_spec.model_id}")
|
|
294
300
|
return cache_from_modelscope(llm_family, llm_spec, quantization)
|
|
301
|
+
elif llm_spec.model_hub == "csghub":
|
|
302
|
+
logger.info(f"Caching from CSGHub: {llm_spec.model_id}")
|
|
303
|
+
return cache_from_csghub(llm_family, llm_spec, quantization)
|
|
295
304
|
else:
|
|
296
305
|
raise ValueError(f"Unknown model hub: {llm_spec.model_hub}")
|
|
297
306
|
|
|
@@ -566,6 +575,7 @@ def _skip_download(
|
|
|
566
575
|
"modelscope": _get_meta_path(
|
|
567
576
|
cache_dir, model_format, "modelscope", quantization
|
|
568
577
|
),
|
|
578
|
+
"csghub": _get_meta_path(cache_dir, model_format, "csghub", quantization),
|
|
569
579
|
}
|
|
570
580
|
if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision):
|
|
571
581
|
logger.info(f"Cache {cache_dir} exists")
|
|
@@ -650,6 +660,75 @@ def _merge_cached_files(
|
|
|
650
660
|
logger.info(f"Merge complete.")
|
|
651
661
|
|
|
652
662
|
|
|
663
|
+
def cache_from_csghub(
|
|
664
|
+
llm_family: LLMFamilyV1,
|
|
665
|
+
llm_spec: "LLMSpecV1",
|
|
666
|
+
quantization: Optional[str] = None,
|
|
667
|
+
) -> str:
|
|
668
|
+
"""
|
|
669
|
+
Cache model from CSGHub. Return the cache directory.
|
|
670
|
+
"""
|
|
671
|
+
from pycsghub.file_download import file_download
|
|
672
|
+
from pycsghub.snapshot_download import snapshot_download
|
|
673
|
+
|
|
674
|
+
cache_dir = _get_cache_dir(llm_family, llm_spec)
|
|
675
|
+
|
|
676
|
+
if _skip_download(
|
|
677
|
+
cache_dir,
|
|
678
|
+
llm_spec.model_format,
|
|
679
|
+
llm_spec.model_hub,
|
|
680
|
+
llm_spec.model_revision,
|
|
681
|
+
quantization,
|
|
682
|
+
):
|
|
683
|
+
return cache_dir
|
|
684
|
+
|
|
685
|
+
if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
|
|
686
|
+
download_dir = retry_download(
|
|
687
|
+
snapshot_download,
|
|
688
|
+
llm_family.model_name,
|
|
689
|
+
{
|
|
690
|
+
"model_size": llm_spec.model_size_in_billions,
|
|
691
|
+
"model_format": llm_spec.model_format,
|
|
692
|
+
},
|
|
693
|
+
llm_spec.model_id,
|
|
694
|
+
endpoint="https://hub-stg.opencsg.com",
|
|
695
|
+
token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
|
|
696
|
+
)
|
|
697
|
+
create_symlink(download_dir, cache_dir)
|
|
698
|
+
|
|
699
|
+
elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
|
|
700
|
+
file_names, final_file_name, need_merge = _generate_model_file_names(
|
|
701
|
+
llm_spec, quantization
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
for filename in file_names:
|
|
705
|
+
download_path = retry_download(
|
|
706
|
+
file_download,
|
|
707
|
+
llm_family.model_name,
|
|
708
|
+
{
|
|
709
|
+
"model_size": llm_spec.model_size_in_billions,
|
|
710
|
+
"model_format": llm_spec.model_format,
|
|
711
|
+
},
|
|
712
|
+
llm_spec.model_id,
|
|
713
|
+
file_name=filename,
|
|
714
|
+
endpoint="https://hub-stg.opencsg.com",
|
|
715
|
+
token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
|
|
716
|
+
)
|
|
717
|
+
symlink_local_file(download_path, cache_dir, filename)
|
|
718
|
+
|
|
719
|
+
if need_merge:
|
|
720
|
+
_merge_cached_files(cache_dir, file_names, final_file_name)
|
|
721
|
+
else:
|
|
722
|
+
raise ValueError(f"Unsupported format: {llm_spec.model_format}")
|
|
723
|
+
|
|
724
|
+
meta_path = _get_meta_path(
|
|
725
|
+
cache_dir, llm_spec.model_format, llm_spec.model_hub, quantization
|
|
726
|
+
)
|
|
727
|
+
_generate_meta_file(meta_path, llm_family, llm_spec, quantization)
|
|
728
|
+
|
|
729
|
+
return cache_dir
|
|
730
|
+
|
|
731
|
+
|
|
653
732
|
def cache_from_modelscope(
|
|
654
733
|
llm_family: LLMFamilyV1,
|
|
655
734
|
llm_spec: "LLMSpecV1",
|
|
@@ -931,6 +1010,12 @@ def match_llm(
|
|
|
931
1010
|
+ BUILTIN_LLM_FAMILIES
|
|
932
1011
|
+ user_defined_llm_families
|
|
933
1012
|
)
|
|
1013
|
+
elif download_from_csghub():
|
|
1014
|
+
all_families = (
|
|
1015
|
+
BUILTIN_CSGHUB_LLM_FAMILIES
|
|
1016
|
+
+ BUILTIN_LLM_FAMILIES
|
|
1017
|
+
+ user_defined_llm_families
|
|
1018
|
+
)
|
|
934
1019
|
else:
|
|
935
1020
|
all_families = BUILTIN_LLM_FAMILIES + user_defined_llm_families
|
|
936
1021
|
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
[
|
|
2
|
+
{
|
|
3
|
+
"version": 1,
|
|
4
|
+
"context_length": 32768,
|
|
5
|
+
"model_name": "qwen2-instruct",
|
|
6
|
+
"model_lang": [
|
|
7
|
+
"en",
|
|
8
|
+
"zh"
|
|
9
|
+
],
|
|
10
|
+
"model_ability": [
|
|
11
|
+
"chat",
|
|
12
|
+
"tools"
|
|
13
|
+
],
|
|
14
|
+
"model_description": "Qwen2 is the new series of Qwen large language models",
|
|
15
|
+
"model_specs": [
|
|
16
|
+
{
|
|
17
|
+
"model_format": "pytorch",
|
|
18
|
+
"model_size_in_billions": "0_5",
|
|
19
|
+
"quantizations": [
|
|
20
|
+
"4-bit",
|
|
21
|
+
"8-bit",
|
|
22
|
+
"none"
|
|
23
|
+
],
|
|
24
|
+
"model_id": "Qwen/Qwen2-0.5B-Instruct",
|
|
25
|
+
"model_hub": "csghub"
|
|
26
|
+
},
|
|
27
|
+
{
|
|
28
|
+
"model_format": "ggufv2",
|
|
29
|
+
"model_size_in_billions": "0_5",
|
|
30
|
+
"quantizations": [
|
|
31
|
+
"q2_k",
|
|
32
|
+
"q3_k_m",
|
|
33
|
+
"q4_0",
|
|
34
|
+
"q4_k_m",
|
|
35
|
+
"q5_0",
|
|
36
|
+
"q5_k_m",
|
|
37
|
+
"q6_k",
|
|
38
|
+
"q8_0",
|
|
39
|
+
"fp16"
|
|
40
|
+
],
|
|
41
|
+
"model_id": "qwen/Qwen2-0.5B-Instruct-GGUF",
|
|
42
|
+
"model_file_name_template": "qwen2-0_5b-instruct-{quantization}.gguf",
|
|
43
|
+
"model_hub": "csghub"
|
|
44
|
+
}
|
|
45
|
+
],
|
|
46
|
+
"prompt_style": {
|
|
47
|
+
"style_name": "QWEN",
|
|
48
|
+
"system_prompt": "You are a helpful assistant.",
|
|
49
|
+
"roles": [
|
|
50
|
+
"user",
|
|
51
|
+
"assistant"
|
|
52
|
+
],
|
|
53
|
+
"intra_message_sep": "\n",
|
|
54
|
+
"stop_token_ids": [
|
|
55
|
+
151643,
|
|
56
|
+
151644,
|
|
57
|
+
151645
|
|
58
|
+
],
|
|
59
|
+
"stop": [
|
|
60
|
+
"<|endoftext|>",
|
|
61
|
+
"<|im_start|>",
|
|
62
|
+
"<|im_end|>"
|
|
63
|
+
]
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
]
|
|
@@ -632,6 +632,8 @@
|
|
|
632
632
|
"model_format": "pytorch",
|
|
633
633
|
"model_size_in_billions": 9,
|
|
634
634
|
"quantizations": [
|
|
635
|
+
"4-bit",
|
|
636
|
+
"8-bit",
|
|
635
637
|
"none"
|
|
636
638
|
],
|
|
637
639
|
"model_hub": "modelscope",
|
|
@@ -2642,7 +2644,8 @@
|
|
|
2642
2644
|
"zh"
|
|
2643
2645
|
],
|
|
2644
2646
|
"model_ability": [
|
|
2645
|
-
"chat"
|
|
2647
|
+
"chat",
|
|
2648
|
+
"tools"
|
|
2646
2649
|
],
|
|
2647
2650
|
"model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
|
|
2648
2651
|
"model_specs": [
|
|
@@ -2966,7 +2969,8 @@
|
|
|
2966
2969
|
"zh"
|
|
2967
2970
|
],
|
|
2968
2971
|
"model_ability": [
|
|
2969
|
-
"chat"
|
|
2972
|
+
"chat",
|
|
2973
|
+
"tools"
|
|
2970
2974
|
],
|
|
2971
2975
|
"model_description": "Qwen2 is the new series of Qwen large language models. ",
|
|
2972
2976
|
"model_specs": [
|
|
@@ -3348,9 +3352,11 @@
|
|
|
3348
3352
|
],
|
|
3349
3353
|
"intra_message_sep": "<|im_end|>",
|
|
3350
3354
|
"stop_token_ids": [
|
|
3355
|
+
2,
|
|
3351
3356
|
92542
|
|
3352
3357
|
],
|
|
3353
3358
|
"stop": [
|
|
3359
|
+
"</s>",
|
|
3354
3360
|
"<|im_end|>"
|
|
3355
3361
|
]
|
|
3356
3362
|
}
|
|
@@ -15,6 +15,7 @@ import time
|
|
|
15
15
|
import uuid
|
|
16
16
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
17
17
|
|
|
18
|
+
from ....core.scheduler import InferenceRequest
|
|
18
19
|
from ....types import (
|
|
19
20
|
SPECIAL_TOOL_PROMPT,
|
|
20
21
|
ChatCompletion,
|
|
@@ -89,24 +90,30 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
89
90
|
return False
|
|
90
91
|
return True
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
def _handle_tools(generate_config) -> Optional[dict]:
|
|
93
|
+
def _handle_tools(self, generate_config) -> Optional[dict]:
|
|
94
94
|
"""Convert openai tools to ChatGLM tools."""
|
|
95
95
|
if generate_config is None:
|
|
96
96
|
return None
|
|
97
97
|
tools = generate_config.pop("tools", None)
|
|
98
98
|
if tools is None:
|
|
99
99
|
return None
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
100
|
+
if self.model_family.model_name == "glm4-chat":
|
|
101
|
+
return {
|
|
102
|
+
"role": "system",
|
|
103
|
+
"content": None,
|
|
104
|
+
"tools": tools,
|
|
105
|
+
}
|
|
106
|
+
else:
|
|
107
|
+
chatglm_tools = []
|
|
108
|
+
for elem in tools:
|
|
109
|
+
if elem.get("type") != "function" or "function" not in elem:
|
|
110
|
+
raise ValueError("ChatGLM tools only support function type.")
|
|
111
|
+
chatglm_tools.append(elem["function"])
|
|
112
|
+
return {
|
|
113
|
+
"role": "system",
|
|
114
|
+
"content": f"Answer the following questions as best as you can. You have access to the following tools:",
|
|
115
|
+
"tools": chatglm_tools,
|
|
116
|
+
}
|
|
110
117
|
|
|
111
118
|
def chat(
|
|
112
119
|
self,
|
|
@@ -238,3 +245,25 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
238
245
|
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
239
246
|
),
|
|
240
247
|
)
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def require_attention_mask():
|
|
251
|
+
"""
|
|
252
|
+
GLM4 needs to use attention mask and position ids during inference.
|
|
253
|
+
Otherwise, the inference result would be not available.
|
|
254
|
+
"""
|
|
255
|
+
return True
|
|
256
|
+
|
|
257
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
258
|
+
"""
|
|
259
|
+
Set temperature and top_p to 0.8 by default
|
|
260
|
+
"""
|
|
261
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
262
|
+
temperature = raw_config.get("temperature", None)
|
|
263
|
+
if temperature is None:
|
|
264
|
+
raw_config["temperature"] = 0.8
|
|
265
|
+
top_p = raw_config.get("top_p", None)
|
|
266
|
+
if top_p is None:
|
|
267
|
+
raw_config["top_p"] = 0.8
|
|
268
|
+
|
|
269
|
+
return raw_config
|
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
-
from
|
|
18
|
+
from functools import lru_cache
|
|
19
|
+
from typing import Iterable, Iterator, List, Optional, Tuple, Union
|
|
19
20
|
|
|
20
21
|
from ....core.scheduler import InferenceRequest
|
|
21
22
|
from ....device_utils import (
|
|
@@ -28,6 +29,7 @@ from ....types import (
|
|
|
28
29
|
ChatCompletionChunk,
|
|
29
30
|
ChatCompletionMessage,
|
|
30
31
|
Completion,
|
|
32
|
+
CompletionChoice,
|
|
31
33
|
CompletionChunk,
|
|
32
34
|
CreateCompletionTorch,
|
|
33
35
|
Embedding,
|
|
@@ -281,35 +283,21 @@ class PytorchModel(LLM):
|
|
|
281
283
|
def generate(
|
|
282
284
|
self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
|
|
283
285
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
284
|
-
from .utils import generate_stream
|
|
285
|
-
|
|
286
|
-
model_family_name = self.model_family.model_name.lower()
|
|
286
|
+
from .utils import generate_stream
|
|
287
287
|
|
|
288
288
|
def generator_wrapper(
|
|
289
289
|
prompt: str, generate_config: PytorchGenerateConfig
|
|
290
290
|
) -> Iterator[CompletionChunk]:
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
yield completion_chunk
|
|
302
|
-
else:
|
|
303
|
-
for completion_chunk, completion_usage in generate_stream(
|
|
304
|
-
self.model_uid,
|
|
305
|
-
self._model,
|
|
306
|
-
self._tokenizer,
|
|
307
|
-
prompt,
|
|
308
|
-
self._device,
|
|
309
|
-
generate_config,
|
|
310
|
-
):
|
|
311
|
-
completion_chunk["usage"] = completion_usage
|
|
312
|
-
yield completion_chunk
|
|
291
|
+
for completion_chunk, completion_usage in generate_stream(
|
|
292
|
+
self.model_uid,
|
|
293
|
+
self._model,
|
|
294
|
+
self._tokenizer,
|
|
295
|
+
prompt,
|
|
296
|
+
self._device,
|
|
297
|
+
generate_config,
|
|
298
|
+
):
|
|
299
|
+
completion_chunk["usage"] = completion_usage
|
|
300
|
+
yield completion_chunk
|
|
313
301
|
|
|
314
302
|
logger.debug(
|
|
315
303
|
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
|
|
@@ -334,26 +322,15 @@ class PytorchModel(LLM):
|
|
|
334
322
|
|
|
335
323
|
stream = generate_config.get("stream", False)
|
|
336
324
|
if not stream:
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
pass
|
|
347
|
-
else:
|
|
348
|
-
for completion_chunk, completion_usage in generate_stream(
|
|
349
|
-
self.model_uid,
|
|
350
|
-
self._model,
|
|
351
|
-
self._tokenizer,
|
|
352
|
-
prompt,
|
|
353
|
-
self._device,
|
|
354
|
-
generate_config,
|
|
355
|
-
):
|
|
356
|
-
pass
|
|
325
|
+
for completion_chunk, completion_usage in generate_stream(
|
|
326
|
+
self.model_uid,
|
|
327
|
+
self._model,
|
|
328
|
+
self._tokenizer,
|
|
329
|
+
prompt,
|
|
330
|
+
self._device,
|
|
331
|
+
generate_config,
|
|
332
|
+
):
|
|
333
|
+
pass
|
|
357
334
|
completion = Completion(
|
|
358
335
|
id=completion_chunk["id"],
|
|
359
336
|
object=completion_chunk["object"],
|
|
@@ -366,6 +343,105 @@ class PytorchModel(LLM):
|
|
|
366
343
|
else:
|
|
367
344
|
return generator_wrapper(prompt, generate_config)
|
|
368
345
|
|
|
346
|
+
@staticmethod
|
|
347
|
+
def require_attention_mask():
|
|
348
|
+
return False
|
|
349
|
+
|
|
350
|
+
@lru_cache
|
|
351
|
+
def get_context_len(self):
|
|
352
|
+
return get_context_length(self._model.config)
|
|
353
|
+
|
|
354
|
+
def get_max_num_seqs(self) -> int:
|
|
355
|
+
return self._pytorch_model_config.get("max_num_seqs") # type: ignore
|
|
356
|
+
|
|
357
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
358
|
+
return self._sanitize_generate_config(req.generate_config)
|
|
359
|
+
|
|
360
|
+
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
361
|
+
# check some parameters
|
|
362
|
+
for r in req_list:
|
|
363
|
+
if r.sanitized_generate_config is None:
|
|
364
|
+
r.sanitized_generate_config = self.prepare_sanitize_generate_config(r)
|
|
365
|
+
if r.is_prefill:
|
|
366
|
+
# check some generate params
|
|
367
|
+
max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
|
|
368
|
+
if max_src_len < 0:
|
|
369
|
+
r.stopped = True
|
|
370
|
+
r.error_msg = "Max tokens exceeds model's max length"
|
|
371
|
+
continue
|
|
372
|
+
if r.stream_interval <= 0:
|
|
373
|
+
r.stopped = True
|
|
374
|
+
r.error_msg = "`stream_interval` must be greater than 0"
|
|
375
|
+
continue
|
|
376
|
+
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
377
|
+
if stop_str and (
|
|
378
|
+
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
379
|
+
):
|
|
380
|
+
r.stopped = True
|
|
381
|
+
r.error_msg = "Invalid `stop` field type"
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
def _get_builtin_stop_token_ids(self) -> Tuple:
|
|
385
|
+
return (
|
|
386
|
+
tuple(self.model_family.prompt_style.stop_token_ids)
|
|
387
|
+
if self.model_family.prompt_style
|
|
388
|
+
and self.model_family.prompt_style.stop_token_ids
|
|
389
|
+
else tuple()
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
393
|
+
for req in req_list:
|
|
394
|
+
if req.error_msg is None:
|
|
395
|
+
# nothing need handle for non-stream case
|
|
396
|
+
if req.stream:
|
|
397
|
+
results = []
|
|
398
|
+
for i, c in enumerate(req.completion):
|
|
399
|
+
if c == "<bos_stream>":
|
|
400
|
+
chunk = req.completion[i + 1]
|
|
401
|
+
results.append(
|
|
402
|
+
CompletionChunk(
|
|
403
|
+
id=chunk["id"],
|
|
404
|
+
object=chunk["object"],
|
|
405
|
+
created=chunk["created"],
|
|
406
|
+
model=chunk["model"],
|
|
407
|
+
choices=[
|
|
408
|
+
CompletionChoice(
|
|
409
|
+
text="",
|
|
410
|
+
index=0,
|
|
411
|
+
logprobs=None,
|
|
412
|
+
finish_reason=None,
|
|
413
|
+
)
|
|
414
|
+
],
|
|
415
|
+
)
|
|
416
|
+
)
|
|
417
|
+
continue
|
|
418
|
+
elif c == "<eos_stream>":
|
|
419
|
+
break
|
|
420
|
+
else:
|
|
421
|
+
results.append(c)
|
|
422
|
+
|
|
423
|
+
if req.stopped and req.include_usage:
|
|
424
|
+
results.append(req.completion[-1])
|
|
425
|
+
req.completion = results
|
|
426
|
+
|
|
427
|
+
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
428
|
+
from .utils import batch_inference_one_step
|
|
429
|
+
|
|
430
|
+
self.prepare_batch_inference(req_list)
|
|
431
|
+
context_len = self.get_context_len()
|
|
432
|
+
assert isinstance(context_len, int)
|
|
433
|
+
batch_inference_one_step(
|
|
434
|
+
req_list,
|
|
435
|
+
self.model_uid,
|
|
436
|
+
self._model,
|
|
437
|
+
self._tokenizer,
|
|
438
|
+
self._device,
|
|
439
|
+
context_len,
|
|
440
|
+
self._get_builtin_stop_token_ids(),
|
|
441
|
+
require_attention_mask=self.require_attention_mask(),
|
|
442
|
+
)
|
|
443
|
+
self.handle_batch_inference_results(req_list)
|
|
444
|
+
|
|
369
445
|
def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
|
|
370
446
|
try:
|
|
371
447
|
import torch
|
|
@@ -464,7 +540,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
464
540
|
pytorch_model_config,
|
|
465
541
|
peft_model,
|
|
466
542
|
)
|
|
467
|
-
self._context_len = None
|
|
468
543
|
|
|
469
544
|
def _sanitize_generate_config(
|
|
470
545
|
self,
|
|
@@ -540,7 +615,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
540
615
|
|
|
541
616
|
def load(self):
|
|
542
617
|
super().load()
|
|
543
|
-
self._context_len = get_context_length(self._model.config)
|
|
544
618
|
|
|
545
619
|
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
546
620
|
assert self.model_family.prompt_style is not None
|
|
@@ -553,48 +627,14 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
553
627
|
)
|
|
554
628
|
return full_prompt
|
|
555
629
|
|
|
556
|
-
def
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
560
|
-
from .utils import batch_inference_one_step
|
|
561
|
-
|
|
630
|
+
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
631
|
+
super().prepare_batch_inference(req_list)
|
|
562
632
|
for r in req_list:
|
|
563
|
-
|
|
564
|
-
r.
|
|
565
|
-
|
|
566
|
-
)
|
|
567
|
-
if r.is_prefill:
|
|
568
|
-
# check some generate params
|
|
569
|
-
max_src_len = get_max_src_len(self._context_len, r) # type: ignore
|
|
570
|
-
if max_src_len < 0:
|
|
571
|
-
r.stopped = True
|
|
572
|
-
r.error_msg = "Max tokens exceeds model's max length"
|
|
573
|
-
continue
|
|
574
|
-
if r.stream_interval <= 0:
|
|
575
|
-
r.stopped = True
|
|
576
|
-
r.error_msg = "`stream_interval` must be greater than 0"
|
|
577
|
-
continue
|
|
578
|
-
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
579
|
-
if stop_str and (
|
|
580
|
-
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
581
|
-
):
|
|
582
|
-
r.stopped = True
|
|
583
|
-
r.error_msg = "Invalid `stop` field type"
|
|
584
|
-
continue
|
|
585
|
-
r.full_prompt = self._get_full_prompt(
|
|
586
|
-
r.prompt, r.system_prompt, r.chat_history, None
|
|
587
|
-
)
|
|
633
|
+
r.full_prompt = self._get_full_prompt(
|
|
634
|
+
r.prompt, r.system_prompt, r.chat_history, None
|
|
635
|
+
)
|
|
588
636
|
|
|
589
|
-
|
|
590
|
-
batch_inference_one_step(
|
|
591
|
-
req_list,
|
|
592
|
-
self.model_uid,
|
|
593
|
-
self._model,
|
|
594
|
-
self._tokenizer,
|
|
595
|
-
self._device,
|
|
596
|
-
self._context_len,
|
|
597
|
-
)
|
|
637
|
+
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
598
638
|
for req in req_list:
|
|
599
639
|
if req.stream and req.error_msg is None:
|
|
600
640
|
if req.completion:
|
|
@@ -56,19 +56,40 @@ class Glm4VModel(PytorchChatModel):
|
|
|
56
56
|
return True
|
|
57
57
|
return False
|
|
58
58
|
|
|
59
|
-
def load(self
|
|
59
|
+
def load(self):
|
|
60
60
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
61
61
|
|
|
62
62
|
device = self._pytorch_model_config.get("device", "auto")
|
|
63
63
|
self._device = select_device(device)
|
|
64
|
-
|
|
64
|
+
|
|
65
|
+
kwargs = {"device_map": self._device}
|
|
66
|
+
quantization = self.quantization
|
|
67
|
+
|
|
68
|
+
# referenced from PytorchModel.load
|
|
69
|
+
if quantization != "none":
|
|
70
|
+
if self._device == "cuda" and self._is_linux():
|
|
71
|
+
kwargs["device_map"] = "auto"
|
|
72
|
+
self._device = "auto"
|
|
73
|
+
if quantization == "4-bit":
|
|
74
|
+
kwargs["load_in_4bit"] = True
|
|
75
|
+
elif quantization == "8-bit":
|
|
76
|
+
kwargs["load_in_8bit"] = True
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Quantization {quantization} is not supported in temporary"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
if quantization != "8-bit":
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Only 8-bit quantization is supported if it is not linux system or cuda device"
|
|
85
|
+
)
|
|
65
86
|
|
|
66
87
|
model = AutoModelForCausalLM.from_pretrained(
|
|
67
88
|
self.model_path,
|
|
68
89
|
low_cpu_mem_usage=True,
|
|
69
90
|
trust_remote_code=True,
|
|
70
91
|
torch_dtype=torch.float16,
|
|
71
|
-
|
|
92
|
+
**kwargs,
|
|
72
93
|
)
|
|
73
94
|
self._model = model.eval()
|
|
74
95
|
|
|
@@ -15,6 +15,7 @@ import time
|
|
|
15
15
|
import uuid
|
|
16
16
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
17
17
|
|
|
18
|
+
from ....core.scheduler import InferenceRequest
|
|
18
19
|
from ....types import (
|
|
19
20
|
ChatCompletion,
|
|
20
21
|
ChatCompletionChoice,
|
|
@@ -88,6 +89,20 @@ class Internlm2PytorchChatModel(PytorchChatModel):
|
|
|
88
89
|
return False
|
|
89
90
|
return True
|
|
90
91
|
|
|
92
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
93
|
+
"""
|
|
94
|
+
Overwrite this func for this special model.
|
|
95
|
+
Cannot use the default configuration, which works poorly on this model.
|
|
96
|
+
"""
|
|
97
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
98
|
+
temperature = raw_config.get("temperature", None)
|
|
99
|
+
if temperature is None:
|
|
100
|
+
raw_config["temperature"] = 0.8
|
|
101
|
+
top_p = raw_config.get("top_p", None)
|
|
102
|
+
if top_p is None:
|
|
103
|
+
raw_config["top_p"] = 0.8
|
|
104
|
+
return raw_config
|
|
105
|
+
|
|
91
106
|
def chat(
|
|
92
107
|
self,
|
|
93
108
|
prompt: str,
|
|
@@ -45,7 +45,7 @@ class QwenVLChatModel(PytorchChatModel):
|
|
|
45
45
|
def match(
|
|
46
46
|
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
47
47
|
) -> bool:
|
|
48
|
-
if "qwen" in model_family.model_name:
|
|
48
|
+
if "qwen" in model_family.model_name and "vision" in model_family.model_ability:
|
|
49
49
|
return True
|
|
50
50
|
return False
|
|
51
51
|
|