xinference 1.5.1__py3-none-any.whl → 1.6.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +97 -8
- xinference/client/restful/restful_client.py +51 -11
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/worker.py +31 -37
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +1 -0
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +20 -3
- xinference/model/audio/model_spec_modelscope.json +18 -1
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +37 -110
- xinference/model/llm/core.py +15 -6
- xinference/model/llm/llama_cpp/core.py +25 -353
- xinference/model/llm/llm_family.json +613 -89
- xinference/model/llm/llm_family.py +9 -1
- xinference/model/llm/llm_family_modelscope.json +540 -90
- xinference/model/llm/mlx/core.py +6 -3
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +16 -3
- xinference/model/llm/transformers/chatglm.py +2 -2
- xinference/model/llm/transformers/cogagent.py +1 -1
- xinference/model/llm/transformers/cogvlm2.py +1 -1
- xinference/model/llm/transformers/core.py +9 -3
- xinference/model/llm/transformers/glm4v.py +1 -1
- xinference/model/llm/transformers/minicpmv26.py +1 -1
- xinference/model/llm/transformers/qwen-omni.py +6 -0
- xinference/model/llm/transformers/qwen_vl.py +1 -1
- xinference/model/llm/utils.py +68 -45
- xinference/model/llm/vllm/core.py +38 -18
- xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +133 -16
- xinference/model/video/model_spec.json +54 -0
- xinference/model/video/model_spec_modelscope.json +56 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +0 -71
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/src/locales/en.json +6 -4
- xinference/web/ui/src/locales/zh.json +6 -4
- {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/METADATA +59 -39
- {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/RECORD +87 -87
- {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
- xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,758 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import io
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
import uuid
|
|
22
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
23
|
+
|
|
24
|
+
import gradio as gr
|
|
25
|
+
import PIL.Image
|
|
26
|
+
from gradio import Markdown
|
|
27
|
+
|
|
28
|
+
from ..client.restful.restful_client import (
|
|
29
|
+
RESTfulAudioModelHandle,
|
|
30
|
+
RESTfulImageModelHandle,
|
|
31
|
+
RESTfulVideoModelHandle,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MediaInterface:
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
endpoint: str,
|
|
41
|
+
model_uid: str,
|
|
42
|
+
model_family: str,
|
|
43
|
+
model_name: str,
|
|
44
|
+
model_id: str,
|
|
45
|
+
model_revision: str,
|
|
46
|
+
model_ability: List[str],
|
|
47
|
+
model_type: str,
|
|
48
|
+
controlnet: Union[None, List[Dict[str, Union[str, None]]]],
|
|
49
|
+
access_token: Optional[str],
|
|
50
|
+
):
|
|
51
|
+
self.endpoint = endpoint
|
|
52
|
+
self.model_uid = model_uid
|
|
53
|
+
self.model_family = model_family
|
|
54
|
+
self.model_name = model_name
|
|
55
|
+
self.model_id = model_id
|
|
56
|
+
self.model_revision = model_revision
|
|
57
|
+
self.model_ability = model_ability
|
|
58
|
+
self.model_type = model_type
|
|
59
|
+
self.controlnet = controlnet
|
|
60
|
+
self.access_token = (
|
|
61
|
+
access_token.replace("Bearer ", "") if access_token is not None else None
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def build(self) -> gr.Blocks:
|
|
65
|
+
if self.model_type == "image":
|
|
66
|
+
assert "stable_diffusion" in self.model_family
|
|
67
|
+
|
|
68
|
+
interface = self.build_main_interface()
|
|
69
|
+
interface.queue()
|
|
70
|
+
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
71
|
+
# started, that event will not run, so manually invoke the startup events.
|
|
72
|
+
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
73
|
+
try:
|
|
74
|
+
interface.run_startup_events()
|
|
75
|
+
except AttributeError:
|
|
76
|
+
# compatibility
|
|
77
|
+
interface.startup_events()
|
|
78
|
+
favicon_path = os.path.join(
|
|
79
|
+
os.path.dirname(os.path.abspath(__file__)),
|
|
80
|
+
os.path.pardir,
|
|
81
|
+
"web",
|
|
82
|
+
"ui",
|
|
83
|
+
"public",
|
|
84
|
+
"favicon.svg",
|
|
85
|
+
)
|
|
86
|
+
interface.favicon_path = favicon_path
|
|
87
|
+
return interface
|
|
88
|
+
|
|
89
|
+
def text2image_interface(self) -> "gr.Blocks":
|
|
90
|
+
from ..model.image.stable_diffusion.core import SAMPLING_METHODS
|
|
91
|
+
|
|
92
|
+
def text_generate_image(
|
|
93
|
+
prompt: str,
|
|
94
|
+
n: int,
|
|
95
|
+
size_width: int,
|
|
96
|
+
size_height: int,
|
|
97
|
+
guidance_scale: int,
|
|
98
|
+
num_inference_steps: int,
|
|
99
|
+
negative_prompt: Optional[str] = None,
|
|
100
|
+
sampler_name: Optional[str] = None,
|
|
101
|
+
progress=gr.Progress(),
|
|
102
|
+
) -> PIL.Image.Image:
|
|
103
|
+
from ..client import RESTfulClient
|
|
104
|
+
|
|
105
|
+
client = RESTfulClient(self.endpoint)
|
|
106
|
+
client._set_token(self.access_token)
|
|
107
|
+
model = client.get_model(self.model_uid)
|
|
108
|
+
assert isinstance(model, RESTfulImageModelHandle)
|
|
109
|
+
|
|
110
|
+
size = f"{int(size_width)}*{int(size_height)}"
|
|
111
|
+
guidance_scale = None if guidance_scale == -1 else guidance_scale # type: ignore
|
|
112
|
+
num_inference_steps = (
|
|
113
|
+
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
114
|
+
)
|
|
115
|
+
sampler_name = None if sampler_name == "default" else sampler_name
|
|
116
|
+
|
|
117
|
+
response = None
|
|
118
|
+
exc = None
|
|
119
|
+
request_id = str(uuid.uuid4())
|
|
120
|
+
|
|
121
|
+
def run_in_thread():
|
|
122
|
+
nonlocal exc, response
|
|
123
|
+
try:
|
|
124
|
+
response = model.text_to_image(
|
|
125
|
+
request_id=request_id,
|
|
126
|
+
prompt=prompt,
|
|
127
|
+
n=n,
|
|
128
|
+
size=size,
|
|
129
|
+
num_inference_steps=num_inference_steps,
|
|
130
|
+
guidance_scale=guidance_scale,
|
|
131
|
+
negative_prompt=negative_prompt,
|
|
132
|
+
sampler_name=sampler_name,
|
|
133
|
+
response_format="b64_json",
|
|
134
|
+
)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
exc = e
|
|
137
|
+
|
|
138
|
+
t = threading.Thread(target=run_in_thread)
|
|
139
|
+
t.start()
|
|
140
|
+
while t.is_alive():
|
|
141
|
+
try:
|
|
142
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
143
|
+
except (KeyError, RuntimeError):
|
|
144
|
+
cur_progress = 0.0
|
|
145
|
+
|
|
146
|
+
progress(cur_progress, desc="Generating images")
|
|
147
|
+
time.sleep(1)
|
|
148
|
+
|
|
149
|
+
if exc:
|
|
150
|
+
raise exc
|
|
151
|
+
|
|
152
|
+
images = []
|
|
153
|
+
for image_dict in response["data"]: # type: ignore
|
|
154
|
+
assert image_dict["b64_json"] is not None
|
|
155
|
+
image_data = base64.b64decode(image_dict["b64_json"])
|
|
156
|
+
image = PIL.Image.open(io.BytesIO(image_data))
|
|
157
|
+
images.append(image)
|
|
158
|
+
|
|
159
|
+
return images
|
|
160
|
+
|
|
161
|
+
with gr.Blocks() as text2image_vl_interface:
|
|
162
|
+
with gr.Column():
|
|
163
|
+
with gr.Row():
|
|
164
|
+
with gr.Column(scale=10):
|
|
165
|
+
prompt = gr.Textbox(
|
|
166
|
+
label="Prompt",
|
|
167
|
+
show_label=True,
|
|
168
|
+
placeholder="Enter prompt here...",
|
|
169
|
+
)
|
|
170
|
+
negative_prompt = gr.Textbox(
|
|
171
|
+
label="Negative prompt",
|
|
172
|
+
show_label=True,
|
|
173
|
+
placeholder="Enter negative prompt here...",
|
|
174
|
+
)
|
|
175
|
+
with gr.Column(scale=1):
|
|
176
|
+
generate_button = gr.Button("Generate")
|
|
177
|
+
|
|
178
|
+
with gr.Row():
|
|
179
|
+
n = gr.Number(label="Number of Images", value=1)
|
|
180
|
+
size_width = gr.Number(label="Width", value=1024)
|
|
181
|
+
size_height = gr.Number(label="Height", value=1024)
|
|
182
|
+
with gr.Row():
|
|
183
|
+
guidance_scale = gr.Number(label="Guidance scale", value=-1)
|
|
184
|
+
num_inference_steps = gr.Number(
|
|
185
|
+
label="Inference Step Number", value=-1
|
|
186
|
+
)
|
|
187
|
+
sampler_name = gr.Dropdown(
|
|
188
|
+
choices=SAMPLING_METHODS,
|
|
189
|
+
value="default",
|
|
190
|
+
label="Sampling method",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
with gr.Column():
|
|
194
|
+
image_output = gr.Gallery()
|
|
195
|
+
|
|
196
|
+
generate_button.click(
|
|
197
|
+
text_generate_image,
|
|
198
|
+
inputs=[
|
|
199
|
+
prompt,
|
|
200
|
+
n,
|
|
201
|
+
size_width,
|
|
202
|
+
size_height,
|
|
203
|
+
guidance_scale,
|
|
204
|
+
num_inference_steps,
|
|
205
|
+
negative_prompt,
|
|
206
|
+
sampler_name,
|
|
207
|
+
],
|
|
208
|
+
outputs=image_output,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return text2image_vl_interface
|
|
212
|
+
|
|
213
|
+
def image2image_interface(self) -> "gr.Blocks":
|
|
214
|
+
from ..model.image.stable_diffusion.core import SAMPLING_METHODS
|
|
215
|
+
|
|
216
|
+
def image_generate_image(
|
|
217
|
+
prompt: str,
|
|
218
|
+
negative_prompt: str,
|
|
219
|
+
image: PIL.Image.Image,
|
|
220
|
+
n: int,
|
|
221
|
+
size_width: int,
|
|
222
|
+
size_height: int,
|
|
223
|
+
num_inference_steps: int,
|
|
224
|
+
padding_image_to_multiple: int,
|
|
225
|
+
sampler_name: Optional[str] = None,
|
|
226
|
+
progress=gr.Progress(),
|
|
227
|
+
) -> PIL.Image.Image:
|
|
228
|
+
from ..client import RESTfulClient
|
|
229
|
+
|
|
230
|
+
client = RESTfulClient(self.endpoint)
|
|
231
|
+
client._set_token(self.access_token)
|
|
232
|
+
model = client.get_model(self.model_uid)
|
|
233
|
+
assert isinstance(model, RESTfulImageModelHandle)
|
|
234
|
+
|
|
235
|
+
if size_width > 0 and size_height > 0:
|
|
236
|
+
size = f"{int(size_width)}*{int(size_height)}"
|
|
237
|
+
else:
|
|
238
|
+
size = None
|
|
239
|
+
num_inference_steps = (
|
|
240
|
+
None if num_inference_steps == -1 else num_inference_steps # type: ignore
|
|
241
|
+
)
|
|
242
|
+
padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
|
|
243
|
+
sampler_name = None if sampler_name == "default" else sampler_name
|
|
244
|
+
|
|
245
|
+
bio = io.BytesIO()
|
|
246
|
+
image.save(bio, format="png")
|
|
247
|
+
|
|
248
|
+
response = None
|
|
249
|
+
exc = None
|
|
250
|
+
request_id = str(uuid.uuid4())
|
|
251
|
+
|
|
252
|
+
def run_in_thread():
|
|
253
|
+
nonlocal exc, response
|
|
254
|
+
try:
|
|
255
|
+
response = model.image_to_image(
|
|
256
|
+
request_id=request_id,
|
|
257
|
+
prompt=prompt,
|
|
258
|
+
negative_prompt=negative_prompt,
|
|
259
|
+
n=n,
|
|
260
|
+
image=bio.getvalue(),
|
|
261
|
+
size=size,
|
|
262
|
+
response_format="b64_json",
|
|
263
|
+
num_inference_steps=num_inference_steps,
|
|
264
|
+
padding_image_to_multiple=padding_image_to_multiple,
|
|
265
|
+
sampler_name=sampler_name,
|
|
266
|
+
)
|
|
267
|
+
except Exception as e:
|
|
268
|
+
exc = e
|
|
269
|
+
|
|
270
|
+
t = threading.Thread(target=run_in_thread)
|
|
271
|
+
t.start()
|
|
272
|
+
while t.is_alive():
|
|
273
|
+
try:
|
|
274
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
275
|
+
except (KeyError, RuntimeError):
|
|
276
|
+
cur_progress = 0.0
|
|
277
|
+
|
|
278
|
+
progress(cur_progress, desc="Generating images")
|
|
279
|
+
time.sleep(1)
|
|
280
|
+
|
|
281
|
+
if exc:
|
|
282
|
+
raise exc
|
|
283
|
+
|
|
284
|
+
images = []
|
|
285
|
+
for image_dict in response["data"]: # type: ignore
|
|
286
|
+
assert image_dict["b64_json"] is not None
|
|
287
|
+
image_data = base64.b64decode(image_dict["b64_json"])
|
|
288
|
+
image = PIL.Image.open(io.BytesIO(image_data))
|
|
289
|
+
images.append(image)
|
|
290
|
+
|
|
291
|
+
return images
|
|
292
|
+
|
|
293
|
+
with gr.Blocks() as image2image_inteface:
|
|
294
|
+
with gr.Column():
|
|
295
|
+
with gr.Row():
|
|
296
|
+
with gr.Column(scale=10):
|
|
297
|
+
prompt = gr.Textbox(
|
|
298
|
+
label="Prompt",
|
|
299
|
+
show_label=True,
|
|
300
|
+
placeholder="Enter prompt here...",
|
|
301
|
+
)
|
|
302
|
+
negative_prompt = gr.Textbox(
|
|
303
|
+
label="Negative Prompt",
|
|
304
|
+
show_label=True,
|
|
305
|
+
placeholder="Enter negative prompt here...",
|
|
306
|
+
)
|
|
307
|
+
with gr.Column(scale=1):
|
|
308
|
+
generate_button = gr.Button("Generate")
|
|
309
|
+
|
|
310
|
+
with gr.Row():
|
|
311
|
+
n = gr.Number(label="Number of image", value=1)
|
|
312
|
+
size_width = gr.Number(label="Width", value=-1)
|
|
313
|
+
size_height = gr.Number(label="Height", value=-1)
|
|
314
|
+
|
|
315
|
+
with gr.Row():
|
|
316
|
+
num_inference_steps = gr.Number(
|
|
317
|
+
label="Inference Step Number", value=-1
|
|
318
|
+
)
|
|
319
|
+
padding_image_to_multiple = gr.Number(
|
|
320
|
+
label="Padding image to multiple", value=-1
|
|
321
|
+
)
|
|
322
|
+
sampler_name = gr.Dropdown(
|
|
323
|
+
choices=SAMPLING_METHODS,
|
|
324
|
+
value="default",
|
|
325
|
+
label="Sampling method",
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
with gr.Row():
|
|
329
|
+
with gr.Column(scale=1):
|
|
330
|
+
uploaded_image = gr.Image(type="pil", label="Upload Image")
|
|
331
|
+
with gr.Column(scale=1):
|
|
332
|
+
output_gallery = gr.Gallery()
|
|
333
|
+
|
|
334
|
+
generate_button.click(
|
|
335
|
+
image_generate_image,
|
|
336
|
+
inputs=[
|
|
337
|
+
prompt,
|
|
338
|
+
negative_prompt,
|
|
339
|
+
uploaded_image,
|
|
340
|
+
n,
|
|
341
|
+
size_width,
|
|
342
|
+
size_height,
|
|
343
|
+
num_inference_steps,
|
|
344
|
+
padding_image_to_multiple,
|
|
345
|
+
sampler_name,
|
|
346
|
+
],
|
|
347
|
+
outputs=output_gallery,
|
|
348
|
+
)
|
|
349
|
+
return image2image_inteface
|
|
350
|
+
|
|
351
|
+
def text2video_interface(self) -> "gr.Blocks":
|
|
352
|
+
def text_generate_video(
|
|
353
|
+
prompt: str,
|
|
354
|
+
negative_prompt: str,
|
|
355
|
+
num_frames: int,
|
|
356
|
+
fps: int,
|
|
357
|
+
num_inference_steps: int,
|
|
358
|
+
guidance_scale: float,
|
|
359
|
+
width: int,
|
|
360
|
+
height: int,
|
|
361
|
+
progress=gr.Progress(),
|
|
362
|
+
) -> List[Tuple[str, str]]:
|
|
363
|
+
from ..client import RESTfulClient
|
|
364
|
+
|
|
365
|
+
client = RESTfulClient(self.endpoint)
|
|
366
|
+
client._set_token(self.access_token)
|
|
367
|
+
model = client.get_model(self.model_uid)
|
|
368
|
+
assert isinstance(model, RESTfulVideoModelHandle)
|
|
369
|
+
|
|
370
|
+
request_id = str(uuid.uuid4())
|
|
371
|
+
response = None
|
|
372
|
+
exc = None
|
|
373
|
+
|
|
374
|
+
# Run generation in a separate thread to allow progress tracking
|
|
375
|
+
def run_in_thread():
|
|
376
|
+
nonlocal exc, response
|
|
377
|
+
try:
|
|
378
|
+
response = model.text_to_video(
|
|
379
|
+
request_id=request_id,
|
|
380
|
+
prompt=prompt,
|
|
381
|
+
negative_prompt=negative_prompt,
|
|
382
|
+
num_frames=num_frames,
|
|
383
|
+
fps=fps,
|
|
384
|
+
num_inference_steps=num_inference_steps,
|
|
385
|
+
guidance_scale=guidance_scale,
|
|
386
|
+
width=width,
|
|
387
|
+
height=height,
|
|
388
|
+
response_format="b64_json",
|
|
389
|
+
)
|
|
390
|
+
except Exception as e:
|
|
391
|
+
exc = e
|
|
392
|
+
|
|
393
|
+
t = threading.Thread(target=run_in_thread)
|
|
394
|
+
t.start()
|
|
395
|
+
|
|
396
|
+
# Update progress bar during generation
|
|
397
|
+
while t.is_alive():
|
|
398
|
+
try:
|
|
399
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
400
|
+
except Exception:
|
|
401
|
+
cur_progress = 0.0
|
|
402
|
+
progress(cur_progress, desc="Generating video")
|
|
403
|
+
time.sleep(1)
|
|
404
|
+
|
|
405
|
+
if exc:
|
|
406
|
+
raise exc
|
|
407
|
+
|
|
408
|
+
# Decode and return the generated video
|
|
409
|
+
videos = []
|
|
410
|
+
for video_dict in response["data"]: # type: ignore
|
|
411
|
+
video_data = base64.b64decode(video_dict["b64_json"])
|
|
412
|
+
video_path = f"/tmp/{uuid.uuid4()}.mp4"
|
|
413
|
+
with open(video_path, "wb") as f:
|
|
414
|
+
f.write(video_data)
|
|
415
|
+
videos.append((video_path, "Generated Video"))
|
|
416
|
+
|
|
417
|
+
return videos
|
|
418
|
+
|
|
419
|
+
# Gradio UI definition
|
|
420
|
+
with gr.Blocks() as text2video_ui:
|
|
421
|
+
# Prompt & Negative Prompt (stacked vertically)
|
|
422
|
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter video prompt")
|
|
423
|
+
negative_prompt = gr.Textbox(
|
|
424
|
+
label="Negative Prompt", placeholder="Enter negative prompt"
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Parameters (2-column layout)
|
|
428
|
+
with gr.Row():
|
|
429
|
+
with gr.Column():
|
|
430
|
+
width = gr.Number(label="Width", value=512)
|
|
431
|
+
num_frames = gr.Number(label="Frames", value=16)
|
|
432
|
+
steps = gr.Number(label="Inference Steps", value=25)
|
|
433
|
+
with gr.Column():
|
|
434
|
+
height = gr.Number(label="Height", value=512)
|
|
435
|
+
fps = gr.Number(label="FPS", value=8)
|
|
436
|
+
guidance_scale = gr.Slider(
|
|
437
|
+
label="Guidance Scale", minimum=1, maximum=20, value=7.5
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Generate button
|
|
441
|
+
generate = gr.Button("Generate")
|
|
442
|
+
|
|
443
|
+
# Output gallery
|
|
444
|
+
gallery = gr.Gallery(label="Generated Videos", columns=2)
|
|
445
|
+
|
|
446
|
+
# Button click logic
|
|
447
|
+
generate.click(
|
|
448
|
+
fn=text_generate_video,
|
|
449
|
+
inputs=[
|
|
450
|
+
prompt,
|
|
451
|
+
negative_prompt,
|
|
452
|
+
num_frames,
|
|
453
|
+
fps,
|
|
454
|
+
steps,
|
|
455
|
+
guidance_scale,
|
|
456
|
+
width,
|
|
457
|
+
height,
|
|
458
|
+
],
|
|
459
|
+
outputs=gallery,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return text2video_ui
|
|
463
|
+
|
|
464
|
+
def image2video_interface(self) -> "gr.Blocks":
|
|
465
|
+
def image_generate_video(
|
|
466
|
+
image: "PIL.Image",
|
|
467
|
+
prompt: str,
|
|
468
|
+
negative_prompt: str,
|
|
469
|
+
num_frames: int,
|
|
470
|
+
fps: int,
|
|
471
|
+
num_inference_steps: int,
|
|
472
|
+
guidance_scale: float,
|
|
473
|
+
width: int,
|
|
474
|
+
height: int,
|
|
475
|
+
progress=gr.Progress(),
|
|
476
|
+
) -> List[Tuple[str, str]]:
|
|
477
|
+
from ..client import RESTfulClient
|
|
478
|
+
|
|
479
|
+
client = RESTfulClient(self.endpoint)
|
|
480
|
+
client._set_token(self.access_token)
|
|
481
|
+
model = client.get_model(self.model_uid)
|
|
482
|
+
assert isinstance(model, RESTfulVideoModelHandle)
|
|
483
|
+
|
|
484
|
+
request_id = str(uuid.uuid4())
|
|
485
|
+
response = None
|
|
486
|
+
exc = None
|
|
487
|
+
|
|
488
|
+
# Convert uploaded image to base64
|
|
489
|
+
buffered = io.BytesIO()
|
|
490
|
+
image.save(buffered, format="PNG")
|
|
491
|
+
|
|
492
|
+
# Run generation in a separate thread
|
|
493
|
+
def run_in_thread():
|
|
494
|
+
nonlocal exc, response
|
|
495
|
+
try:
|
|
496
|
+
response = model.image_to_video(
|
|
497
|
+
request_id=request_id,
|
|
498
|
+
image=buffered.getvalue(),
|
|
499
|
+
prompt=prompt,
|
|
500
|
+
negative_prompt=negative_prompt,
|
|
501
|
+
num_frames=num_frames,
|
|
502
|
+
fps=fps,
|
|
503
|
+
num_inference_steps=num_inference_steps,
|
|
504
|
+
guidance_scale=guidance_scale,
|
|
505
|
+
width=width,
|
|
506
|
+
height=height,
|
|
507
|
+
response_format="b64_json",
|
|
508
|
+
)
|
|
509
|
+
except Exception as e:
|
|
510
|
+
exc = e
|
|
511
|
+
|
|
512
|
+
t = threading.Thread(target=run_in_thread)
|
|
513
|
+
t.start()
|
|
514
|
+
|
|
515
|
+
# Progress loop
|
|
516
|
+
while t.is_alive():
|
|
517
|
+
try:
|
|
518
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
519
|
+
except Exception:
|
|
520
|
+
cur_progress = 0.0
|
|
521
|
+
progress(cur_progress, desc="Generating video from image")
|
|
522
|
+
time.sleep(1)
|
|
523
|
+
|
|
524
|
+
if exc:
|
|
525
|
+
raise exc
|
|
526
|
+
|
|
527
|
+
# Decode and return video files
|
|
528
|
+
videos = []
|
|
529
|
+
for video_dict in response["data"]: # type: ignore
|
|
530
|
+
video_data = base64.b64decode(video_dict["b64_json"])
|
|
531
|
+
video_path = f"/tmp/{uuid.uuid4()}.mp4"
|
|
532
|
+
with open(video_path, "wb") as f:
|
|
533
|
+
f.write(video_data)
|
|
534
|
+
videos.append((video_path, "Generated Video"))
|
|
535
|
+
|
|
536
|
+
return videos
|
|
537
|
+
|
|
538
|
+
# Gradio UI
|
|
539
|
+
with gr.Blocks() as image2video_ui:
|
|
540
|
+
image = gr.Image(label="Input Image", type="pil")
|
|
541
|
+
|
|
542
|
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter video prompt")
|
|
543
|
+
negative_prompt = gr.Textbox(
|
|
544
|
+
label="Negative Prompt", placeholder="Enter negative prompt"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
with gr.Row():
|
|
548
|
+
with gr.Column():
|
|
549
|
+
width = gr.Number(label="Width", value=512)
|
|
550
|
+
num_frames = gr.Number(label="Frames", value=16)
|
|
551
|
+
steps = gr.Number(label="Inference Steps", value=25)
|
|
552
|
+
with gr.Column():
|
|
553
|
+
height = gr.Number(label="Height", value=512)
|
|
554
|
+
fps = gr.Number(label="FPS", value=8)
|
|
555
|
+
guidance_scale = gr.Slider(
|
|
556
|
+
label="Guidance Scale", minimum=1, maximum=20, value=7.5
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
generate = gr.Button("Generate")
|
|
560
|
+
gallery = gr.Gallery(label="Generated Videos", columns=2)
|
|
561
|
+
|
|
562
|
+
generate.click(
|
|
563
|
+
fn=image_generate_video,
|
|
564
|
+
inputs=[
|
|
565
|
+
image,
|
|
566
|
+
prompt,
|
|
567
|
+
negative_prompt,
|
|
568
|
+
num_frames,
|
|
569
|
+
fps,
|
|
570
|
+
steps,
|
|
571
|
+
guidance_scale,
|
|
572
|
+
width,
|
|
573
|
+
height,
|
|
574
|
+
],
|
|
575
|
+
outputs=gallery,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
return image2video_ui
|
|
579
|
+
|
|
580
|
+
def audio2text_interface(self) -> "gr.Blocks":
|
|
581
|
+
def transcribe_audio(
|
|
582
|
+
audio_path: str,
|
|
583
|
+
language: Optional[str],
|
|
584
|
+
prompt: Optional[str],
|
|
585
|
+
temperature: float,
|
|
586
|
+
) -> str:
|
|
587
|
+
from ..client import RESTfulClient
|
|
588
|
+
|
|
589
|
+
client = RESTfulClient(self.endpoint)
|
|
590
|
+
client._set_token(self.access_token)
|
|
591
|
+
model = client.get_model(self.model_uid)
|
|
592
|
+
assert isinstance(model, RESTfulAudioModelHandle)
|
|
593
|
+
|
|
594
|
+
with open(audio_path, "rb") as f:
|
|
595
|
+
audio_data = f.read()
|
|
596
|
+
|
|
597
|
+
response = model.transcriptions(
|
|
598
|
+
audio=audio_data,
|
|
599
|
+
language=language or None,
|
|
600
|
+
prompt=prompt or None,
|
|
601
|
+
temperature=temperature,
|
|
602
|
+
response_format="json",
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
return response.get("text", "No transcription result.")
|
|
606
|
+
|
|
607
|
+
with gr.Blocks() as audio2text_ui:
|
|
608
|
+
with gr.Row():
|
|
609
|
+
audio_input = gr.Audio(
|
|
610
|
+
type="filepath",
|
|
611
|
+
label="Upload or Record Audio",
|
|
612
|
+
sources=["upload", "microphone"], # ✅ support both
|
|
613
|
+
)
|
|
614
|
+
with gr.Row():
|
|
615
|
+
language = gr.Textbox(
|
|
616
|
+
label="Language", placeholder="e.g. en or zh", value=""
|
|
617
|
+
)
|
|
618
|
+
prompt = gr.Textbox(
|
|
619
|
+
label="Prompt (optional)",
|
|
620
|
+
placeholder="Provide context or vocabulary",
|
|
621
|
+
)
|
|
622
|
+
temperature = gr.Slider(
|
|
623
|
+
label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1
|
|
624
|
+
)
|
|
625
|
+
transcribe_btn = gr.Button("Transcribe")
|
|
626
|
+
output_text = gr.Textbox(label="Transcription", lines=5)
|
|
627
|
+
|
|
628
|
+
transcribe_btn.click(
|
|
629
|
+
fn=transcribe_audio,
|
|
630
|
+
inputs=[audio_input, language, prompt, temperature],
|
|
631
|
+
outputs=output_text,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
return audio2text_ui
|
|
635
|
+
|
|
636
|
+
def text2speech_interface(self) -> "gr.Blocks":
|
|
637
|
+
def tts_generate(
|
|
638
|
+
input_text: str,
|
|
639
|
+
voice: str,
|
|
640
|
+
speed: float,
|
|
641
|
+
prompt_speech_file,
|
|
642
|
+
prompt_text: Optional[str],
|
|
643
|
+
) -> str:
|
|
644
|
+
from ..client import RESTfulClient
|
|
645
|
+
|
|
646
|
+
client = RESTfulClient(self.endpoint)
|
|
647
|
+
client._set_token(self.access_token)
|
|
648
|
+
model = client.get_model(self.model_uid)
|
|
649
|
+
assert hasattr(model, "speech")
|
|
650
|
+
|
|
651
|
+
prompt_speech_bytes = None
|
|
652
|
+
if prompt_speech_file is not None:
|
|
653
|
+
with open(prompt_speech_file, "rb") as f:
|
|
654
|
+
prompt_speech_bytes = f.read()
|
|
655
|
+
|
|
656
|
+
response = model.speech(
|
|
657
|
+
input=input_text,
|
|
658
|
+
voice=voice,
|
|
659
|
+
speed=speed,
|
|
660
|
+
response_format="mp3",
|
|
661
|
+
prompt_speech=prompt_speech_bytes,
|
|
662
|
+
prompt_text=prompt_text,
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Write to a temp .mp3 file and return its path
|
|
666
|
+
audio_path = f"/tmp/{uuid.uuid4()}.mp3"
|
|
667
|
+
with open(audio_path, "wb") as f:
|
|
668
|
+
f.write(response)
|
|
669
|
+
|
|
670
|
+
return audio_path
|
|
671
|
+
|
|
672
|
+
# Gradio UI
|
|
673
|
+
with gr.Blocks() as tts_ui:
|
|
674
|
+
with gr.Row():
|
|
675
|
+
with gr.Column():
|
|
676
|
+
input_text = gr.Textbox(
|
|
677
|
+
label="Text", placeholder="Enter text to synthesize"
|
|
678
|
+
)
|
|
679
|
+
voice = gr.Textbox(
|
|
680
|
+
label="Voice", placeholder="Optional voice ID", value=""
|
|
681
|
+
)
|
|
682
|
+
speed = gr.Slider(
|
|
683
|
+
label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
prompt_speech = gr.Audio(
|
|
687
|
+
label="Prompt Speech (for cloning)", type="filepath"
|
|
688
|
+
)
|
|
689
|
+
prompt_text = gr.Textbox(
|
|
690
|
+
label="Prompt Text (for cloning)",
|
|
691
|
+
placeholder="Text of the prompt speech",
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
generate = gr.Button("Generate")
|
|
695
|
+
|
|
696
|
+
with gr.Column():
|
|
697
|
+
audio_output = gr.Audio(label="Generated Audio", type="filepath")
|
|
698
|
+
|
|
699
|
+
generate.click(
|
|
700
|
+
fn=tts_generate,
|
|
701
|
+
inputs=[input_text, voice, speed, prompt_speech, prompt_text],
|
|
702
|
+
outputs=audio_output,
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
return tts_ui
|
|
706
|
+
|
|
707
|
+
def build_main_interface(self) -> "gr.Blocks":
|
|
708
|
+
if self.model_type == "image":
|
|
709
|
+
title = f"🎨 Xinference Stable Diffusion: {self.model_name} 🎨"
|
|
710
|
+
elif self.model_type == "video":
|
|
711
|
+
title = f"🎨 Xinference Video Generation: {self.model_name} 🎨"
|
|
712
|
+
else:
|
|
713
|
+
assert self.model_type == "audio"
|
|
714
|
+
title = f"🎨 Xinference Audio Model: {self.model_name} 🎨"
|
|
715
|
+
with gr.Blocks(
|
|
716
|
+
title=title,
|
|
717
|
+
css="""
|
|
718
|
+
.center{
|
|
719
|
+
display: flex;
|
|
720
|
+
justify-content: center;
|
|
721
|
+
align-items: center;
|
|
722
|
+
padding: 0px;
|
|
723
|
+
color: #9ea4b0 !important;
|
|
724
|
+
}
|
|
725
|
+
""",
|
|
726
|
+
analytics_enabled=False,
|
|
727
|
+
) as app:
|
|
728
|
+
Markdown(
|
|
729
|
+
f"""
|
|
730
|
+
<h1 class="center" style='text-align: center; margin-bottom: 1rem'>{title}</h1>
|
|
731
|
+
"""
|
|
732
|
+
)
|
|
733
|
+
Markdown(
|
|
734
|
+
f"""
|
|
735
|
+
<div class="center">
|
|
736
|
+
Model ID: {self.model_uid}
|
|
737
|
+
</div>
|
|
738
|
+
"""
|
|
739
|
+
)
|
|
740
|
+
if "text2image" in self.model_ability:
|
|
741
|
+
with gr.Tab("Text to Image"):
|
|
742
|
+
self.text2image_interface()
|
|
743
|
+
if "image2image" in self.model_ability:
|
|
744
|
+
with gr.Tab("Image to Image"):
|
|
745
|
+
self.image2image_interface()
|
|
746
|
+
if "text2video" in self.model_ability:
|
|
747
|
+
with gr.Tab("Text to Video"):
|
|
748
|
+
self.text2video_interface()
|
|
749
|
+
if "image2video" in self.model_ability:
|
|
750
|
+
with gr.Tab("Image to Video"):
|
|
751
|
+
self.image2video_interface()
|
|
752
|
+
if "audio2text" in self.model_ability:
|
|
753
|
+
with gr.Tab("Audio to Text"):
|
|
754
|
+
self.audio2text_interface()
|
|
755
|
+
if "text2audio" in self.model_ability:
|
|
756
|
+
with gr.Tab("Text to Audio"):
|
|
757
|
+
self.text2speech_interface()
|
|
758
|
+
return app
|