xinference 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +413 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +447 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/METADATA +127 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,13 @@ import base64
16
16
  import html
17
17
  import logging
18
18
  import os
19
+ import tempfile
19
20
  from io import BytesIO
20
- from typing import Dict, Generator, List, Optional
21
+ from typing import Generator, List, Optional
21
22
 
22
23
  import gradio as gr
23
24
  import PIL.Image
25
+ from gradio import ChatMessage
24
26
  from gradio.components import Markdown, Textbox
25
27
  from gradio.layouts import Accordion, Column, Row
26
28
 
@@ -65,13 +67,13 @@ class GradioInterface:
65
67
 
66
68
  def build(self) -> "gr.Blocks":
67
69
  if "vision" in self.model_ability:
68
- interface = self.build_chat_vl_interface()
70
+ interface = self.build_chat_multimodel_interface()
69
71
  elif "chat" in self.model_ability:
70
72
  interface = self.build_chat_interface()
71
73
  else:
72
74
  interface = self.build_generate_interface()
73
75
 
74
- interface.queue()
76
+ interface.queue(default_concurrency_limit=os.cpu_count())
75
77
  # Gradio initiates the queue during a startup event, but since the app has already been
76
78
  # started, that event will not run, so manually invoke the startup events.
77
79
  # See: https://github.com/gradio-app/gradio/issues/5228
@@ -91,25 +93,10 @@ class GradioInterface:
91
93
  interface.favicon_path = favicon_path
92
94
  return interface
93
95
 
94
- def build_chat_interface(
95
- self,
96
- ) -> "gr.Blocks":
97
- def flatten(matrix: List[List[str]]) -> List[str]:
98
- flat_list = []
99
- for row in matrix:
100
- flat_list += row
101
- return flat_list
102
-
103
- def to_chat(lst: List[str]) -> List[Dict]:
104
- res = []
105
- for i in range(len(lst)):
106
- role = "assistant" if i % 2 == 1 else "user"
107
- res.append(dict(role=role, content=lst[i]))
108
- return res
109
-
96
+ def build_chat_interface(self) -> "gr.Blocks":
110
97
  def generate_wrapper(
111
98
  message: str,
112
- history: List[List[str]],
99
+ history: List[ChatMessage],
113
100
  max_tokens: int,
114
101
  temperature: float,
115
102
  lora_name: str,
@@ -121,13 +108,22 @@ class GradioInterface:
121
108
  client._set_token(self._access_token)
122
109
  model = client.get_model(self.model_uid)
123
110
  assert isinstance(model, RESTfulChatModelHandle)
124
- messages = to_chat(flatten(history))
125
- messages.append(dict(role="user", content=message))
111
+
112
+ # Convert history to messages format
113
+ messages = []
114
+ for msg in history:
115
+ # ignore thinking content
116
+ if msg["metadata"]:
117
+ continue
118
+ messages.append({"role": msg["role"], "content": msg["content"]})
126
119
 
127
120
  if stream:
128
121
  response_content = ""
122
+ reasoning_content = ""
123
+ is_first_reasoning_content = True
124
+ is_first_content = True
129
125
  for chunk in model.chat(
130
- messages,
126
+ messages=messages,
131
127
  generate_config={
132
128
  "max_tokens": int(max_tokens),
133
129
  "temperature": temperature,
@@ -137,46 +133,79 @@ class GradioInterface:
137
133
  ):
138
134
  assert isinstance(chunk, dict)
139
135
  delta = chunk["choices"][0]["delta"]
140
- if "content" not in delta or delta["content"] is None:
141
- continue
142
- else:
143
- # some model like deepseek-r1-distill-qwen
144
- # will generate <think>...</think> ...
145
- # in gradio, no output will be rendered,
146
- # thus escape html tags in advance
147
- response_content += html.escape(delta["content"])
148
- yield response_content
149
136
 
150
- yield response_content
137
+ if (
138
+ "reasoning_content" in delta
139
+ and delta["reasoning_content"] is not None
140
+ and is_first_reasoning_content
141
+ ):
142
+ reasoning_content += html.escape(delta["reasoning_content"])
143
+ history.append(
144
+ ChatMessage(
145
+ role="assistant",
146
+ content=reasoning_content,
147
+ metadata={"title": "💭 Thinking Process"},
148
+ )
149
+ )
150
+ is_first_reasoning_content = False
151
+ elif (
152
+ "reasoning_content" in delta
153
+ and delta["reasoning_content"] is not None
154
+ ):
155
+ reasoning_content += html.escape(delta["reasoning_content"])
156
+ history[-1] = ChatMessage(
157
+ role="assistant",
158
+ content=reasoning_content,
159
+ metadata={"title": "💭 Thinking Process"},
160
+ )
161
+ elif (
162
+ "content" in delta
163
+ and delta["content"] is not None
164
+ and is_first_content
165
+ ):
166
+ response_content += html.escape(delta["content"])
167
+ history.append(
168
+ ChatMessage(role="assistant", content=response_content)
169
+ )
170
+ is_first_content = False
171
+ elif "content" in delta and delta["content"] is not None:
172
+ response_content += html.escape(delta["content"])
173
+ # Replace thinking message with actual response
174
+ history[-1] = ChatMessage(
175
+ role="assistant", content=response_content
176
+ )
177
+ yield history
151
178
  else:
152
179
  result = model.chat(
153
- messages,
180
+ messages=messages,
154
181
  generate_config={
155
182
  "max_tokens": int(max_tokens),
156
183
  "temperature": temperature,
157
184
  "lora_name": lora_name,
158
185
  },
159
186
  )
160
- yield html.escape(result["choices"][0]["message"]["content"]) # type: ignore
187
+ assert isinstance(result, dict)
188
+ mg = result["choices"][0]["message"]
189
+ if "reasoning_content" in mg:
190
+ reasoning_content = mg["reasoning_content"]
191
+ if reasoning_content is not None:
192
+ reasoning_content = html.escape(str(reasoning_content))
193
+ history.append(
194
+ ChatMessage(
195
+ role="assistant",
196
+ content=reasoning_content,
197
+ metadata={"title": "💭 Thinking Process"},
198
+ )
199
+ )
161
200
 
162
- return gr.ChatInterface(
163
- fn=generate_wrapper,
164
- additional_inputs=[
165
- gr.Slider(
166
- minimum=1,
167
- maximum=self.context_length,
168
- value=512
169
- if "reasoning" not in self.model_ability
170
- else self.context_length // 2,
171
- step=1,
172
- label="Max Tokens",
173
- ),
174
- gr.Slider(
175
- minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
176
- ),
177
- gr.Text(label="LoRA Name"),
178
- gr.Checkbox(label="Stream", value=True),
179
- ],
201
+ content = mg["content"]
202
+ response_content = (
203
+ html.escape(str(content)) if content is not None else ""
204
+ )
205
+ history.append(ChatMessage(role="assistant", content=response_content))
206
+ yield history
207
+
208
+ with gr.Blocks(
180
209
  title=f"🚀 Xinference Chat Bot : {self.model_name} 🚀",
181
210
  css="""
182
211
  .center{
@@ -186,25 +215,123 @@ class GradioInterface:
186
215
  padding: 0px;
187
216
  color: #9ea4b0 !important;
188
217
  }
189
- """,
190
- description=f"""
191
- <div class="center">
192
- Model ID: {self.model_uid}
193
- </div>
194
- <div class="center">
195
- Model Size: {self.model_size_in_billions} Billion Parameters
196
- </div>
197
- <div class="center">
198
- Model Format: {self.model_format}
199
- </div>
200
- <div class="center">
201
- Model Quantization: {self.quantization}
202
- </div>
218
+ .main-container {
219
+ display: flex;
220
+ flex-direction: column;
221
+ padding: 0.5rem;
222
+ box-sizing: border-box;
223
+ gap: 0.25rem;
224
+ flex-grow: 1;
225
+ min-width: min(320px, 100%);
226
+ height: calc(100vh - 70px)!important;
227
+ }
228
+ .header {
229
+ flex-grow: 0!important;
230
+ }
231
+ .header h1 {
232
+ margin: 0.5rem 0;
233
+ font-size: 1.5rem;
234
+ }
235
+ .center {
236
+ font-size: 0.9rem;
237
+ margin: 0.1rem 0;
238
+ }
239
+ .chat-container {
240
+ flex: 1;
241
+ display: flex;
242
+ min-height: 0;
243
+ margin: 0.25rem 0;
244
+ }
245
+ .chat-container .block {
246
+ height: 100%!important;
247
+ }
248
+ .input-container {
249
+ flex-grow: 0!important;
250
+ }
203
251
  """,
204
252
  analytics_enabled=False,
205
- )
253
+ ) as chat_interface:
254
+ with gr.Column(elem_classes="main-container"):
255
+ # Header section
256
+ with gr.Column(elem_classes="header"):
257
+ gr.Markdown(
258
+ f"""<h1 style='text-align: center; margin-bottom: 1rem'>🚀 Xinference Chat Bot : {self.model_name} 🚀</h1>"""
259
+ )
260
+ gr.Markdown(
261
+ f"""
262
+ <div class="center">Model ID: {self.model_uid}</div>
263
+ <div class="center">Model Size: {self.model_size_in_billions} Billion Parameters</div>
264
+ <div class="center">Model Format: {self.model_format}</div>
265
+ <div class="center">Model Quantization: {self.quantization}</div>
266
+ """
267
+ )
268
+
269
+ # Chat container
270
+ with gr.Column(elem_classes="chat-container"):
271
+ chatbot = gr.Chatbot(
272
+ type="messages",
273
+ label=self.model_name,
274
+ show_label=True,
275
+ render_markdown=True,
276
+ container=True,
277
+ )
206
278
 
207
- def build_chat_vl_interface(
279
+ # Input container
280
+ with gr.Column(elem_classes="input-container"):
281
+ with gr.Row():
282
+ with gr.Column(scale=12):
283
+ textbox = gr.Textbox(
284
+ show_label=False,
285
+ placeholder="Type a message...",
286
+ container=False,
287
+ )
288
+ with gr.Column(scale=1, min_width=50):
289
+ submit_btn = gr.Button("Enter", variant="primary")
290
+
291
+ with gr.Accordion("Additional Inputs", open=False):
292
+ max_tokens = gr.Slider(
293
+ minimum=1,
294
+ maximum=self.context_length,
295
+ value=512
296
+ if "reasoning" not in self.model_ability
297
+ else self.context_length // 2,
298
+ step=1,
299
+ label="Max Tokens",
300
+ )
301
+ temperature = gr.Slider(
302
+ minimum=0,
303
+ maximum=2,
304
+ value=1,
305
+ step=0.01,
306
+ label="Temperature",
307
+ )
308
+ stream = gr.Checkbox(label="Stream", value=True)
309
+ lora_name = gr.Text(label="LoRA Name")
310
+
311
+ # deal with message submit
312
+ textbox.submit(
313
+ lambda m, h: ("", h + [ChatMessage(role="user", content=m)]),
314
+ [textbox, chatbot],
315
+ [textbox, chatbot],
316
+ ).then(
317
+ generate_wrapper,
318
+ [textbox, chatbot, max_tokens, temperature, lora_name, stream],
319
+ chatbot,
320
+ )
321
+
322
+ submit_btn.click(
323
+ lambda m, h: ("", h + [ChatMessage(role="user", content=m)]),
324
+ [textbox, chatbot],
325
+ [textbox, chatbot],
326
+ ).then(
327
+ generate_wrapper,
328
+ [textbox, chatbot, max_tokens, temperature, lora_name, stream],
329
+ chatbot,
330
+ )
331
+
332
+ return chat_interface
333
+
334
+ def build_chat_multimodel_interface(
208
335
  self,
209
336
  ) -> "gr.Blocks":
210
337
  def predict(history, bot, max_tokens, temperature, stream):
@@ -251,11 +378,46 @@ class GradioInterface:
251
378
  },
252
379
  )
253
380
  history.append(response["choices"][0]["message"])
254
- bot[-1][1] = history[-1]["content"]
255
- yield history, bot
381
+ if "audio" in history[-1]:
382
+ # audio output
383
+ audio_bytes = base64.b64decode(history[-1]["audio"]["data"])
384
+ audio_file = tempfile.NamedTemporaryFile(
385
+ delete=False, suffix=".wav"
386
+ )
387
+ audio_file.write(audio_bytes)
388
+ audio_file.close()
389
+
390
+ def audio_to_base64(audio_path):
391
+ with open(audio_path, "rb") as audio_file:
392
+ return base64.b64encode(audio_file.read()).decode("utf-8")
393
+
394
+ def generate_html_audio(audio_path):
395
+ base64_audio = audio_to_base64(audio_path)
396
+ audio_format = audio_path.split(".")[-1]
397
+ return (
398
+ f"<audio controls style='max-width:100%;'>"
399
+ f"<source src='data:audio/{audio_format};base64,{base64_audio}' type='audio/{audio_format}'>"
400
+ f"Your browser does not support the audio tag.</audio>"
401
+ )
256
402
 
257
- def add_text(history, bot, text, image, video):
258
- logger.debug("Add text, text: %s, image: %s, video: %s", text, image, video)
403
+ bot[-1] = (bot[-1][0], history[-1]["content"])
404
+ yield history, bot
405
+
406
+ # append html audio tag instead of gr.Audio
407
+ bot.append((None, generate_html_audio(audio_file.name)))
408
+ yield history, bot
409
+ else:
410
+ bot[-1][1] = history[-1]["content"]
411
+ yield history, bot
412
+
413
+ def add_text(history, bot, text, image, video, audio):
414
+ logger.debug(
415
+ "Add text, text: %s, image: %s, video: %s, audio: %s",
416
+ text,
417
+ image,
418
+ video,
419
+ audio,
420
+ )
259
421
  if image:
260
422
  buffered = BytesIO()
261
423
  with PIL.Image.open(image) as img:
@@ -306,20 +468,54 @@ class GradioInterface:
306
468
  },
307
469
  ],
308
470
  }
471
+
472
+ elif audio:
473
+
474
+ def audio_to_base64(audio_path):
475
+ with open(audio_path, "rb") as audio_file:
476
+ encoded_string = base64.b64encode(audio_file.read()).decode(
477
+ "utf-8"
478
+ )
479
+ return encoded_string
480
+
481
+ def generate_html_audio(audio_path):
482
+ base64_audio = audio_to_base64(audio_path)
483
+ audio_format = audio_path.split(".")[-1]
484
+ return (
485
+ f"<audio controls style='max-width:100%;'>"
486
+ f"<source src='data:audio/{audio_format};base64,{base64_audio}' type='audio/{audio_format}'>"
487
+ f"Your browser does not support the audio tag.</audio>"
488
+ )
489
+
490
+ display_content = f"{generate_html_audio(audio)}<br>{text}"
491
+ message = {
492
+ "role": "user",
493
+ "content": [
494
+ {"type": "text", "text": text},
495
+ {
496
+ "type": "audio_url",
497
+ "audio_url": {"url": audio},
498
+ },
499
+ ],
500
+ }
501
+
309
502
  else:
310
503
  display_content = text
311
504
  message = {"role": "user", "content": text}
312
505
  history = history + [message]
313
506
  bot = bot + [[display_content, None]]
314
- return history, bot, "", None, None
507
+ return history, bot, "", None, None, None
315
508
 
316
509
  def clear_history():
317
510
  logger.debug("Clear history.")
318
- return [], None, "", None, None
511
+ return [], None, "", None, None, None
319
512
 
320
513
  def update_button(text):
321
514
  return gr.update(interactive=bool(text))
322
515
 
516
+ has_vision = "vision" in self.model_ability
517
+ has_audio = "audio" in self.model_ability
518
+
323
519
  with gr.Blocks(
324
520
  title=f"🚀 Xinference Chat Bot : {self.model_name} 🚀",
325
521
  css="""
@@ -358,11 +554,29 @@ class GradioInterface:
358
554
  state = gr.State([])
359
555
  with gr.Row():
360
556
  chatbot = gr.Chatbot(
361
- elem_id="chatbot", label=self.model_name, height=700, scale=7
557
+ elem_id="chatbot", label=self.model_name, scale=7, min_height=900
362
558
  )
363
559
  with gr.Column(scale=3):
364
- imagebox = gr.Image(type="filepath")
365
- videobox = gr.Video()
560
+ if has_vision:
561
+ imagebox = gr.Image(type="filepath")
562
+ videobox = gr.Video()
563
+ else:
564
+ imagebox = gr.Image(type="filepath", visible=False)
565
+ videobox = gr.Video(visible=False)
566
+
567
+ if has_audio:
568
+ audiobox = gr.Audio(
569
+ sources=["microphone", "upload"],
570
+ type="filepath",
571
+ visible=True,
572
+ )
573
+ else:
574
+ audiobox = gr.Audio(
575
+ sources=["microphone", "upload"],
576
+ type="filepath",
577
+ visible=False,
578
+ )
579
+
366
580
  textbox = gr.Textbox(
367
581
  show_label=False,
368
582
  placeholder="Enter text and press ENTER",
@@ -390,8 +604,8 @@ class GradioInterface:
390
604
 
391
605
  textbox.submit(
392
606
  add_text,
393
- [state, chatbot, textbox, imagebox, videobox],
394
- [state, chatbot, textbox, imagebox, videobox],
607
+ [state, chatbot, textbox, imagebox, videobox, audiobox],
608
+ [state, chatbot, textbox, imagebox, videobox, audiobox],
395
609
  queue=False,
396
610
  ).then(
397
611
  predict,
@@ -401,8 +615,8 @@ class GradioInterface:
401
615
 
402
616
  submit_btn.click(
403
617
  add_text,
404
- [state, chatbot, textbox, imagebox, videobox],
405
- [state, chatbot, textbox, imagebox, videobox],
618
+ [state, chatbot, textbox, imagebox, videobox, audiobox],
619
+ [state, chatbot, textbox, imagebox, videobox, audiobox],
406
620
  queue=False,
407
621
  ).then(
408
622
  predict,
@@ -413,7 +627,7 @@ class GradioInterface:
413
627
  clear_btn.click(
414
628
  clear_history,
415
629
  None,
416
- [state, chatbot, textbox, imagebox, videobox],
630
+ [state, chatbot, textbox, imagebox, videobox, audiobox],
417
631
  queue=False,
418
632
  )
419
633
 
xinference/core/model.py CHANGED
@@ -231,6 +231,7 @@ class ModelActor(xo.StatelessActor, CancelMixin):
231
231
  driver_info: Optional[dict] = None, # for model across workers
232
232
  ):
233
233
  super().__init__()
234
+
234
235
  from ..model.llm.llama_cpp.core import XllamaCppModel
235
236
  from ..model.llm.lmdeploy.core import LMDeployModel
236
237
  from ..model.llm.sglang.core import SGLANGModel
@@ -92,16 +92,20 @@ class ProgressTrackerActor(xo.StatelessActor):
92
92
 
93
93
  await asyncio.sleep(self._check_interval)
94
94
 
95
- def start(self, request_id: str):
95
+ def start(self, request_id: str, info: Optional[str] = None):
96
96
  self._request_id_to_progress[request_id] = _ProgressInfo(
97
- progress=0.0, last_updated=time.time()
97
+ progress=0.0, last_updated=time.time(), info=info
98
98
  )
99
99
 
100
- def set_progress(self, request_id: str, progress: float):
100
+ def set_progress(
101
+ self, request_id: str, progress: float, info: Optional[str] = None
102
+ ):
101
103
  assert progress <= 1.0
102
- info = self._request_id_to_progress[request_id]
103
- info.progress = progress
104
- info.last_updated = time.time()
104
+ info_ = self._request_id_to_progress[request_id]
105
+ info_.progress = progress
106
+ info_.last_updated = time.time()
107
+ if info:
108
+ info_.info = info
105
109
  logger.debug(
106
110
  "Setting progress, request id: %s, progress: %s", request_id, progress
107
111
  )
@@ -109,6 +113,10 @@ class ProgressTrackerActor(xo.StatelessActor):
109
113
  def get_progress(self, request_id: str) -> float:
110
114
  return self._request_id_to_progress[request_id].progress
111
115
 
116
+ def get_progress_info(self, request_id: str) -> Tuple[float, Optional[str]]:
117
+ info = self._request_id_to_progress[request_id]
118
+ return info.progress, info.info
119
+
112
120
 
113
121
  class Progressor:
114
122
  _sub_progress_stack: List[Tuple[float, float]]
@@ -169,7 +177,7 @@ class Progressor:
169
177
  self.set_progress(1.0)
170
178
  return False
171
179
 
172
- def set_progress(self, progress: float):
180
+ def set_progress(self, progress: float, info: Optional[str] = None):
173
181
  if self.request_id:
174
182
  self._current_progress = (
175
183
  self._current_sub_progress_start
@@ -179,7 +187,7 @@ class Progressor:
179
187
  if (
180
188
  self._current_progress - self._last_report_progress >= self._upload_span
181
189
  or 1.0 - progress < 1e-5
182
- ):
190
+ ) or info:
183
191
  set_progress = self.progress_tracker_ref.set_progress(
184
192
  self.request_id, self._current_progress
185
193
  )
@@ -18,11 +18,13 @@ import os
18
18
  import signal
19
19
  import time
20
20
  import typing
21
- from dataclasses import dataclass
21
+ from collections import defaultdict
22
+ from dataclasses import dataclass, field
22
23
  from logging import getLogger
23
24
  from typing import (
24
25
  TYPE_CHECKING,
25
26
  Any,
27
+ DefaultDict,
26
28
  Dict,
27
29
  Iterator,
28
30
  List,
@@ -91,6 +93,9 @@ class WorkerStatus:
91
93
  class ReplicaInfo:
92
94
  replica: int
93
95
  scheduler: Iterator
96
+ replica_to_worker_refs: DefaultDict[
97
+ int, List[xo.ActorRefType["WorkerActor"]]
98
+ ] = field(default_factory=lambda: defaultdict(list))
94
99
 
95
100
 
96
101
  class SupervisorActor(xo.StatelessActor):
@@ -1113,6 +1118,9 @@ class SupervisorActor(xo.StatelessActor):
1113
1118
  if target_ip_worker_ref is not None
1114
1119
  else await self._choose_worker()
1115
1120
  )
1121
+ self._model_uid_to_replica_info[model_uid].replica_to_worker_refs[
1122
+ _idx
1123
+ ].append(worker_ref)
1116
1124
  if enable_xavier and _idx == 0:
1117
1125
  """
1118
1126
  Start the rank 0 model actor on the worker that holds the rank 1 replica,
@@ -1260,6 +1268,9 @@ class SupervisorActor(xo.StatelessActor):
1260
1268
  driver_info = None
1261
1269
  for i_worker in range(n_worker):
1262
1270
  worker_ref = await self._choose_worker(available_workers)
1271
+ self._model_uid_to_replica_info[
1272
+ model_uid
1273
+ ].replica_to_worker_refs[_idx].append(worker_ref)
1263
1274
  nonlocal model_type
1264
1275
  model_type = model_type or "LLM"
1265
1276
  if i_worker > 1:
@@ -1344,6 +1355,39 @@ class SupervisorActor(xo.StatelessActor):
1344
1355
  task.add_done_callback(lambda _: callback_for_async_launch(model_uid)) # type: ignore
1345
1356
  return model_uid
1346
1357
 
1358
+ async def get_launch_builtin_model_progress(self, model_uid: str) -> float:
1359
+ info = self._model_uid_to_replica_info[model_uid]
1360
+ all_progress = 0.0
1361
+ i = 0
1362
+ for rep_model_uid in iter_replica_model_uid(model_uid, info.replica):
1363
+ request_id = f"launching-{rep_model_uid}"
1364
+ try:
1365
+ all_progress += await self._progress_tracker.get_progress(request_id)
1366
+ i += 1
1367
+ except KeyError:
1368
+ continue
1369
+
1370
+ return all_progress / i if i > 0 else 0.0
1371
+
1372
+ async def cancel_launch_builtin_model(self, model_uid: str):
1373
+ info = self._model_uid_to_replica_info[model_uid]
1374
+ coros = []
1375
+ for i, rep_model_uid in enumerate(
1376
+ iter_replica_model_uid(model_uid, info.replica)
1377
+ ):
1378
+ worker_refs = self._model_uid_to_replica_info[
1379
+ model_uid
1380
+ ].replica_to_worker_refs[i]
1381
+ for worker_ref in worker_refs:
1382
+ coros.append(worker_ref.cancel_launch_model(rep_model_uid))
1383
+ try:
1384
+ await asyncio.gather(*coros)
1385
+ except RuntimeError:
1386
+ # some may have finished
1387
+ pass
1388
+ # remove replica info
1389
+ self._model_uid_to_replica_info.pop(model_uid, None)
1390
+
1347
1391
  async def get_instance_info(
1348
1392
  self, model_name: Optional[str], model_uid: Optional[str]
1349
1393
  ) -> List[Dict]: