xinference 0.12.1__py3-none-any.whl → 0.12.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +34 -8
- xinference/client/restful/restful_client.py +4 -0
- xinference/core/event.py +5 -6
- xinference/core/model.py +8 -3
- xinference/core/scheduler.py +13 -3
- xinference/model/llm/llm_family.json +6 -2
- xinference/model/llm/llm_family_modelscope.json +6 -2
- xinference/model/llm/pytorch/chatglm.py +23 -0
- xinference/model/llm/pytorch/core.py +39 -49
- xinference/model/llm/pytorch/glm4v.py +11 -0
- xinference/model/llm/pytorch/internlm2.py +15 -0
- xinference/model/llm/pytorch/utils.py +46 -179
- xinference/model/llm/utils.py +14 -2
- xinference/model/rerank/core.py +35 -6
- xinference/types.py +28 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
- xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
- {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/METADATA +1 -1
- {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/RECORD +41 -40
- xinference/web/ui/build/static/css/main.074e2b31.css +0 -2
- xinference/web/ui/build/static/css/main.074e2b31.css.map +0 -1
- xinference/web/ui/build/static/js/main.a58ff436.js +0 -3
- xinference/web/ui/build/static/js/main.a58ff436.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +0 -1
- /xinference/web/ui/build/static/js/{main.a58ff436.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
- {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/LICENSE +0 -0
- {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/WHEEL +0 -0
- {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-06-
|
|
11
|
+
"date": "2024-06-22T23:28:43+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.12.
|
|
14
|
+
"full-revisionid": "7705d4ae1eb68523e87c4f2abf84026dae18b694",
|
|
15
|
+
"version": "0.12.2.post1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -109,6 +109,7 @@ class RerankRequest(BaseModel):
|
|
|
109
109
|
documents: List[str]
|
|
110
110
|
top_n: Optional[int] = None
|
|
111
111
|
return_documents: Optional[bool] = False
|
|
112
|
+
return_len: Optional[bool] = False
|
|
112
113
|
max_chunks_per_doc: Optional[int] = None
|
|
113
114
|
|
|
114
115
|
|
|
@@ -981,7 +982,8 @@ class RESTfulAPI:
|
|
|
981
982
|
return JSONResponse(content=self._supervisor_address)
|
|
982
983
|
|
|
983
984
|
async def create_completion(self, request: Request) -> Response:
|
|
984
|
-
|
|
985
|
+
raw_body = await request.json()
|
|
986
|
+
body = CreateCompletionRequest.parse_obj(raw_body)
|
|
985
987
|
exclude = {
|
|
986
988
|
"prompt",
|
|
987
989
|
"model",
|
|
@@ -991,6 +993,7 @@ class RESTfulAPI:
|
|
|
991
993
|
"logit_bias_type",
|
|
992
994
|
"user",
|
|
993
995
|
}
|
|
996
|
+
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
994
997
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
995
998
|
|
|
996
999
|
# TODO: Decide if this default value override is necessary #1061
|
|
@@ -1020,7 +1023,9 @@ class RESTfulAPI:
|
|
|
1020
1023
|
iterator = None
|
|
1021
1024
|
try:
|
|
1022
1025
|
try:
|
|
1023
|
-
iterator = await model.generate(
|
|
1026
|
+
iterator = await model.generate(
|
|
1027
|
+
body.prompt, kwargs, raw_params=raw_kwargs
|
|
1028
|
+
)
|
|
1024
1029
|
except RuntimeError as re:
|
|
1025
1030
|
self.handle_request_limit_error(re)
|
|
1026
1031
|
async for item in iterator:
|
|
@@ -1040,7 +1045,7 @@ class RESTfulAPI:
|
|
|
1040
1045
|
return EventSourceResponse(stream_results())
|
|
1041
1046
|
else:
|
|
1042
1047
|
try:
|
|
1043
|
-
data = await model.generate(body.prompt, kwargs)
|
|
1048
|
+
data = await model.generate(body.prompt, kwargs, raw_params=raw_kwargs)
|
|
1044
1049
|
return Response(data, media_type="application/json")
|
|
1045
1050
|
except Exception as e:
|
|
1046
1051
|
logger.error(e, exc_info=True)
|
|
@@ -1112,6 +1117,7 @@ class RESTfulAPI:
|
|
|
1112
1117
|
top_n=body.top_n,
|
|
1113
1118
|
max_chunks_per_doc=body.max_chunks_per_doc,
|
|
1114
1119
|
return_documents=body.return_documents,
|
|
1120
|
+
return_len=body.return_len,
|
|
1115
1121
|
**kwargs,
|
|
1116
1122
|
)
|
|
1117
1123
|
return Response(scores, media_type="application/json")
|
|
@@ -1341,7 +1347,8 @@ class RESTfulAPI:
|
|
|
1341
1347
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1342
1348
|
|
|
1343
1349
|
async def create_chat_completion(self, request: Request) -> Response:
|
|
1344
|
-
|
|
1350
|
+
raw_body = await request.json()
|
|
1351
|
+
body = CreateChatCompletion.parse_obj(raw_body)
|
|
1345
1352
|
exclude = {
|
|
1346
1353
|
"prompt",
|
|
1347
1354
|
"model",
|
|
@@ -1351,6 +1358,7 @@ class RESTfulAPI:
|
|
|
1351
1358
|
"logit_bias_type",
|
|
1352
1359
|
"user",
|
|
1353
1360
|
}
|
|
1361
|
+
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
1354
1362
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
1355
1363
|
|
|
1356
1364
|
# TODO: Decide if this default value override is necessary #1061
|
|
@@ -1425,7 +1433,9 @@ class RESTfulAPI:
|
|
|
1425
1433
|
"gorilla-openfunctions-v1",
|
|
1426
1434
|
"qwen-chat",
|
|
1427
1435
|
"qwen1.5-chat",
|
|
1436
|
+
"qwen1.5-moe-chat",
|
|
1428
1437
|
"qwen2-instruct",
|
|
1438
|
+
"qwen2-moe-instruct",
|
|
1429
1439
|
]
|
|
1430
1440
|
|
|
1431
1441
|
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
|
|
@@ -1451,7 +1461,9 @@ class RESTfulAPI:
|
|
|
1451
1461
|
if not is_vllm or model_family not in [
|
|
1452
1462
|
"qwen-chat",
|
|
1453
1463
|
"qwen1.5-chat",
|
|
1464
|
+
"qwen1.5-moe-chat",
|
|
1454
1465
|
"qwen2-instruct",
|
|
1466
|
+
"qwen2-moe-instruct",
|
|
1455
1467
|
]:
|
|
1456
1468
|
raise HTTPException(
|
|
1457
1469
|
status_code=400,
|
|
@@ -1465,10 +1477,16 @@ class RESTfulAPI:
|
|
|
1465
1477
|
try:
|
|
1466
1478
|
try:
|
|
1467
1479
|
if is_qwen:
|
|
1468
|
-
iterator = await model.chat(
|
|
1480
|
+
iterator = await model.chat(
|
|
1481
|
+
prompt, chat_history, kwargs, raw_params=raw_kwargs
|
|
1482
|
+
)
|
|
1469
1483
|
else:
|
|
1470
1484
|
iterator = await model.chat(
|
|
1471
|
-
prompt,
|
|
1485
|
+
prompt,
|
|
1486
|
+
system_prompt,
|
|
1487
|
+
chat_history,
|
|
1488
|
+
kwargs,
|
|
1489
|
+
raw_params=raw_kwargs,
|
|
1472
1490
|
)
|
|
1473
1491
|
except RuntimeError as re:
|
|
1474
1492
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1498,9 +1516,17 @@ class RESTfulAPI:
|
|
|
1498
1516
|
else:
|
|
1499
1517
|
try:
|
|
1500
1518
|
if is_qwen:
|
|
1501
|
-
data = await model.chat(
|
|
1519
|
+
data = await model.chat(
|
|
1520
|
+
prompt, chat_history, kwargs, raw_params=raw_kwargs
|
|
1521
|
+
)
|
|
1502
1522
|
else:
|
|
1503
|
-
data = await model.chat(
|
|
1523
|
+
data = await model.chat(
|
|
1524
|
+
prompt,
|
|
1525
|
+
system_prompt,
|
|
1526
|
+
chat_history,
|
|
1527
|
+
kwargs,
|
|
1528
|
+
raw_params=raw_kwargs,
|
|
1529
|
+
)
|
|
1504
1530
|
return Response(content=data, media_type="application/json")
|
|
1505
1531
|
except Exception as e:
|
|
1506
1532
|
logger.error(e, exc_info=True)
|
|
@@ -135,6 +135,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
135
135
|
top_n: Optional[int] = None,
|
|
136
136
|
max_chunks_per_doc: Optional[int] = None,
|
|
137
137
|
return_documents: Optional[bool] = None,
|
|
138
|
+
return_len: Optional[bool] = None,
|
|
138
139
|
**kwargs,
|
|
139
140
|
):
|
|
140
141
|
"""
|
|
@@ -152,6 +153,8 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
152
153
|
The maximum number of chunks derived from a document
|
|
153
154
|
return_documents: bool
|
|
154
155
|
if return documents
|
|
156
|
+
return_len: bool
|
|
157
|
+
if return tokens len
|
|
155
158
|
Returns
|
|
156
159
|
-------
|
|
157
160
|
Scores
|
|
@@ -170,6 +173,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
170
173
|
"top_n": top_n,
|
|
171
174
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
172
175
|
"return_documents": return_documents,
|
|
176
|
+
"return_len": return_len,
|
|
173
177
|
}
|
|
174
178
|
request_body.update(kwargs)
|
|
175
179
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
xinference/core/event.py
CHANGED
|
@@ -12,8 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
16
|
-
from collections import defaultdict
|
|
15
|
+
from collections import defaultdict, deque
|
|
17
16
|
from enum import Enum
|
|
18
17
|
from typing import Dict, List, TypedDict
|
|
19
18
|
|
|
@@ -37,8 +36,8 @@ class Event(TypedDict):
|
|
|
37
36
|
class EventCollectorActor(xo.StatelessActor):
|
|
38
37
|
def __init__(self):
|
|
39
38
|
super().__init__()
|
|
40
|
-
self._model_uid_to_events: Dict[str,
|
|
41
|
-
lambda:
|
|
39
|
+
self._model_uid_to_events: Dict[str, deque] = defaultdict( # type: ignore
|
|
40
|
+
lambda: deque(maxlen=MAX_EVENT_COUNT_PER_MODEL)
|
|
42
41
|
)
|
|
43
42
|
|
|
44
43
|
@classmethod
|
|
@@ -50,7 +49,7 @@ class EventCollectorActor(xo.StatelessActor):
|
|
|
50
49
|
if event_queue is None:
|
|
51
50
|
return []
|
|
52
51
|
else:
|
|
53
|
-
return [dict(e, event_type=e["event_type"].name) for e in event_queue
|
|
52
|
+
return [dict(e, event_type=e["event_type"].name) for e in iter(event_queue)]
|
|
54
53
|
|
|
55
54
|
def report_event(self, model_uid: str, event: Event):
|
|
56
|
-
self._model_uid_to_events[model_uid].
|
|
55
|
+
self._model_uid_to_events[model_uid].append(event)
|
xinference/core/model.py
CHANGED
|
@@ -264,13 +264,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
264
264
|
return isinstance(self._model, VLLMModel)
|
|
265
265
|
|
|
266
266
|
def allow_batching(self) -> bool:
|
|
267
|
-
from ..model.llm.pytorch.core import
|
|
267
|
+
from ..model.llm.pytorch.core import PytorchModel
|
|
268
|
+
|
|
269
|
+
model_ability = self._model_description.get("model_ability", [])
|
|
268
270
|
|
|
269
271
|
return (
|
|
270
272
|
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
271
273
|
and isinstance(self._model, PytorchModel)
|
|
272
|
-
and
|
|
273
|
-
in (PytorchChatModel.__name__, PytorchModel.__name__)
|
|
274
|
+
and "vision" not in model_ability
|
|
274
275
|
)
|
|
275
276
|
|
|
276
277
|
async def load(self):
|
|
@@ -399,6 +400,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
399
400
|
prompt, "generate", *args, **kwargs
|
|
400
401
|
)
|
|
401
402
|
else:
|
|
403
|
+
kwargs.pop("raw_params", None)
|
|
402
404
|
if hasattr(self._model, "generate"):
|
|
403
405
|
return await self._call_wrapper(
|
|
404
406
|
self._model.generate, prompt, *args, **kwargs
|
|
@@ -481,6 +483,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
481
483
|
prompt, "chat", *args, **kwargs
|
|
482
484
|
)
|
|
483
485
|
else:
|
|
486
|
+
kwargs.pop("raw_params", None)
|
|
484
487
|
if hasattr(self._model, "chat"):
|
|
485
488
|
response = await self._call_wrapper(
|
|
486
489
|
self._model.chat, prompt, *args, **kwargs
|
|
@@ -540,6 +543,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
540
543
|
top_n: Optional[int],
|
|
541
544
|
max_chunks_per_doc: Optional[int],
|
|
542
545
|
return_documents: Optional[bool],
|
|
546
|
+
return_len: Optional[bool],
|
|
543
547
|
*args,
|
|
544
548
|
**kwargs,
|
|
545
549
|
):
|
|
@@ -551,6 +555,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
551
555
|
top_n,
|
|
552
556
|
max_chunks_per_doc,
|
|
553
557
|
return_documents,
|
|
558
|
+
return_len,
|
|
554
559
|
*args,
|
|
555
560
|
**kwargs,
|
|
556
561
|
)
|
xinference/core/scheduler.py
CHANGED
|
@@ -18,7 +18,7 @@ import logging
|
|
|
18
18
|
import uuid
|
|
19
19
|
from collections import deque
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import List, Optional, Set
|
|
21
|
+
from typing import List, Optional, Set, Tuple
|
|
22
22
|
|
|
23
23
|
import xoscar as xo
|
|
24
24
|
|
|
@@ -53,7 +53,8 @@ class InferenceRequest:
|
|
|
53
53
|
self._kv_cache = None
|
|
54
54
|
# use passed args from upstream interface
|
|
55
55
|
self._inference_args = args
|
|
56
|
-
# use passed kwargs from upstream interface,
|
|
56
|
+
# use passed kwargs from upstream interface, currently for getting raw generate config from upstream,
|
|
57
|
+
# which is useful for some special models
|
|
57
58
|
self._inference_kwargs = kwargs
|
|
58
59
|
# should this request be stopped
|
|
59
60
|
self._stopped = False
|
|
@@ -66,6 +67,8 @@ class InferenceRequest:
|
|
|
66
67
|
self._sanitized_generate_config = None
|
|
67
68
|
# Chunk id for results. In stream mode, all the chunk ids should be same.
|
|
68
69
|
self._stream_chunk_id = str(uuid.uuid4())
|
|
70
|
+
# For calculate attention mask if needed
|
|
71
|
+
self.padding_len = 0
|
|
69
72
|
# Use in stream mode
|
|
70
73
|
self.last_output_length = 0
|
|
71
74
|
# inference results,
|
|
@@ -172,6 +175,10 @@ class InferenceRequest:
|
|
|
172
175
|
def sanitized_generate_config(self, value: dict):
|
|
173
176
|
self._sanitized_generate_config = value
|
|
174
177
|
|
|
178
|
+
@property
|
|
179
|
+
def inference_kwargs(self):
|
|
180
|
+
return self._inference_kwargs
|
|
181
|
+
|
|
175
182
|
@property
|
|
176
183
|
def stopped(self):
|
|
177
184
|
return self._stopped
|
|
@@ -231,7 +238,9 @@ class InferenceRequest:
|
|
|
231
238
|
)
|
|
232
239
|
|
|
233
240
|
@functools.lru_cache
|
|
234
|
-
def get_generate_configs(
|
|
241
|
+
def get_generate_configs(
|
|
242
|
+
self, eos_token_id: int, builtin_stop_token_ids: Optional[Tuple[int]] = None
|
|
243
|
+
):
|
|
235
244
|
from ..types import max_tokens_field
|
|
236
245
|
|
|
237
246
|
max_new_tokens = int(
|
|
@@ -245,6 +254,7 @@ class InferenceRequest:
|
|
|
245
254
|
)
|
|
246
255
|
stop_token_ids = set(stop_token_ids)
|
|
247
256
|
stop_token_ids.add(eos_token_id)
|
|
257
|
+
stop_token_ids.update(builtin_stop_token_ids or [])
|
|
248
258
|
temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
|
|
249
259
|
repetition_penalty = float(
|
|
250
260
|
self.sanitized_generate_config.get("repetition_penalty", 1.0)
|
|
@@ -2290,7 +2290,8 @@
|
|
|
2290
2290
|
"zh"
|
|
2291
2291
|
],
|
|
2292
2292
|
"model_ability": [
|
|
2293
|
-
"chat"
|
|
2293
|
+
"chat",
|
|
2294
|
+
"tools"
|
|
2294
2295
|
],
|
|
2295
2296
|
"model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
|
|
2296
2297
|
"model_specs": [
|
|
@@ -2595,7 +2596,8 @@
|
|
|
2595
2596
|
"zh"
|
|
2596
2597
|
],
|
|
2597
2598
|
"model_ability": [
|
|
2598
|
-
"chat"
|
|
2599
|
+
"chat",
|
|
2600
|
+
"tools"
|
|
2599
2601
|
],
|
|
2600
2602
|
"model_description": "Qwen2 is the new series of Qwen large language models. ",
|
|
2601
2603
|
"model_specs": [
|
|
@@ -5675,9 +5677,11 @@
|
|
|
5675
5677
|
],
|
|
5676
5678
|
"intra_message_sep": "<|im_end|>",
|
|
5677
5679
|
"stop_token_ids": [
|
|
5680
|
+
2,
|
|
5678
5681
|
92542
|
|
5679
5682
|
],
|
|
5680
5683
|
"stop": [
|
|
5684
|
+
"</s>",
|
|
5681
5685
|
"<|im_end|>"
|
|
5682
5686
|
]
|
|
5683
5687
|
}
|
|
@@ -2644,7 +2644,8 @@
|
|
|
2644
2644
|
"zh"
|
|
2645
2645
|
],
|
|
2646
2646
|
"model_ability": [
|
|
2647
|
-
"chat"
|
|
2647
|
+
"chat",
|
|
2648
|
+
"tools"
|
|
2648
2649
|
],
|
|
2649
2650
|
"model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
|
|
2650
2651
|
"model_specs": [
|
|
@@ -2968,7 +2969,8 @@
|
|
|
2968
2969
|
"zh"
|
|
2969
2970
|
],
|
|
2970
2971
|
"model_ability": [
|
|
2971
|
-
"chat"
|
|
2972
|
+
"chat",
|
|
2973
|
+
"tools"
|
|
2972
2974
|
],
|
|
2973
2975
|
"model_description": "Qwen2 is the new series of Qwen large language models. ",
|
|
2974
2976
|
"model_specs": [
|
|
@@ -3350,9 +3352,11 @@
|
|
|
3350
3352
|
],
|
|
3351
3353
|
"intra_message_sep": "<|im_end|>",
|
|
3352
3354
|
"stop_token_ids": [
|
|
3355
|
+
2,
|
|
3353
3356
|
92542
|
|
3354
3357
|
],
|
|
3355
3358
|
"stop": [
|
|
3359
|
+
"</s>",
|
|
3356
3360
|
"<|im_end|>"
|
|
3357
3361
|
]
|
|
3358
3362
|
}
|
|
@@ -15,6 +15,7 @@ import time
|
|
|
15
15
|
import uuid
|
|
16
16
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
17
17
|
|
|
18
|
+
from ....core.scheduler import InferenceRequest
|
|
18
19
|
from ....types import (
|
|
19
20
|
SPECIAL_TOOL_PROMPT,
|
|
20
21
|
ChatCompletion,
|
|
@@ -244,3 +245,25 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
244
245
|
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
245
246
|
),
|
|
246
247
|
)
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def require_attention_mask():
|
|
251
|
+
"""
|
|
252
|
+
GLM4 needs to use attention mask and position ids during inference.
|
|
253
|
+
Otherwise, the inference result would be not available.
|
|
254
|
+
"""
|
|
255
|
+
return True
|
|
256
|
+
|
|
257
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
258
|
+
"""
|
|
259
|
+
Set temperature and top_p to 0.8 by default
|
|
260
|
+
"""
|
|
261
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
262
|
+
temperature = raw_config.get("temperature", None)
|
|
263
|
+
if temperature is None:
|
|
264
|
+
raw_config["temperature"] = 0.8
|
|
265
|
+
top_p = raw_config.get("top_p", None)
|
|
266
|
+
if top_p is None:
|
|
267
|
+
raw_config["top_p"] = 0.8
|
|
268
|
+
|
|
269
|
+
return raw_config
|
|
@@ -16,7 +16,7 @@ import json
|
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
18
|
from functools import lru_cache
|
|
19
|
-
from typing import Iterable, Iterator, List, Optional, Union
|
|
19
|
+
from typing import Iterable, Iterator, List, Optional, Tuple, Union
|
|
20
20
|
|
|
21
21
|
from ....core.scheduler import InferenceRequest
|
|
22
22
|
from ....device_utils import (
|
|
@@ -283,35 +283,21 @@ class PytorchModel(LLM):
|
|
|
283
283
|
def generate(
|
|
284
284
|
self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
|
|
285
285
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
286
|
-
from .utils import generate_stream
|
|
287
|
-
|
|
288
|
-
model_family_name = self.model_family.model_name.lower()
|
|
286
|
+
from .utils import generate_stream
|
|
289
287
|
|
|
290
288
|
def generator_wrapper(
|
|
291
289
|
prompt: str, generate_config: PytorchGenerateConfig
|
|
292
290
|
) -> Iterator[CompletionChunk]:
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
yield completion_chunk
|
|
304
|
-
else:
|
|
305
|
-
for completion_chunk, completion_usage in generate_stream(
|
|
306
|
-
self.model_uid,
|
|
307
|
-
self._model,
|
|
308
|
-
self._tokenizer,
|
|
309
|
-
prompt,
|
|
310
|
-
self._device,
|
|
311
|
-
generate_config,
|
|
312
|
-
):
|
|
313
|
-
completion_chunk["usage"] = completion_usage
|
|
314
|
-
yield completion_chunk
|
|
291
|
+
for completion_chunk, completion_usage in generate_stream(
|
|
292
|
+
self.model_uid,
|
|
293
|
+
self._model,
|
|
294
|
+
self._tokenizer,
|
|
295
|
+
prompt,
|
|
296
|
+
self._device,
|
|
297
|
+
generate_config,
|
|
298
|
+
):
|
|
299
|
+
completion_chunk["usage"] = completion_usage
|
|
300
|
+
yield completion_chunk
|
|
315
301
|
|
|
316
302
|
logger.debug(
|
|
317
303
|
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
|
|
@@ -336,26 +322,15 @@ class PytorchModel(LLM):
|
|
|
336
322
|
|
|
337
323
|
stream = generate_config.get("stream", False)
|
|
338
324
|
if not stream:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
pass
|
|
349
|
-
else:
|
|
350
|
-
for completion_chunk, completion_usage in generate_stream(
|
|
351
|
-
self.model_uid,
|
|
352
|
-
self._model,
|
|
353
|
-
self._tokenizer,
|
|
354
|
-
prompt,
|
|
355
|
-
self._device,
|
|
356
|
-
generate_config,
|
|
357
|
-
):
|
|
358
|
-
pass
|
|
325
|
+
for completion_chunk, completion_usage in generate_stream(
|
|
326
|
+
self.model_uid,
|
|
327
|
+
self._model,
|
|
328
|
+
self._tokenizer,
|
|
329
|
+
prompt,
|
|
330
|
+
self._device,
|
|
331
|
+
generate_config,
|
|
332
|
+
):
|
|
333
|
+
pass
|
|
359
334
|
completion = Completion(
|
|
360
335
|
id=completion_chunk["id"],
|
|
361
336
|
object=completion_chunk["object"],
|
|
@@ -368,6 +343,10 @@ class PytorchModel(LLM):
|
|
|
368
343
|
else:
|
|
369
344
|
return generator_wrapper(prompt, generate_config)
|
|
370
345
|
|
|
346
|
+
@staticmethod
|
|
347
|
+
def require_attention_mask():
|
|
348
|
+
return False
|
|
349
|
+
|
|
371
350
|
@lru_cache
|
|
372
351
|
def get_context_len(self):
|
|
373
352
|
return get_context_length(self._model.config)
|
|
@@ -375,13 +354,14 @@ class PytorchModel(LLM):
|
|
|
375
354
|
def get_max_num_seqs(self) -> int:
|
|
376
355
|
return self._pytorch_model_config.get("max_num_seqs") # type: ignore
|
|
377
356
|
|
|
357
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
358
|
+
return self._sanitize_generate_config(req.generate_config)
|
|
359
|
+
|
|
378
360
|
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
379
361
|
# check some parameters
|
|
380
362
|
for r in req_list:
|
|
381
363
|
if r.sanitized_generate_config is None:
|
|
382
|
-
r.sanitized_generate_config = self.
|
|
383
|
-
r.generate_config
|
|
384
|
-
)
|
|
364
|
+
r.sanitized_generate_config = self.prepare_sanitize_generate_config(r)
|
|
385
365
|
if r.is_prefill:
|
|
386
366
|
# check some generate params
|
|
387
367
|
max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
|
|
@@ -401,6 +381,14 @@ class PytorchModel(LLM):
|
|
|
401
381
|
r.error_msg = "Invalid `stop` field type"
|
|
402
382
|
continue
|
|
403
383
|
|
|
384
|
+
def _get_builtin_stop_token_ids(self) -> Tuple:
|
|
385
|
+
return (
|
|
386
|
+
tuple(self.model_family.prompt_style.stop_token_ids)
|
|
387
|
+
if self.model_family.prompt_style
|
|
388
|
+
and self.model_family.prompt_style.stop_token_ids
|
|
389
|
+
else tuple()
|
|
390
|
+
)
|
|
391
|
+
|
|
404
392
|
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
405
393
|
for req in req_list:
|
|
406
394
|
if req.error_msg is None:
|
|
@@ -449,6 +437,8 @@ class PytorchModel(LLM):
|
|
|
449
437
|
self._tokenizer,
|
|
450
438
|
self._device,
|
|
451
439
|
context_len,
|
|
440
|
+
self._get_builtin_stop_token_ids(),
|
|
441
|
+
require_attention_mask=self.require_attention_mask(),
|
|
452
442
|
)
|
|
453
443
|
self.handle_batch_inference_results(req_list)
|
|
454
444
|
|
|
@@ -64,6 +64,8 @@ class Glm4VModel(PytorchChatModel):
|
|
|
64
64
|
|
|
65
65
|
kwargs = {"device_map": self._device}
|
|
66
66
|
quantization = self.quantization
|
|
67
|
+
|
|
68
|
+
# referenced from PytorchModel.load
|
|
67
69
|
if quantization != "none":
|
|
68
70
|
if self._device == "cuda" and self._is_linux():
|
|
69
71
|
kwargs["device_map"] = "auto"
|
|
@@ -72,6 +74,15 @@ class Glm4VModel(PytorchChatModel):
|
|
|
72
74
|
kwargs["load_in_4bit"] = True
|
|
73
75
|
elif quantization == "8-bit":
|
|
74
76
|
kwargs["load_in_8bit"] = True
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Quantization {quantization} is not supported in temporary"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
if quantization != "8-bit":
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Only 8-bit quantization is supported if it is not linux system or cuda device"
|
|
85
|
+
)
|
|
75
86
|
|
|
76
87
|
model = AutoModelForCausalLM.from_pretrained(
|
|
77
88
|
self.model_path,
|
|
@@ -15,6 +15,7 @@ import time
|
|
|
15
15
|
import uuid
|
|
16
16
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
17
17
|
|
|
18
|
+
from ....core.scheduler import InferenceRequest
|
|
18
19
|
from ....types import (
|
|
19
20
|
ChatCompletion,
|
|
20
21
|
ChatCompletionChoice,
|
|
@@ -88,6 +89,20 @@ class Internlm2PytorchChatModel(PytorchChatModel):
|
|
|
88
89
|
return False
|
|
89
90
|
return True
|
|
90
91
|
|
|
92
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
93
|
+
"""
|
|
94
|
+
Overwrite this func for this special model.
|
|
95
|
+
Cannot use the default configuration, which works poorly on this model.
|
|
96
|
+
"""
|
|
97
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
98
|
+
temperature = raw_config.get("temperature", None)
|
|
99
|
+
if temperature is None:
|
|
100
|
+
raw_config["temperature"] = 0.8
|
|
101
|
+
top_p = raw_config.get("top_p", None)
|
|
102
|
+
if top_p is None:
|
|
103
|
+
raw_config["top_p"] = 0.8
|
|
104
|
+
return raw_config
|
|
105
|
+
|
|
91
106
|
def chat(
|
|
92
107
|
self,
|
|
93
108
|
prompt: str,
|