xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,524 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import time
16
+ import uuid
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ from ....core.scheduler import InferenceRequest
23
+ from ....model.utils import select_device
24
+ from ....types import (
25
+ ChatCompletion,
26
+ ChatCompletionChunk,
27
+ ChatCompletionMessage,
28
+ Completion,
29
+ CompletionChoice,
30
+ CompletionChunk,
31
+ CompletionUsage,
32
+ )
33
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
35
+ from .core import PytorchChatModel, PytorchGenerateConfig
36
+ from .utils import get_max_src_len
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ LANGUAGE_TOKEN_TYPE = 0
42
+ VISION_TOKEN_TYPE = 1
43
+
44
+
45
+ def recur_move_to(item, tgt, criterion_func):
46
+ """
47
+ This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
48
+ """
49
+ if criterion_func(item):
50
+ device_copy = item.to(tgt)
51
+ return device_copy
52
+ elif isinstance(item, list):
53
+ return [recur_move_to(v, tgt, criterion_func) for v in item]
54
+ elif isinstance(item, tuple):
55
+ return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
56
+ elif isinstance(item, dict):
57
+ return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
58
+ else:
59
+ return item
60
+
61
+
62
+ class CogVLM2VideoModel(PytorchChatModel):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+ self._torch_type = None
66
+ self._device = None
67
+ self._tokenizer = None
68
+ self._model = None
69
+
70
+ @classmethod
71
+ def match(
72
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
73
+ ) -> bool:
74
+ family = model_family.model_family or model_family.model_name
75
+ if "cogvlm2" in family.lower() and "video" in family.lower():
76
+ return True
77
+ return False
78
+
79
+ def load(self, **kwargs):
80
+ from transformers import AutoModelForCausalLM, AutoTokenizer
81
+ from transformers.generation import GenerationConfig
82
+
83
+ device = self._pytorch_model_config.get("device", "auto")
84
+ self._device = select_device(device)
85
+ self._torch_type = (
86
+ torch.bfloat16
87
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
88
+ else torch.float16
89
+ )
90
+
91
+ if self._check_tensorizer_integrity():
92
+ self._model, self._tokenizer = self._load_tensorizer()
93
+ return
94
+
95
+ if "8-bit" in self.quantization.lower():
96
+ kwargs["load_in_8bit"] = True
97
+ elif "4-bit" in self.quantization.lower():
98
+ kwargs["load_in_4bit"] = True
99
+
100
+ self._tokenizer = AutoTokenizer.from_pretrained(
101
+ self.model_path,
102
+ trust_remote_code=True,
103
+ )
104
+
105
+ self._model = AutoModelForCausalLM.from_pretrained(
106
+ self.model_path,
107
+ torch_dtype=self._torch_type,
108
+ trust_remote_code=True,
109
+ low_cpu_mem_usage=True,
110
+ device_map="auto",
111
+ **kwargs
112
+ ).eval()
113
+
114
+ # Specify hyperparameters for generation
115
+ self._model.generation_config = GenerationConfig.from_pretrained(
116
+ self.model_path,
117
+ trust_remote_code=True,
118
+ )
119
+ self._save_tensorizer()
120
+
121
+ def _load_video(self, video_path):
122
+ import numpy as np
123
+ from decord import VideoReader, bridge, cpu
124
+
125
+ bridge.set_bridge("torch")
126
+ num_frames = 24
127
+
128
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
129
+ frame_id_list = None
130
+ total_frames = len(decord_vr)
131
+ timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
132
+ timestamps = [i[0] for i in timestamps]
133
+ max_second = round(max(timestamps)) + 1
134
+ frame_id_list = []
135
+ for second in range(max_second):
136
+ closest_num = min(timestamps, key=lambda x: abs(x - second))
137
+ index = timestamps.index(closest_num)
138
+ frame_id_list.append(index)
139
+ if len(frame_id_list) >= num_frames:
140
+ break
141
+ video_data = decord_vr.get_batch(frame_id_list)
142
+ video_data = video_data.permute(3, 0, 1, 2)
143
+ return video_data
144
+
145
+ def _message_content_to_cogvlm2(self, content):
146
+ if not isinstance(content, str):
147
+ texts = []
148
+ image_urls = []
149
+ video_urls = []
150
+ for c in content:
151
+ c_type = c.get("type")
152
+ if c_type == "text":
153
+ texts.append(c["text"])
154
+ elif c_type == "image_url":
155
+ image_urls.append(c["image_url"]["url"])
156
+ elif c_type == "video_url":
157
+ video_urls.append(c["video_url"]["url"])
158
+ if len(video_urls) > 1:
159
+ raise RuntimeError("Only one video per message is supported")
160
+ image_futures = []
161
+ video = None
162
+ with ThreadPoolExecutor() as executor:
163
+ for image_url in image_urls:
164
+ fut = executor.submit(_decode_image, image_url)
165
+ image_futures.append(fut)
166
+ images = [fut.result() for fut in image_futures]
167
+ for v in video_urls:
168
+ video = self._load_video(v)
169
+ text = " ".join(texts)
170
+ return text, images, video
171
+ return content, [], None
172
+
173
+ def _history_content_to_cogvlm2(
174
+ self, system_prompt: str, chat_history: List[ChatCompletionMessage]
175
+ ):
176
+ query = system_prompt
177
+ history: List[Tuple] = []
178
+ pixel_values = None
179
+ video_urls: List[str] = []
180
+ for i in range(0, len(chat_history), 2):
181
+ user = chat_history[i]["content"]
182
+ if isinstance(user, List):
183
+ for content in user:
184
+ c_type = content.get("type")
185
+ if c_type == "text":
186
+ user = content["text"]
187
+ elif c_type == "image_url" and not pixel_values:
188
+ pixel_values = _decode_image(content["image_url"]["url"])
189
+ elif c_type == "video_url":
190
+ video_urls.append(content["video_url"]["url"])
191
+ assistant = chat_history[i + 1]["content"]
192
+ history.append((user, assistant))
193
+ query = assistant # type: ignore
194
+ if len(video_urls) > 1:
195
+ raise RuntimeError("Only one video per message is supported")
196
+ video = None
197
+ for v in video_urls:
198
+ video = self._load_video(v)
199
+ return query, history, [pixel_values], video
200
+
201
+ def get_query_and_history(
202
+ self,
203
+ prompt: Union[str, List[Dict]],
204
+ system_prompt: Optional[str] = None,
205
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
206
+ ):
207
+ content, image, video = self._message_content_to_cogvlm2(prompt)
208
+
209
+ history = []
210
+ history_image = None
211
+ history_video = None
212
+ if chat_history:
213
+ (
214
+ query,
215
+ history,
216
+ history_image,
217
+ history_video,
218
+ ) = self._history_content_to_cogvlm2(
219
+ system_prompt, chat_history # type: ignore
220
+ )
221
+
222
+ if image and history_image:
223
+ history = []
224
+ query = content
225
+ else:
226
+ image = image if image else history_image
227
+ query = content
228
+
229
+ if video is not None and history_video is not None:
230
+ history = []
231
+ query = content
232
+ else:
233
+ video = video if video is not None else history_video
234
+ query = content
235
+
236
+ return query, image, video, history
237
+
238
+ def chat(
239
+ self,
240
+ prompt: Union[str, List[Dict]],
241
+ system_prompt: Optional[str] = None,
242
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
243
+ generate_config: Optional[PytorchGenerateConfig] = None,
244
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
245
+ system_prompt = system_prompt if system_prompt else ""
246
+ stream = generate_config.get("stream", False) if generate_config else False
247
+
248
+ sanitized_config = {
249
+ "pad_token_id": 128002,
250
+ "max_new_tokens": generate_config.get("max_tokens", 512)
251
+ if generate_config
252
+ else 512,
253
+ }
254
+
255
+ query, image, video, history = self.get_query_and_history(
256
+ prompt, system_prompt=system_prompt, chat_history=chat_history
257
+ )
258
+
259
+ if video is not None:
260
+ image = [video]
261
+
262
+ input_by_model = self._model.build_conversation_input_ids(
263
+ self._tokenizer,
264
+ query=query,
265
+ history=history,
266
+ images=image,
267
+ template_version="chat",
268
+ )
269
+
270
+ inputs = {
271
+ "input_ids": input_by_model["input_ids"].unsqueeze(0).to(self._device),
272
+ "token_type_ids": input_by_model["token_type_ids"]
273
+ .unsqueeze(0)
274
+ .to(self._device),
275
+ "attention_mask": input_by_model["attention_mask"]
276
+ .unsqueeze(0)
277
+ .to(self._device),
278
+ "images": [
279
+ [input_by_model["images"][0].to(self._device).to(self._torch_type)]
280
+ ]
281
+ if image is not None
282
+ else None,
283
+ }
284
+
285
+ if stream:
286
+ it = self._streaming_chat_response(inputs, sanitized_config)
287
+ return self._to_chat_completion_chunks(it)
288
+ else:
289
+ with torch.no_grad():
290
+ outputs = self._model.generate(**inputs, **sanitized_config)
291
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
292
+ response = self._tokenizer.decode(outputs[0])
293
+ response = response.split("<|end_of_text|>")[0]
294
+
295
+ chunk = Completion(
296
+ id=str(uuid.uuid1()),
297
+ object="text_completion",
298
+ created=int(time.time()),
299
+ model=self.model_uid,
300
+ choices=[
301
+ CompletionChoice(
302
+ index=0, text=response, finish_reason="stop", logprobs=None
303
+ )
304
+ ],
305
+ usage=CompletionUsage(
306
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
307
+ ),
308
+ )
309
+ return self._to_chat_completion(chunk)
310
+
311
+ def _streaming_chat_response(
312
+ self, inputs: Dict, config: Dict
313
+ ) -> Iterator[CompletionChunk]:
314
+ from threading import Thread
315
+
316
+ from transformers import TextIteratorStreamer
317
+
318
+ streamer = TextIteratorStreamer(
319
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True
320
+ )
321
+ generation_kwargs = {
322
+ "input_ids": inputs["input_ids"],
323
+ "attention_mask": inputs["attention_mask"],
324
+ "token_type_ids": inputs["token_type_ids"],
325
+ "images": inputs["images"],
326
+ "max_new_tokens": config["max_new_tokens"],
327
+ "pad_token_id": config["pad_token_id"],
328
+ "streamer": streamer,
329
+ }
330
+
331
+ thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
332
+ thread.start()
333
+
334
+ completion_id = str(uuid.uuid1())
335
+ for new_text in streamer:
336
+ chunk = CompletionChunk(
337
+ id=completion_id,
338
+ object="text_completion",
339
+ created=int(time.time()),
340
+ model=self.model_uid,
341
+ choices=[
342
+ CompletionChoice(
343
+ index=0, text=new_text, finish_reason=None, logprobs=None
344
+ )
345
+ ],
346
+ usage=CompletionUsage(
347
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
348
+ ),
349
+ )
350
+ yield chunk
351
+
352
+ completion_choice = CompletionChoice(
353
+ text="", index=0, logprobs=None, finish_reason="stop"
354
+ )
355
+ chunk = CompletionChunk(
356
+ id=completion_id,
357
+ object="text_completion",
358
+ created=int(time.time()),
359
+ model=self.model_uid,
360
+ choices=[completion_choice],
361
+ usage=CompletionUsage(
362
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
363
+ ),
364
+ )
365
+ yield chunk
366
+
367
+ @staticmethod
368
+ def build_position_ids(x, attention_mask=None):
369
+ """
370
+ Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
371
+ """
372
+ # Fix: 参考官方开源代码
373
+ if attention_mask is not None:
374
+ tmp = x.clone()
375
+ tmp[~(attention_mask.bool())] = -1
376
+ else:
377
+ tmp = x.clone()
378
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
379
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
380
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
381
+ tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
382
+ )
383
+ is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
384
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
385
+ tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
386
+ )
387
+ is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
388
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
389
+ # final position ids
390
+ y = torch.zeros_like(x, dtype=torch.long)
391
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
392
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
393
+ )
394
+ y = y.cumsum(dim=-1)
395
+ return y
396
+
397
+ def get_dtype(self):
398
+ return self._torch_type
399
+
400
+ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
401
+ query, image, video, history = self.get_query_and_history(
402
+ prompt, system_prompt=system_prompt, chat_history=chat_history
403
+ )
404
+
405
+ if video:
406
+ image = [video]
407
+
408
+ input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
409
+ self._tokenizer,
410
+ query=query,
411
+ history=history,
412
+ images=image,
413
+ template_version="chat",
414
+ )
415
+ return {
416
+ "input_ids": input_by_model["input_ids"], # seq_len
417
+ "token_type_ids": input_by_model["token_type_ids"], # seq_len
418
+ "attention_mask": input_by_model["attention_mask"], # seq_len
419
+ "images": input_by_model["images"],
420
+ }
421
+
422
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
423
+ """
424
+ See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
425
+ """
426
+ raw_config = req.inference_kwargs.get("raw_params", {})
427
+ temperature = raw_config.get("temperature", None)
428
+ if temperature is None:
429
+ raw_config["temperature"] = 0.6
430
+ top_p = raw_config.get("top_p", None)
431
+ if top_p is None:
432
+ raw_config["top_p"] = 0.9
433
+ return raw_config
434
+
435
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
436
+ context_len = self.get_context_len()
437
+ assert isinstance(prompts[0], dict)
438
+ images = []
439
+ max_length = float("-inf")
440
+ for i, feature in enumerate(prompts):
441
+ req = req_list[i]
442
+ if "images" in feature:
443
+ images.append(feature.pop("images", None))
444
+ max_src_len = get_max_src_len(context_len, req)
445
+ input_ids = feature["input_ids"][-max_src_len:]
446
+ req.prompt_tokens = input_ids.tolist()
447
+ feature["input_ids"] = input_ids
448
+ feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
449
+ feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
450
+ req.extra_kwargs["attention_mask_seq_len"] = feature[
451
+ "attention_mask"
452
+ ].shape[0]
453
+ max_length = max(len(input_ids), max_length)
454
+
455
+ def pad_to_max_length_internal(feature, max_len, idx):
456
+ padding_length = max_len - len(feature["input_ids"])
457
+ req_list[idx].padding_len = padding_length
458
+ feature["input_ids"] = torch.cat(
459
+ [torch.full((padding_length,), 0), feature["input_ids"]]
460
+ )
461
+ feature["token_type_ids"] = torch.cat(
462
+ [
463
+ torch.zeros(padding_length, dtype=torch.long),
464
+ feature["token_type_ids"],
465
+ ]
466
+ )
467
+ feature["attention_mask"] = torch.cat(
468
+ [
469
+ torch.zeros(padding_length, dtype=torch.long),
470
+ feature["attention_mask"],
471
+ ]
472
+ )
473
+ return feature
474
+
475
+ features = [
476
+ pad_to_max_length_internal(feature, max_length, i)
477
+ for i, feature in enumerate(prompts)
478
+ ]
479
+ batch = {
480
+ key: torch.stack([feature[key] for feature in features])
481
+ for key in features[0].keys()
482
+ }
483
+
484
+ position_ids = self.build_position_ids(batch["token_type_ids"])
485
+ batch["position_ids"] = position_ids
486
+
487
+ for i in range(len(prompts)):
488
+ req = req_list[i]
489
+ req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
490
+
491
+ if images:
492
+ batch["images"] = images
493
+
494
+ batch = recur_move_to(
495
+ batch, self._device, lambda x: isinstance(x, torch.Tensor)
496
+ )
497
+ dtype = self.get_dtype()
498
+ if dtype:
499
+ batch = recur_move_to(
500
+ batch,
501
+ dtype,
502
+ lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
503
+ )
504
+ return batch
505
+
506
+ def build_decode_token_type_ids(
507
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
508
+ ):
509
+ token_type_ids = torch.full(
510
+ (batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
511
+ )
512
+ return token_type_ids
513
+
514
+ def build_decode_position_ids(
515
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
516
+ ):
517
+ tmp = []
518
+ for r in reqs:
519
+ r.extra_kwargs["max_position_id"] += 1
520
+ tmp.append(r.extra_kwargs["max_position_id"])
521
+ position_ids = torch.as_tensor(
522
+ tmp, device=self._device, dtype=torch.long
523
+ ).unsqueeze(1)
524
+ return position_ids
@@ -63,6 +63,7 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
63
63
  "internvl-chat",
64
64
  "internvl2",
65
65
  "cogvlm2",
66
+ "cogvlm2-video-llama3-chat",
66
67
  "MiniCPM-Llama3-V-2_5",
67
68
  "MiniCPM-V-2.6",
68
69
  "glm-4v",
@@ -318,6 +319,8 @@ class PytorchModel(LLM):
318
319
  else:
319
320
  self._model, self._tokenizer = self._load_model(**kwargs)
320
321
 
322
+ self._apply_lora()
323
+
321
324
  if not is_device_map_auto:
322
325
  self._model.to(self._device)
323
326
 
@@ -11,19 +11,15 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import typing
18
17
  import uuid
19
18
  from concurrent.futures import ThreadPoolExecutor
20
- from io import BytesIO
21
19
  from threading import Thread
22
20
  from typing import Dict, Iterator, List, Optional, Union
23
21
 
24
- import requests
25
22
  import torch
26
- from PIL import Image
27
23
 
28
24
  from ....core.scheduler import InferenceRequest
29
25
  from ....types import (
@@ -37,6 +33,7 @@ from ....types import (
37
33
  )
38
34
  from ...utils import select_device
39
35
  from ..llm_family import LLMFamilyV1, LLMSpecV1
36
+ from ..utils import _decode_image
40
37
  from .core import PytorchChatModel, PytorchGenerateConfig
41
38
  from .utils import get_max_src_len
42
39
 
@@ -106,24 +103,6 @@ class Glm4VModel(PytorchChatModel):
106
103
  self._save_tensorizer()
107
104
 
108
105
  def _message_content_to_chat(self, content):
109
- def _load_image(_url):
110
- if _url.startswith("data:"):
111
- logging.info("Parse url by base64 decoder.")
112
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
113
- # e.g. f"data:image/jpeg;base64,{base64_image}"
114
- _type, data = _url.split(";")
115
- _, ext = _type.split("/")
116
- data = data[len("base64,") :]
117
- data = base64.b64decode(data.encode("utf-8"))
118
- return Image.open(BytesIO(data)).convert("RGB")
119
- else:
120
- try:
121
- response = requests.get(_url)
122
- except requests.exceptions.MissingSchema:
123
- return Image.open(_url).convert("RGB")
124
- else:
125
- return Image.open(BytesIO(response.content)).convert("RGB")
126
-
127
106
  if not isinstance(content, str):
128
107
  texts = []
129
108
  image_urls = []
@@ -136,7 +115,7 @@ class Glm4VModel(PytorchChatModel):
136
115
  image_futures = []
137
116
  with ThreadPoolExecutor() as executor:
138
117
  for image_url in image_urls:
139
- fut = executor.submit(_load_image, image_url)
118
+ fut = executor.submit(_decode_image, image_url)
140
119
  image_futures.append(fut)
141
120
  images = [fut.result() for fut in image_futures]
142
121
  text = " ".join(texts)