xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +48 -41
  6. xinference/model/audio/chattts.py +24 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/fish_speech.py +228 -0
  9. xinference/model/audio/model_spec.json +8 -0
  10. xinference/model/embedding/core.py +23 -1
  11. xinference/model/image/model_spec.json +2 -1
  12. xinference/model/image/model_spec_modelscope.json +2 -1
  13. xinference/model/image/stable_diffusion/core.py +49 -1
  14. xinference/model/llm/__init__.py +6 -0
  15. xinference/model/llm/llm_family.json +54 -9
  16. xinference/model/llm/llm_family.py +2 -0
  17. xinference/model/llm/llm_family_modelscope.json +56 -10
  18. xinference/model/llm/lmdeploy/__init__.py +0 -0
  19. xinference/model/llm/lmdeploy/core.py +557 -0
  20. xinference/model/llm/transformers/cogvlm2.py +4 -45
  21. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/glm4v.py +2 -23
  24. xinference/model/llm/transformers/intern_vl.py +94 -11
  25. xinference/model/llm/transformers/minicpmv25.py +2 -23
  26. xinference/model/llm/transformers/minicpmv26.py +2 -22
  27. xinference/model/llm/transformers/yi_vl.py +2 -24
  28. xinference/model/llm/utils.py +10 -1
  29. xinference/model/llm/vllm/core.py +1 -1
  30. xinference/thirdparty/fish_speech/__init__.py +0 -0
  31. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  32. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  33. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  34. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  35. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  36. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  37. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  38. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  39. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  40. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  41. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  42. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  44. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  46. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  48. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  49. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  50. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  52. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  56. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  57. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  58. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  59. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  60. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  63. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  64. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  67. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  68. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  69. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  70. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  71. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  72. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  73. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  74. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  75. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  76. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  77. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  78. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  79. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  83. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  84. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  85. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  86. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  87. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  88. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  89. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  90. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  91. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  92. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  93. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  94. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  95. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  96. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  99. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  100. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  101. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  102. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  103. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  104. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  105. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  106. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  107. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  108. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  109. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  110. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  111. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  112. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  113. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  114. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  115. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  116. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  117. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  118. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  119. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  120. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  121. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  122. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  123. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  124. xinference/web/ui/build/asset-manifest.json +3 -3
  125. xinference/web/ui/build/index.html +1 -1
  126. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  127. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  129. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
  130. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
  131. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  133. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  134. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  135. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  136. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  137. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -42,27 +42,38 @@ def _message_content_to_intern(content, image_cnt):
42
42
  if not isinstance(content, str):
43
43
  texts = []
44
44
  image_urls = []
45
+ video_urls = []
45
46
  for c in content:
46
47
  c_type = c.get("type")
47
48
  if c_type == "text":
48
49
  texts.append(c["text"])
49
50
  elif c_type == "image_url":
50
51
  image_urls.append(c["image_url"]["url"])
52
+ elif c_type == "video_url":
53
+ video_urls.append(c["video_url"]["url"])
54
+ if len(video_urls) > 1:
55
+ raise RuntimeError("Only one video per message is supported")
51
56
  image_futures = []
52
57
  with ThreadPoolExecutor() as executor:
53
58
  for image_url in image_urls:
54
59
  fut = executor.submit(_decode_image, image_url)
55
60
  image_futures.append(fut)
56
61
  images = [fut.result() for fut in image_futures]
62
+ videos = []
63
+ for vid_url in video_urls:
64
+ videos.append(_load_video(vid_url, num_segments=8, max_num=1))
57
65
  prefix = ""
58
66
  for i, _ in enumerate(images):
59
67
  prefix += f"Image-{image_cnt + i + 1}: <image>\n\n"
68
+
69
+ if len(videos) > 0:
70
+ prefix = "".join(
71
+ [f"Frame{i+1}: <image>\n" for i in range(len(videos[0][1]))]
72
+ )
73
+
60
74
  text = prefix + " ".join(texts)
61
- if len(images) == 0:
62
- return text, []
63
- else:
64
- return text, images
65
- return content, []
75
+ return text, images, videos
76
+ return content, [], []
66
77
 
67
78
 
68
79
  def _get_prompt_and_chat_history(
@@ -71,18 +82,21 @@ def _get_prompt_and_chat_history(
71
82
  ):
72
83
  # Convert openai history to intern vl history
73
84
  images = []
85
+ videos = []
74
86
  history = []
75
87
  image_cnt = 0
76
88
  for h1, h2 in zip(*[iter(chat_history or [])] * 2):
77
- content1, img = _message_content_to_intern(h1["content"], image_cnt)
78
- content2, _ = _message_content_to_intern(h2["content"], image_cnt)
89
+ content1, img, vid = _message_content_to_intern(h1["content"], image_cnt)
90
+ content2, _, _ = _message_content_to_intern(h2["content"], image_cnt)
79
91
  history.append([content1, content2])
80
92
  images.extend(img)
81
93
  image_cnt += len(img)
94
+ videos.extend(vid)
82
95
 
83
- question, img = _message_content_to_intern(prompt, image_cnt)
96
+ question, img, vid = _message_content_to_intern(prompt, image_cnt)
84
97
  images.extend(img)
85
- return question, history, images
98
+ videos.extend(vid)
99
+ return question, history, images, videos
86
100
 
87
101
 
88
102
  def _build_transform(input_size=448):
@@ -174,6 +188,53 @@ def _load_image(image_file, input_size=448, max_num=12):
174
188
  return pixel_values
175
189
 
176
190
 
191
+ # video multi-round conversation
192
+ def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
193
+ import numpy as np
194
+
195
+ if bound:
196
+ start, end = bound[0], bound[1]
197
+ else:
198
+ start, end = -100000, 100000
199
+ start_idx = max(first_idx, round(start * fps))
200
+ end_idx = min(round(end * fps), max_frame)
201
+ seg_size = float(end_idx - start_idx) / num_segments
202
+ frame_indices = np.array(
203
+ [
204
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
205
+ for idx in range(num_segments)
206
+ ]
207
+ )
208
+ return frame_indices
209
+
210
+
211
+ def _load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
212
+ from decord import VideoReader, cpu
213
+ from PIL import Image
214
+
215
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
216
+ max_frame = len(vr) - 1
217
+ fps = float(vr.get_avg_fps())
218
+
219
+ pixel_values_list, num_patches_list = [], []
220
+ transform = _build_transform(input_size=input_size)
221
+ frame_indices = _get_index(
222
+ bound, fps, max_frame, first_idx=0, num_segments=num_segments
223
+ )
224
+ for frame_index in frame_indices:
225
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
226
+ img = _dynamic_preprocess(
227
+ img, image_size=input_size, use_thumbnail=True, max_num=max_num
228
+ )
229
+ pixel_values = [transform(tile) for tile in img]
230
+ pixel_values = torch.stack(pixel_values)
231
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
232
+ num_patches_list.append(pixel_values.shape[0])
233
+ pixel_values_list.append(pixel_values)
234
+ pixel_values = torch.cat(pixel_values_list)
235
+ return pixel_values, num_patches_list
236
+
237
+
177
238
  class InternVLChatModel(PytorchChatModel):
178
239
  def __init__(self, *args, **kwargs):
179
240
  super().__init__(*args, **kwargs)
@@ -305,7 +366,9 @@ class InternVLChatModel(PytorchChatModel):
305
366
  else False
306
367
  )
307
368
 
308
- content, history, images = _get_prompt_and_chat_history(prompt, chat_history)
369
+ content, history, images, videos = _get_prompt_and_chat_history(
370
+ prompt, chat_history
371
+ )
309
372
 
310
373
  num_patches_list = []
311
374
  if len(images) == 1:
@@ -327,6 +390,10 @@ class InternVLChatModel(PytorchChatModel):
327
390
  else:
328
391
  pixel_values = None
329
392
 
393
+ if len(videos) > 0:
394
+ pixel_values = videos[0][0]
395
+ num_patches_list = videos[0][1]
396
+
330
397
  assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
331
398
 
332
399
  img_context_token_id = self._tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
@@ -440,7 +507,23 @@ class InternVLChatModel(PytorchChatModel):
440
507
  )
441
508
  chunk["usage"] = completion_usage
442
509
  yield chunk
443
-
510
+ completion_choice = CompletionChoice(
511
+ text="", index=0, logprobs=None, finish_reason="stop"
512
+ )
513
+ chunk = CompletionChunk(
514
+ id=completion_id,
515
+ object="text_completion",
516
+ created=int(time.time()),
517
+ model=self.model_uid,
518
+ choices=[completion_choice],
519
+ )
520
+ completion_usage = CompletionUsage(
521
+ prompt_tokens=prompt_tokens,
522
+ completion_tokens=completion_tokens,
523
+ total_tokens=total_tokens,
524
+ )
525
+ chunk["usage"] = completion_usage
526
+ yield chunk
444
527
  if include_usage:
445
528
  chunk = CompletionChunk(
446
529
  id=completion_id,
@@ -11,18 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import json
16
15
  import logging
17
16
  import time
18
17
  import uuid
19
18
  from concurrent.futures import ThreadPoolExecutor
20
- from io import BytesIO
21
19
  from typing import Dict, Iterator, List, Optional, Union
22
20
 
23
- import requests
24
21
  import torch
25
- from PIL import Image
26
22
 
27
23
  from ....types import (
28
24
  ChatCompletion,
@@ -35,6 +31,7 @@ from ....types import (
35
31
  )
36
32
  from ...utils import select_device
37
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
38
35
  from .core import PytorchChatModel, PytorchGenerateConfig
39
36
 
40
37
  logger = logging.getLogger(__name__)
@@ -102,24 +99,6 @@ class MiniCPMV25Model(PytorchChatModel):
102
99
  self._save_tensorizer()
103
100
 
104
101
  def _message_content_to_chat(self, content):
105
- def _load_image(_url):
106
- if _url.startswith("data:"):
107
- logging.info("Parse url by base64 decoder.")
108
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
109
- # e.g. f"data:image/jpeg;base64,{base64_image}"
110
- _type, data = _url.split(";")
111
- _, ext = _type.split("/")
112
- data = data[len("base64,") :]
113
- data = base64.b64decode(data.encode("utf-8"))
114
- return Image.open(BytesIO(data)).convert("RGB")
115
- else:
116
- try:
117
- response = requests.get(_url)
118
- except requests.exceptions.MissingSchema:
119
- return Image.open(_url).convert("RGB")
120
- else:
121
- return Image.open(BytesIO(response.content)).convert("RGB")
122
-
123
102
  if not isinstance(content, str):
124
103
  texts = []
125
104
  image_urls = []
@@ -132,7 +111,7 @@ class MiniCPMV25Model(PytorchChatModel):
132
111
  image_futures = []
133
112
  with ThreadPoolExecutor() as executor:
134
113
  for image_url in image_urls:
135
- fut = executor.submit(_load_image, image_url)
114
+ fut = executor.submit(_decode_image, image_url)
136
115
  image_futures.append(fut)
137
116
  images = [fut.result() for fut in image_futures]
138
117
  text = " ".join(texts)
@@ -11,15 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import uuid
18
17
  from concurrent.futures import ThreadPoolExecutor
19
- from io import BytesIO
20
18
  from typing import Dict, Iterator, List, Optional, Union
21
19
 
22
- import requests
23
20
  import torch
24
21
  from PIL import Image
25
22
 
@@ -34,6 +31,7 @@ from ....types import (
34
31
  )
35
32
  from ...utils import select_device
36
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
37
35
  from .core import PytorchChatModel, PytorchGenerateConfig
38
36
 
39
37
  logger = logging.getLogger(__name__)
@@ -105,24 +103,6 @@ class MiniCPMV26Model(PytorchChatModel):
105
103
  self._save_tensorizer()
106
104
 
107
105
  def _message_content_to_chat(self, content):
108
- def _load_image(_url):
109
- if _url.startswith("data:"):
110
- logging.info("Parse url by base64 decoder.")
111
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
112
- # e.g. f"data:image/jpeg;base64,{base64_image}"
113
- _type, data = _url.split(";")
114
- _, ext = _type.split("/")
115
- data = data[len("base64,") :]
116
- data = base64.b64decode(data.encode("utf-8"))
117
- return Image.open(BytesIO(data)).convert("RGB")
118
- else:
119
- try:
120
- response = requests.get(_url)
121
- except requests.exceptions.MissingSchema:
122
- return Image.open(_url).convert("RGB")
123
- else:
124
- return Image.open(BytesIO(response.content)).convert("RGB")
125
-
126
106
  MAX_NUM_FRAMES = 64
127
107
 
128
108
  def encode_video(video_path):
@@ -166,7 +146,7 @@ class MiniCPMV26Model(PytorchChatModel):
166
146
  image_futures = []
167
147
  with ThreadPoolExecutor() as executor:
168
148
  for image_url in image_urls:
169
- fut = executor.submit(_load_image, image_url)
149
+ fut = executor.submit(_decode_image, image_url)
170
150
  image_futures.append(fut)
171
151
  images = [fut.result() for fut in image_futures]
172
152
  frames = []
@@ -11,18 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import uuid
18
17
  from concurrent.futures import ThreadPoolExecutor
19
- from io import BytesIO
20
18
  from threading import Thread
21
19
  from typing import Dict, Iterator, List, Optional, Union
22
20
 
23
- import requests
24
21
  import torch
25
- from PIL import Image
26
22
 
27
23
  from ....model.utils import select_device
28
24
  from ....types import (
@@ -35,6 +31,7 @@ from ....types import (
35
31
  CompletionUsage,
36
32
  )
37
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
38
35
  from .core import PytorchChatModel, PytorchGenerateConfig
39
36
 
40
37
  logger = logging.getLogger(__name__)
@@ -78,25 +75,6 @@ class YiVLChatModel(PytorchChatModel):
78
75
 
79
76
  @staticmethod
80
77
  def _message_content_to_yi(content) -> Union[str, tuple]:
81
- def _load_image(_url):
82
- if _url.startswith("data:"):
83
- logging.info("Parse url by base64 decoder.")
84
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
85
- # e.g. f"data:image/jpeg;base64,{base64_image}"
86
- _type, data = _url.split(";")
87
- _, ext = _type.split("/")
88
- data = data[len("base64,") :]
89
- data = base64.b64decode(data.encode("utf-8"))
90
-
91
- return Image.open(BytesIO(data))
92
- else:
93
- try:
94
- response = requests.get(_url)
95
- except requests.exceptions.MissingSchema:
96
- return Image.open(_url)
97
- else:
98
- return Image.open(BytesIO(response.content))
99
-
100
78
  if not isinstance(content, str):
101
79
  from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
102
80
 
@@ -111,7 +89,7 @@ class YiVLChatModel(PytorchChatModel):
111
89
  image_futures = []
112
90
  with ThreadPoolExecutor() as executor:
113
91
  for image_url in image_urls:
114
- fut = executor.submit(_load_image, image_url)
92
+ fut = executor.submit(_decode_image, image_url)
115
93
  image_futures.append(fut)
116
94
  images = [fut.result() for fut in image_futures]
117
95
  text = " ".join(texts)
@@ -459,7 +459,16 @@ Begin!"""
459
459
  role = get_role(message["role"])
460
460
  content = message["content"]
461
461
  if isinstance(content, str):
462
- ret += role + "\n" + content + prompt_style.intra_message_sep + "\n"
462
+ if content:
463
+ ret += (
464
+ role
465
+ + "\n"
466
+ + content
467
+ + prompt_style.intra_message_sep
468
+ + "\n"
469
+ )
470
+ else:
471
+ ret += role + "\n"
463
472
  elif isinstance(content, list):
464
473
  text = ""
465
474
  image_urls = []
@@ -721,7 +721,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
721
721
  prompt_style = self.model_family.prompt_style.copy()
722
722
  chat_history = chat_history or []
723
723
  prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
724
- logger.info(f"messages:{prompt}")
724
+
725
725
  if len(images) == 0:
726
726
  inputs = {
727
727
  "prompt": prompt,
File without changes
@@ -0,0 +1,3 @@
1
+ from .grad_norm import GradNormMonitor
2
+
3
+ __all__ = ["GradNormMonitor"]
@@ -0,0 +1,113 @@
1
+ from typing import Optional, Union
2
+
3
+ import lightning.pytorch as pl
4
+ import torch
5
+ from lightning import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from torch import Tensor, nn
8
+ from torch.utils._foreach_utils import (
9
+ _group_tensors_by_device_and_dtype,
10
+ _has_foreach_support,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def grad_norm(
16
+ parameters: Union[Tensor, list[Tensor]],
17
+ norm_type: float = 2.0,
18
+ ) -> float:
19
+ """
20
+ Returns the norm of the gradients of the given parameters.
21
+
22
+ Args:
23
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
+ single Tensor that will have gradients normalized
25
+ norm_type (float): type of the used p-norm.
26
+
27
+ Returns:
28
+ Total norm of the parameter gradients (viewed as a single vector).
29
+ """ # noqa: E501
30
+
31
+ if isinstance(parameters, Tensor):
32
+ parameters = [parameters]
33
+
34
+ grads = [p.grad for p in parameters if p.grad is not None]
35
+ if len(grads) == 0:
36
+ return None
37
+
38
+ first_device = grads[0].device
39
+ grouped_grads: dict[
40
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
41
+ ] = _group_tensors_by_device_and_dtype(
42
+ [[g.detach() for g in grads]]
43
+ ) # type: ignore[assignment]
44
+
45
+ norms = []
46
+ for (device, _), ([grads], _) in grouped_grads.items():
47
+ if _has_foreach_support(grads, device=device):
48
+ norms.extend(torch._foreach_norm(grads, norm_type))
49
+ else:
50
+ norms.extend([torch.norm(g, norm_type) for g in grads])
51
+
52
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
+
54
+
55
+ class GradNormMonitor(Callback):
56
+ """
57
+ Callback that computes the gradient norm of the model parameters.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ norm_type: float = 2.0,
63
+ logging_interval: str = "step",
64
+ sub_module: Optional[Union[str, list[str]]] = None,
65
+ ) -> None:
66
+ """
67
+ Args:
68
+ norm_type (float): type of the used p-norm.
69
+ logging_interval (str): "step" or "epoch".
70
+ """
71
+ super().__init__()
72
+
73
+ self.norm_type = norm_type
74
+ self.logging_interval = logging_interval
75
+ self.sub_module = sub_module
76
+
77
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
+ """
79
+ Computes the gradient norm of the model parameters and logs it to the logger.
80
+
81
+ Args:
82
+ trainer (Trainer): The trainer object
83
+ model (LightningModule): The current lightningModule
84
+ """
85
+
86
+ lightning_model = model
87
+
88
+ if self.sub_module is None:
89
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
90
+
91
+ sub_modules = self.sub_module
92
+ if isinstance(sub_modules, str):
93
+ sub_modules = [sub_modules]
94
+
95
+ for sub_module in sub_modules:
96
+ self.log_sub_module_grad_norm(
97
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
+ )
99
+
100
+ def log_sub_module_grad_norm(
101
+ self, lightning_model: LightningModule, model: nn.Module, path: str
102
+ ) -> None:
103
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
+ if grad_norm_val is None:
105
+ return
106
+
107
+ on_step = self.logging_interval == "step"
108
+ lightning_model.log(
109
+ f"train{path}/grad_norm",
110
+ grad_norm_val,
111
+ on_step=on_step,
112
+ on_epoch=not on_step,
113
+ )
@@ -0,0 +1,2 @@
1
+ SEMANTIC_TOKEN = "<|semantic|>"
2
+ CODEBOOK_PAD_TOKEN_ID = 0
@@ -0,0 +1,53 @@
1
+ import bisect
2
+ import random
3
+ from typing import Iterable
4
+
5
+ from torch.utils.data import Dataset, IterableDataset
6
+
7
+
8
+ class ConcatRepeatDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+ repeats: list[int]
12
+
13
+ @staticmethod
14
+ def cumsum(sequence, repeats):
15
+ r, s = [], 0
16
+ for dataset, repeat in zip(sequence, repeats):
17
+ l = len(dataset) * repeat
18
+ r.append(l + s)
19
+ s += l
20
+ return r
21
+
22
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
+ super().__init__()
24
+
25
+ self.datasets = list(datasets)
26
+ self.repeats = repeats
27
+
28
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
+ assert len(self.datasets) == len(
30
+ repeats
31
+ ), "datasets and repeats should have the same length"
32
+
33
+ for d in self.datasets:
34
+ assert not isinstance(
35
+ d, IterableDataset
36
+ ), "ConcatRepeatDataset does not support IterableDataset"
37
+
38
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
+
40
+ def __len__(self):
41
+ return self.cumulative_sizes[-1]
42
+
43
+ def __getitem__(self, idx):
44
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
+
46
+ if dataset_idx == 0:
47
+ sample_idx = idx
48
+ else:
49
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
+
51
+ dataset = self.datasets[dataset_idx]
52
+
53
+ return dataset[sample_idx % len(dataset)]
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: text-data.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_SEMANTICS"]._serialized_start = 30
26
+ _globals["_SEMANTICS"]._serialized_end = 57
27
+ _globals["_SENTENCE"]._serialized_start = 59
28
+ _globals["_SENTENCE"]._serialized_end = 125
29
+ _globals["_TEXTDATA"]._serialized_start = 127
30
+ _globals["_TEXTDATA"]._serialized_end = 207
31
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
32
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
33
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,36 @@
1
+ import struct
2
+
3
+ from .text_data_pb2 import TextData
4
+
5
+
6
+ def read_pb_stream(f):
7
+ while True:
8
+ buf = f.read(4)
9
+ if len(buf) == 0:
10
+ break
11
+ size = struct.unpack("I", buf)[0]
12
+ buf = f.read(size)
13
+ text_data = TextData()
14
+ text_data.ParseFromString(buf)
15
+ yield text_data
16
+
17
+
18
+ def write_pb_stream(f, text_data):
19
+ buf = text_data.SerializeToString()
20
+ f.write(struct.pack("I", len(buf)))
21
+ f.write(buf)
22
+
23
+
24
+ def pack_pb_stream(text_data):
25
+ buf = text_data.SerializeToString()
26
+ return struct.pack("I", len(buf)) + buf
27
+
28
+
29
+ def split_pb_stream(f):
30
+ while True:
31
+ head = f.read(4)
32
+ if len(head) == 0:
33
+ break
34
+ size = struct.unpack("I", head)[0]
35
+ buf = f.read(size)
36
+ yield head + buf