xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (85) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +108 -14
  3. xinference/client/restful/restful_client.py +78 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/event.py +5 -6
  7. xinference/core/model.py +59 -42
  8. xinference/core/scheduler.py +46 -18
  9. xinference/core/supervisor.py +73 -24
  10. xinference/core/worker.py +68 -2
  11. xinference/deploy/cmdline.py +86 -2
  12. xinference/deploy/test/test_cmdline.py +19 -10
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/core.py +12 -1
  15. xinference/model/audio/custom.py +6 -4
  16. xinference/model/audio/model_spec_modelscope.json +20 -0
  17. xinference/model/llm/__init__.py +34 -2
  18. xinference/model/llm/llm_family.json +8 -2
  19. xinference/model/llm/llm_family.py +86 -1
  20. xinference/model/llm/llm_family_csghub.json +66 -0
  21. xinference/model/llm/llm_family_modelscope.json +8 -2
  22. xinference/model/llm/pytorch/chatglm.py +41 -12
  23. xinference/model/llm/pytorch/core.py +128 -88
  24. xinference/model/llm/pytorch/glm4v.py +24 -3
  25. xinference/model/llm/pytorch/internlm2.py +15 -0
  26. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  27. xinference/model/llm/pytorch/utils.py +69 -189
  28. xinference/model/llm/utils.py +27 -14
  29. xinference/model/llm/vllm/core.py +10 -4
  30. xinference/model/rerank/core.py +35 -6
  31. xinference/model/utils.py +8 -2
  32. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  33. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  34. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  35. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  36. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  38. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  39. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  40. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  41. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  42. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  43. xinference/types.py +28 -0
  44. xinference/web/ui/build/asset-manifest.json +6 -6
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
  47. xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
  49. xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  63. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
  64. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
  65. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  66. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  67. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  68. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
  72. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  74. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
  81. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
  82. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
  83. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
  84. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
  85. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
@@ -32,10 +32,15 @@ from ..._compat import (
32
32
  load_str_bytes,
33
33
  validator,
34
34
  )
35
- from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
35
+ from ...constants import (
36
+ XINFERENCE_CACHE_DIR,
37
+ XINFERENCE_ENV_CSG_TOKEN,
38
+ XINFERENCE_MODEL_DIR,
39
+ )
36
40
  from ..utils import (
37
41
  IS_NEW_HUGGINGFACE_HUB,
38
42
  create_symlink,
43
+ download_from_csghub,
39
44
  download_from_modelscope,
40
45
  is_valid_model_uri,
41
46
  parse_uri,
@@ -232,6 +237,7 @@ LLAMA_CLASSES: List[Type[LLM]] = []
232
237
 
233
238
  BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
234
239
  BUILTIN_MODELSCOPE_LLM_FAMILIES: List["LLMFamilyV1"] = []
240
+ BUILTIN_CSGHUB_LLM_FAMILIES: List["LLMFamilyV1"] = []
235
241
 
236
242
  SGLANG_CLASSES: List[Type[LLM]] = []
237
243
  TRANSFORMERS_CLASSES: List[Type[LLM]] = []
@@ -292,6 +298,9 @@ def cache(
292
298
  elif llm_spec.model_hub == "modelscope":
293
299
  logger.info(f"Caching from Modelscope: {llm_spec.model_id}")
294
300
  return cache_from_modelscope(llm_family, llm_spec, quantization)
301
+ elif llm_spec.model_hub == "csghub":
302
+ logger.info(f"Caching from CSGHub: {llm_spec.model_id}")
303
+ return cache_from_csghub(llm_family, llm_spec, quantization)
295
304
  else:
296
305
  raise ValueError(f"Unknown model hub: {llm_spec.model_hub}")
297
306
 
@@ -566,6 +575,7 @@ def _skip_download(
566
575
  "modelscope": _get_meta_path(
567
576
  cache_dir, model_format, "modelscope", quantization
568
577
  ),
578
+ "csghub": _get_meta_path(cache_dir, model_format, "csghub", quantization),
569
579
  }
570
580
  if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision):
571
581
  logger.info(f"Cache {cache_dir} exists")
@@ -650,6 +660,75 @@ def _merge_cached_files(
650
660
  logger.info(f"Merge complete.")
651
661
 
652
662
 
663
+ def cache_from_csghub(
664
+ llm_family: LLMFamilyV1,
665
+ llm_spec: "LLMSpecV1",
666
+ quantization: Optional[str] = None,
667
+ ) -> str:
668
+ """
669
+ Cache model from CSGHub. Return the cache directory.
670
+ """
671
+ from pycsghub.file_download import file_download
672
+ from pycsghub.snapshot_download import snapshot_download
673
+
674
+ cache_dir = _get_cache_dir(llm_family, llm_spec)
675
+
676
+ if _skip_download(
677
+ cache_dir,
678
+ llm_spec.model_format,
679
+ llm_spec.model_hub,
680
+ llm_spec.model_revision,
681
+ quantization,
682
+ ):
683
+ return cache_dir
684
+
685
+ if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
686
+ download_dir = retry_download(
687
+ snapshot_download,
688
+ llm_family.model_name,
689
+ {
690
+ "model_size": llm_spec.model_size_in_billions,
691
+ "model_format": llm_spec.model_format,
692
+ },
693
+ llm_spec.model_id,
694
+ endpoint="https://hub-stg.opencsg.com",
695
+ token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
696
+ )
697
+ create_symlink(download_dir, cache_dir)
698
+
699
+ elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
700
+ file_names, final_file_name, need_merge = _generate_model_file_names(
701
+ llm_spec, quantization
702
+ )
703
+
704
+ for filename in file_names:
705
+ download_path = retry_download(
706
+ file_download,
707
+ llm_family.model_name,
708
+ {
709
+ "model_size": llm_spec.model_size_in_billions,
710
+ "model_format": llm_spec.model_format,
711
+ },
712
+ llm_spec.model_id,
713
+ file_name=filename,
714
+ endpoint="https://hub-stg.opencsg.com",
715
+ token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
716
+ )
717
+ symlink_local_file(download_path, cache_dir, filename)
718
+
719
+ if need_merge:
720
+ _merge_cached_files(cache_dir, file_names, final_file_name)
721
+ else:
722
+ raise ValueError(f"Unsupported format: {llm_spec.model_format}")
723
+
724
+ meta_path = _get_meta_path(
725
+ cache_dir, llm_spec.model_format, llm_spec.model_hub, quantization
726
+ )
727
+ _generate_meta_file(meta_path, llm_family, llm_spec, quantization)
728
+
729
+ return cache_dir
730
+
731
+
653
732
  def cache_from_modelscope(
654
733
  llm_family: LLMFamilyV1,
655
734
  llm_spec: "LLMSpecV1",
@@ -931,6 +1010,12 @@ def match_llm(
931
1010
  + BUILTIN_LLM_FAMILIES
932
1011
  + user_defined_llm_families
933
1012
  )
1013
+ elif download_from_csghub():
1014
+ all_families = (
1015
+ BUILTIN_CSGHUB_LLM_FAMILIES
1016
+ + BUILTIN_LLM_FAMILIES
1017
+ + user_defined_llm_families
1018
+ )
934
1019
  else:
935
1020
  all_families = BUILTIN_LLM_FAMILIES + user_defined_llm_families
936
1021
 
@@ -0,0 +1,66 @@
1
+ [
2
+ {
3
+ "version": 1,
4
+ "context_length": 32768,
5
+ "model_name": "qwen2-instruct",
6
+ "model_lang": [
7
+ "en",
8
+ "zh"
9
+ ],
10
+ "model_ability": [
11
+ "chat",
12
+ "tools"
13
+ ],
14
+ "model_description": "Qwen2 is the new series of Qwen large language models",
15
+ "model_specs": [
16
+ {
17
+ "model_format": "pytorch",
18
+ "model_size_in_billions": "0_5",
19
+ "quantizations": [
20
+ "4-bit",
21
+ "8-bit",
22
+ "none"
23
+ ],
24
+ "model_id": "Qwen/Qwen2-0.5B-Instruct",
25
+ "model_hub": "csghub"
26
+ },
27
+ {
28
+ "model_format": "ggufv2",
29
+ "model_size_in_billions": "0_5",
30
+ "quantizations": [
31
+ "q2_k",
32
+ "q3_k_m",
33
+ "q4_0",
34
+ "q4_k_m",
35
+ "q5_0",
36
+ "q5_k_m",
37
+ "q6_k",
38
+ "q8_0",
39
+ "fp16"
40
+ ],
41
+ "model_id": "qwen/Qwen2-0.5B-Instruct-GGUF",
42
+ "model_file_name_template": "qwen2-0_5b-instruct-{quantization}.gguf",
43
+ "model_hub": "csghub"
44
+ }
45
+ ],
46
+ "prompt_style": {
47
+ "style_name": "QWEN",
48
+ "system_prompt": "You are a helpful assistant.",
49
+ "roles": [
50
+ "user",
51
+ "assistant"
52
+ ],
53
+ "intra_message_sep": "\n",
54
+ "stop_token_ids": [
55
+ 151643,
56
+ 151644,
57
+ 151645
58
+ ],
59
+ "stop": [
60
+ "<|endoftext|>",
61
+ "<|im_start|>",
62
+ "<|im_end|>"
63
+ ]
64
+ }
65
+ }
66
+ ]
@@ -632,6 +632,8 @@
632
632
  "model_format": "pytorch",
633
633
  "model_size_in_billions": 9,
634
634
  "quantizations": [
635
+ "4-bit",
636
+ "8-bit",
635
637
  "none"
636
638
  ],
637
639
  "model_hub": "modelscope",
@@ -2642,7 +2644,8 @@
2642
2644
  "zh"
2643
2645
  ],
2644
2646
  "model_ability": [
2645
- "chat"
2647
+ "chat",
2648
+ "tools"
2646
2649
  ],
2647
2650
  "model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
2648
2651
  "model_specs": [
@@ -2966,7 +2969,8 @@
2966
2969
  "zh"
2967
2970
  ],
2968
2971
  "model_ability": [
2969
- "chat"
2972
+ "chat",
2973
+ "tools"
2970
2974
  ],
2971
2975
  "model_description": "Qwen2 is the new series of Qwen large language models. ",
2972
2976
  "model_specs": [
@@ -3348,9 +3352,11 @@
3348
3352
  ],
3349
3353
  "intra_message_sep": "<|im_end|>",
3350
3354
  "stop_token_ids": [
3355
+ 2,
3351
3356
  92542
3352
3357
  ],
3353
3358
  "stop": [
3359
+ "</s>",
3354
3360
  "<|im_end|>"
3355
3361
  ]
3356
3362
  }
@@ -15,6 +15,7 @@ import time
15
15
  import uuid
16
16
  from typing import Any, Dict, Iterator, List, Optional, Union
17
17
 
18
+ from ....core.scheduler import InferenceRequest
18
19
  from ....types import (
19
20
  SPECIAL_TOOL_PROMPT,
20
21
  ChatCompletion,
@@ -89,24 +90,30 @@ class ChatglmPytorchChatModel(PytorchChatModel):
89
90
  return False
90
91
  return True
91
92
 
92
- @staticmethod
93
- def _handle_tools(generate_config) -> Optional[dict]:
93
+ def _handle_tools(self, generate_config) -> Optional[dict]:
94
94
  """Convert openai tools to ChatGLM tools."""
95
95
  if generate_config is None:
96
96
  return None
97
97
  tools = generate_config.pop("tools", None)
98
98
  if tools is None:
99
99
  return None
100
- chatglm_tools = []
101
- for elem in tools:
102
- if elem.get("type") != "function" or "function" not in elem:
103
- raise ValueError("ChatGLM tools only support function type.")
104
- chatglm_tools.append(elem["function"])
105
- return {
106
- "role": "system",
107
- "content": f"Answer the following questions as best as you can. You have access to the following tools:",
108
- "tools": chatglm_tools,
109
- }
100
+ if self.model_family.model_name == "glm4-chat":
101
+ return {
102
+ "role": "system",
103
+ "content": None,
104
+ "tools": tools,
105
+ }
106
+ else:
107
+ chatglm_tools = []
108
+ for elem in tools:
109
+ if elem.get("type") != "function" or "function" not in elem:
110
+ raise ValueError("ChatGLM tools only support function type.")
111
+ chatglm_tools.append(elem["function"])
112
+ return {
113
+ "role": "system",
114
+ "content": f"Answer the following questions as best as you can. You have access to the following tools:",
115
+ "tools": chatglm_tools,
116
+ }
110
117
 
111
118
  def chat(
112
119
  self,
@@ -238,3 +245,25 @@ class ChatglmPytorchChatModel(PytorchChatModel):
238
245
  prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
239
246
  ),
240
247
  )
248
+
249
+ @staticmethod
250
+ def require_attention_mask():
251
+ """
252
+ GLM4 needs to use attention mask and position ids during inference.
253
+ Otherwise, the inference result would be not available.
254
+ """
255
+ return True
256
+
257
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
258
+ """
259
+ Set temperature and top_p to 0.8 by default
260
+ """
261
+ raw_config = req.inference_kwargs.get("raw_params", {})
262
+ temperature = raw_config.get("temperature", None)
263
+ if temperature is None:
264
+ raw_config["temperature"] = 0.8
265
+ top_p = raw_config.get("top_p", None)
266
+ if top_p is None:
267
+ raw_config["top_p"] = 0.8
268
+
269
+ return raw_config
@@ -15,7 +15,8 @@
15
15
  import json
16
16
  import logging
17
17
  import os
18
- from typing import Iterable, Iterator, List, Optional, Union
18
+ from functools import lru_cache
19
+ from typing import Iterable, Iterator, List, Optional, Tuple, Union
19
20
 
20
21
  from ....core.scheduler import InferenceRequest
21
22
  from ....device_utils import (
@@ -28,6 +29,7 @@ from ....types import (
28
29
  ChatCompletionChunk,
29
30
  ChatCompletionMessage,
30
31
  Completion,
32
+ CompletionChoice,
31
33
  CompletionChunk,
32
34
  CreateCompletionTorch,
33
35
  Embedding,
@@ -281,35 +283,21 @@ class PytorchModel(LLM):
281
283
  def generate(
282
284
  self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
283
285
  ) -> Union[Completion, Iterator[CompletionChunk]]:
284
- from .utils import generate_stream, generate_stream_falcon
285
-
286
- model_family_name = self.model_family.model_name.lower()
286
+ from .utils import generate_stream
287
287
 
288
288
  def generator_wrapper(
289
289
  prompt: str, generate_config: PytorchGenerateConfig
290
290
  ) -> Iterator[CompletionChunk]:
291
- if "falcon" in model_family_name:
292
- for completion_chunk, completion_usage in generate_stream_falcon(
293
- self.model_uid,
294
- self._model,
295
- self._tokenizer,
296
- prompt,
297
- self._device,
298
- generate_config,
299
- ):
300
- completion_chunk["usage"] = completion_usage
301
- yield completion_chunk
302
- else:
303
- for completion_chunk, completion_usage in generate_stream(
304
- self.model_uid,
305
- self._model,
306
- self._tokenizer,
307
- prompt,
308
- self._device,
309
- generate_config,
310
- ):
311
- completion_chunk["usage"] = completion_usage
312
- yield completion_chunk
291
+ for completion_chunk, completion_usage in generate_stream(
292
+ self.model_uid,
293
+ self._model,
294
+ self._tokenizer,
295
+ prompt,
296
+ self._device,
297
+ generate_config,
298
+ ):
299
+ completion_chunk["usage"] = completion_usage
300
+ yield completion_chunk
313
301
 
314
302
  logger.debug(
315
303
  "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
@@ -334,26 +322,15 @@ class PytorchModel(LLM):
334
322
 
335
323
  stream = generate_config.get("stream", False)
336
324
  if not stream:
337
- if "falcon" in model_family_name:
338
- for completion_chunk, completion_usage in generate_stream_falcon(
339
- self.model_uid,
340
- self._model,
341
- self._tokenizer,
342
- prompt,
343
- self._device,
344
- generate_config,
345
- ):
346
- pass
347
- else:
348
- for completion_chunk, completion_usage in generate_stream(
349
- self.model_uid,
350
- self._model,
351
- self._tokenizer,
352
- prompt,
353
- self._device,
354
- generate_config,
355
- ):
356
- pass
325
+ for completion_chunk, completion_usage in generate_stream(
326
+ self.model_uid,
327
+ self._model,
328
+ self._tokenizer,
329
+ prompt,
330
+ self._device,
331
+ generate_config,
332
+ ):
333
+ pass
357
334
  completion = Completion(
358
335
  id=completion_chunk["id"],
359
336
  object=completion_chunk["object"],
@@ -366,6 +343,105 @@ class PytorchModel(LLM):
366
343
  else:
367
344
  return generator_wrapper(prompt, generate_config)
368
345
 
346
+ @staticmethod
347
+ def require_attention_mask():
348
+ return False
349
+
350
+ @lru_cache
351
+ def get_context_len(self):
352
+ return get_context_length(self._model.config)
353
+
354
+ def get_max_num_seqs(self) -> int:
355
+ return self._pytorch_model_config.get("max_num_seqs") # type: ignore
356
+
357
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
358
+ return self._sanitize_generate_config(req.generate_config)
359
+
360
+ def prepare_batch_inference(self, req_list: List[InferenceRequest]):
361
+ # check some parameters
362
+ for r in req_list:
363
+ if r.sanitized_generate_config is None:
364
+ r.sanitized_generate_config = self.prepare_sanitize_generate_config(r)
365
+ if r.is_prefill:
366
+ # check some generate params
367
+ max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
368
+ if max_src_len < 0:
369
+ r.stopped = True
370
+ r.error_msg = "Max tokens exceeds model's max length"
371
+ continue
372
+ if r.stream_interval <= 0:
373
+ r.stopped = True
374
+ r.error_msg = "`stream_interval` must be greater than 0"
375
+ continue
376
+ stop_str = r.sanitized_generate_config.get("stop", None)
377
+ if stop_str and (
378
+ not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
379
+ ):
380
+ r.stopped = True
381
+ r.error_msg = "Invalid `stop` field type"
382
+ continue
383
+
384
+ def _get_builtin_stop_token_ids(self) -> Tuple:
385
+ return (
386
+ tuple(self.model_family.prompt_style.stop_token_ids)
387
+ if self.model_family.prompt_style
388
+ and self.model_family.prompt_style.stop_token_ids
389
+ else tuple()
390
+ )
391
+
392
+ def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
393
+ for req in req_list:
394
+ if req.error_msg is None:
395
+ # nothing need handle for non-stream case
396
+ if req.stream:
397
+ results = []
398
+ for i, c in enumerate(req.completion):
399
+ if c == "<bos_stream>":
400
+ chunk = req.completion[i + 1]
401
+ results.append(
402
+ CompletionChunk(
403
+ id=chunk["id"],
404
+ object=chunk["object"],
405
+ created=chunk["created"],
406
+ model=chunk["model"],
407
+ choices=[
408
+ CompletionChoice(
409
+ text="",
410
+ index=0,
411
+ logprobs=None,
412
+ finish_reason=None,
413
+ )
414
+ ],
415
+ )
416
+ )
417
+ continue
418
+ elif c == "<eos_stream>":
419
+ break
420
+ else:
421
+ results.append(c)
422
+
423
+ if req.stopped and req.include_usage:
424
+ results.append(req.completion[-1])
425
+ req.completion = results
426
+
427
+ def batch_inference(self, req_list: List[InferenceRequest]):
428
+ from .utils import batch_inference_one_step
429
+
430
+ self.prepare_batch_inference(req_list)
431
+ context_len = self.get_context_len()
432
+ assert isinstance(context_len, int)
433
+ batch_inference_one_step(
434
+ req_list,
435
+ self.model_uid,
436
+ self._model,
437
+ self._tokenizer,
438
+ self._device,
439
+ context_len,
440
+ self._get_builtin_stop_token_ids(),
441
+ require_attention_mask=self.require_attention_mask(),
442
+ )
443
+ self.handle_batch_inference_results(req_list)
444
+
369
445
  def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
370
446
  try:
371
447
  import torch
@@ -464,7 +540,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
464
540
  pytorch_model_config,
465
541
  peft_model,
466
542
  )
467
- self._context_len = None
468
543
 
469
544
  def _sanitize_generate_config(
470
545
  self,
@@ -540,7 +615,6 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
540
615
 
541
616
  def load(self):
542
617
  super().load()
543
- self._context_len = get_context_length(self._model.config)
544
618
 
545
619
  def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
546
620
  assert self.model_family.prompt_style is not None
@@ -553,48 +627,14 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
553
627
  )
554
628
  return full_prompt
555
629
 
556
- def get_max_num_seqs(self) -> int:
557
- return self._pytorch_model_config.get("max_num_seqs") # type: ignore
558
-
559
- def batch_inference(self, req_list: List[InferenceRequest]):
560
- from .utils import batch_inference_one_step
561
-
630
+ def prepare_batch_inference(self, req_list: List[InferenceRequest]):
631
+ super().prepare_batch_inference(req_list)
562
632
  for r in req_list:
563
- if r.sanitized_generate_config is None:
564
- r.sanitized_generate_config = self._sanitize_generate_config(
565
- r.generate_config
566
- )
567
- if r.is_prefill:
568
- # check some generate params
569
- max_src_len = get_max_src_len(self._context_len, r) # type: ignore
570
- if max_src_len < 0:
571
- r.stopped = True
572
- r.error_msg = "Max tokens exceeds model's max length"
573
- continue
574
- if r.stream_interval <= 0:
575
- r.stopped = True
576
- r.error_msg = "`stream_interval` must be greater than 0"
577
- continue
578
- stop_str = r.sanitized_generate_config.get("stop", None)
579
- if stop_str and (
580
- not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
581
- ):
582
- r.stopped = True
583
- r.error_msg = "Invalid `stop` field type"
584
- continue
585
- r.full_prompt = self._get_full_prompt(
586
- r.prompt, r.system_prompt, r.chat_history, None
587
- )
633
+ r.full_prompt = self._get_full_prompt(
634
+ r.prompt, r.system_prompt, r.chat_history, None
635
+ )
588
636
 
589
- assert isinstance(self._context_len, int)
590
- batch_inference_one_step(
591
- req_list,
592
- self.model_uid,
593
- self._model,
594
- self._tokenizer,
595
- self._device,
596
- self._context_len,
597
- )
637
+ def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
598
638
  for req in req_list:
599
639
  if req.stream and req.error_msg is None:
600
640
  if req.completion:
@@ -56,19 +56,40 @@ class Glm4VModel(PytorchChatModel):
56
56
  return True
57
57
  return False
58
58
 
59
- def load(self, **kwargs):
59
+ def load(self):
60
60
  from transformers import AutoModelForCausalLM, AutoTokenizer
61
61
 
62
62
  device = self._pytorch_model_config.get("device", "auto")
63
63
  self._device = select_device(device)
64
- self._device = "auto" if self._device == "cuda" else self._device
64
+
65
+ kwargs = {"device_map": self._device}
66
+ quantization = self.quantization
67
+
68
+ # referenced from PytorchModel.load
69
+ if quantization != "none":
70
+ if self._device == "cuda" and self._is_linux():
71
+ kwargs["device_map"] = "auto"
72
+ self._device = "auto"
73
+ if quantization == "4-bit":
74
+ kwargs["load_in_4bit"] = True
75
+ elif quantization == "8-bit":
76
+ kwargs["load_in_8bit"] = True
77
+ else:
78
+ raise ValueError(
79
+ f"Quantization {quantization} is not supported in temporary"
80
+ )
81
+ else:
82
+ if quantization != "8-bit":
83
+ raise ValueError(
84
+ f"Only 8-bit quantization is supported if it is not linux system or cuda device"
85
+ )
65
86
 
66
87
  model = AutoModelForCausalLM.from_pretrained(
67
88
  self.model_path,
68
89
  low_cpu_mem_usage=True,
69
90
  trust_remote_code=True,
70
91
  torch_dtype=torch.float16,
71
- device_map=self._device,
92
+ **kwargs,
72
93
  )
73
94
  self._model = model.eval()
74
95
 
@@ -15,6 +15,7 @@ import time
15
15
  import uuid
16
16
  from typing import Any, Dict, Iterator, List, Optional, Union
17
17
 
18
+ from ....core.scheduler import InferenceRequest
18
19
  from ....types import (
19
20
  ChatCompletion,
20
21
  ChatCompletionChoice,
@@ -88,6 +89,20 @@ class Internlm2PytorchChatModel(PytorchChatModel):
88
89
  return False
89
90
  return True
90
91
 
92
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
93
+ """
94
+ Overwrite this func for this special model.
95
+ Cannot use the default configuration, which works poorly on this model.
96
+ """
97
+ raw_config = req.inference_kwargs.get("raw_params", {})
98
+ temperature = raw_config.get("temperature", None)
99
+ if temperature is None:
100
+ raw_config["temperature"] = 0.8
101
+ top_p = raw_config.get("top_p", None)
102
+ if top_p is None:
103
+ raw_config["top_p"] = 0.8
104
+ return raw_config
105
+
91
106
  def chat(
92
107
  self,
93
108
  prompt: str,
@@ -45,7 +45,7 @@ class QwenVLChatModel(PytorchChatModel):
45
45
  def match(
46
46
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
47
  ) -> bool:
48
- if "qwen" in model_family.model_name:
48
+ if "qwen" in model_family.model_name and "vision" in model_family.model_ability:
49
49
  return True
50
50
  return False
51
51