xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -7
- xinference/client/handlers.py +3 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +2 -0
- xinference/core/scheduler.py +4 -7
- xinference/core/supervisor.py +114 -23
- xinference/core/worker.py +70 -4
- xinference/deploy/local.py +2 -1
- xinference/model/audio/core.py +11 -0
- xinference/model/audio/cosyvoice.py +16 -5
- xinference/model/audio/kokoro.py +139 -0
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +80 -0
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/llm/llama_cpp/core.py +21 -14
- xinference/model/llm/llm_family.json +527 -1
- xinference/model/llm/llm_family.py +4 -1
- xinference/model/llm/llm_family_modelscope.json +495 -3
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +24 -6
- xinference/model/llm/transformers/core.py +9 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -3
- xinference/model/llm/transformers/utils.py +22 -11
- xinference/model/llm/utils.py +115 -1
- xinference/model/llm/vllm/core.py +14 -4
- xinference/model/llm/vllm/xavier/block.py +3 -4
- xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/executor.py +18 -16
- xinference/model/llm/vllm/xavier/scheduler.py +79 -63
- xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
- xinference/model/llm/vllm/xavier/transfer.py +53 -32
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/melo/__init__.py +0 -0
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
- /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2025-
|
|
11
|
+
"date": "2025-02-08T17:06:47+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.2.
|
|
14
|
+
"full-revisionid": "ac97a13a831de6debda52e6fdb8c1bf9366be57c",
|
|
15
|
+
"version": "1.2.2"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -2000,25 +2000,22 @@ class RESTfulAPI(CancelMixin):
|
|
|
2000
2000
|
|
|
2001
2001
|
from ..model.llm.utils import (
|
|
2002
2002
|
GLM4_TOOL_CALL_FAMILY,
|
|
2003
|
-
LLAMA3_TOOL_CALL_FAMILY,
|
|
2004
2003
|
QWEN_TOOL_CALL_FAMILY,
|
|
2004
|
+
TOOL_CALL_FAMILY,
|
|
2005
2005
|
)
|
|
2006
2006
|
|
|
2007
2007
|
model_family = desc.get("model_family", "")
|
|
2008
|
-
function_call_models = (
|
|
2009
|
-
QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY + LLAMA3_TOOL_CALL_FAMILY
|
|
2010
|
-
)
|
|
2011
2008
|
|
|
2012
|
-
if model_family not in
|
|
2009
|
+
if model_family not in TOOL_CALL_FAMILY:
|
|
2013
2010
|
if body.tools:
|
|
2014
2011
|
raise HTTPException(
|
|
2015
2012
|
status_code=400,
|
|
2016
|
-
detail=f"Only {
|
|
2013
|
+
detail=f"Only {TOOL_CALL_FAMILY} support tool calls",
|
|
2017
2014
|
)
|
|
2018
2015
|
if has_tool_message:
|
|
2019
2016
|
raise HTTPException(
|
|
2020
2017
|
status_code=400,
|
|
2021
|
-
detail=f"Only {
|
|
2018
|
+
detail=f"Only {TOOL_CALL_FAMILY} support tool messages",
|
|
2022
2019
|
)
|
|
2023
2020
|
if body.tools and body.stream:
|
|
2024
2021
|
is_vllm = await model.is_vllm_backend()
|
xinference/client/handlers.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import base64
|
|
16
|
+
import html
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
18
19
|
from io import BytesIO
|
|
@@ -137,7 +138,11 @@ class GradioInterface:
|
|
|
137
138
|
if "content" not in delta:
|
|
138
139
|
continue
|
|
139
140
|
else:
|
|
140
|
-
|
|
141
|
+
# some model like deepseek-r1-distill-qwen
|
|
142
|
+
# will generate <think>...</think> ...
|
|
143
|
+
# in gradio, no output will be rendered,
|
|
144
|
+
# thus escape html tags in advance
|
|
145
|
+
response_content += html.escape(delta["content"])
|
|
141
146
|
yield response_content
|
|
142
147
|
|
|
143
148
|
yield response_content
|
xinference/core/model.py
CHANGED
|
@@ -35,6 +35,7 @@ from typing import (
|
|
|
35
35
|
List,
|
|
36
36
|
Optional,
|
|
37
37
|
Union,
|
|
38
|
+
no_type_check,
|
|
38
39
|
)
|
|
39
40
|
|
|
40
41
|
import sse_starlette.sse
|
|
@@ -302,6 +303,7 @@ class ModelActor(xo.StatelessActor, CancelMixin):
|
|
|
302
303
|
def decrease_serve_count(self):
|
|
303
304
|
self._serve_count -= 1
|
|
304
305
|
|
|
306
|
+
@no_type_check
|
|
305
307
|
async def start_transfer_for_vllm(self, rank_addresses: List[str]):
|
|
306
308
|
from ..model.llm.vllm.core import VLLMModel
|
|
307
309
|
from ..model.llm.vllm.xavier.transfer import TransferActor
|
xinference/core/scheduler.py
CHANGED
|
@@ -269,16 +269,13 @@ class InferenceRequest:
|
|
|
269
269
|
)
|
|
270
270
|
|
|
271
271
|
|
|
272
|
-
def _get_valid_batch_kv_cache(
|
|
273
|
-
from transformers.cache_utils import DynamicCache
|
|
274
|
-
|
|
275
|
-
cache = DynamicCache.from_legacy_cache(data)
|
|
272
|
+
def _get_valid_batch_kv_cache(cache, skipped_indexes: Set[int]):
|
|
276
273
|
batch_size = cache.key_cache[0].shape[0]
|
|
277
274
|
batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
|
|
278
275
|
for idx in range(len(cache)):
|
|
279
|
-
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
|
|
280
|
-
cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
|
|
281
|
-
return cache
|
|
276
|
+
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::].contiguous()
|
|
277
|
+
cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::].contiguous()
|
|
278
|
+
return cache
|
|
282
279
|
|
|
283
280
|
|
|
284
281
|
class SchedulerActor(xo.StatelessActor):
|
xinference/core/supervisor.py
CHANGED
|
@@ -268,8 +268,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
268
268
|
)
|
|
269
269
|
|
|
270
270
|
from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
|
|
271
|
+
from ..model.llm.vllm.xavier.collective_manager import CollectiveManager
|
|
271
272
|
|
|
272
|
-
self.
|
|
273
|
+
self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {}
|
|
274
|
+
self._collective_manager_mapping: Dict[
|
|
275
|
+
str, xo.ActorRefType[CollectiveManager]
|
|
276
|
+
] = {}
|
|
273
277
|
|
|
274
278
|
@typing.no_type_check
|
|
275
279
|
async def get_cluster_device_info(self, detailed: bool = False) -> List:
|
|
@@ -960,26 +964,40 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
960
964
|
]:
|
|
961
965
|
raise ValueError("Tensorizer is not supported for %s." % model_name)
|
|
962
966
|
|
|
967
|
+
if model_uid is None:
|
|
968
|
+
model_uid = self._gen_model_uid(model_name)
|
|
969
|
+
|
|
970
|
+
# Xavier-related
|
|
963
971
|
enable_xavier: bool = (
|
|
964
972
|
bool(kwargs.pop("enable_xavier", False))
|
|
965
973
|
and model_engine is not None
|
|
966
974
|
and model_engine.lower() == "vllm"
|
|
967
975
|
)
|
|
976
|
+
store_address = None
|
|
977
|
+
store_port = None
|
|
978
|
+
world_size = None
|
|
968
979
|
if enable_xavier:
|
|
969
980
|
if replica <= 1:
|
|
970
981
|
logger.warning(f"Enabling xavier when `replica<=1` is meaningless.")
|
|
971
982
|
enable_xavier = False
|
|
972
983
|
else:
|
|
973
984
|
from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
|
|
985
|
+
from ..model.llm.vllm.xavier.collective_manager import CollectiveManager
|
|
974
986
|
|
|
975
|
-
self.
|
|
987
|
+
self._block_tracker_mapping[model_uid] = await xo.create_actor(
|
|
976
988
|
VLLMBlockTracker,
|
|
977
989
|
address=self.address,
|
|
978
|
-
uid=VLLMBlockTracker.default_uid(),
|
|
990
|
+
uid=f"{VLLMBlockTracker.default_uid()}-{model_uid}",
|
|
979
991
|
)
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
992
|
+
world_size = replica + 1
|
|
993
|
+
logger.info(f"Going to start xavier with world size: {world_size}")
|
|
994
|
+
self._collective_manager_mapping[model_uid] = await xo.create_actor(
|
|
995
|
+
CollectiveManager,
|
|
996
|
+
address=self.address,
|
|
997
|
+
uid=f"{CollectiveManager.default_uid()}-{model_uid}",
|
|
998
|
+
model_uid=model_uid,
|
|
999
|
+
)
|
|
1000
|
+
logger.info(f"Start collective manager for {model_uid} done.")
|
|
983
1001
|
|
|
984
1002
|
model_size = str(model_size_in_billions) if model_size_in_billions else ""
|
|
985
1003
|
logger.debug(
|
|
@@ -988,13 +1006,38 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
988
1006
|
f"kwargs: {kwargs}"
|
|
989
1007
|
)
|
|
990
1008
|
|
|
991
|
-
async def _launch_one_model(
|
|
992
|
-
worker_ref, _replica_model_uid, rank: int, store_port: int
|
|
993
|
-
):
|
|
1009
|
+
async def _launch_one_model(worker_ref, _replica_model_uid, rank: int):
|
|
994
1010
|
if _replica_model_uid in self._replica_model_uid_to_worker:
|
|
995
1011
|
raise ValueError(
|
|
996
1012
|
f"Model is already in the model list, uid: {_replica_model_uid}"
|
|
997
1013
|
)
|
|
1014
|
+
|
|
1015
|
+
nonlocal store_address
|
|
1016
|
+
nonlocal store_port
|
|
1017
|
+
xavier_config = (
|
|
1018
|
+
{
|
|
1019
|
+
"block_tracker_uid": self._block_tracker_mapping[model_uid].uid,
|
|
1020
|
+
"block_tracker_address": self._block_tracker_mapping[
|
|
1021
|
+
model_uid
|
|
1022
|
+
].address,
|
|
1023
|
+
"rank": rank,
|
|
1024
|
+
"world_size": world_size,
|
|
1025
|
+
"store_address": store_address,
|
|
1026
|
+
"store_port": store_port,
|
|
1027
|
+
}
|
|
1028
|
+
if enable_xavier
|
|
1029
|
+
else None
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
if enable_xavier and rank == 0:
|
|
1033
|
+
rank0_address, _port = await worker_ref.launch_rank0_model(
|
|
1034
|
+
_replica_model_uid, xavier_config
|
|
1035
|
+
)
|
|
1036
|
+
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
|
|
1037
|
+
store_address = rank0_address.split(":")[0]
|
|
1038
|
+
store_port = _port
|
|
1039
|
+
return rank0_address
|
|
1040
|
+
|
|
998
1041
|
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx)
|
|
999
1042
|
nonlocal model_type
|
|
1000
1043
|
|
|
@@ -1014,17 +1057,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1014
1057
|
gpu_idx=replica_gpu_idx,
|
|
1015
1058
|
download_hub=download_hub,
|
|
1016
1059
|
model_path=model_path,
|
|
1017
|
-
xavier_config=
|
|
1018
|
-
"block_tracker_address": self._block_tracker.address
|
|
1019
|
-
if self._block_tracker is not None
|
|
1020
|
-
else None,
|
|
1021
|
-
"rank": rank,
|
|
1022
|
-
"world_size": replica,
|
|
1023
|
-
"store_address": self.address.split(":")[0],
|
|
1024
|
-
"store_port": store_port,
|
|
1025
|
-
}
|
|
1026
|
-
if enable_xavier
|
|
1027
|
-
else None,
|
|
1060
|
+
xavier_config=xavier_config,
|
|
1028
1061
|
**kwargs,
|
|
1029
1062
|
)
|
|
1030
1063
|
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
|
|
@@ -1032,10 +1065,9 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1032
1065
|
|
|
1033
1066
|
async def _launch_model():
|
|
1034
1067
|
try:
|
|
1035
|
-
store_port = xo.utils.get_next_port()
|
|
1036
1068
|
worker_refs = []
|
|
1037
1069
|
rank_addresses = []
|
|
1038
|
-
for
|
|
1070
|
+
for _idx, rep_model_uid in enumerate(
|
|
1039
1071
|
iter_replica_model_uid(model_uid, replica)
|
|
1040
1072
|
):
|
|
1041
1073
|
worker_ref = (
|
|
@@ -1043,8 +1075,18 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1043
1075
|
if target_ip_worker_ref is not None
|
|
1044
1076
|
else await self._choose_worker()
|
|
1045
1077
|
)
|
|
1078
|
+
if enable_xavier and _idx == 0:
|
|
1079
|
+
"""
|
|
1080
|
+
Start the rank 0 model actor on the worker that holds the rank 1 replica,
|
|
1081
|
+
solely for constructing the collective communication world.
|
|
1082
|
+
"""
|
|
1083
|
+
_uid = model_uid + "-rank0"
|
|
1084
|
+
rank0_address = await _launch_one_model(worker_ref, _uid, 0)
|
|
1085
|
+
worker_refs.append((worker_ref, _uid))
|
|
1086
|
+
rank_addresses.append(rank0_address)
|
|
1087
|
+
|
|
1046
1088
|
subpool_address = await _launch_one_model(
|
|
1047
|
-
worker_ref, rep_model_uid,
|
|
1089
|
+
worker_ref, rep_model_uid, _idx + 1
|
|
1048
1090
|
)
|
|
1049
1091
|
worker_refs.append((worker_ref, rep_model_uid))
|
|
1050
1092
|
rank_addresses.append(subpool_address)
|
|
@@ -1054,6 +1096,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1054
1096
|
# because the transfer actor needs all the rank addresses used for collective communication
|
|
1055
1097
|
if enable_xavier:
|
|
1056
1098
|
logger.debug(f"Init transfer component for xavier...")
|
|
1099
|
+
collective_manager_ref = self._collective_manager_mapping[model_uid]
|
|
1057
1100
|
tasks = []
|
|
1058
1101
|
for worker_ref, rep_model_uid in worker_refs:
|
|
1059
1102
|
tasks.append(
|
|
@@ -1064,6 +1107,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1064
1107
|
# Here you must use asyncio.gather, not a for loop,
|
|
1065
1108
|
# or you will get stuck.
|
|
1066
1109
|
await asyncio.gather(*tasks)
|
|
1110
|
+
|
|
1111
|
+
# init collective_manager
|
|
1112
|
+
for idx, addr in enumerate(rank_addresses):
|
|
1113
|
+
await collective_manager_ref.register_rank(
|
|
1114
|
+
idx, addr, update=False
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1067
1117
|
logger.debug(f"Init transfer component for xavier done.")
|
|
1068
1118
|
except Exception:
|
|
1069
1119
|
# terminate_model will remove the replica info.
|
|
@@ -1193,6 +1243,38 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1193
1243
|
raise
|
|
1194
1244
|
self._model_uid_to_replica_info.pop(model_uid, None)
|
|
1195
1245
|
|
|
1246
|
+
# clear for xavier
|
|
1247
|
+
rank0_uid = model_uid + "-rank0"
|
|
1248
|
+
if rank0_uid in self._replica_model_uid_to_worker:
|
|
1249
|
+
await _terminate_one_model(rank0_uid)
|
|
1250
|
+
|
|
1251
|
+
collective_manager_ref = self._collective_manager_mapping.pop(model_uid, None)
|
|
1252
|
+
if collective_manager_ref is not None:
|
|
1253
|
+
try:
|
|
1254
|
+
await xo.destroy_actor(collective_manager_ref)
|
|
1255
|
+
except Exception as e:
|
|
1256
|
+
logger.debug(
|
|
1257
|
+
"Destroy collective_manager_ref failed, model uid: %s, error: %s",
|
|
1258
|
+
model_uid,
|
|
1259
|
+
e,
|
|
1260
|
+
)
|
|
1261
|
+
finally:
|
|
1262
|
+
logger.debug(
|
|
1263
|
+
f"Destroy collective_manager_ref done. model uid: {model_uid}"
|
|
1264
|
+
)
|
|
1265
|
+
block_tracker_ref = self._block_tracker_mapping.pop(model_uid, None)
|
|
1266
|
+
if block_tracker_ref is not None:
|
|
1267
|
+
try:
|
|
1268
|
+
await xo.destroy_actor(block_tracker_ref)
|
|
1269
|
+
except Exception as e:
|
|
1270
|
+
logger.debug(
|
|
1271
|
+
"Destroy block_tracker_ref failed, model uid: %s, error: %s",
|
|
1272
|
+
model_uid,
|
|
1273
|
+
e,
|
|
1274
|
+
)
|
|
1275
|
+
finally:
|
|
1276
|
+
logger.debug(f"Destroy block_tracker_ref done. model uid: {model_uid}")
|
|
1277
|
+
|
|
1196
1278
|
@log_async(logger=logger)
|
|
1197
1279
|
async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
|
|
1198
1280
|
replica_info = self._model_uid_to_replica_info.get(model_uid, None)
|
|
@@ -1448,3 +1530,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1448
1530
|
|
|
1449
1531
|
async def get_progress(self, request_id: str) -> float:
|
|
1450
1532
|
return await self._progress_tracker.get_progress(request_id)
|
|
1533
|
+
|
|
1534
|
+
async def call_collective_manager(
|
|
1535
|
+
self, model_uid: str, func_name: str, *args, **kwargs
|
|
1536
|
+
):
|
|
1537
|
+
"""
|
|
1538
|
+
Used by worker.
|
|
1539
|
+
"""
|
|
1540
|
+
collective_manager_ref = self._collective_manager_mapping[model_uid]
|
|
1541
|
+
await getattr(collective_manager_ref, func_name)(*args, **kwargs)
|
xinference/core/worker.py
CHANGED
|
@@ -24,7 +24,7 @@ import time
|
|
|
24
24
|
from collections import defaultdict
|
|
25
25
|
from dataclasses import dataclass
|
|
26
26
|
from logging import getLogger
|
|
27
|
-
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
|
27
|
+
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, no_type_check
|
|
28
28
|
|
|
29
29
|
import xoscar as xo
|
|
30
30
|
from async_timeout import timeout
|
|
@@ -184,12 +184,12 @@ class WorkerActor(xo.StatelessActor):
|
|
|
184
184
|
self._model_uid_to_recover_count[model_uid] = (
|
|
185
185
|
recover_count - 1
|
|
186
186
|
)
|
|
187
|
-
await self.
|
|
187
|
+
await self.recover_model(launch_args)
|
|
188
188
|
else:
|
|
189
189
|
logger.warning("Stop recreating model actor.")
|
|
190
190
|
else:
|
|
191
191
|
logger.warning("Recreating model actor %s ...", model_uid)
|
|
192
|
-
await self.
|
|
192
|
+
await self.recover_model(launch_args)
|
|
193
193
|
break
|
|
194
194
|
|
|
195
195
|
@classmethod
|
|
@@ -940,7 +940,11 @@ class WorkerActor(xo.StatelessActor):
|
|
|
940
940
|
# Terminate model while its launching is not allow
|
|
941
941
|
if model_uid in self._model_uid_launching_guard:
|
|
942
942
|
raise ValueError(f"{model_uid} is launching")
|
|
943
|
-
|
|
943
|
+
# In special cases, if the suffix is `-rank0`, this is the Xavier's rank 0 model actor.
|
|
944
|
+
if model_uid.endswith("-rank0"):
|
|
945
|
+
origin_uid = model_uid.removesuffix("-rank0")
|
|
946
|
+
else:
|
|
947
|
+
origin_uid, _ = parse_replica_model_uid(model_uid)
|
|
944
948
|
try:
|
|
945
949
|
_ = await self.get_supervisor_ref()
|
|
946
950
|
if self._event_collector_ref is not None:
|
|
@@ -1173,3 +1177,65 @@ class WorkerActor(xo.StatelessActor):
|
|
|
1173
1177
|
):
|
|
1174
1178
|
model_ref = self._model_uid_to_model[rep_model_uid]
|
|
1175
1179
|
await model_ref.start_transfer_for_vllm(rank_addresses)
|
|
1180
|
+
|
|
1181
|
+
@log_async(logger=logger, level=logging.INFO)
|
|
1182
|
+
async def launch_rank0_model(
|
|
1183
|
+
self, rep_model_uid: str, xavier_config: Dict[str, Any]
|
|
1184
|
+
) -> Tuple[str, int]:
|
|
1185
|
+
from ..model.llm.vllm.xavier.collective_manager import Rank0ModelActor
|
|
1186
|
+
|
|
1187
|
+
if os.name != "nt" and platform.system() != "Darwin":
|
|
1188
|
+
# Linux
|
|
1189
|
+
start_method = "forkserver"
|
|
1190
|
+
else:
|
|
1191
|
+
# Windows and macOS
|
|
1192
|
+
start_method = "spawn"
|
|
1193
|
+
subpool_address = await self._main_pool.append_sub_pool(
|
|
1194
|
+
start_method=start_method
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
store_address = subpool_address.split(":")[0]
|
|
1198
|
+
# Note that `store_port` needs to be generated on the worker,
|
|
1199
|
+
# as the TCP store is on rank 0, not on the supervisor.
|
|
1200
|
+
store_port = xo.utils.get_next_port()
|
|
1201
|
+
self._model_uid_launching_guard[rep_model_uid] = True
|
|
1202
|
+
try:
|
|
1203
|
+
try:
|
|
1204
|
+
xavier_config["rank_address"] = subpool_address
|
|
1205
|
+
xavier_config["store_address"] = store_address
|
|
1206
|
+
xavier_config["store_port"] = store_port
|
|
1207
|
+
model_ref = await xo.create_actor(
|
|
1208
|
+
Rank0ModelActor,
|
|
1209
|
+
address=subpool_address,
|
|
1210
|
+
uid=rep_model_uid,
|
|
1211
|
+
xavier_config=xavier_config,
|
|
1212
|
+
)
|
|
1213
|
+
except:
|
|
1214
|
+
await self._main_pool.remove_sub_pool(subpool_address)
|
|
1215
|
+
raise
|
|
1216
|
+
self._model_uid_to_model[rep_model_uid] = model_ref
|
|
1217
|
+
self._model_uid_to_addr[rep_model_uid] = subpool_address
|
|
1218
|
+
finally:
|
|
1219
|
+
del self._model_uid_launching_guard[rep_model_uid]
|
|
1220
|
+
return subpool_address, store_port
|
|
1221
|
+
|
|
1222
|
+
@no_type_check
|
|
1223
|
+
async def recover_model(self, launch_args: Dict[str, Any]):
|
|
1224
|
+
rep_model_uid = launch_args.get("model_uid")
|
|
1225
|
+
origin_uid, _ = parse_replica_model_uid(rep_model_uid)
|
|
1226
|
+
xavier_config: Optional[Dict[str, Any]] = launch_args.get("xavier_config", None)
|
|
1227
|
+
is_xavier: bool = xavier_config is not None
|
|
1228
|
+
supervisor_ref = await self.get_supervisor_ref(add_worker=False)
|
|
1229
|
+
if is_xavier:
|
|
1230
|
+
rank = xavier_config.get("rank")
|
|
1231
|
+
await supervisor_ref.call_collective_manager(
|
|
1232
|
+
origin_uid, "unregister_rank", rank
|
|
1233
|
+
)
|
|
1234
|
+
subpool_address = await self.launch_builtin_model(**launch_args)
|
|
1235
|
+
if is_xavier:
|
|
1236
|
+
model_ref = self._model_uid_to_model[rep_model_uid]
|
|
1237
|
+
await model_ref.start_transfer_for_vllm([])
|
|
1238
|
+
rank = xavier_config.get("rank")
|
|
1239
|
+
await supervisor_ref.call_collective_manager(
|
|
1240
|
+
origin_uid, "register_rank", rank, subpool_address, update=True
|
|
1241
|
+
)
|
xinference/deploy/local.py
CHANGED
xinference/model/audio/core.py
CHANGED
|
@@ -25,6 +25,8 @@ from .f5tts import F5TTSModel
|
|
|
25
25
|
from .f5tts_mlx import F5TTSMLXModel
|
|
26
26
|
from .fish_speech import FishSpeechModel
|
|
27
27
|
from .funasr import FunASRModel
|
|
28
|
+
from .kokoro import KokoroModel
|
|
29
|
+
from .melotts import MeloTTSModel
|
|
28
30
|
from .whisper import WhisperModel
|
|
29
31
|
from .whisper_mlx import WhisperMLXModel
|
|
30
32
|
|
|
@@ -48,6 +50,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
48
50
|
model_id: str
|
|
49
51
|
model_revision: Optional[str]
|
|
50
52
|
multilingual: bool
|
|
53
|
+
language: Optional[str]
|
|
51
54
|
model_ability: Optional[str]
|
|
52
55
|
default_model_config: Optional[Dict[str, Any]]
|
|
53
56
|
default_transcription_config: Optional[Dict[str, Any]]
|
|
@@ -173,6 +176,8 @@ def create_audio_model_instance(
|
|
|
173
176
|
FishSpeechModel,
|
|
174
177
|
F5TTSModel,
|
|
175
178
|
F5TTSMLXModel,
|
|
179
|
+
MeloTTSModel,
|
|
180
|
+
KokoroModel,
|
|
176
181
|
],
|
|
177
182
|
AudioModelDescription,
|
|
178
183
|
]:
|
|
@@ -188,6 +193,8 @@ def create_audio_model_instance(
|
|
|
188
193
|
FishSpeechModel,
|
|
189
194
|
F5TTSModel,
|
|
190
195
|
F5TTSMLXModel,
|
|
196
|
+
MeloTTSModel,
|
|
197
|
+
KokoroModel,
|
|
191
198
|
]
|
|
192
199
|
if model_spec.model_family == "whisper":
|
|
193
200
|
if not model_spec.engine:
|
|
@@ -206,6 +213,10 @@ def create_audio_model_instance(
|
|
|
206
213
|
model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
207
214
|
elif model_spec.model_family == "F5-TTS-MLX":
|
|
208
215
|
model = F5TTSMLXModel(model_uid, model_path, model_spec, **kwargs)
|
|
216
|
+
elif model_spec.model_family == "MeloTTS":
|
|
217
|
+
model = MeloTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
218
|
+
elif model_spec.model_family == "Kokoro":
|
|
219
|
+
model = KokoroModel(model_uid, model_path, model_spec, **kwargs)
|
|
209
220
|
else:
|
|
210
221
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
211
222
|
model_description = AudioModelDescription(
|
|
@@ -49,8 +49,11 @@ class CosyVoiceModel:
|
|
|
49
49
|
import os
|
|
50
50
|
import sys
|
|
51
51
|
|
|
52
|
+
import torch
|
|
53
|
+
|
|
52
54
|
# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
|
|
53
|
-
|
|
55
|
+
thirdparty_dir = os.path.join(os.path.dirname(__file__), "../../thirdparty")
|
|
56
|
+
sys.path.insert(0, thirdparty_dir)
|
|
54
57
|
|
|
55
58
|
if "CosyVoice2" in self._model_spec.model_name:
|
|
56
59
|
from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
|
|
@@ -61,9 +64,17 @@ class CosyVoiceModel:
|
|
|
61
64
|
|
|
62
65
|
self._is_cosyvoice2 = False
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
67
|
+
# Unify this configuration name as 'compile' to be compatible with the name 'load_jit'.
|
|
68
|
+
load_jit = self._kwargs.get("load_jit", False) or self._kwargs.get(
|
|
69
|
+
"compile", False
|
|
66
70
|
)
|
|
71
|
+
logger.info("Loading CosyVoice model, compile=%s...", load_jit)
|
|
72
|
+
self._model = CosyVoice(self._model_path, load_jit=load_jit)
|
|
73
|
+
if self._is_cosyvoice2:
|
|
74
|
+
spk2info_file = os.path.join(thirdparty_dir, "cosyvoice/bin/spk2info.pt")
|
|
75
|
+
self._model.frontend.spk2info = torch.load(
|
|
76
|
+
spk2info_file, map_location=self._device
|
|
77
|
+
)
|
|
67
78
|
|
|
68
79
|
def _speech_handle(
|
|
69
80
|
self,
|
|
@@ -101,10 +112,10 @@ class CosyVoiceModel:
|
|
|
101
112
|
input, prompt_speech_16k, stream=stream
|
|
102
113
|
)
|
|
103
114
|
else:
|
|
104
|
-
assert not self._is_cosyvoice2
|
|
105
115
|
available_speakers = self._model.list_avaliable_spks()
|
|
106
116
|
if not voice:
|
|
107
117
|
voice = available_speakers[0]
|
|
118
|
+
logger.info("Auto select speaker: %s", voice)
|
|
108
119
|
else:
|
|
109
120
|
assert (
|
|
110
121
|
voice in available_speakers
|
|
@@ -184,7 +195,7 @@ class CosyVoiceModel:
|
|
|
184
195
|
prompt_text is None
|
|
185
196
|
), "CosyVoice Instruct model does not support prompt_text"
|
|
186
197
|
elif self._is_cosyvoice2:
|
|
187
|
-
|
|
198
|
+
pass
|
|
188
199
|
else:
|
|
189
200
|
# inference_zero_shot
|
|
190
201
|
# inference_cross_lingual
|