xinference 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client.py +18 -0
- xinference/constants.py +1 -0
- xinference/core/gradio.py +2 -2
- xinference/core/restful_api.py +31 -5
- xinference/core/supervisor.py +64 -1
- xinference/core/worker.py +22 -0
- xinference/deploy/cmdline.py +39 -13
- xinference/deploy/worker.py +2 -2
- xinference/model/llm/__init__.py +20 -83
- xinference/model/llm/ggml/llamacpp.py +1 -0
- xinference/model/llm/llm_family.json +30 -15
- xinference/model/llm/llm_family.py +152 -7
- xinference/model/llm/pytorch/core.py +63 -40
- xinference/model/llm/pytorch/utils.py +5 -1
- xinference/model/llm/utils.py +6 -0
- {xinference-0.1.1.dist-info → xinference-0.1.3.dist-info}/METADATA +133 -29
- {xinference-0.1.1.dist-info → xinference-0.1.3.dist-info}/RECORD +22 -22
- {xinference-0.1.1.dist-info → xinference-0.1.3.dist-info}/WHEEL +1 -1
- {xinference-0.1.1.dist-info → xinference-0.1.3.dist-info}/LICENSE +0 -0
- {xinference-0.1.1.dist-info → xinference-0.1.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.1.1.dist-info → xinference-0.1.3.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2023-08-
|
|
11
|
+
"date": "2023-08-09T18:43:41+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.1.
|
|
14
|
+
"full-revisionid": "4d2f61cb6591ac94624f035b37259a89002abefd",
|
|
15
|
+
"version": "0.1.3"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/client.py
CHANGED
|
@@ -480,6 +480,24 @@ class Client:
|
|
|
480
480
|
# generate a time-based uuid.
|
|
481
481
|
return str(uuid.uuid1())
|
|
482
482
|
|
|
483
|
+
def register_model(self, model_type: str, model: str, persist: bool):
|
|
484
|
+
coro = self._supervisor_ref.register_model(model_type, model, persist)
|
|
485
|
+
self._isolation.call(coro)
|
|
486
|
+
|
|
487
|
+
def unregister_model(self, model_type: str, model_name: str):
|
|
488
|
+
coro = self._supervisor_ref.unregister_model(model_type, model_name)
|
|
489
|
+
self._isolation.call(coro)
|
|
490
|
+
|
|
491
|
+
def list_model_registrations(self, model_type: str) -> List[Dict[str, Any]]:
|
|
492
|
+
coro = self._supervisor_ref.list_model_registrations(model_type)
|
|
493
|
+
return self._isolation.call(coro)
|
|
494
|
+
|
|
495
|
+
def get_model_registration(
|
|
496
|
+
self, model_type: str, model_name: str
|
|
497
|
+
) -> Dict[str, Any]:
|
|
498
|
+
coro = self._supervisor_ref.get_model_registration(model_type, model_name)
|
|
499
|
+
return self._isolation.call(coro)
|
|
500
|
+
|
|
483
501
|
def launch_model(
|
|
484
502
|
self,
|
|
485
503
|
model_name: str,
|
xinference/constants.py
CHANGED
|
@@ -17,6 +17,7 @@ from pathlib import Path
|
|
|
17
17
|
|
|
18
18
|
XINFERENCE_HOME = str(Path.home() / ".xinference")
|
|
19
19
|
XINFERENCE_CACHE_DIR = os.path.join(XINFERENCE_HOME, "cache")
|
|
20
|
+
XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
|
|
20
21
|
XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
|
|
21
22
|
|
|
22
23
|
XINFERENCE_DEFAULT_LOCAL_HOST = "127.0.0.1"
|
xinference/core/gradio.py
CHANGED
|
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional
|
|
|
18
18
|
import gradio as gr
|
|
19
19
|
|
|
20
20
|
from ..locale.utils import Locale
|
|
21
|
-
from ..model.llm import
|
|
21
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES, LLMFamilyV1, match_llm
|
|
22
22
|
from ..model.llm.llm_family import cache
|
|
23
23
|
from .api import SyncSupervisorAPI
|
|
24
24
|
|
|
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|
|
27
27
|
|
|
28
28
|
MODEL_TO_FAMILIES: Dict[str, LLMFamilyV1] = dict(
|
|
29
29
|
(model_family.model_name, model_family)
|
|
30
|
-
for model_family in
|
|
30
|
+
for model_family in BUILTIN_LLM_FAMILIES
|
|
31
31
|
if "chat" in model_family.model_ability
|
|
32
32
|
)
|
|
33
33
|
|
xinference/core/restful_api.py
CHANGED
|
@@ -480,7 +480,7 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
480
480
|
(msg["content"] for msg in body.messages if msg["role"] == "system"), None
|
|
481
481
|
)
|
|
482
482
|
|
|
483
|
-
chat_history = body.messages
|
|
483
|
+
chat_history = body.messages[:-1] # exclude the prompt
|
|
484
484
|
|
|
485
485
|
model_uid = body.model
|
|
486
486
|
|
|
@@ -494,6 +494,26 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
494
494
|
logger.error(e, exc_info=True)
|
|
495
495
|
raise HTTPException(status_code=500, detail=str(e))
|
|
496
496
|
|
|
497
|
+
try:
|
|
498
|
+
desc = await self._supervisor_ref.describe_model(model_uid)
|
|
499
|
+
|
|
500
|
+
except ValueError as ve:
|
|
501
|
+
logger.error(str(ve), exc_info=True)
|
|
502
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
503
|
+
|
|
504
|
+
except Exception as e:
|
|
505
|
+
logger.error(e, exc_info=True)
|
|
506
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
507
|
+
|
|
508
|
+
is_chatglm_ggml = desc.get(
|
|
509
|
+
"model_format"
|
|
510
|
+
) == "ggmlv3" and "chatglm" in desc.get("model_name", "")
|
|
511
|
+
|
|
512
|
+
if is_chatglm_ggml and system_prompt is not None:
|
|
513
|
+
raise HTTPException(
|
|
514
|
+
status_code=400, detail="ChatGLM ggml does not have system prompt"
|
|
515
|
+
)
|
|
516
|
+
|
|
497
517
|
if body.stream:
|
|
498
518
|
# create a pair of memory object streams
|
|
499
519
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
|
@@ -501,9 +521,12 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
501
521
|
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
|
502
522
|
async with inner_send_chan:
|
|
503
523
|
try:
|
|
504
|
-
|
|
505
|
-
prompt,
|
|
506
|
-
|
|
524
|
+
if is_chatglm_ggml:
|
|
525
|
+
iterator = await model.chat(prompt, chat_history, kwargs)
|
|
526
|
+
else:
|
|
527
|
+
iterator = await model.chat(
|
|
528
|
+
prompt, system_prompt, chat_history, kwargs
|
|
529
|
+
)
|
|
507
530
|
async for chunk in iterator:
|
|
508
531
|
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
|
509
532
|
if await request.is_disconnected():
|
|
@@ -525,7 +548,10 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
525
548
|
|
|
526
549
|
else:
|
|
527
550
|
try:
|
|
528
|
-
|
|
551
|
+
if is_chatglm_ggml:
|
|
552
|
+
return await model.chat(prompt, chat_history, kwargs)
|
|
553
|
+
else:
|
|
554
|
+
return await model.chat(prompt, system_prompt, chat_history, kwargs)
|
|
529
555
|
except Exception as e:
|
|
530
556
|
logger.error(e, exc_info=True)
|
|
531
557
|
raise HTTPException(status_code=500, detail=str(e))
|
xinference/core/supervisor.py
CHANGED
|
@@ -16,7 +16,7 @@ import asyncio
|
|
|
16
16
|
import time
|
|
17
17
|
from dataclasses import dataclass
|
|
18
18
|
from logging import getLogger
|
|
19
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
19
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
20
20
|
|
|
21
21
|
import xoscar as xo
|
|
22
22
|
|
|
@@ -74,6 +74,69 @@ class SupervisorActor(xo.Actor):
|
|
|
74
74
|
|
|
75
75
|
raise RuntimeError("No available worker found")
|
|
76
76
|
|
|
77
|
+
@log_sync(logger=logger)
|
|
78
|
+
def list_model_registrations(self, model_type: str) -> List[Dict[str, Any]]:
|
|
79
|
+
if model_type == "LLM":
|
|
80
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
|
|
81
|
+
|
|
82
|
+
ret = [
|
|
83
|
+
{"model_name": f.model_name, "is_builtin": True}
|
|
84
|
+
for f in BUILTIN_LLM_FAMILIES
|
|
85
|
+
]
|
|
86
|
+
user_defined_llm_families = get_user_defined_llm_families()
|
|
87
|
+
ret.extend(
|
|
88
|
+
[
|
|
89
|
+
{"model_name": f.model_name, "is_builtin": False}
|
|
90
|
+
for f in user_defined_llm_families
|
|
91
|
+
]
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return ret
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
97
|
+
|
|
98
|
+
@log_sync(logger=logger)
|
|
99
|
+
def get_model_registration(
|
|
100
|
+
self, model_type: str, model_name: str
|
|
101
|
+
) -> Dict[str, Any]:
|
|
102
|
+
if model_type == "LLM":
|
|
103
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
|
|
104
|
+
|
|
105
|
+
for f in BUILTIN_LLM_FAMILIES + get_user_defined_llm_families():
|
|
106
|
+
if f.model_name == model_name:
|
|
107
|
+
return f
|
|
108
|
+
|
|
109
|
+
raise ValueError(f"Model {model_name} not found")
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
112
|
+
|
|
113
|
+
@log_async(logger=logger)
|
|
114
|
+
async def register_model(self, model_type: str, model: str, persist: bool):
|
|
115
|
+
if model_type == "LLM":
|
|
116
|
+
from ..model.llm import LLMFamilyV1, register_llm
|
|
117
|
+
|
|
118
|
+
llm_family = LLMFamilyV1.parse_raw(model)
|
|
119
|
+
register_llm(llm_family, persist)
|
|
120
|
+
|
|
121
|
+
if not self.is_local_deployment:
|
|
122
|
+
for worker in self._worker_address_to_worker.values():
|
|
123
|
+
await worker.register_model(model_type, model, persist)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
126
|
+
|
|
127
|
+
@log_async(logger=logger)
|
|
128
|
+
async def unregister_model(self, model_type: str, model_name: str):
|
|
129
|
+
if model_type == "LLM":
|
|
130
|
+
from ..model.llm import unregister_llm
|
|
131
|
+
|
|
132
|
+
unregister_llm(model_name)
|
|
133
|
+
|
|
134
|
+
if not self.is_local_deployment:
|
|
135
|
+
for worker in self._worker_address_to_worker.values():
|
|
136
|
+
await worker.unregister_model(model_name)
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
139
|
+
|
|
77
140
|
async def launch_builtin_model(
|
|
78
141
|
self,
|
|
79
142
|
model_uid: str,
|
xinference/core/worker.py
CHANGED
|
@@ -108,8 +108,30 @@ class WorkerActor(xo.Actor):
|
|
|
108
108
|
"model_format": llm_spec.model_format,
|
|
109
109
|
"model_size_in_billions": llm_spec.model_size_in_billions,
|
|
110
110
|
"quantization": quantization,
|
|
111
|
+
"revision": llm_spec.model_revision,
|
|
111
112
|
}
|
|
112
113
|
|
|
114
|
+
@log_sync(logger=logger)
|
|
115
|
+
async def register_model(self, model_type: str, model: str, persist: bool):
|
|
116
|
+
# TODO: centralized model registrations
|
|
117
|
+
if model_type == "LLM":
|
|
118
|
+
from ..model.llm import LLMFamilyV1, register_llm
|
|
119
|
+
|
|
120
|
+
llm_family = LLMFamilyV1.parse_raw(model)
|
|
121
|
+
register_llm(llm_family, persist)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
124
|
+
|
|
125
|
+
@log_sync(logger=logger)
|
|
126
|
+
async def unregister_model(self, model_type: str, model_name: str):
|
|
127
|
+
# TODO: centralized model registrations
|
|
128
|
+
if model_type == "LLM":
|
|
129
|
+
from ..model.llm import unregister_llm
|
|
130
|
+
|
|
131
|
+
unregister_llm(model_name)
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
134
|
+
|
|
113
135
|
@log_async(logger=logger)
|
|
114
136
|
async def launch_builtin_model(
|
|
115
137
|
self,
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -11,10 +11,10 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
import configparser
|
|
16
15
|
import logging
|
|
17
16
|
import os
|
|
17
|
+
import sys
|
|
18
18
|
from typing import Optional
|
|
19
19
|
|
|
20
20
|
import click
|
|
@@ -30,6 +30,32 @@ from ..constants import (
|
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def get_config_string(log_level: str) -> str:
|
|
34
|
+
return f"""
|
|
35
|
+
[loggers]
|
|
36
|
+
keys=root
|
|
37
|
+
|
|
38
|
+
[handlers]
|
|
39
|
+
keys=stream_handler
|
|
40
|
+
|
|
41
|
+
[formatters]
|
|
42
|
+
keys=formatter
|
|
43
|
+
|
|
44
|
+
[logger_root]
|
|
45
|
+
level={log_level.upper()}
|
|
46
|
+
handlers=stream_handler
|
|
47
|
+
|
|
48
|
+
[handler_stream_handler]
|
|
49
|
+
class=StreamHandler
|
|
50
|
+
formatter=formatter
|
|
51
|
+
level={log_level.upper()}
|
|
52
|
+
args=(sys.stderr,)
|
|
53
|
+
|
|
54
|
+
[formatter_formatter]
|
|
55
|
+
format=%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
|
|
33
59
|
def get_endpoint(endpoint: Optional[str]) -> str:
|
|
34
60
|
# user didn't specify the endpoint.
|
|
35
61
|
if endpoint is None:
|
|
@@ -57,9 +83,10 @@ def cli(
|
|
|
57
83
|
if ctx.invoked_subcommand is None:
|
|
58
84
|
from .local import main
|
|
59
85
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
logging_conf
|
|
86
|
+
logging_conf = configparser.RawConfigParser()
|
|
87
|
+
logger_config_string = get_config_string(log_level)
|
|
88
|
+
logging_conf.read_string(logger_config_string)
|
|
89
|
+
logging.config.fileConfig(logging_conf) # type: ignore
|
|
63
90
|
|
|
64
91
|
address = f"{host}:{get_next_port()}"
|
|
65
92
|
|
|
@@ -102,9 +129,10 @@ def supervisor(
|
|
|
102
129
|
def worker(log_level: str, endpoint: Optional[str], host: str):
|
|
103
130
|
from ..deploy.worker import main
|
|
104
131
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
logging_conf
|
|
132
|
+
logging_conf = configparser.RawConfigParser()
|
|
133
|
+
logger_config_string = get_config_string(log_level)
|
|
134
|
+
logging_conf.read_string(logger_config_string)
|
|
135
|
+
logging.config.fileConfig(level=logging.getLevelName(log_level.upper())) # type: ignore
|
|
108
136
|
|
|
109
137
|
endpoint = get_endpoint(endpoint)
|
|
110
138
|
|
|
@@ -146,7 +174,7 @@ def model_launch(
|
|
|
146
174
|
quantization=quantization,
|
|
147
175
|
)
|
|
148
176
|
|
|
149
|
-
print(f"Model uid: {model_uid}")
|
|
177
|
+
print(f"Model uid: {model_uid}", file=sys.stderr)
|
|
150
178
|
|
|
151
179
|
|
|
152
180
|
@cli.command("list")
|
|
@@ -157,18 +185,16 @@ def model_launch(
|
|
|
157
185
|
)
|
|
158
186
|
@click.option("--all", is_flag=True)
|
|
159
187
|
def model_list(endpoint: Optional[str], all: bool):
|
|
160
|
-
import sys
|
|
161
|
-
|
|
162
188
|
from tabulate import tabulate
|
|
163
189
|
|
|
164
190
|
# TODO: get from the supervisor
|
|
165
|
-
from ..model.llm import
|
|
191
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES
|
|
166
192
|
|
|
167
193
|
endpoint = get_endpoint(endpoint)
|
|
168
194
|
|
|
169
195
|
table = []
|
|
170
196
|
if all:
|
|
171
|
-
for model_family in
|
|
197
|
+
for model_family in BUILTIN_LLM_FAMILIES:
|
|
172
198
|
table.append(
|
|
173
199
|
[
|
|
174
200
|
model_family.model_name,
|
xinference/deploy/worker.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import logging
|
|
17
|
-
from typing import Dict, Optional
|
|
17
|
+
from typing import Any, Dict, Optional
|
|
18
18
|
|
|
19
19
|
import xoscar as xo
|
|
20
20
|
|
|
@@ -53,7 +53,7 @@ async def _start_worker(
|
|
|
53
53
|
await pool.join()
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def main(address: str, supervisor_address: str, logging_conf:
|
|
56
|
+
def main(address: str, supervisor_address: str, logging_conf: Any = None):
|
|
57
57
|
loop = asyncio.get_event_loop()
|
|
58
58
|
task = loop.create_task(_start_worker(address, supervisor_address, logging_conf))
|
|
59
59
|
|
xinference/model/llm/__init__.py
CHANGED
|
@@ -12,98 +12,26 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import codecs
|
|
15
16
|
import json
|
|
16
|
-
import logging
|
|
17
17
|
import os
|
|
18
|
-
import platform
|
|
19
|
-
from typing import List, Optional, Tuple, Type
|
|
20
18
|
|
|
21
19
|
from .core import LLM
|
|
22
20
|
from .llm_family import (
|
|
21
|
+
BUILTIN_LLM_FAMILIES,
|
|
22
|
+
LLM_CLASSES,
|
|
23
23
|
GgmlLLMSpecV1,
|
|
24
24
|
LLMFamilyV1,
|
|
25
25
|
LLMSpecV1,
|
|
26
26
|
PromptStyleV1,
|
|
27
27
|
PytorchLLMSpecV1,
|
|
28
|
+
get_user_defined_llm_families,
|
|
29
|
+
match_llm,
|
|
30
|
+
match_llm_cls,
|
|
31
|
+
register_llm,
|
|
32
|
+
unregister_llm,
|
|
28
33
|
)
|
|
29
34
|
|
|
30
|
-
_LLM_CLASSES: List[Type[LLM]] = []
|
|
31
|
-
|
|
32
|
-
LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
33
|
-
|
|
34
|
-
logger = logging.getLogger(__name__)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _is_linux():
|
|
38
|
-
return platform.system() == "Linux"
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def _has_cuda_device():
|
|
42
|
-
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
|
43
|
-
if cuda_visible_devices:
|
|
44
|
-
return True
|
|
45
|
-
else:
|
|
46
|
-
from xorbits._mars.resource import cuda_count
|
|
47
|
-
|
|
48
|
-
return cuda_count() > 0
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def match_llm(
|
|
52
|
-
model_name: str,
|
|
53
|
-
model_format: Optional[str] = None,
|
|
54
|
-
model_size_in_billions: Optional[int] = None,
|
|
55
|
-
quantization: Optional[str] = None,
|
|
56
|
-
is_local_deployment: bool = False,
|
|
57
|
-
) -> Optional[Tuple[LLMFamilyV1, LLMSpecV1, str]]:
|
|
58
|
-
"""
|
|
59
|
-
Find an LLM family, spec, and quantization that satisfy given criteria.
|
|
60
|
-
"""
|
|
61
|
-
for family in LLM_FAMILIES:
|
|
62
|
-
if model_name != family.model_name:
|
|
63
|
-
continue
|
|
64
|
-
for spec in family.model_specs:
|
|
65
|
-
if (
|
|
66
|
-
model_format
|
|
67
|
-
and model_format != spec.model_format
|
|
68
|
-
or model_size_in_billions
|
|
69
|
-
and model_size_in_billions != spec.model_size_in_billions
|
|
70
|
-
or quantization
|
|
71
|
-
and quantization not in spec.quantizations
|
|
72
|
-
):
|
|
73
|
-
continue
|
|
74
|
-
if quantization:
|
|
75
|
-
return family, spec, quantization
|
|
76
|
-
else:
|
|
77
|
-
# by default, choose the most coarse-grained quantization.
|
|
78
|
-
# TODO: too hacky.
|
|
79
|
-
quantizations = spec.quantizations
|
|
80
|
-
quantizations.sort()
|
|
81
|
-
for q in quantizations:
|
|
82
|
-
if (
|
|
83
|
-
is_local_deployment
|
|
84
|
-
and not (_is_linux() and _has_cuda_device())
|
|
85
|
-
and q == "4-bit"
|
|
86
|
-
):
|
|
87
|
-
logger.warning(
|
|
88
|
-
"Skipping %s for non-linux or non-cuda local deployment .",
|
|
89
|
-
q,
|
|
90
|
-
)
|
|
91
|
-
continue
|
|
92
|
-
return family, spec, q
|
|
93
|
-
return None
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def match_llm_cls(
|
|
97
|
-
llm_family: LLMFamilyV1, llm_spec: "LLMSpecV1"
|
|
98
|
-
) -> Optional[Type[LLM]]:
|
|
99
|
-
"""
|
|
100
|
-
Find an LLM implementation for given LLM family and spec.
|
|
101
|
-
"""
|
|
102
|
-
for cls in _LLM_CLASSES:
|
|
103
|
-
if cls.match(llm_family, llm_spec):
|
|
104
|
-
return cls
|
|
105
|
-
return None
|
|
106
|
-
|
|
107
35
|
|
|
108
36
|
def _install():
|
|
109
37
|
from .ggml.chatglm import ChatglmCppChatModel
|
|
@@ -114,7 +42,7 @@ def _install():
|
|
|
114
42
|
from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
|
|
115
43
|
from .pytorch.vicuna import VicunaPytorchChatModel
|
|
116
44
|
|
|
117
|
-
|
|
45
|
+
LLM_CLASSES.extend(
|
|
118
46
|
[
|
|
119
47
|
ChatglmCppChatModel,
|
|
120
48
|
LlamaCppModel,
|
|
@@ -132,5 +60,14 @@ def _install():
|
|
|
132
60
|
json_path = os.path.join(
|
|
133
61
|
os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
|
|
134
62
|
)
|
|
135
|
-
for json_obj in json.load(open(json_path)):
|
|
136
|
-
|
|
63
|
+
for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
|
|
64
|
+
BUILTIN_LLM_FAMILIES.append(LLMFamilyV1.parse_obj(json_obj))
|
|
65
|
+
|
|
66
|
+
from ...constants import XINFERENCE_MODEL_DIR
|
|
67
|
+
|
|
68
|
+
user_defined_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "llm")
|
|
69
|
+
if os.path.isdir(user_defined_llm_dir):
|
|
70
|
+
for f in os.listdir(user_defined_llm_dir):
|
|
71
|
+
with codecs.open(f, encoding="utf-8") as fd:
|
|
72
|
+
user_defined_llm_family = LLMFamilyV1.parse_obj(json.load(fd))
|
|
73
|
+
register_llm(user_defined_llm_family, persist=False)
|
|
@@ -139,6 +139,7 @@ class LlamaCppModel(LLM):
|
|
|
139
139
|
llamacpp_model_config["n_gqa"] = 8
|
|
140
140
|
|
|
141
141
|
if self._is_darwin_and_apple_silicon() and self._can_apply_metal():
|
|
142
|
+
# TODO: platform.processor() is not safe, need to be replaced to other method.
|
|
142
143
|
llamacpp_model_config.setdefault("n_gpu_layers", 1)
|
|
143
144
|
elif self._is_linux() and self._can_apply_cublas():
|
|
144
145
|
llamacpp_model_config.setdefault("n_gpu_layers", self._gpu_layers)
|
|
@@ -41,7 +41,8 @@
|
|
|
41
41
|
"8-bit",
|
|
42
42
|
"none"
|
|
43
43
|
],
|
|
44
|
-
"model_id": "baichuan-inc/Baichuan-7B"
|
|
44
|
+
"model_id": "baichuan-inc/Baichuan-7B",
|
|
45
|
+
"model_revision": "c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756"
|
|
45
46
|
},
|
|
46
47
|
{
|
|
47
48
|
"model_format": "pytorch",
|
|
@@ -51,7 +52,8 @@
|
|
|
51
52
|
"8-bit",
|
|
52
53
|
"none"
|
|
53
54
|
],
|
|
54
|
-
"model_id": "baichuan-inc/Baichuan-13B-Base"
|
|
55
|
+
"model_id": "baichuan-inc/Baichuan-13B-Base",
|
|
56
|
+
"model_revision": "0ef0739c7bdd34df954003ef76d80f3dabca2ff9"
|
|
55
57
|
}
|
|
56
58
|
],
|
|
57
59
|
"prompt_style": null
|
|
@@ -98,7 +100,8 @@
|
|
|
98
100
|
"8-bit",
|
|
99
101
|
"none"
|
|
100
102
|
],
|
|
101
|
-
"model_id": "baichuan-inc/Baichuan-13B-Chat"
|
|
103
|
+
"model_id": "baichuan-inc/Baichuan-13B-Chat",
|
|
104
|
+
"model_revision": "19ef51ba5bad8935b03acd20ff04a269210983bc"
|
|
102
105
|
}
|
|
103
106
|
],
|
|
104
107
|
"prompt_style": {
|
|
@@ -267,7 +270,8 @@
|
|
|
267
270
|
"8-bit",
|
|
268
271
|
"none"
|
|
269
272
|
],
|
|
270
|
-
"model_id": "lmsys/vicuna-33b-v1.3"
|
|
273
|
+
"model_id": "lmsys/vicuna-33b-v1.3",
|
|
274
|
+
"model_revision": "ef8d6becf883fb3ce52e3706885f761819477ab4"
|
|
271
275
|
},
|
|
272
276
|
{
|
|
273
277
|
"model_format": "pytorch",
|
|
@@ -277,7 +281,8 @@
|
|
|
277
281
|
"8-bit",
|
|
278
282
|
"none"
|
|
279
283
|
],
|
|
280
|
-
"model_id": "lmsys/vicuna-13b-v1.3"
|
|
284
|
+
"model_id": "lmsys/vicuna-13b-v1.3",
|
|
285
|
+
"model_revision": "6566e9cb1787585d1147dcf4f9bc48f29e1328d2"
|
|
281
286
|
},
|
|
282
287
|
{
|
|
283
288
|
"model_format": "pytorch",
|
|
@@ -287,7 +292,8 @@
|
|
|
287
292
|
"8-bit",
|
|
288
293
|
"none"
|
|
289
294
|
],
|
|
290
|
-
"model_id": "lmsys/vicuna-7b-v1.3"
|
|
295
|
+
"model_id": "lmsys/vicuna-7b-v1.3",
|
|
296
|
+
"model_revision": "236eeeab96f0dc2e463f2bebb7bb49809279c6d6"
|
|
291
297
|
}
|
|
292
298
|
],
|
|
293
299
|
"prompt_style": {
|
|
@@ -395,7 +401,8 @@
|
|
|
395
401
|
"8-bit",
|
|
396
402
|
"none"
|
|
397
403
|
],
|
|
398
|
-
"model_id": "THUDM/chatglm-6b"
|
|
404
|
+
"model_id": "THUDM/chatglm-6b",
|
|
405
|
+
"model_revision": "b1502f4f75c71499a3d566b14463edd62620ce9f"
|
|
399
406
|
}
|
|
400
407
|
],
|
|
401
408
|
"prompt_style": {
|
|
@@ -441,7 +448,8 @@
|
|
|
441
448
|
"8-bit",
|
|
442
449
|
"none"
|
|
443
450
|
],
|
|
444
|
-
"model_id": "THUDM/chatglm2-6b"
|
|
451
|
+
"model_id": "THUDM/chatglm2-6b",
|
|
452
|
+
"model_revision": "b1502f4f75c71499a3d566b14463edd62620ce9f"
|
|
445
453
|
}
|
|
446
454
|
],
|
|
447
455
|
"prompt_style": {
|
|
@@ -474,7 +482,8 @@
|
|
|
474
482
|
"8-bit",
|
|
475
483
|
"none"
|
|
476
484
|
],
|
|
477
|
-
"model_id": "THUDM/chatglm2-6b-32k"
|
|
485
|
+
"model_id": "THUDM/chatglm2-6b-32k",
|
|
486
|
+
"model_revision": "455746d4706479a1cbbd07179db39eb2741dc692"
|
|
478
487
|
}
|
|
479
488
|
],
|
|
480
489
|
"prompt_style": {
|
|
@@ -643,7 +652,8 @@
|
|
|
643
652
|
"8-bit",
|
|
644
653
|
"none"
|
|
645
654
|
],
|
|
646
|
-
"model_id": "facebook/opt-125m"
|
|
655
|
+
"model_id": "facebook/opt-125m",
|
|
656
|
+
"model_revision": "3d2b5f275bdf882b8775f902e1bfdb790e2cfc32"
|
|
647
657
|
}
|
|
648
658
|
],
|
|
649
659
|
"prompt_style": null
|
|
@@ -667,7 +677,8 @@
|
|
|
667
677
|
"8-bit",
|
|
668
678
|
"none"
|
|
669
679
|
],
|
|
670
|
-
"model_id": "tiiuae/falcon-40b"
|
|
680
|
+
"model_id": "tiiuae/falcon-40b",
|
|
681
|
+
"model_revision": "561820f7eef0cc56a31ea38af15ca1acb07fab5d"
|
|
671
682
|
},
|
|
672
683
|
{
|
|
673
684
|
"model_format": "pytorch",
|
|
@@ -677,7 +688,8 @@
|
|
|
677
688
|
"8-bit",
|
|
678
689
|
"none"
|
|
679
690
|
],
|
|
680
|
-
"model_id": "tiiuae/falcon-7b"
|
|
691
|
+
"model_id": "tiiuae/falcon-7b",
|
|
692
|
+
"model_revision": "378337427557d1df3e742264a2901a49f25d4eb1"
|
|
681
693
|
}
|
|
682
694
|
],
|
|
683
695
|
"prompt_style": null
|
|
@@ -701,7 +713,8 @@
|
|
|
701
713
|
"8-bit",
|
|
702
714
|
"none"
|
|
703
715
|
],
|
|
704
|
-
"model_id": "tiiuae/falcon-7b-instruct"
|
|
716
|
+
"model_id": "tiiuae/falcon-7b-instruct",
|
|
717
|
+
"model_revision": "eb410fb6ffa9028e97adb801f0d6ec46d02f8b07"
|
|
705
718
|
},
|
|
706
719
|
{
|
|
707
720
|
"model_format": "pytorch",
|
|
@@ -711,7 +724,8 @@
|
|
|
711
724
|
"8-bit",
|
|
712
725
|
"none"
|
|
713
726
|
],
|
|
714
|
-
"model_id": "tiiuae/falcon-40b-instruct"
|
|
727
|
+
"model_id": "tiiuae/falcon-40b-instruct",
|
|
728
|
+
"model_revision": "ca78eac0ed45bf64445ff0687fabba1598daebf3"
|
|
715
729
|
}
|
|
716
730
|
],
|
|
717
731
|
"prompt_style": {
|
|
@@ -759,7 +773,8 @@
|
|
|
759
773
|
"8-bit",
|
|
760
774
|
"none"
|
|
761
775
|
],
|
|
762
|
-
"model_id": "Qwen/Qwen-7B-Chat"
|
|
776
|
+
"model_id": "Qwen/Qwen-7B-Chat",
|
|
777
|
+
"model_revision": "5c611a5cde5769440581f91e8b4bba050f62b1af"
|
|
763
778
|
}
|
|
764
779
|
],
|
|
765
780
|
"prompt_style": {
|