xinference 0.12.1__py3-none-any.whl → 0.12.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (55) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +34 -8
  3. xinference/client/restful/restful_client.py +4 -0
  4. xinference/core/event.py +5 -6
  5. xinference/core/model.py +8 -3
  6. xinference/core/scheduler.py +13 -3
  7. xinference/model/llm/llm_family.json +6 -2
  8. xinference/model/llm/llm_family_modelscope.json +6 -2
  9. xinference/model/llm/pytorch/chatglm.py +23 -0
  10. xinference/model/llm/pytorch/core.py +39 -49
  11. xinference/model/llm/pytorch/glm4v.py +11 -0
  12. xinference/model/llm/pytorch/internlm2.py +15 -0
  13. xinference/model/llm/pytorch/utils.py +46 -179
  14. xinference/model/llm/utils.py +14 -2
  15. xinference/model/rerank/core.py +35 -6
  16. xinference/types.py +28 -0
  17. xinference/web/ui/build/asset-manifest.json +6 -6
  18. xinference/web/ui/build/index.html +1 -1
  19. xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
  20. xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
  21. xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
  22. xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
  23. xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
  24. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
  25. xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
  26. xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
  27. xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
  28. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
  29. xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
  30. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
  31. xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
  32. xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
  33. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
  35. {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/METADATA +1 -1
  36. {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/RECORD +41 -40
  37. xinference/web/ui/build/static/css/main.074e2b31.css +0 -2
  38. xinference/web/ui/build/static/css/main.074e2b31.css.map +0 -1
  39. xinference/web/ui/build/static/js/main.a58ff436.js +0 -3
  40. xinference/web/ui/build/static/js/main.a58ff436.js.map +0 -1
  41. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +0 -1
  42. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
  43. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
  44. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
  46. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +0 -1
  47. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +0 -1
  48. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
  49. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
  50. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +0 -1
  51. /xinference/web/ui/build/static/js/{main.a58ff436.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
  52. {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/LICENSE +0 -0
  53. {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/WHEEL +0 -0
  54. {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/entry_points.txt +0 -0
  55. {xinference-0.12.1.dist-info → xinference-0.12.2.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-06-14T17:17:50+0800",
11
+ "date": "2024-06-22T23:28:43+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "34a57df449f0890415c424802d3596f3c8758412",
15
- "version": "0.12.1"
14
+ "full-revisionid": "7705d4ae1eb68523e87c4f2abf84026dae18b694",
15
+ "version": "0.12.2.post1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -109,6 +109,7 @@ class RerankRequest(BaseModel):
109
109
  documents: List[str]
110
110
  top_n: Optional[int] = None
111
111
  return_documents: Optional[bool] = False
112
+ return_len: Optional[bool] = False
112
113
  max_chunks_per_doc: Optional[int] = None
113
114
 
114
115
 
@@ -981,7 +982,8 @@ class RESTfulAPI:
981
982
  return JSONResponse(content=self._supervisor_address)
982
983
 
983
984
  async def create_completion(self, request: Request) -> Response:
984
- body = CreateCompletionRequest.parse_obj(await request.json())
985
+ raw_body = await request.json()
986
+ body = CreateCompletionRequest.parse_obj(raw_body)
985
987
  exclude = {
986
988
  "prompt",
987
989
  "model",
@@ -991,6 +993,7 @@ class RESTfulAPI:
991
993
  "logit_bias_type",
992
994
  "user",
993
995
  }
996
+ raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
994
997
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
995
998
 
996
999
  # TODO: Decide if this default value override is necessary #1061
@@ -1020,7 +1023,9 @@ class RESTfulAPI:
1020
1023
  iterator = None
1021
1024
  try:
1022
1025
  try:
1023
- iterator = await model.generate(body.prompt, kwargs)
1026
+ iterator = await model.generate(
1027
+ body.prompt, kwargs, raw_params=raw_kwargs
1028
+ )
1024
1029
  except RuntimeError as re:
1025
1030
  self.handle_request_limit_error(re)
1026
1031
  async for item in iterator:
@@ -1040,7 +1045,7 @@ class RESTfulAPI:
1040
1045
  return EventSourceResponse(stream_results())
1041
1046
  else:
1042
1047
  try:
1043
- data = await model.generate(body.prompt, kwargs)
1048
+ data = await model.generate(body.prompt, kwargs, raw_params=raw_kwargs)
1044
1049
  return Response(data, media_type="application/json")
1045
1050
  except Exception as e:
1046
1051
  logger.error(e, exc_info=True)
@@ -1112,6 +1117,7 @@ class RESTfulAPI:
1112
1117
  top_n=body.top_n,
1113
1118
  max_chunks_per_doc=body.max_chunks_per_doc,
1114
1119
  return_documents=body.return_documents,
1120
+ return_len=body.return_len,
1115
1121
  **kwargs,
1116
1122
  )
1117
1123
  return Response(scores, media_type="application/json")
@@ -1341,7 +1347,8 @@ class RESTfulAPI:
1341
1347
  raise HTTPException(status_code=500, detail=str(e))
1342
1348
 
1343
1349
  async def create_chat_completion(self, request: Request) -> Response:
1344
- body = CreateChatCompletion.parse_obj(await request.json())
1350
+ raw_body = await request.json()
1351
+ body = CreateChatCompletion.parse_obj(raw_body)
1345
1352
  exclude = {
1346
1353
  "prompt",
1347
1354
  "model",
@@ -1351,6 +1358,7 @@ class RESTfulAPI:
1351
1358
  "logit_bias_type",
1352
1359
  "user",
1353
1360
  }
1361
+ raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1354
1362
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1355
1363
 
1356
1364
  # TODO: Decide if this default value override is necessary #1061
@@ -1425,7 +1433,9 @@ class RESTfulAPI:
1425
1433
  "gorilla-openfunctions-v1",
1426
1434
  "qwen-chat",
1427
1435
  "qwen1.5-chat",
1436
+ "qwen1.5-moe-chat",
1428
1437
  "qwen2-instruct",
1438
+ "qwen2-moe-instruct",
1429
1439
  ]
1430
1440
 
1431
1441
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
@@ -1451,7 +1461,9 @@ class RESTfulAPI:
1451
1461
  if not is_vllm or model_family not in [
1452
1462
  "qwen-chat",
1453
1463
  "qwen1.5-chat",
1464
+ "qwen1.5-moe-chat",
1454
1465
  "qwen2-instruct",
1466
+ "qwen2-moe-instruct",
1455
1467
  ]:
1456
1468
  raise HTTPException(
1457
1469
  status_code=400,
@@ -1465,10 +1477,16 @@ class RESTfulAPI:
1465
1477
  try:
1466
1478
  try:
1467
1479
  if is_qwen:
1468
- iterator = await model.chat(prompt, chat_history, kwargs)
1480
+ iterator = await model.chat(
1481
+ prompt, chat_history, kwargs, raw_params=raw_kwargs
1482
+ )
1469
1483
  else:
1470
1484
  iterator = await model.chat(
1471
- prompt, system_prompt, chat_history, kwargs
1485
+ prompt,
1486
+ system_prompt,
1487
+ chat_history,
1488
+ kwargs,
1489
+ raw_params=raw_kwargs,
1472
1490
  )
1473
1491
  except RuntimeError as re:
1474
1492
  await self._report_error_event(model_uid, str(re))
@@ -1498,9 +1516,17 @@ class RESTfulAPI:
1498
1516
  else:
1499
1517
  try:
1500
1518
  if is_qwen:
1501
- data = await model.chat(prompt, chat_history, kwargs)
1519
+ data = await model.chat(
1520
+ prompt, chat_history, kwargs, raw_params=raw_kwargs
1521
+ )
1502
1522
  else:
1503
- data = await model.chat(prompt, system_prompt, chat_history, kwargs)
1523
+ data = await model.chat(
1524
+ prompt,
1525
+ system_prompt,
1526
+ chat_history,
1527
+ kwargs,
1528
+ raw_params=raw_kwargs,
1529
+ )
1504
1530
  return Response(content=data, media_type="application/json")
1505
1531
  except Exception as e:
1506
1532
  logger.error(e, exc_info=True)
@@ -135,6 +135,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
135
135
  top_n: Optional[int] = None,
136
136
  max_chunks_per_doc: Optional[int] = None,
137
137
  return_documents: Optional[bool] = None,
138
+ return_len: Optional[bool] = None,
138
139
  **kwargs,
139
140
  ):
140
141
  """
@@ -152,6 +153,8 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
152
153
  The maximum number of chunks derived from a document
153
154
  return_documents: bool
154
155
  if return documents
156
+ return_len: bool
157
+ if return tokens len
155
158
  Returns
156
159
  -------
157
160
  Scores
@@ -170,6 +173,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
170
173
  "top_n": top_n,
171
174
  "max_chunks_per_doc": max_chunks_per_doc,
172
175
  "return_documents": return_documents,
176
+ "return_len": return_len,
173
177
  }
174
178
  request_body.update(kwargs)
175
179
  response = requests.post(url, json=request_body, headers=self.auth_headers)
xinference/core/event.py CHANGED
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import queue
16
- from collections import defaultdict
15
+ from collections import defaultdict, deque
17
16
  from enum import Enum
18
17
  from typing import Dict, List, TypedDict
19
18
 
@@ -37,8 +36,8 @@ class Event(TypedDict):
37
36
  class EventCollectorActor(xo.StatelessActor):
38
37
  def __init__(self):
39
38
  super().__init__()
40
- self._model_uid_to_events: Dict[str, queue.Queue] = defaultdict( # type: ignore
41
- lambda: queue.Queue(maxsize=MAX_EVENT_COUNT_PER_MODEL)
39
+ self._model_uid_to_events: Dict[str, deque] = defaultdict( # type: ignore
40
+ lambda: deque(maxlen=MAX_EVENT_COUNT_PER_MODEL)
42
41
  )
43
42
 
44
43
  @classmethod
@@ -50,7 +49,7 @@ class EventCollectorActor(xo.StatelessActor):
50
49
  if event_queue is None:
51
50
  return []
52
51
  else:
53
- return [dict(e, event_type=e["event_type"].name) for e in event_queue.queue]
52
+ return [dict(e, event_type=e["event_type"].name) for e in iter(event_queue)]
54
53
 
55
54
  def report_event(self, model_uid: str, event: Event):
56
- self._model_uid_to_events[model_uid].put(event)
55
+ self._model_uid_to_events[model_uid].append(event)
xinference/core/model.py CHANGED
@@ -264,13 +264,14 @@ class ModelActor(xo.StatelessActor):
264
264
  return isinstance(self._model, VLLMModel)
265
265
 
266
266
  def allow_batching(self) -> bool:
267
- from ..model.llm.pytorch.core import PytorchChatModel, PytorchModel
267
+ from ..model.llm.pytorch.core import PytorchModel
268
+
269
+ model_ability = self._model_description.get("model_ability", [])
268
270
 
269
271
  return (
270
272
  XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
271
273
  and isinstance(self._model, PytorchModel)
272
- and self._model.__class__.__name__
273
- in (PytorchChatModel.__name__, PytorchModel.__name__)
274
+ and "vision" not in model_ability
274
275
  )
275
276
 
276
277
  async def load(self):
@@ -399,6 +400,7 @@ class ModelActor(xo.StatelessActor):
399
400
  prompt, "generate", *args, **kwargs
400
401
  )
401
402
  else:
403
+ kwargs.pop("raw_params", None)
402
404
  if hasattr(self._model, "generate"):
403
405
  return await self._call_wrapper(
404
406
  self._model.generate, prompt, *args, **kwargs
@@ -481,6 +483,7 @@ class ModelActor(xo.StatelessActor):
481
483
  prompt, "chat", *args, **kwargs
482
484
  )
483
485
  else:
486
+ kwargs.pop("raw_params", None)
484
487
  if hasattr(self._model, "chat"):
485
488
  response = await self._call_wrapper(
486
489
  self._model.chat, prompt, *args, **kwargs
@@ -540,6 +543,7 @@ class ModelActor(xo.StatelessActor):
540
543
  top_n: Optional[int],
541
544
  max_chunks_per_doc: Optional[int],
542
545
  return_documents: Optional[bool],
546
+ return_len: Optional[bool],
543
547
  *args,
544
548
  **kwargs,
545
549
  ):
@@ -551,6 +555,7 @@ class ModelActor(xo.StatelessActor):
551
555
  top_n,
552
556
  max_chunks_per_doc,
553
557
  return_documents,
558
+ return_len,
554
559
  *args,
555
560
  **kwargs,
556
561
  )
@@ -18,7 +18,7 @@ import logging
18
18
  import uuid
19
19
  from collections import deque
20
20
  from enum import Enum
21
- from typing import List, Optional, Set
21
+ from typing import List, Optional, Set, Tuple
22
22
 
23
23
  import xoscar as xo
24
24
 
@@ -53,7 +53,8 @@ class InferenceRequest:
53
53
  self._kv_cache = None
54
54
  # use passed args from upstream interface
55
55
  self._inference_args = args
56
- # use passed kwargs from upstream interface, basically not used for now
56
+ # use passed kwargs from upstream interface, currently for getting raw generate config from upstream,
57
+ # which is useful for some special models
57
58
  self._inference_kwargs = kwargs
58
59
  # should this request be stopped
59
60
  self._stopped = False
@@ -66,6 +67,8 @@ class InferenceRequest:
66
67
  self._sanitized_generate_config = None
67
68
  # Chunk id for results. In stream mode, all the chunk ids should be same.
68
69
  self._stream_chunk_id = str(uuid.uuid4())
70
+ # For calculate attention mask if needed
71
+ self.padding_len = 0
69
72
  # Use in stream mode
70
73
  self.last_output_length = 0
71
74
  # inference results,
@@ -172,6 +175,10 @@ class InferenceRequest:
172
175
  def sanitized_generate_config(self, value: dict):
173
176
  self._sanitized_generate_config = value
174
177
 
178
+ @property
179
+ def inference_kwargs(self):
180
+ return self._inference_kwargs
181
+
175
182
  @property
176
183
  def stopped(self):
177
184
  return self._stopped
@@ -231,7 +238,9 @@ class InferenceRequest:
231
238
  )
232
239
 
233
240
  @functools.lru_cache
234
- def get_generate_configs(self, eos_token_id: int):
241
+ def get_generate_configs(
242
+ self, eos_token_id: int, builtin_stop_token_ids: Optional[Tuple[int]] = None
243
+ ):
235
244
  from ..types import max_tokens_field
236
245
 
237
246
  max_new_tokens = int(
@@ -245,6 +254,7 @@ class InferenceRequest:
245
254
  )
246
255
  stop_token_ids = set(stop_token_ids)
247
256
  stop_token_ids.add(eos_token_id)
257
+ stop_token_ids.update(builtin_stop_token_ids or [])
248
258
  temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
249
259
  repetition_penalty = float(
250
260
  self.sanitized_generate_config.get("repetition_penalty", 1.0)
@@ -2290,7 +2290,8 @@
2290
2290
  "zh"
2291
2291
  ],
2292
2292
  "model_ability": [
2293
- "chat"
2293
+ "chat",
2294
+ "tools"
2294
2295
  ],
2295
2296
  "model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
2296
2297
  "model_specs": [
@@ -2595,7 +2596,8 @@
2595
2596
  "zh"
2596
2597
  ],
2597
2598
  "model_ability": [
2598
- "chat"
2599
+ "chat",
2600
+ "tools"
2599
2601
  ],
2600
2602
  "model_description": "Qwen2 is the new series of Qwen large language models. ",
2601
2603
  "model_specs": [
@@ -5675,9 +5677,11 @@
5675
5677
  ],
5676
5678
  "intra_message_sep": "<|im_end|>",
5677
5679
  "stop_token_ids": [
5680
+ 2,
5678
5681
  92542
5679
5682
  ],
5680
5683
  "stop": [
5684
+ "</s>",
5681
5685
  "<|im_end|>"
5682
5686
  ]
5683
5687
  }
@@ -2644,7 +2644,8 @@
2644
2644
  "zh"
2645
2645
  ],
2646
2646
  "model_ability": [
2647
- "chat"
2647
+ "chat",
2648
+ "tools"
2648
2649
  ],
2649
2650
  "model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
2650
2651
  "model_specs": [
@@ -2968,7 +2969,8 @@
2968
2969
  "zh"
2969
2970
  ],
2970
2971
  "model_ability": [
2971
- "chat"
2972
+ "chat",
2973
+ "tools"
2972
2974
  ],
2973
2975
  "model_description": "Qwen2 is the new series of Qwen large language models. ",
2974
2976
  "model_specs": [
@@ -3350,9 +3352,11 @@
3350
3352
  ],
3351
3353
  "intra_message_sep": "<|im_end|>",
3352
3354
  "stop_token_ids": [
3355
+ 2,
3353
3356
  92542
3354
3357
  ],
3355
3358
  "stop": [
3359
+ "</s>",
3356
3360
  "<|im_end|>"
3357
3361
  ]
3358
3362
  }
@@ -15,6 +15,7 @@ import time
15
15
  import uuid
16
16
  from typing import Any, Dict, Iterator, List, Optional, Union
17
17
 
18
+ from ....core.scheduler import InferenceRequest
18
19
  from ....types import (
19
20
  SPECIAL_TOOL_PROMPT,
20
21
  ChatCompletion,
@@ -244,3 +245,25 @@ class ChatglmPytorchChatModel(PytorchChatModel):
244
245
  prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
245
246
  ),
246
247
  )
248
+
249
+ @staticmethod
250
+ def require_attention_mask():
251
+ """
252
+ GLM4 needs to use attention mask and position ids during inference.
253
+ Otherwise, the inference result would be not available.
254
+ """
255
+ return True
256
+
257
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
258
+ """
259
+ Set temperature and top_p to 0.8 by default
260
+ """
261
+ raw_config = req.inference_kwargs.get("raw_params", {})
262
+ temperature = raw_config.get("temperature", None)
263
+ if temperature is None:
264
+ raw_config["temperature"] = 0.8
265
+ top_p = raw_config.get("top_p", None)
266
+ if top_p is None:
267
+ raw_config["top_p"] = 0.8
268
+
269
+ return raw_config
@@ -16,7 +16,7 @@ import json
16
16
  import logging
17
17
  import os
18
18
  from functools import lru_cache
19
- from typing import Iterable, Iterator, List, Optional, Union
19
+ from typing import Iterable, Iterator, List, Optional, Tuple, Union
20
20
 
21
21
  from ....core.scheduler import InferenceRequest
22
22
  from ....device_utils import (
@@ -283,35 +283,21 @@ class PytorchModel(LLM):
283
283
  def generate(
284
284
  self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
285
285
  ) -> Union[Completion, Iterator[CompletionChunk]]:
286
- from .utils import generate_stream, generate_stream_falcon
287
-
288
- model_family_name = self.model_family.model_name.lower()
286
+ from .utils import generate_stream
289
287
 
290
288
  def generator_wrapper(
291
289
  prompt: str, generate_config: PytorchGenerateConfig
292
290
  ) -> Iterator[CompletionChunk]:
293
- if "falcon" in model_family_name:
294
- for completion_chunk, completion_usage in generate_stream_falcon(
295
- self.model_uid,
296
- self._model,
297
- self._tokenizer,
298
- prompt,
299
- self._device,
300
- generate_config,
301
- ):
302
- completion_chunk["usage"] = completion_usage
303
- yield completion_chunk
304
- else:
305
- for completion_chunk, completion_usage in generate_stream(
306
- self.model_uid,
307
- self._model,
308
- self._tokenizer,
309
- prompt,
310
- self._device,
311
- generate_config,
312
- ):
313
- completion_chunk["usage"] = completion_usage
314
- yield completion_chunk
291
+ for completion_chunk, completion_usage in generate_stream(
292
+ self.model_uid,
293
+ self._model,
294
+ self._tokenizer,
295
+ prompt,
296
+ self._device,
297
+ generate_config,
298
+ ):
299
+ completion_chunk["usage"] = completion_usage
300
+ yield completion_chunk
315
301
 
316
302
  logger.debug(
317
303
  "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
@@ -336,26 +322,15 @@ class PytorchModel(LLM):
336
322
 
337
323
  stream = generate_config.get("stream", False)
338
324
  if not stream:
339
- if "falcon" in model_family_name:
340
- for completion_chunk, completion_usage in generate_stream_falcon(
341
- self.model_uid,
342
- self._model,
343
- self._tokenizer,
344
- prompt,
345
- self._device,
346
- generate_config,
347
- ):
348
- pass
349
- else:
350
- for completion_chunk, completion_usage in generate_stream(
351
- self.model_uid,
352
- self._model,
353
- self._tokenizer,
354
- prompt,
355
- self._device,
356
- generate_config,
357
- ):
358
- pass
325
+ for completion_chunk, completion_usage in generate_stream(
326
+ self.model_uid,
327
+ self._model,
328
+ self._tokenizer,
329
+ prompt,
330
+ self._device,
331
+ generate_config,
332
+ ):
333
+ pass
359
334
  completion = Completion(
360
335
  id=completion_chunk["id"],
361
336
  object=completion_chunk["object"],
@@ -368,6 +343,10 @@ class PytorchModel(LLM):
368
343
  else:
369
344
  return generator_wrapper(prompt, generate_config)
370
345
 
346
+ @staticmethod
347
+ def require_attention_mask():
348
+ return False
349
+
371
350
  @lru_cache
372
351
  def get_context_len(self):
373
352
  return get_context_length(self._model.config)
@@ -375,13 +354,14 @@ class PytorchModel(LLM):
375
354
  def get_max_num_seqs(self) -> int:
376
355
  return self._pytorch_model_config.get("max_num_seqs") # type: ignore
377
356
 
357
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
358
+ return self._sanitize_generate_config(req.generate_config)
359
+
378
360
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
379
361
  # check some parameters
380
362
  for r in req_list:
381
363
  if r.sanitized_generate_config is None:
382
- r.sanitized_generate_config = self._sanitize_generate_config(
383
- r.generate_config
384
- )
364
+ r.sanitized_generate_config = self.prepare_sanitize_generate_config(r)
385
365
  if r.is_prefill:
386
366
  # check some generate params
387
367
  max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
@@ -401,6 +381,14 @@ class PytorchModel(LLM):
401
381
  r.error_msg = "Invalid `stop` field type"
402
382
  continue
403
383
 
384
+ def _get_builtin_stop_token_ids(self) -> Tuple:
385
+ return (
386
+ tuple(self.model_family.prompt_style.stop_token_ids)
387
+ if self.model_family.prompt_style
388
+ and self.model_family.prompt_style.stop_token_ids
389
+ else tuple()
390
+ )
391
+
404
392
  def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
405
393
  for req in req_list:
406
394
  if req.error_msg is None:
@@ -449,6 +437,8 @@ class PytorchModel(LLM):
449
437
  self._tokenizer,
450
438
  self._device,
451
439
  context_len,
440
+ self._get_builtin_stop_token_ids(),
441
+ require_attention_mask=self.require_attention_mask(),
452
442
  )
453
443
  self.handle_batch_inference_results(req_list)
454
444
 
@@ -64,6 +64,8 @@ class Glm4VModel(PytorchChatModel):
64
64
 
65
65
  kwargs = {"device_map": self._device}
66
66
  quantization = self.quantization
67
+
68
+ # referenced from PytorchModel.load
67
69
  if quantization != "none":
68
70
  if self._device == "cuda" and self._is_linux():
69
71
  kwargs["device_map"] = "auto"
@@ -72,6 +74,15 @@ class Glm4VModel(PytorchChatModel):
72
74
  kwargs["load_in_4bit"] = True
73
75
  elif quantization == "8-bit":
74
76
  kwargs["load_in_8bit"] = True
77
+ else:
78
+ raise ValueError(
79
+ f"Quantization {quantization} is not supported in temporary"
80
+ )
81
+ else:
82
+ if quantization != "8-bit":
83
+ raise ValueError(
84
+ f"Only 8-bit quantization is supported if it is not linux system or cuda device"
85
+ )
75
86
 
76
87
  model = AutoModelForCausalLM.from_pretrained(
77
88
  self.model_path,
@@ -15,6 +15,7 @@ import time
15
15
  import uuid
16
16
  from typing import Any, Dict, Iterator, List, Optional, Union
17
17
 
18
+ from ....core.scheduler import InferenceRequest
18
19
  from ....types import (
19
20
  ChatCompletion,
20
21
  ChatCompletionChoice,
@@ -88,6 +89,20 @@ class Internlm2PytorchChatModel(PytorchChatModel):
88
89
  return False
89
90
  return True
90
91
 
92
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
93
+ """
94
+ Overwrite this func for this special model.
95
+ Cannot use the default configuration, which works poorly on this model.
96
+ """
97
+ raw_config = req.inference_kwargs.get("raw_params", {})
98
+ temperature = raw_config.get("temperature", None)
99
+ if temperature is None:
100
+ raw_config["temperature"] = 0.8
101
+ top_p = raw_config.get("top_p", None)
102
+ if top_p is None:
103
+ raw_config["top_p"] = 0.8
104
+ return raw_config
105
+
91
106
  def chat(
92
107
  self,
93
108
  prompt: str,