xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,340 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from typing import Dict, Iterator, List, Optional, Union
17
+
18
+ import torch
19
+
20
+ from ....types import (
21
+ ChatCompletion,
22
+ ChatCompletionChunk,
23
+ Completion,
24
+ CompletionChunk,
25
+ PytorchGenerateConfig,
26
+ )
27
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
28
+ from ..utils import (
29
+ generate_chat_completion,
30
+ generate_completion,
31
+ generate_completion_chunk,
32
+ )
33
+ from .core import PytorchChatModel, PytorchModel
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class DeepSeekV2PytorchModel(PytorchModel):
39
+ def _load_model(self, **kwargs):
40
+ try:
41
+ from transformers import (
42
+ AutoModelForCausalLM,
43
+ AutoTokenizer,
44
+ GenerationConfig,
45
+ )
46
+ except ImportError:
47
+ error_message = "Failed to import module 'transformers'"
48
+ installation_guide = [
49
+ "Please make sure 'transformers' is installed. ",
50
+ "You can install it by `pip install transformers`\n",
51
+ ]
52
+
53
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ self.model_path,
57
+ trust_remote_code=kwargs["trust_remote_code"],
58
+ )
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ self.model_path,
61
+ attn_implementation="eager",
62
+ torch_dtype=torch.bfloat16,
63
+ trust_remote_code=True,
64
+ device_map="auto",
65
+ )
66
+ model.generation_config = GenerationConfig.from_pretrained(self.model_path)
67
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
68
+ return model, tokenizer
69
+
70
+ @classmethod
71
+ def match(
72
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
73
+ ) -> bool:
74
+ if llm_spec.model_format != "pytorch":
75
+ return False
76
+ model_family = llm_family.model_family or llm_family.model_name
77
+ if "deepseek-v2" not in model_family:
78
+ return False
79
+ if "generate" not in llm_family.model_ability:
80
+ return False
81
+ return True
82
+
83
+ def generate(
84
+ self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
85
+ ) -> Union[Completion, Iterator[CompletionChunk]]:
86
+ input_tensor = self._tokenizer(prompt, return_tensors="pt")
87
+ generate_config = self._sanitize_generate_config(generate_config)
88
+ default_generate_config = self._model.generation_config
89
+ generate_kwargs = {
90
+ "input_ids": input_tensor["input_ids"].cuda(),
91
+ "attention_mask": input_tensor["attention_mask"].cuda(),
92
+ "temperature": float(
93
+ generate_config.get("temperature", default_generate_config.temperature)
94
+ ),
95
+ "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
96
+ "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
97
+ "top_k": int(generate_config.get("top_k", -1)),
98
+ "max_new_tokens": generate_config.get("max_tokens", 512),
99
+ "bos_token_id": default_generate_config.bos_token_id,
100
+ "do_sample": default_generate_config.do_sample,
101
+ "eos_token_id": default_generate_config.eos_token_id,
102
+ }
103
+
104
+ stream = generate_config.get("stream", False)
105
+ if stream:
106
+ return self._generate_stream(generate_kwargs, input_tensor)
107
+ else:
108
+ return self._generate(generate_kwargs, input_tensor)
109
+
110
+ def _generate(self, generate_kwargs, input_ids) -> Completion:
111
+ prompt_tokens = len(input_ids[0])
112
+ logger.info(f"generate_kwargs:{generate_kwargs}")
113
+ generation_output = self._model.generate(**generate_kwargs)
114
+ completion_tokens = len(generation_output[0])
115
+ response = self._tokenizer.decode(
116
+ generation_output[0], skip_special_tokens=True
117
+ )
118
+ return generate_completion(
119
+ self.model_uid,
120
+ response,
121
+ prompt_tokens=prompt_tokens,
122
+ completion_tokens=completion_tokens,
123
+ total_tokens=prompt_tokens + completion_tokens,
124
+ )
125
+
126
+ def _generate_stream(self, generate_kwargs, input_ids):
127
+ from threading import Thread
128
+
129
+ from transformers import TextIteratorStreamer
130
+
131
+ # Initialize the streamer
132
+ streamer = TextIteratorStreamer(
133
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
134
+ )
135
+ # Define the generation configuration
136
+ generate_kwargs["streamer"] = streamer
137
+ # Start the model chat in a separate thread
138
+ thread = Thread(
139
+ target=self._model.generate,
140
+ kwargs=generate_kwargs,
141
+ )
142
+ thread.start()
143
+
144
+ completion_id = str(uuid.uuid1())
145
+ prompt_tokens = len(input_ids[0])
146
+ total_tokens, completion_tokens = 0, 0
147
+ # Loop through the streamer to get the new text as it is generated
148
+ for i, new_text in enumerate(streamer):
149
+ completion_tokens = i
150
+ total_tokens = prompt_tokens + completion_tokens
151
+ yield generate_completion_chunk(
152
+ chunk_text=new_text,
153
+ finish_reason=None,
154
+ chunk_id=completion_id,
155
+ model_uid=self.model_uid,
156
+ prompt_tokens=prompt_tokens,
157
+ completion_tokens=completion_tokens,
158
+ total_tokens=total_tokens,
159
+ )
160
+ yield generate_completion_chunk(
161
+ chunk_text=None,
162
+ finish_reason="stop",
163
+ chunk_id=completion_id,
164
+ model_uid=self.model_uid,
165
+ prompt_tokens=prompt_tokens,
166
+ completion_tokens=completion_tokens,
167
+ total_tokens=total_tokens,
168
+ has_choice=True,
169
+ has_content=False,
170
+ )
171
+
172
+
173
+ class DeepSeekV2PytorchChatModel(PytorchChatModel):
174
+ def _load_model(self, **kwargs):
175
+ try:
176
+ from transformers import (
177
+ AutoModelForCausalLM,
178
+ AutoTokenizer,
179
+ GenerationConfig,
180
+ )
181
+ except ImportError:
182
+ error_message = "Failed to import module 'transformers'"
183
+ installation_guide = [
184
+ "Please make sure 'transformers' is installed. ",
185
+ "You can install it by `pip install transformers`\n",
186
+ ]
187
+
188
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
189
+
190
+ tokenizer = AutoTokenizer.from_pretrained(
191
+ self.model_path,
192
+ trust_remote_code=kwargs["trust_remote_code"],
193
+ )
194
+ logger.info(f"kwargs:{kwargs}")
195
+ model = AutoModelForCausalLM.from_pretrained(
196
+ self.model_path,
197
+ attn_implementation="eager",
198
+ torch_dtype=torch.bfloat16,
199
+ trust_remote_code=True,
200
+ device_map="auto",
201
+ )
202
+ model.generation_config = GenerationConfig.from_pretrained(self.model_path)
203
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
204
+ return model, tokenizer
205
+
206
+ @classmethod
207
+ def match(
208
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
209
+ ) -> bool:
210
+ if llm_spec.model_format != "pytorch":
211
+ return False
212
+ model_family = llm_family.model_family or llm_family.model_name
213
+ if "deepseek-v2" not in model_family:
214
+ return False
215
+ if "chat" not in llm_family.model_ability:
216
+ return False
217
+ return True
218
+
219
+ def chat(
220
+ self,
221
+ messages: List[Dict],
222
+ generate_config: Optional[PytorchGenerateConfig] = None,
223
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
224
+ assert self.model_family.chat_template is not None
225
+ full_prompt = self.get_full_context(
226
+ messages,
227
+ self.model_family.chat_template,
228
+ tokenizer=self._tokenizer,
229
+ )
230
+ input_tensor = self._tokenizer.encode(
231
+ full_prompt,
232
+ padding=False,
233
+ truncation=False,
234
+ max_length=None,
235
+ add_special_tokens=False,
236
+ return_tensors="pt",
237
+ )
238
+
239
+ generate_config = self._sanitize_generate_config(generate_config)
240
+ default_generate_config = self._model.generation_config
241
+ generate_kwargs = {
242
+ "input_ids": input_tensor.cuda(),
243
+ "temperature": float(
244
+ generate_config.get("temperature", default_generate_config.temperature)
245
+ ),
246
+ "repetition_penalty": float(generate_config.get("repetition_penalty", 1.0)),
247
+ "top_p": float(generate_config.get("top_p", default_generate_config.top_p)),
248
+ "top_k": int(generate_config.get("top_k", -1)),
249
+ "max_new_tokens": generate_config.get("max_tokens", 512),
250
+ "bos_token_id": default_generate_config.bos_token_id,
251
+ "do_sample": default_generate_config.do_sample,
252
+ "eos_token_id": default_generate_config.eos_token_id,
253
+ }
254
+
255
+ stream = generate_config.get("stream", False)
256
+ stream_options = generate_config.get("stream_options", None)
257
+ include_usage = (
258
+ stream_options["include_usage"]
259
+ if isinstance(stream_options, dict)
260
+ else False
261
+ )
262
+ if stream:
263
+ chunk = self._generate_stream(generate_kwargs, input_tensor, include_usage)
264
+ return self._to_chat_completion_chunks(chunk)
265
+ else:
266
+ return self._generate(generate_kwargs, input_tensor)
267
+
268
+ def _generate(self, generate_kwargs, input_ids) -> ChatCompletion:
269
+ prompt_tokens = len(input_ids[0])
270
+ generation_output = self._model.generate(**generate_kwargs)
271
+ completion_tokens = len(generation_output[0])
272
+ response = self._tokenizer.decode(
273
+ generation_output[0][input_ids.shape[1] :], skip_special_tokens=True
274
+ )
275
+ return generate_chat_completion(
276
+ self.model_uid,
277
+ response,
278
+ prompt_tokens=prompt_tokens,
279
+ completion_tokens=completion_tokens,
280
+ total_tokens=prompt_tokens + completion_tokens,
281
+ )
282
+
283
+ def _generate_stream(self, generate_kwargs, input_ids, include_usage):
284
+ from threading import Thread
285
+
286
+ from transformers import TextIteratorStreamer
287
+
288
+ # Initialize the streamer
289
+ streamer = TextIteratorStreamer(
290
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10
291
+ )
292
+ # Define the generation configuration
293
+ generate_kwargs["streamer"] = streamer
294
+ # Start the model chat in a separate thread
295
+ thread = Thread(
296
+ target=self._model.generate,
297
+ kwargs=generate_kwargs,
298
+ )
299
+ thread.start()
300
+
301
+ completion_id = str(uuid.uuid1())
302
+ prompt_tokens = len(input_ids[0])
303
+ total_tokens, completion_tokens = 0, 0
304
+ # Loop through the streamer to get the new text as it is generated
305
+ for i, new_text in enumerate(streamer):
306
+ completion_tokens = max(completion_tokens, len(streamer.token_cache))
307
+ total_tokens = prompt_tokens + completion_tokens
308
+ yield generate_completion_chunk(
309
+ chunk_text=new_text,
310
+ finish_reason=None,
311
+ chunk_id=completion_id,
312
+ model_uid=self.model_uid,
313
+ prompt_tokens=prompt_tokens,
314
+ completion_tokens=completion_tokens,
315
+ total_tokens=total_tokens,
316
+ )
317
+ yield generate_completion_chunk(
318
+ chunk_text=None,
319
+ finish_reason="stop",
320
+ chunk_id=completion_id,
321
+ model_uid=self.model_uid,
322
+ prompt_tokens=prompt_tokens,
323
+ completion_tokens=completion_tokens,
324
+ total_tokens=total_tokens,
325
+ has_choice=True,
326
+ has_content=False,
327
+ )
328
+
329
+ if include_usage:
330
+ yield generate_completion_chunk(
331
+ chunk_text=None,
332
+ finish_reason=None,
333
+ chunk_id=completion_id,
334
+ model_uid=self.model_uid,
335
+ prompt_tokens=prompt_tokens,
336
+ completion_tokens=completion_tokens,
337
+ total_tokens=total_tokens,
338
+ has_choice=False,
339
+ has_content=False,
340
+ )
@@ -15,7 +15,6 @@ import base64
15
15
  import logging
16
16
  import os.path
17
17
  import tempfile
18
- import time
19
18
  import uuid
20
19
  from concurrent.futures import ThreadPoolExecutor
21
20
  from io import BytesIO
@@ -25,16 +24,9 @@ import requests
25
24
  import torch
26
25
 
27
26
  from ....model.utils import select_device
28
- from ....types import (
29
- ChatCompletion,
30
- ChatCompletionChunk,
31
- ChatCompletionMessage,
32
- Completion,
33
- CompletionChoice,
34
- CompletionChunk,
35
- CompletionUsage,
36
- )
27
+ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
37
28
  from ..llm_family import LLMFamilyV1, LLMSpecV1
29
+ from ..utils import generate_chat_completion, generate_completion_chunk
38
30
  from .core import PytorchChatModel, PytorchGenerateConfig
39
31
 
40
32
  logger = logging.getLogger(__name__)
@@ -147,9 +139,7 @@ class DeepSeekVLChatModel(PytorchChatModel):
147
139
 
148
140
  def chat(
149
141
  self,
150
- prompt: Union[str, List[Dict]],
151
- system_prompt: Optional[str] = None,
152
- chat_history: Optional[List[ChatCompletionMessage]] = None,
142
+ messages: List[Dict],
153
143
  generate_config: Optional[PytorchGenerateConfig] = None,
154
144
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
155
145
  if not generate_config:
@@ -162,44 +152,40 @@ class DeepSeekVLChatModel(PytorchChatModel):
162
152
  if isinstance(stream_options, dict)
163
153
  else False
164
154
  )
165
- prompt, images = self._message_content_to_deepseek(prompt)
166
- prompt_messages: List[Dict[str, Any]] = [
167
- {
168
- "role": "User",
169
- "content": prompt,
170
- },
171
- {"role": "Assistant", "content": ""},
172
- ]
173
- if images:
174
- prompt_messages[0]["images"] = images
175
-
176
- # Convert openai history to qwen vl history
177
- deepseek_history = []
178
- for h in chat_history or []:
179
- role = h["role"]
155
+
156
+ prompt = ""
157
+ deepseek_messages = []
158
+ for i, message in enumerate(messages):
159
+ role = message["role"]
160
+ content = message["content"]
180
161
  if role == "user":
181
- content, images = self._message_content_to_deepseek(h["content"])
182
- msg: Dict[str, Any] = {
183
- "role": "User",
184
- "content": content,
185
- }
186
- if images:
187
- msg["images"] = images
188
- deepseek_history.append(msg)
162
+ if isinstance(content, str):
163
+ deepseek_messages.append({"role": "User", "content": content})
164
+ else:
165
+ content, images = self._message_content_to_deepseek(content)
166
+ msg: Dict[str, Any] = {
167
+ "role": "User",
168
+ "content": content,
169
+ }
170
+ if images:
171
+ msg["images"] = images
172
+ deepseek_messages.append(msg)
173
+ if i == len(messages) - 1:
174
+ prompt = content
189
175
  elif role == "assistant":
190
- deepseek_history.append({"role": "Assistant", "content": h["content"]})
176
+ deepseek_messages.append({"role": "Assistant", "content": content})
191
177
  else:
192
- logger.error("Unexpected msg in chat history: %s", h)
193
-
194
- deepseek_history.extend(prompt_messages)
178
+ logger.error(
179
+ f"Unexpected message in messages: role: {role}, message: {message}"
180
+ )
195
181
 
196
182
  from ....thirdparty.deepseek_vl.serve.inference import generate
197
183
  from ....thirdparty.deepseek_vl.utils.io import load_pil_images
198
184
 
199
185
  # load images and prepare for inputs
200
- pil_images = load_pil_images(deepseek_history)
186
+ pil_images = load_pil_images(deepseek_messages)
201
187
  prepare_inputs = self._vl_chat_processor(
202
- conversations=deepseek_history, images=pil_images, force_batchify=True
188
+ conversations=deepseek_messages, images=pil_images, force_batchify=True
203
189
  ).to(self._model.device, self._model.dtype)
204
190
 
205
191
  temperature = generate_config.get("temperature", 0.2)
@@ -226,31 +212,16 @@ class DeepSeekVLChatModel(PytorchChatModel):
226
212
  it = self._generate_stream(streamer, stop_str, include_usage, prompt)
227
213
  return self._to_chat_completion_chunks(it)
228
214
  else:
229
- c = self._generate(streamer, stop_str)
230
- return self._to_chat_completion(c)
215
+ return self._generate(streamer, stop_str)
231
216
 
232
- def _generate(self, streamer, stop_str) -> Completion:
217
+ def _generate(self, streamer, stop_str) -> ChatCompletion:
233
218
  generated_text = ""
234
219
  for new_text in streamer:
235
220
  if new_text.endswith(stop_str):
236
221
  new_text = new_text[: -len(stop_str)]
237
222
  generated_text += new_text
238
223
 
239
- c = Completion(
240
- id=str(uuid.uuid1()),
241
- object="text_completion",
242
- created=int(time.time()),
243
- model=self.model_uid,
244
- choices=[
245
- CompletionChoice(
246
- index=0, text=generated_text, finish_reason="stop", logprobs=None
247
- )
248
- ],
249
- usage=CompletionUsage(
250
- prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
251
- ),
252
- )
253
- return c
224
+ return generate_chat_completion(self.model_uid, generated_text)
254
225
 
255
226
  def _generate_stream(
256
227
  self, streamer, stop_str, include_usage, prompt
@@ -262,54 +233,40 @@ class DeepSeekVLChatModel(PytorchChatModel):
262
233
  for i, new_text in enumerate(streamer):
263
234
  if new_text.endswith(stop_str):
264
235
  new_text = new_text[: -len(stop_str)]
265
- completion_choice = CompletionChoice(
266
- text=new_text, index=0, logprobs=None, finish_reason=None
267
- )
268
- chunk = CompletionChunk(
269
- id=completion_id,
270
- object="text_completion",
271
- created=int(time.time()),
272
- model=self.model_uid,
273
- choices=[completion_choice],
274
- )
275
236
  completion_tokens = i
276
237
  total_tokens = prompt_tokens + completion_tokens
277
- completion_usage = CompletionUsage(
238
+ yield generate_completion_chunk(
239
+ chunk_text=new_text,
240
+ finish_reason=None,
241
+ chunk_id=completion_id,
242
+ model_uid=self.model_uid,
278
243
  prompt_tokens=prompt_tokens,
279
244
  completion_tokens=completion_tokens,
280
245
  total_tokens=total_tokens,
246
+ has_choice=True,
247
+ has_content=True,
281
248
  )
282
- chunk["usage"] = completion_usage
283
- yield chunk
284
-
285
- completion_choice = CompletionChoice(
286
- text="", index=0, logprobs=None, finish_reason="stop"
287
- )
288
- chunk = CompletionChunk(
289
- id=completion_id,
290
- object="text_completion",
291
- created=int(time.time()),
292
- model=self.model_uid,
293
- choices=[completion_choice],
294
- )
295
- completion_usage = CompletionUsage(
249
+ yield generate_completion_chunk(
250
+ chunk_text=None,
251
+ finish_reason="stop",
252
+ chunk_id=completion_id,
253
+ model_uid=self.model_uid,
296
254
  prompt_tokens=prompt_tokens,
297
255
  completion_tokens=completion_tokens,
298
256
  total_tokens=total_tokens,
257
+ has_choice=True,
258
+ has_content=False,
299
259
  )
300
- chunk["usage"] = completion_usage
301
- yield chunk
260
+
302
261
  if include_usage:
303
- chunk = CompletionChunk(
304
- id=completion_id,
305
- object="text_completion",
306
- created=int(time.time()),
307
- model=self.model_uid,
308
- choices=[],
309
- )
310
- chunk["usage"] = CompletionUsage(
262
+ yield generate_completion_chunk(
263
+ chunk_text=None,
264
+ finish_reason=None,
265
+ chunk_id=completion_id,
266
+ model_uid=self.model_uid,
311
267
  prompt_tokens=prompt_tokens,
312
268
  completion_tokens=completion_tokens,
313
269
  total_tokens=total_tokens,
270
+ has_choice=False,
271
+ has_content=False,
314
272
  )
315
- yield chunk