xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-01-10T17:24:10+0800",
11
+ "date": "2025-02-08T17:06:47+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "df45f11115051929d6296a0c138b99472abf497f",
15
- "version": "1.2.0"
14
+ "full-revisionid": "ac97a13a831de6debda52e6fdb8c1bf9366be57c",
15
+ "version": "1.2.2"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -2000,25 +2000,22 @@ class RESTfulAPI(CancelMixin):
2000
2000
 
2001
2001
  from ..model.llm.utils import (
2002
2002
  GLM4_TOOL_CALL_FAMILY,
2003
- LLAMA3_TOOL_CALL_FAMILY,
2004
2003
  QWEN_TOOL_CALL_FAMILY,
2004
+ TOOL_CALL_FAMILY,
2005
2005
  )
2006
2006
 
2007
2007
  model_family = desc.get("model_family", "")
2008
- function_call_models = (
2009
- QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY + LLAMA3_TOOL_CALL_FAMILY
2010
- )
2011
2008
 
2012
- if model_family not in function_call_models:
2009
+ if model_family not in TOOL_CALL_FAMILY:
2013
2010
  if body.tools:
2014
2011
  raise HTTPException(
2015
2012
  status_code=400,
2016
- detail=f"Only {function_call_models} support tool calls",
2013
+ detail=f"Only {TOOL_CALL_FAMILY} support tool calls",
2017
2014
  )
2018
2015
  if has_tool_message:
2019
2016
  raise HTTPException(
2020
2017
  status_code=400,
2021
- detail=f"Only {function_call_models} support tool messages",
2018
+ detail=f"Only {TOOL_CALL_FAMILY} support tool messages",
2022
2019
  )
2023
2020
  if body.tools and body.stream:
2024
2021
  is_vllm = await model.is_vllm_backend()
@@ -13,3 +13,6 @@ from .restful.restful_client import ( # noqa: F401
13
13
  from .restful.restful_client import ( # noqa: F401
14
14
  RESTfulImageModelHandle as ImageModelHandle,
15
15
  )
16
+ from .restful.restful_client import ( # noqa: F401
17
+ RESTfulVideoModelHandle as VideoModelHandle,
18
+ )
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import base64
16
+ import html
16
17
  import logging
17
18
  import os
18
19
  from io import BytesIO
@@ -137,7 +138,11 @@ class GradioInterface:
137
138
  if "content" not in delta:
138
139
  continue
139
140
  else:
140
- response_content += delta["content"]
141
+ # some model like deepseek-r1-distill-qwen
142
+ # will generate <think>...</think> ...
143
+ # in gradio, no output will be rendered,
144
+ # thus escape html tags in advance
145
+ response_content += html.escape(delta["content"])
141
146
  yield response_content
142
147
 
143
148
  yield response_content
xinference/core/model.py CHANGED
@@ -35,6 +35,7 @@ from typing import (
35
35
  List,
36
36
  Optional,
37
37
  Union,
38
+ no_type_check,
38
39
  )
39
40
 
40
41
  import sse_starlette.sse
@@ -302,6 +303,7 @@ class ModelActor(xo.StatelessActor, CancelMixin):
302
303
  def decrease_serve_count(self):
303
304
  self._serve_count -= 1
304
305
 
306
+ @no_type_check
305
307
  async def start_transfer_for_vllm(self, rank_addresses: List[str]):
306
308
  from ..model.llm.vllm.core import VLLMModel
307
309
  from ..model.llm.vllm.xavier.transfer import TransferActor
@@ -269,16 +269,13 @@ class InferenceRequest:
269
269
  )
270
270
 
271
271
 
272
- def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]):
273
- from transformers.cache_utils import DynamicCache
274
-
275
- cache = DynamicCache.from_legacy_cache(data)
272
+ def _get_valid_batch_kv_cache(cache, skipped_indexes: Set[int]):
276
273
  batch_size = cache.key_cache[0].shape[0]
277
274
  batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
278
275
  for idx in range(len(cache)):
279
- cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
280
- cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
281
- return cache.to_legacy_cache()
276
+ cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::].contiguous()
277
+ cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::].contiguous()
278
+ return cache
282
279
 
283
280
 
284
281
  class SchedulerActor(xo.StatelessActor):
@@ -268,8 +268,12 @@ class SupervisorActor(xo.StatelessActor):
268
268
  )
269
269
 
270
270
  from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
271
+ from ..model.llm.vllm.xavier.collective_manager import CollectiveManager
271
272
 
272
- self._block_tracker: Optional[xo.ActorRefType[VLLMBlockTracker]] = None
273
+ self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {}
274
+ self._collective_manager_mapping: Dict[
275
+ str, xo.ActorRefType[CollectiveManager]
276
+ ] = {}
273
277
 
274
278
  @typing.no_type_check
275
279
  async def get_cluster_device_info(self, detailed: bool = False) -> List:
@@ -960,26 +964,40 @@ class SupervisorActor(xo.StatelessActor):
960
964
  ]:
961
965
  raise ValueError("Tensorizer is not supported for %s." % model_name)
962
966
 
967
+ if model_uid is None:
968
+ model_uid = self._gen_model_uid(model_name)
969
+
970
+ # Xavier-related
963
971
  enable_xavier: bool = (
964
972
  bool(kwargs.pop("enable_xavier", False))
965
973
  and model_engine is not None
966
974
  and model_engine.lower() == "vllm"
967
975
  )
976
+ store_address = None
977
+ store_port = None
978
+ world_size = None
968
979
  if enable_xavier:
969
980
  if replica <= 1:
970
981
  logger.warning(f"Enabling xavier when `replica<=1` is meaningless.")
971
982
  enable_xavier = False
972
983
  else:
973
984
  from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
985
+ from ..model.llm.vllm.xavier.collective_manager import CollectiveManager
974
986
 
975
- self._block_tracker = await xo.create_actor(
987
+ self._block_tracker_mapping[model_uid] = await xo.create_actor(
976
988
  VLLMBlockTracker,
977
989
  address=self.address,
978
- uid=VLLMBlockTracker.default_uid(),
990
+ uid=f"{VLLMBlockTracker.default_uid()}-{model_uid}",
979
991
  )
980
-
981
- if model_uid is None:
982
- model_uid = self._gen_model_uid(model_name)
992
+ world_size = replica + 1
993
+ logger.info(f"Going to start xavier with world size: {world_size}")
994
+ self._collective_manager_mapping[model_uid] = await xo.create_actor(
995
+ CollectiveManager,
996
+ address=self.address,
997
+ uid=f"{CollectiveManager.default_uid()}-{model_uid}",
998
+ model_uid=model_uid,
999
+ )
1000
+ logger.info(f"Start collective manager for {model_uid} done.")
983
1001
 
984
1002
  model_size = str(model_size_in_billions) if model_size_in_billions else ""
985
1003
  logger.debug(
@@ -988,13 +1006,38 @@ class SupervisorActor(xo.StatelessActor):
988
1006
  f"kwargs: {kwargs}"
989
1007
  )
990
1008
 
991
- async def _launch_one_model(
992
- worker_ref, _replica_model_uid, rank: int, store_port: int
993
- ):
1009
+ async def _launch_one_model(worker_ref, _replica_model_uid, rank: int):
994
1010
  if _replica_model_uid in self._replica_model_uid_to_worker:
995
1011
  raise ValueError(
996
1012
  f"Model is already in the model list, uid: {_replica_model_uid}"
997
1013
  )
1014
+
1015
+ nonlocal store_address
1016
+ nonlocal store_port
1017
+ xavier_config = (
1018
+ {
1019
+ "block_tracker_uid": self._block_tracker_mapping[model_uid].uid,
1020
+ "block_tracker_address": self._block_tracker_mapping[
1021
+ model_uid
1022
+ ].address,
1023
+ "rank": rank,
1024
+ "world_size": world_size,
1025
+ "store_address": store_address,
1026
+ "store_port": store_port,
1027
+ }
1028
+ if enable_xavier
1029
+ else None
1030
+ )
1031
+
1032
+ if enable_xavier and rank == 0:
1033
+ rank0_address, _port = await worker_ref.launch_rank0_model(
1034
+ _replica_model_uid, xavier_config
1035
+ )
1036
+ self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
1037
+ store_address = rank0_address.split(":")[0]
1038
+ store_port = _port
1039
+ return rank0_address
1040
+
998
1041
  replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx)
999
1042
  nonlocal model_type
1000
1043
 
@@ -1014,17 +1057,7 @@ class SupervisorActor(xo.StatelessActor):
1014
1057
  gpu_idx=replica_gpu_idx,
1015
1058
  download_hub=download_hub,
1016
1059
  model_path=model_path,
1017
- xavier_config={
1018
- "block_tracker_address": self._block_tracker.address
1019
- if self._block_tracker is not None
1020
- else None,
1021
- "rank": rank,
1022
- "world_size": replica,
1023
- "store_address": self.address.split(":")[0],
1024
- "store_port": store_port,
1025
- }
1026
- if enable_xavier
1027
- else None,
1060
+ xavier_config=xavier_config,
1028
1061
  **kwargs,
1029
1062
  )
1030
1063
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
@@ -1032,10 +1065,9 @@ class SupervisorActor(xo.StatelessActor):
1032
1065
 
1033
1066
  async def _launch_model():
1034
1067
  try:
1035
- store_port = xo.utils.get_next_port()
1036
1068
  worker_refs = []
1037
1069
  rank_addresses = []
1038
- for rank, rep_model_uid in enumerate(
1070
+ for _idx, rep_model_uid in enumerate(
1039
1071
  iter_replica_model_uid(model_uid, replica)
1040
1072
  ):
1041
1073
  worker_ref = (
@@ -1043,8 +1075,18 @@ class SupervisorActor(xo.StatelessActor):
1043
1075
  if target_ip_worker_ref is not None
1044
1076
  else await self._choose_worker()
1045
1077
  )
1078
+ if enable_xavier and _idx == 0:
1079
+ """
1080
+ Start the rank 0 model actor on the worker that holds the rank 1 replica,
1081
+ solely for constructing the collective communication world.
1082
+ """
1083
+ _uid = model_uid + "-rank0"
1084
+ rank0_address = await _launch_one_model(worker_ref, _uid, 0)
1085
+ worker_refs.append((worker_ref, _uid))
1086
+ rank_addresses.append(rank0_address)
1087
+
1046
1088
  subpool_address = await _launch_one_model(
1047
- worker_ref, rep_model_uid, rank, store_port
1089
+ worker_ref, rep_model_uid, _idx + 1
1048
1090
  )
1049
1091
  worker_refs.append((worker_ref, rep_model_uid))
1050
1092
  rank_addresses.append(subpool_address)
@@ -1054,6 +1096,7 @@ class SupervisorActor(xo.StatelessActor):
1054
1096
  # because the transfer actor needs all the rank addresses used for collective communication
1055
1097
  if enable_xavier:
1056
1098
  logger.debug(f"Init transfer component for xavier...")
1099
+ collective_manager_ref = self._collective_manager_mapping[model_uid]
1057
1100
  tasks = []
1058
1101
  for worker_ref, rep_model_uid in worker_refs:
1059
1102
  tasks.append(
@@ -1064,6 +1107,13 @@ class SupervisorActor(xo.StatelessActor):
1064
1107
  # Here you must use asyncio.gather, not a for loop,
1065
1108
  # or you will get stuck.
1066
1109
  await asyncio.gather(*tasks)
1110
+
1111
+ # init collective_manager
1112
+ for idx, addr in enumerate(rank_addresses):
1113
+ await collective_manager_ref.register_rank(
1114
+ idx, addr, update=False
1115
+ )
1116
+
1067
1117
  logger.debug(f"Init transfer component for xavier done.")
1068
1118
  except Exception:
1069
1119
  # terminate_model will remove the replica info.
@@ -1193,6 +1243,38 @@ class SupervisorActor(xo.StatelessActor):
1193
1243
  raise
1194
1244
  self._model_uid_to_replica_info.pop(model_uid, None)
1195
1245
 
1246
+ # clear for xavier
1247
+ rank0_uid = model_uid + "-rank0"
1248
+ if rank0_uid in self._replica_model_uid_to_worker:
1249
+ await _terminate_one_model(rank0_uid)
1250
+
1251
+ collective_manager_ref = self._collective_manager_mapping.pop(model_uid, None)
1252
+ if collective_manager_ref is not None:
1253
+ try:
1254
+ await xo.destroy_actor(collective_manager_ref)
1255
+ except Exception as e:
1256
+ logger.debug(
1257
+ "Destroy collective_manager_ref failed, model uid: %s, error: %s",
1258
+ model_uid,
1259
+ e,
1260
+ )
1261
+ finally:
1262
+ logger.debug(
1263
+ f"Destroy collective_manager_ref done. model uid: {model_uid}"
1264
+ )
1265
+ block_tracker_ref = self._block_tracker_mapping.pop(model_uid, None)
1266
+ if block_tracker_ref is not None:
1267
+ try:
1268
+ await xo.destroy_actor(block_tracker_ref)
1269
+ except Exception as e:
1270
+ logger.debug(
1271
+ "Destroy block_tracker_ref failed, model uid: %s, error: %s",
1272
+ model_uid,
1273
+ e,
1274
+ )
1275
+ finally:
1276
+ logger.debug(f"Destroy block_tracker_ref done. model uid: {model_uid}")
1277
+
1196
1278
  @log_async(logger=logger)
1197
1279
  async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
1198
1280
  replica_info = self._model_uid_to_replica_info.get(model_uid, None)
@@ -1448,3 +1530,12 @@ class SupervisorActor(xo.StatelessActor):
1448
1530
 
1449
1531
  async def get_progress(self, request_id: str) -> float:
1450
1532
  return await self._progress_tracker.get_progress(request_id)
1533
+
1534
+ async def call_collective_manager(
1535
+ self, model_uid: str, func_name: str, *args, **kwargs
1536
+ ):
1537
+ """
1538
+ Used by worker.
1539
+ """
1540
+ collective_manager_ref = self._collective_manager_mapping[model_uid]
1541
+ await getattr(collective_manager_ref, func_name)(*args, **kwargs)
xinference/core/worker.py CHANGED
@@ -24,7 +24,7 @@ import time
24
24
  from collections import defaultdict
25
25
  from dataclasses import dataclass
26
26
  from logging import getLogger
27
- from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
27
+ from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, no_type_check
28
28
 
29
29
  import xoscar as xo
30
30
  from async_timeout import timeout
@@ -184,12 +184,12 @@ class WorkerActor(xo.StatelessActor):
184
184
  self._model_uid_to_recover_count[model_uid] = (
185
185
  recover_count - 1
186
186
  )
187
- await self.launch_builtin_model(**launch_args)
187
+ await self.recover_model(launch_args)
188
188
  else:
189
189
  logger.warning("Stop recreating model actor.")
190
190
  else:
191
191
  logger.warning("Recreating model actor %s ...", model_uid)
192
- await self.launch_builtin_model(**launch_args)
192
+ await self.recover_model(launch_args)
193
193
  break
194
194
 
195
195
  @classmethod
@@ -940,7 +940,11 @@ class WorkerActor(xo.StatelessActor):
940
940
  # Terminate model while its launching is not allow
941
941
  if model_uid in self._model_uid_launching_guard:
942
942
  raise ValueError(f"{model_uid} is launching")
943
- origin_uid, _ = parse_replica_model_uid(model_uid)
943
+ # In special cases, if the suffix is `-rank0`, this is the Xavier's rank 0 model actor.
944
+ if model_uid.endswith("-rank0"):
945
+ origin_uid = model_uid.removesuffix("-rank0")
946
+ else:
947
+ origin_uid, _ = parse_replica_model_uid(model_uid)
944
948
  try:
945
949
  _ = await self.get_supervisor_ref()
946
950
  if self._event_collector_ref is not None:
@@ -1173,3 +1177,65 @@ class WorkerActor(xo.StatelessActor):
1173
1177
  ):
1174
1178
  model_ref = self._model_uid_to_model[rep_model_uid]
1175
1179
  await model_ref.start_transfer_for_vllm(rank_addresses)
1180
+
1181
+ @log_async(logger=logger, level=logging.INFO)
1182
+ async def launch_rank0_model(
1183
+ self, rep_model_uid: str, xavier_config: Dict[str, Any]
1184
+ ) -> Tuple[str, int]:
1185
+ from ..model.llm.vllm.xavier.collective_manager import Rank0ModelActor
1186
+
1187
+ if os.name != "nt" and platform.system() != "Darwin":
1188
+ # Linux
1189
+ start_method = "forkserver"
1190
+ else:
1191
+ # Windows and macOS
1192
+ start_method = "spawn"
1193
+ subpool_address = await self._main_pool.append_sub_pool(
1194
+ start_method=start_method
1195
+ )
1196
+
1197
+ store_address = subpool_address.split(":")[0]
1198
+ # Note that `store_port` needs to be generated on the worker,
1199
+ # as the TCP store is on rank 0, not on the supervisor.
1200
+ store_port = xo.utils.get_next_port()
1201
+ self._model_uid_launching_guard[rep_model_uid] = True
1202
+ try:
1203
+ try:
1204
+ xavier_config["rank_address"] = subpool_address
1205
+ xavier_config["store_address"] = store_address
1206
+ xavier_config["store_port"] = store_port
1207
+ model_ref = await xo.create_actor(
1208
+ Rank0ModelActor,
1209
+ address=subpool_address,
1210
+ uid=rep_model_uid,
1211
+ xavier_config=xavier_config,
1212
+ )
1213
+ except:
1214
+ await self._main_pool.remove_sub_pool(subpool_address)
1215
+ raise
1216
+ self._model_uid_to_model[rep_model_uid] = model_ref
1217
+ self._model_uid_to_addr[rep_model_uid] = subpool_address
1218
+ finally:
1219
+ del self._model_uid_launching_guard[rep_model_uid]
1220
+ return subpool_address, store_port
1221
+
1222
+ @no_type_check
1223
+ async def recover_model(self, launch_args: Dict[str, Any]):
1224
+ rep_model_uid = launch_args.get("model_uid")
1225
+ origin_uid, _ = parse_replica_model_uid(rep_model_uid)
1226
+ xavier_config: Optional[Dict[str, Any]] = launch_args.get("xavier_config", None)
1227
+ is_xavier: bool = xavier_config is not None
1228
+ supervisor_ref = await self.get_supervisor_ref(add_worker=False)
1229
+ if is_xavier:
1230
+ rank = xavier_config.get("rank")
1231
+ await supervisor_ref.call_collective_manager(
1232
+ origin_uid, "unregister_rank", rank
1233
+ )
1234
+ subpool_address = await self.launch_builtin_model(**launch_args)
1235
+ if is_xavier:
1236
+ model_ref = self._model_uid_to_model[rep_model_uid]
1237
+ await model_ref.start_transfer_for_vllm([])
1238
+ rank = xavier_config.get("rank")
1239
+ await supervisor_ref.call_collective_manager(
1240
+ origin_uid, "register_rank", rank, subpool_address, update=True
1241
+ )
@@ -41,7 +41,8 @@ async def _start_local_cluster(
41
41
  ):
42
42
  from .utils import create_worker_actor_pool
43
43
 
44
- logging.config.dictConfig(logging_conf) # type: ignore
44
+ if logging_conf:
45
+ logging.config.dictConfig(logging_conf) # type: ignore
45
46
 
46
47
  pool = None
47
48
  try:
@@ -25,6 +25,8 @@ from .f5tts import F5TTSModel
25
25
  from .f5tts_mlx import F5TTSMLXModel
26
26
  from .fish_speech import FishSpeechModel
27
27
  from .funasr import FunASRModel
28
+ from .kokoro import KokoroModel
29
+ from .melotts import MeloTTSModel
28
30
  from .whisper import WhisperModel
29
31
  from .whisper_mlx import WhisperMLXModel
30
32
 
@@ -48,6 +50,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
48
50
  model_id: str
49
51
  model_revision: Optional[str]
50
52
  multilingual: bool
53
+ language: Optional[str]
51
54
  model_ability: Optional[str]
52
55
  default_model_config: Optional[Dict[str, Any]]
53
56
  default_transcription_config: Optional[Dict[str, Any]]
@@ -173,6 +176,8 @@ def create_audio_model_instance(
173
176
  FishSpeechModel,
174
177
  F5TTSModel,
175
178
  F5TTSMLXModel,
179
+ MeloTTSModel,
180
+ KokoroModel,
176
181
  ],
177
182
  AudioModelDescription,
178
183
  ]:
@@ -188,6 +193,8 @@ def create_audio_model_instance(
188
193
  FishSpeechModel,
189
194
  F5TTSModel,
190
195
  F5TTSMLXModel,
196
+ MeloTTSModel,
197
+ KokoroModel,
191
198
  ]
192
199
  if model_spec.model_family == "whisper":
193
200
  if not model_spec.engine:
@@ -206,6 +213,10 @@ def create_audio_model_instance(
206
213
  model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
207
214
  elif model_spec.model_family == "F5-TTS-MLX":
208
215
  model = F5TTSMLXModel(model_uid, model_path, model_spec, **kwargs)
216
+ elif model_spec.model_family == "MeloTTS":
217
+ model = MeloTTSModel(model_uid, model_path, model_spec, **kwargs)
218
+ elif model_spec.model_family == "Kokoro":
219
+ model = KokoroModel(model_uid, model_path, model_spec, **kwargs)
209
220
  else:
210
221
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
211
222
  model_description = AudioModelDescription(
@@ -49,8 +49,11 @@ class CosyVoiceModel:
49
49
  import os
50
50
  import sys
51
51
 
52
+ import torch
53
+
52
54
  # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
53
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
55
+ thirdparty_dir = os.path.join(os.path.dirname(__file__), "../../thirdparty")
56
+ sys.path.insert(0, thirdparty_dir)
54
57
 
55
58
  if "CosyVoice2" in self._model_spec.model_name:
56
59
  from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
@@ -61,9 +64,17 @@ class CosyVoiceModel:
61
64
 
62
65
  self._is_cosyvoice2 = False
63
66
 
64
- self._model = CosyVoice(
65
- self._model_path, load_jit=self._kwargs.get("load_jit", False)
67
+ # Unify this configuration name as 'compile' to be compatible with the name 'load_jit'.
68
+ load_jit = self._kwargs.get("load_jit", False) or self._kwargs.get(
69
+ "compile", False
66
70
  )
71
+ logger.info("Loading CosyVoice model, compile=%s...", load_jit)
72
+ self._model = CosyVoice(self._model_path, load_jit=load_jit)
73
+ if self._is_cosyvoice2:
74
+ spk2info_file = os.path.join(thirdparty_dir, "cosyvoice/bin/spk2info.pt")
75
+ self._model.frontend.spk2info = torch.load(
76
+ spk2info_file, map_location=self._device
77
+ )
67
78
 
68
79
  def _speech_handle(
69
80
  self,
@@ -101,10 +112,10 @@ class CosyVoiceModel:
101
112
  input, prompt_speech_16k, stream=stream
102
113
  )
103
114
  else:
104
- assert not self._is_cosyvoice2
105
115
  available_speakers = self._model.list_avaliable_spks()
106
116
  if not voice:
107
117
  voice = available_speakers[0]
118
+ logger.info("Auto select speaker: %s", voice)
108
119
  else:
109
120
  assert (
110
121
  voice in available_speakers
@@ -184,7 +195,7 @@ class CosyVoiceModel:
184
195
  prompt_text is None
185
196
  ), "CosyVoice Instruct model does not support prompt_text"
186
197
  elif self._is_cosyvoice2:
187
- assert prompt_speech is not None, "CosyVoice2 requires prompt_speech"
198
+ pass
188
199
  else:
189
200
  # inference_zero_shot
190
201
  # inference_cross_lingual