xinference 0.7.4.1__py3-none-any.whl → 0.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +22 -8
- xinference/client/oscar/actor_client.py +78 -8
- xinference/core/model.py +14 -7
- xinference/core/supervisor.py +12 -0
- xinference/deploy/cmdline.py +16 -0
- xinference/deploy/test/test_cmdline.py +1 -0
- xinference/model/embedding/model_spec.json +40 -0
- xinference/model/llm/__init__.py +14 -1
- xinference/model/llm/llm_family.json +10 -1
- xinference/model/llm/llm_family.py +38 -2
- xinference/model/llm/llm_family_modelscope.json +10 -1
- xinference/model/llm/pytorch/chatglm.py +1 -0
- xinference/model/llm/pytorch/core.py +1 -1
- xinference/model/llm/pytorch/utils.py +50 -18
- xinference/model/llm/utils.py +2 -2
- xinference/model/llm/vllm/core.py +13 -4
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.31d347d8.js → main.236e72e7.js} +3 -3
- xinference/web/ui/build/static/js/main.236e72e7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/78f2521da2e2a98b075a2666cb782c7e2c019cd3c72199eecd5901c82d8655df.json +1 -0
- {xinference-0.7.4.1.dist-info → xinference-0.7.5.dist-info}/METADATA +9 -2
- {xinference-0.7.4.1.dist-info → xinference-0.7.5.dist-info}/RECORD +29 -29
- xinference/web/ui/build/static/js/main.31d347d8.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/ca8515ecefb4a06c5305417bfd9c04e13cf6b9103f52a47c925921b26c0a9f9d.json +0 -1
- /xinference/web/ui/build/static/js/{main.31d347d8.js.LICENSE.txt → main.236e72e7.js.LICENSE.txt} +0 -0
- {xinference-0.7.4.1.dist-info → xinference-0.7.5.dist-info}/LICENSE +0 -0
- {xinference-0.7.4.1.dist-info → xinference-0.7.5.dist-info}/WHEEL +0 -0
- {xinference-0.7.4.1.dist-info → xinference-0.7.5.dist-info}/entry_points.txt +0 -0
- {xinference-0.7.4.1.dist-info → xinference-0.7.5.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "
|
|
11
|
+
"date": "2024-01-05T15:29:43+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.7.
|
|
14
|
+
"full-revisionid": "56b28b3e4149b0a9ab6f5322401b1c3f1fc95c1a",
|
|
15
|
+
"version": "0.7.5"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -160,6 +160,9 @@ class RESTfulAPI:
|
|
|
160
160
|
self._router.add_api_route(
|
|
161
161
|
"/v1/models/prompts", self._get_builtin_prompts, methods=["GET"]
|
|
162
162
|
)
|
|
163
|
+
self._router.add_api_route(
|
|
164
|
+
"/v1/models/families", self._get_builtin_families, methods=["GET"]
|
|
165
|
+
)
|
|
163
166
|
self._router.add_api_route(
|
|
164
167
|
"/v1/cluster/devices", self._get_devices_count, methods=["GET"]
|
|
165
168
|
)
|
|
@@ -312,6 +315,17 @@ class RESTfulAPI:
|
|
|
312
315
|
logger.error(e, exc_info=True)
|
|
313
316
|
raise HTTPException(status_code=500, detail=str(e))
|
|
314
317
|
|
|
318
|
+
async def _get_builtin_families(self) -> JSONResponse:
|
|
319
|
+
"""
|
|
320
|
+
For internal usage
|
|
321
|
+
"""
|
|
322
|
+
try:
|
|
323
|
+
data = await (await self._get_supervisor_ref()).get_builtin_families()
|
|
324
|
+
return JSONResponse(content=data)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
logger.error(e, exc_info=True)
|
|
327
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
328
|
+
|
|
315
329
|
async def _get_devices_count(self) -> JSONResponse:
|
|
316
330
|
"""
|
|
317
331
|
For internal usage
|
|
@@ -565,7 +579,7 @@ class RESTfulAPI:
|
|
|
565
579
|
except RuntimeError as re:
|
|
566
580
|
self.handle_request_limit_error(re)
|
|
567
581
|
async for item in iterator:
|
|
568
|
-
yield
|
|
582
|
+
yield item
|
|
569
583
|
except Exception as ex:
|
|
570
584
|
if iterator is not None:
|
|
571
585
|
await iterator.destroy()
|
|
@@ -577,7 +591,7 @@ class RESTfulAPI:
|
|
|
577
591
|
else:
|
|
578
592
|
try:
|
|
579
593
|
data = await model.generate(body.prompt, kwargs)
|
|
580
|
-
return
|
|
594
|
+
return Response(data, media_type="application/json")
|
|
581
595
|
except Exception as e:
|
|
582
596
|
logger.error(e, exc_info=True)
|
|
583
597
|
self.handle_request_limit_error(e)
|
|
@@ -634,7 +648,7 @@ class RESTfulAPI:
|
|
|
634
648
|
logger.error(e, exc_info=True)
|
|
635
649
|
raise HTTPException(status_code=500, detail=str(e))
|
|
636
650
|
|
|
637
|
-
async def create_images(self, request: TextToImageRequest) ->
|
|
651
|
+
async def create_images(self, request: TextToImageRequest) -> Response:
|
|
638
652
|
model_uid = request.model
|
|
639
653
|
try:
|
|
640
654
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -655,7 +669,7 @@ class RESTfulAPI:
|
|
|
655
669
|
response_format=request.response_format,
|
|
656
670
|
**kwargs,
|
|
657
671
|
)
|
|
658
|
-
return
|
|
672
|
+
return Response(content=image_list, media_type="application/json")
|
|
659
673
|
except RuntimeError as re:
|
|
660
674
|
logger.error(re, exc_info=True)
|
|
661
675
|
self.handle_request_limit_error(re)
|
|
@@ -674,7 +688,7 @@ class RESTfulAPI:
|
|
|
674
688
|
response_format: Optional[str] = Form("url"),
|
|
675
689
|
size: Optional[str] = Form("1024*1024"),
|
|
676
690
|
kwargs: Optional[str] = Form(None),
|
|
677
|
-
) ->
|
|
691
|
+
) -> Response:
|
|
678
692
|
model_uid = model
|
|
679
693
|
try:
|
|
680
694
|
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -697,7 +711,7 @@ class RESTfulAPI:
|
|
|
697
711
|
response_format=response_format,
|
|
698
712
|
**kwargs,
|
|
699
713
|
)
|
|
700
|
-
return
|
|
714
|
+
return Response(content=image_list, media_type="application/json")
|
|
701
715
|
except RuntimeError as re:
|
|
702
716
|
logger.error(re, exc_info=True)
|
|
703
717
|
raise HTTPException(status_code=400, detail=str(re))
|
|
@@ -828,7 +842,7 @@ class RESTfulAPI:
|
|
|
828
842
|
except RuntimeError as re:
|
|
829
843
|
self.handle_request_limit_error(re)
|
|
830
844
|
async for item in iterator:
|
|
831
|
-
yield
|
|
845
|
+
yield item
|
|
832
846
|
except Exception as ex:
|
|
833
847
|
if iterator is not None:
|
|
834
848
|
await iterator.destroy()
|
|
@@ -843,7 +857,7 @@ class RESTfulAPI:
|
|
|
843
857
|
data = await model.chat(prompt, chat_history, kwargs)
|
|
844
858
|
else:
|
|
845
859
|
data = await model.chat(prompt, system_prompt, chat_history, kwargs)
|
|
846
|
-
return
|
|
860
|
+
return Response(content=data, media_type="application/json")
|
|
847
861
|
except Exception as e:
|
|
848
862
|
logger.error(e, exc_info=True)
|
|
849
863
|
self.handle_request_limit_error(e)
|
|
@@ -13,11 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
+
import re
|
|
16
17
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
|
17
18
|
|
|
19
|
+
import orjson
|
|
18
20
|
import xoscar as xo
|
|
19
21
|
|
|
20
|
-
from ...core.model import ModelActor
|
|
22
|
+
from ...core.model import IteratorWrapper, ModelActor
|
|
21
23
|
from ...core.supervisor import SupervisorActor
|
|
22
24
|
from ...isolation import Isolation
|
|
23
25
|
from ..restful.restful_client import Client
|
|
@@ -38,6 +40,52 @@ if TYPE_CHECKING:
|
|
|
38
40
|
)
|
|
39
41
|
|
|
40
42
|
|
|
43
|
+
class SSEEvent(object):
|
|
44
|
+
# https://github.com/btubbs/sseclient/blob/master/sseclient.py
|
|
45
|
+
sse_line_pattern = re.compile("(?P<name>[^:]*):?( ?(?P<value>.*))?")
|
|
46
|
+
|
|
47
|
+
def __init__(self, data="", event="message", id=None, retry=None):
|
|
48
|
+
self.data = data
|
|
49
|
+
self.event = event
|
|
50
|
+
self.id = id
|
|
51
|
+
self.retry = retry
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def parse(cls, raw):
|
|
55
|
+
"""
|
|
56
|
+
Given a possibly-multiline string representing an SSE message, parse it
|
|
57
|
+
and return a Event object.
|
|
58
|
+
"""
|
|
59
|
+
msg = cls()
|
|
60
|
+
for line in raw.splitlines():
|
|
61
|
+
m = cls.sse_line_pattern.match(line)
|
|
62
|
+
if m is None:
|
|
63
|
+
# Malformed line. Discard but warn.
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
name = m.group("name")
|
|
67
|
+
if name == "":
|
|
68
|
+
# line began with a ":", so is a comment. Ignore
|
|
69
|
+
continue
|
|
70
|
+
value = m.group("value")
|
|
71
|
+
|
|
72
|
+
if name == "data":
|
|
73
|
+
# If we already have some data, then join to it with a newline.
|
|
74
|
+
# Else this is it.
|
|
75
|
+
if msg.data:
|
|
76
|
+
msg.data = "%s\n%s" % (msg.data, value)
|
|
77
|
+
else:
|
|
78
|
+
msg.data = value
|
|
79
|
+
elif name == "event":
|
|
80
|
+
msg.event = value
|
|
81
|
+
elif name == "id":
|
|
82
|
+
msg.id = value
|
|
83
|
+
elif name == "retry":
|
|
84
|
+
msg.retry = int(value)
|
|
85
|
+
|
|
86
|
+
return msg
|
|
87
|
+
|
|
88
|
+
|
|
41
89
|
class ModelHandle:
|
|
42
90
|
"""
|
|
43
91
|
A sync model interface (for rpc client) which provides type hints that makes it much easier to use xinference
|
|
@@ -49,6 +97,19 @@ class ModelHandle:
|
|
|
49
97
|
self._isolation = isolation
|
|
50
98
|
|
|
51
99
|
|
|
100
|
+
class ClientIteratorWrapper(IteratorWrapper):
|
|
101
|
+
async def __anext__(self):
|
|
102
|
+
r = await super().__anext__()
|
|
103
|
+
text = r.decode("utf-8")
|
|
104
|
+
return orjson.loads(SSEEvent.parse(text).data)
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def wrap(cls, iterator_wrapper):
|
|
108
|
+
c = cls.__new__(cls)
|
|
109
|
+
c.__dict__.update(iterator_wrapper.__dict__)
|
|
110
|
+
return c
|
|
111
|
+
|
|
112
|
+
|
|
52
113
|
class EmbeddingModelHandle(ModelHandle):
|
|
53
114
|
def create_embedding(self, input: Union[str, List[str]]) -> bytes:
|
|
54
115
|
"""
|
|
@@ -68,7 +129,7 @@ class EmbeddingModelHandle(ModelHandle):
|
|
|
68
129
|
"""
|
|
69
130
|
|
|
70
131
|
coro = self._model_ref.create_embedding(input)
|
|
71
|
-
return self._isolation.call(coro)
|
|
132
|
+
return orjson.loads(self._isolation.call(coro))
|
|
72
133
|
|
|
73
134
|
|
|
74
135
|
class RerankModelHandle(ModelHandle):
|
|
@@ -104,7 +165,7 @@ class RerankModelHandle(ModelHandle):
|
|
|
104
165
|
coro = self._model_ref.rerank(
|
|
105
166
|
documents, query, top_n, max_chunks_per_doc, return_documents
|
|
106
167
|
)
|
|
107
|
-
results = self._isolation.call(coro)
|
|
168
|
+
results = orjson.loads(self._isolation.call(coro))
|
|
108
169
|
for r in results["results"]:
|
|
109
170
|
r["document"] = documents[r["index"]]
|
|
110
171
|
return results
|
|
@@ -140,7 +201,10 @@ class GenerateModelHandle(EmbeddingModelHandle):
|
|
|
140
201
|
"""
|
|
141
202
|
|
|
142
203
|
coro = self._model_ref.generate(prompt, generate_config)
|
|
143
|
-
|
|
204
|
+
r = self._isolation.call(coro)
|
|
205
|
+
if isinstance(r, bytes):
|
|
206
|
+
return orjson.loads(r)
|
|
207
|
+
return ClientIteratorWrapper.wrap(r)
|
|
144
208
|
|
|
145
209
|
|
|
146
210
|
class ChatModelHandle(GenerateModelHandle):
|
|
@@ -185,7 +249,10 @@ class ChatModelHandle(GenerateModelHandle):
|
|
|
185
249
|
coro = self._model_ref.chat(
|
|
186
250
|
prompt, system_prompt, chat_history, generate_config
|
|
187
251
|
)
|
|
188
|
-
|
|
252
|
+
r = self._isolation.call(coro)
|
|
253
|
+
if isinstance(r, bytes):
|
|
254
|
+
return orjson.loads(r)
|
|
255
|
+
return ClientIteratorWrapper.wrap(r)
|
|
189
256
|
|
|
190
257
|
|
|
191
258
|
class ChatglmCppChatModelHandle(EmbeddingModelHandle):
|
|
@@ -217,7 +284,10 @@ class ChatglmCppChatModelHandle(EmbeddingModelHandle):
|
|
|
217
284
|
"""
|
|
218
285
|
|
|
219
286
|
coro = self._model_ref.chat(prompt, chat_history, generate_config)
|
|
220
|
-
|
|
287
|
+
r = self._isolation.call(coro)
|
|
288
|
+
if isinstance(r, bytes):
|
|
289
|
+
return orjson.loads(r)
|
|
290
|
+
return ClientIteratorWrapper.wrap(r)
|
|
221
291
|
|
|
222
292
|
|
|
223
293
|
class ImageModelHandle(ModelHandle):
|
|
@@ -249,7 +319,7 @@ class ImageModelHandle(ModelHandle):
|
|
|
249
319
|
"""
|
|
250
320
|
|
|
251
321
|
coro = self._model_ref.text_to_image(prompt, n, size, response_format, **kwargs)
|
|
252
|
-
return self._isolation.call(coro)
|
|
322
|
+
return orjson.loads(self._isolation.call(coro))
|
|
253
323
|
|
|
254
324
|
def image_to_image(
|
|
255
325
|
self,
|
|
@@ -294,7 +364,7 @@ class ImageModelHandle(ModelHandle):
|
|
|
294
364
|
coro = self._model_ref.image_to_image(
|
|
295
365
|
image, prompt, negative_prompt, n, size, response_format, **kwargs
|
|
296
366
|
)
|
|
297
|
-
return self._isolation.call(coro)
|
|
367
|
+
return orjson.loads(self._isolation.call(coro))
|
|
298
368
|
|
|
299
369
|
|
|
300
370
|
class ActorClient:
|
xinference/core/model.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import inspect
|
|
17
|
+
import json
|
|
17
18
|
import os
|
|
18
19
|
import uuid
|
|
19
20
|
from typing import (
|
|
@@ -30,6 +31,7 @@ from typing import (
|
|
|
30
31
|
Union,
|
|
31
32
|
)
|
|
32
33
|
|
|
34
|
+
import sse_starlette.sse
|
|
33
35
|
import xoscar as xo
|
|
34
36
|
|
|
35
37
|
if TYPE_CHECKING:
|
|
@@ -186,7 +188,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
186
188
|
)
|
|
187
189
|
)
|
|
188
190
|
|
|
189
|
-
|
|
191
|
+
def _wrap_generator(self, ret: Any):
|
|
190
192
|
if inspect.isgenerator(ret) or inspect.isasyncgen(ret):
|
|
191
193
|
if self._lock is not None and self._generators:
|
|
192
194
|
raise Exception("Parallel generation is not supported by ggml.")
|
|
@@ -199,7 +201,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
199
201
|
model_actor_uid=self.uid,
|
|
200
202
|
)
|
|
201
203
|
else:
|
|
202
|
-
return ret
|
|
204
|
+
return json_dumps(ret)
|
|
203
205
|
|
|
204
206
|
async def _call_wrapper(self, _wrapper: Callable):
|
|
205
207
|
try:
|
|
@@ -335,9 +337,10 @@ class ModelActor(xo.StatelessActor):
|
|
|
335
337
|
)
|
|
336
338
|
|
|
337
339
|
def _wrapper():
|
|
338
|
-
|
|
340
|
+
r = getattr(self._model, "text_to_image")(
|
|
339
341
|
prompt, n, size, response_format, *args, **kwargs
|
|
340
342
|
)
|
|
343
|
+
return json_dumps(r)
|
|
341
344
|
|
|
342
345
|
return await self._call_wrapper(_wrapper)
|
|
343
346
|
|
|
@@ -358,7 +361,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
358
361
|
)
|
|
359
362
|
|
|
360
363
|
def _wrapper():
|
|
361
|
-
|
|
364
|
+
r = getattr(self._model, "image_to_image")(
|
|
362
365
|
image,
|
|
363
366
|
prompt,
|
|
364
367
|
negative_prompt,
|
|
@@ -368,10 +371,10 @@ class ModelActor(xo.StatelessActor):
|
|
|
368
371
|
*args,
|
|
369
372
|
**kwargs,
|
|
370
373
|
)
|
|
374
|
+
return json_dumps(r)
|
|
371
375
|
|
|
372
376
|
return await self._call_wrapper(_wrapper)
|
|
373
377
|
|
|
374
|
-
@log_async(logger=logger)
|
|
375
378
|
async def next(
|
|
376
379
|
self, generator_uid: str
|
|
377
380
|
) -> Union["ChatCompletionChunk", "CompletionChunk"]:
|
|
@@ -381,14 +384,18 @@ class ModelActor(xo.StatelessActor):
|
|
|
381
384
|
|
|
382
385
|
def _wrapper():
|
|
383
386
|
try:
|
|
384
|
-
|
|
387
|
+
v = dict(data=json.dumps(next(gen)))
|
|
388
|
+
return sse_starlette.sse.ensure_bytes(v, None)
|
|
385
389
|
except StopIteration:
|
|
386
390
|
return stop
|
|
387
391
|
|
|
388
392
|
async def _async_wrapper():
|
|
389
393
|
try:
|
|
390
394
|
# anext is only available for Python >= 3.10
|
|
391
|
-
|
|
395
|
+
v = await gen.__anext__()
|
|
396
|
+
v = await asyncio.to_thread(json.dumps, v)
|
|
397
|
+
v = dict(data=v) # noqa: F821
|
|
398
|
+
return await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
|
|
392
399
|
except StopAsyncIteration:
|
|
393
400
|
return stop
|
|
394
401
|
|
xinference/core/supervisor.py
CHANGED
|
@@ -114,6 +114,18 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
114
114
|
data[k] = v.dict()
|
|
115
115
|
return data
|
|
116
116
|
|
|
117
|
+
@staticmethod
|
|
118
|
+
async def get_builtin_families() -> Dict[str, List[str]]:
|
|
119
|
+
from ..model.llm.llm_family import (
|
|
120
|
+
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
|
|
121
|
+
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
|
|
126
|
+
"generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
|
|
127
|
+
}
|
|
128
|
+
|
|
117
129
|
async def get_devices_count(self) -> int:
|
|
118
130
|
from ..utils import cuda_count
|
|
119
131
|
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -402,6 +402,22 @@ def list_model_registrations(
|
|
|
402
402
|
tabulate(table, headers=["Type", "Name", "Family", "Is-built-in"]),
|
|
403
403
|
file=sys.stderr,
|
|
404
404
|
)
|
|
405
|
+
elif model_type == "multimodal":
|
|
406
|
+
for registration in registrations:
|
|
407
|
+
model_name = registration["model_name"]
|
|
408
|
+
model_family = client.get_model_registration(model_type, model_name)
|
|
409
|
+
table.append(
|
|
410
|
+
[
|
|
411
|
+
model_type,
|
|
412
|
+
model_family["model_name"],
|
|
413
|
+
model_family["model_lang"],
|
|
414
|
+
registration["is_builtin"],
|
|
415
|
+
]
|
|
416
|
+
)
|
|
417
|
+
print(
|
|
418
|
+
tabulate(table, headers=["Type", "Name", "Language", "Is-built-in"]),
|
|
419
|
+
file=sys.stderr,
|
|
420
|
+
)
|
|
405
421
|
else:
|
|
406
422
|
raise NotImplementedError(f"List {model_type} is not implemented.")
|
|
407
423
|
|
|
@@ -142,5 +142,45 @@
|
|
|
142
142
|
"language": ["en"],
|
|
143
143
|
"model_id": "jinaai/jina-embeddings-v2-base-en",
|
|
144
144
|
"model_revision": "7302ac470bed880590f9344bfeee32ff8722d0e5"
|
|
145
|
+
},
|
|
146
|
+
{
|
|
147
|
+
"model_name": "text2vec-large-chinese",
|
|
148
|
+
"dimensions": 1024,
|
|
149
|
+
"max_tokens": 256,
|
|
150
|
+
"language": ["zh"],
|
|
151
|
+
"model_id": "shibing624/text2vec-bge-large-chinese",
|
|
152
|
+
"model_revision": "f5027ca48ea8316d63ee26d2b9bd27a061de33a3"
|
|
153
|
+
},
|
|
154
|
+
{
|
|
155
|
+
"model_name": "text2vec-base-chinese",
|
|
156
|
+
"dimensions": 768,
|
|
157
|
+
"max_tokens": 128,
|
|
158
|
+
"language": ["zh"],
|
|
159
|
+
"model_id": "shibing624/text2vec-base-chinese",
|
|
160
|
+
"model_revision": "8acc1289891d75f6b665ad623359798b55f86adb"
|
|
161
|
+
},
|
|
162
|
+
{
|
|
163
|
+
"model_name": "text2vec-base-chinese-paraphrase",
|
|
164
|
+
"dimensions": 768,
|
|
165
|
+
"max_tokens": 256,
|
|
166
|
+
"language": ["zh"],
|
|
167
|
+
"model_id": "shibing624/text2vec-base-chinese-paraphrase",
|
|
168
|
+
"model_revision": "beaf10481a5d9ca3b0daa9f0df6831ec956bf739"
|
|
169
|
+
},
|
|
170
|
+
{
|
|
171
|
+
"model_name": "text2vec-base-chinese-sentence",
|
|
172
|
+
"dimensions": 768,
|
|
173
|
+
"max_tokens": 256,
|
|
174
|
+
"language": ["zh"],
|
|
175
|
+
"model_id": "shibing624/text2vec-base-chinese-sentence",
|
|
176
|
+
"model_revision": "e73a94e821f22c6163166bfab9408d03933a5525"
|
|
177
|
+
},
|
|
178
|
+
{
|
|
179
|
+
"model_name": "text2vec-base-multilingual",
|
|
180
|
+
"dimensions": 384,
|
|
181
|
+
"max_tokens": 256,
|
|
182
|
+
"language": ["zh"],
|
|
183
|
+
"model_id": "shibing624/text2vec-base-multilingual",
|
|
184
|
+
"model_revision": "f241877385fa56ebcc75f04d1850e1579cfa661d"
|
|
145
185
|
}
|
|
146
186
|
]
|
xinference/model/llm/__init__.py
CHANGED
|
@@ -19,9 +19,12 @@ import os
|
|
|
19
19
|
from .core import LLM
|
|
20
20
|
from .llm_family import (
|
|
21
21
|
BUILTIN_LLM_FAMILIES,
|
|
22
|
+
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
|
|
23
|
+
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
|
|
22
24
|
BUILTIN_LLM_PROMPT_STYLE,
|
|
23
25
|
BUILTIN_MODELSCOPE_LLM_FAMILIES,
|
|
24
26
|
LLM_CLASSES,
|
|
27
|
+
CustomLLMFamilyV1,
|
|
25
28
|
GgmlLLMSpecV1,
|
|
26
29
|
LLMFamilyV1,
|
|
27
30
|
LLMSpecV1,
|
|
@@ -94,6 +97,11 @@ def _install():
|
|
|
94
97
|
# note that the key is the model name,
|
|
95
98
|
# since there are multiple representations of the same prompt style name in json.
|
|
96
99
|
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
|
|
100
|
+
# register model family
|
|
101
|
+
if "chat" in model_spec.model_ability:
|
|
102
|
+
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
|
|
103
|
+
else:
|
|
104
|
+
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
|
|
97
105
|
|
|
98
106
|
modelscope_json_path = os.path.join(
|
|
99
107
|
os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
|
|
@@ -110,6 +118,11 @@ def _install():
|
|
|
110
118
|
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
|
|
111
119
|
):
|
|
112
120
|
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
|
|
121
|
+
# register model family
|
|
122
|
+
if "chat" in model_spec.model_ability:
|
|
123
|
+
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
|
|
124
|
+
else:
|
|
125
|
+
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
|
|
113
126
|
|
|
114
127
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
115
128
|
|
|
@@ -119,5 +132,5 @@ def _install():
|
|
|
119
132
|
with codecs.open(
|
|
120
133
|
os.path.join(user_defined_llm_dir, f), encoding="utf-8"
|
|
121
134
|
) as fd:
|
|
122
|
-
user_defined_llm_family =
|
|
135
|
+
user_defined_llm_family = CustomLLMFamilyV1.parse_obj(json.load(fd))
|
|
123
136
|
register_llm(user_defined_llm_family, persist=False)
|
|
@@ -557,7 +557,7 @@
|
|
|
557
557
|
"none"
|
|
558
558
|
],
|
|
559
559
|
"model_id": "THUDM/chatglm3-6b",
|
|
560
|
-
"model_revision": "
|
|
560
|
+
"model_revision": "b098244a71fbe69ce149682d9072a7629f7e908c"
|
|
561
561
|
}
|
|
562
562
|
],
|
|
563
563
|
"prompt_style": {
|
|
@@ -566,6 +566,15 @@
|
|
|
566
566
|
"roles": [
|
|
567
567
|
"user",
|
|
568
568
|
"assistant"
|
|
569
|
+
],
|
|
570
|
+
"stop_token_ids": [
|
|
571
|
+
64795,
|
|
572
|
+
64797,
|
|
573
|
+
2
|
|
574
|
+
],
|
|
575
|
+
"stop":[
|
|
576
|
+
"<|user|>",
|
|
577
|
+
"<|observation|>"
|
|
569
578
|
]
|
|
570
579
|
}
|
|
571
580
|
},
|
|
@@ -17,7 +17,7 @@ import os
|
|
|
17
17
|
import platform
|
|
18
18
|
import shutil
|
|
19
19
|
from threading import Lock
|
|
20
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
20
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
|
21
21
|
|
|
22
22
|
from pydantic import BaseModel, Field, Protocol, ValidationError, validator
|
|
23
23
|
from pydantic.error_wrappers import ErrorWrapper
|
|
@@ -41,6 +41,8 @@ logger = logging.getLogger(__name__)
|
|
|
41
41
|
|
|
42
42
|
DEFAULT_CONTEXT_LENGTH = 2048
|
|
43
43
|
BUILTIN_LLM_PROMPT_STYLE: Dict[str, "PromptStyleV1"] = {}
|
|
44
|
+
BUILTIN_LLM_MODEL_CHAT_FAMILIES: Set[str] = set()
|
|
45
|
+
BUILTIN_LLM_MODEL_GENERATE_FAMILIES: Set[str] = set()
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
class GgmlLLMSpecV1(BaseModel):
|
|
@@ -105,6 +107,8 @@ class LLMFamilyV1(BaseModel):
|
|
|
105
107
|
model_lang: List[str]
|
|
106
108
|
model_ability: List[Literal["embed", "generate", "chat"]]
|
|
107
109
|
model_description: Optional[str]
|
|
110
|
+
# reason for not required str here: legacy registration
|
|
111
|
+
model_family: Optional[str]
|
|
108
112
|
model_specs: List["LLMSpecV1"]
|
|
109
113
|
prompt_style: Optional["PromptStyleV1"]
|
|
110
114
|
|
|
@@ -134,7 +138,39 @@ class CustomLLMFamilyV1(LLMFamilyV1):
|
|
|
134
138
|
)
|
|
135
139
|
except (ValueError, TypeError, UnicodeDecodeError) as e:
|
|
136
140
|
raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls)
|
|
137
|
-
llm_spec = cls.parse_obj(obj)
|
|
141
|
+
llm_spec: CustomLLMFamilyV1 = cls.parse_obj(obj)
|
|
142
|
+
|
|
143
|
+
# check model_family
|
|
144
|
+
if llm_spec.model_family is None:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
f"You must specify `model_family` when registering custom LLM models."
|
|
147
|
+
)
|
|
148
|
+
assert isinstance(llm_spec.model_family, str)
|
|
149
|
+
if (
|
|
150
|
+
llm_spec.model_family != "other"
|
|
151
|
+
and "chat" in llm_spec.model_ability
|
|
152
|
+
and llm_spec.model_family not in BUILTIN_LLM_MODEL_CHAT_FAMILIES
|
|
153
|
+
):
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"`model_family` for chat model must be `other` or one of the following values: \n"
|
|
156
|
+
f"{', '.join(list(BUILTIN_LLM_MODEL_CHAT_FAMILIES))}"
|
|
157
|
+
)
|
|
158
|
+
if (
|
|
159
|
+
llm_spec.model_family != "other"
|
|
160
|
+
and "chat" not in llm_spec.model_ability
|
|
161
|
+
and llm_spec.model_family not in BUILTIN_LLM_MODEL_GENERATE_FAMILIES
|
|
162
|
+
):
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"`model_family` for generate model must be `other` or one of the following values: \n"
|
|
165
|
+
f"{', '.join(list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES))}"
|
|
166
|
+
)
|
|
167
|
+
# set prompt style when it is the builtin model family
|
|
168
|
+
if (
|
|
169
|
+
llm_spec.prompt_style is None
|
|
170
|
+
and llm_spec.model_family != "other"
|
|
171
|
+
and "chat" in llm_spec.model_ability
|
|
172
|
+
):
|
|
173
|
+
llm_spec.prompt_style = llm_spec.model_family
|
|
138
174
|
|
|
139
175
|
# handle prompt style when user choose existing style
|
|
140
176
|
if llm_spec.prompt_style is not None and isinstance(llm_spec.prompt_style, str):
|
|
@@ -331,6 +331,15 @@
|
|
|
331
331
|
"roles": [
|
|
332
332
|
"user",
|
|
333
333
|
"assistant"
|
|
334
|
+
],
|
|
335
|
+
"stop_token_ids": [
|
|
336
|
+
64795,
|
|
337
|
+
64797,
|
|
338
|
+
2
|
|
339
|
+
],
|
|
340
|
+
"stop":[
|
|
341
|
+
"<|user|>",
|
|
342
|
+
"<|observation|>"
|
|
334
343
|
]
|
|
335
344
|
}
|
|
336
345
|
},
|
|
@@ -357,7 +366,7 @@
|
|
|
357
366
|
],
|
|
358
367
|
"model_hub": "modelscope",
|
|
359
368
|
"model_id": "ZhipuAI/chatglm3-6b-32k",
|
|
360
|
-
"model_revision": "
|
|
369
|
+
"model_revision": "master"
|
|
361
370
|
}
|
|
362
371
|
],
|
|
363
372
|
"prompt_style": {
|
|
@@ -58,6 +58,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
58
58
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
59
59
|
self.model_path,
|
|
60
60
|
trust_remote_code=kwargs["trust_remote_code"],
|
|
61
|
+
encode_special_tokens=True,
|
|
61
62
|
revision=kwargs["revision"],
|
|
62
63
|
)
|
|
63
64
|
model = AutoModel.from_pretrained(
|
|
@@ -409,7 +409,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
409
409
|
) -> PytorchGenerateConfig:
|
|
410
410
|
generate_config = super()._sanitize_generate_config(generate_config)
|
|
411
411
|
if (
|
|
412
|
-
generate_config.get("stop"
|
|
412
|
+
(not generate_config.get("stop"))
|
|
413
413
|
and self.model_family.prompt_style
|
|
414
414
|
and self.model_family.prompt_style.stop
|
|
415
415
|
):
|