xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +48 -41
  6. xinference/model/audio/chattts.py +24 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/fish_speech.py +228 -0
  9. xinference/model/audio/model_spec.json +8 -0
  10. xinference/model/embedding/core.py +23 -1
  11. xinference/model/image/model_spec.json +2 -1
  12. xinference/model/image/model_spec_modelscope.json +2 -1
  13. xinference/model/image/stable_diffusion/core.py +49 -1
  14. xinference/model/llm/__init__.py +6 -0
  15. xinference/model/llm/llm_family.json +54 -9
  16. xinference/model/llm/llm_family.py +2 -0
  17. xinference/model/llm/llm_family_modelscope.json +56 -10
  18. xinference/model/llm/lmdeploy/__init__.py +0 -0
  19. xinference/model/llm/lmdeploy/core.py +557 -0
  20. xinference/model/llm/transformers/cogvlm2.py +4 -45
  21. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/glm4v.py +2 -23
  24. xinference/model/llm/transformers/intern_vl.py +94 -11
  25. xinference/model/llm/transformers/minicpmv25.py +2 -23
  26. xinference/model/llm/transformers/minicpmv26.py +2 -22
  27. xinference/model/llm/transformers/yi_vl.py +2 -24
  28. xinference/model/llm/utils.py +10 -1
  29. xinference/model/llm/vllm/core.py +1 -1
  30. xinference/thirdparty/fish_speech/__init__.py +0 -0
  31. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  32. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  33. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  34. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  35. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  36. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  37. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  38. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  39. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  40. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  41. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  42. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  44. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  46. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  48. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  49. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  50. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  52. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  56. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  57. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  58. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  59. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  60. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  63. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  64. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  67. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  68. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  69. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  70. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  71. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  72. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  73. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  74. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  75. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  76. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  77. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  78. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  79. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  83. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  84. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  85. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  86. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  87. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  88. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  89. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  90. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  91. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  92. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  93. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  94. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  95. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  96. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  99. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  100. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  101. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  102. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  103. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  104. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  105. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  106. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  107. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  108. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  109. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  110. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  111. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  112. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  113. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  114. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  115. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  116. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  117. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  118. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  119. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  120. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  121. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  122. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  123. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  124. xinference/web/ui/build/asset-manifest.json +3 -3
  125. xinference/web/ui/build/index.html +1 -1
  126. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  127. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  129. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
  130. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
  131. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  133. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  134. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  135. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  136. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  137. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import time
16
+ import uuid
17
+ from typing import AsyncGenerator, Dict, Iterator, List, Optional, TypedDict, Union
18
+
19
+ import torch
20
+
21
+ from ....types import (
22
+ ChatCompletion,
23
+ ChatCompletionChunk,
24
+ ChatCompletionChunkChoice,
25
+ ChatCompletionMessage,
26
+ Completion,
27
+ CompletionChoice,
28
+ CompletionUsage,
29
+ LoRA,
30
+ )
31
+ from ..core import LLM
32
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
33
+ from ..utils import ChatModelMixin
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ try:
38
+ import lmdeploy # noqa: F401
39
+
40
+ LMDEPLOY_INSTALLED = True
41
+ except ImportError:
42
+ LMDEPLOY_INSTALLED = False
43
+
44
+ LMDEPLOY_SUPPORTED_CHAT_MODELS = ["internvl2"]
45
+ LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME = {
46
+ "internvl2": "internvl-internlm2",
47
+ }
48
+
49
+
50
+ class LMDeployModelConfig(TypedDict, total=False):
51
+ model_format: Optional[str]
52
+ tp: Optional[int]
53
+ session_len: Optional[int]
54
+ max_batch_size: Optional[int]
55
+ cache_max_entry_count: Optional[float]
56
+ cache_block_seq_len: Optional[int]
57
+ enable_prefix_caching: Optional[bool]
58
+ quant_policy: Optional[int]
59
+ rope_scaling_factor: Optional[float]
60
+ use_logn_attn: Optional[bool]
61
+ download_dir: Optional[str]
62
+ revision: Optional[str]
63
+ max_prefill_token_num: Optional[int]
64
+ num_tokens_per_iter: Optional[int]
65
+ max_prefill_iters: Optional[int]
66
+
67
+
68
+ class LMDeployGenerateConfig(TypedDict, total=False):
69
+ n: Optional[int]
70
+ max_new_tokens: Optional[int]
71
+ top_p: Optional[float]
72
+ top_k: Optional[int]
73
+ temperature: Optional[float]
74
+ repetition_penalty: Optional[float]
75
+ ignore_eos: Optional[bool]
76
+ random_seed: Optional[int]
77
+ stop_words: Optional[List[str]]
78
+ bad_words: Optional[List[str]]
79
+ min_new_tokens: Optional[int]
80
+ skip_special_tokens: Optional[bool]
81
+ logprobs: Optional[int]
82
+
83
+
84
+ class LMDeployModel(LLM):
85
+ def __init__(
86
+ self,
87
+ model_uid: str,
88
+ model_family: "LLMFamilyV1",
89
+ model_spec: "LLMSpecV1",
90
+ quantization: str,
91
+ model_path: str,
92
+ model_config: Optional[LMDeployModelConfig] = None,
93
+ peft_model: Optional[List[LoRA]] = None,
94
+ ):
95
+ super().__init__(model_uid, model_family, model_spec, quantization, model_path)
96
+ self._model_config: LMDeployModelConfig = self._sanitize_model_config(
97
+ model_config
98
+ )
99
+ if peft_model is not None:
100
+ raise ValueError("LMDEPLOY engine has not supported lora yet.")
101
+
102
+ def _sanitize_model_config(
103
+ self, model_config: Optional[LMDeployModelConfig]
104
+ ) -> LMDeployModelConfig:
105
+ if model_config is None:
106
+ model_config = LMDeployModelConfig()
107
+ model_config.setdefault("session_len", 8192)
108
+ if self.model_spec.model_format == "awq":
109
+ model_config.setdefault("model_format", "awq")
110
+ return model_config
111
+
112
+ def load(self):
113
+ try:
114
+ import lmdeploy # noqa: F401, F811
115
+ except ImportError:
116
+ error_message = "Failed to import module 'lmdeploy'"
117
+ installation_guide = [
118
+ "Please make sure 'lmdeploy' is installed. ",
119
+ "You can install it by `pip install lmdeploy`\n",
120
+ ]
121
+
122
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
123
+ raise ValueError("LMDEPLOY engine has not supported generate yet.")
124
+
125
+ @classmethod
126
+ def match(
127
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
128
+ ) -> bool:
129
+ return False
130
+
131
+ def generate(
132
+ self,
133
+ prompt: str,
134
+ generate_config: Optional[Dict] = None,
135
+ ) -> Union[Completion, Iterator[ChatCompletionChunk]]:
136
+ raise NotImplementedError("LMDeploy generate ablility does not support now.")
137
+
138
+
139
+ class LMDeployChatModel(LMDeployModel, ChatModelMixin):
140
+ def load(self):
141
+ try:
142
+ from lmdeploy import (
143
+ ChatTemplateConfig,
144
+ TurbomindEngineConfig,
145
+ VisionConfig,
146
+ pipeline,
147
+ )
148
+ except ImportError:
149
+ error_message = "Failed to import module 'lmdeploy'"
150
+ installation_guide = [
151
+ "Please make sure 'lmdeploy' is installed. ",
152
+ "You can install it by `pip install lmdeploy`\n",
153
+ ]
154
+
155
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
156
+
157
+ chat_temp_name = ""
158
+ family = self.model_family.model_family or self.model_family.model_name
159
+ for key in LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME.keys():
160
+ if family in key:
161
+ chat_temp_name = LMDEPLOY_MODEL_CHAT_TEMPLATE_NAME[key]
162
+ break
163
+ if chat_temp_name == "":
164
+ raise ValueError(f"Can not find correct chat template.")
165
+
166
+ chat_template_config = ChatTemplateConfig(chat_temp_name)
167
+ chat_template_config.meta_instruction = (
168
+ self.model_family.prompt_style.system_prompt
169
+ )
170
+ count = torch.cuda.device_count()
171
+ if count > 1:
172
+ self._model_config.setdefault("tp", torch.cuda.device_count())
173
+
174
+ self._model = pipeline(
175
+ self.model_path,
176
+ chat_template_config=chat_template_config,
177
+ backend_config=TurbomindEngineConfig(**self._model_config),
178
+ vision_config=VisionConfig(thread_safe=True),
179
+ )
180
+
181
+ @classmethod
182
+ def match(
183
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
184
+ ) -> bool:
185
+ if llm_spec.model_format == "awq":
186
+ # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits.
187
+ if "4" not in quantization:
188
+ return False
189
+ if llm_family.model_name not in LMDEPLOY_SUPPORTED_CHAT_MODELS:
190
+ return False
191
+ return LMDEPLOY_INSTALLED
192
+
193
+ async def async_chat(
194
+ self,
195
+ prompt: Union[str, List[Dict]],
196
+ system_prompt: Optional[str] = None,
197
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
198
+ generate_config: Optional[Dict] = None,
199
+ ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]:
200
+ stream = (
201
+ generate_config.get("stream", False)
202
+ if isinstance(generate_config, dict)
203
+ else False
204
+ )
205
+ stream_options = (
206
+ generate_config.get("stream_options", None)
207
+ if isinstance(generate_config, dict)
208
+ else False
209
+ )
210
+ include_usage = (
211
+ stream_options["include_usage"]
212
+ if isinstance(stream_options, dict)
213
+ else False
214
+ )
215
+
216
+ chat_history = chat_history or []
217
+
218
+ if stream:
219
+ chunk = self._chat_stream(prompt, chat_history, include_usage)
220
+ return self._async_to_chat_completion_chunks(chunk)
221
+ else:
222
+ chunk = await self._chat(prompt, chat_history)
223
+ return self._to_chat_completion(chunk)
224
+
225
+ async def _chat_stream(self, prompt, chat_history, include_usage):
226
+ from lmdeploy.messages import Response
227
+
228
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
229
+ completion_id = str(uuid.uuid1())
230
+ async for output in self._generate(
231
+ prompt,
232
+ chat_history,
233
+ session_id=-1,
234
+ stream_response=True,
235
+ ):
236
+ new_text = output.text if isinstance(output, Response) else output.response
237
+
238
+ completion_choice = ChatCompletionChunkChoice(
239
+ text=new_text,
240
+ index=0,
241
+ logprobs=None,
242
+ finish_reason=output.finish_reason,
243
+ )
244
+ chunk = ChatCompletionChunk(
245
+ id=completion_id,
246
+ object="chat.completion",
247
+ created=int(time.time()),
248
+ model=self.model_uid,
249
+ choices=[completion_choice],
250
+ )
251
+ prompt_tokens = output.input_token_len
252
+ completion_tokens = output.generate_token_len
253
+ total_tokens = prompt_tokens + completion_tokens
254
+ completion_usage = CompletionUsage(
255
+ prompt_tokens=prompt_tokens,
256
+ completion_tokens=completion_tokens,
257
+ total_tokens=total_tokens,
258
+ )
259
+ chunk["usage"] = completion_usage
260
+ print(chunk)
261
+ yield chunk
262
+ if include_usage:
263
+ chunk = ChatCompletionChunk(
264
+ id=completion_id,
265
+ object="chat.completion",
266
+ created=int(time.time()),
267
+ model=self.model_uid,
268
+ choices=[],
269
+ )
270
+ chunk["usage"] = CompletionUsage(
271
+ prompt_tokens=prompt_tokens,
272
+ completion_tokens=completion_tokens,
273
+ total_tokens=total_tokens,
274
+ )
275
+ yield chunk
276
+
277
+ async def _chat(self, prompt, chat_history):
278
+ from lmdeploy.messages import Response
279
+
280
+ response, finish_reason = "", ""
281
+ prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
282
+ async for output in self._generate(
283
+ prompt,
284
+ chat_history,
285
+ session_id=-1,
286
+ stream_response=False,
287
+ ):
288
+ response += output.text if isinstance(output, Response) else output.response
289
+ prompt_tokens = output.input_token_len
290
+ completion_tokens = output.generate_token_len
291
+ total_tokens = output.input_token_len + output.generate_token_len
292
+ finish_reason = output.finish_reason
293
+
294
+ chunk = ChatCompletion(
295
+ id=str(uuid.uuid1()),
296
+ object="chat.completion",
297
+ created=int(time.time()),
298
+ model=self.model_uid,
299
+ choices=[
300
+ CompletionChoice(
301
+ index=0, text=response, finish_reason=finish_reason, logprobs=None
302
+ )
303
+ ],
304
+ usage=CompletionUsage(
305
+ prompt_tokens=prompt_tokens,
306
+ completion_tokens=completion_tokens,
307
+ total_tokens=total_tokens,
308
+ ),
309
+ )
310
+ return chunk
311
+
312
+ # copy from lmdeploy
313
+ # Reference: lmdeploy.serve.async_engine.py
314
+ async def _generate(
315
+ self,
316
+ prompt,
317
+ chat_history,
318
+ session_id: int,
319
+ generate_config: Optional[Dict] = None,
320
+ tools: Optional[List[object]] = None,
321
+ stream_response: bool = True,
322
+ sequence_start: bool = True,
323
+ sequence_end: bool = True, # no interactive mode by default
324
+ step: int = 0,
325
+ do_preprocess: bool = False,
326
+ adapter_name: Optional[str] = None,
327
+ **kwargs,
328
+ ):
329
+ import random
330
+
331
+ from lmdeploy.messages import EngineGenerationConfig, GenerationConfig
332
+ from lmdeploy.serve.async_engine import GenOut
333
+ from lmdeploy.tokenizer import DetokenizeState
334
+
335
+ session_id = -1
336
+
337
+ if str(session_id) not in self._model.id2step:
338
+ self._model.id2step[str(session_id)] = 0
339
+ if generate_config is None:
340
+ generate_config = GenerationConfig()
341
+ if type(generate_config) is GenerationConfig:
342
+ generate_config = EngineGenerationConfig.From(
343
+ generate_config, self._model.tokenizer
344
+ )
345
+ if generate_config.stop_words is None: # type: ignore
346
+ generate_config.stop_words = self._model.stop_words # type: ignore
347
+ if generate_config.random_seed is None and sequence_start: # type: ignore
348
+ generate_config.random_seed = random.getrandbits(64) # type: ignore
349
+ if generate_config.n > 1: # type: ignore
350
+ logger.warning(
351
+ f"n({generate_config.n}) > 1 hasn't been supported yet. " # type: ignore
352
+ f"Fallback to 1"
353
+ )
354
+ generate_config.n = 1 # type: ignore
355
+
356
+ prompt_input = await self._get_prompt_input(prompt, chat_history)
357
+ prompt = prompt_input["prompt"]
358
+ input_ids = prompt_input["input_ids"]
359
+ finish_reason = None
360
+ logger.info(
361
+ f"prompt={prompt!r}, "
362
+ f"gen_config={generate_config}, "
363
+ f"prompt_token_id={input_ids}, "
364
+ f"adapter_name={adapter_name}."
365
+ )
366
+ logger.info(
367
+ f"session_id={session_id}, " # type: ignore
368
+ f"history_tokens={self._model.id2step[str(session_id)]}, "
369
+ f"input_tokens={len(input_ids)}, "
370
+ f"max_new_tokens={generate_config.max_new_tokens}, "
371
+ f"seq_start={sequence_start}, seq_end={sequence_end}, "
372
+ f"step={step}, prep={do_preprocess}"
373
+ )
374
+
375
+ if generate_config.max_new_tokens is None: # type: ignore
376
+ # for interactive endpoint, will try maximum possible token num
377
+ generate_config.max_new_tokens = max( # type: ignore
378
+ 128,
379
+ self._model.session_len
380
+ - self._model.id2step[str(session_id)]
381
+ - len(input_ids),
382
+ )
383
+ elif (
384
+ self._model.id2step[str(session_id)]
385
+ + len(input_ids)
386
+ + generate_config.max_new_tokens # type: ignore
387
+ > self._model.session_len
388
+ ):
389
+ generate_config.max_new_tokens = max( # type: ignore
390
+ self._model.session_len
391
+ - self._model.id2step[str(session_id)]
392
+ - len(input_ids),
393
+ 128,
394
+ )
395
+ logger.error(f"Truncate max_new_tokens to {generate_config.max_new_tokens}") # type: ignore
396
+
397
+ if (
398
+ self._model.id2step[str(session_id)]
399
+ + len(input_ids)
400
+ + generate_config.max_new_tokens # type: ignore
401
+ > self._model.session_len
402
+ ):
403
+ logger.error(f"run out of tokens. session_id={session_id}.")
404
+ yield GenOut(
405
+ "", self._model.id2step[str(session_id)], len(input_ids), 0, "length"
406
+ )
407
+ if sequence_end is True and sequence_start is False:
408
+ await self._model.end_session(session_id)
409
+ else:
410
+ generator = await self._model.get_generator(False, session_id)
411
+ async with self._model.safe_run(session_id):
412
+ state = DetokenizeState(len(input_ids))
413
+ start_ids_offset = state.ids_offset
414
+ response = ""
415
+ async for outputs in generator.async_stream_infer(
416
+ session_id=session_id,
417
+ **prompt_input,
418
+ gen_config=generate_config,
419
+ adapter_name=adapter_name,
420
+ stream_output=stream_response,
421
+ sequence_start=sequence_start,
422
+ sequence_end=sequence_end,
423
+ step=self._model.id2step[str(session_id)],
424
+ ):
425
+ # decode res
426
+ res, tokens = (
427
+ input_ids + outputs.token_ids,
428
+ outputs.num_token,
429
+ ) # noqa
430
+ if len(res) <= state.ids_offset:
431
+ continue
432
+
433
+ ids_offset = state.ids_offset
434
+ response, state = self._model.tokenizer.detokenize_incrementally(
435
+ res,
436
+ state,
437
+ skip_special_tokens=generate_config.skip_special_tokens, # type: ignore
438
+ )
439
+
440
+ res = res[ids_offset:]
441
+ logprobs = None
442
+ if outputs.logprobs:
443
+ log_offset = ids_offset - start_ids_offset
444
+ logprobs = outputs.logprobs[log_offset:]
445
+
446
+ # response, history token len,
447
+ # input token len, gen token len
448
+ yield GenOut(
449
+ response,
450
+ self._model.id2step[str(session_id)],
451
+ len(input_ids),
452
+ tokens,
453
+ finish_reason,
454
+ res,
455
+ logprobs,
456
+ )
457
+
458
+ finish_reason = (
459
+ "length" if tokens >= generate_config.max_new_tokens else "stop" # type: ignore
460
+ )
461
+ # utf-8 char at the end means it's a potential unfinished
462
+ # byte sequence
463
+ if not response.endswith("�"):
464
+ response = "" # avaid returning the last response twice
465
+ yield GenOut(
466
+ response,
467
+ self._model.id2step[str(session_id)],
468
+ len(input_ids),
469
+ tokens,
470
+ finish_reason,
471
+ )
472
+ # update step
473
+ self._model.id2step[str(session_id)] += len(input_ids) + tokens
474
+ if sequence_end:
475
+ self._model.id2step[str(session_id)] = 0
476
+ # manually end pytorch session
477
+ # TODO modify pytorch or turbomind api
478
+ if self._model.backend == "pytorch" and sequence_end:
479
+ await self._model.end_session(session_id)
480
+
481
+ # copy from lmdeploy
482
+ # Reference: lmdeploy.serve.vl_async_engine.py
483
+ async def _get_prompt_input(
484
+ self,
485
+ prompt: Union[str, List[Dict]],
486
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
487
+ sequence_start: bool = True,
488
+ tools: Optional[List[object]] = None,
489
+ **kwargs,
490
+ ):
491
+ """get input_ids, embeddings and offsets."""
492
+ IMAGE_TOKEN = "<IMAGE_TOKEN>"
493
+ IMAGE_DUMMY_TOKEN_INDEX = 0
494
+ import numpy as np
495
+
496
+ assert self.model_family.prompt_style is not None
497
+ prompt_style = self.model_family.prompt_style.copy()
498
+ chat_history = chat_history or []
499
+
500
+ decorated, _ = self.get_prompt(prompt, chat_history, prompt_style) # type: ignore
501
+ chat_history.append(ChatCompletionMessage(role="user", content=prompt)) # type: ignore
502
+ prompt = chat_history # type: ignore
503
+
504
+ decorated = decorated.replace("<image>", "<img><IMAGE_TOKEN></img>")
505
+
506
+ segs = decorated.split(IMAGE_TOKEN)
507
+
508
+ results = {}
509
+ input_ids = [] # type: ignore
510
+ if len(segs) > 1:
511
+ images = await self._model.vl_prompt_template.async_collect_pil_images(
512
+ prompt
513
+ )
514
+
515
+ features = await self._model.vl_encoder.async_infer(images)
516
+
517
+ from lmdeploy.vl.templates import MiniCPMVTempateWrapper
518
+
519
+ if isinstance(self._model.vl_prompt_template, MiniCPMVTempateWrapper):
520
+ (
521
+ decorated,
522
+ features,
523
+ ) = self._model.vl_prompt_template.update_image_token( # noqa: E501
524
+ decorated, features
525
+ )
526
+ segs = decorated.split(IMAGE_TOKEN)
527
+
528
+ features = [x.cpu().numpy() for x in features]
529
+ input_ids = []
530
+ begins = []
531
+ ends = []
532
+ if len(segs) != len(features) + 1:
533
+ logger.error(
534
+ f"the number of {IMAGE_TOKEN} is not equal "
535
+ f"to input images, {len(segs) - 1} vs {len(features)}"
536
+ )
537
+ features = features[: len(segs) - 1]
538
+ for i, seg in enumerate(segs):
539
+ if i > 0 and i <= len(features):
540
+ image_dim = features[i - 1].shape[0]
541
+ begins.append(len(input_ids))
542
+ ends.append(begins[-1] + image_dim)
543
+ input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
544
+ seg_ids = self._model.tokenizer.encode(
545
+ seg, add_bos=((i == 0) and sequence_start)
546
+ )
547
+ input_ids.extend(seg_ids)
548
+ ranges = np.stack([begins, ends], axis=1).tolist()
549
+ results["input_embeddings"] = features
550
+ results["input_embedding_ranges"] = ranges
551
+ else:
552
+ input_ids = self._model.tokenizer.encode(decorated, add_bos=sequence_start)
553
+
554
+ results["input_ids"] = input_ids
555
+ results["prompt"] = decorated
556
+
557
+ return results
@@ -11,17 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import uuid
18
17
  from concurrent.futures import ThreadPoolExecutor
19
- from io import BytesIO
20
18
  from typing import Dict, Iterator, List, Optional, Tuple, Union
21
19
 
22
- import requests
23
20
  import torch
24
- from PIL import Image
25
21
 
26
22
  from ....core.scheduler import InferenceRequest
27
23
  from ....model.utils import select_device
@@ -35,6 +31,7 @@ from ....types import (
35
31
  CompletionUsage,
36
32
  )
37
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
38
35
  from .core import PytorchChatModel, PytorchGenerateConfig
39
36
  from .utils import get_max_src_len
40
37
 
@@ -75,7 +72,7 @@ class CogVLM2Model(PytorchChatModel):
75
72
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
76
73
  ) -> bool:
77
74
  family = model_family.model_family or model_family.model_name
78
- if "cogvlm" in family.lower():
75
+ if "cogvlm2" in family.lower() and "video" not in family.lower():
79
76
  return True
80
77
  return False
81
78
 
@@ -116,24 +113,6 @@ class CogVLM2Model(PytorchChatModel):
116
113
  self._save_tensorizer()
117
114
 
118
115
  def _message_content_to_cogvlm2(self, content):
119
- def _load_image(_url):
120
- if _url.startswith("data:"):
121
- logging.info("Parse url by base64 decoder.")
122
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
123
- # e.g. f"data:image/jpeg;base64,{base64_image}"
124
- _type, data = _url.split(";")
125
- _, ext = _type.split("/")
126
- data = data[len("base64,") :]
127
- data = base64.b64decode(data.encode("utf-8"))
128
- return Image.open(BytesIO(data)).convert("RGB")
129
- else:
130
- try:
131
- response = requests.get(_url)
132
- except requests.exceptions.MissingSchema:
133
- return Image.open(_url).convert("RGB")
134
- else:
135
- return Image.open(BytesIO(response.content)).convert("RGB")
136
-
137
116
  if not isinstance(content, str):
138
117
  texts = []
139
118
  image_urls = []
@@ -146,7 +125,7 @@ class CogVLM2Model(PytorchChatModel):
146
125
  image_futures = []
147
126
  with ThreadPoolExecutor() as executor:
148
127
  for image_url in image_urls:
149
- fut = executor.submit(_load_image, image_url)
128
+ fut = executor.submit(_decode_image, image_url)
150
129
  image_futures.append(fut)
151
130
  images = [fut.result() for fut in image_futures]
152
131
  text = " ".join(texts)
@@ -163,24 +142,6 @@ class CogVLM2Model(PytorchChatModel):
163
142
  def _history_content_to_cogvlm2(
164
143
  self, system_prompt: str, chat_history: List[ChatCompletionMessage]
165
144
  ):
166
- def _image_to_piexl_values(image):
167
- if image.startswith("data:"):
168
- logging.info("Parse url by base64 decoder.")
169
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
170
- # e.g. f"data:image/jpeg;base64,{base64_image}"
171
- _type, data = image.split(";")
172
- _, ext = _type.split("/")
173
- data = data[len("base64,") :]
174
- data = base64.b64decode(data.encode("utf-8"))
175
- return Image.open(BytesIO(data)).convert("RGB")
176
- else:
177
- try:
178
- response = requests.get(image)
179
- except requests.exceptions.MissingSchema:
180
- return Image.open(image).convert("RGB")
181
- else:
182
- return Image.open(BytesIO(response.content)).convert("RGB")
183
-
184
145
  query = system_prompt
185
146
  history: List[Tuple] = []
186
147
  pixel_values = None
@@ -192,9 +153,7 @@ class CogVLM2Model(PytorchChatModel):
192
153
  if c_type == "text":
193
154
  user = content["text"]
194
155
  elif c_type == "image_url" and not pixel_values:
195
- pixel_values = _image_to_piexl_values(
196
- content["image_url"]["url"]
197
- )
156
+ pixel_values = _decode_image(content["image_url"]["url"])
198
157
  assistant = chat_history[i + 1]["content"]
199
158
  history.append((user, assistant))
200
159
  query = assistant # type: ignore