xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -42,27 +42,38 @@ def _message_content_to_intern(content, image_cnt):
|
|
|
42
42
|
if not isinstance(content, str):
|
|
43
43
|
texts = []
|
|
44
44
|
image_urls = []
|
|
45
|
+
video_urls = []
|
|
45
46
|
for c in content:
|
|
46
47
|
c_type = c.get("type")
|
|
47
48
|
if c_type == "text":
|
|
48
49
|
texts.append(c["text"])
|
|
49
50
|
elif c_type == "image_url":
|
|
50
51
|
image_urls.append(c["image_url"]["url"])
|
|
52
|
+
elif c_type == "video_url":
|
|
53
|
+
video_urls.append(c["video_url"]["url"])
|
|
54
|
+
if len(video_urls) > 1:
|
|
55
|
+
raise RuntimeError("Only one video per message is supported")
|
|
51
56
|
image_futures = []
|
|
52
57
|
with ThreadPoolExecutor() as executor:
|
|
53
58
|
for image_url in image_urls:
|
|
54
59
|
fut = executor.submit(_decode_image, image_url)
|
|
55
60
|
image_futures.append(fut)
|
|
56
61
|
images = [fut.result() for fut in image_futures]
|
|
62
|
+
videos = []
|
|
63
|
+
for vid_url in video_urls:
|
|
64
|
+
videos.append(_load_video(vid_url, num_segments=8, max_num=1))
|
|
57
65
|
prefix = ""
|
|
58
66
|
for i, _ in enumerate(images):
|
|
59
67
|
prefix += f"Image-{image_cnt + i + 1}: <image>\n\n"
|
|
68
|
+
|
|
69
|
+
if len(videos) > 0:
|
|
70
|
+
prefix = "".join(
|
|
71
|
+
[f"Frame{i+1}: <image>\n" for i in range(len(videos[0][1]))]
|
|
72
|
+
)
|
|
73
|
+
|
|
60
74
|
text = prefix + " ".join(texts)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
else:
|
|
64
|
-
return text, images
|
|
65
|
-
return content, []
|
|
75
|
+
return text, images, videos
|
|
76
|
+
return content, [], []
|
|
66
77
|
|
|
67
78
|
|
|
68
79
|
def _get_prompt_and_chat_history(
|
|
@@ -71,18 +82,21 @@ def _get_prompt_and_chat_history(
|
|
|
71
82
|
):
|
|
72
83
|
# Convert openai history to intern vl history
|
|
73
84
|
images = []
|
|
85
|
+
videos = []
|
|
74
86
|
history = []
|
|
75
87
|
image_cnt = 0
|
|
76
88
|
for h1, h2 in zip(*[iter(chat_history or [])] * 2):
|
|
77
|
-
content1, img = _message_content_to_intern(h1["content"], image_cnt)
|
|
78
|
-
content2, _ = _message_content_to_intern(h2["content"], image_cnt)
|
|
89
|
+
content1, img, vid = _message_content_to_intern(h1["content"], image_cnt)
|
|
90
|
+
content2, _, _ = _message_content_to_intern(h2["content"], image_cnt)
|
|
79
91
|
history.append([content1, content2])
|
|
80
92
|
images.extend(img)
|
|
81
93
|
image_cnt += len(img)
|
|
94
|
+
videos.extend(vid)
|
|
82
95
|
|
|
83
|
-
question, img = _message_content_to_intern(prompt, image_cnt)
|
|
96
|
+
question, img, vid = _message_content_to_intern(prompt, image_cnt)
|
|
84
97
|
images.extend(img)
|
|
85
|
-
|
|
98
|
+
videos.extend(vid)
|
|
99
|
+
return question, history, images, videos
|
|
86
100
|
|
|
87
101
|
|
|
88
102
|
def _build_transform(input_size=448):
|
|
@@ -174,6 +188,53 @@ def _load_image(image_file, input_size=448, max_num=12):
|
|
|
174
188
|
return pixel_values
|
|
175
189
|
|
|
176
190
|
|
|
191
|
+
# video multi-round conversation
|
|
192
|
+
def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
|
193
|
+
import numpy as np
|
|
194
|
+
|
|
195
|
+
if bound:
|
|
196
|
+
start, end = bound[0], bound[1]
|
|
197
|
+
else:
|
|
198
|
+
start, end = -100000, 100000
|
|
199
|
+
start_idx = max(first_idx, round(start * fps))
|
|
200
|
+
end_idx = min(round(end * fps), max_frame)
|
|
201
|
+
seg_size = float(end_idx - start_idx) / num_segments
|
|
202
|
+
frame_indices = np.array(
|
|
203
|
+
[
|
|
204
|
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
|
205
|
+
for idx in range(num_segments)
|
|
206
|
+
]
|
|
207
|
+
)
|
|
208
|
+
return frame_indices
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
|
212
|
+
from decord import VideoReader, cpu
|
|
213
|
+
from PIL import Image
|
|
214
|
+
|
|
215
|
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
|
216
|
+
max_frame = len(vr) - 1
|
|
217
|
+
fps = float(vr.get_avg_fps())
|
|
218
|
+
|
|
219
|
+
pixel_values_list, num_patches_list = [], []
|
|
220
|
+
transform = _build_transform(input_size=input_size)
|
|
221
|
+
frame_indices = _get_index(
|
|
222
|
+
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
|
223
|
+
)
|
|
224
|
+
for frame_index in frame_indices:
|
|
225
|
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
|
226
|
+
img = _dynamic_preprocess(
|
|
227
|
+
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
|
228
|
+
)
|
|
229
|
+
pixel_values = [transform(tile) for tile in img]
|
|
230
|
+
pixel_values = torch.stack(pixel_values)
|
|
231
|
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
|
232
|
+
num_patches_list.append(pixel_values.shape[0])
|
|
233
|
+
pixel_values_list.append(pixel_values)
|
|
234
|
+
pixel_values = torch.cat(pixel_values_list)
|
|
235
|
+
return pixel_values, num_patches_list
|
|
236
|
+
|
|
237
|
+
|
|
177
238
|
class InternVLChatModel(PytorchChatModel):
|
|
178
239
|
def __init__(self, *args, **kwargs):
|
|
179
240
|
super().__init__(*args, **kwargs)
|
|
@@ -305,7 +366,9 @@ class InternVLChatModel(PytorchChatModel):
|
|
|
305
366
|
else False
|
|
306
367
|
)
|
|
307
368
|
|
|
308
|
-
content, history, images = _get_prompt_and_chat_history(
|
|
369
|
+
content, history, images, videos = _get_prompt_and_chat_history(
|
|
370
|
+
prompt, chat_history
|
|
371
|
+
)
|
|
309
372
|
|
|
310
373
|
num_patches_list = []
|
|
311
374
|
if len(images) == 1:
|
|
@@ -327,6 +390,10 @@ class InternVLChatModel(PytorchChatModel):
|
|
|
327
390
|
else:
|
|
328
391
|
pixel_values = None
|
|
329
392
|
|
|
393
|
+
if len(videos) > 0:
|
|
394
|
+
pixel_values = videos[0][0]
|
|
395
|
+
num_patches_list = videos[0][1]
|
|
396
|
+
|
|
330
397
|
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
|
|
331
398
|
|
|
332
399
|
img_context_token_id = self._tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
|
@@ -440,7 +507,23 @@ class InternVLChatModel(PytorchChatModel):
|
|
|
440
507
|
)
|
|
441
508
|
chunk["usage"] = completion_usage
|
|
442
509
|
yield chunk
|
|
443
|
-
|
|
510
|
+
completion_choice = CompletionChoice(
|
|
511
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
512
|
+
)
|
|
513
|
+
chunk = CompletionChunk(
|
|
514
|
+
id=completion_id,
|
|
515
|
+
object="text_completion",
|
|
516
|
+
created=int(time.time()),
|
|
517
|
+
model=self.model_uid,
|
|
518
|
+
choices=[completion_choice],
|
|
519
|
+
)
|
|
520
|
+
completion_usage = CompletionUsage(
|
|
521
|
+
prompt_tokens=prompt_tokens,
|
|
522
|
+
completion_tokens=completion_tokens,
|
|
523
|
+
total_tokens=total_tokens,
|
|
524
|
+
)
|
|
525
|
+
chunk["usage"] = completion_usage
|
|
526
|
+
yield chunk
|
|
444
527
|
if include_usage:
|
|
445
528
|
chunk = CompletionChunk(
|
|
446
529
|
id=completion_id,
|
|
@@ -11,18 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import json
|
|
16
15
|
import logging
|
|
17
16
|
import time
|
|
18
17
|
import uuid
|
|
19
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
-
from io import BytesIO
|
|
21
19
|
from typing import Dict, Iterator, List, Optional, Union
|
|
22
20
|
|
|
23
|
-
import requests
|
|
24
21
|
import torch
|
|
25
|
-
from PIL import Image
|
|
26
22
|
|
|
27
23
|
from ....types import (
|
|
28
24
|
ChatCompletion,
|
|
@@ -35,6 +31,7 @@ from ....types import (
|
|
|
35
31
|
)
|
|
36
32
|
from ...utils import select_device
|
|
37
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
38
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
36
|
|
|
40
37
|
logger = logging.getLogger(__name__)
|
|
@@ -102,24 +99,6 @@ class MiniCPMV25Model(PytorchChatModel):
|
|
|
102
99
|
self._save_tensorizer()
|
|
103
100
|
|
|
104
101
|
def _message_content_to_chat(self, content):
|
|
105
|
-
def _load_image(_url):
|
|
106
|
-
if _url.startswith("data:"):
|
|
107
|
-
logging.info("Parse url by base64 decoder.")
|
|
108
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
109
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
110
|
-
_type, data = _url.split(";")
|
|
111
|
-
_, ext = _type.split("/")
|
|
112
|
-
data = data[len("base64,") :]
|
|
113
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
114
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
115
|
-
else:
|
|
116
|
-
try:
|
|
117
|
-
response = requests.get(_url)
|
|
118
|
-
except requests.exceptions.MissingSchema:
|
|
119
|
-
return Image.open(_url).convert("RGB")
|
|
120
|
-
else:
|
|
121
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
122
|
-
|
|
123
102
|
if not isinstance(content, str):
|
|
124
103
|
texts = []
|
|
125
104
|
image_urls = []
|
|
@@ -132,7 +111,7 @@ class MiniCPMV25Model(PytorchChatModel):
|
|
|
132
111
|
image_futures = []
|
|
133
112
|
with ThreadPoolExecutor() as executor:
|
|
134
113
|
for image_url in image_urls:
|
|
135
|
-
fut = executor.submit(
|
|
114
|
+
fut = executor.submit(_decode_image, image_url)
|
|
136
115
|
image_futures.append(fut)
|
|
137
116
|
images = [fut.result() for fut in image_futures]
|
|
138
117
|
text = " ".join(texts)
|
|
@@ -11,15 +11,12 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import logging
|
|
16
15
|
import time
|
|
17
16
|
import uuid
|
|
18
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
-
from io import BytesIO
|
|
20
18
|
from typing import Dict, Iterator, List, Optional, Union
|
|
21
19
|
|
|
22
|
-
import requests
|
|
23
20
|
import torch
|
|
24
21
|
from PIL import Image
|
|
25
22
|
|
|
@@ -34,6 +31,7 @@ from ....types import (
|
|
|
34
31
|
)
|
|
35
32
|
from ...utils import select_device
|
|
36
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
37
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
38
36
|
|
|
39
37
|
logger = logging.getLogger(__name__)
|
|
@@ -105,24 +103,6 @@ class MiniCPMV26Model(PytorchChatModel):
|
|
|
105
103
|
self._save_tensorizer()
|
|
106
104
|
|
|
107
105
|
def _message_content_to_chat(self, content):
|
|
108
|
-
def _load_image(_url):
|
|
109
|
-
if _url.startswith("data:"):
|
|
110
|
-
logging.info("Parse url by base64 decoder.")
|
|
111
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
112
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
113
|
-
_type, data = _url.split(";")
|
|
114
|
-
_, ext = _type.split("/")
|
|
115
|
-
data = data[len("base64,") :]
|
|
116
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
117
|
-
return Image.open(BytesIO(data)).convert("RGB")
|
|
118
|
-
else:
|
|
119
|
-
try:
|
|
120
|
-
response = requests.get(_url)
|
|
121
|
-
except requests.exceptions.MissingSchema:
|
|
122
|
-
return Image.open(_url).convert("RGB")
|
|
123
|
-
else:
|
|
124
|
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
125
|
-
|
|
126
106
|
MAX_NUM_FRAMES = 64
|
|
127
107
|
|
|
128
108
|
def encode_video(video_path):
|
|
@@ -166,7 +146,7 @@ class MiniCPMV26Model(PytorchChatModel):
|
|
|
166
146
|
image_futures = []
|
|
167
147
|
with ThreadPoolExecutor() as executor:
|
|
168
148
|
for image_url in image_urls:
|
|
169
|
-
fut = executor.submit(
|
|
149
|
+
fut = executor.submit(_decode_image, image_url)
|
|
170
150
|
image_futures.append(fut)
|
|
171
151
|
images = [fut.result() for fut in image_futures]
|
|
172
152
|
frames = []
|
|
@@ -11,18 +11,14 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import base64
|
|
15
14
|
import logging
|
|
16
15
|
import time
|
|
17
16
|
import uuid
|
|
18
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
-
from io import BytesIO
|
|
20
18
|
from threading import Thread
|
|
21
19
|
from typing import Dict, Iterator, List, Optional, Union
|
|
22
20
|
|
|
23
|
-
import requests
|
|
24
21
|
import torch
|
|
25
|
-
from PIL import Image
|
|
26
22
|
|
|
27
23
|
from ....model.utils import select_device
|
|
28
24
|
from ....types import (
|
|
@@ -35,6 +31,7 @@ from ....types import (
|
|
|
35
31
|
CompletionUsage,
|
|
36
32
|
)
|
|
37
33
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
+
from ..utils import _decode_image
|
|
38
35
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
36
|
|
|
40
37
|
logger = logging.getLogger(__name__)
|
|
@@ -78,25 +75,6 @@ class YiVLChatModel(PytorchChatModel):
|
|
|
78
75
|
|
|
79
76
|
@staticmethod
|
|
80
77
|
def _message_content_to_yi(content) -> Union[str, tuple]:
|
|
81
|
-
def _load_image(_url):
|
|
82
|
-
if _url.startswith("data:"):
|
|
83
|
-
logging.info("Parse url by base64 decoder.")
|
|
84
|
-
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
85
|
-
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
86
|
-
_type, data = _url.split(";")
|
|
87
|
-
_, ext = _type.split("/")
|
|
88
|
-
data = data[len("base64,") :]
|
|
89
|
-
data = base64.b64decode(data.encode("utf-8"))
|
|
90
|
-
|
|
91
|
-
return Image.open(BytesIO(data))
|
|
92
|
-
else:
|
|
93
|
-
try:
|
|
94
|
-
response = requests.get(_url)
|
|
95
|
-
except requests.exceptions.MissingSchema:
|
|
96
|
-
return Image.open(_url)
|
|
97
|
-
else:
|
|
98
|
-
return Image.open(BytesIO(response.content))
|
|
99
|
-
|
|
100
78
|
if not isinstance(content, str):
|
|
101
79
|
from ....thirdparty.llava.model.constants import DEFAULT_IMAGE_TOKEN
|
|
102
80
|
|
|
@@ -111,7 +89,7 @@ class YiVLChatModel(PytorchChatModel):
|
|
|
111
89
|
image_futures = []
|
|
112
90
|
with ThreadPoolExecutor() as executor:
|
|
113
91
|
for image_url in image_urls:
|
|
114
|
-
fut = executor.submit(
|
|
92
|
+
fut = executor.submit(_decode_image, image_url)
|
|
115
93
|
image_futures.append(fut)
|
|
116
94
|
images = [fut.result() for fut in image_futures]
|
|
117
95
|
text = " ".join(texts)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -32,6 +32,7 @@ from ...types import (
|
|
|
32
32
|
Completion,
|
|
33
33
|
CompletionChunk,
|
|
34
34
|
)
|
|
35
|
+
from ..utils import ensure_cache_cleared
|
|
35
36
|
from .llm_family import (
|
|
36
37
|
LlamaCppLLMSpecV1,
|
|
37
38
|
LLMFamilyV1,
|
|
@@ -459,7 +460,16 @@ Begin!"""
|
|
|
459
460
|
role = get_role(message["role"])
|
|
460
461
|
content = message["content"]
|
|
461
462
|
if isinstance(content, str):
|
|
462
|
-
|
|
463
|
+
if content:
|
|
464
|
+
ret += (
|
|
465
|
+
role
|
|
466
|
+
+ "\n"
|
|
467
|
+
+ content
|
|
468
|
+
+ prompt_style.intra_message_sep
|
|
469
|
+
+ "\n"
|
|
470
|
+
)
|
|
471
|
+
else:
|
|
472
|
+
ret += role + "\n"
|
|
463
473
|
elif isinstance(content, list):
|
|
464
474
|
text = ""
|
|
465
475
|
image_urls = []
|
|
@@ -567,6 +577,7 @@ Begin!"""
|
|
|
567
577
|
return cast(ChatCompletionChunk, chat_chunk)
|
|
568
578
|
|
|
569
579
|
@classmethod
|
|
580
|
+
@ensure_cache_cleared
|
|
570
581
|
def _to_chat_completion_chunks(
|
|
571
582
|
cls,
|
|
572
583
|
chunks: Iterator[CompletionChunk],
|
|
@@ -599,6 +610,7 @@ Begin!"""
|
|
|
599
610
|
i += 1
|
|
600
611
|
|
|
601
612
|
@staticmethod
|
|
613
|
+
@ensure_cache_cleared
|
|
602
614
|
def _to_chat_completion(completion: Completion) -> ChatCompletion:
|
|
603
615
|
return {
|
|
604
616
|
"id": "chat" + completion["id"],
|
|
@@ -643,39 +643,6 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
|
|
|
643
643
|
|
|
644
644
|
|
|
645
645
|
class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
646
|
-
def load(self):
|
|
647
|
-
try:
|
|
648
|
-
import vllm
|
|
649
|
-
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
650
|
-
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
651
|
-
except ImportError:
|
|
652
|
-
error_message = "Failed to import module 'vllm'"
|
|
653
|
-
installation_guide = [
|
|
654
|
-
"Please make sure 'vllm' is installed. ",
|
|
655
|
-
"You can install it by `pip install vllm`\n",
|
|
656
|
-
]
|
|
657
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
658
|
-
|
|
659
|
-
if vllm.__version__ >= "0.3.1":
|
|
660
|
-
# from vllm v0.3.1, it uses cupy as NCCL backend
|
|
661
|
-
# in which cupy will fork a process
|
|
662
|
-
# only for xoscar >= 0.3.0, new process is allowed in subpool
|
|
663
|
-
# besides, xinference set start method as forkserver for unix
|
|
664
|
-
# we need to set it to fork to make cupy NCCL work
|
|
665
|
-
multiprocessing.set_start_method("fork", force=True)
|
|
666
|
-
|
|
667
|
-
self._model_config = self._sanitize_model_config(self._model_config)
|
|
668
|
-
|
|
669
|
-
logger.info(
|
|
670
|
-
f"Loading {self.model_uid} with following model config: {self._model_config}"
|
|
671
|
-
)
|
|
672
|
-
|
|
673
|
-
engine_args = AsyncEngineArgs(
|
|
674
|
-
model=self.model_path,
|
|
675
|
-
**self._model_config,
|
|
676
|
-
)
|
|
677
|
-
self._engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
678
|
-
|
|
679
646
|
@classmethod
|
|
680
647
|
def match(
|
|
681
648
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
@@ -721,7 +688,7 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
721
688
|
prompt_style = self.model_family.prompt_style.copy()
|
|
722
689
|
chat_history = chat_history or []
|
|
723
690
|
prompt, images = self.get_prompt(prompt, chat_history, prompt_style)
|
|
724
|
-
|
|
691
|
+
|
|
725
692
|
if len(images) == 0:
|
|
726
693
|
inputs = {
|
|
727
694
|
"prompt": prompt,
|
|
@@ -48,6 +48,10 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
|
|
|
48
48
|
if not is_valid_model_name(model_spec.model_name):
|
|
49
49
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
50
50
|
|
|
51
|
+
model_uri = model_spec.model_uri
|
|
52
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
53
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
54
|
+
|
|
51
55
|
with UD_RERANK_LOCK:
|
|
52
56
|
for model_name in (
|
|
53
57
|
list(BUILTIN_RERANK_MODELS.keys())
|
|
@@ -62,11 +66,6 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
|
|
|
62
66
|
UD_RERANKS.append(model_spec)
|
|
63
67
|
|
|
64
68
|
if persist:
|
|
65
|
-
# We only validate model URL when persist is True.
|
|
66
|
-
model_uri = model_spec.model_uri
|
|
67
|
-
if model_uri and not is_valid_model_uri(model_uri):
|
|
68
|
-
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
69
|
-
|
|
70
69
|
persist_path = os.path.join(
|
|
71
70
|
XINFERENCE_MODEL_DIR, "rerank", f"{model_spec.model_name}.json"
|
|
72
71
|
)
|
xinference/model/utils.py
CHANGED
|
@@ -11,17 +11,24 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import functools
|
|
16
|
+
import gc
|
|
17
|
+
import inspect
|
|
14
18
|
import json
|
|
15
19
|
import logging
|
|
16
20
|
import os
|
|
21
|
+
import random
|
|
17
22
|
from json import JSONDecodeError
|
|
18
23
|
from pathlib import Path
|
|
19
24
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
20
25
|
|
|
21
26
|
import huggingface_hub
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
22
29
|
|
|
23
30
|
from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
|
|
24
|
-
from ..device_utils import get_available_device, is_device_available
|
|
31
|
+
from ..device_utils import empty_cache, get_available_device, is_device_available
|
|
25
32
|
from .core import CacheableModelSpec
|
|
26
33
|
|
|
27
34
|
logger = logging.getLogger(__name__)
|
|
@@ -348,3 +355,36 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
|
|
|
348
355
|
return int(model_size)
|
|
349
356
|
else:
|
|
350
357
|
return str(model_size)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def ensure_cache_cleared(func: Callable):
|
|
361
|
+
assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
|
362
|
+
func
|
|
363
|
+
)
|
|
364
|
+
if inspect.isgeneratorfunction(func):
|
|
365
|
+
|
|
366
|
+
@functools.wraps(func)
|
|
367
|
+
def inner(*args, **kwargs):
|
|
368
|
+
for obj in func(*args, **kwargs):
|
|
369
|
+
yield obj
|
|
370
|
+
gc.collect()
|
|
371
|
+
empty_cache()
|
|
372
|
+
|
|
373
|
+
else:
|
|
374
|
+
|
|
375
|
+
@functools.wraps(func)
|
|
376
|
+
def inner(*args, **kwargs):
|
|
377
|
+
try:
|
|
378
|
+
return func(*args, **kwargs)
|
|
379
|
+
finally:
|
|
380
|
+
gc.collect()
|
|
381
|
+
empty_cache()
|
|
382
|
+
|
|
383
|
+
return inner
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def set_all_random_seed(seed: int):
|
|
387
|
+
random.seed(seed)
|
|
388
|
+
np.random.seed(seed)
|
|
389
|
+
torch.manual_seed(seed)
|
|
390
|
+
torch.cuda.manual_seed_all(seed)
|
xinference/model/video/core.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Dict, List, Literal, Optional, Tuple
|
|
17
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
@@ -44,6 +44,8 @@ class VideoModelFamilyV1(CacheableModelSpec):
|
|
|
44
44
|
model_revision: str
|
|
45
45
|
model_hub: str = "huggingface"
|
|
46
46
|
model_ability: Optional[List[str]]
|
|
47
|
+
default_model_config: Optional[Dict[str, Any]]
|
|
48
|
+
default_generate_config: Optional[Dict[str, Any]]
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
class VideoModelDescription(ModelDescription):
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
import base64
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
-
import sys
|
|
19
18
|
import time
|
|
20
19
|
import uuid
|
|
21
20
|
from concurrent.futures import ThreadPoolExecutor
|
|
@@ -24,10 +23,9 @@ from typing import TYPE_CHECKING, List, Union
|
|
|
24
23
|
|
|
25
24
|
import numpy as np
|
|
26
25
|
import PIL.Image
|
|
27
|
-
import torch
|
|
28
26
|
|
|
29
27
|
from ...constants import XINFERENCE_VIDEO_DIR
|
|
30
|
-
from ...device_utils import move_model_to_available_device
|
|
28
|
+
from ...device_utils import gpu_count, move_model_to_available_device
|
|
31
29
|
from ...types import Video, VideoList
|
|
32
30
|
|
|
33
31
|
if TYPE_CHECKING:
|
|
@@ -76,41 +74,58 @@ class DiffUsersVideoModel:
|
|
|
76
74
|
def load(self):
|
|
77
75
|
import torch
|
|
78
76
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
77
|
+
kwargs = self._model_spec.default_model_config.copy()
|
|
78
|
+
kwargs.update(self._kwargs)
|
|
79
|
+
|
|
80
|
+
scheduler_cls_name = kwargs.pop("scheduler", None)
|
|
81
|
+
|
|
82
|
+
torch_dtype = kwargs.get("torch_dtype")
|
|
85
83
|
if isinstance(torch_dtype, str):
|
|
86
|
-
|
|
84
|
+
kwargs["torch_dtype"] = getattr(torch, torch_dtype)
|
|
85
|
+
logger.debug("Loading video model with kwargs: %s", kwargs)
|
|
87
86
|
|
|
88
87
|
if self._model_spec.model_family == "CogVideoX":
|
|
88
|
+
import diffusers
|
|
89
89
|
from diffusers import CogVideoXPipeline
|
|
90
90
|
|
|
91
|
-
self._model = CogVideoXPipeline.from_pretrained(
|
|
92
|
-
self._model_path, **
|
|
91
|
+
pipeline = self._model = CogVideoXPipeline.from_pretrained(
|
|
92
|
+
self._model_path, **kwargs
|
|
93
93
|
)
|
|
94
94
|
else:
|
|
95
95
|
raise Exception(
|
|
96
96
|
f"Unsupported model family: {self._model_spec.model_family}"
|
|
97
97
|
)
|
|
98
98
|
|
|
99
|
-
if
|
|
99
|
+
if scheduler_cls_name:
|
|
100
|
+
logger.debug("Using scheduler: %s", scheduler_cls_name)
|
|
101
|
+
pipeline.scheduler = getattr(diffusers, scheduler_cls_name).from_config(
|
|
102
|
+
pipeline.scheduler.config, timestep_spacing="trailing"
|
|
103
|
+
)
|
|
104
|
+
if kwargs.get("compile_graph", False):
|
|
105
|
+
pipeline.transformer = torch.compile(
|
|
106
|
+
pipeline.transformer, mode="max-autotune", fullgraph=True
|
|
107
|
+
)
|
|
108
|
+
if kwargs.get("cpu_offload", False):
|
|
100
109
|
logger.debug("CPU offloading model")
|
|
101
|
-
|
|
102
|
-
|
|
110
|
+
pipeline.enable_model_cpu_offload()
|
|
111
|
+
if kwargs.get("sequential_cpu_offload", True):
|
|
112
|
+
pipeline.enable_sequential_cpu_offload()
|
|
113
|
+
pipeline.vae.enable_slicing()
|
|
114
|
+
pipeline.vae.enable_tiling()
|
|
115
|
+
elif not kwargs.get("device_map"):
|
|
103
116
|
logger.debug("Loading model to available device")
|
|
104
|
-
|
|
117
|
+
if gpu_count() > 1:
|
|
118
|
+
kwargs["device_map"] = "balanced"
|
|
119
|
+
else:
|
|
120
|
+
pipeline = move_model_to_available_device(self._model)
|
|
105
121
|
# Recommended if your computer has < 64 GB of RAM
|
|
106
|
-
|
|
122
|
+
pipeline.enable_attention_slicing()
|
|
107
123
|
|
|
108
124
|
def text_to_video(
|
|
109
125
|
self,
|
|
110
126
|
prompt: str,
|
|
111
127
|
n: int = 1,
|
|
112
128
|
num_inference_steps: int = 50,
|
|
113
|
-
guidance_scale: int = 6,
|
|
114
129
|
response_format: str = "b64_json",
|
|
115
130
|
**kwargs,
|
|
116
131
|
) -> VideoList:
|
|
@@ -121,31 +136,19 @@ class DiffUsersVideoModel:
|
|
|
121
136
|
# from diffusers.utils import export_to_video
|
|
122
137
|
from ...device_utils import empty_cache
|
|
123
138
|
|
|
139
|
+
assert self._model is not None
|
|
140
|
+
assert callable(self._model)
|
|
141
|
+
generate_kwargs = self._model_spec.default_generate_config.copy()
|
|
142
|
+
generate_kwargs.update(kwargs)
|
|
143
|
+
generate_kwargs["num_videos_per_prompt"] = n
|
|
124
144
|
logger.debug(
|
|
125
145
|
"diffusers text_to_video args: %s",
|
|
126
|
-
|
|
146
|
+
generate_kwargs,
|
|
127
147
|
)
|
|
128
|
-
assert self._model is not None
|
|
129
|
-
if self._kwargs.get("cpu_offload"):
|
|
130
|
-
# if enabled cpu offload,
|
|
131
|
-
# the model.device would be CPU
|
|
132
|
-
device = "cuda"
|
|
133
|
-
else:
|
|
134
|
-
device = self._model.device
|
|
135
|
-
prompt_embeds, _ = self._model.encode_prompt(
|
|
136
|
-
prompt=prompt,
|
|
137
|
-
do_classifier_free_guidance=True,
|
|
138
|
-
num_videos_per_prompt=n,
|
|
139
|
-
max_sequence_length=226,
|
|
140
|
-
device=device,
|
|
141
|
-
dtype=torch.float16,
|
|
142
|
-
)
|
|
143
|
-
assert callable(self._model)
|
|
144
148
|
output = self._model(
|
|
149
|
+
prompt=prompt,
|
|
145
150
|
num_inference_steps=num_inference_steps,
|
|
146
|
-
|
|
147
|
-
prompt_embeds=prompt_embeds,
|
|
148
|
-
**kwargs,
|
|
151
|
+
**generate_kwargs,
|
|
149
152
|
)
|
|
150
153
|
|
|
151
154
|
# clean cache
|