xinference 0.14.0__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -1
- xinference/client/handlers.py +0 -3
- xinference/client/restful/restful_client.py +51 -134
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +1 -4
- xinference/core/image_interface.py +33 -5
- xinference/core/model.py +28 -2
- xinference/core/supervisor.py +37 -0
- xinference/core/worker.py +130 -84
- xinference/deploy/cmdline.py +1 -4
- xinference/model/audio/core.py +11 -3
- xinference/model/audio/funasr.py +114 -0
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +21 -0
- xinference/model/audio/whisper.py +1 -1
- xinference/model/core.py +12 -0
- xinference/model/embedding/core.py +6 -6
- xinference/model/image/core.py +3 -4
- xinference/model/image/model_spec.json +41 -13
- xinference/model/image/model_spec_modelscope.json +30 -10
- xinference/model/image/stable_diffusion/core.py +53 -2
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +83 -1
- xinference/model/llm/llm_family_modelscope.json +85 -1
- xinference/model/llm/pytorch/core.py +1 -0
- xinference/model/llm/pytorch/minicpmv26.py +247 -0
- xinference/model/llm/sglang/core.py +72 -34
- xinference/model/llm/vllm/core.py +38 -0
- xinference/model/video/__init__.py +62 -0
- xinference/model/video/core.py +178 -0
- xinference/model/video/diffusers.py +180 -0
- xinference/model/video/model_spec.json +11 -0
- xinference/model/video/model_spec_modelscope.json +12 -0
- xinference/types.py +10 -24
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.af906659.js → main.17ca0398.js} +3 -3
- xinference/web/ui/build/static/js/main.17ca0398.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +1 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/METADATA +128 -122
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/RECORD +49 -42
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/js/main.af906659.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +0 -1
- /xinference/web/ui/build/static/js/{main.af906659.js.LICENSE.txt → main.17ca0398.js.LICENSE.txt} +0 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py
CHANGED
|
@@ -68,7 +68,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
68
68
|
# static attrs.
|
|
69
69
|
self._total_gpu_devices = gpu_devices
|
|
70
70
|
self._supervisor_address = supervisor_address
|
|
71
|
-
self._supervisor_ref = None
|
|
71
|
+
self._supervisor_ref: Optional[xo.ActorRefType] = None
|
|
72
72
|
self._main_pool = main_pool
|
|
73
73
|
self._main_pool.recover_sub_pool = self.recover_sub_pool
|
|
74
74
|
|
|
@@ -147,17 +147,20 @@ class WorkerActor(xo.StatelessActor):
|
|
|
147
147
|
)
|
|
148
148
|
event_model_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
149
149
|
try:
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
150
|
+
if self._event_collector_ref is not None:
|
|
151
|
+
await self._event_collector_ref.report_event(
|
|
152
|
+
event_model_uid,
|
|
153
|
+
Event(
|
|
154
|
+
event_type=EventType.WARNING,
|
|
155
|
+
event_ts=int(time.time()),
|
|
156
|
+
event_content="Recreate model",
|
|
157
|
+
),
|
|
158
|
+
)
|
|
158
159
|
except Exception as e:
|
|
159
160
|
# Report callback error can be log and ignore, should not interrupt the Process
|
|
160
161
|
logger.error("report_event error: %s" % (e))
|
|
162
|
+
finally:
|
|
163
|
+
del event_model_uid
|
|
161
164
|
|
|
162
165
|
self._model_uid_to_recover_count[model_uid] = (
|
|
163
166
|
recover_count - 1
|
|
@@ -175,79 +178,39 @@ class WorkerActor(xo.StatelessActor):
|
|
|
175
178
|
return "worker"
|
|
176
179
|
|
|
177
180
|
async def __post_create__(self):
|
|
178
|
-
from ..isolation import Isolation
|
|
179
|
-
from .cache_tracker import CacheTrackerActor
|
|
180
|
-
from .status_guard import StatusGuardActor
|
|
181
|
-
from .supervisor import SupervisorActor
|
|
182
|
-
|
|
183
|
-
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
184
|
-
"StatusGuardActor"
|
|
185
|
-
] = await xo.actor_ref(
|
|
186
|
-
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
187
|
-
)
|
|
188
|
-
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
189
|
-
EventCollectorActor
|
|
190
|
-
] = await xo.actor_ref(
|
|
191
|
-
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
192
|
-
)
|
|
193
|
-
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
194
|
-
"CacheTrackerActor"
|
|
195
|
-
] = await xo.actor_ref(
|
|
196
|
-
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
197
|
-
)
|
|
198
|
-
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
|
|
199
|
-
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
200
|
-
)
|
|
201
|
-
await self._supervisor_ref.add_worker(self.address)
|
|
202
|
-
if not XINFERENCE_DISABLE_HEALTH_CHECK:
|
|
203
|
-
# Run _periodical_report_status() in a dedicated thread.
|
|
204
|
-
self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
|
|
205
|
-
self._isolation.start()
|
|
206
|
-
asyncio.run_coroutine_threadsafe(
|
|
207
|
-
self._periodical_report_status(), loop=self._isolation.loop
|
|
208
|
-
)
|
|
209
|
-
logger.info(f"Xinference worker {self.address} started")
|
|
210
|
-
logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
|
|
211
|
-
purge_dir(XINFERENCE_CACHE_DIR)
|
|
212
|
-
|
|
213
181
|
from ..model.audio import (
|
|
214
182
|
CustomAudioModelFamilyV1,
|
|
215
183
|
generate_audio_description,
|
|
216
|
-
get_audio_model_descriptions,
|
|
217
184
|
register_audio,
|
|
218
185
|
unregister_audio,
|
|
219
186
|
)
|
|
220
187
|
from ..model.embedding import (
|
|
221
188
|
CustomEmbeddingModelSpec,
|
|
222
189
|
generate_embedding_description,
|
|
223
|
-
get_embedding_model_descriptions,
|
|
224
190
|
register_embedding,
|
|
225
191
|
unregister_embedding,
|
|
226
192
|
)
|
|
227
193
|
from ..model.flexible import (
|
|
228
194
|
FlexibleModelSpec,
|
|
229
|
-
|
|
195
|
+
generate_flexible_model_description,
|
|
230
196
|
register_flexible_model,
|
|
231
197
|
unregister_flexible_model,
|
|
232
198
|
)
|
|
233
199
|
from ..model.image import (
|
|
234
200
|
CustomImageModelFamilyV1,
|
|
235
201
|
generate_image_description,
|
|
236
|
-
get_image_model_descriptions,
|
|
237
202
|
register_image,
|
|
238
203
|
unregister_image,
|
|
239
204
|
)
|
|
240
205
|
from ..model.llm import (
|
|
241
206
|
CustomLLMFamilyV1,
|
|
242
207
|
generate_llm_description,
|
|
243
|
-
get_llm_model_descriptions,
|
|
244
208
|
register_llm,
|
|
245
209
|
unregister_llm,
|
|
246
210
|
)
|
|
247
211
|
from ..model.rerank import (
|
|
248
212
|
CustomRerankModelSpec,
|
|
249
213
|
generate_rerank_description,
|
|
250
|
-
get_rerank_model_descriptions,
|
|
251
214
|
register_rerank,
|
|
252
215
|
unregister_rerank,
|
|
253
216
|
)
|
|
@@ -287,27 +250,37 @@ class WorkerActor(xo.StatelessActor):
|
|
|
287
250
|
FlexibleModelSpec,
|
|
288
251
|
register_flexible_model,
|
|
289
252
|
unregister_flexible_model,
|
|
253
|
+
generate_flexible_model_description,
|
|
290
254
|
),
|
|
291
255
|
}
|
|
292
256
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
257
|
+
logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
|
|
258
|
+
purge_dir(XINFERENCE_CACHE_DIR)
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
await self.get_supervisor_ref(add_worker=True)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
# Do not crash the worker if supervisor is down, auto re-connect later
|
|
264
|
+
logger.error(f"cannot connect to supervisor {e}")
|
|
265
|
+
|
|
266
|
+
if not XINFERENCE_DISABLE_HEALTH_CHECK:
|
|
267
|
+
from ..isolation import Isolation
|
|
268
|
+
|
|
269
|
+
# Run _periodical_report_status() in a dedicated thread.
|
|
270
|
+
self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
|
|
271
|
+
self._isolation.start()
|
|
272
|
+
asyncio.run_coroutine_threadsafe(
|
|
273
|
+
self._periodical_report_status(), loop=self._isolation.loop
|
|
274
|
+
)
|
|
275
|
+
logger.info(f"Xinference worker {self.address} started")
|
|
304
276
|
|
|
305
277
|
# Windows does not have signal handler
|
|
306
278
|
if os.name != "nt":
|
|
307
279
|
|
|
308
280
|
async def signal_handler():
|
|
309
281
|
try:
|
|
310
|
-
await self.
|
|
282
|
+
supervisor_ref = await self.get_supervisor_ref(add_worker=False)
|
|
283
|
+
await supervisor_ref.remove_worker(self.address)
|
|
311
284
|
except Exception as e:
|
|
312
285
|
# Ignore the error of rpc, anyway we are exiting
|
|
313
286
|
logger.exception("remove worker rpc error: %s", e)
|
|
@@ -329,6 +302,64 @@ class WorkerActor(xo.StatelessActor):
|
|
|
329
302
|
return False
|
|
330
303
|
return True
|
|
331
304
|
|
|
305
|
+
async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
|
|
306
|
+
"""
|
|
307
|
+
Try connect to supervisor and return ActorRef. Raise exception on error
|
|
308
|
+
Params:
|
|
309
|
+
add_worker: By default will call supervisor.add_worker after first connect
|
|
310
|
+
"""
|
|
311
|
+
from .status_guard import StatusGuardActor
|
|
312
|
+
from .supervisor import SupervisorActor
|
|
313
|
+
|
|
314
|
+
if self._supervisor_ref is not None:
|
|
315
|
+
return self._supervisor_ref
|
|
316
|
+
self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
|
|
317
|
+
address=self._supervisor_address, uid=SupervisorActor.uid()
|
|
318
|
+
)
|
|
319
|
+
if add_worker and len(self._model_uid_to_model) == 0:
|
|
320
|
+
# Newly started (or restarted), has no model, notify supervisor
|
|
321
|
+
await self._supervisor_ref.add_worker(self.address)
|
|
322
|
+
logger.info("Connected to supervisor as a fresh worker")
|
|
323
|
+
|
|
324
|
+
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
325
|
+
"StatusGuardActor"
|
|
326
|
+
] = await xo.actor_ref(
|
|
327
|
+
address=self._supervisor_address, uid=StatusGuardActor.uid()
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
self._event_collector_ref: xo.ActorRefType[ # type: ignore
|
|
331
|
+
EventCollectorActor
|
|
332
|
+
] = await xo.actor_ref(
|
|
333
|
+
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
334
|
+
)
|
|
335
|
+
from .cache_tracker import CacheTrackerActor
|
|
336
|
+
|
|
337
|
+
self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
|
|
338
|
+
"CacheTrackerActor"
|
|
339
|
+
] = await xo.actor_ref(
|
|
340
|
+
address=self._supervisor_address, uid=CacheTrackerActor.uid()
|
|
341
|
+
)
|
|
342
|
+
# cache_tracker is on supervisor
|
|
343
|
+
from ..model.audio import get_audio_model_descriptions
|
|
344
|
+
from ..model.embedding import get_embedding_model_descriptions
|
|
345
|
+
from ..model.flexible import get_flexible_model_descriptions
|
|
346
|
+
from ..model.image import get_image_model_descriptions
|
|
347
|
+
from ..model.llm import get_llm_model_descriptions
|
|
348
|
+
from ..model.rerank import get_rerank_model_descriptions
|
|
349
|
+
|
|
350
|
+
# record model version
|
|
351
|
+
model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
|
|
352
|
+
model_version_infos.update(get_llm_model_descriptions())
|
|
353
|
+
model_version_infos.update(get_embedding_model_descriptions())
|
|
354
|
+
model_version_infos.update(get_rerank_model_descriptions())
|
|
355
|
+
model_version_infos.update(get_image_model_descriptions())
|
|
356
|
+
model_version_infos.update(get_audio_model_descriptions())
|
|
357
|
+
model_version_infos.update(get_flexible_model_descriptions())
|
|
358
|
+
await self._cache_tracker_ref.record_model_version(
|
|
359
|
+
model_version_infos, self.address
|
|
360
|
+
)
|
|
361
|
+
return self._supervisor_ref
|
|
362
|
+
|
|
332
363
|
@staticmethod
|
|
333
364
|
def get_devices_count():
|
|
334
365
|
from ..device_utils import gpu_count
|
|
@@ -340,9 +371,9 @@ class WorkerActor(xo.StatelessActor):
|
|
|
340
371
|
return len(self._model_uid_to_model)
|
|
341
372
|
|
|
342
373
|
async def is_model_vllm_backend(self, model_uid: str) -> bool:
|
|
343
|
-
assert self._supervisor_ref is not None
|
|
344
374
|
_model_uid, _, _ = parse_replica_model_uid(model_uid)
|
|
345
|
-
|
|
375
|
+
supervisor_ref = await self.get_supervisor_ref()
|
|
376
|
+
model_ref = await supervisor_ref.get_model(_model_uid)
|
|
346
377
|
return await model_ref.is_vllm_backend()
|
|
347
378
|
|
|
348
379
|
async def allocate_devices_for_embedding(self, model_uid: str) -> int:
|
|
@@ -704,6 +735,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
704
735
|
return ["text_to_image"]
|
|
705
736
|
elif model_type == "audio":
|
|
706
737
|
return ["audio_to_text"]
|
|
738
|
+
elif model_type == "video":
|
|
739
|
+
return ["text_to_video"]
|
|
707
740
|
elif model_type == "flexible":
|
|
708
741
|
return ["flexible"]
|
|
709
742
|
else:
|
|
@@ -760,14 +793,15 @@ class WorkerActor(xo.StatelessActor):
|
|
|
760
793
|
logger.exception(e)
|
|
761
794
|
raise
|
|
762
795
|
try:
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
796
|
+
if self._event_collector_ref is not None:
|
|
797
|
+
await self._event_collector_ref.report_event(
|
|
798
|
+
origin_uid,
|
|
799
|
+
Event(
|
|
800
|
+
event_type=EventType.INFO,
|
|
801
|
+
event_ts=int(time.time()),
|
|
802
|
+
event_content="Launch model",
|
|
803
|
+
),
|
|
804
|
+
)
|
|
771
805
|
except Exception as e:
|
|
772
806
|
# Report callback error can be log and ignore, should not interrupt the Process
|
|
773
807
|
logger.error("report_event error: %s" % (e))
|
|
@@ -863,6 +897,11 @@ class WorkerActor(xo.StatelessActor):
|
|
|
863
897
|
|
|
864
898
|
# update status to READY
|
|
865
899
|
abilities = await self._get_model_ability(model, model_type)
|
|
900
|
+
_ = await self.get_supervisor_ref(add_worker=False)
|
|
901
|
+
|
|
902
|
+
if self._status_guard_ref is None:
|
|
903
|
+
_ = await self.get_supervisor_ref()
|
|
904
|
+
assert self._status_guard_ref is not None
|
|
866
905
|
await self._status_guard_ref.update_instance_info(
|
|
867
906
|
origin_uid,
|
|
868
907
|
{"model_ability": abilities, "status": LaunchStatus.READY.name},
|
|
@@ -875,21 +914,23 @@ class WorkerActor(xo.StatelessActor):
|
|
|
875
914
|
raise ValueError(f"{model_uid} is launching")
|
|
876
915
|
origin_uid, _, __ = parse_replica_model_uid(model_uid)
|
|
877
916
|
try:
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
917
|
+
if self._event_collector_ref is not None:
|
|
918
|
+
await self._event_collector_ref.report_event(
|
|
919
|
+
origin_uid,
|
|
920
|
+
Event(
|
|
921
|
+
event_type=EventType.INFO,
|
|
922
|
+
event_ts=int(time.time()),
|
|
923
|
+
event_content="Terminate model",
|
|
924
|
+
),
|
|
925
|
+
)
|
|
886
926
|
except Exception as e:
|
|
887
927
|
# Report callback error can be log and ignore, should not interrupt the Process
|
|
888
928
|
logger.error("report_event error: %s" % (e))
|
|
889
929
|
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
930
|
+
if self._status_guard_ref is not None:
|
|
931
|
+
await self._status_guard_ref.update_instance_info(
|
|
932
|
+
origin_uid, {"status": LaunchStatus.TERMINATING.name}
|
|
933
|
+
)
|
|
893
934
|
model_ref = self._model_uid_to_model.get(model_uid, None)
|
|
894
935
|
if model_ref is None:
|
|
895
936
|
logger.debug("Model not found, uid: %s", model_uid)
|
|
@@ -914,6 +955,10 @@ class WorkerActor(xo.StatelessActor):
|
|
|
914
955
|
self._model_uid_to_addr.pop(model_uid, None)
|
|
915
956
|
self._model_uid_to_recover_count.pop(model_uid, None)
|
|
916
957
|
self._model_uid_to_launch_args.pop(model_uid, None)
|
|
958
|
+
|
|
959
|
+
if self._status_guard_ref is None:
|
|
960
|
+
_ = await self.get_supervisor_ref()
|
|
961
|
+
assert self._status_guard_ref is not None
|
|
917
962
|
await self._status_guard_ref.update_instance_info(
|
|
918
963
|
origin_uid, {"status": LaunchStatus.TERMINATED.name}
|
|
919
964
|
)
|
|
@@ -966,7 +1011,8 @@ class WorkerActor(xo.StatelessActor):
|
|
|
966
1011
|
raise
|
|
967
1012
|
except Exception:
|
|
968
1013
|
logger.exception("Report status got error.")
|
|
969
|
-
await self.
|
|
1014
|
+
supervisor_ref = await self.get_supervisor_ref()
|
|
1015
|
+
await supervisor_ref.report_worker_status(self.address, status)
|
|
970
1016
|
|
|
971
1017
|
async def _periodical_report_status(self):
|
|
972
1018
|
while True:
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -25,7 +25,6 @@ from xoscar.utils import get_next_port
|
|
|
25
25
|
from .. import __version__
|
|
26
26
|
from ..client import RESTfulClient
|
|
27
27
|
from ..client.restful.restful_client import (
|
|
28
|
-
RESTfulChatglmCppChatModelHandle,
|
|
29
28
|
RESTfulChatModelHandle,
|
|
30
29
|
RESTfulGenerateModelHandle,
|
|
31
30
|
)
|
|
@@ -1268,9 +1267,7 @@ def model_chat(
|
|
|
1268
1267
|
task.exception()
|
|
1269
1268
|
else:
|
|
1270
1269
|
restful_model = client.get_model(model_uid=model_uid)
|
|
1271
|
-
if not isinstance(
|
|
1272
|
-
restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
|
|
1273
|
-
):
|
|
1270
|
+
if not isinstance(restful_model, RESTfulChatModelHandle):
|
|
1274
1271
|
raise ValueError(f"model {model_uid} has no chat method")
|
|
1275
1272
|
|
|
1276
1273
|
while True:
|
xinference/model/audio/core.py
CHANGED
|
@@ -14,13 +14,14 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
17
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
|
+
from .funasr import FunASRModel
|
|
24
25
|
from .whisper import WhisperModel
|
|
25
26
|
|
|
26
27
|
MAX_ATTEMPTS = 3
|
|
@@ -45,6 +46,8 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
45
46
|
model_id: str
|
|
46
47
|
model_revision: str
|
|
47
48
|
multilingual: bool
|
|
49
|
+
default_model_config: Optional[Dict[str, Any]]
|
|
50
|
+
default_transcription_config: Optional[Dict[str, Any]]
|
|
48
51
|
|
|
49
52
|
|
|
50
53
|
class AudioModelDescription(ModelDescription):
|
|
@@ -152,13 +155,18 @@ def create_audio_model_instance(
|
|
|
152
155
|
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
153
156
|
model_path: Optional[str] = None,
|
|
154
157
|
**kwargs,
|
|
155
|
-
) -> Tuple[
|
|
158
|
+
) -> Tuple[
|
|
159
|
+
Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
|
|
160
|
+
AudioModelDescription,
|
|
161
|
+
]:
|
|
156
162
|
model_spec = match_audio(model_name, download_hub)
|
|
157
163
|
if model_path is None:
|
|
158
164
|
model_path = cache(model_spec)
|
|
159
|
-
model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
|
|
165
|
+
model: Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel]
|
|
160
166
|
if model_spec.model_family == "whisper":
|
|
161
167
|
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
168
|
+
elif model_spec.model_family == "funasr":
|
|
169
|
+
model = FunASRModel(model_uid, model_path, model_spec, **kwargs)
|
|
162
170
|
elif model_spec.model_family == "ChatTTS":
|
|
163
171
|
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
164
172
|
elif model_spec.model_family == "CosyVoice":
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import tempfile
|
|
17
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
18
|
+
|
|
19
|
+
from ...device_utils import get_available_device, is_device_available
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .core import AudioModelFamilyV1
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FunASRModel:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_uid: str,
|
|
31
|
+
model_path: str,
|
|
32
|
+
model_spec: "AudioModelFamilyV1",
|
|
33
|
+
device: Optional[str] = None,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
self._model_uid = model_uid
|
|
37
|
+
self._model_path = model_path
|
|
38
|
+
self._model_spec = model_spec
|
|
39
|
+
self._device = device
|
|
40
|
+
self._model = None
|
|
41
|
+
self._kwargs = kwargs
|
|
42
|
+
|
|
43
|
+
def load(self):
|
|
44
|
+
try:
|
|
45
|
+
from funasr import AutoModel
|
|
46
|
+
except ImportError:
|
|
47
|
+
error_message = "Failed to import module 'funasr'"
|
|
48
|
+
installation_guide = [
|
|
49
|
+
"Please make sure 'funasr' is installed. ",
|
|
50
|
+
"You can install it by `pip install funasr`\n",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
54
|
+
|
|
55
|
+
if self._device is None:
|
|
56
|
+
self._device = get_available_device()
|
|
57
|
+
else:
|
|
58
|
+
if not is_device_available(self._device):
|
|
59
|
+
raise ValueError(f"Device {self._device} is not available!")
|
|
60
|
+
|
|
61
|
+
kwargs = self._model_spec.default_model_config.copy()
|
|
62
|
+
kwargs.update(self._kwargs)
|
|
63
|
+
logger.debug("Loading FunASR model with kwargs: %s", kwargs)
|
|
64
|
+
self._model = AutoModel(model=self._model_path, device=self._device, **kwargs)
|
|
65
|
+
|
|
66
|
+
def transcriptions(
|
|
67
|
+
self,
|
|
68
|
+
audio: bytes,
|
|
69
|
+
language: Optional[str] = None,
|
|
70
|
+
prompt: Optional[str] = None,
|
|
71
|
+
response_format: str = "json",
|
|
72
|
+
temperature: float = 0,
|
|
73
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
74
|
+
**kwargs,
|
|
75
|
+
):
|
|
76
|
+
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
|
77
|
+
|
|
78
|
+
if temperature != 0:
|
|
79
|
+
raise RuntimeError("`temperature`is not supported for FunASR")
|
|
80
|
+
if timestamp_granularities is not None:
|
|
81
|
+
raise RuntimeError("`timestamp_granularities`is not supported for FunASR")
|
|
82
|
+
if prompt is not None:
|
|
83
|
+
logger.warning(
|
|
84
|
+
"Prompt for funasr transcriptions will be ignored: %s", prompt
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
language = "auto" if language is None else language
|
|
88
|
+
|
|
89
|
+
with tempfile.NamedTemporaryFile(buffering=0) as f:
|
|
90
|
+
f.write(audio)
|
|
91
|
+
|
|
92
|
+
kw = self._model_spec.default_transcription_config.copy() # type: ignore
|
|
93
|
+
kw.update(kwargs)
|
|
94
|
+
logger.debug("Calling FunASR model with kwargs: %s", kw)
|
|
95
|
+
result = self._model.generate( # type: ignore
|
|
96
|
+
input=f.name, cache={}, language=language, **kw
|
|
97
|
+
)
|
|
98
|
+
text = rich_transcription_postprocess(result[0]["text"])
|
|
99
|
+
|
|
100
|
+
if response_format == "json":
|
|
101
|
+
return {"text": text}
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(f"Unsupported response format: {response_format}")
|
|
104
|
+
|
|
105
|
+
def translations(
|
|
106
|
+
self,
|
|
107
|
+
audio: bytes,
|
|
108
|
+
language: Optional[str] = None,
|
|
109
|
+
prompt: Optional[str] = None,
|
|
110
|
+
response_format: str = "json",
|
|
111
|
+
temperature: float = 0,
|
|
112
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
113
|
+
):
|
|
114
|
+
raise RuntimeError("FunASR does not support translations API")
|
|
@@ -95,6 +95,26 @@
|
|
|
95
95
|
"ability": "audio-to-text",
|
|
96
96
|
"multilingual": false
|
|
97
97
|
},
|
|
98
|
+
{
|
|
99
|
+
"model_name": "SenseVoiceSmall",
|
|
100
|
+
"model_family": "funasr",
|
|
101
|
+
"model_id": "FunAudioLLM/SenseVoiceSmall",
|
|
102
|
+
"model_revision": "3eb3b4eeffc2f2dde6051b853983753db33e35c3",
|
|
103
|
+
"ability": "audio-to-text",
|
|
104
|
+
"multilingual": true,
|
|
105
|
+
"default_model_config": {
|
|
106
|
+
"vad_model": "fsmn-vad",
|
|
107
|
+
"vad_kwargs": {
|
|
108
|
+
"max_single_segment_time": 30000
|
|
109
|
+
}
|
|
110
|
+
},
|
|
111
|
+
"default_transcription_config": {
|
|
112
|
+
"use_itn": true,
|
|
113
|
+
"batch_size_s": 60,
|
|
114
|
+
"merge_vad": true,
|
|
115
|
+
"merge_length_s": 15
|
|
116
|
+
}
|
|
117
|
+
},
|
|
98
118
|
{
|
|
99
119
|
"model_name": "ChatTTS",
|
|
100
120
|
"model_family": "ChatTTS",
|
|
@@ -8,6 +8,27 @@
|
|
|
8
8
|
"ability": "audio-to-text",
|
|
9
9
|
"multilingual": true
|
|
10
10
|
},
|
|
11
|
+
{
|
|
12
|
+
"model_name": "SenseVoiceSmall",
|
|
13
|
+
"model_family": "funasr",
|
|
14
|
+
"model_hub": "modelscope",
|
|
15
|
+
"model_id": "iic/SenseVoiceSmall",
|
|
16
|
+
"model_revision": "master",
|
|
17
|
+
"ability": "audio-to-text",
|
|
18
|
+
"multilingual": true,
|
|
19
|
+
"default_model_config": {
|
|
20
|
+
"vad_model": "fsmn-vad",
|
|
21
|
+
"vad_kwargs": {
|
|
22
|
+
"max_single_segment_time": 30000
|
|
23
|
+
}
|
|
24
|
+
},
|
|
25
|
+
"default_transcription_config": {
|
|
26
|
+
"use_itn": true,
|
|
27
|
+
"batch_size_s": 60,
|
|
28
|
+
"merge_vad": true,
|
|
29
|
+
"merge_length_s": 15
|
|
30
|
+
}
|
|
31
|
+
},
|
|
11
32
|
{
|
|
12
33
|
"model_name": "ChatTTS",
|
|
13
34
|
"model_family": "ChatTTS",
|
xinference/model/core.py
CHANGED
|
@@ -65,6 +65,7 @@ def create_model_instance(
|
|
|
65
65
|
from .image.core import create_image_model_instance
|
|
66
66
|
from .llm.core import create_llm_model_instance
|
|
67
67
|
from .rerank.core import create_rerank_model_instance
|
|
68
|
+
from .video.core import create_video_model_instance
|
|
68
69
|
|
|
69
70
|
if model_type == "LLM":
|
|
70
71
|
return create_llm_model_instance(
|
|
@@ -127,6 +128,17 @@ def create_model_instance(
|
|
|
127
128
|
model_path,
|
|
128
129
|
**kwargs,
|
|
129
130
|
)
|
|
131
|
+
elif model_type == "video":
|
|
132
|
+
kwargs.pop("trust_remote_code", None)
|
|
133
|
+
return create_video_model_instance(
|
|
134
|
+
subpool_addr,
|
|
135
|
+
devices,
|
|
136
|
+
model_uid,
|
|
137
|
+
model_name,
|
|
138
|
+
download_hub,
|
|
139
|
+
model_path,
|
|
140
|
+
**kwargs,
|
|
141
|
+
)
|
|
130
142
|
elif model_type == "flexible":
|
|
131
143
|
kwargs.pop("trust_remote_code", None)
|
|
132
144
|
return create_flexible_model_instance(
|
|
@@ -151,8 +151,8 @@ class EmbeddingModel:
|
|
|
151
151
|
|
|
152
152
|
patch_trust_remote_code()
|
|
153
153
|
if (
|
|
154
|
-
"gte
|
|
155
|
-
|
|
154
|
+
"gte" in self._model_spec.model_name.lower()
|
|
155
|
+
and "qwen2" in self._model_spec.model_name.lower()
|
|
156
156
|
):
|
|
157
157
|
self._model = XSentenceTransformer(
|
|
158
158
|
self._model_path,
|
|
@@ -260,8 +260,8 @@ class EmbeddingModel:
|
|
|
260
260
|
device = model._target_device
|
|
261
261
|
|
|
262
262
|
if (
|
|
263
|
-
"gte
|
|
264
|
-
and "
|
|
263
|
+
"gte" in self._model_spec.model_name.lower()
|
|
264
|
+
and "qwen2" in self._model_spec.model_name.lower()
|
|
265
265
|
):
|
|
266
266
|
model.to(device)
|
|
267
267
|
|
|
@@ -342,8 +342,8 @@ class EmbeddingModel:
|
|
|
342
342
|
return all_embeddings, all_token_nums
|
|
343
343
|
|
|
344
344
|
if (
|
|
345
|
-
"gte
|
|
346
|
-
|
|
345
|
+
"gte" in self._model_spec.model_name.lower()
|
|
346
|
+
and "qwen2" in self._model_spec.model_name.lower()
|
|
347
347
|
):
|
|
348
348
|
all_embeddings, all_token_nums = encode(
|
|
349
349
|
self._model,
|