xinference 0.1.1__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- {xinference-0.1.1/xinference.egg-info → xinference-0.1.3}/PKG-INFO +105 -1
- {xinference-0.1.1 → xinference-0.1.3}/README.md +104 -0
- {xinference-0.1.1 → xinference-0.1.3}/setup.cfg +2 -2
- {xinference-0.1.1 → xinference-0.1.3}/xinference/_version.py +3 -3
- {xinference-0.1.1 → xinference-0.1.3}/xinference/client.py +18 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/constants.py +1 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/gradio.py +2 -2
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/restful_api.py +31 -5
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/supervisor.py +64 -1
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/worker.py +22 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/cmdline.py +39 -13
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/worker.py +2 -2
- xinference-0.1.3/xinference/model/llm/__init__.py +73 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/ggml/llamacpp.py +1 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/llm_family.json +30 -15
- xinference-0.1.3/xinference/model/llm/llm_family.py +279 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/core.py +63 -40
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/utils.py +5 -1
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/utils.py +6 -0
- {xinference-0.1.1 → xinference-0.1.3/xinference.egg-info}/PKG-INFO +105 -1
- {xinference-0.1.1 → xinference-0.1.3}/xinference.egg-info/requires.txt +2 -2
- xinference-0.1.1/xinference/model/llm/__init__.py +0 -136
- xinference-0.1.1/xinference/model/llm/llm_family.py +0 -134
- {xinference-0.1.1 → xinference-0.1.3}/LICENSE +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/MANIFEST.in +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/pyproject.toml +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/setup.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/versioneer.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/api.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/model.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/resource.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/core/utils.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/local.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/supervisor.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/test/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/deploy/utils.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/isolation.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/locale/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/locale/utils.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/locale/zh_CN.json +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/core.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/core.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/ggml/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/ggml/chatglm.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/__init__.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/baichuan.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/chatglm.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/compression.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/falcon.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/model/llm/pytorch/vicuna.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference/types.py +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference.egg-info/SOURCES.txt +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference.egg-info/dependency_links.txt +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference.egg-info/entry_points.txt +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference.egg-info/not-zip-safe +0 -0
- {xinference-0.1.1 → xinference-0.1.3}/xinference.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: xinference
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Model Serving Made Easy
|
|
5
5
|
Home-page: https://github.com/xorbitsai/inference
|
|
6
6
|
Author: Qin Xuye
|
|
@@ -238,6 +238,110 @@ $ xinference list --all
|
|
|
238
238
|
- If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
|
|
239
239
|
- `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.
|
|
240
240
|
|
|
241
|
+
## Custom models \[Experimental\]
|
|
242
|
+
Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.
|
|
243
|
+
|
|
244
|
+
Define a custom model based on the following template:
|
|
245
|
+
```python
|
|
246
|
+
custom_model = {
|
|
247
|
+
"version": 1,
|
|
248
|
+
# model name. must start with a letter or a
|
|
249
|
+
# digit, and can only contain letters, digits,
|
|
250
|
+
# underscores, or dashes.
|
|
251
|
+
"model_name": "nsql-2B",
|
|
252
|
+
# supported languages
|
|
253
|
+
"model_lang": [
|
|
254
|
+
"en"
|
|
255
|
+
],
|
|
256
|
+
# model abilities. could be "embed", "generate"
|
|
257
|
+
# and "chat".
|
|
258
|
+
"model_ability": [
|
|
259
|
+
"generate"
|
|
260
|
+
],
|
|
261
|
+
# model specifications.
|
|
262
|
+
"model_specs": [
|
|
263
|
+
{
|
|
264
|
+
# model format.
|
|
265
|
+
"model_format": "pytorch",
|
|
266
|
+
"model_size_in_billions": 2,
|
|
267
|
+
# quantizations.
|
|
268
|
+
"quantizations": [
|
|
269
|
+
"4-bit",
|
|
270
|
+
"8-bit",
|
|
271
|
+
"none"
|
|
272
|
+
],
|
|
273
|
+
# hugging face model ID.
|
|
274
|
+
"model_id": "NumbersStation/nsql-2B"
|
|
275
|
+
}
|
|
276
|
+
],
|
|
277
|
+
# prompt style, required by chat models.
|
|
278
|
+
# for more details, see: xinference/model/llm/tests/test_utils.py
|
|
279
|
+
"prompt_style": None
|
|
280
|
+
}
|
|
281
|
+
```
|
|
282
|
+
|
|
283
|
+
Register the custom model:
|
|
284
|
+
```python
|
|
285
|
+
import json
|
|
286
|
+
|
|
287
|
+
from xinference.client import Client
|
|
288
|
+
|
|
289
|
+
# replace with real xinference endpoint
|
|
290
|
+
endpoint = "http://localhost:9997"
|
|
291
|
+
client = Client(endpoint)
|
|
292
|
+
client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
|
|
293
|
+
```
|
|
294
|
+
|
|
295
|
+
Load the custom model:
|
|
296
|
+
```python
|
|
297
|
+
uid = client.launch_model(model_name='nsql-2B')
|
|
298
|
+
```
|
|
299
|
+
|
|
300
|
+
Run the custom model:
|
|
301
|
+
```python
|
|
302
|
+
text = """CREATE TABLE work_orders (
|
|
303
|
+
ID NUMBER,
|
|
304
|
+
CREATED_AT TEXT,
|
|
305
|
+
COST FLOAT,
|
|
306
|
+
INVOICE_AMOUNT FLOAT,
|
|
307
|
+
IS_DUE BOOLEAN,
|
|
308
|
+
IS_OPEN BOOLEAN,
|
|
309
|
+
IS_OVERDUE BOOLEAN,
|
|
310
|
+
COUNTRY_NAME TEXT,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
-- Using valid SQLite, answer the following questions for the tables provided above.
|
|
314
|
+
|
|
315
|
+
-- how many work orders are open?
|
|
316
|
+
|
|
317
|
+
SELECT"""
|
|
318
|
+
|
|
319
|
+
model = client.get_model(model_uid=uid)
|
|
320
|
+
model.generate(prompt=text)
|
|
321
|
+
```
|
|
322
|
+
|
|
323
|
+
Result:
|
|
324
|
+
```json
|
|
325
|
+
{
|
|
326
|
+
"id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
|
|
327
|
+
"object":"text_completion",
|
|
328
|
+
"created":1691418511,
|
|
329
|
+
"model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
|
|
330
|
+
"choices":[
|
|
331
|
+
{
|
|
332
|
+
"text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
|
|
333
|
+
"index":0,
|
|
334
|
+
"logprobs":"None",
|
|
335
|
+
"finish_reason":"stop"
|
|
336
|
+
}
|
|
337
|
+
],
|
|
338
|
+
"usage":{
|
|
339
|
+
"prompt_tokens":117,
|
|
340
|
+
"completion_tokens":17,
|
|
341
|
+
"total_tokens":134
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
```
|
|
241
345
|
|
|
242
346
|
## Pytorch Model Best Practices
|
|
243
347
|
|
|
@@ -210,6 +210,110 @@ $ xinference list --all
|
|
|
210
210
|
- If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
|
|
211
211
|
- `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.
|
|
212
212
|
|
|
213
|
+
## Custom models \[Experimental\]
|
|
214
|
+
Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.
|
|
215
|
+
|
|
216
|
+
Define a custom model based on the following template:
|
|
217
|
+
```python
|
|
218
|
+
custom_model = {
|
|
219
|
+
"version": 1,
|
|
220
|
+
# model name. must start with a letter or a
|
|
221
|
+
# digit, and can only contain letters, digits,
|
|
222
|
+
# underscores, or dashes.
|
|
223
|
+
"model_name": "nsql-2B",
|
|
224
|
+
# supported languages
|
|
225
|
+
"model_lang": [
|
|
226
|
+
"en"
|
|
227
|
+
],
|
|
228
|
+
# model abilities. could be "embed", "generate"
|
|
229
|
+
# and "chat".
|
|
230
|
+
"model_ability": [
|
|
231
|
+
"generate"
|
|
232
|
+
],
|
|
233
|
+
# model specifications.
|
|
234
|
+
"model_specs": [
|
|
235
|
+
{
|
|
236
|
+
# model format.
|
|
237
|
+
"model_format": "pytorch",
|
|
238
|
+
"model_size_in_billions": 2,
|
|
239
|
+
# quantizations.
|
|
240
|
+
"quantizations": [
|
|
241
|
+
"4-bit",
|
|
242
|
+
"8-bit",
|
|
243
|
+
"none"
|
|
244
|
+
],
|
|
245
|
+
# hugging face model ID.
|
|
246
|
+
"model_id": "NumbersStation/nsql-2B"
|
|
247
|
+
}
|
|
248
|
+
],
|
|
249
|
+
# prompt style, required by chat models.
|
|
250
|
+
# for more details, see: xinference/model/llm/tests/test_utils.py
|
|
251
|
+
"prompt_style": None
|
|
252
|
+
}
|
|
253
|
+
```
|
|
254
|
+
|
|
255
|
+
Register the custom model:
|
|
256
|
+
```python
|
|
257
|
+
import json
|
|
258
|
+
|
|
259
|
+
from xinference.client import Client
|
|
260
|
+
|
|
261
|
+
# replace with real xinference endpoint
|
|
262
|
+
endpoint = "http://localhost:9997"
|
|
263
|
+
client = Client(endpoint)
|
|
264
|
+
client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
|
|
265
|
+
```
|
|
266
|
+
|
|
267
|
+
Load the custom model:
|
|
268
|
+
```python
|
|
269
|
+
uid = client.launch_model(model_name='nsql-2B')
|
|
270
|
+
```
|
|
271
|
+
|
|
272
|
+
Run the custom model:
|
|
273
|
+
```python
|
|
274
|
+
text = """CREATE TABLE work_orders (
|
|
275
|
+
ID NUMBER,
|
|
276
|
+
CREATED_AT TEXT,
|
|
277
|
+
COST FLOAT,
|
|
278
|
+
INVOICE_AMOUNT FLOAT,
|
|
279
|
+
IS_DUE BOOLEAN,
|
|
280
|
+
IS_OPEN BOOLEAN,
|
|
281
|
+
IS_OVERDUE BOOLEAN,
|
|
282
|
+
COUNTRY_NAME TEXT,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
-- Using valid SQLite, answer the following questions for the tables provided above.
|
|
286
|
+
|
|
287
|
+
-- how many work orders are open?
|
|
288
|
+
|
|
289
|
+
SELECT"""
|
|
290
|
+
|
|
291
|
+
model = client.get_model(model_uid=uid)
|
|
292
|
+
model.generate(prompt=text)
|
|
293
|
+
```
|
|
294
|
+
|
|
295
|
+
Result:
|
|
296
|
+
```json
|
|
297
|
+
{
|
|
298
|
+
"id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
|
|
299
|
+
"object":"text_completion",
|
|
300
|
+
"created":1691418511,
|
|
301
|
+
"model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
|
|
302
|
+
"choices":[
|
|
303
|
+
{
|
|
304
|
+
"text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
|
|
305
|
+
"index":0,
|
|
306
|
+
"logprobs":"None",
|
|
307
|
+
"finish_reason":"stop"
|
|
308
|
+
}
|
|
309
|
+
],
|
|
310
|
+
"usage":{
|
|
311
|
+
"prompt_tokens":117,
|
|
312
|
+
"completion_tokens":17,
|
|
313
|
+
"total_tokens":134
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
```
|
|
213
317
|
|
|
214
318
|
## Pytorch Model Best Practices
|
|
215
319
|
|
|
@@ -60,7 +60,6 @@ dev =
|
|
|
60
60
|
flake8>=3.8.0
|
|
61
61
|
black
|
|
62
62
|
all =
|
|
63
|
-
chatglm-cpp
|
|
64
63
|
llama-cpp-python>=0.1.77
|
|
65
64
|
transformers>=4.31.0
|
|
66
65
|
torch
|
|
@@ -70,8 +69,8 @@ all =
|
|
|
70
69
|
bitsandbytes
|
|
71
70
|
protobuf
|
|
72
71
|
einops
|
|
72
|
+
tiktoken
|
|
73
73
|
ggml =
|
|
74
|
-
chatglm-cpp
|
|
75
74
|
llama-cpp-python>=0.1.77
|
|
76
75
|
pytorch =
|
|
77
76
|
transformers>=4.31.0
|
|
@@ -82,6 +81,7 @@ pytorch =
|
|
|
82
81
|
bitsandbytes
|
|
83
82
|
protobuf
|
|
84
83
|
einops
|
|
84
|
+
tiktoken
|
|
85
85
|
doc =
|
|
86
86
|
ipython>=6.5.0
|
|
87
87
|
sphinx>=3.0.0,<5.0.0
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2023-08-
|
|
11
|
+
"date": "2023-08-09T18:43:41+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.1.
|
|
14
|
+
"full-revisionid": "4d2f61cb6591ac94624f035b37259a89002abefd",
|
|
15
|
+
"version": "0.1.3"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -480,6 +480,24 @@ class Client:
|
|
|
480
480
|
# generate a time-based uuid.
|
|
481
481
|
return str(uuid.uuid1())
|
|
482
482
|
|
|
483
|
+
def register_model(self, model_type: str, model: str, persist: bool):
|
|
484
|
+
coro = self._supervisor_ref.register_model(model_type, model, persist)
|
|
485
|
+
self._isolation.call(coro)
|
|
486
|
+
|
|
487
|
+
def unregister_model(self, model_type: str, model_name: str):
|
|
488
|
+
coro = self._supervisor_ref.unregister_model(model_type, model_name)
|
|
489
|
+
self._isolation.call(coro)
|
|
490
|
+
|
|
491
|
+
def list_model_registrations(self, model_type: str) -> List[Dict[str, Any]]:
|
|
492
|
+
coro = self._supervisor_ref.list_model_registrations(model_type)
|
|
493
|
+
return self._isolation.call(coro)
|
|
494
|
+
|
|
495
|
+
def get_model_registration(
|
|
496
|
+
self, model_type: str, model_name: str
|
|
497
|
+
) -> Dict[str, Any]:
|
|
498
|
+
coro = self._supervisor_ref.get_model_registration(model_type, model_name)
|
|
499
|
+
return self._isolation.call(coro)
|
|
500
|
+
|
|
483
501
|
def launch_model(
|
|
484
502
|
self,
|
|
485
503
|
model_name: str,
|
|
@@ -17,6 +17,7 @@ from pathlib import Path
|
|
|
17
17
|
|
|
18
18
|
XINFERENCE_HOME = str(Path.home() / ".xinference")
|
|
19
19
|
XINFERENCE_CACHE_DIR = os.path.join(XINFERENCE_HOME, "cache")
|
|
20
|
+
XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
|
|
20
21
|
XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
|
|
21
22
|
|
|
22
23
|
XINFERENCE_DEFAULT_LOCAL_HOST = "127.0.0.1"
|
|
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional
|
|
|
18
18
|
import gradio as gr
|
|
19
19
|
|
|
20
20
|
from ..locale.utils import Locale
|
|
21
|
-
from ..model.llm import
|
|
21
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES, LLMFamilyV1, match_llm
|
|
22
22
|
from ..model.llm.llm_family import cache
|
|
23
23
|
from .api import SyncSupervisorAPI
|
|
24
24
|
|
|
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|
|
27
27
|
|
|
28
28
|
MODEL_TO_FAMILIES: Dict[str, LLMFamilyV1] = dict(
|
|
29
29
|
(model_family.model_name, model_family)
|
|
30
|
-
for model_family in
|
|
30
|
+
for model_family in BUILTIN_LLM_FAMILIES
|
|
31
31
|
if "chat" in model_family.model_ability
|
|
32
32
|
)
|
|
33
33
|
|
|
@@ -480,7 +480,7 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
480
480
|
(msg["content"] for msg in body.messages if msg["role"] == "system"), None
|
|
481
481
|
)
|
|
482
482
|
|
|
483
|
-
chat_history = body.messages
|
|
483
|
+
chat_history = body.messages[:-1] # exclude the prompt
|
|
484
484
|
|
|
485
485
|
model_uid = body.model
|
|
486
486
|
|
|
@@ -494,6 +494,26 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
494
494
|
logger.error(e, exc_info=True)
|
|
495
495
|
raise HTTPException(status_code=500, detail=str(e))
|
|
496
496
|
|
|
497
|
+
try:
|
|
498
|
+
desc = await self._supervisor_ref.describe_model(model_uid)
|
|
499
|
+
|
|
500
|
+
except ValueError as ve:
|
|
501
|
+
logger.error(str(ve), exc_info=True)
|
|
502
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
503
|
+
|
|
504
|
+
except Exception as e:
|
|
505
|
+
logger.error(e, exc_info=True)
|
|
506
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
507
|
+
|
|
508
|
+
is_chatglm_ggml = desc.get(
|
|
509
|
+
"model_format"
|
|
510
|
+
) == "ggmlv3" and "chatglm" in desc.get("model_name", "")
|
|
511
|
+
|
|
512
|
+
if is_chatglm_ggml and system_prompt is not None:
|
|
513
|
+
raise HTTPException(
|
|
514
|
+
status_code=400, detail="ChatGLM ggml does not have system prompt"
|
|
515
|
+
)
|
|
516
|
+
|
|
497
517
|
if body.stream:
|
|
498
518
|
# create a pair of memory object streams
|
|
499
519
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
|
@@ -501,9 +521,12 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
501
521
|
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
|
|
502
522
|
async with inner_send_chan:
|
|
503
523
|
try:
|
|
504
|
-
|
|
505
|
-
prompt,
|
|
506
|
-
|
|
524
|
+
if is_chatglm_ggml:
|
|
525
|
+
iterator = await model.chat(prompt, chat_history, kwargs)
|
|
526
|
+
else:
|
|
527
|
+
iterator = await model.chat(
|
|
528
|
+
prompt, system_prompt, chat_history, kwargs
|
|
529
|
+
)
|
|
507
530
|
async for chunk in iterator:
|
|
508
531
|
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
|
509
532
|
if await request.is_disconnected():
|
|
@@ -525,7 +548,10 @@ class RESTfulAPIActor(xo.Actor):
|
|
|
525
548
|
|
|
526
549
|
else:
|
|
527
550
|
try:
|
|
528
|
-
|
|
551
|
+
if is_chatglm_ggml:
|
|
552
|
+
return await model.chat(prompt, chat_history, kwargs)
|
|
553
|
+
else:
|
|
554
|
+
return await model.chat(prompt, system_prompt, chat_history, kwargs)
|
|
529
555
|
except Exception as e:
|
|
530
556
|
logger.error(e, exc_info=True)
|
|
531
557
|
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -16,7 +16,7 @@ import asyncio
|
|
|
16
16
|
import time
|
|
17
17
|
from dataclasses import dataclass
|
|
18
18
|
from logging import getLogger
|
|
19
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
19
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
20
20
|
|
|
21
21
|
import xoscar as xo
|
|
22
22
|
|
|
@@ -74,6 +74,69 @@ class SupervisorActor(xo.Actor):
|
|
|
74
74
|
|
|
75
75
|
raise RuntimeError("No available worker found")
|
|
76
76
|
|
|
77
|
+
@log_sync(logger=logger)
|
|
78
|
+
def list_model_registrations(self, model_type: str) -> List[Dict[str, Any]]:
|
|
79
|
+
if model_type == "LLM":
|
|
80
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
|
|
81
|
+
|
|
82
|
+
ret = [
|
|
83
|
+
{"model_name": f.model_name, "is_builtin": True}
|
|
84
|
+
for f in BUILTIN_LLM_FAMILIES
|
|
85
|
+
]
|
|
86
|
+
user_defined_llm_families = get_user_defined_llm_families()
|
|
87
|
+
ret.extend(
|
|
88
|
+
[
|
|
89
|
+
{"model_name": f.model_name, "is_builtin": False}
|
|
90
|
+
for f in user_defined_llm_families
|
|
91
|
+
]
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return ret
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
97
|
+
|
|
98
|
+
@log_sync(logger=logger)
|
|
99
|
+
def get_model_registration(
|
|
100
|
+
self, model_type: str, model_name: str
|
|
101
|
+
) -> Dict[str, Any]:
|
|
102
|
+
if model_type == "LLM":
|
|
103
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
|
|
104
|
+
|
|
105
|
+
for f in BUILTIN_LLM_FAMILIES + get_user_defined_llm_families():
|
|
106
|
+
if f.model_name == model_name:
|
|
107
|
+
return f
|
|
108
|
+
|
|
109
|
+
raise ValueError(f"Model {model_name} not found")
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
112
|
+
|
|
113
|
+
@log_async(logger=logger)
|
|
114
|
+
async def register_model(self, model_type: str, model: str, persist: bool):
|
|
115
|
+
if model_type == "LLM":
|
|
116
|
+
from ..model.llm import LLMFamilyV1, register_llm
|
|
117
|
+
|
|
118
|
+
llm_family = LLMFamilyV1.parse_raw(model)
|
|
119
|
+
register_llm(llm_family, persist)
|
|
120
|
+
|
|
121
|
+
if not self.is_local_deployment:
|
|
122
|
+
for worker in self._worker_address_to_worker.values():
|
|
123
|
+
await worker.register_model(model_type, model, persist)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
126
|
+
|
|
127
|
+
@log_async(logger=logger)
|
|
128
|
+
async def unregister_model(self, model_type: str, model_name: str):
|
|
129
|
+
if model_type == "LLM":
|
|
130
|
+
from ..model.llm import unregister_llm
|
|
131
|
+
|
|
132
|
+
unregister_llm(model_name)
|
|
133
|
+
|
|
134
|
+
if not self.is_local_deployment:
|
|
135
|
+
for worker in self._worker_address_to_worker.values():
|
|
136
|
+
await worker.unregister_model(model_name)
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
139
|
+
|
|
77
140
|
async def launch_builtin_model(
|
|
78
141
|
self,
|
|
79
142
|
model_uid: str,
|
|
@@ -108,8 +108,30 @@ class WorkerActor(xo.Actor):
|
|
|
108
108
|
"model_format": llm_spec.model_format,
|
|
109
109
|
"model_size_in_billions": llm_spec.model_size_in_billions,
|
|
110
110
|
"quantization": quantization,
|
|
111
|
+
"revision": llm_spec.model_revision,
|
|
111
112
|
}
|
|
112
113
|
|
|
114
|
+
@log_sync(logger=logger)
|
|
115
|
+
async def register_model(self, model_type: str, model: str, persist: bool):
|
|
116
|
+
# TODO: centralized model registrations
|
|
117
|
+
if model_type == "LLM":
|
|
118
|
+
from ..model.llm import LLMFamilyV1, register_llm
|
|
119
|
+
|
|
120
|
+
llm_family = LLMFamilyV1.parse_raw(model)
|
|
121
|
+
register_llm(llm_family, persist)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
124
|
+
|
|
125
|
+
@log_sync(logger=logger)
|
|
126
|
+
async def unregister_model(self, model_type: str, model_name: str):
|
|
127
|
+
# TODO: centralized model registrations
|
|
128
|
+
if model_type == "LLM":
|
|
129
|
+
from ..model.llm import unregister_llm
|
|
130
|
+
|
|
131
|
+
unregister_llm(model_name)
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
134
|
+
|
|
113
135
|
@log_async(logger=logger)
|
|
114
136
|
async def launch_builtin_model(
|
|
115
137
|
self,
|
|
@@ -11,10 +11,10 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
import configparser
|
|
16
15
|
import logging
|
|
17
16
|
import os
|
|
17
|
+
import sys
|
|
18
18
|
from typing import Optional
|
|
19
19
|
|
|
20
20
|
import click
|
|
@@ -30,6 +30,32 @@ from ..constants import (
|
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def get_config_string(log_level: str) -> str:
|
|
34
|
+
return f"""
|
|
35
|
+
[loggers]
|
|
36
|
+
keys=root
|
|
37
|
+
|
|
38
|
+
[handlers]
|
|
39
|
+
keys=stream_handler
|
|
40
|
+
|
|
41
|
+
[formatters]
|
|
42
|
+
keys=formatter
|
|
43
|
+
|
|
44
|
+
[logger_root]
|
|
45
|
+
level={log_level.upper()}
|
|
46
|
+
handlers=stream_handler
|
|
47
|
+
|
|
48
|
+
[handler_stream_handler]
|
|
49
|
+
class=StreamHandler
|
|
50
|
+
formatter=formatter
|
|
51
|
+
level={log_level.upper()}
|
|
52
|
+
args=(sys.stderr,)
|
|
53
|
+
|
|
54
|
+
[formatter_formatter]
|
|
55
|
+
format=%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
|
|
33
59
|
def get_endpoint(endpoint: Optional[str]) -> str:
|
|
34
60
|
# user didn't specify the endpoint.
|
|
35
61
|
if endpoint is None:
|
|
@@ -57,9 +83,10 @@ def cli(
|
|
|
57
83
|
if ctx.invoked_subcommand is None:
|
|
58
84
|
from .local import main
|
|
59
85
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
logging_conf
|
|
86
|
+
logging_conf = configparser.RawConfigParser()
|
|
87
|
+
logger_config_string = get_config_string(log_level)
|
|
88
|
+
logging_conf.read_string(logger_config_string)
|
|
89
|
+
logging.config.fileConfig(logging_conf) # type: ignore
|
|
63
90
|
|
|
64
91
|
address = f"{host}:{get_next_port()}"
|
|
65
92
|
|
|
@@ -102,9 +129,10 @@ def supervisor(
|
|
|
102
129
|
def worker(log_level: str, endpoint: Optional[str], host: str):
|
|
103
130
|
from ..deploy.worker import main
|
|
104
131
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
logging_conf
|
|
132
|
+
logging_conf = configparser.RawConfigParser()
|
|
133
|
+
logger_config_string = get_config_string(log_level)
|
|
134
|
+
logging_conf.read_string(logger_config_string)
|
|
135
|
+
logging.config.fileConfig(level=logging.getLevelName(log_level.upper())) # type: ignore
|
|
108
136
|
|
|
109
137
|
endpoint = get_endpoint(endpoint)
|
|
110
138
|
|
|
@@ -146,7 +174,7 @@ def model_launch(
|
|
|
146
174
|
quantization=quantization,
|
|
147
175
|
)
|
|
148
176
|
|
|
149
|
-
print(f"Model uid: {model_uid}")
|
|
177
|
+
print(f"Model uid: {model_uid}", file=sys.stderr)
|
|
150
178
|
|
|
151
179
|
|
|
152
180
|
@cli.command("list")
|
|
@@ -157,18 +185,16 @@ def model_launch(
|
|
|
157
185
|
)
|
|
158
186
|
@click.option("--all", is_flag=True)
|
|
159
187
|
def model_list(endpoint: Optional[str], all: bool):
|
|
160
|
-
import sys
|
|
161
|
-
|
|
162
188
|
from tabulate import tabulate
|
|
163
189
|
|
|
164
190
|
# TODO: get from the supervisor
|
|
165
|
-
from ..model.llm import
|
|
191
|
+
from ..model.llm import BUILTIN_LLM_FAMILIES
|
|
166
192
|
|
|
167
193
|
endpoint = get_endpoint(endpoint)
|
|
168
194
|
|
|
169
195
|
table = []
|
|
170
196
|
if all:
|
|
171
|
-
for model_family in
|
|
197
|
+
for model_family in BUILTIN_LLM_FAMILIES:
|
|
172
198
|
table.append(
|
|
173
199
|
[
|
|
174
200
|
model_family.model_name,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import logging
|
|
17
|
-
from typing import Dict, Optional
|
|
17
|
+
from typing import Any, Dict, Optional
|
|
18
18
|
|
|
19
19
|
import xoscar as xo
|
|
20
20
|
|
|
@@ -53,7 +53,7 @@ async def _start_worker(
|
|
|
53
53
|
await pool.join()
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def main(address: str, supervisor_address: str, logging_conf:
|
|
56
|
+
def main(address: str, supervisor_address: str, logging_conf: Any = None):
|
|
57
57
|
loop = asyncio.get_event_loop()
|
|
58
58
|
task = loop.create_task(_start_worker(address, supervisor_address, logging_conf))
|
|
59
59
|
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import codecs
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
from .core import LLM
|
|
20
|
+
from .llm_family import (
|
|
21
|
+
BUILTIN_LLM_FAMILIES,
|
|
22
|
+
LLM_CLASSES,
|
|
23
|
+
GgmlLLMSpecV1,
|
|
24
|
+
LLMFamilyV1,
|
|
25
|
+
LLMSpecV1,
|
|
26
|
+
PromptStyleV1,
|
|
27
|
+
PytorchLLMSpecV1,
|
|
28
|
+
get_user_defined_llm_families,
|
|
29
|
+
match_llm,
|
|
30
|
+
match_llm_cls,
|
|
31
|
+
register_llm,
|
|
32
|
+
unregister_llm,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _install():
|
|
37
|
+
from .ggml.chatglm import ChatglmCppChatModel
|
|
38
|
+
from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel
|
|
39
|
+
from .pytorch.baichuan import BaichuanPytorchChatModel
|
|
40
|
+
from .pytorch.chatglm import ChatglmPytorchChatModel
|
|
41
|
+
from .pytorch.core import PytorchChatModel, PytorchModel
|
|
42
|
+
from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
|
|
43
|
+
from .pytorch.vicuna import VicunaPytorchChatModel
|
|
44
|
+
|
|
45
|
+
LLM_CLASSES.extend(
|
|
46
|
+
[
|
|
47
|
+
ChatglmCppChatModel,
|
|
48
|
+
LlamaCppModel,
|
|
49
|
+
LlamaCppChatModel,
|
|
50
|
+
PytorchModel,
|
|
51
|
+
PytorchChatModel,
|
|
52
|
+
BaichuanPytorchChatModel,
|
|
53
|
+
VicunaPytorchChatModel,
|
|
54
|
+
FalconPytorchModel,
|
|
55
|
+
FalconPytorchChatModel,
|
|
56
|
+
ChatglmPytorchChatModel,
|
|
57
|
+
]
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
json_path = os.path.join(
|
|
61
|
+
os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
|
|
62
|
+
)
|
|
63
|
+
for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
|
|
64
|
+
BUILTIN_LLM_FAMILIES.append(LLMFamilyV1.parse_obj(json_obj))
|
|
65
|
+
|
|
66
|
+
from ...constants import XINFERENCE_MODEL_DIR
|
|
67
|
+
|
|
68
|
+
user_defined_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "llm")
|
|
69
|
+
if os.path.isdir(user_defined_llm_dir):
|
|
70
|
+
for f in os.listdir(user_defined_llm_dir):
|
|
71
|
+
with codecs.open(f, encoding="utf-8") as fd:
|
|
72
|
+
user_defined_llm_family = LLMFamilyV1.parse_obj(json.load(fd))
|
|
73
|
+
register_llm(user_defined_llm_family, persist=False)
|