xinference 0.14.0.post1__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +54 -0
- xinference/client/handlers.py +0 -3
- xinference/client/restful/restful_client.py +51 -134
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +1 -4
- xinference/core/image_interface.py +33 -5
- xinference/core/model.py +28 -2
- xinference/core/supervisor.py +37 -0
- xinference/core/worker.py +128 -84
- xinference/deploy/cmdline.py +1 -4
- xinference/model/audio/core.py +11 -3
- xinference/model/audio/funasr.py +114 -0
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +21 -0
- xinference/model/audio/whisper.py +1 -1
- xinference/model/core.py +12 -0
- xinference/model/image/core.py +3 -4
- xinference/model/image/model_spec.json +41 -13
- xinference/model/image/model_spec_modelscope.json +30 -10
- xinference/model/image/stable_diffusion/core.py +53 -2
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +83 -1
- xinference/model/llm/llm_family_modelscope.json +85 -1
- xinference/model/llm/pytorch/core.py +1 -0
- xinference/model/llm/pytorch/minicpmv26.py +247 -0
- xinference/model/llm/sglang/core.py +72 -34
- xinference/model/llm/vllm/core.py +38 -0
- xinference/model/video/__init__.py +62 -0
- xinference/model/video/core.py +178 -0
- xinference/model/video/diffusers.py +180 -0
- xinference/model/video/model_spec.json +11 -0
- xinference/model/video/model_spec_modelscope.json +12 -0
- xinference/types.py +10 -24
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ef2a203a.js → main.17ca0398.js} +3 -3
- xinference/web/ui/build/static/js/main.17ca0398.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +1 -0
- {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/METADATA +14 -8
- {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/RECORD +47 -40
- xinference/web/ui/build/static/js/main.ef2a203a.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +0 -1
- /xinference/web/ui/build/static/js/{main.ef2a203a.js.LICENSE.txt → main.17ca0398.js.LICENSE.txt} +0 -0
- {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py
CHANGED
|
@@ -68,7 +68,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
68
68
|
# static attrs.
|
|
69
69
|
self._total_gpu_devices = gpu_devices
|
|
70
70
|
self._supervisor_address = supervisor_address
|
|
71
|
-
self._supervisor_ref = None
|
|
71
|
+
self._supervisor_ref: Optional[xo.ActorRefType] = None
|
|
72
72
|
self._main_pool = main_pool
|
|
73
73
|
self._main_pool.recover_sub_pool = self.recover_sub_pool
|
|
74
74
|
|
|
@@ -147,17 +147,20 @@ class WorkerActor(xo.StatelessActor):
|
|
|
147
147
|
)
|
|
148
148
|
event_model_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
149
149
|
try:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
150
|
+
if self._event_collector_ref is not None:
|
|
151
|
+
await self._event_collector_ref.report_event(
|
|
152
|
+
event_model_uid,
|
|
153
|
+
Event(
|
|
154
|
+
event_type=EventType.WARNING,
|
|
155
|
+
event_ts=int(time.time()),
|
|
156
|
+
event_content="Recreate model",
|
|
157
|
+
),
|
|
158
|
+
)
|
|
158
159
|
except Exception as e:
|
|
159
160
|
# Report callback error can be log and ignore, should not interrupt the Process
|
|
160
161
|
logger.error("report_event error: %s" % (e))
|
|
162
|
+
finally:
|
|
163
|
+
del event_model_uid
|
|
161
164
|
|
|
162
165
|
self._model_uid_to_recover_count[model_uid] = (
|
|
163
166
|
recover_count - 1
|
|
@@ -175,80 +178,39 @@ class WorkerActor(xo.StatelessActor):
|
|
|
175
178
|
return "worker"
|
|
176
179
|
|
|
177
180
|
async def __post_create__(self):
|
|
178
|
-
from ..isolation import Isolation
|
|
179
|
-
from .cache_tracker import CacheTrackerActor
|
|
180
|
-
from .status_guard import StatusGuardActor
|
|
181
|
-
from .supervisor import SupervisorActor
|
|
182
|
-
|
|
183
|
-
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
184
|
-
"StatusGuardActor"
|
|
185
|
-
] = await xo.actor_ref(
|
|
186
|
-
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
187
|
-
)
|
|
188
|
-
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
189
|
-
EventCollectorActor
|
|
190
|
-
] = await xo.actor_ref(
|
|
191
|
-
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
192
|
-
)
|
|
193
|
-
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
194
|
-
"CacheTrackerActor"
|
|
195
|
-
] = await xo.actor_ref(
|
|
196
|
-
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
197
|
-
)
|
|
198
|
-
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
|
|
199
|
-
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
200
|
-
)
|
|
201
|
-
await self._supervisor_ref.add_worker(self.address)
|
|
202
|
-
if not XINFERENCE_DISABLE_HEALTH_CHECK:
|
|
203
|
-
# Run _periodical_report_status() in a dedicated thread.
|
|
204
|
-
self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
|
|
205
|
-
self._isolation.start()
|
|
206
|
-
asyncio.run_coroutine_threadsafe(
|
|
207
|
-
self._periodical_report_status(), loop=self._isolation.loop
|
|
208
|
-
)
|
|
209
|
-
logger.info(f"Xinference worker {self.address} started")
|
|
210
|
-
logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
|
|
211
|
-
purge_dir(XINFERENCE_CACHE_DIR)
|
|
212
|
-
|
|
213
181
|
from ..model.audio import (
|
|
214
182
|
CustomAudioModelFamilyV1,
|
|
215
183
|
generate_audio_description,
|
|
216
|
-
get_audio_model_descriptions,
|
|
217
184
|
register_audio,
|
|
218
185
|
unregister_audio,
|
|
219
186
|
)
|
|
220
187
|
from ..model.embedding import (
|
|
221
188
|
CustomEmbeddingModelSpec,
|
|
222
189
|
generate_embedding_description,
|
|
223
|
-
get_embedding_model_descriptions,
|
|
224
190
|
register_embedding,
|
|
225
191
|
unregister_embedding,
|
|
226
192
|
)
|
|
227
193
|
from ..model.flexible import (
|
|
228
194
|
FlexibleModelSpec,
|
|
229
195
|
generate_flexible_model_description,
|
|
230
|
-
get_flexible_model_descriptions,
|
|
231
196
|
register_flexible_model,
|
|
232
197
|
unregister_flexible_model,
|
|
233
198
|
)
|
|
234
199
|
from ..model.image import (
|
|
235
200
|
CustomImageModelFamilyV1,
|
|
236
201
|
generate_image_description,
|
|
237
|
-
get_image_model_descriptions,
|
|
238
202
|
register_image,
|
|
239
203
|
unregister_image,
|
|
240
204
|
)
|
|
241
205
|
from ..model.llm import (
|
|
242
206
|
CustomLLMFamilyV1,
|
|
243
207
|
generate_llm_description,
|
|
244
|
-
get_llm_model_descriptions,
|
|
245
208
|
register_llm,
|
|
246
209
|
unregister_llm,
|
|
247
210
|
)
|
|
248
211
|
from ..model.rerank import (
|
|
249
212
|
CustomRerankModelSpec,
|
|
250
213
|
generate_rerank_description,
|
|
251
|
-
get_rerank_model_descriptions,
|
|
252
214
|
register_rerank,
|
|
253
215
|
unregister_rerank,
|
|
254
216
|
)
|
|
@@ -292,24 +254,33 @@ class WorkerActor(xo.StatelessActor):
|
|
|
292
254
|
),
|
|
293
255
|
}
|
|
294
256
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
257
|
+
logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
|
|
258
|
+
purge_dir(XINFERENCE_CACHE_DIR)
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
await self.get_supervisor_ref(add_worker=True)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
# Do not crash the worker if supervisor is down, auto re-connect later
|
|
264
|
+
logger.error(f"cannot connect to supervisor {e}")
|
|
265
|
+
|
|
266
|
+
if not XINFERENCE_DISABLE_HEALTH_CHECK:
|
|
267
|
+
from ..isolation import Isolation
|
|
268
|
+
|
|
269
|
+
# Run _periodical_report_status() in a dedicated thread.
|
|
270
|
+
self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
|
|
271
|
+
self._isolation.start()
|
|
272
|
+
asyncio.run_coroutine_threadsafe(
|
|
273
|
+
self._periodical_report_status(), loop=self._isolation.loop
|
|
274
|
+
)
|
|
275
|
+
logger.info(f"Xinference worker {self.address} started")
|
|
306
276
|
|
|
307
277
|
# Windows does not have signal handler
|
|
308
278
|
if os.name != "nt":
|
|
309
279
|
|
|
310
280
|
async def signal_handler():
|
|
311
281
|
try:
|
|
312
|
-
await self.
|
|
282
|
+
supervisor_ref = await self.get_supervisor_ref(add_worker=False)
|
|
283
|
+
await supervisor_ref.remove_worker(self.address)
|
|
313
284
|
except Exception as e:
|
|
314
285
|
# Ignore the error of rpc, anyway we are exiting
|
|
315
286
|
logger.exception("remove worker rpc error: %s", e)
|
|
@@ -331,6 +302,64 @@ class WorkerActor(xo.StatelessActor):
|
|
|
331
302
|
return False
|
|
332
303
|
return True
|
|
333
304
|
|
|
305
|
+
async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
|
|
306
|
+
"""
|
|
307
|
+
Try connect to supervisor and return ActorRef. Raise exception on error
|
|
308
|
+
Params:
|
|
309
|
+
add_worker: By default will call supervisor.add_worker after first connect
|
|
310
|
+
"""
|
|
311
|
+
from .status_guard import StatusGuardActor
|
|
312
|
+
from .supervisor import SupervisorActor
|
|
313
|
+
|
|
314
|
+
if self._supervisor_ref is not None:
|
|
315
|
+
return self._supervisor_ref
|
|
316
|
+
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
|
|
317
|
+
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
318
|
+
)
|
|
319
|
+
if add_worker and len(self._model_uid_to_model) == 0:
|
|
320
|
+
# Newly started (or restarted), has no model, notify supervisor
|
|
321
|
+
await self._supervisor_ref.add_worker(self.address)
|
|
322
|
+
logger.info("Connected to supervisor as a fresh worker")
|
|
323
|
+
|
|
324
|
+
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
325
|
+
"StatusGuardActor"
|
|
326
|
+
] = await xo.actor_ref(
|
|
327
|
+
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
331
|
+
EventCollectorActor
|
|
332
|
+
] = await xo.actor_ref(
|
|
333
|
+
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
334
|
+
)
|
|
335
|
+
from .cache_tracker import CacheTrackerActor
|
|
336
|
+
|
|
337
|
+
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
338
|
+
"CacheTrackerActor"
|
|
339
|
+
] = await xo.actor_ref(
|
|
340
|
+
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
341
|
+
)
|
|
342
|
+
# cache_tracker is on supervisor
|
|
343
|
+
from ..model.audio import get_audio_model_descriptions
|
|
344
|
+
from ..model.embedding import get_embedding_model_descriptions
|
|
345
|
+
from ..model.flexible import get_flexible_model_descriptions
|
|
346
|
+
from ..model.image import get_image_model_descriptions
|
|
347
|
+
from ..model.llm import get_llm_model_descriptions
|
|
348
|
+
from ..model.rerank import get_rerank_model_descriptions
|
|
349
|
+
|
|
350
|
+
# record model version
|
|
351
|
+
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
|
|
352
|
+
model_version_infos.update(get_llm_model_descriptions())
|
|
353
|
+
model_version_infos.update(get_embedding_model_descriptions())
|
|
354
|
+
model_version_infos.update(get_rerank_model_descriptions())
|
|
355
|
+
model_version_infos.update(get_image_model_descriptions())
|
|
356
|
+
model_version_infos.update(get_audio_model_descriptions())
|
|
357
|
+
model_version_infos.update(get_flexible_model_descriptions())
|
|
358
|
+
await self._cache_tracker_ref.record_model_version(
|
|
359
|
+
model_version_infos, self.address
|
|
360
|
+
)
|
|
361
|
+
return self._supervisor_ref
|
|
362
|
+
|
|
334
363
|
@staticmethod
|
|
335
364
|
def get_devices_count():
|
|
336
365
|
from ..device_utils import gpu_count
|
|
@@ -342,9 +371,9 @@ class WorkerActor(xo.StatelessActor):
|
|
|
342
371
|
return len(self._model_uid_to_model)
|
|
343
372
|
|
|
344
373
|
async def is_model_vllm_backend(self, model_uid: str) -> bool:
|
|
345
|
-
assert self._supervisor_ref is not None
|
|
346
374
|
_model_uid, _, _ = parse_replica_model_uid(model_uid)
|
|
347
|
-
|
|
375
|
+
supervisor_ref = await self.get_supervisor_ref()
|
|
376
|
+
model_ref = await supervisor_ref.get_model(_model_uid)
|
|
348
377
|
return await model_ref.is_vllm_backend()
|
|
349
378
|
|
|
350
379
|
async def allocate_devices_for_embedding(self, model_uid: str) -> int:
|
|
@@ -706,6 +735,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
706
735
|
return ["text_to_image"]
|
|
707
736
|
elif model_type == "audio":
|
|
708
737
|
return ["audio_to_text"]
|
|
738
|
+
elif model_type == "video":
|
|
739
|
+
return ["text_to_video"]
|
|
709
740
|
elif model_type == "flexible":
|
|
710
741
|
return ["flexible"]
|
|
711
742
|
else:
|
|
@@ -762,14 +793,15 @@ class WorkerActor(xo.StatelessActor):
|
|
|
762
793
|
logger.exception(e)
|
|
763
794
|
raise
|
|
764
795
|
try:
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
796
|
+
if self._event_collector_ref is not None:
|
|
797
|
+
await self._event_collector_ref.report_event(
|
|
798
|
+
origin_uid,
|
|
799
|
+
Event(
|
|
800
|
+
event_type=EventType.INFO,
|
|
801
|
+
event_ts=int(time.time()),
|
|
802
|
+
event_content="Launch model",
|
|
803
|
+
),
|
|
804
|
+
)
|
|
773
805
|
except Exception as e:
|
|
774
806
|
# Report callback error can be log and ignore, should not interrupt the Process
|
|
775
807
|
logger.error("report_event error: %s" % (e))
|
|
@@ -865,6 +897,11 @@ class WorkerActor(xo.StatelessActor):
|
|
|
865
897
|
|
|
866
898
|
# update status to READY
|
|
867
899
|
abilities = await self._get_model_ability(model, model_type)
|
|
900
|
+
_ = await self.get_supervisor_ref(add_worker=False)
|
|
901
|
+
|
|
902
|
+
if self._status_guard_ref is None:
|
|
903
|
+
_ = await self.get_supervisor_ref()
|
|
904
|
+
assert self._status_guard_ref is not None
|
|
868
905
|
await self._status_guard_ref.update_instance_info(
|
|
869
906
|
origin_uid,
|
|
870
907
|
{"model_ability": abilities, "status": LaunchStatus.READY.name},
|
|
@@ -877,21 +914,23 @@ class WorkerActor(xo.StatelessActor):
|
|
|
877
914
|
raise ValueError(f"{model_uid} is launching")
|
|
878
915
|
origin_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
879
916
|
try:
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
917
|
+
if self._event_collector_ref is not None:
|
|
918
|
+
await self._event_collector_ref.report_event(
|
|
919
|
+
origin_uid,
|
|
920
|
+
Event(
|
|
921
|
+
event_type=EventType.INFO,
|
|
922
|
+
event_ts=int(time.time()),
|
|
923
|
+
event_content="Terminate model",
|
|
924
|
+
),
|
|
925
|
+
)
|
|
888
926
|
except Exception as e:
|
|
889
927
|
# Report callback error can be log and ignore, should not interrupt the Process
|
|
890
928
|
logger.error("report_event error: %s" % (e))
|
|
891
929
|
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
930
|
+
if self._status_guard_ref is not None:
|
|
931
|
+
await self._status_guard_ref.update_instance_info(
|
|
932
|
+
origin_uid, {"status": LaunchStatus.TERMINATING.name}
|
|
933
|
+
)
|
|
895
934
|
model_ref = self._model_uid_to_model.get(model_uid, None)
|
|
896
935
|
if model_ref is None:
|
|
897
936
|
logger.debug("Model not found, uid: %s", model_uid)
|
|
@@ -916,6 +955,10 @@ class WorkerActor(xo.StatelessActor):
|
|
|
916
955
|
self._model_uid_to_addr.pop(model_uid, None)
|
|
917
956
|
self._model_uid_to_recover_count.pop(model_uid, None)
|
|
918
957
|
self._model_uid_to_launch_args.pop(model_uid, None)
|
|
958
|
+
|
|
959
|
+
if self._status_guard_ref is None:
|
|
960
|
+
_ = await self.get_supervisor_ref()
|
|
961
|
+
assert self._status_guard_ref is not None
|
|
919
962
|
await self._status_guard_ref.update_instance_info(
|
|
920
963
|
origin_uid, {"status": LaunchStatus.TERMINATED.name}
|
|
921
964
|
)
|
|
@@ -968,7 +1011,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
968
1011
|
raise
|
|
969
1012
|
except Exception:
|
|
970
1013
|
logger.exception("Report status got error.")
|
|
971
|
-
await self.
|
|
1014
|
+
supervisor_ref = await self.get_supervisor_ref()
|
|
1015
|
+
await supervisor_ref.report_worker_status(self.address, status)
|
|
972
1016
|
|
|
973
1017
|
async def _periodical_report_status(self):
|
|
974
1018
|
while True:
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -25,7 +25,6 @@ from xoscar.utils import get_next_port
|
|
|
25
25
|
from .. import __version__
|
|
26
26
|
from ..client import RESTfulClient
|
|
27
27
|
from ..client.restful.restful_client import (
|
|
28
|
-
RESTfulChatglmCppChatModelHandle,
|
|
29
28
|
RESTfulChatModelHandle,
|
|
30
29
|
RESTfulGenerateModelHandle,
|
|
31
30
|
)
|
|
@@ -1268,9 +1267,7 @@ def model_chat(
|
|
|
1268
1267
|
task.exception()
|
|
1269
1268
|
else:
|
|
1270
1269
|
restful_model = client.get_model(model_uid=model_uid)
|
|
1271
|
-
if not isinstance(
|
|
1272
|
-
restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
|
|
1273
|
-
):
|
|
1270
|
+
if not isinstance(restful_model, RESTfulChatModelHandle):
|
|
1274
1271
|
raise ValueError(f"model {model_uid} has no chat method")
|
|
1275
1272
|
|
|
1276
1273
|
while True:
|
xinference/model/audio/core.py
CHANGED
|
@@ -14,13 +14,14 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
17
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
|
+
from .funasr import FunASRModel
|
|
24
25
|
from .whisper import WhisperModel
|
|
25
26
|
|
|
26
27
|
MAX_ATTEMPTS = 3
|
|
@@ -45,6 +46,8 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
45
46
|
model_id: str
|
|
46
47
|
model_revision: str
|
|
47
48
|
multilingual: bool
|
|
49
|
+
default_model_config: Optional[Dict[str, Any]]
|
|
50
|
+
default_transcription_config: Optional[Dict[str, Any]]
|
|
48
51
|
|
|
49
52
|
|
|
50
53
|
class AudioModelDescription(ModelDescription):
|
|
@@ -152,13 +155,18 @@ def create_audio_model_instance(
|
|
|
152
155
|
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
153
156
|
model_path: Optional[str] = None,
|
|
154
157
|
**kwargs,
|
|
155
|
-
) -> Tuple[
|
|
158
|
+
) -> Tuple[
|
|
159
|
+
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
|
|
160
|
+
AudioModelDescription,
|
|
161
|
+
]:
|
|
156
162
|
model_spec = match_audio(model_name, download_hub)
|
|
157
163
|
if model_path is None:
|
|
158
164
|
model_path = cache(model_spec)
|
|
159
|
-
model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
|
|
165
|
+
model: Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel]
|
|
160
166
|
if model_spec.model_family == "whisper":
|
|
161
167
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
168
|
+
elif model_spec.model_family == "funasr":
|
|
169
|
+
model = FunASRModel(model_uid, model_path, model_spec, **kwargs)
|
|
162
170
|
elif model_spec.model_family == "ChatTTS":
|
|
163
171
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
164
172
|
elif model_spec.model_family == "CosyVoice":
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import tempfile
|
|
17
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
18
|
+
|
|
19
|
+
from ...device_utils import get_available_device, is_device_available
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .core import AudioModelFamilyV1
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FunASRModel:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_uid: str,
|
|
31
|
+
model_path: str,
|
|
32
|
+
model_spec: "AudioModelFamilyV1",
|
|
33
|
+
device: Optional[str] = None,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
self._model_uid = model_uid
|
|
37
|
+
self._model_path = model_path
|
|
38
|
+
self._model_spec = model_spec
|
|
39
|
+
self._device = device
|
|
40
|
+
self._model = None
|
|
41
|
+
self._kwargs = kwargs
|
|
42
|
+
|
|
43
|
+
def load(self):
|
|
44
|
+
try:
|
|
45
|
+
from funasr import AutoModel
|
|
46
|
+
except ImportError:
|
|
47
|
+
error_message = "Failed to import module 'funasr'"
|
|
48
|
+
installation_guide = [
|
|
49
|
+
"Please make sure 'funasr' is installed. ",
|
|
50
|
+
"You can install it by `pip install funasr`\n",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
54
|
+
|
|
55
|
+
if self._device is None:
|
|
56
|
+
self._device = get_available_device()
|
|
57
|
+
else:
|
|
58
|
+
if not is_device_available(self._device):
|
|
59
|
+
raise ValueError(f"Device {self._device} is not available!")
|
|
60
|
+
|
|
61
|
+
kwargs = self._model_spec.default_model_config.copy()
|
|
62
|
+
kwargs.update(self._kwargs)
|
|
63
|
+
logger.debug("Loading FunASR model with kwargs: %s", kwargs)
|
|
64
|
+
self._model = AutoModel(model=self._model_path, device=self._device, **kwargs)
|
|
65
|
+
|
|
66
|
+
def transcriptions(
|
|
67
|
+
self,
|
|
68
|
+
audio: bytes,
|
|
69
|
+
language: Optional[str] = None,
|
|
70
|
+
prompt: Optional[str] = None,
|
|
71
|
+
response_format: str = "json",
|
|
72
|
+
temperature: float = 0,
|
|
73
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
74
|
+
**kwargs,
|
|
75
|
+
):
|
|
76
|
+
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
|
77
|
+
|
|
78
|
+
if temperature != 0:
|
|
79
|
+
raise RuntimeError("`temperature`is not supported for FunASR")
|
|
80
|
+
if timestamp_granularities is not None:
|
|
81
|
+
raise RuntimeError("`timestamp_granularities`is not supported for FunASR")
|
|
82
|
+
if prompt is not None:
|
|
83
|
+
logger.warning(
|
|
84
|
+
"Prompt for funasr transcriptions will be ignored: %s", prompt
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
language = "auto" if language is None else language
|
|
88
|
+
|
|
89
|
+
with tempfile.NamedTemporaryFile(buffering=0) as f:
|
|
90
|
+
f.write(audio)
|
|
91
|
+
|
|
92
|
+
kw = self._model_spec.default_transcription_config.copy() # type: ignore
|
|
93
|
+
kw.update(kwargs)
|
|
94
|
+
logger.debug("Calling FunASR model with kwargs: %s", kw)
|
|
95
|
+
result = self._model.generate( # type: ignore
|
|
96
|
+
input=f.name, cache={}, language=language, **kw
|
|
97
|
+
)
|
|
98
|
+
text = rich_transcription_postprocess(result[0]["text"])
|
|
99
|
+
|
|
100
|
+
if response_format == "json":
|
|
101
|
+
return {"text": text}
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(f"Unsupported response format: {response_format}")
|
|
104
|
+
|
|
105
|
+
def translations(
|
|
106
|
+
self,
|
|
107
|
+
audio: bytes,
|
|
108
|
+
language: Optional[str] = None,
|
|
109
|
+
prompt: Optional[str] = None,
|
|
110
|
+
response_format: str = "json",
|
|
111
|
+
temperature: float = 0,
|
|
112
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
113
|
+
):
|
|
114
|
+
raise RuntimeError("FunASR does not support translations API")
|
|
@@ -95,6 +95,26 @@
|
|
|
95
95
|
"ability": "audio-to-text",
|
|
96
96
|
"multilingual": false
|
|
97
97
|
},
|
|
98
|
+
{
|
|
99
|
+
"model_name": "SenseVoiceSmall",
|
|
100
|
+
"model_family": "funasr",
|
|
101
|
+
"model_id": "FunAudioLLM/SenseVoiceSmall",
|
|
102
|
+
"model_revision": "3eb3b4eeffc2f2dde6051b853983753db33e35c3",
|
|
103
|
+
"ability": "audio-to-text",
|
|
104
|
+
"multilingual": true,
|
|
105
|
+
"default_model_config": {
|
|
106
|
+
"vad_model": "fsmn-vad",
|
|
107
|
+
"vad_kwargs": {
|
|
108
|
+
"max_single_segment_time": 30000
|
|
109
|
+
}
|
|
110
|
+
},
|
|
111
|
+
"default_transcription_config": {
|
|
112
|
+
"use_itn": true,
|
|
113
|
+
"batch_size_s": 60,
|
|
114
|
+
"merge_vad": true,
|
|
115
|
+
"merge_length_s": 15
|
|
116
|
+
}
|
|
117
|
+
},
|
|
98
118
|
{
|
|
99
119
|
"model_name": "ChatTTS",
|
|
100
120
|
"model_family": "ChatTTS",
|
|
@@ -8,6 +8,27 @@
|
|
|
8
8
|
"ability": "audio-to-text",
|
|
9
9
|
"multilingual": true
|
|
10
10
|
},
|
|
11
|
+
{
|
|
12
|
+
"model_name": "SenseVoiceSmall",
|
|
13
|
+
"model_family": "funasr",
|
|
14
|
+
"model_hub": "modelscope",
|
|
15
|
+
"model_id": "iic/SenseVoiceSmall",
|
|
16
|
+
"model_revision": "master",
|
|
17
|
+
"ability": "audio-to-text",
|
|
18
|
+
"multilingual": true,
|
|
19
|
+
"default_model_config": {
|
|
20
|
+
"vad_model": "fsmn-vad",
|
|
21
|
+
"vad_kwargs": {
|
|
22
|
+
"max_single_segment_time": 30000
|
|
23
|
+
}
|
|
24
|
+
},
|
|
25
|
+
"default_transcription_config": {
|
|
26
|
+
"use_itn": true,
|
|
27
|
+
"batch_size_s": 60,
|
|
28
|
+
"merge_vad": true,
|
|
29
|
+
"merge_length_s": 15
|
|
30
|
+
}
|
|
31
|
+
},
|
|
11
32
|
{
|
|
12
33
|
"model_name": "ChatTTS",
|
|
13
34
|
"model_family": "ChatTTS",
|
xinference/model/core.py
CHANGED
|
@@ -65,6 +65,7 @@ def create_model_instance(
|
|
|
65
65
|
from .image.core import create_image_model_instance
|
|
66
66
|
from .llm.core import create_llm_model_instance
|
|
67
67
|
from .rerank.core import create_rerank_model_instance
|
|
68
|
+
from .video.core import create_video_model_instance
|
|
68
69
|
|
|
69
70
|
if model_type == "LLM":
|
|
70
71
|
return create_llm_model_instance(
|
|
@@ -127,6 +128,17 @@ def create_model_instance(
|
|
|
127
128
|
model_path,
|
|
128
129
|
**kwargs,
|
|
129
130
|
)
|
|
131
|
+
elif model_type == "video":
|
|
132
|
+
kwargs.pop("trust_remote_code", None)
|
|
133
|
+
return create_video_model_instance(
|
|
134
|
+
subpool_addr,
|
|
135
|
+
devices,
|
|
136
|
+
model_uid,
|
|
137
|
+
model_name,
|
|
138
|
+
download_hub,
|
|
139
|
+
model_path,
|
|
140
|
+
**kwargs,
|
|
141
|
+
)
|
|
130
142
|
elif model_type == "flexible":
|
|
131
143
|
kwargs.pop("trust_remote_code", None)
|
|
132
144
|
return create_flexible_model_instance(
|
xinference/model/image/core.py
CHANGED
|
@@ -45,7 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
|
|
|
45
45
|
model_id: str
|
|
46
46
|
model_revision: str
|
|
47
47
|
model_hub: str = "huggingface"
|
|
48
|
-
|
|
48
|
+
model_ability: Optional[List[str]]
|
|
49
49
|
controlnet: Optional[List["ImageModelFamilyV1"]]
|
|
50
50
|
|
|
51
51
|
|
|
@@ -72,7 +72,7 @@ class ImageModelDescription(ModelDescription):
|
|
|
72
72
|
"model_name": self._model_spec.model_name,
|
|
73
73
|
"model_family": self._model_spec.model_family,
|
|
74
74
|
"model_revision": self._model_spec.model_revision,
|
|
75
|
-
"
|
|
75
|
+
"model_ability": self._model_spec.model_ability,
|
|
76
76
|
"controlnet": controlnet,
|
|
77
77
|
}
|
|
78
78
|
|
|
@@ -178,7 +178,6 @@ def get_cache_status(
|
|
|
178
178
|
]
|
|
179
179
|
)
|
|
180
180
|
else: # Usually for UT
|
|
181
|
-
logger.warning(f"Cannot find builtin image model spec: {model_name}")
|
|
182
181
|
return valid_model_revision(meta_path, model_spec.model_revision)
|
|
183
182
|
|
|
184
183
|
|
|
@@ -239,7 +238,7 @@ def create_image_model_instance(
|
|
|
239
238
|
lora_model_paths=lora_model,
|
|
240
239
|
lora_load_kwargs=lora_load_kwargs,
|
|
241
240
|
lora_fuse_kwargs=lora_fuse_kwargs,
|
|
242
|
-
abilities=model_spec.
|
|
241
|
+
abilities=model_spec.model_ability,
|
|
243
242
|
**kwargs,
|
|
244
243
|
)
|
|
245
244
|
model_description = ImageModelDescription(
|