xinference 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (136) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/async_restful_client.py +8 -13
  3. xinference/client/restful/restful_client.py +6 -2
  4. xinference/core/chat_interface.py +6 -4
  5. xinference/core/media_interface.py +5 -0
  6. xinference/core/model.py +1 -5
  7. xinference/core/supervisor.py +117 -68
  8. xinference/core/worker.py +49 -37
  9. xinference/deploy/test/test_cmdline.py +2 -6
  10. xinference/model/audio/__init__.py +26 -23
  11. xinference/model/audio/chattts.py +3 -2
  12. xinference/model/audio/core.py +49 -98
  13. xinference/model/audio/cosyvoice.py +3 -2
  14. xinference/model/audio/custom.py +28 -73
  15. xinference/model/audio/f5tts.py +3 -2
  16. xinference/model/audio/f5tts_mlx.py +3 -2
  17. xinference/model/audio/fish_speech.py +3 -2
  18. xinference/model/audio/funasr.py +17 -4
  19. xinference/model/audio/kokoro.py +3 -2
  20. xinference/model/audio/megatts.py +3 -2
  21. xinference/model/audio/melotts.py +3 -2
  22. xinference/model/audio/model_spec.json +572 -171
  23. xinference/model/audio/utils.py +0 -6
  24. xinference/model/audio/whisper.py +3 -2
  25. xinference/model/audio/whisper_mlx.py +3 -2
  26. xinference/model/cache_manager.py +141 -0
  27. xinference/model/core.py +6 -49
  28. xinference/model/custom.py +174 -0
  29. xinference/model/embedding/__init__.py +67 -56
  30. xinference/model/embedding/cache_manager.py +35 -0
  31. xinference/model/embedding/core.py +104 -84
  32. xinference/model/embedding/custom.py +55 -78
  33. xinference/model/embedding/embed_family.py +80 -31
  34. xinference/model/embedding/flag/core.py +21 -5
  35. xinference/model/embedding/llama_cpp/__init__.py +0 -0
  36. xinference/model/embedding/llama_cpp/core.py +234 -0
  37. xinference/model/embedding/model_spec.json +968 -103
  38. xinference/model/embedding/sentence_transformers/core.py +30 -20
  39. xinference/model/embedding/vllm/core.py +11 -5
  40. xinference/model/flexible/__init__.py +8 -2
  41. xinference/model/flexible/core.py +26 -119
  42. xinference/model/flexible/custom.py +69 -0
  43. xinference/model/flexible/launchers/image_process_launcher.py +1 -0
  44. xinference/model/flexible/launchers/modelscope_launcher.py +5 -1
  45. xinference/model/flexible/launchers/transformers_launcher.py +15 -3
  46. xinference/model/flexible/launchers/yolo_launcher.py +5 -1
  47. xinference/model/image/__init__.py +20 -20
  48. xinference/model/image/cache_manager.py +62 -0
  49. xinference/model/image/core.py +70 -182
  50. xinference/model/image/custom.py +28 -72
  51. xinference/model/image/model_spec.json +402 -119
  52. xinference/model/image/ocr/got_ocr2.py +3 -2
  53. xinference/model/image/stable_diffusion/core.py +22 -7
  54. xinference/model/image/stable_diffusion/mlx.py +6 -6
  55. xinference/model/image/utils.py +2 -2
  56. xinference/model/llm/__init__.py +71 -94
  57. xinference/model/llm/cache_manager.py +292 -0
  58. xinference/model/llm/core.py +37 -111
  59. xinference/model/llm/custom.py +88 -0
  60. xinference/model/llm/llama_cpp/core.py +5 -7
  61. xinference/model/llm/llm_family.json +16260 -8151
  62. xinference/model/llm/llm_family.py +138 -839
  63. xinference/model/llm/lmdeploy/core.py +5 -7
  64. xinference/model/llm/memory.py +3 -4
  65. xinference/model/llm/mlx/core.py +6 -8
  66. xinference/model/llm/reasoning_parser.py +3 -1
  67. xinference/model/llm/sglang/core.py +32 -14
  68. xinference/model/llm/transformers/chatglm.py +3 -7
  69. xinference/model/llm/transformers/core.py +49 -27
  70. xinference/model/llm/transformers/deepseek_v2.py +2 -2
  71. xinference/model/llm/transformers/gemma3.py +2 -2
  72. xinference/model/llm/transformers/multimodal/cogagent.py +2 -2
  73. xinference/model/llm/transformers/multimodal/deepseek_vl2.py +2 -2
  74. xinference/model/llm/transformers/multimodal/gemma3.py +2 -2
  75. xinference/model/llm/transformers/multimodal/glm4_1v.py +167 -0
  76. xinference/model/llm/transformers/multimodal/glm4v.py +2 -2
  77. xinference/model/llm/transformers/multimodal/intern_vl.py +2 -2
  78. xinference/model/llm/transformers/multimodal/minicpmv26.py +3 -3
  79. xinference/model/llm/transformers/multimodal/ovis2.py +2 -2
  80. xinference/model/llm/transformers/multimodal/qwen-omni.py +2 -2
  81. xinference/model/llm/transformers/multimodal/qwen2_audio.py +2 -2
  82. xinference/model/llm/transformers/multimodal/qwen2_vl.py +2 -2
  83. xinference/model/llm/transformers/opt.py +3 -7
  84. xinference/model/llm/utils.py +34 -49
  85. xinference/model/llm/vllm/core.py +77 -27
  86. xinference/model/llm/vllm/xavier/engine.py +5 -3
  87. xinference/model/llm/vllm/xavier/scheduler.py +10 -6
  88. xinference/model/llm/vllm/xavier/transfer.py +1 -1
  89. xinference/model/rerank/__init__.py +26 -25
  90. xinference/model/rerank/core.py +47 -87
  91. xinference/model/rerank/custom.py +25 -71
  92. xinference/model/rerank/model_spec.json +158 -33
  93. xinference/model/rerank/utils.py +2 -2
  94. xinference/model/utils.py +115 -54
  95. xinference/model/video/__init__.py +13 -17
  96. xinference/model/video/core.py +44 -102
  97. xinference/model/video/diffusers.py +4 -3
  98. xinference/model/video/model_spec.json +90 -21
  99. xinference/types.py +5 -3
  100. xinference/web/ui/build/asset-manifest.json +3 -3
  101. xinference/web/ui/build/index.html +1 -1
  102. xinference/web/ui/build/static/js/main.7d24df53.js +3 -0
  103. xinference/web/ui/build/static/js/main.7d24df53.js.map +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2704ff66a5f73ca78b341eb3edec60154369df9d87fbc8c6dd60121abc5e1b0a.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/607dfef23d33e6b594518c0c6434567639f24f356b877c80c60575184ec50ed0.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/9be3d56173aacc3efd0b497bcb13c4f6365de30069176ee9403b40e717542326.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/9f9dd6c32c78a222d07da5987ae902effe16bcf20aac00774acdccc4de3c9ff2.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/b2ab5ee972c60d15eb9abf5845705f8ab7e1d125d324d9a9b1bcae5d6fd7ffb2.json +1 -0
  109. xinference/web/ui/src/locales/en.json +0 -1
  110. xinference/web/ui/src/locales/ja.json +0 -1
  111. xinference/web/ui/src/locales/ko.json +0 -1
  112. xinference/web/ui/src/locales/zh.json +0 -1
  113. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/METADATA +9 -11
  114. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/RECORD +119 -119
  115. xinference/model/audio/model_spec_modelscope.json +0 -231
  116. xinference/model/embedding/model_spec_modelscope.json +0 -293
  117. xinference/model/embedding/utils.py +0 -18
  118. xinference/model/image/model_spec_modelscope.json +0 -375
  119. xinference/model/llm/llama_cpp/memory.py +0 -457
  120. xinference/model/llm/llm_family_csghub.json +0 -56
  121. xinference/model/llm/llm_family_modelscope.json +0 -8700
  122. xinference/model/llm/llm_family_openmind_hub.json +0 -1019
  123. xinference/model/rerank/model_spec_modelscope.json +0 -85
  124. xinference/model/video/model_spec_modelscope.json +0 -184
  125. xinference/web/ui/build/static/js/main.9b12b7f9.js +0 -3
  126. xinference/web/ui/build/static/js/main.9b12b7f9.js.map +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +0 -1
  132. /xinference/web/ui/build/static/js/{main.9b12b7f9.js.LICENSE.txt → main.7d24df53.js.LICENSE.txt} +0 -0
  133. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/WHEEL +0 -0
  134. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/entry_points.txt +0 -0
  135. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/licenses/LICENSE +0 -0
  136. {xinference-1.7.1.dist-info → xinference-1.8.0.dist-info}/top_level.txt +0 -0
xinference/core/worker.py CHANGED
@@ -54,7 +54,7 @@ from ..constants import (
54
54
  from ..core.model import ModelActor
55
55
  from ..core.status_guard import LaunchStatus
56
56
  from ..device_utils import get_available_device_env_name, gpu_count
57
- from ..model.core import ModelDescription, VirtualEnvSettings, create_model_instance
57
+ from ..model.core import VirtualEnvSettings, create_model_instance
58
58
  from ..model.utils import CancellableDownloader, get_engine_params_by_name
59
59
  from ..types import PeftModelConfig
60
60
  from ..utils import get_pip_config_args, get_real_path
@@ -131,14 +131,14 @@ class WorkerActor(xo.StatelessActor):
131
131
  self._model_uid_launching_guard: Dict[str, LaunchInfo] = {}
132
132
  # attributes maintained after model launched:
133
133
  self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {}
134
- self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
134
+ self._model_uid_to_model_spec: Dict[str, Dict[str, Any]] = {}
135
135
  self._model_uid_to_model_status: Dict[str, ModelStatus] = {}
136
136
  self._gpu_to_model_uid: Dict[int, str] = {}
137
137
  self._gpu_to_embedding_model_uids: Dict[int, Set[str]] = defaultdict(set)
138
138
  # Dict structure: gpu_index: {(replica_model_uid, model_type)}
139
- self._user_specified_gpu_to_model_uids: Dict[
140
- int, Set[Tuple[str, str]]
141
- ] = defaultdict(set)
139
+ self._user_specified_gpu_to_model_uids: Dict[int, Set[Tuple[str, str]]] = (
140
+ defaultdict(set)
141
+ )
142
142
  self._model_uid_to_addr: Dict[str, str] = {}
143
143
  self._model_uid_to_recover_count: Dict[str, Optional[int]] = {}
144
144
  self._model_uid_to_launch_args: Dict[str, Dict] = {}
@@ -236,13 +236,13 @@ class WorkerActor(xo.StatelessActor):
236
236
 
237
237
  async def __post_create__(self):
238
238
  from ..model.audio import (
239
- CustomAudioModelFamilyV1,
239
+ CustomAudioModelFamilyV2,
240
240
  generate_audio_description,
241
241
  register_audio,
242
242
  unregister_audio,
243
243
  )
244
244
  from ..model.embedding import (
245
- CustomEmbeddingModelSpec,
245
+ CustomEmbeddingModelFamilyV2,
246
246
  generate_embedding_description,
247
247
  register_embedding,
248
248
  unregister_embedding,
@@ -254,19 +254,19 @@ class WorkerActor(xo.StatelessActor):
254
254
  unregister_flexible_model,
255
255
  )
256
256
  from ..model.image import (
257
- CustomImageModelFamilyV1,
257
+ CustomImageModelFamilyV2,
258
258
  generate_image_description,
259
259
  register_image,
260
260
  unregister_image,
261
261
  )
262
262
  from ..model.llm import (
263
- CustomLLMFamilyV1,
264
- generate_llm_description,
263
+ CustomLLMFamilyV2,
264
+ generate_llm_version_info,
265
265
  register_llm,
266
266
  unregister_llm,
267
267
  )
268
268
  from ..model.rerank import (
269
- CustomRerankModelSpec,
269
+ CustomRerankModelFamilyV2,
270
270
  generate_rerank_description,
271
271
  register_rerank,
272
272
  unregister_rerank,
@@ -274,31 +274,31 @@ class WorkerActor(xo.StatelessActor):
274
274
 
275
275
  self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore
276
276
  "LLM": (
277
- CustomLLMFamilyV1,
277
+ CustomLLMFamilyV2,
278
278
  register_llm,
279
279
  unregister_llm,
280
- generate_llm_description,
280
+ generate_llm_version_info,
281
281
  ),
282
282
  "embedding": (
283
- CustomEmbeddingModelSpec,
283
+ CustomEmbeddingModelFamilyV2,
284
284
  register_embedding,
285
285
  unregister_embedding,
286
286
  generate_embedding_description,
287
287
  ),
288
288
  "rerank": (
289
- CustomRerankModelSpec,
289
+ CustomRerankModelFamilyV2,
290
290
  register_rerank,
291
291
  unregister_rerank,
292
292
  generate_rerank_description,
293
293
  ),
294
294
  "image": (
295
- CustomImageModelFamilyV1,
295
+ CustomImageModelFamilyV2,
296
296
  register_image,
297
297
  unregister_image,
298
298
  generate_image_description,
299
299
  ),
300
300
  "audio": (
301
- CustomAudioModelFamilyV1,
301
+ CustomAudioModelFamilyV2,
302
302
  register_audio,
303
303
  unregister_audio,
304
304
  generate_audio_description,
@@ -396,16 +396,18 @@ class WorkerActor(xo.StatelessActor):
396
396
  from ..model.embedding import get_embedding_model_descriptions
397
397
  from ..model.flexible import get_flexible_model_descriptions
398
398
  from ..model.image import get_image_model_descriptions
399
- from ..model.llm import get_llm_model_descriptions
399
+ from ..model.llm import get_llm_version_infos
400
400
  from ..model.rerank import get_rerank_model_descriptions
401
+ from ..model.video import get_video_model_descriptions
401
402
 
402
403
  # record model version
403
404
  model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
404
- model_version_infos.update(get_llm_model_descriptions())
405
+ model_version_infos.update(get_llm_version_infos())
405
406
  model_version_infos.update(get_embedding_model_descriptions())
406
407
  model_version_infos.update(get_rerank_model_descriptions())
407
408
  model_version_infos.update(get_image_model_descriptions())
408
409
  model_version_infos.update(get_audio_model_descriptions())
410
+ model_version_infos.update(get_video_model_descriptions())
409
411
  model_version_infos.update(get_flexible_model_descriptions())
410
412
  await self._cache_tracker_ref.record_model_version(
411
413
  model_version_infos, self.address
@@ -774,10 +776,7 @@ class WorkerActor(xo.StatelessActor):
774
776
  assert isinstance(model, LLM)
775
777
  return model.model_family.model_ability # type: ignore
776
778
 
777
- async def update_cache_status(
778
- self, model_name: str, model_description: ModelDescription
779
- ):
780
- version_info = model_description.to_version_info()
779
+ async def update_cache_status(self, model_name: str, version_info: Any):
781
780
  if isinstance(version_info, list): # image model
782
781
  model_path = version_info[0]["model_file_location"]
783
782
  await self._cache_tracker_ref.update_cache_status(
@@ -1028,10 +1027,8 @@ class WorkerActor(xo.StatelessActor):
1028
1027
  self._upload_download_progress, progressor, downloader
1029
1028
  )
1030
1029
  )
1031
- model, model_description = await asyncio.to_thread(
1030
+ model = await asyncio.to_thread(
1032
1031
  create_model_instance,
1033
- subpool_address,
1034
- devices,
1035
1032
  model_uid,
1036
1033
  model_type,
1037
1034
  model_name,
@@ -1044,7 +1041,14 @@ class WorkerActor(xo.StatelessActor):
1044
1041
  model_path,
1045
1042
  **model_kwargs,
1046
1043
  )
1047
- await self.update_cache_status(model_name, model_description)
1044
+ model.model_family.address = subpool_address
1045
+ model.model_family.accelerators = devices
1046
+ model.model_family.multimodal_projector = model_kwargs.get(
1047
+ "multimodal_projector", None
1048
+ )
1049
+ await self.update_cache_status(
1050
+ model_name, model.model_family.to_version_info()
1051
+ )
1048
1052
 
1049
1053
  def check_cancel():
1050
1054
  # check downloader first, sometimes download finished
@@ -1063,7 +1067,7 @@ class WorkerActor(xo.StatelessActor):
1063
1067
  await asyncio.to_thread(
1064
1068
  self._prepare_virtual_env,
1065
1069
  virtual_env_manager,
1066
- model_description.spec.virtualenv,
1070
+ model.model_family.virtualenv,
1067
1071
  )
1068
1072
  launch_info.virtual_env_manager = virtual_env_manager
1069
1073
 
@@ -1078,7 +1082,6 @@ class WorkerActor(xo.StatelessActor):
1078
1082
  worker_address=self.address,
1079
1083
  replica_model_uid=model_uid,
1080
1084
  model=model,
1081
- model_description=model_description,
1082
1085
  request_limits=request_limits,
1083
1086
  xavier_config=xavier_config,
1084
1087
  n_worker=n_worker,
@@ -1125,7 +1128,9 @@ class WorkerActor(xo.StatelessActor):
1125
1128
  continue
1126
1129
  raise
1127
1130
  self._model_uid_to_model[model_uid] = model_ref
1128
- self._model_uid_to_model_spec[model_uid] = model_description
1131
+ self._model_uid_to_model_spec[model_uid] = (
1132
+ model.model_family.to_description()
1133
+ )
1129
1134
  self._model_uid_to_model_status[model_uid] = ModelStatus()
1130
1135
  self._model_uid_to_addr[model_uid] = subpool_address
1131
1136
  self._model_uid_to_recover_count.setdefault(
@@ -1301,12 +1306,7 @@ class WorkerActor(xo.StatelessActor):
1301
1306
 
1302
1307
  @log_async(logger=logger)
1303
1308
  async def list_models(self) -> Dict[str, Dict[str, Any]]:
1304
- ret = {}
1305
-
1306
- items = list(self._model_uid_to_model_spec.items())
1307
- for k, v in items:
1308
- ret[k] = v.to_dict()
1309
- return ret
1309
+ return {k: v for k, v in self._model_uid_to_model_spec.items()}
1310
1310
 
1311
1311
  @log_sync(logger=logger)
1312
1312
  def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
@@ -1323,7 +1323,7 @@ class WorkerActor(xo.StatelessActor):
1323
1323
  model_desc = self._model_uid_to_model_spec.get(model_uid, None)
1324
1324
  if model_desc is None:
1325
1325
  raise ValueError(f"Model not found in the model list, uid: {model_uid}")
1326
- return model_desc.to_dict()
1326
+ return model_desc
1327
1327
 
1328
1328
  async def report_status(self):
1329
1329
  status = dict()
@@ -1409,7 +1409,9 @@ class WorkerActor(xo.StatelessActor):
1409
1409
 
1410
1410
  async def confirm_and_remove_model(self, model_version: str) -> bool:
1411
1411
  paths = await self.list_deletable_models(model_version)
1412
+ dir_paths = set()
1412
1413
  for path in paths:
1414
+ dir_paths.add(os.path.dirname(path))
1413
1415
  try:
1414
1416
  if os.path.islink(path):
1415
1417
  os.unlink(path)
@@ -1422,6 +1424,16 @@ class WorkerActor(xo.StatelessActor):
1422
1424
  except Exception as e:
1423
1425
  logger.error(f"Fail to delete {path} with error:{e}.") # noqa: E231
1424
1426
  return False
1427
+
1428
+ for _dir in dir_paths:
1429
+ try:
1430
+ shutil.rmtree(_dir)
1431
+ except Exception as e:
1432
+ logger.error(
1433
+ f"Fail to delete parent dir {_dir} with error:{e}."
1434
+ ) # noqa: E231
1435
+ return False
1436
+
1425
1437
  await self._cache_tracker_ref.confirm_and_remove_model(
1426
1438
  model_version, self.address
1427
1439
  )
@@ -185,7 +185,7 @@ def test_cmdline_of_custom_model(setup):
185
185
 
186
186
  # register custom model
187
187
  custom_model_desc = """{
188
- "version": 1,
188
+ "version": 2,
189
189
  "context_length":2048,
190
190
  "model_name": "custom_model",
191
191
  "model_lang": [
@@ -200,11 +200,7 @@ def test_cmdline_of_custom_model(setup):
200
200
  {
201
201
  "model_format": "pytorch",
202
202
  "model_size_in_billions": 7,
203
- "quantizations": [
204
- "4-bit",
205
- "8-bit",
206
- "none"
207
- ],
203
+ "quantization": "none",
208
204
  "model_id": "ziqingyang/chinese-alpaca-2-7b"
209
205
  }
210
206
  ],
@@ -18,38 +18,41 @@ import os
18
18
  import platform
19
19
  import sys
20
20
  import warnings
21
- from typing import Any, Dict
21
+ from typing import Dict, List
22
22
 
23
23
  from ...constants import XINFERENCE_MODEL_DIR
24
+ from ..utils import flatten_model_src
24
25
  from .core import (
25
26
  AUDIO_MODEL_DESCRIPTIONS,
26
- MODEL_NAME_TO_REVISION,
27
- AudioModelFamilyV1,
27
+ AudioModelFamilyV2,
28
28
  generate_audio_description,
29
29
  get_audio_model_descriptions,
30
- get_cache_status,
31
30
  )
32
31
  from .custom import (
33
- CustomAudioModelFamilyV1,
32
+ CustomAudioModelFamilyV2,
34
33
  get_user_defined_audios,
35
34
  register_audio,
36
35
  unregister_audio,
37
36
  )
38
37
 
39
- BUILTIN_AUDIO_MODELS: Dict[str, Any] = {}
40
- MODELSCOPE_AUDIO_MODELS: Dict[str, Any] = {}
38
+ BUILTIN_AUDIO_MODELS: Dict[str, List["AudioModelFamilyV2"]] = {}
41
39
 
42
40
 
43
41
  def register_custom_model():
42
+ from ..custom import migrate_from_v1_to_v2
43
+
44
+ # migrate from v1 to v2 first
45
+ migrate_from_v1_to_v2("audio", CustomAudioModelFamilyV2)
46
+
44
47
  # if persist=True, load them when init
45
- user_defined_audio_dir = os.path.join(XINFERENCE_MODEL_DIR, "audio")
48
+ user_defined_audio_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "audio")
46
49
  if os.path.isdir(user_defined_audio_dir):
47
50
  for f in os.listdir(user_defined_audio_dir):
48
51
  try:
49
52
  with codecs.open(
50
53
  os.path.join(user_defined_audio_dir, f), encoding="utf-8"
51
54
  ) as fd:
52
- user_defined_audio_family = CustomAudioModelFamilyV1.parse_obj(
55
+ user_defined_audio_family = CustomAudioModelFamilyV2.parse_obj(
53
56
  json.load(fd)
54
57
  )
55
58
  register_audio(user_defined_audio_family, persist=False)
@@ -67,13 +70,12 @@ def _need_filter(spec: dict):
67
70
 
68
71
  def _install():
69
72
  load_model_family_from_json("model_spec.json", BUILTIN_AUDIO_MODELS)
70
- load_model_family_from_json("model_spec_modelscope.json", MODELSCOPE_AUDIO_MODELS)
71
73
 
72
74
  # register model description after recording model revision
73
- for model_spec_info in [BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS]:
74
- for model_name, model_spec in model_spec_info.items():
75
- if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
76
- AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
75
+ for model_name, model_specs in BUILTIN_AUDIO_MODELS.items():
76
+ model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0]
77
+ if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
78
+ AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
77
79
 
78
80
  register_custom_model()
79
81
 
@@ -84,14 +86,15 @@ def _install():
84
86
 
85
87
  def load_model_family_from_json(json_filename, target_families):
86
88
  json_path = os.path.join(os.path.dirname(__file__), json_filename)
87
- target_families.update(
88
- dict(
89
- (spec["model_name"], AudioModelFamilyV1(**spec))
90
- for spec in json.load(codecs.open(json_path, "r", encoding="utf-8"))
91
- if not _need_filter(spec)
92
- )
93
- )
94
- for model_name, model_spec in target_families.items():
95
- MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
89
+ flattened_model_specs = []
90
+ for spec in json.load(codecs.open(json_path, "r", encoding="utf-8")):
91
+ flattened_model_specs.extend(flatten_model_src(spec))
92
+
93
+ for spec in flattened_model_specs:
94
+ if not _need_filter(spec):
95
+ if spec["model_name"] not in target_families:
96
+ target_families[spec["model_name"]] = [AudioModelFamilyV2(**spec)]
97
+ else:
98
+ target_families[spec["model_name"]].append(AudioModelFamilyV2(**spec))
96
99
 
97
100
  del json_path
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Optional
20
20
  from ..utils import set_all_random_seed
21
21
 
22
22
  if TYPE_CHECKING:
23
- from .core import AudioModelFamilyV1
23
+ from .core import AudioModelFamilyV2
24
24
 
25
25
  logger = logging.getLogger(__name__)
26
26
 
@@ -30,10 +30,11 @@ class ChatTTSModel:
30
30
  self,
31
31
  model_uid: str,
32
32
  model_path: str,
33
- model_spec: "AudioModelFamilyV1",
33
+ model_spec: "AudioModelFamilyV2",
34
34
  device: Optional[str] = None,
35
35
  **kwargs,
36
36
  ):
37
+ self.model_family = model_spec
37
38
  self._model_uid = model_uid
38
39
  self._model_path = model_path
39
40
  self._model_spec = model_spec
@@ -12,13 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- import os
16
15
  from collections import defaultdict
17
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
16
+ from typing import Any, Dict, List, Literal, Optional, Union
18
17
 
19
- from ...constants import XINFERENCE_CACHE_DIR
20
- from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
21
- from ..utils import valid_model_revision
18
+ from ..core import CacheableModelSpec, VirtualEnvSettings
19
+ from ..utils import ModelInstanceInfoMixin
22
20
  from .chattts import ChatTTSModel
23
21
  from .cosyvoice import CosyVoiceModel
24
22
  from .f5tts import F5TTSModel
@@ -33,9 +31,7 @@ from .whisper_mlx import WhisperMLXModel
33
31
 
34
32
  logger = logging.getLogger(__name__)
35
33
 
36
- # Used for check whether the model is cached.
37
34
  # Init when registering all the builtin models.
38
- MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
39
35
  AUDIO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
40
36
 
41
37
 
@@ -45,7 +41,8 @@ def get_audio_model_descriptions():
45
41
  return copy.deepcopy(AUDIO_MODEL_DESCRIPTIONS)
46
42
 
47
43
 
48
- class AudioModelFamilyV1(CacheableModelSpec):
44
+ class AudioModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
45
+ version: Literal[2]
49
46
  model_family: str
50
47
  model_name: str
51
48
  model_id: str
@@ -58,57 +55,37 @@ class AudioModelFamilyV1(CacheableModelSpec):
58
55
  engine: Optional[str]
59
56
  virtualenv: Optional[VirtualEnvSettings]
60
57
 
58
+ class Config:
59
+ extra = "allow"
61
60
 
62
- class AudioModelDescription(ModelDescription):
63
- def __init__(
64
- self,
65
- address: Optional[str],
66
- devices: Optional[List[str]],
67
- model_spec: AudioModelFamilyV1,
68
- model_path: Optional[str] = None,
69
- ):
70
- super().__init__(address, devices, model_path=model_path)
71
- self._model_spec = model_spec
72
-
73
- @property
74
- def spec(self):
75
- return self._model_spec
76
-
77
- def to_dict(self):
61
+ def to_description(self):
78
62
  return {
79
63
  "model_type": "audio",
80
- "address": self.address,
81
- "accelerators": self.devices,
82
- "model_name": self._model_spec.model_name,
83
- "model_family": self._model_spec.model_family,
84
- "model_revision": self._model_spec.model_revision,
85
- "model_ability": self._model_spec.model_ability,
64
+ "address": getattr(self, "address", None),
65
+ "accelerators": getattr(self, "accelerators", None),
66
+ "model_name": self.model_name,
67
+ "model_family": self.model_family,
68
+ "model_revision": self.model_revision,
69
+ "model_ability": self.model_ability,
86
70
  }
87
71
 
88
72
  def to_version_info(self):
89
- from .utils import get_model_version
73
+ from ..cache_manager import CacheManager
90
74
 
91
- if self._model_path is None:
92
- is_cached = get_cache_status(self._model_spec)
93
- file_location = get_cache_dir(self._model_spec)
94
- else:
95
- is_cached = True
96
- file_location = self._model_path
75
+ cache_manager = CacheManager(self)
97
76
 
98
77
  return {
99
- "model_version": get_model_version(self._model_spec),
100
- "model_file_location": file_location,
101
- "cache_status": is_cached,
78
+ "model_version": self.model_name,
79
+ "model_file_location": cache_manager.get_cache_dir(),
80
+ "cache_status": cache_manager.get_cache_status(),
102
81
  }
103
82
 
104
83
 
105
84
  def generate_audio_description(
106
- image_model: AudioModelFamilyV1,
85
+ audio_model: AudioModelFamilyV2,
107
86
  ) -> Dict[str, List[Dict]]:
108
87
  res = defaultdict(list)
109
- res[image_model.model_name].append(
110
- AudioModelDescription(None, None, image_model).to_version_info()
111
- )
88
+ res[audio_model.model_name].append(audio_model.to_version_info())
112
89
  return res
113
90
 
114
91
 
@@ -117,27 +94,24 @@ def match_audio(
117
94
  download_hub: Optional[
118
95
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
119
96
  ] = None,
120
- ) -> AudioModelFamilyV1:
97
+ ) -> AudioModelFamilyV2:
121
98
  from ..utils import download_from_modelscope
122
- from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
99
+ from . import BUILTIN_AUDIO_MODELS
123
100
  from .custom import get_user_defined_audios
124
101
 
125
102
  for model_spec in get_user_defined_audios():
126
103
  if model_spec.model_name == model_name:
127
104
  return model_spec
128
105
 
129
- if download_hub == "huggingface" and model_name in BUILTIN_AUDIO_MODELS:
130
- logger.debug(f"Audio model {model_name} found in huggingface.")
131
- return BUILTIN_AUDIO_MODELS[model_name]
132
- elif download_hub == "modelscope" and model_name in MODELSCOPE_AUDIO_MODELS:
133
- logger.debug(f"Audio model {model_name} found in ModelScope.")
134
- return MODELSCOPE_AUDIO_MODELS[model_name]
135
- elif download_from_modelscope() and model_name in MODELSCOPE_AUDIO_MODELS:
136
- logger.debug(f"Audio model {model_name} found in ModelScope.")
137
- return MODELSCOPE_AUDIO_MODELS[model_name]
138
- elif model_name in BUILTIN_AUDIO_MODELS:
139
- logger.debug(f"Audio model {model_name} found in huggingface.")
140
- return BUILTIN_AUDIO_MODELS[model_name]
106
+ if model_name in BUILTIN_AUDIO_MODELS:
107
+ model_families = BUILTIN_AUDIO_MODELS[model_name]
108
+ if download_hub == "modelscope" or download_from_modelscope():
109
+ return (
110
+ [x for x in model_families if x.model_hub == "modelscope"]
111
+ + [x for x in model_families if x.model_hub == "huggingface"]
112
+ )[0]
113
+ else:
114
+ return [x for x in model_families if x.model_hub == "huggingface"][0]
141
115
  else:
142
116
  raise ValueError(
143
117
  f"Audio model {model_name} not found, available"
@@ -145,27 +119,7 @@ def match_audio(
145
119
  )
146
120
 
147
121
 
148
- def cache(model_spec: AudioModelFamilyV1):
149
- from ..utils import cache
150
-
151
- return cache(model_spec, AudioModelDescription)
152
-
153
-
154
- def get_cache_dir(model_spec: AudioModelFamilyV1):
155
- return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
156
-
157
-
158
- def get_cache_status(
159
- model_spec: AudioModelFamilyV1,
160
- ) -> bool:
161
- cache_dir = get_cache_dir(model_spec)
162
- meta_path = os.path.join(cache_dir, "__valid_download")
163
- return valid_model_revision(meta_path, model_spec.model_revision)
164
-
165
-
166
122
  def create_audio_model_instance(
167
- subpool_addr: str,
168
- devices: List[str],
169
123
  model_uid: str,
170
124
  model_name: str,
171
125
  download_hub: Optional[
@@ -173,25 +127,25 @@ def create_audio_model_instance(
173
127
  ] = None,
174
128
  model_path: Optional[str] = None,
175
129
  **kwargs,
176
- ) -> Tuple[
177
- Union[
178
- WhisperModel,
179
- WhisperMLXModel,
180
- FunASRModel,
181
- ChatTTSModel,
182
- CosyVoiceModel,
183
- FishSpeechModel,
184
- F5TTSModel,
185
- F5TTSMLXModel,
186
- MeloTTSModel,
187
- KokoroModel,
188
- MegaTTSModel,
189
- ],
190
- AudioModelDescription,
130
+ ) -> Union[
131
+ WhisperModel,
132
+ WhisperMLXModel,
133
+ FunASRModel,
134
+ ChatTTSModel,
135
+ CosyVoiceModel,
136
+ FishSpeechModel,
137
+ F5TTSModel,
138
+ F5TTSMLXModel,
139
+ MeloTTSModel,
140
+ KokoroModel,
141
+ MegaTTSModel,
191
142
  ]:
143
+ from ..cache_manager import CacheManager
144
+
192
145
  model_spec = match_audio(model_name, download_hub)
193
146
  if model_path is None:
194
- model_path = cache(model_spec)
147
+ cache_manager = CacheManager(model_spec)
148
+ model_path = cache_manager.cache()
195
149
  model: Union[
196
150
  WhisperModel,
197
151
  WhisperMLXModel,
@@ -230,7 +184,4 @@ def create_audio_model_instance(
230
184
  model = MegaTTSModel(model_uid, model_path, model_spec, **kwargs)
231
185
  else:
232
186
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
233
- model_description = AudioModelDescription(
234
- subpool_addr, devices, model_spec, model_path
235
- )
236
- return model, model_description
187
+ return model
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
18
18
  from ..utils import set_all_random_seed
19
19
 
20
20
  if TYPE_CHECKING:
21
- from .core import AudioModelFamilyV1
21
+ from .core import AudioModelFamilyV2
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
@@ -28,10 +28,11 @@ class CosyVoiceModel:
28
28
  self,
29
29
  model_uid: str,
30
30
  model_path: str,
31
- model_spec: "AudioModelFamilyV1",
31
+ model_spec: "AudioModelFamilyV2",
32
32
  device: Optional[str] = None,
33
33
  **kwargs,
34
34
  ):
35
+ self.model_family = model_spec
35
36
  self._model_uid = model_uid
36
37
  self._model_path = model_path
37
38
  self._model_spec = model_spec