xinference 0.15.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (38) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/constants.py +4 -4
  4. xinference/core/model.py +89 -18
  5. xinference/core/scheduler.py +10 -7
  6. xinference/core/utils.py +9 -0
  7. xinference/deploy/supervisor.py +4 -0
  8. xinference/model/__init__.py +4 -0
  9. xinference/model/image/scheduler/__init__.py +13 -0
  10. xinference/model/image/scheduler/flux.py +533 -0
  11. xinference/model/image/stable_diffusion/core.py +6 -31
  12. xinference/model/image/utils.py +39 -3
  13. xinference/model/llm/__init__.py +2 -0
  14. xinference/model/llm/llm_family.json +169 -1
  15. xinference/model/llm/llm_family_modelscope.json +108 -0
  16. xinference/model/llm/transformers/chatglm.py +104 -0
  17. xinference/model/llm/transformers/core.py +37 -111
  18. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  19. xinference/model/llm/transformers/internlm2.py +3 -95
  20. xinference/model/llm/transformers/opt.py +68 -0
  21. xinference/model/llm/transformers/utils.py +4 -284
  22. xinference/model/llm/utils.py +2 -2
  23. xinference/model/llm/vllm/core.py +16 -1
  24. xinference/utils.py +2 -3
  25. xinference/web/ui/build/asset-manifest.json +3 -3
  26. xinference/web/ui/build/index.html +1 -1
  27. xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
  28. xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
  29. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  30. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/METADATA +36 -4
  31. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/RECORD +36 -33
  32. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  33. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  34. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
  35. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
  36. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
  37. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
  38. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
@@ -206,7 +206,7 @@
206
206
  "none"
207
207
  ],
208
208
  "model_id": "THUDM/glm-4-9b-chat",
209
- "model_revision": "f6e0743b285dd808084530f070ad08e504386750"
209
+ "model_revision": "eb55a443d66541f30869f6caac5ad0d2e95bcbaa"
210
210
  },
211
211
  {
212
212
  "model_format": "ggufv2",
@@ -7923,6 +7923,174 @@
7923
7923
  "00021-of-00021"
7924
7924
  ]
7925
7925
  }
7926
+ },
7927
+ {
7928
+ "model_format": "mlx",
7929
+ "model_size_in_billions": "0_5",
7930
+ "quantizations": [
7931
+ "4-bit"
7932
+ ],
7933
+ "model_id": "mlx-community/Qwen2.5-0.5B-Instruct-4bit"
7934
+ },
7935
+ {
7936
+ "model_format": "mlx",
7937
+ "model_size_in_billions": "0_5",
7938
+ "quantizations": [
7939
+ "8-bit"
7940
+ ],
7941
+ "model_id": "mlx-community/Qwen2.5-0.5B-Instruct-8bit"
7942
+ },
7943
+ {
7944
+ "model_format": "mlx",
7945
+ "model_size_in_billions": "0_5",
7946
+ "quantizations": [
7947
+ "none"
7948
+ ],
7949
+ "model_id": "mlx-community/Qwen2.5-0.5B-Instruct-bf16"
7950
+ },
7951
+ {
7952
+ "model_format": "mlx",
7953
+ "model_size_in_billions": "1_5",
7954
+ "quantizations": [
7955
+ "4-bit"
7956
+ ],
7957
+ "model_id": "mlx-community/Qwen2.5-1.5B-Instruct-4bit"
7958
+ },
7959
+ {
7960
+ "model_format": "mlx",
7961
+ "model_size_in_billions": "1_5",
7962
+ "quantizations": [
7963
+ "8-bit"
7964
+ ],
7965
+ "model_id": "mlx-community/Qwen2.5-1.5B-Instruct-8bit"
7966
+ },
7967
+ {
7968
+ "model_format": "mlx",
7969
+ "model_size_in_billions": "1_5",
7970
+ "quantizations": [
7971
+ "none"
7972
+ ],
7973
+ "model_id": "mlx-community/Qwen2.5-1.5B-Instruct-bf16"
7974
+ },
7975
+ {
7976
+ "model_format": "mlx",
7977
+ "model_size_in_billions": 3,
7978
+ "quantizations": [
7979
+ "4-bit"
7980
+ ],
7981
+ "model_id": "mlx-community/Qwen2.5-3B-Instruct-4bit"
7982
+ },
7983
+ {
7984
+ "model_format": "mlx",
7985
+ "model_size_in_billions": 3,
7986
+ "quantizations": [
7987
+ "8-bit"
7988
+ ],
7989
+ "model_id": "mlx-community/Qwen2.5-3B-Instruct-8bit"
7990
+ },
7991
+ {
7992
+ "model_format": "mlx",
7993
+ "model_size_in_billions": 3,
7994
+ "quantizations": [
7995
+ "none"
7996
+ ],
7997
+ "model_id": "mlx-community/Qwen2.5-3B-Instruct-bf16"
7998
+ },
7999
+ {
8000
+ "model_format": "mlx",
8001
+ "model_size_in_billions": 7,
8002
+ "quantizations": [
8003
+ "4-bit"
8004
+ ],
8005
+ "model_id": "mlx-community/Qwen2.5-7B-Instruct-4bit"
8006
+ },
8007
+ {
8008
+ "model_format": "mlx",
8009
+ "model_size_in_billions": 7,
8010
+ "quantizations": [
8011
+ "8-bit"
8012
+ ],
8013
+ "model_id": "mlx-community/Qwen2.5-7B-Instruct-8bit"
8014
+ },
8015
+ {
8016
+ "model_format": "mlx",
8017
+ "model_size_in_billions": 7,
8018
+ "quantizations": [
8019
+ "none"
8020
+ ],
8021
+ "model_id": "mlx-community/Qwen2.5-7B-Instruct-bf16"
8022
+ },
8023
+ {
8024
+ "model_format": "mlx",
8025
+ "model_size_in_billions": 14,
8026
+ "quantizations": [
8027
+ "4-bit"
8028
+ ],
8029
+ "model_id": "mlx-community/Qwen2.5-14B-Instruct-4bit"
8030
+ },
8031
+ {
8032
+ "model_format": "mlx",
8033
+ "model_size_in_billions": 14,
8034
+ "quantizations": [
8035
+ "8-bit"
8036
+ ],
8037
+ "model_id": "mlx-community/Qwen2.5-14B-Instruct-8bit"
8038
+ },
8039
+ {
8040
+ "model_format": "mlx",
8041
+ "model_size_in_billions": 14,
8042
+ "quantizations": [
8043
+ "none"
8044
+ ],
8045
+ "model_id": "mlx-community/Qwen2.5-14B-Instruct-bf16"
8046
+ },
8047
+ {
8048
+ "model_format": "mlx",
8049
+ "model_size_in_billions": 32,
8050
+ "quantizations": [
8051
+ "4-bit"
8052
+ ],
8053
+ "model_id": "mlx-community/Qwen2.5-32B-Instruct-4bit"
8054
+ },
8055
+ {
8056
+ "model_format": "mlx",
8057
+ "model_size_in_billions": 32,
8058
+ "quantizations": [
8059
+ "8-bit"
8060
+ ],
8061
+ "model_id": "mlx-community/Qwen2.5-32B-Instruct-8bit"
8062
+ },
8063
+ {
8064
+ "model_format": "mlx",
8065
+ "model_size_in_billions": 32,
8066
+ "quantizations": [
8067
+ "none"
8068
+ ],
8069
+ "model_id": "mlx-community/Qwen2.5-32B-Instruct-bf16"
8070
+ },
8071
+ {
8072
+ "model_format": "mlx",
8073
+ "model_size_in_billions": 72,
8074
+ "quantizations": [
8075
+ "4-bit"
8076
+ ],
8077
+ "model_id": "mlx-community/Qwen2.5-72B-Instruct-4bit"
8078
+ },
8079
+ {
8080
+ "model_format": "mlx",
8081
+ "model_size_in_billions": 72,
8082
+ "quantizations": [
8083
+ "8-bit"
8084
+ ],
8085
+ "model_id": "mlx-community/Qwen2.5-72B-Instruct-8bit"
8086
+ },
8087
+ {
8088
+ "model_format": "mlx",
8089
+ "model_size_in_billions": 72,
8090
+ "quantizations": [
8091
+ "none"
8092
+ ],
8093
+ "model_id": "mlx-community/Qwen2.5-72B-Instruct-bf16"
7926
8094
  }
7927
8095
  ],
7928
8096
  "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
@@ -5681,6 +5681,114 @@
5681
5681
  "00021-of-00021"
5682
5682
  ]
5683
5683
  }
5684
+ },
5685
+ {
5686
+ "model_format": "mlx",
5687
+ "model_size_in_billions": 3,
5688
+ "quantizations": [
5689
+ "4-bit"
5690
+ ],
5691
+ "model_id": "okwinds/Qwen2.5-3B-Instruct-MLX-4bit",
5692
+ "model_hub": "modelscope"
5693
+ },
5694
+ {
5695
+ "model_format": "mlx",
5696
+ "model_size_in_billions": 3,
5697
+ "quantizations": [
5698
+ "8-bit"
5699
+ ],
5700
+ "model_id": "okwinds/Qwen2.5-3B-Instruct-MLX-8bit",
5701
+ "model_hub": "modelscope"
5702
+ },
5703
+ {
5704
+ "model_format": "mlx",
5705
+ "model_size_in_billions": 7,
5706
+ "quantizations": [
5707
+ "4-bit"
5708
+ ],
5709
+ "model_id": "okwinds/Qwen2.5-7B-Instruct-MLX-4bit",
5710
+ "model_hub": "modelscope"
5711
+ },
5712
+ {
5713
+ "model_format": "mlx",
5714
+ "model_size_in_billions": 7,
5715
+ "quantizations": [
5716
+ "8-bit"
5717
+ ],
5718
+ "model_id": "okwinds/Qwen2.5-7B-Instruct-MLX-8bit",
5719
+ "model_hub": "modelscope"
5720
+ },
5721
+ {
5722
+ "model_format": "mlx",
5723
+ "model_size_in_billions": 14,
5724
+ "quantizations": [
5725
+ "4-bit"
5726
+ ],
5727
+ "model_id": "okwinds/Qwen2.5-14B-Instruct-MLX-4bit",
5728
+ "model_hub": "modelscope"
5729
+ },
5730
+ {
5731
+ "model_format": "mlx",
5732
+ "model_size_in_billions": 14,
5733
+ "quantizations": [
5734
+ "8-bit"
5735
+ ],
5736
+ "model_id": "okwinds/Qwen2.5-14B-Instruct-MLX-8bit",
5737
+ "model_hub": "modelscope"
5738
+ },
5739
+ {
5740
+ "model_format": "mlx",
5741
+ "model_size_in_billions": 32,
5742
+ "quantizations": [
5743
+ "2-bit"
5744
+ ],
5745
+ "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-2bit",
5746
+ "model_hub": "modelscope"
5747
+ },
5748
+ {
5749
+ "model_format": "mlx",
5750
+ "model_size_in_billions": 32,
5751
+ "quantizations": [
5752
+ "4-bit"
5753
+ ],
5754
+ "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-4bit",
5755
+ "model_hub": "modelscope"
5756
+ },
5757
+ {
5758
+ "model_format": "mlx",
5759
+ "model_size_in_billions": 32,
5760
+ "quantizations": [
5761
+ "8-bit"
5762
+ ],
5763
+ "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-8bit",
5764
+ "model_hub": "modelscope"
5765
+ },
5766
+ {
5767
+ "model_format": "mlx",
5768
+ "model_size_in_billions": 72,
5769
+ "quantizations": [
5770
+ "2-bit"
5771
+ ],
5772
+ "model_id": "okwinds/Qwen2.5-32B-Instruct-MLX-2bit",
5773
+ "model_hub": "modelscope"
5774
+ },
5775
+ {
5776
+ "model_format": "mlx",
5777
+ "model_size_in_billions": 72,
5778
+ "quantizations": [
5779
+ "4-bit"
5780
+ ],
5781
+ "model_id": "okwinds/Qwen2.5-72B-Instruct-MLX-4bit",
5782
+ "model_hub": "modelscope"
5783
+ },
5784
+ {
5785
+ "model_format": "mlx",
5786
+ "model_size_in_billions": 72,
5787
+ "quantizations": [
5788
+ "8-bit"
5789
+ ],
5790
+ "model_id": "okwinds/Qwen2.5-72B-Instruct-MLX-8bit",
5791
+ "model_hub": "modelscope"
5684
5792
  }
5685
5793
  ],
5686
5794
  "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import json
15
+ import logging
15
16
  import typing
16
17
  import uuid
17
18
  from threading import Thread
@@ -29,6 +30,8 @@ from ..utils import (
29
30
  )
30
31
  from .core import PytorchChatModel, PytorchModelConfig
31
32
 
33
+ logger = logging.getLogger(__name__)
34
+
32
35
 
33
36
  class ChatglmPytorchChatModel(PytorchChatModel):
34
37
  def __init__(
@@ -445,3 +448,104 @@ class ChatglmPytorchChatModel(PytorchChatModel):
445
448
  raw_config["top_p"] = 0.8
446
449
 
447
450
  return raw_config
451
+
452
+ def prepare_batch_inference(self, req_list: List[InferenceRequest]):
453
+ super(PytorchChatModel, self).prepare_batch_inference(req_list)
454
+ for r in req_list:
455
+ try:
456
+ if not r.stopped and r.is_prefill:
457
+ tools = r.generate_config.get("tools", None)
458
+ tools = list(tools) if tools is not None else None
459
+ tool_choice = r.generate_config.get("tool_choice", "none")
460
+
461
+ r.prompt = self._process_messages(
462
+ r.prompt, tools=tools, tool_choice=tool_choice
463
+ )
464
+ r.full_prompt = self.get_full_context(
465
+ r.prompt,
466
+ self.model_family.chat_template, # type: ignore
467
+ tokenizer=self._tokenizer,
468
+ )
469
+ if tools:
470
+ r.tools = tools
471
+ except Exception as e:
472
+ logger.exception(f"prepare inference error with {e}")
473
+ r.stopped = True
474
+ r.error_msg = str(e)
475
+
476
+ def handle_chat_result_non_streaming(self, req: InferenceRequest):
477
+ if req.tools:
478
+ response = req.completion[0]["choices"][0]["text"]
479
+ usage = req.completion[0]["usage"]
480
+ function_call = self._process_response_non_streaming(
481
+ response, req.tools, use_tool=True
482
+ )
483
+ req.completion[0] = self._tool_calls_completion(
484
+ self.model_family, self.model_uid, function_call
485
+ )
486
+ req.completion[0]["usage"] = usage
487
+ else:
488
+ req.completion[0] = self._to_chat_completion(req.completion[0])
489
+
490
+ def handle_chat_result_streaming(self, req: InferenceRequest):
491
+ results = []
492
+ tools = {tool["function"]["name"] for tool in req.tools} if req.tools else {}
493
+ response = "".join(req.outputs)
494
+ eos_pos = response.find("<eos_stream>")
495
+ if eos_pos != -1:
496
+ response = response[:eos_pos]
497
+
498
+ if "<bos_stream>" in req.completion:
499
+ bos_pos = req.completion.index("<bos_stream>")
500
+ results.append(
501
+ self._get_first_chat_completion_chunk(req.completion[bos_pos + 1])
502
+ )
503
+
504
+ if req.stopped:
505
+ if tools:
506
+ new_response = self._process_response_streaming(
507
+ response, tools, end=True
508
+ )
509
+ if new_response:
510
+ if isinstance(new_response, dict): # tool call case
511
+ chunk_id = [
512
+ c for c in req.completion if not isinstance(c, str)
513
+ ][0]["id"]
514
+ results.append(
515
+ self._tool_calls_completion_chunk(
516
+ self.model_family,
517
+ self.model_uid,
518
+ new_response,
519
+ chunk_id=chunk_id,
520
+ )
521
+ )
522
+ else: # normal case
523
+ for c in req.completion:
524
+ if c == "<bos_stream>":
525
+ continue
526
+ elif c == "<eos_stream>":
527
+ break
528
+ else:
529
+ results.append(self._to_chat_completion_chunk(c))
530
+ else:
531
+ for c in req.completion:
532
+ if c == "<bos_stream>":
533
+ continue
534
+ elif c == "<eos_stream>":
535
+ break
536
+ else:
537
+ results.append(self._to_chat_completion_chunk(c))
538
+ else:
539
+ if response and response[-1] != "�":
540
+ new_response = self._process_response_streaming(
541
+ response, tools, end=False
542
+ )
543
+ if new_response is not None: # normal case
544
+ for c in req.completion:
545
+ if c == "<bos_stream>":
546
+ continue
547
+ results.append(self._to_chat_completion_chunk(c))
548
+
549
+ if req.stopped and req.include_usage:
550
+ results.append(self._get_final_chat_completion_chunk(req.completion[-1]))
551
+ req.completion = results
@@ -29,7 +29,6 @@ from ....device_utils import (
29
29
  from ....types import (
30
30
  ChatCompletion,
31
31
  ChatCompletionChunk,
32
- Completion,
33
32
  CompletionChoice,
34
33
  CompletionChunk,
35
34
  CreateCompletionTorch,
@@ -46,9 +45,7 @@ from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
46
45
  logger = logging.getLogger(__name__)
47
46
 
48
47
  NON_DEFAULT_MODEL_LIST: List[str] = [
49
- "chatglm3",
50
- "chatglm3-32k",
51
- "chatglm3-128k",
48
+ "opt",
52
49
  "glm4-chat",
53
50
  "glm4-chat-1m",
54
51
  "internlm2-chat",
@@ -345,69 +342,6 @@ class PytorchModel(LLM):
345
342
  return False
346
343
  return True
347
344
 
348
- def generate(
349
- self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
350
- ) -> Union[Completion, Iterator[CompletionChunk]]:
351
- from .utils import generate_stream
352
-
353
- def generator_wrapper(
354
- prompt: str, generate_config: PytorchGenerateConfig
355
- ) -> Iterator[CompletionChunk]:
356
- for completion_chunk, completion_usage in generate_stream(
357
- self.model_uid,
358
- self._model,
359
- self._tokenizer,
360
- prompt,
361
- self._device,
362
- generate_config,
363
- ):
364
- completion_chunk["usage"] = completion_usage
365
- yield completion_chunk
366
-
367
- logger.debug(
368
- "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
369
- )
370
-
371
- generate_config = self._sanitize_generate_config(generate_config)
372
-
373
- assert self._model is not None
374
- assert self._tokenizer is not None
375
-
376
- lora_model = generate_config.pop("lora_name")
377
-
378
- if lora_model is not None and self._peft_model is not None:
379
- for lora in self._peft_model:
380
- if lora_model == lora.lora_name:
381
- self._model.set_adapter(lora_model)
382
- logger.info(f"Set lora model to {lora_model}")
383
- break
384
- else:
385
- self._model.disable_adapter()
386
- logger.info(f"No lora model {lora_model} found, skip setting")
387
-
388
- stream = generate_config.get("stream", False)
389
- if not stream:
390
- for completion_chunk, completion_usage in generate_stream(
391
- self.model_uid,
392
- self._model,
393
- self._tokenizer,
394
- prompt,
395
- self._device,
396
- generate_config,
397
- ):
398
- pass
399
- completion = Completion(
400
- id=completion_chunk["id"],
401
- object=completion_chunk["object"],
402
- created=completion_chunk["created"],
403
- model=completion_chunk["model"],
404
- choices=completion_chunk["choices"],
405
- usage=completion_usage,
406
- )
407
- return completion
408
- else:
409
- return generator_wrapper(prompt, generate_config)
410
-
411
345
  def build_prefill_attention_mask(
412
346
  self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
413
347
  ):
@@ -730,7 +664,12 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
730
664
  messages: List[Dict],
731
665
  generate_config: Optional[PytorchGenerateConfig] = None,
732
666
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
733
- tools = generate_config.pop("tools", []) if generate_config else None
667
+ raise NotImplementedError
668
+
669
+ def load(self):
670
+ super().load()
671
+
672
+ def _get_full_prompt(self, messages: List[Dict], tools):
734
673
  model_family = self.model_family.model_family or self.model_family.model_name
735
674
  full_context_kwargs = {}
736
675
  if (
@@ -746,29 +685,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
746
685
  tokenizer=self._tokenizer,
747
686
  **full_context_kwargs,
748
687
  )
749
-
750
- generate_config = self._sanitize_generate_config(generate_config)
751
-
752
- stream = generate_config.get("stream", False)
753
- if stream:
754
- it = self.generate(full_prompt, generate_config)
755
- assert isinstance(it, Iterator)
756
- return self._to_chat_completion_chunks(it)
757
- else:
758
- c = self.generate(full_prompt, generate_config)
759
- assert not isinstance(c, Iterator)
760
- if tools:
761
- return self._tool_calls_completion(self.model_family, self.model_uid, c)
762
- return self._to_chat_completion(c)
763
-
764
- def load(self):
765
- super().load()
766
-
767
- def _get_full_prompt(self, messages: List[Dict], tools):
768
- assert self.model_family.chat_template is not None
769
- full_prompt = self.get_full_context(
770
- messages, self.model_family.chat_template, tokenizer=self._tokenizer
771
- )
772
688
  return full_prompt
773
689
 
774
690
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
@@ -776,12 +692,39 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
776
692
  for r in req_list:
777
693
  try:
778
694
  if not r.stopped and r.is_prefill:
779
- r.full_prompt = self._get_full_prompt(r.prompt, None)
695
+ tools = r.generate_config.get("tools", None)
696
+ r.full_prompt = self._get_full_prompt(r.prompt, tools)
697
+ if tools:
698
+ r.tools = tools
780
699
  except Exception as e:
781
700
  logger.exception(f"prepare inference error with {e}")
782
701
  r.stopped = True
783
702
  r.error_msg = str(e)
784
703
 
704
+ def handle_chat_result_non_streaming(self, req: InferenceRequest):
705
+ if req.tools:
706
+ req.completion[0] = self._tool_calls_completion(
707
+ self.model_family, self.model_uid, req.completion[0]
708
+ )
709
+ else:
710
+ req.completion[0] = self._to_chat_completion(req.completion[0])
711
+
712
+ def handle_chat_result_streaming(self, req: InferenceRequest):
713
+ results = []
714
+ for i, c in enumerate(req.completion):
715
+ if c == "<bos_stream>":
716
+ results.append(
717
+ self._get_first_chat_completion_chunk(req.completion[i + 1])
718
+ )
719
+ elif c == "<eos_stream>":
720
+ break
721
+ else:
722
+ results.append(self._to_chat_completion_chunk(c))
723
+
724
+ if req.stopped and req.include_usage:
725
+ results.append(self._get_final_chat_completion_chunk(req.completion[-1]))
726
+ req.completion = results
727
+
785
728
  def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
786
729
  for req in req_list:
787
730
  if req.error_msg is None and req.completion:
@@ -800,23 +743,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
800
743
  continue
801
744
 
802
745
  if req.stream:
803
- results = []
804
- for i, c in enumerate(req.completion):
805
- if c == "<bos_stream>":
806
- results.append(
807
- self._get_first_chat_completion_chunk(
808
- req.completion[i + 1]
809
- )
810
- )
811
- elif c == "<eos_stream>":
812
- break
813
- else:
814
- results.append(self._to_chat_completion_chunk(c))
815
-
816
- if req.stopped and req.include_usage:
817
- results.append(
818
- self._get_final_chat_completion_chunk(req.completion[-1])
819
- )
820
- req.completion = results
746
+ self.handle_chat_result_streaming(req)
821
747
  else:
822
- req.completion[0] = self._to_chat_completion(req.completion[0])
748
+ self.handle_chat_result_non_streaming(req)