xinference 0.14.0.post1__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (50) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +54 -0
  3. xinference/client/handlers.py +0 -3
  4. xinference/client/restful/restful_client.py +51 -134
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +1 -4
  7. xinference/core/image_interface.py +33 -5
  8. xinference/core/model.py +28 -2
  9. xinference/core/supervisor.py +37 -0
  10. xinference/core/worker.py +128 -84
  11. xinference/deploy/cmdline.py +1 -4
  12. xinference/model/audio/core.py +11 -3
  13. xinference/model/audio/funasr.py +114 -0
  14. xinference/model/audio/model_spec.json +20 -0
  15. xinference/model/audio/model_spec_modelscope.json +21 -0
  16. xinference/model/audio/whisper.py +1 -1
  17. xinference/model/core.py +12 -0
  18. xinference/model/image/core.py +3 -4
  19. xinference/model/image/model_spec.json +41 -13
  20. xinference/model/image/model_spec_modelscope.json +30 -10
  21. xinference/model/image/stable_diffusion/core.py +53 -2
  22. xinference/model/llm/__init__.py +2 -0
  23. xinference/model/llm/llm_family.json +83 -1
  24. xinference/model/llm/llm_family_modelscope.json +85 -1
  25. xinference/model/llm/pytorch/core.py +1 -0
  26. xinference/model/llm/pytorch/minicpmv26.py +247 -0
  27. xinference/model/llm/sglang/core.py +72 -34
  28. xinference/model/llm/vllm/core.py +38 -0
  29. xinference/model/video/__init__.py +62 -0
  30. xinference/model/video/core.py +178 -0
  31. xinference/model/video/diffusers.py +180 -0
  32. xinference/model/video/model_spec.json +11 -0
  33. xinference/model/video/model_spec_modelscope.json +12 -0
  34. xinference/types.py +10 -24
  35. xinference/web/ui/build/asset-manifest.json +3 -3
  36. xinference/web/ui/build/index.html +1 -1
  37. xinference/web/ui/build/static/js/{main.ef2a203a.js → main.17ca0398.js} +3 -3
  38. xinference/web/ui/build/static/js/main.17ca0398.js.map +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +1 -0
  41. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/METADATA +14 -8
  42. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/RECORD +47 -40
  43. xinference/web/ui/build/static/js/main.ef2a203a.js.map +0 -1
  44. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +0 -1
  46. /xinference/web/ui/build/static/js/{main.ef2a203a.js.LICENSE.txt → main.17ca0398.js.LICENSE.txt} +0 -0
  47. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/LICENSE +0 -0
  48. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/WHEEL +0 -0
  49. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/entry_points.txt +0 -0
  50. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py CHANGED
@@ -68,7 +68,7 @@ class WorkerActor(xo.StatelessActor):
68
68
  # static attrs.
69
69
  self._total_gpu_devices = gpu_devices
70
70
  self._supervisor_address = supervisor_address
71
- self._supervisor_ref = None
71
+ self._supervisor_ref: Optional[xo.ActorRefType] = None
72
72
  self._main_pool = main_pool
73
73
  self._main_pool.recover_sub_pool = self.recover_sub_pool
74
74
 
@@ -147,17 +147,20 @@ class WorkerActor(xo.StatelessActor):
147
147
  )
148
148
  event_model_uid, _, __ = parse_replica_model_uid(model_uid)
149
149
  try:
150
- await self._event_collector_ref.report_event(
151
- event_model_uid,
152
- Event(
153
- event_type=EventType.WARNING,
154
- event_ts=int(time.time()),
155
- event_content="Recreate model",
156
- ),
157
- )
150
+ if self._event_collector_ref is not None:
151
+ await self._event_collector_ref.report_event(
152
+ event_model_uid,
153
+ Event(
154
+ event_type=EventType.WARNING,
155
+ event_ts=int(time.time()),
156
+ event_content="Recreate model",
157
+ ),
158
+ )
158
159
  except Exception as e:
159
160
  # Report callback error can be log and ignore, should not interrupt the Process
160
161
  logger.error("report_event error: %s" % (e))
162
+ finally:
163
+ del event_model_uid
161
164
 
162
165
  self._model_uid_to_recover_count[model_uid] = (
163
166
  recover_count - 1
@@ -175,80 +178,39 @@ class WorkerActor(xo.StatelessActor):
175
178
  return "worker"
176
179
 
177
180
  async def __post_create__(self):
178
- from ..isolation import Isolation
179
- from .cache_tracker import CacheTrackerActor
180
- from .status_guard import StatusGuardActor
181
- from .supervisor import SupervisorActor
182
-
183
- self._status_guard_ref: xo.ActorRefType[ # type: ignore
184
- "StatusGuardActor"
185
- ] = await xo.actor_ref(
186
- address=self._supervisor_address, uid=StatusGuardActor.uid()
187
- )
188
- self._event_collector_ref: xo.ActorRefType[ # type: ignore
189
- EventCollectorActor
190
- ] = await xo.actor_ref(
191
- address=self._supervisor_address, uid=EventCollectorActor.uid()
192
- )
193
- self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
194
- "CacheTrackerActor"
195
- ] = await xo.actor_ref(
196
- address=self._supervisor_address, uid=CacheTrackerActor.uid()
197
- )
198
- self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
199
- address=self._supervisor_address, uid=SupervisorActor.uid()
200
- )
201
- await self._supervisor_ref.add_worker(self.address)
202
- if not XINFERENCE_DISABLE_HEALTH_CHECK:
203
- # Run _periodical_report_status() in a dedicated thread.
204
- self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
205
- self._isolation.start()
206
- asyncio.run_coroutine_threadsafe(
207
- self._periodical_report_status(), loop=self._isolation.loop
208
- )
209
- logger.info(f"Xinference worker {self.address} started")
210
- logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
211
- purge_dir(XINFERENCE_CACHE_DIR)
212
-
213
181
  from ..model.audio import (
214
182
  CustomAudioModelFamilyV1,
215
183
  generate_audio_description,
216
- get_audio_model_descriptions,
217
184
  register_audio,
218
185
  unregister_audio,
219
186
  )
220
187
  from ..model.embedding import (
221
188
  CustomEmbeddingModelSpec,
222
189
  generate_embedding_description,
223
- get_embedding_model_descriptions,
224
190
  register_embedding,
225
191
  unregister_embedding,
226
192
  )
227
193
  from ..model.flexible import (
228
194
  FlexibleModelSpec,
229
195
  generate_flexible_model_description,
230
- get_flexible_model_descriptions,
231
196
  register_flexible_model,
232
197
  unregister_flexible_model,
233
198
  )
234
199
  from ..model.image import (
235
200
  CustomImageModelFamilyV1,
236
201
  generate_image_description,
237
- get_image_model_descriptions,
238
202
  register_image,
239
203
  unregister_image,
240
204
  )
241
205
  from ..model.llm import (
242
206
  CustomLLMFamilyV1,
243
207
  generate_llm_description,
244
- get_llm_model_descriptions,
245
208
  register_llm,
246
209
  unregister_llm,
247
210
  )
248
211
  from ..model.rerank import (
249
212
  CustomRerankModelSpec,
250
213
  generate_rerank_description,
251
- get_rerank_model_descriptions,
252
214
  register_rerank,
253
215
  unregister_rerank,
254
216
  )
@@ -292,24 +254,33 @@ class WorkerActor(xo.StatelessActor):
292
254
  ),
293
255
  }
294
256
 
295
- # record model version
296
- model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
297
- model_version_infos.update(get_llm_model_descriptions())
298
- model_version_infos.update(get_embedding_model_descriptions())
299
- model_version_infos.update(get_rerank_model_descriptions())
300
- model_version_infos.update(get_image_model_descriptions())
301
- model_version_infos.update(get_audio_model_descriptions())
302
- model_version_infos.update(get_flexible_model_descriptions())
303
- await self._cache_tracker_ref.record_model_version(
304
- model_version_infos, self.address
305
- )
257
+ logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
258
+ purge_dir(XINFERENCE_CACHE_DIR)
259
+
260
+ try:
261
+ await self.get_supervisor_ref(add_worker=True)
262
+ except Exception as e:
263
+ # Do not crash the worker if supervisor is down, auto re-connect later
264
+ logger.error(f"cannot connect to supervisor {e}")
265
+
266
+ if not XINFERENCE_DISABLE_HEALTH_CHECK:
267
+ from ..isolation import Isolation
268
+
269
+ # Run _periodical_report_status() in a dedicated thread.
270
+ self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
271
+ self._isolation.start()
272
+ asyncio.run_coroutine_threadsafe(
273
+ self._periodical_report_status(), loop=self._isolation.loop
274
+ )
275
+ logger.info(f"Xinference worker {self.address} started")
306
276
 
307
277
  # Windows does not have signal handler
308
278
  if os.name != "nt":
309
279
 
310
280
  async def signal_handler():
311
281
  try:
312
- await self._supervisor_ref.remove_worker(self.address)
282
+ supervisor_ref = await self.get_supervisor_ref(add_worker=False)
283
+ await supervisor_ref.remove_worker(self.address)
313
284
  except Exception as e:
314
285
  # Ignore the error of rpc, anyway we are exiting
315
286
  logger.exception("remove worker rpc error: %s", e)
@@ -331,6 +302,64 @@ class WorkerActor(xo.StatelessActor):
331
302
  return False
332
303
  return True
333
304
 
305
+ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
306
+ """
307
+ Try connect to supervisor and return ActorRef. Raise exception on error
308
+ Params:
309
+ add_worker: By default will call supervisor.add_worker after first connect
310
+ """
311
+ from .status_guard import StatusGuardActor
312
+ from .supervisor import SupervisorActor
313
+
314
+ if self._supervisor_ref is not None:
315
+ return self._supervisor_ref
316
+ self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
317
+ address=self._supervisor_address, uid=SupervisorActor.uid()
318
+ )
319
+ if add_worker and len(self._model_uid_to_model) == 0:
320
+ # Newly started (or restarted), has no model, notify supervisor
321
+ await self._supervisor_ref.add_worker(self.address)
322
+ logger.info("Connected to supervisor as a fresh worker")
323
+
324
+ self._status_guard_ref: xo.ActorRefType[ # type: ignore
325
+ "StatusGuardActor"
326
+ ] = await xo.actor_ref(
327
+ address=self._supervisor_address, uid=StatusGuardActor.uid()
328
+ )
329
+
330
+ self._event_collector_ref: xo.ActorRefType[ # type: ignore
331
+ EventCollectorActor
332
+ ] = await xo.actor_ref(
333
+ address=self._supervisor_address, uid=EventCollectorActor.uid()
334
+ )
335
+ from .cache_tracker import CacheTrackerActor
336
+
337
+ self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
338
+ "CacheTrackerActor"
339
+ ] = await xo.actor_ref(
340
+ address=self._supervisor_address, uid=CacheTrackerActor.uid()
341
+ )
342
+ # cache_tracker is on supervisor
343
+ from ..model.audio import get_audio_model_descriptions
344
+ from ..model.embedding import get_embedding_model_descriptions
345
+ from ..model.flexible import get_flexible_model_descriptions
346
+ from ..model.image import get_image_model_descriptions
347
+ from ..model.llm import get_llm_model_descriptions
348
+ from ..model.rerank import get_rerank_model_descriptions
349
+
350
+ # record model version
351
+ model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
352
+ model_version_infos.update(get_llm_model_descriptions())
353
+ model_version_infos.update(get_embedding_model_descriptions())
354
+ model_version_infos.update(get_rerank_model_descriptions())
355
+ model_version_infos.update(get_image_model_descriptions())
356
+ model_version_infos.update(get_audio_model_descriptions())
357
+ model_version_infos.update(get_flexible_model_descriptions())
358
+ await self._cache_tracker_ref.record_model_version(
359
+ model_version_infos, self.address
360
+ )
361
+ return self._supervisor_ref
362
+
334
363
  @staticmethod
335
364
  def get_devices_count():
336
365
  from ..device_utils import gpu_count
@@ -342,9 +371,9 @@ class WorkerActor(xo.StatelessActor):
342
371
  return len(self._model_uid_to_model)
343
372
 
344
373
  async def is_model_vllm_backend(self, model_uid: str) -> bool:
345
- assert self._supervisor_ref is not None
346
374
  _model_uid, _, _ = parse_replica_model_uid(model_uid)
347
- model_ref = await self._supervisor_ref.get_model(_model_uid)
375
+ supervisor_ref = await self.get_supervisor_ref()
376
+ model_ref = await supervisor_ref.get_model(_model_uid)
348
377
  return await model_ref.is_vllm_backend()
349
378
 
350
379
  async def allocate_devices_for_embedding(self, model_uid: str) -> int:
@@ -706,6 +735,8 @@ class WorkerActor(xo.StatelessActor):
706
735
  return ["text_to_image"]
707
736
  elif model_type == "audio":
708
737
  return ["audio_to_text"]
738
+ elif model_type == "video":
739
+ return ["text_to_video"]
709
740
  elif model_type == "flexible":
710
741
  return ["flexible"]
711
742
  else:
@@ -762,14 +793,15 @@ class WorkerActor(xo.StatelessActor):
762
793
  logger.exception(e)
763
794
  raise
764
795
  try:
765
- await self._event_collector_ref.report_event(
766
- origin_uid,
767
- Event(
768
- event_type=EventType.INFO,
769
- event_ts=int(time.time()),
770
- event_content="Launch model",
771
- ),
772
- )
796
+ if self._event_collector_ref is not None:
797
+ await self._event_collector_ref.report_event(
798
+ origin_uid,
799
+ Event(
800
+ event_type=EventType.INFO,
801
+ event_ts=int(time.time()),
802
+ event_content="Launch model",
803
+ ),
804
+ )
773
805
  except Exception as e:
774
806
  # Report callback error can be log and ignore, should not interrupt the Process
775
807
  logger.error("report_event error: %s" % (e))
@@ -865,6 +897,11 @@ class WorkerActor(xo.StatelessActor):
865
897
 
866
898
  # update status to READY
867
899
  abilities = await self._get_model_ability(model, model_type)
900
+ _ = await self.get_supervisor_ref(add_worker=False)
901
+
902
+ if self._status_guard_ref is None:
903
+ _ = await self.get_supervisor_ref()
904
+ assert self._status_guard_ref is not None
868
905
  await self._status_guard_ref.update_instance_info(
869
906
  origin_uid,
870
907
  {"model_ability": abilities, "status": LaunchStatus.READY.name},
@@ -877,21 +914,23 @@ class WorkerActor(xo.StatelessActor):
877
914
  raise ValueError(f"{model_uid} is launching")
878
915
  origin_uid, _, __ = parse_replica_model_uid(model_uid)
879
916
  try:
880
- await self._event_collector_ref.report_event(
881
- origin_uid,
882
- Event(
883
- event_type=EventType.INFO,
884
- event_ts=int(time.time()),
885
- event_content="Terminate model",
886
- ),
887
- )
917
+ if self._event_collector_ref is not None:
918
+ await self._event_collector_ref.report_event(
919
+ origin_uid,
920
+ Event(
921
+ event_type=EventType.INFO,
922
+ event_ts=int(time.time()),
923
+ event_content="Terminate model",
924
+ ),
925
+ )
888
926
  except Exception as e:
889
927
  # Report callback error can be log and ignore, should not interrupt the Process
890
928
  logger.error("report_event error: %s" % (e))
891
929
 
892
- await self._status_guard_ref.update_instance_info(
893
- origin_uid, {"status": LaunchStatus.TERMINATING.name}
894
- )
930
+ if self._status_guard_ref is not None:
931
+ await self._status_guard_ref.update_instance_info(
932
+ origin_uid, {"status": LaunchStatus.TERMINATING.name}
933
+ )
895
934
  model_ref = self._model_uid_to_model.get(model_uid, None)
896
935
  if model_ref is None:
897
936
  logger.debug("Model not found, uid: %s", model_uid)
@@ -916,6 +955,10 @@ class WorkerActor(xo.StatelessActor):
916
955
  self._model_uid_to_addr.pop(model_uid, None)
917
956
  self._model_uid_to_recover_count.pop(model_uid, None)
918
957
  self._model_uid_to_launch_args.pop(model_uid, None)
958
+
959
+ if self._status_guard_ref is None:
960
+ _ = await self.get_supervisor_ref()
961
+ assert self._status_guard_ref is not None
919
962
  await self._status_guard_ref.update_instance_info(
920
963
  origin_uid, {"status": LaunchStatus.TERMINATED.name}
921
964
  )
@@ -968,7 +1011,8 @@ class WorkerActor(xo.StatelessActor):
968
1011
  raise
969
1012
  except Exception:
970
1013
  logger.exception("Report status got error.")
971
- await self._supervisor_ref.report_worker_status(self.address, status)
1014
+ supervisor_ref = await self.get_supervisor_ref()
1015
+ await supervisor_ref.report_worker_status(self.address, status)
972
1016
 
973
1017
  async def _periodical_report_status(self):
974
1018
  while True:
@@ -25,7 +25,6 @@ from xoscar.utils import get_next_port
25
25
  from .. import __version__
26
26
  from ..client import RESTfulClient
27
27
  from ..client.restful.restful_client import (
28
- RESTfulChatglmCppChatModelHandle,
29
28
  RESTfulChatModelHandle,
30
29
  RESTfulGenerateModelHandle,
31
30
  )
@@ -1268,9 +1267,7 @@ def model_chat(
1268
1267
  task.exception()
1269
1268
  else:
1270
1269
  restful_model = client.get_model(model_uid=model_uid)
1271
- if not isinstance(
1272
- restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
1273
- ):
1270
+ if not isinstance(restful_model, RESTfulChatModelHandle):
1274
1271
  raise ValueError(f"model {model_uid} has no chat method")
1275
1272
 
1276
1273
  while True:
@@ -14,13 +14,14 @@
14
14
  import logging
15
15
  import os
16
16
  from collections import defaultdict
17
- from typing import Dict, List, Literal, Optional, Tuple, Union
17
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
24
+ from .funasr import FunASRModel
24
25
  from .whisper import WhisperModel
25
26
 
26
27
  MAX_ATTEMPTS = 3
@@ -45,6 +46,8 @@ class AudioModelFamilyV1(CacheableModelSpec):
45
46
  model_id: str
46
47
  model_revision: str
47
48
  multilingual: bool
49
+ default_model_config: Optional[Dict[str, Any]]
50
+ default_transcription_config: Optional[Dict[str, Any]]
48
51
 
49
52
 
50
53
  class AudioModelDescription(ModelDescription):
@@ -152,13 +155,18 @@ def create_audio_model_instance(
152
155
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
153
156
  model_path: Optional[str] = None,
154
157
  **kwargs,
155
- ) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
158
+ ) -> Tuple[
159
+ Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
160
+ AudioModelDescription,
161
+ ]:
156
162
  model_spec = match_audio(model_name, download_hub)
157
163
  if model_path is None:
158
164
  model_path = cache(model_spec)
159
- model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
165
+ model: Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel]
160
166
  if model_spec.model_family == "whisper":
161
167
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
168
+ elif model_spec.model_family == "funasr":
169
+ model = FunASRModel(model_uid, model_path, model_spec, **kwargs)
162
170
  elif model_spec.model_family == "ChatTTS":
163
171
  model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
164
172
  elif model_spec.model_family == "CosyVoice":
@@ -0,0 +1,114 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import tempfile
17
+ from typing import TYPE_CHECKING, List, Optional
18
+
19
+ from ...device_utils import get_available_device, is_device_available
20
+
21
+ if TYPE_CHECKING:
22
+ from .core import AudioModelFamilyV1
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class FunASRModel:
28
+ def __init__(
29
+ self,
30
+ model_uid: str,
31
+ model_path: str,
32
+ model_spec: "AudioModelFamilyV1",
33
+ device: Optional[str] = None,
34
+ **kwargs,
35
+ ):
36
+ self._model_uid = model_uid
37
+ self._model_path = model_path
38
+ self._model_spec = model_spec
39
+ self._device = device
40
+ self._model = None
41
+ self._kwargs = kwargs
42
+
43
+ def load(self):
44
+ try:
45
+ from funasr import AutoModel
46
+ except ImportError:
47
+ error_message = "Failed to import module 'funasr'"
48
+ installation_guide = [
49
+ "Please make sure 'funasr' is installed. ",
50
+ "You can install it by `pip install funasr`\n",
51
+ ]
52
+
53
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
54
+
55
+ if self._device is None:
56
+ self._device = get_available_device()
57
+ else:
58
+ if not is_device_available(self._device):
59
+ raise ValueError(f"Device {self._device} is not available!")
60
+
61
+ kwargs = self._model_spec.default_model_config.copy()
62
+ kwargs.update(self._kwargs)
63
+ logger.debug("Loading FunASR model with kwargs: %s", kwargs)
64
+ self._model = AutoModel(model=self._model_path, device=self._device, **kwargs)
65
+
66
+ def transcriptions(
67
+ self,
68
+ audio: bytes,
69
+ language: Optional[str] = None,
70
+ prompt: Optional[str] = None,
71
+ response_format: str = "json",
72
+ temperature: float = 0,
73
+ timestamp_granularities: Optional[List[str]] = None,
74
+ **kwargs,
75
+ ):
76
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
77
+
78
+ if temperature != 0:
79
+ raise RuntimeError("`temperature`is not supported for FunASR")
80
+ if timestamp_granularities is not None:
81
+ raise RuntimeError("`timestamp_granularities`is not supported for FunASR")
82
+ if prompt is not None:
83
+ logger.warning(
84
+ "Prompt for funasr transcriptions will be ignored: %s", prompt
85
+ )
86
+
87
+ language = "auto" if language is None else language
88
+
89
+ with tempfile.NamedTemporaryFile(buffering=0) as f:
90
+ f.write(audio)
91
+
92
+ kw = self._model_spec.default_transcription_config.copy() # type: ignore
93
+ kw.update(kwargs)
94
+ logger.debug("Calling FunASR model with kwargs: %s", kw)
95
+ result = self._model.generate( # type: ignore
96
+ input=f.name, cache={}, language=language, **kw
97
+ )
98
+ text = rich_transcription_postprocess(result[0]["text"])
99
+
100
+ if response_format == "json":
101
+ return {"text": text}
102
+ else:
103
+ raise ValueError(f"Unsupported response format: {response_format}")
104
+
105
+ def translations(
106
+ self,
107
+ audio: bytes,
108
+ language: Optional[str] = None,
109
+ prompt: Optional[str] = None,
110
+ response_format: str = "json",
111
+ temperature: float = 0,
112
+ timestamp_granularities: Optional[List[str]] = None,
113
+ ):
114
+ raise RuntimeError("FunASR does not support translations API")
@@ -95,6 +95,26 @@
95
95
  "ability": "audio-to-text",
96
96
  "multilingual": false
97
97
  },
98
+ {
99
+ "model_name": "SenseVoiceSmall",
100
+ "model_family": "funasr",
101
+ "model_id": "FunAudioLLM/SenseVoiceSmall",
102
+ "model_revision": "3eb3b4eeffc2f2dde6051b853983753db33e35c3",
103
+ "ability": "audio-to-text",
104
+ "multilingual": true,
105
+ "default_model_config": {
106
+ "vad_model": "fsmn-vad",
107
+ "vad_kwargs": {
108
+ "max_single_segment_time": 30000
109
+ }
110
+ },
111
+ "default_transcription_config": {
112
+ "use_itn": true,
113
+ "batch_size_s": 60,
114
+ "merge_vad": true,
115
+ "merge_length_s": 15
116
+ }
117
+ },
98
118
  {
99
119
  "model_name": "ChatTTS",
100
120
  "model_family": "ChatTTS",
@@ -8,6 +8,27 @@
8
8
  "ability": "audio-to-text",
9
9
  "multilingual": true
10
10
  },
11
+ {
12
+ "model_name": "SenseVoiceSmall",
13
+ "model_family": "funasr",
14
+ "model_hub": "modelscope",
15
+ "model_id": "iic/SenseVoiceSmall",
16
+ "model_revision": "master",
17
+ "ability": "audio-to-text",
18
+ "multilingual": true,
19
+ "default_model_config": {
20
+ "vad_model": "fsmn-vad",
21
+ "vad_kwargs": {
22
+ "max_single_segment_time": 30000
23
+ }
24
+ },
25
+ "default_transcription_config": {
26
+ "use_itn": true,
27
+ "batch_size_s": 60,
28
+ "merge_vad": true,
29
+ "merge_length_s": 15
30
+ }
31
+ },
11
32
  {
12
33
  "model_name": "ChatTTS",
13
34
  "model_family": "ChatTTS",
@@ -14,7 +14,7 @@
14
14
  import logging
15
15
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
16
16
 
17
- from xinference.device_utils import (
17
+ from ...device_utils import (
18
18
  get_available_device,
19
19
  get_device_preferred_dtype,
20
20
  is_device_available,
xinference/model/core.py CHANGED
@@ -65,6 +65,7 @@ def create_model_instance(
65
65
  from .image.core import create_image_model_instance
66
66
  from .llm.core import create_llm_model_instance
67
67
  from .rerank.core import create_rerank_model_instance
68
+ from .video.core import create_video_model_instance
68
69
 
69
70
  if model_type == "LLM":
70
71
  return create_llm_model_instance(
@@ -127,6 +128,17 @@ def create_model_instance(
127
128
  model_path,
128
129
  **kwargs,
129
130
  )
131
+ elif model_type == "video":
132
+ kwargs.pop("trust_remote_code", None)
133
+ return create_video_model_instance(
134
+ subpool_addr,
135
+ devices,
136
+ model_uid,
137
+ model_name,
138
+ download_hub,
139
+ model_path,
140
+ **kwargs,
141
+ )
130
142
  elif model_type == "flexible":
131
143
  kwargs.pop("trust_remote_code", None)
132
144
  return create_flexible_model_instance(
@@ -45,7 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
45
  model_id: str
46
46
  model_revision: str
47
47
  model_hub: str = "huggingface"
48
- abilities: Optional[List[str]]
48
+ model_ability: Optional[List[str]]
49
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
50
50
 
51
51
 
@@ -72,7 +72,7 @@ class ImageModelDescription(ModelDescription):
72
72
  "model_name": self._model_spec.model_name,
73
73
  "model_family": self._model_spec.model_family,
74
74
  "model_revision": self._model_spec.model_revision,
75
- "abilities": self._model_spec.abilities,
75
+ "model_ability": self._model_spec.model_ability,
76
76
  "controlnet": controlnet,
77
77
  }
78
78
 
@@ -178,7 +178,6 @@ def get_cache_status(
178
178
  ]
179
179
  )
180
180
  else: # Usually for UT
181
- logger.warning(f"Cannot find builtin image model spec: {model_name}")
182
181
  return valid_model_revision(meta_path, model_spec.model_revision)
183
182
 
184
183
 
@@ -239,7 +238,7 @@ def create_image_model_instance(
239
238
  lora_model_paths=lora_model,
240
239
  lora_load_kwargs=lora_load_kwargs,
241
240
  lora_fuse_kwargs=lora_fuse_kwargs,
242
- abilities=model_spec.abilities,
241
+ abilities=model_spec.model_ability,
243
242
  **kwargs,
244
243
  )
245
244
  model_description = ImageModelDescription(