xinference 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,758 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import io
17
+ import logging
18
+ import os
19
+ import threading
20
+ import time
21
+ import uuid
22
+ from typing import Dict, List, Optional, Tuple, Union
23
+
24
+ import gradio as gr
25
+ import PIL.Image
26
+ from gradio import Markdown
27
+
28
+ from ..client.restful.restful_client import (
29
+ RESTfulAudioModelHandle,
30
+ RESTfulImageModelHandle,
31
+ RESTfulVideoModelHandle,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class MediaInterface:
38
+ def __init__(
39
+ self,
40
+ endpoint: str,
41
+ model_uid: str,
42
+ model_family: str,
43
+ model_name: str,
44
+ model_id: str,
45
+ model_revision: str,
46
+ model_ability: List[str],
47
+ model_type: str,
48
+ controlnet: Union[None, List[Dict[str, Union[str, None]]]],
49
+ access_token: Optional[str],
50
+ ):
51
+ self.endpoint = endpoint
52
+ self.model_uid = model_uid
53
+ self.model_family = model_family
54
+ self.model_name = model_name
55
+ self.model_id = model_id
56
+ self.model_revision = model_revision
57
+ self.model_ability = model_ability
58
+ self.model_type = model_type
59
+ self.controlnet = controlnet
60
+ self.access_token = (
61
+ access_token.replace("Bearer ", "") if access_token is not None else None
62
+ )
63
+
64
+ def build(self) -> gr.Blocks:
65
+ if self.model_type == "image":
66
+ assert "stable_diffusion" in self.model_family
67
+
68
+ interface = self.build_main_interface()
69
+ interface.queue()
70
+ # Gradio initiates the queue during a startup event, but since the app has already been
71
+ # started, that event will not run, so manually invoke the startup events.
72
+ # See: https://github.com/gradio-app/gradio/issues/5228
73
+ try:
74
+ interface.run_startup_events()
75
+ except AttributeError:
76
+ # compatibility
77
+ interface.startup_events()
78
+ favicon_path = os.path.join(
79
+ os.path.dirname(os.path.abspath(__file__)),
80
+ os.path.pardir,
81
+ "web",
82
+ "ui",
83
+ "public",
84
+ "favicon.svg",
85
+ )
86
+ interface.favicon_path = favicon_path
87
+ return interface
88
+
89
+ def text2image_interface(self) -> "gr.Blocks":
90
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
91
+
92
+ def text_generate_image(
93
+ prompt: str,
94
+ n: int,
95
+ size_width: int,
96
+ size_height: int,
97
+ guidance_scale: int,
98
+ num_inference_steps: int,
99
+ negative_prompt: Optional[str] = None,
100
+ sampler_name: Optional[str] = None,
101
+ progress=gr.Progress(),
102
+ ) -> PIL.Image.Image:
103
+ from ..client import RESTfulClient
104
+
105
+ client = RESTfulClient(self.endpoint)
106
+ client._set_token(self.access_token)
107
+ model = client.get_model(self.model_uid)
108
+ assert isinstance(model, RESTfulImageModelHandle)
109
+
110
+ size = f"{int(size_width)}*{int(size_height)}"
111
+ guidance_scale = None if guidance_scale == -1 else guidance_scale # type: ignore
112
+ num_inference_steps = (
113
+ None if num_inference_steps == -1 else num_inference_steps # type: ignore
114
+ )
115
+ sampler_name = None if sampler_name == "default" else sampler_name
116
+
117
+ response = None
118
+ exc = None
119
+ request_id = str(uuid.uuid4())
120
+
121
+ def run_in_thread():
122
+ nonlocal exc, response
123
+ try:
124
+ response = model.text_to_image(
125
+ request_id=request_id,
126
+ prompt=prompt,
127
+ n=n,
128
+ size=size,
129
+ num_inference_steps=num_inference_steps,
130
+ guidance_scale=guidance_scale,
131
+ negative_prompt=negative_prompt,
132
+ sampler_name=sampler_name,
133
+ response_format="b64_json",
134
+ )
135
+ except Exception as e:
136
+ exc = e
137
+
138
+ t = threading.Thread(target=run_in_thread)
139
+ t.start()
140
+ while t.is_alive():
141
+ try:
142
+ cur_progress = client.get_progress(request_id)["progress"]
143
+ except (KeyError, RuntimeError):
144
+ cur_progress = 0.0
145
+
146
+ progress(cur_progress, desc="Generating images")
147
+ time.sleep(1)
148
+
149
+ if exc:
150
+ raise exc
151
+
152
+ images = []
153
+ for image_dict in response["data"]: # type: ignore
154
+ assert image_dict["b64_json"] is not None
155
+ image_data = base64.b64decode(image_dict["b64_json"])
156
+ image = PIL.Image.open(io.BytesIO(image_data))
157
+ images.append(image)
158
+
159
+ return images
160
+
161
+ with gr.Blocks() as text2image_vl_interface:
162
+ with gr.Column():
163
+ with gr.Row():
164
+ with gr.Column(scale=10):
165
+ prompt = gr.Textbox(
166
+ label="Prompt",
167
+ show_label=True,
168
+ placeholder="Enter prompt here...",
169
+ )
170
+ negative_prompt = gr.Textbox(
171
+ label="Negative prompt",
172
+ show_label=True,
173
+ placeholder="Enter negative prompt here...",
174
+ )
175
+ with gr.Column(scale=1):
176
+ generate_button = gr.Button("Generate")
177
+
178
+ with gr.Row():
179
+ n = gr.Number(label="Number of Images", value=1)
180
+ size_width = gr.Number(label="Width", value=1024)
181
+ size_height = gr.Number(label="Height", value=1024)
182
+ with gr.Row():
183
+ guidance_scale = gr.Number(label="Guidance scale", value=-1)
184
+ num_inference_steps = gr.Number(
185
+ label="Inference Step Number", value=-1
186
+ )
187
+ sampler_name = gr.Dropdown(
188
+ choices=SAMPLING_METHODS,
189
+ value="default",
190
+ label="Sampling method",
191
+ )
192
+
193
+ with gr.Column():
194
+ image_output = gr.Gallery()
195
+
196
+ generate_button.click(
197
+ text_generate_image,
198
+ inputs=[
199
+ prompt,
200
+ n,
201
+ size_width,
202
+ size_height,
203
+ guidance_scale,
204
+ num_inference_steps,
205
+ negative_prompt,
206
+ sampler_name,
207
+ ],
208
+ outputs=image_output,
209
+ )
210
+
211
+ return text2image_vl_interface
212
+
213
+ def image2image_interface(self) -> "gr.Blocks":
214
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
215
+
216
+ def image_generate_image(
217
+ prompt: str,
218
+ negative_prompt: str,
219
+ image: PIL.Image.Image,
220
+ n: int,
221
+ size_width: int,
222
+ size_height: int,
223
+ num_inference_steps: int,
224
+ padding_image_to_multiple: int,
225
+ sampler_name: Optional[str] = None,
226
+ progress=gr.Progress(),
227
+ ) -> PIL.Image.Image:
228
+ from ..client import RESTfulClient
229
+
230
+ client = RESTfulClient(self.endpoint)
231
+ client._set_token(self.access_token)
232
+ model = client.get_model(self.model_uid)
233
+ assert isinstance(model, RESTfulImageModelHandle)
234
+
235
+ if size_width > 0 and size_height > 0:
236
+ size = f"{int(size_width)}*{int(size_height)}"
237
+ else:
238
+ size = None
239
+ num_inference_steps = (
240
+ None if num_inference_steps == -1 else num_inference_steps # type: ignore
241
+ )
242
+ padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
243
+ sampler_name = None if sampler_name == "default" else sampler_name
244
+
245
+ bio = io.BytesIO()
246
+ image.save(bio, format="png")
247
+
248
+ response = None
249
+ exc = None
250
+ request_id = str(uuid.uuid4())
251
+
252
+ def run_in_thread():
253
+ nonlocal exc, response
254
+ try:
255
+ response = model.image_to_image(
256
+ request_id=request_id,
257
+ prompt=prompt,
258
+ negative_prompt=negative_prompt,
259
+ n=n,
260
+ image=bio.getvalue(),
261
+ size=size,
262
+ response_format="b64_json",
263
+ num_inference_steps=num_inference_steps,
264
+ padding_image_to_multiple=padding_image_to_multiple,
265
+ sampler_name=sampler_name,
266
+ )
267
+ except Exception as e:
268
+ exc = e
269
+
270
+ t = threading.Thread(target=run_in_thread)
271
+ t.start()
272
+ while t.is_alive():
273
+ try:
274
+ cur_progress = client.get_progress(request_id)["progress"]
275
+ except (KeyError, RuntimeError):
276
+ cur_progress = 0.0
277
+
278
+ progress(cur_progress, desc="Generating images")
279
+ time.sleep(1)
280
+
281
+ if exc:
282
+ raise exc
283
+
284
+ images = []
285
+ for image_dict in response["data"]: # type: ignore
286
+ assert image_dict["b64_json"] is not None
287
+ image_data = base64.b64decode(image_dict["b64_json"])
288
+ image = PIL.Image.open(io.BytesIO(image_data))
289
+ images.append(image)
290
+
291
+ return images
292
+
293
+ with gr.Blocks() as image2image_inteface:
294
+ with gr.Column():
295
+ with gr.Row():
296
+ with gr.Column(scale=10):
297
+ prompt = gr.Textbox(
298
+ label="Prompt",
299
+ show_label=True,
300
+ placeholder="Enter prompt here...",
301
+ )
302
+ negative_prompt = gr.Textbox(
303
+ label="Negative Prompt",
304
+ show_label=True,
305
+ placeholder="Enter negative prompt here...",
306
+ )
307
+ with gr.Column(scale=1):
308
+ generate_button = gr.Button("Generate")
309
+
310
+ with gr.Row():
311
+ n = gr.Number(label="Number of image", value=1)
312
+ size_width = gr.Number(label="Width", value=-1)
313
+ size_height = gr.Number(label="Height", value=-1)
314
+
315
+ with gr.Row():
316
+ num_inference_steps = gr.Number(
317
+ label="Inference Step Number", value=-1
318
+ )
319
+ padding_image_to_multiple = gr.Number(
320
+ label="Padding image to multiple", value=-1
321
+ )
322
+ sampler_name = gr.Dropdown(
323
+ choices=SAMPLING_METHODS,
324
+ value="default",
325
+ label="Sampling method",
326
+ )
327
+
328
+ with gr.Row():
329
+ with gr.Column(scale=1):
330
+ uploaded_image = gr.Image(type="pil", label="Upload Image")
331
+ with gr.Column(scale=1):
332
+ output_gallery = gr.Gallery()
333
+
334
+ generate_button.click(
335
+ image_generate_image,
336
+ inputs=[
337
+ prompt,
338
+ negative_prompt,
339
+ uploaded_image,
340
+ n,
341
+ size_width,
342
+ size_height,
343
+ num_inference_steps,
344
+ padding_image_to_multiple,
345
+ sampler_name,
346
+ ],
347
+ outputs=output_gallery,
348
+ )
349
+ return image2image_inteface
350
+
351
+ def text2video_interface(self) -> "gr.Blocks":
352
+ def text_generate_video(
353
+ prompt: str,
354
+ negative_prompt: str,
355
+ num_frames: int,
356
+ fps: int,
357
+ num_inference_steps: int,
358
+ guidance_scale: float,
359
+ width: int,
360
+ height: int,
361
+ progress=gr.Progress(),
362
+ ) -> List[Tuple[str, str]]:
363
+ from ..client import RESTfulClient
364
+
365
+ client = RESTfulClient(self.endpoint)
366
+ client._set_token(self.access_token)
367
+ model = client.get_model(self.model_uid)
368
+ assert isinstance(model, RESTfulVideoModelHandle)
369
+
370
+ request_id = str(uuid.uuid4())
371
+ response = None
372
+ exc = None
373
+
374
+ # Run generation in a separate thread to allow progress tracking
375
+ def run_in_thread():
376
+ nonlocal exc, response
377
+ try:
378
+ response = model.text_to_video(
379
+ request_id=request_id,
380
+ prompt=prompt,
381
+ negative_prompt=negative_prompt,
382
+ num_frames=num_frames,
383
+ fps=fps,
384
+ num_inference_steps=num_inference_steps,
385
+ guidance_scale=guidance_scale,
386
+ width=width,
387
+ height=height,
388
+ response_format="b64_json",
389
+ )
390
+ except Exception as e:
391
+ exc = e
392
+
393
+ t = threading.Thread(target=run_in_thread)
394
+ t.start()
395
+
396
+ # Update progress bar during generation
397
+ while t.is_alive():
398
+ try:
399
+ cur_progress = client.get_progress(request_id)["progress"]
400
+ except Exception:
401
+ cur_progress = 0.0
402
+ progress(cur_progress, desc="Generating video")
403
+ time.sleep(1)
404
+
405
+ if exc:
406
+ raise exc
407
+
408
+ # Decode and return the generated video
409
+ videos = []
410
+ for video_dict in response["data"]: # type: ignore
411
+ video_data = base64.b64decode(video_dict["b64_json"])
412
+ video_path = f"/tmp/{uuid.uuid4()}.mp4"
413
+ with open(video_path, "wb") as f:
414
+ f.write(video_data)
415
+ videos.append((video_path, "Generated Video"))
416
+
417
+ return videos
418
+
419
+ # Gradio UI definition
420
+ with gr.Blocks() as text2video_ui:
421
+ # Prompt & Negative Prompt (stacked vertically)
422
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter video prompt")
423
+ negative_prompt = gr.Textbox(
424
+ label="Negative Prompt", placeholder="Enter negative prompt"
425
+ )
426
+
427
+ # Parameters (2-column layout)
428
+ with gr.Row():
429
+ with gr.Column():
430
+ width = gr.Number(label="Width", value=512)
431
+ num_frames = gr.Number(label="Frames", value=16)
432
+ steps = gr.Number(label="Inference Steps", value=25)
433
+ with gr.Column():
434
+ height = gr.Number(label="Height", value=512)
435
+ fps = gr.Number(label="FPS", value=8)
436
+ guidance_scale = gr.Slider(
437
+ label="Guidance Scale", minimum=1, maximum=20, value=7.5
438
+ )
439
+
440
+ # Generate button
441
+ generate = gr.Button("Generate")
442
+
443
+ # Output gallery
444
+ gallery = gr.Gallery(label="Generated Videos", columns=2)
445
+
446
+ # Button click logic
447
+ generate.click(
448
+ fn=text_generate_video,
449
+ inputs=[
450
+ prompt,
451
+ negative_prompt,
452
+ num_frames,
453
+ fps,
454
+ steps,
455
+ guidance_scale,
456
+ width,
457
+ height,
458
+ ],
459
+ outputs=gallery,
460
+ )
461
+
462
+ return text2video_ui
463
+
464
+ def image2video_interface(self) -> "gr.Blocks":
465
+ def image_generate_video(
466
+ image: "PIL.Image",
467
+ prompt: str,
468
+ negative_prompt: str,
469
+ num_frames: int,
470
+ fps: int,
471
+ num_inference_steps: int,
472
+ guidance_scale: float,
473
+ width: int,
474
+ height: int,
475
+ progress=gr.Progress(),
476
+ ) -> List[Tuple[str, str]]:
477
+ from ..client import RESTfulClient
478
+
479
+ client = RESTfulClient(self.endpoint)
480
+ client._set_token(self.access_token)
481
+ model = client.get_model(self.model_uid)
482
+ assert isinstance(model, RESTfulVideoModelHandle)
483
+
484
+ request_id = str(uuid.uuid4())
485
+ response = None
486
+ exc = None
487
+
488
+ # Convert uploaded image to base64
489
+ buffered = io.BytesIO()
490
+ image.save(buffered, format="PNG")
491
+
492
+ # Run generation in a separate thread
493
+ def run_in_thread():
494
+ nonlocal exc, response
495
+ try:
496
+ response = model.image_to_video(
497
+ request_id=request_id,
498
+ image=buffered.getvalue(),
499
+ prompt=prompt,
500
+ negative_prompt=negative_prompt,
501
+ num_frames=num_frames,
502
+ fps=fps,
503
+ num_inference_steps=num_inference_steps,
504
+ guidance_scale=guidance_scale,
505
+ width=width,
506
+ height=height,
507
+ response_format="b64_json",
508
+ )
509
+ except Exception as e:
510
+ exc = e
511
+
512
+ t = threading.Thread(target=run_in_thread)
513
+ t.start()
514
+
515
+ # Progress loop
516
+ while t.is_alive():
517
+ try:
518
+ cur_progress = client.get_progress(request_id)["progress"]
519
+ except Exception:
520
+ cur_progress = 0.0
521
+ progress(cur_progress, desc="Generating video from image")
522
+ time.sleep(1)
523
+
524
+ if exc:
525
+ raise exc
526
+
527
+ # Decode and return video files
528
+ videos = []
529
+ for video_dict in response["data"]: # type: ignore
530
+ video_data = base64.b64decode(video_dict["b64_json"])
531
+ video_path = f"/tmp/{uuid.uuid4()}.mp4"
532
+ with open(video_path, "wb") as f:
533
+ f.write(video_data)
534
+ videos.append((video_path, "Generated Video"))
535
+
536
+ return videos
537
+
538
+ # Gradio UI
539
+ with gr.Blocks() as image2video_ui:
540
+ image = gr.Image(label="Input Image", type="pil")
541
+
542
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter video prompt")
543
+ negative_prompt = gr.Textbox(
544
+ label="Negative Prompt", placeholder="Enter negative prompt"
545
+ )
546
+
547
+ with gr.Row():
548
+ with gr.Column():
549
+ width = gr.Number(label="Width", value=512)
550
+ num_frames = gr.Number(label="Frames", value=16)
551
+ steps = gr.Number(label="Inference Steps", value=25)
552
+ with gr.Column():
553
+ height = gr.Number(label="Height", value=512)
554
+ fps = gr.Number(label="FPS", value=8)
555
+ guidance_scale = gr.Slider(
556
+ label="Guidance Scale", minimum=1, maximum=20, value=7.5
557
+ )
558
+
559
+ generate = gr.Button("Generate")
560
+ gallery = gr.Gallery(label="Generated Videos", columns=2)
561
+
562
+ generate.click(
563
+ fn=image_generate_video,
564
+ inputs=[
565
+ image,
566
+ prompt,
567
+ negative_prompt,
568
+ num_frames,
569
+ fps,
570
+ steps,
571
+ guidance_scale,
572
+ width,
573
+ height,
574
+ ],
575
+ outputs=gallery,
576
+ )
577
+
578
+ return image2video_ui
579
+
580
+ def audio2text_interface(self) -> "gr.Blocks":
581
+ def transcribe_audio(
582
+ audio_path: str,
583
+ language: Optional[str],
584
+ prompt: Optional[str],
585
+ temperature: float,
586
+ ) -> str:
587
+ from ..client import RESTfulClient
588
+
589
+ client = RESTfulClient(self.endpoint)
590
+ client._set_token(self.access_token)
591
+ model = client.get_model(self.model_uid)
592
+ assert isinstance(model, RESTfulAudioModelHandle)
593
+
594
+ with open(audio_path, "rb") as f:
595
+ audio_data = f.read()
596
+
597
+ response = model.transcriptions(
598
+ audio=audio_data,
599
+ language=language or None,
600
+ prompt=prompt or None,
601
+ temperature=temperature,
602
+ response_format="json",
603
+ )
604
+
605
+ return response.get("text", "No transcription result.")
606
+
607
+ with gr.Blocks() as audio2text_ui:
608
+ with gr.Row():
609
+ audio_input = gr.Audio(
610
+ type="filepath",
611
+ label="Upload or Record Audio",
612
+ sources=["upload", "microphone"], # ✅ support both
613
+ )
614
+ with gr.Row():
615
+ language = gr.Textbox(
616
+ label="Language", placeholder="e.g. en or zh", value=""
617
+ )
618
+ prompt = gr.Textbox(
619
+ label="Prompt (optional)",
620
+ placeholder="Provide context or vocabulary",
621
+ )
622
+ temperature = gr.Slider(
623
+ label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1
624
+ )
625
+ transcribe_btn = gr.Button("Transcribe")
626
+ output_text = gr.Textbox(label="Transcription", lines=5)
627
+
628
+ transcribe_btn.click(
629
+ fn=transcribe_audio,
630
+ inputs=[audio_input, language, prompt, temperature],
631
+ outputs=output_text,
632
+ )
633
+
634
+ return audio2text_ui
635
+
636
+ def text2speech_interface(self) -> "gr.Blocks":
637
+ def tts_generate(
638
+ input_text: str,
639
+ voice: str,
640
+ speed: float,
641
+ prompt_speech_file,
642
+ prompt_text: Optional[str],
643
+ ) -> str:
644
+ from ..client import RESTfulClient
645
+
646
+ client = RESTfulClient(self.endpoint)
647
+ client._set_token(self.access_token)
648
+ model = client.get_model(self.model_uid)
649
+ assert hasattr(model, "speech")
650
+
651
+ prompt_speech_bytes = None
652
+ if prompt_speech_file is not None:
653
+ with open(prompt_speech_file, "rb") as f:
654
+ prompt_speech_bytes = f.read()
655
+
656
+ response = model.speech(
657
+ input=input_text,
658
+ voice=voice,
659
+ speed=speed,
660
+ response_format="mp3",
661
+ prompt_speech=prompt_speech_bytes,
662
+ prompt_text=prompt_text,
663
+ )
664
+
665
+ # Write to a temp .mp3 file and return its path
666
+ audio_path = f"/tmp/{uuid.uuid4()}.mp3"
667
+ with open(audio_path, "wb") as f:
668
+ f.write(response)
669
+
670
+ return audio_path
671
+
672
+ # Gradio UI
673
+ with gr.Blocks() as tts_ui:
674
+ with gr.Row():
675
+ with gr.Column():
676
+ input_text = gr.Textbox(
677
+ label="Text", placeholder="Enter text to synthesize"
678
+ )
679
+ voice = gr.Textbox(
680
+ label="Voice", placeholder="Optional voice ID", value=""
681
+ )
682
+ speed = gr.Slider(
683
+ label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1
684
+ )
685
+
686
+ prompt_speech = gr.Audio(
687
+ label="Prompt Speech (for cloning)", type="filepath"
688
+ )
689
+ prompt_text = gr.Textbox(
690
+ label="Prompt Text (for cloning)",
691
+ placeholder="Text of the prompt speech",
692
+ )
693
+
694
+ generate = gr.Button("Generate")
695
+
696
+ with gr.Column():
697
+ audio_output = gr.Audio(label="Generated Audio", type="filepath")
698
+
699
+ generate.click(
700
+ fn=tts_generate,
701
+ inputs=[input_text, voice, speed, prompt_speech, prompt_text],
702
+ outputs=audio_output,
703
+ )
704
+
705
+ return tts_ui
706
+
707
+ def build_main_interface(self) -> "gr.Blocks":
708
+ if self.model_type == "image":
709
+ title = f"🎨 Xinference Stable Diffusion: {self.model_name} 🎨"
710
+ elif self.model_type == "video":
711
+ title = f"🎨 Xinference Video Generation: {self.model_name} 🎨"
712
+ else:
713
+ assert self.model_type == "audio"
714
+ title = f"🎨 Xinference Audio Model: {self.model_name} 🎨"
715
+ with gr.Blocks(
716
+ title=title,
717
+ css="""
718
+ .center{
719
+ display: flex;
720
+ justify-content: center;
721
+ align-items: center;
722
+ padding: 0px;
723
+ color: #9ea4b0 !important;
724
+ }
725
+ """,
726
+ analytics_enabled=False,
727
+ ) as app:
728
+ Markdown(
729
+ f"""
730
+ <h1 class="center" style='text-align: center; margin-bottom: 1rem'>{title}</h1>
731
+ """
732
+ )
733
+ Markdown(
734
+ f"""
735
+ <div class="center">
736
+ Model ID: {self.model_uid}
737
+ </div>
738
+ """
739
+ )
740
+ if "text2image" in self.model_ability:
741
+ with gr.Tab("Text to Image"):
742
+ self.text2image_interface()
743
+ if "image2image" in self.model_ability:
744
+ with gr.Tab("Image to Image"):
745
+ self.image2image_interface()
746
+ if "text2video" in self.model_ability:
747
+ with gr.Tab("Text to Video"):
748
+ self.text2video_interface()
749
+ if "image2video" in self.model_ability:
750
+ with gr.Tab("Image to Video"):
751
+ self.image2video_interface()
752
+ if "audio2text" in self.model_ability:
753
+ with gr.Tab("Audio to Text"):
754
+ self.audio2text_interface()
755
+ if "text2audio" in self.model_ability:
756
+ with gr.Tab("Text to Audio"):
757
+ self.text2speech_interface()
758
+ return app