xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -11,17 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import uuid
18
17
  from concurrent.futures import ThreadPoolExecutor
19
- from io import BytesIO
20
18
  from typing import Dict, Iterator, List, Optional, Tuple, Union
21
19
 
22
- import requests
23
20
  import torch
24
- from PIL import Image
25
21
 
26
22
  from ....core.scheduler import InferenceRequest
27
23
  from ....model.utils import select_device
@@ -35,6 +31,7 @@ from ....types import (
35
31
  CompletionUsage,
36
32
  )
37
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
38
35
  from .core import PytorchChatModel, PytorchGenerateConfig
39
36
  from .utils import get_max_src_len
40
37
 
@@ -75,7 +72,7 @@ class CogVLM2Model(PytorchChatModel):
75
72
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
76
73
  ) -> bool:
77
74
  family = model_family.model_family or model_family.model_name
78
- if "cogvlm" in family.lower():
75
+ if "cogvlm2" in family.lower() and "video" not in family.lower():
79
76
  return True
80
77
  return False
81
78
 
@@ -116,24 +113,6 @@ class CogVLM2Model(PytorchChatModel):
116
113
  self._save_tensorizer()
117
114
 
118
115
  def _message_content_to_cogvlm2(self, content):
119
- def _load_image(_url):
120
- if _url.startswith("data:"):
121
- logging.info("Parse url by base64 decoder.")
122
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
123
- # e.g. f"data:image/jpeg;base64,{base64_image}"
124
- _type, data = _url.split(";")
125
- _, ext = _type.split("/")
126
- data = data[len("base64,") :]
127
- data = base64.b64decode(data.encode("utf-8"))
128
- return Image.open(BytesIO(data)).convert("RGB")
129
- else:
130
- try:
131
- response = requests.get(_url)
132
- except requests.exceptions.MissingSchema:
133
- return Image.open(_url).convert("RGB")
134
- else:
135
- return Image.open(BytesIO(response.content)).convert("RGB")
136
-
137
116
  if not isinstance(content, str):
138
117
  texts = []
139
118
  image_urls = []
@@ -146,7 +125,7 @@ class CogVLM2Model(PytorchChatModel):
146
125
  image_futures = []
147
126
  with ThreadPoolExecutor() as executor:
148
127
  for image_url in image_urls:
149
- fut = executor.submit(_load_image, image_url)
128
+ fut = executor.submit(_decode_image, image_url)
150
129
  image_futures.append(fut)
151
130
  images = [fut.result() for fut in image_futures]
152
131
  text = " ".join(texts)
@@ -163,24 +142,6 @@ class CogVLM2Model(PytorchChatModel):
163
142
  def _history_content_to_cogvlm2(
164
143
  self, system_prompt: str, chat_history: List[ChatCompletionMessage]
165
144
  ):
166
- def _image_to_piexl_values(image):
167
- if image.startswith("data:"):
168
- logging.info("Parse url by base64 decoder.")
169
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
170
- # e.g. f"data:image/jpeg;base64,{base64_image}"
171
- _type, data = image.split(";")
172
- _, ext = _type.split("/")
173
- data = data[len("base64,") :]
174
- data = base64.b64decode(data.encode("utf-8"))
175
- return Image.open(BytesIO(data)).convert("RGB")
176
- else:
177
- try:
178
- response = requests.get(image)
179
- except requests.exceptions.MissingSchema:
180
- return Image.open(image).convert("RGB")
181
- else:
182
- return Image.open(BytesIO(response.content)).convert("RGB")
183
-
184
145
  query = system_prompt
185
146
  history: List[Tuple] = []
186
147
  pixel_values = None
@@ -192,9 +153,7 @@ class CogVLM2Model(PytorchChatModel):
192
153
  if c_type == "text":
193
154
  user = content["text"]
194
155
  elif c_type == "image_url" and not pixel_values:
195
- pixel_values = _image_to_piexl_values(
196
- content["image_url"]["url"]
197
- )
156
+ pixel_values = _decode_image(content["image_url"]["url"])
198
157
  assistant = chat_history[i + 1]["content"]
199
158
  history.append((user, assistant))
200
159
  query = assistant # type: ignore
@@ -0,0 +1,524 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import time
16
+ import uuid
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ from ....core.scheduler import InferenceRequest
23
+ from ....model.utils import select_device
24
+ from ....types import (
25
+ ChatCompletion,
26
+ ChatCompletionChunk,
27
+ ChatCompletionMessage,
28
+ Completion,
29
+ CompletionChoice,
30
+ CompletionChunk,
31
+ CompletionUsage,
32
+ )
33
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
35
+ from .core import PytorchChatModel, PytorchGenerateConfig
36
+ from .utils import get_max_src_len
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ LANGUAGE_TOKEN_TYPE = 0
42
+ VISION_TOKEN_TYPE = 1
43
+
44
+
45
+ def recur_move_to(item, tgt, criterion_func):
46
+ """
47
+ This function is copied from https://github.com/THUDM/CogVLM2/blob/main/basic_demo/cli_demo_batch_inference.py
48
+ """
49
+ if criterion_func(item):
50
+ device_copy = item.to(tgt)
51
+ return device_copy
52
+ elif isinstance(item, list):
53
+ return [recur_move_to(v, tgt, criterion_func) for v in item]
54
+ elif isinstance(item, tuple):
55
+ return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
56
+ elif isinstance(item, dict):
57
+ return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
58
+ else:
59
+ return item
60
+
61
+
62
+ class CogVLM2VideoModel(PytorchChatModel):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+ self._torch_type = None
66
+ self._device = None
67
+ self._tokenizer = None
68
+ self._model = None
69
+
70
+ @classmethod
71
+ def match(
72
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
73
+ ) -> bool:
74
+ family = model_family.model_family or model_family.model_name
75
+ if "cogvlm2" in family.lower() and "video" in family.lower():
76
+ return True
77
+ return False
78
+
79
+ def load(self, **kwargs):
80
+ from transformers import AutoModelForCausalLM, AutoTokenizer
81
+ from transformers.generation import GenerationConfig
82
+
83
+ device = self._pytorch_model_config.get("device", "auto")
84
+ self._device = select_device(device)
85
+ self._torch_type = (
86
+ torch.bfloat16
87
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
88
+ else torch.float16
89
+ )
90
+
91
+ if self._check_tensorizer_integrity():
92
+ self._model, self._tokenizer = self._load_tensorizer()
93
+ return
94
+
95
+ if "8-bit" in self.quantization.lower():
96
+ kwargs["load_in_8bit"] = True
97
+ elif "4-bit" in self.quantization.lower():
98
+ kwargs["load_in_4bit"] = True
99
+
100
+ self._tokenizer = AutoTokenizer.from_pretrained(
101
+ self.model_path,
102
+ trust_remote_code=True,
103
+ )
104
+
105
+ self._model = AutoModelForCausalLM.from_pretrained(
106
+ self.model_path,
107
+ torch_dtype=self._torch_type,
108
+ trust_remote_code=True,
109
+ low_cpu_mem_usage=True,
110
+ device_map="auto",
111
+ **kwargs
112
+ ).eval()
113
+
114
+ # Specify hyperparameters for generation
115
+ self._model.generation_config = GenerationConfig.from_pretrained(
116
+ self.model_path,
117
+ trust_remote_code=True,
118
+ )
119
+ self._save_tensorizer()
120
+
121
+ def _load_video(self, video_path):
122
+ import numpy as np
123
+ from decord import VideoReader, bridge, cpu
124
+
125
+ bridge.set_bridge("torch")
126
+ num_frames = 24
127
+
128
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
129
+ frame_id_list = None
130
+ total_frames = len(decord_vr)
131
+ timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
132
+ timestamps = [i[0] for i in timestamps]
133
+ max_second = round(max(timestamps)) + 1
134
+ frame_id_list = []
135
+ for second in range(max_second):
136
+ closest_num = min(timestamps, key=lambda x: abs(x - second))
137
+ index = timestamps.index(closest_num)
138
+ frame_id_list.append(index)
139
+ if len(frame_id_list) >= num_frames:
140
+ break
141
+ video_data = decord_vr.get_batch(frame_id_list)
142
+ video_data = video_data.permute(3, 0, 1, 2)
143
+ return video_data
144
+
145
+ def _message_content_to_cogvlm2(self, content):
146
+ if not isinstance(content, str):
147
+ texts = []
148
+ image_urls = []
149
+ video_urls = []
150
+ for c in content:
151
+ c_type = c.get("type")
152
+ if c_type == "text":
153
+ texts.append(c["text"])
154
+ elif c_type == "image_url":
155
+ image_urls.append(c["image_url"]["url"])
156
+ elif c_type == "video_url":
157
+ video_urls.append(c["video_url"]["url"])
158
+ if len(video_urls) > 1:
159
+ raise RuntimeError("Only one video per message is supported")
160
+ image_futures = []
161
+ video = None
162
+ with ThreadPoolExecutor() as executor:
163
+ for image_url in image_urls:
164
+ fut = executor.submit(_decode_image, image_url)
165
+ image_futures.append(fut)
166
+ images = [fut.result() for fut in image_futures]
167
+ for v in video_urls:
168
+ video = self._load_video(v)
169
+ text = " ".join(texts)
170
+ return text, images, video
171
+ return content, [], None
172
+
173
+ def _history_content_to_cogvlm2(
174
+ self, system_prompt: str, chat_history: List[ChatCompletionMessage]
175
+ ):
176
+ query = system_prompt
177
+ history: List[Tuple] = []
178
+ pixel_values = None
179
+ video_urls: List[str] = []
180
+ for i in range(0, len(chat_history), 2):
181
+ user = chat_history[i]["content"]
182
+ if isinstance(user, List):
183
+ for content in user:
184
+ c_type = content.get("type")
185
+ if c_type == "text":
186
+ user = content["text"]
187
+ elif c_type == "image_url" and not pixel_values:
188
+ pixel_values = _decode_image(content["image_url"]["url"])
189
+ elif c_type == "video_url":
190
+ video_urls.append(content["video_url"]["url"])
191
+ assistant = chat_history[i + 1]["content"]
192
+ history.append((user, assistant))
193
+ query = assistant # type: ignore
194
+ if len(video_urls) > 1:
195
+ raise RuntimeError("Only one video per message is supported")
196
+ video = None
197
+ for v in video_urls:
198
+ video = self._load_video(v)
199
+ return query, history, [pixel_values], video
200
+
201
+ def get_query_and_history(
202
+ self,
203
+ prompt: Union[str, List[Dict]],
204
+ system_prompt: Optional[str] = None,
205
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
206
+ ):
207
+ content, image, video = self._message_content_to_cogvlm2(prompt)
208
+
209
+ history = []
210
+ history_image = None
211
+ history_video = None
212
+ if chat_history:
213
+ (
214
+ query,
215
+ history,
216
+ history_image,
217
+ history_video,
218
+ ) = self._history_content_to_cogvlm2(
219
+ system_prompt, chat_history # type: ignore
220
+ )
221
+
222
+ if image and history_image:
223
+ history = []
224
+ query = content
225
+ else:
226
+ image = image if image else history_image
227
+ query = content
228
+
229
+ if video is not None and history_video is not None:
230
+ history = []
231
+ query = content
232
+ else:
233
+ video = video if video is not None else history_video
234
+ query = content
235
+
236
+ return query, image, video, history
237
+
238
+ def chat(
239
+ self,
240
+ prompt: Union[str, List[Dict]],
241
+ system_prompt: Optional[str] = None,
242
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
243
+ generate_config: Optional[PytorchGenerateConfig] = None,
244
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
245
+ system_prompt = system_prompt if system_prompt else ""
246
+ stream = generate_config.get("stream", False) if generate_config else False
247
+
248
+ sanitized_config = {
249
+ "pad_token_id": 128002,
250
+ "max_new_tokens": generate_config.get("max_tokens", 512)
251
+ if generate_config
252
+ else 512,
253
+ }
254
+
255
+ query, image, video, history = self.get_query_and_history(
256
+ prompt, system_prompt=system_prompt, chat_history=chat_history
257
+ )
258
+
259
+ if video is not None:
260
+ image = [video]
261
+
262
+ input_by_model = self._model.build_conversation_input_ids(
263
+ self._tokenizer,
264
+ query=query,
265
+ history=history,
266
+ images=image,
267
+ template_version="chat",
268
+ )
269
+
270
+ inputs = {
271
+ "input_ids": input_by_model["input_ids"].unsqueeze(0).to(self._device),
272
+ "token_type_ids": input_by_model["token_type_ids"]
273
+ .unsqueeze(0)
274
+ .to(self._device),
275
+ "attention_mask": input_by_model["attention_mask"]
276
+ .unsqueeze(0)
277
+ .to(self._device),
278
+ "images": [
279
+ [input_by_model["images"][0].to(self._device).to(self._torch_type)]
280
+ ]
281
+ if image is not None
282
+ else None,
283
+ }
284
+
285
+ if stream:
286
+ it = self._streaming_chat_response(inputs, sanitized_config)
287
+ return self._to_chat_completion_chunks(it)
288
+ else:
289
+ with torch.no_grad():
290
+ outputs = self._model.generate(**inputs, **sanitized_config)
291
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
292
+ response = self._tokenizer.decode(outputs[0])
293
+ response = response.split("<|end_of_text|>")[0]
294
+
295
+ chunk = Completion(
296
+ id=str(uuid.uuid1()),
297
+ object="text_completion",
298
+ created=int(time.time()),
299
+ model=self.model_uid,
300
+ choices=[
301
+ CompletionChoice(
302
+ index=0, text=response, finish_reason="stop", logprobs=None
303
+ )
304
+ ],
305
+ usage=CompletionUsage(
306
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
307
+ ),
308
+ )
309
+ return self._to_chat_completion(chunk)
310
+
311
+ def _streaming_chat_response(
312
+ self, inputs: Dict, config: Dict
313
+ ) -> Iterator[CompletionChunk]:
314
+ from threading import Thread
315
+
316
+ from transformers import TextIteratorStreamer
317
+
318
+ streamer = TextIteratorStreamer(
319
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True
320
+ )
321
+ generation_kwargs = {
322
+ "input_ids": inputs["input_ids"],
323
+ "attention_mask": inputs["attention_mask"],
324
+ "token_type_ids": inputs["token_type_ids"],
325
+ "images": inputs["images"],
326
+ "max_new_tokens": config["max_new_tokens"],
327
+ "pad_token_id": config["pad_token_id"],
328
+ "streamer": streamer,
329
+ }
330
+
331
+ thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
332
+ thread.start()
333
+
334
+ completion_id = str(uuid.uuid1())
335
+ for new_text in streamer:
336
+ chunk = CompletionChunk(
337
+ id=completion_id,
338
+ object="text_completion",
339
+ created=int(time.time()),
340
+ model=self.model_uid,
341
+ choices=[
342
+ CompletionChoice(
343
+ index=0, text=new_text, finish_reason=None, logprobs=None
344
+ )
345
+ ],
346
+ usage=CompletionUsage(
347
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
348
+ ),
349
+ )
350
+ yield chunk
351
+
352
+ completion_choice = CompletionChoice(
353
+ text="", index=0, logprobs=None, finish_reason="stop"
354
+ )
355
+ chunk = CompletionChunk(
356
+ id=completion_id,
357
+ object="text_completion",
358
+ created=int(time.time()),
359
+ model=self.model_uid,
360
+ choices=[completion_choice],
361
+ usage=CompletionUsage(
362
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
363
+ ),
364
+ )
365
+ yield chunk
366
+
367
+ @staticmethod
368
+ def build_position_ids(x, attention_mask=None):
369
+ """
370
+ Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
371
+ """
372
+ # Fix: 参考官方开源代码
373
+ if attention_mask is not None:
374
+ tmp = x.clone()
375
+ tmp[~(attention_mask.bool())] = -1
376
+ else:
377
+ tmp = x.clone()
378
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
379
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
380
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
381
+ tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
382
+ )
383
+ is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
384
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
385
+ tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
386
+ )
387
+ is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
388
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
389
+ # final position ids
390
+ y = torch.zeros_like(x, dtype=torch.long)
391
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
392
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
393
+ )
394
+ y = y.cumsum(dim=-1)
395
+ return y
396
+
397
+ def get_dtype(self):
398
+ return self._torch_type
399
+
400
+ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
401
+ query, image, video, history = self.get_query_and_history(
402
+ prompt, system_prompt=system_prompt, chat_history=chat_history
403
+ )
404
+
405
+ if video:
406
+ image = [video]
407
+
408
+ input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
409
+ self._tokenizer,
410
+ query=query,
411
+ history=history,
412
+ images=image,
413
+ template_version="chat",
414
+ )
415
+ return {
416
+ "input_ids": input_by_model["input_ids"], # seq_len
417
+ "token_type_ids": input_by_model["token_type_ids"], # seq_len
418
+ "attention_mask": input_by_model["attention_mask"], # seq_len
419
+ "images": input_by_model["images"],
420
+ }
421
+
422
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
423
+ """
424
+ See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
425
+ """
426
+ raw_config = req.inference_kwargs.get("raw_params", {})
427
+ temperature = raw_config.get("temperature", None)
428
+ if temperature is None:
429
+ raw_config["temperature"] = 0.6
430
+ top_p = raw_config.get("top_p", None)
431
+ if top_p is None:
432
+ raw_config["top_p"] = 0.9
433
+ return raw_config
434
+
435
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
436
+ context_len = self.get_context_len()
437
+ assert isinstance(prompts[0], dict)
438
+ images = []
439
+ max_length = float("-inf")
440
+ for i, feature in enumerate(prompts):
441
+ req = req_list[i]
442
+ if "images" in feature:
443
+ images.append(feature.pop("images", None))
444
+ max_src_len = get_max_src_len(context_len, req)
445
+ input_ids = feature["input_ids"][-max_src_len:]
446
+ req.prompt_tokens = input_ids.tolist()
447
+ feature["input_ids"] = input_ids
448
+ feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
449
+ feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
450
+ req.extra_kwargs["attention_mask_seq_len"] = feature[
451
+ "attention_mask"
452
+ ].shape[0]
453
+ max_length = max(len(input_ids), max_length)
454
+
455
+ def pad_to_max_length_internal(feature, max_len, idx):
456
+ padding_length = max_len - len(feature["input_ids"])
457
+ req_list[idx].padding_len = padding_length
458
+ feature["input_ids"] = torch.cat(
459
+ [torch.full((padding_length,), 0), feature["input_ids"]]
460
+ )
461
+ feature["token_type_ids"] = torch.cat(
462
+ [
463
+ torch.zeros(padding_length, dtype=torch.long),
464
+ feature["token_type_ids"],
465
+ ]
466
+ )
467
+ feature["attention_mask"] = torch.cat(
468
+ [
469
+ torch.zeros(padding_length, dtype=torch.long),
470
+ feature["attention_mask"],
471
+ ]
472
+ )
473
+ return feature
474
+
475
+ features = [
476
+ pad_to_max_length_internal(feature, max_length, i)
477
+ for i, feature in enumerate(prompts)
478
+ ]
479
+ batch = {
480
+ key: torch.stack([feature[key] for feature in features])
481
+ for key in features[0].keys()
482
+ }
483
+
484
+ position_ids = self.build_position_ids(batch["token_type_ids"])
485
+ batch["position_ids"] = position_ids
486
+
487
+ for i in range(len(prompts)):
488
+ req = req_list[i]
489
+ req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
490
+
491
+ if images:
492
+ batch["images"] = images
493
+
494
+ batch = recur_move_to(
495
+ batch, self._device, lambda x: isinstance(x, torch.Tensor)
496
+ )
497
+ dtype = self.get_dtype()
498
+ if dtype:
499
+ batch = recur_move_to(
500
+ batch,
501
+ dtype,
502
+ lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
503
+ )
504
+ return batch
505
+
506
+ def build_decode_token_type_ids(
507
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
508
+ ):
509
+ token_type_ids = torch.full(
510
+ (batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
511
+ )
512
+ return token_type_ids
513
+
514
+ def build_decode_position_ids(
515
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
516
+ ):
517
+ tmp = []
518
+ for r in reqs:
519
+ r.extra_kwargs["max_position_id"] += 1
520
+ tmp.append(r.extra_kwargs["max_position_id"])
521
+ position_ids = torch.as_tensor(
522
+ tmp, device=self._device, dtype=torch.long
523
+ ).unsqueeze(1)
524
+ return position_ids
@@ -47,15 +47,6 @@ from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
47
47
  logger = logging.getLogger(__name__)
48
48
 
49
49
  NON_DEFAULT_MODEL_LIST: List[str] = [
50
- "baichuan-chat",
51
- "baichuan-2-chat",
52
- "vicuna-v1.3",
53
- "falcon",
54
- "falcon-instruct",
55
- "chatglm",
56
- "chatglm2",
57
- "chatglm2-32k",
58
- "chatglm2-128k",
59
50
  "chatglm3",
60
51
  "chatglm3-32k",
61
52
  "chatglm3-128k",
@@ -64,13 +55,15 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
64
55
  "llama-2",
65
56
  "llama-2-chat",
66
57
  "internlm2-chat",
58
+ "internlm2.5-chat",
67
59
  "qwen-vl-chat",
68
60
  "OmniLMM",
69
61
  "yi-vl-chat",
70
62
  "deepseek-vl-chat",
71
63
  "internvl-chat",
72
- "mini-internvl-chat",
64
+ "internvl2",
73
65
  "cogvlm2",
66
+ "cogvlm2-video-llama3-chat",
74
67
  "MiniCPM-Llama3-V-2_5",
75
68
  "MiniCPM-V-2.6",
76
69
  "glm-4v",