xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -42,27 +42,38 @@ def _message_content_to_intern(content, image_cnt):
42
42
  if not isinstance(content, str):
43
43
  texts = []
44
44
  image_urls = []
45
+ video_urls = []
45
46
  for c in content:
46
47
  c_type = c.get("type")
47
48
  if c_type == "text":
48
49
  texts.append(c["text"])
49
50
  elif c_type == "image_url":
50
51
  image_urls.append(c["image_url"]["url"])
52
+ elif c_type == "video_url":
53
+ video_urls.append(c["video_url"]["url"])
54
+ if len(video_urls) > 1:
55
+ raise RuntimeError("Only one video per message is supported")
51
56
  image_futures = []
52
57
  with ThreadPoolExecutor() as executor:
53
58
  for image_url in image_urls:
54
59
  fut = executor.submit(_decode_image, image_url)
55
60
  image_futures.append(fut)
56
61
  images = [fut.result() for fut in image_futures]
62
+ videos = []
63
+ for vid_url in video_urls:
64
+ videos.append(_load_video(vid_url, num_segments=8, max_num=1))
57
65
  prefix = ""
58
66
  for i, _ in enumerate(images):
59
67
  prefix += f"Image-{image_cnt + i + 1}: <image>\n\n"
68
+
69
+ if len(videos) > 0:
70
+ prefix = "".join(
71
+ [f"Frame{i+1}: <image>\n" for i in range(len(videos[0][1]))]
72
+ )
73
+
60
74
  text = prefix + " ".join(texts)
61
- if len(images) == 0:
62
- return text, []
63
- else:
64
- return text, images
65
- return content, []
75
+ return text, images, videos
76
+ return content, [], []
66
77
 
67
78
 
68
79
  def _get_prompt_and_chat_history(
@@ -71,18 +82,21 @@ def _get_prompt_and_chat_history(
71
82
  ):
72
83
  # Convert openai history to intern vl history
73
84
  images = []
85
+ videos = []
74
86
  history = []
75
87
  image_cnt = 0
76
88
  for h1, h2 in zip(*[iter(chat_history or [])] * 2):
77
- content1, img = _message_content_to_intern(h1["content"], image_cnt)
78
- content2, _ = _message_content_to_intern(h2["content"], image_cnt)
89
+ content1, img, vid = _message_content_to_intern(h1["content"], image_cnt)
90
+ content2, _, _ = _message_content_to_intern(h2["content"], image_cnt)
79
91
  history.append([content1, content2])
80
92
  images.extend(img)
81
93
  image_cnt += len(img)
94
+ videos.extend(vid)
82
95
 
83
- question, img = _message_content_to_intern(prompt, image_cnt)
96
+ question, img, vid = _message_content_to_intern(prompt, image_cnt)
84
97
  images.extend(img)
85
- return question, history, images
98
+ videos.extend(vid)
99
+ return question, history, images, videos
86
100
 
87
101
 
88
102
  def _build_transform(input_size=448):
@@ -174,6 +188,53 @@ def _load_image(image_file, input_size=448, max_num=12):
174
188
  return pixel_values
175
189
 
176
190
 
191
+ # video multi-round conversation
192
+ def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
193
+ import numpy as np
194
+
195
+ if bound:
196
+ start, end = bound[0], bound[1]
197
+ else:
198
+ start, end = -100000, 100000
199
+ start_idx = max(first_idx, round(start * fps))
200
+ end_idx = min(round(end * fps), max_frame)
201
+ seg_size = float(end_idx - start_idx) / num_segments
202
+ frame_indices = np.array(
203
+ [
204
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
205
+ for idx in range(num_segments)
206
+ ]
207
+ )
208
+ return frame_indices
209
+
210
+
211
+ def _load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
212
+ from decord import VideoReader, cpu
213
+ from PIL import Image
214
+
215
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
216
+ max_frame = len(vr) - 1
217
+ fps = float(vr.get_avg_fps())
218
+
219
+ pixel_values_list, num_patches_list = [], []
220
+ transform = _build_transform(input_size=input_size)
221
+ frame_indices = _get_index(
222
+ bound, fps, max_frame, first_idx=0, num_segments=num_segments
223
+ )
224
+ for frame_index in frame_indices:
225
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
226
+ img = _dynamic_preprocess(
227
+ img, image_size=input_size, use_thumbnail=True, max_num=max_num
228
+ )
229
+ pixel_values = [transform(tile) for tile in img]
230
+ pixel_values = torch.stack(pixel_values)
231
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
232
+ num_patches_list.append(pixel_values.shape[0])
233
+ pixel_values_list.append(pixel_values)
234
+ pixel_values = torch.cat(pixel_values_list)
235
+ return pixel_values, num_patches_list
236
+
237
+
177
238
  class InternVLChatModel(PytorchChatModel):
178
239
  def __init__(self, *args, **kwargs):
179
240
  super().__init__(*args, **kwargs)
@@ -305,7 +366,9 @@ class InternVLChatModel(PytorchChatModel):
305
366
  else False
306
367
  )
307
368
 
308
- content, history, images = _get_prompt_and_chat_history(prompt, chat_history)
369
+ content, history, images, videos = _get_prompt_and_chat_history(
370
+ prompt, chat_history
371
+ )
309
372
 
310
373
  num_patches_list = []
311
374
  if len(images) == 1:
@@ -327,6 +390,10 @@ class InternVLChatModel(PytorchChatModel):
327
390
  else:
328
391
  pixel_values = None
329
392
 
393
+ if len(videos) > 0:
394
+ pixel_values = videos[0][0]
395
+ num_patches_list = videos[0][1]
396
+
330
397
  assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
331
398
 
332
399
  img_context_token_id = self._tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
@@ -440,7 +507,23 @@ class InternVLChatModel(PytorchChatModel):
440
507
  )
441
508
  chunk["usage"] = completion_usage
442
509
  yield chunk
443
-
510
+ completion_choice = CompletionChoice(
511
+ text="", index=0, logprobs=None, finish_reason="stop"
512
+ )
513
+ chunk = CompletionChunk(
514
+ id=completion_id,
515
+ object="text_completion",
516
+ created=int(time.time()),
517
+ model=self.model_uid,
518
+ choices=[completion_choice],
519
+ )
520
+ completion_usage = CompletionUsage(
521
+ prompt_tokens=prompt_tokens,
522
+ completion_tokens=completion_tokens,
523
+ total_tokens=total_tokens,
524
+ )
525
+ chunk["usage"] = completion_usage
526
+ yield chunk
444
527
  if include_usage:
445
528
  chunk = CompletionChunk(
446
529
  id=completion_id,
@@ -11,18 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import json
16
15
  import logging
17
16
  import time
18
17
  import uuid
19
18
  from concurrent.futures import ThreadPoolExecutor
20
- from io import BytesIO
21
19
  from typing import Dict, Iterator, List, Optional, Union
22
20
 
23
- import requests
24
21
  import torch
25
- from PIL import Image
26
22
 
27
23
  from ....types import (
28
24
  ChatCompletion,
@@ -35,6 +31,7 @@ from ....types import (
35
31
  )
36
32
  from ...utils import select_device
37
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
38
35
  from .core import PytorchChatModel, PytorchGenerateConfig
39
36
 
40
37
  logger = logging.getLogger(__name__)
@@ -102,24 +99,6 @@ class MiniCPMV25Model(PytorchChatModel):
102
99
  self._save_tensorizer()
103
100
 
104
101
  def _message_content_to_chat(self, content):
105
- def _load_image(_url):
106
- if _url.startswith("data:"):
107
- logging.info("Parse url by base64 decoder.")
108
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
109
- # e.g. f"data:image/jpeg;base64,{base64_image}"
110
- _type, data = _url.split(";")
111
- _, ext = _type.split("/")
112
- data = data[len("base64,") :]
113
- data = base64.b64decode(data.encode("utf-8"))
114
- return Image.open(BytesIO(data)).convert("RGB")
115
- else:
116
- try:
117
- response = requests.get(_url)
118
- except requests.exceptions.MissingSchema:
119
- return Image.open(_url).convert("RGB")
120
- else:
121
- return Image.open(BytesIO(response.content)).convert("RGB")
122
-
123
102
  if not isinstance(content, str):
124
103
  texts = []
125
104
  image_urls = []
@@ -132,7 +111,7 @@ class MiniCPMV25Model(PytorchChatModel):
132
111
  image_futures = []
133
112
  with ThreadPoolExecutor() as executor:
134
113
  for image_url in image_urls:
135
- fut = executor.submit(_load_image, image_url)
114
+ fut = executor.submit(_decode_image, image_url)
136
115
  image_futures.append(fut)
137
116
  images = [fut.result() for fut in image_futures]
138
117
  text = " ".join(texts)
@@ -11,15 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import uuid
18
17
  from concurrent.futures import ThreadPoolExecutor
19
- from io import BytesIO
20
18
  from typing import Dict, Iterator, List, Optional, Union
21
19
 
22
- import requests
23
20
  import torch
24
21
  from PIL import Image
25
22
 
@@ -34,6 +31,7 @@ from ....types import (
34
31
  )
35
32
  from ...utils import select_device
36
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
37
35
  from .core import PytorchChatModel, PytorchGenerateConfig
38
36
 
39
37
  logger = logging.getLogger(__name__)
@@ -105,24 +103,6 @@ class MiniCPMV26Model(PytorchChatModel):
105
103
  self._save_tensorizer()
106
104
 
107
105
  def _message_content_to_chat(self, content):
108
- def _load_image(_url):
109
- if _url.startswith("data:"):
110
- logging.info("Parse url by base64 decoder.")
111
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
112
- # e.g. f"data:image/jpeg;base64,{base64_image}"
113
- _type, data = _url.split(";")
114
- _, ext = _type.split("/")
115
- data = data[len("base64,") :]
116
- data = base64.b64decode(data.encode("utf-8"))
117
- return Image.open(BytesIO(data)).convert("RGB")
118
- else:
119
- try:
120
- response = requests.get(_url)
121
- except requests.exceptions.MissingSchema:
122
- return Image.open(_url).convert("RGB")
123
- else:
124
- return Image.open(BytesIO(response.content)).convert("RGB")
125
-
126
106
  MAX_NUM_FRAMES = 64
127
107
 
128
108
  def encode_video(video_path):
@@ -166,7 +146,7 @@ class MiniCPMV26Model(PytorchChatModel):
166
146
  image_futures = []
167
147
  with ThreadPoolExecutor() as executor:
168
148
  for image_url in image_urls:
169
- fut = executor.submit(_load_image, image_url)
149
+ fut = executor.submit(_decode_image, image_url)
170
150
  image_futures.append(fut)
171
151
  images = [fut.result() for fut in image_futures]
172
152
  frames = []
@@ -11,18 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import base64
15
14
  import logging
16
15
  import time
17
16
  import uuid
18
17
  from concurrent.futures import ThreadPoolExecutor
19
- from io import BytesIO
20
18
  from threading import Thread
21
19
  from typing import Dict, Iterator, List, Optional, Union
22
20
 
23
- import requests
24
21
  import torch
25
- from PIL import Image
26
22
 
27
23
  from ....model.utils import select_device
28
24
  from ....types import (
@@ -35,6 +31,7 @@ from ....types import (
35
31
  CompletionUsage,
36
32
  )
37
33
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
+ from ..utils import _decode_image
38
35
  from .core import PytorchChatModel, PytorchGenerateConfig
39
36
 
40
37
  logger = logging.getLogger(__name__)
@@ -78,25 +75,6 @@ class YiVLChatModel(PytorchChatModel):
78
75
 
79
76
  @staticmethod
80
77
  def _message_content_to_yi(content) -> Union[str, tuple]:
81
- def _load_image(_url):
82
- if _url.startswith("data:"):
83
- logging.info("Parse url by base64 decoder.")
84
- # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
85
- # e.g. f"data:image/jpeg;base64,{base64_image}"
86
- _type, data = _url.split(";")
87
- _, ext = _type.split("/")
88
- data = data[len("base64,") :]
89
- data = base64.b64decode(data.encode("utf-8"))
90
-
91
- return Image.open(BytesIO(data))
92
- else:
93
- try:
94
- response = requests.get(_url)
95
- except requests.exceptions.MissingSchema:
96
- return Image.open(_url)
97
- else:
98
- return Image.open(BytesIO(response.content))
99
-
100
78
  if not isinstance(content, str):
101
79
  from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
102
80
 
@@ -111,7 +89,7 @@ class YiVLChatModel(PytorchChatModel):
111
89
  image_futures = []
112
90
  with ThreadPoolExecutor() as executor:
113
91
  for image_url in image_urls:
114
- fut = executor.submit(_load_image, image_url)
92
+ fut = executor.submit(_decode_image, image_url)
115
93
  image_futures.append(fut)
116
94
  images = [fut.result() for fut in image_futures]
117
95
  text = " ".join(texts)
@@ -32,6 +32,7 @@ from ...types import (
32
32
  Completion,
33
33
  CompletionChunk,
34
34
  )
35
+ from ..utils import ensure_cache_cleared
35
36
  from .llm_family import (
36
37
  LlamaCppLLMSpecV1,
37
38
  LLMFamilyV1,
@@ -459,7 +460,16 @@ Begin!"""
459
460
  role = get_role(message["role"])
460
461
  content = message["content"]
461
462
  if isinstance(content, str):
462
- ret += role + "\n" + content + prompt_style.intra_message_sep + "\n"
463
+ if content:
464
+ ret += (
465
+ role
466
+ + "\n"
467
+ + content
468
+ + prompt_style.intra_message_sep
469
+ + "\n"
470
+ )
471
+ else:
472
+ ret += role + "\n"
463
473
  elif isinstance(content, list):
464
474
  text = ""
465
475
  image_urls = []
@@ -567,6 +577,7 @@ Begin!"""
567
577
  return cast(ChatCompletionChunk, chat_chunk)
568
578
 
569
579
  @classmethod
580
+ @ensure_cache_cleared
570
581
  def _to_chat_completion_chunks(
571
582
  cls,
572
583
  chunks: Iterator[CompletionChunk],
@@ -599,6 +610,7 @@ Begin!"""
599
610
  i += 1
600
611
 
601
612
  @staticmethod
613
+ @ensure_cache_cleared
602
614
  def _to_chat_completion(completion: Completion) -> ChatCompletion:
603
615
  return {
604
616
  "id": "chat" + completion["id"],
@@ -643,39 +643,6 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
643
643
 
644
644
 
645
645
  class VLLMVisionModel(VLLMModel, ChatModelMixin):
646
- def load(self):
647
- try:
648
- import vllm
649
- from vllm.engine.arg_utils import AsyncEngineArgs
650
- from vllm.engine.async_llm_engine import AsyncLLMEngine
651
- except ImportError:
652
- error_message = "Failed to import module 'vllm'"
653
- installation_guide = [
654
- "Please make sure 'vllm' is installed. ",
655
- "You can install it by `pip install vllm`\n",
656
- ]
657
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
658
-
659
- if vllm.__version__ >= "0.3.1":
660
- # from vllm v0.3.1, it uses cupy as NCCL backend
661
- # in which cupy will fork a process
662
- # only for xoscar >= 0.3.0, new process is allowed in subpool
663
- # besides, xinference set start method as forkserver for unix
664
- # we need to set it to fork to make cupy NCCL work
665
- multiprocessing.set_start_method("fork", force=True)
666
-
667
- self._model_config = self._sanitize_model_config(self._model_config)
668
-
669
- logger.info(
670
- f"Loading {self.model_uid} with following model config: {self._model_config}"
671
- )
672
-
673
- engine_args = AsyncEngineArgs(
674
- model=self.model_path,
675
- **self._model_config,
676
- )
677
- self._engine = AsyncLLMEngine.from_engine_args(engine_args)
678
-
679
646
  @classmethod
680
647
  def match(
681
648
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
@@ -721,7 +688,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
721
688
  prompt_style = self.model_family.prompt_style.copy()
722
689
  chat_history = chat_history or []
723
690
  prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
724
- logger.info(f"messages:{prompt}")
691
+
725
692
  if len(images) == 0:
726
693
  inputs = {
727
694
  "prompt": prompt,
@@ -48,6 +48,10 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
48
48
  if not is_valid_model_name(model_spec.model_name):
49
49
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
50
50
 
51
+ model_uri = model_spec.model_uri
52
+ if model_uri and not is_valid_model_uri(model_uri):
53
+ raise ValueError(f"Invalid model URI {model_uri}.")
54
+
51
55
  with UD_RERANK_LOCK:
52
56
  for model_name in (
53
57
  list(BUILTIN_RERANK_MODELS.keys())
@@ -62,11 +66,6 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
62
66
  UD_RERANKS.append(model_spec)
63
67
 
64
68
  if persist:
65
- # We only validate model URL when persist is True.
66
- model_uri = model_spec.model_uri
67
- if model_uri and not is_valid_model_uri(model_uri):
68
- raise ValueError(f"Invalid model URI {model_uri}.")
69
-
70
69
  persist_path = os.path.join(
71
70
  XINFERENCE_MODEL_DIR, "rerank", f"{model_spec.model_name}.json"
72
71
  )
xinference/model/utils.py CHANGED
@@ -11,17 +11,24 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import functools
16
+ import gc
17
+ import inspect
14
18
  import json
15
19
  import logging
16
20
  import os
21
+ import random
17
22
  from json import JSONDecodeError
18
23
  from pathlib import Path
19
24
  from typing import Any, Callable, Dict, Optional, Tuple, Union
20
25
 
21
26
  import huggingface_hub
27
+ import numpy as np
28
+ import torch
22
29
 
23
30
  from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
24
- from ..device_utils import get_available_device, is_device_available
31
+ from ..device_utils import empty_cache, get_available_device, is_device_available
25
32
  from .core import CacheableModelSpec
26
33
 
27
34
  logger = logging.getLogger(__name__)
@@ -348,3 +355,36 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
348
355
  return int(model_size)
349
356
  else:
350
357
  return str(model_size)
358
+
359
+
360
+ def ensure_cache_cleared(func: Callable):
361
+ assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
362
+ func
363
+ )
364
+ if inspect.isgeneratorfunction(func):
365
+
366
+ @functools.wraps(func)
367
+ def inner(*args, **kwargs):
368
+ for obj in func(*args, **kwargs):
369
+ yield obj
370
+ gc.collect()
371
+ empty_cache()
372
+
373
+ else:
374
+
375
+ @functools.wraps(func)
376
+ def inner(*args, **kwargs):
377
+ try:
378
+ return func(*args, **kwargs)
379
+ finally:
380
+ gc.collect()
381
+ empty_cache()
382
+
383
+ return inner
384
+
385
+
386
+ def set_all_random_seed(seed: int):
387
+ random.seed(seed)
388
+ np.random.seed(seed)
389
+ torch.manual_seed(seed)
390
+ torch.cuda.manual_seed_all(seed)
@@ -14,7 +14,7 @@
14
14
  import logging
15
15
  import os
16
16
  from collections import defaultdict
17
- from typing import Dict, List, Literal, Optional, Tuple
17
+ from typing import Any, Dict, List, Literal, Optional, Tuple
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
@@ -44,6 +44,8 @@ class VideoModelFamilyV1(CacheableModelSpec):
44
44
  model_revision: str
45
45
  model_hub: str = "huggingface"
46
46
  model_ability: Optional[List[str]]
47
+ default_model_config: Optional[Dict[str, Any]]
48
+ default_generate_config: Optional[Dict[str, Any]]
47
49
 
48
50
 
49
51
  class VideoModelDescription(ModelDescription):
@@ -15,7 +15,6 @@
15
15
  import base64
16
16
  import logging
17
17
  import os
18
- import sys
19
18
  import time
20
19
  import uuid
21
20
  from concurrent.futures import ThreadPoolExecutor
@@ -24,10 +23,9 @@ from typing import TYPE_CHECKING, List, Union
24
23
 
25
24
  import numpy as np
26
25
  import PIL.Image
27
- import torch
28
26
 
29
27
  from ...constants import XINFERENCE_VIDEO_DIR
30
- from ...device_utils import move_model_to_available_device
28
+ from ...device_utils import gpu_count, move_model_to_available_device
31
29
  from ...types import Video, VideoList
32
30
 
33
31
  if TYPE_CHECKING:
@@ -76,41 +74,58 @@ class DiffUsersVideoModel:
76
74
  def load(self):
77
75
  import torch
78
76
 
79
- torch_dtype = self._kwargs.get("torch_dtype")
80
- if sys.platform != "darwin" and torch_dtype is None:
81
- # The following params crashes on Mac M2
82
- self._kwargs["torch_dtype"] = torch.float16
83
- self._kwargs["variant"] = "fp16"
84
- self._kwargs["use_safetensors"] = True
77
+ kwargs = self._model_spec.default_model_config.copy()
78
+ kwargs.update(self._kwargs)
79
+
80
+ scheduler_cls_name = kwargs.pop("scheduler", None)
81
+
82
+ torch_dtype = kwargs.get("torch_dtype")
85
83
  if isinstance(torch_dtype, str):
86
- self._kwargs["torch_dtype"] = getattr(torch, torch_dtype)
84
+ kwargs["torch_dtype"] = getattr(torch, torch_dtype)
85
+ logger.debug("Loading video model with kwargs: %s", kwargs)
87
86
 
88
87
  if self._model_spec.model_family == "CogVideoX":
88
+ import diffusers
89
89
  from diffusers import CogVideoXPipeline
90
90
 
91
- self._model = CogVideoXPipeline.from_pretrained(
92
- self._model_path, **self._kwargs
91
+ pipeline = self._model = CogVideoXPipeline.from_pretrained(
92
+ self._model_path, **kwargs
93
93
  )
94
94
  else:
95
95
  raise Exception(
96
96
  f"Unsupported model family: {self._model_spec.model_family}"
97
97
  )
98
98
 
99
- if self._kwargs.get("cpu_offload", False):
99
+ if scheduler_cls_name:
100
+ logger.debug("Using scheduler: %s", scheduler_cls_name)
101
+ pipeline.scheduler = getattr(diffusers, scheduler_cls_name).from_config(
102
+ pipeline.scheduler.config, timestep_spacing="trailing"
103
+ )
104
+ if kwargs.get("compile_graph", False):
105
+ pipeline.transformer = torch.compile(
106
+ pipeline.transformer, mode="max-autotune", fullgraph=True
107
+ )
108
+ if kwargs.get("cpu_offload", False):
100
109
  logger.debug("CPU offloading model")
101
- self._model.enable_model_cpu_offload()
102
- elif not self._kwargs.get("device_map"):
110
+ pipeline.enable_model_cpu_offload()
111
+ if kwargs.get("sequential_cpu_offload", True):
112
+ pipeline.enable_sequential_cpu_offload()
113
+ pipeline.vae.enable_slicing()
114
+ pipeline.vae.enable_tiling()
115
+ elif not kwargs.get("device_map"):
103
116
  logger.debug("Loading model to available device")
104
- self._model = move_model_to_available_device(self._model)
117
+ if gpu_count() > 1:
118
+ kwargs["device_map"] = "balanced"
119
+ else:
120
+ pipeline = move_model_to_available_device(self._model)
105
121
  # Recommended if your computer has < 64 GB of RAM
106
- self._model.enable_attention_slicing()
122
+ pipeline.enable_attention_slicing()
107
123
 
108
124
  def text_to_video(
109
125
  self,
110
126
  prompt: str,
111
127
  n: int = 1,
112
128
  num_inference_steps: int = 50,
113
- guidance_scale: int = 6,
114
129
  response_format: str = "b64_json",
115
130
  **kwargs,
116
131
  ) -> VideoList:
@@ -121,31 +136,19 @@ class DiffUsersVideoModel:
121
136
  # from diffusers.utils import export_to_video
122
137
  from ...device_utils import empty_cache
123
138
 
139
+ assert self._model is not None
140
+ assert callable(self._model)
141
+ generate_kwargs = self._model_spec.default_generate_config.copy()
142
+ generate_kwargs.update(kwargs)
143
+ generate_kwargs["num_videos_per_prompt"] = n
124
144
  logger.debug(
125
145
  "diffusers text_to_video args: %s",
126
- kwargs,
146
+ generate_kwargs,
127
147
  )
128
- assert self._model is not None
129
- if self._kwargs.get("cpu_offload"):
130
- # if enabled cpu offload,
131
- # the model.device would be CPU
132
- device = "cuda"
133
- else:
134
- device = self._model.device
135
- prompt_embeds, _ = self._model.encode_prompt(
136
- prompt=prompt,
137
- do_classifier_free_guidance=True,
138
- num_videos_per_prompt=n,
139
- max_sequence_length=226,
140
- device=device,
141
- dtype=torch.float16,
142
- )
143
- assert callable(self._model)
144
148
  output = self._model(
149
+ prompt=prompt,
145
150
  num_inference_steps=num_inference_steps,
146
- guidance_scale=guidance_scale,
147
- prompt_embeds=prompt_embeds,
148
- **kwargs,
151
+ **generate_kwargs,
149
152
  )
150
153
 
151
154
  # clean cache