xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -7
- xinference/client/handlers.py +3 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +2 -0
- xinference/core/scheduler.py +4 -7
- xinference/core/supervisor.py +114 -23
- xinference/core/worker.py +70 -4
- xinference/deploy/local.py +2 -1
- xinference/model/audio/core.py +11 -0
- xinference/model/audio/cosyvoice.py +16 -5
- xinference/model/audio/kokoro.py +139 -0
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +80 -0
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/llm/llama_cpp/core.py +21 -14
- xinference/model/llm/llm_family.json +527 -1
- xinference/model/llm/llm_family.py +4 -1
- xinference/model/llm/llm_family_modelscope.json +495 -3
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +24 -6
- xinference/model/llm/transformers/core.py +9 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -3
- xinference/model/llm/transformers/utils.py +22 -11
- xinference/model/llm/utils.py +115 -1
- xinference/model/llm/vllm/core.py +14 -4
- xinference/model/llm/vllm/xavier/block.py +3 -4
- xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/executor.py +18 -16
- xinference/model/llm/vllm/xavier/scheduler.py +79 -63
- xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
- xinference/model/llm/vllm/xavier/transfer.py +53 -32
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/melo/__init__.py +0 -0
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
- /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import logging
|
|
16
|
+
import traceback
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, no_type_check
|
|
18
|
+
|
|
19
|
+
import xoscar as xo
|
|
20
|
+
|
|
21
|
+
from .block_tracker import VLLMBlockTracker
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from .transfer import Rank0TransferActor, TransferActor
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Rank0ModelActor(xo.StatelessActor):
|
|
31
|
+
@classmethod
|
|
32
|
+
def default_uid(cls):
|
|
33
|
+
return "rank0-model-actor"
|
|
34
|
+
|
|
35
|
+
def __init__(self, xavier_config: Dict[str, Any]):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self._rank = 0
|
|
38
|
+
self._xavier_config = xavier_config
|
|
39
|
+
self._transfer_ref: Optional[xo.ActorRefType["Rank0TransferActor"]] = None
|
|
40
|
+
|
|
41
|
+
async def __pre_destroy__(self):
|
|
42
|
+
if self._transfer_ref is not None:
|
|
43
|
+
try:
|
|
44
|
+
await xo.destroy_actor(self._transfer_ref)
|
|
45
|
+
del self._transfer_ref
|
|
46
|
+
except Exception as e:
|
|
47
|
+
logger.debug(
|
|
48
|
+
f"Destroy transfer actor failed, rank: {self._rank}, address: {self.address}, error: {e}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@no_type_check
|
|
52
|
+
async def start_transfer_for_vllm(self, rank_addresses: List[str]):
|
|
53
|
+
from .transfer import Rank0TransferActor
|
|
54
|
+
|
|
55
|
+
self._transfer_ref = await xo.create_actor(
|
|
56
|
+
Rank0TransferActor,
|
|
57
|
+
address=self.address,
|
|
58
|
+
uid=f"{Rank0TransferActor.default_uid()}-{self._rank}",
|
|
59
|
+
rank=self._rank,
|
|
60
|
+
world_size=self._xavier_config.get("world_size"), # type: ignore
|
|
61
|
+
rank_address=self._xavier_config.get("rank_address"), # type: ignore
|
|
62
|
+
store_address=self._xavier_config.get("store_address"), # type: ignore
|
|
63
|
+
store_port=self._xavier_config.get("store_port"), # type: ignore
|
|
64
|
+
world_addresses=rank_addresses,
|
|
65
|
+
)
|
|
66
|
+
logger.debug(
|
|
67
|
+
f"Init transfer actor: {self._transfer_ref.address}, rank: {self._rank} done for vllm." # type: ignore
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def with_lock(method):
|
|
72
|
+
async def wrapper(self, *args, **kwargs):
|
|
73
|
+
async with self._lock:
|
|
74
|
+
return await method(self, *args, **kwargs)
|
|
75
|
+
|
|
76
|
+
return wrapper
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class CollectiveManager(xo.StatelessActor):
|
|
80
|
+
@classmethod
|
|
81
|
+
def default_uid(cls):
|
|
82
|
+
return f"xavier-collective-manager"
|
|
83
|
+
|
|
84
|
+
def __init__(self, model_uid: str):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self._model_uid = model_uid
|
|
87
|
+
self._tracker_ref: Optional[xo.ActorRefType["VLLMBlockTracker"]] = None
|
|
88
|
+
self._rank_to_ref: Dict[int, xo.ActorRefType["TransferActor"]] = {}
|
|
89
|
+
self._lock = asyncio.Lock()
|
|
90
|
+
|
|
91
|
+
async def __post_create__(self):
|
|
92
|
+
self._tracker_ref = await xo.actor_ref(
|
|
93
|
+
address=self.address,
|
|
94
|
+
uid=f"{VLLMBlockTracker.default_uid()}-{self._model_uid}",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
async def unregister_rank(self, rank: int):
|
|
98
|
+
self._rank_to_ref.pop(rank, None)
|
|
99
|
+
await self._tracker_ref.unregister_rank(rank) # type: ignore
|
|
100
|
+
logger.debug(f"Unregister rank: {rank}")
|
|
101
|
+
|
|
102
|
+
async def register_rank(self, rank: int, address: str, update: bool = False):
|
|
103
|
+
from .transfer import TransferActor
|
|
104
|
+
|
|
105
|
+
rank_ref = await xo.actor_ref(
|
|
106
|
+
address=address, uid=f"{TransferActor.default_uid()}-{rank}"
|
|
107
|
+
)
|
|
108
|
+
self._rank_to_ref[rank] = rank_ref
|
|
109
|
+
logger.debug(f"Register rank: {rank}, address: {address}")
|
|
110
|
+
if update:
|
|
111
|
+
await self._update_world()
|
|
112
|
+
await self._tracker_ref.register_rank(rank) # type: ignore
|
|
113
|
+
|
|
114
|
+
@with_lock
|
|
115
|
+
async def _update_world(self):
|
|
116
|
+
"""
|
|
117
|
+
Locking is used to prevent chaos when multiple replicas trigger recovery simultaneously.
|
|
118
|
+
"""
|
|
119
|
+
from .....core.utils import gen_random_string
|
|
120
|
+
|
|
121
|
+
prefix = gen_random_string(6)
|
|
122
|
+
tasks = []
|
|
123
|
+
rank_to_ref = self._rank_to_ref.copy()
|
|
124
|
+
world_addresses = [ref.address for _, ref in sorted(rank_to_ref.items())]
|
|
125
|
+
for rank, ref in rank_to_ref.items():
|
|
126
|
+
tasks.append(ref.connect_full_mesh(prefix, world_addresses))
|
|
127
|
+
try:
|
|
128
|
+
logger.debug(
|
|
129
|
+
f"Rebuild collective communication with world_addresses: {world_addresses}, prefix: {prefix}"
|
|
130
|
+
)
|
|
131
|
+
await asyncio.gather(*tasks)
|
|
132
|
+
logger.debug(
|
|
133
|
+
f"Rebuild collective communication with world_addresses: {world_addresses}, prefix: {prefix} done."
|
|
134
|
+
)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
"""
|
|
137
|
+
The exception here is most likely due to another replica triggering recovery during the recovery process,
|
|
138
|
+
causing `connect_full_mesh` to time out.
|
|
139
|
+
Simply log the exception and
|
|
140
|
+
let the subsequent update process handle the reconstruction of the collective communication world.
|
|
141
|
+
"""
|
|
142
|
+
logger.error(
|
|
143
|
+
f"Rebuild collective communication with world_addresses: {world_addresses} failed. "
|
|
144
|
+
f"Exception: {e}"
|
|
145
|
+
)
|
|
146
|
+
# Print the complete error stack
|
|
147
|
+
traceback.print_exception(type(e), e, e.__traceback__)
|
|
@@ -64,14 +64,13 @@ class XavierExecutor(GPUExecutorAsync):
|
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
async def _get_block_tracker_ref(self):
|
|
67
|
-
from .block_tracker import VLLMBlockTracker
|
|
68
|
-
|
|
69
67
|
if self._block_tracker_ref is None:
|
|
70
68
|
block_tracker_address = self.vllm_config.xavier_config.get(
|
|
71
69
|
"block_tracker_address"
|
|
72
70
|
)
|
|
71
|
+
block_tracker_uid = self.vllm_config.xavier_config.get("block_tracker_uid")
|
|
73
72
|
self._block_tracker_ref = await xo.actor_ref(
|
|
74
|
-
address=block_tracker_address, uid=
|
|
73
|
+
address=block_tracker_address, uid=block_tracker_uid
|
|
75
74
|
)
|
|
76
75
|
return self._block_tracker_ref
|
|
77
76
|
|
|
@@ -86,8 +85,8 @@ class XavierExecutor(GPUExecutorAsync):
|
|
|
86
85
|
)
|
|
87
86
|
return self._transfer_ref
|
|
88
87
|
|
|
89
|
-
def
|
|
90
|
-
return self.vllm_config.xavier_config.get("
|
|
88
|
+
def get_rank(self) -> int:
|
|
89
|
+
return self.vllm_config.xavier_config.get("rank")
|
|
91
90
|
|
|
92
91
|
async def execute_model_async(
|
|
93
92
|
self,
|
|
@@ -100,7 +99,7 @@ class XavierExecutor(GPUExecutorAsync):
|
|
|
100
99
|
virtual_engine = execute_model_req.virtual_engine
|
|
101
100
|
block_tracker_ref = await self._get_block_tracker_ref()
|
|
102
101
|
scheduler = self.scheduler[virtual_engine] # type: ignore
|
|
103
|
-
|
|
102
|
+
rank = self.get_rank()
|
|
104
103
|
executed_blocks_details: Set[Tuple[int, int]] = set()
|
|
105
104
|
for meta in execute_model_req.seq_group_metadata_list:
|
|
106
105
|
block_tables = meta.block_tables
|
|
@@ -117,16 +116,19 @@ class XavierExecutor(GPUExecutorAsync):
|
|
|
117
116
|
|
|
118
117
|
res = await super().execute_model_async(execute_model_req)
|
|
119
118
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
119
|
+
if executed_blocks_details:
|
|
120
|
+
"""
|
|
121
|
+
Why not collect and register the information after execution?
|
|
122
|
+
Because after execution, the model's execution callback hook will release the block_id,
|
|
123
|
+
causing the block manager to lose access to the correct information.
|
|
124
|
+
"""
|
|
125
|
+
await block_tracker_ref.register_blocks(
|
|
126
|
+
virtual_engine, list(executed_blocks_details), rank
|
|
127
|
+
)
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
129
|
+
for _, _id in executed_blocks_details:
|
|
130
|
+
scheduler.block_manager.set_block_status_by_block_id(
|
|
131
|
+
"executed", _id, True
|
|
132
|
+
)
|
|
131
133
|
|
|
132
134
|
return res
|
|
@@ -72,12 +72,11 @@ class XavierScheduler(Scheduler):
|
|
|
72
72
|
self._transfer_status: Dict[SequenceGroup, Set[int]] = {}
|
|
73
73
|
|
|
74
74
|
async def _get_block_tracker_ref(self):
|
|
75
|
-
from .block_tracker import VLLMBlockTracker
|
|
76
|
-
|
|
77
75
|
if self._block_tracker_ref is None:
|
|
78
76
|
block_tracker_address = self._xavier_config.get("block_tracker_address")
|
|
77
|
+
block_tracker_uid = self._xavier_config.get("block_tracker_uid")
|
|
79
78
|
self._block_tracker_ref = await xo.actor_ref(
|
|
80
|
-
address=block_tracker_address, uid=
|
|
79
|
+
address=block_tracker_address, uid=block_tracker_uid
|
|
81
80
|
)
|
|
82
81
|
return self._block_tracker_ref
|
|
83
82
|
|
|
@@ -97,7 +96,12 @@ class XavierScheduler(Scheduler):
|
|
|
97
96
|
virtual_engine: int,
|
|
98
97
|
block_tables: Dict[int, List[int]],
|
|
99
98
|
seq_group: SequenceGroup,
|
|
100
|
-
) -> Tuple[Set[int], Dict[
|
|
99
|
+
) -> Tuple[Set[int], Dict[int, Set[Tuple[int, int, int]]]]:
|
|
100
|
+
# If the `seq_group` has the `force_calculation` attribute set to `True`,
|
|
101
|
+
# it indicates that there were issues during the transmission process.
|
|
102
|
+
# In this case, force the computation and exclude it from the Xavier process.
|
|
103
|
+
if getattr(seq_group, "force_calculation", False):
|
|
104
|
+
return set(), dict()
|
|
101
105
|
"""
|
|
102
106
|
Retrieve information from other replicas to check if any blocks have already been computed,
|
|
103
107
|
for the purpose of data transfer.
|
|
@@ -132,48 +136,63 @@ class XavierScheduler(Scheduler):
|
|
|
132
136
|
)
|
|
133
137
|
):
|
|
134
138
|
details.add(detail)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
139
|
+
|
|
140
|
+
if details:
|
|
141
|
+
tracker_ref = await self._get_block_tracker_ref()
|
|
142
|
+
remote = await tracker_ref.query_blocks(virtual_engine, list(details))
|
|
143
|
+
# Not all queried blocks have corresponding results in other replicas.
|
|
144
|
+
# Therefore, it is necessary to record which local block data was actually transferred.
|
|
145
|
+
local: Set[int] = set()
|
|
146
|
+
for _, remote_details in remote.items():
|
|
147
|
+
for _, _, local_block_id in remote_details:
|
|
148
|
+
local.add(local_block_id)
|
|
149
|
+
if local:
|
|
150
|
+
logger.debug(
|
|
151
|
+
f"Data in local blocks: {local} will be transmitted from the remote."
|
|
152
|
+
)
|
|
153
|
+
return local, remote
|
|
154
|
+
else:
|
|
155
|
+
return set(), dict()
|
|
148
156
|
|
|
149
157
|
async def _do_transfer_inner(
|
|
150
|
-
self, virtual_engine: int, remote: Dict[
|
|
158
|
+
self, virtual_engine: int, remote: Dict[int, Set[Tuple[int, int, int]]]
|
|
151
159
|
):
|
|
152
160
|
transfer_ref = await self._get_transfer_ref()
|
|
153
|
-
for
|
|
161
|
+
for from_rank, hash_and_block_id in remote.items():
|
|
154
162
|
src_to_dst: Dict[int, int] = {x[1]: x[2] for x in hash_and_block_id}
|
|
155
|
-
await transfer_ref.recv(virtual_engine,
|
|
163
|
+
await transfer_ref.recv(virtual_engine, from_rank, src_to_dst)
|
|
156
164
|
|
|
157
165
|
async def _do_transfer(
|
|
158
166
|
self,
|
|
159
167
|
virtual_engine: int,
|
|
160
168
|
local: Set[int],
|
|
161
|
-
remote: Dict[
|
|
169
|
+
remote: Dict[int, Set[Tuple[int, int, int]]],
|
|
162
170
|
seq_group: SequenceGroup,
|
|
163
|
-
is_prefill: bool,
|
|
164
171
|
):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
172
|
+
try:
|
|
173
|
+
await self._do_transfer_inner(virtual_engine, remote)
|
|
174
|
+
except Exception as e:
|
|
175
|
+
"""
|
|
176
|
+
The exception here is most likely due to the sender triggering recovery during the transmission process.
|
|
177
|
+
In this case, fallback to performing computation during the prefill stage.
|
|
178
|
+
"""
|
|
179
|
+
logger.error(f"Transfer failed: {e}")
|
|
180
|
+
# Force this `seq_group` to perform computation.
|
|
181
|
+
seq_group.force_calculation = True
|
|
182
|
+
self._transfer_status.pop(seq_group, None)
|
|
173
183
|
self.waiting.appendleft(seq_group)
|
|
184
|
+
self._transferring.remove(seq_group)
|
|
174
185
|
else:
|
|
175
|
-
|
|
176
|
-
|
|
186
|
+
# After the transfer is completed, update the corresponding metadata.
|
|
187
|
+
self._transfer_status[seq_group] = local
|
|
188
|
+
for _id in local:
|
|
189
|
+
self.block_manager.set_block_status_by_block_id(
|
|
190
|
+
"transferred", _id, True
|
|
191
|
+
)
|
|
192
|
+
# After the transfer, place the `seq_group` back into the `waiting` queue to
|
|
193
|
+
# wait for the next scheduling execution.
|
|
194
|
+
self.waiting.appendleft(seq_group)
|
|
195
|
+
self._transferring.remove(seq_group)
|
|
177
196
|
|
|
178
197
|
@no_type_check
|
|
179
198
|
async def schedule(
|
|
@@ -240,39 +259,36 @@ class XavierScheduler(Scheduler):
|
|
|
240
259
|
After completing the scheduling, the blocks have been allocated.
|
|
241
260
|
Therefore, it is possible to check whether some blocks have already been computed on other replicas based on this information,
|
|
242
261
|
and subsequently initiate the transfer.
|
|
262
|
+
According to the internal code comments in vllm,
|
|
263
|
+
whether `token_chunk_size` is 1 can indicate whether the `seq_group` is in the decode or prefill stage.
|
|
264
|
+
It is noted that data transmission is only applied during the prefill stage.
|
|
265
|
+
In the decode stage, it only applies to the last token of the block, which can negatively impact throughput.
|
|
243
266
|
"""
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
249
|
-
# According to the internal code comments in vllm,
|
|
250
|
-
# whether `token_chunk_size` is 1 can indicate whether the `seq_group` is in the decode or prefill stage.
|
|
251
|
-
is_prefill = token_chunk_size != 1
|
|
252
|
-
for seq in running_seqs:
|
|
253
|
-
seq.status = (
|
|
254
|
-
SequenceStatus.WAITING if is_prefill else SequenceStatus.RUNNING
|
|
255
|
-
)
|
|
256
|
-
# Additional attribute `transferred` to mark that this `seq_group` involves a transfer process.
|
|
257
|
-
# During the next scheduling, block allocation will no longer be required
|
|
258
|
-
# since it has already been completed.
|
|
259
|
-
seq.transferred = True
|
|
260
|
-
seq.data._stage = (
|
|
261
|
-
SequenceStage.PREFILL if is_prefill else SequenceStage.DECODE
|
|
262
|
-
)
|
|
263
|
-
self._transfer_status[seq_group] = set()
|
|
264
|
-
# Use `create_task` to avoid blocking subsequent scheduling.
|
|
265
|
-
asyncio.create_task(
|
|
266
|
-
self._do_transfer(
|
|
267
|
-
virtual_engine, local, remote, seq_group, is_prefill
|
|
268
|
-
)
|
|
267
|
+
is_prefill: bool = token_chunk_size != 1
|
|
268
|
+
if is_prefill:
|
|
269
|
+
local, remote = await self._get_transfer_details(
|
|
270
|
+
virtual_engine, block_tables, seq_group
|
|
269
271
|
)
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
272
|
+
if remote:
|
|
273
|
+
running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
274
|
+
for seq in running_seqs:
|
|
275
|
+
seq.status = SequenceStatus.WAITING
|
|
276
|
+
# Additional attribute `transferred` to mark that this `seq_group` involves a transfer process.
|
|
277
|
+
# During the next scheduling, block allocation will no longer be required
|
|
278
|
+
# since it has already been completed.
|
|
279
|
+
seq.transferred = True
|
|
280
|
+
seq.data._stage = SequenceStage.PREFILL
|
|
281
|
+
self._transfer_status[seq_group] = set()
|
|
282
|
+
# Use `create_task` to avoid blocking subsequent scheduling.
|
|
283
|
+
asyncio.create_task(
|
|
284
|
+
self._do_transfer(virtual_engine, local, remote, seq_group)
|
|
285
|
+
)
|
|
286
|
+
# The `seq_group` that is currently being transferred enters a new queue.
|
|
287
|
+
self._transferring.append(seq_group)
|
|
288
|
+
has_transferring = True
|
|
289
|
+
continue
|
|
290
|
+
else:
|
|
291
|
+
scheduled_seq_groups.append(seq_group)
|
|
276
292
|
|
|
277
293
|
if self.cache_config.enable_prefix_caching:
|
|
278
294
|
common_computed_block_nums = (
|
|
@@ -21,11 +21,11 @@ from ..block_tracker import VLLMBlockTracker
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class ExtendedBlockTracker(VLLMBlockTracker):
|
|
24
|
-
def
|
|
25
|
-
return self.
|
|
24
|
+
def get_hash_to_rank_and_block_id(self):
|
|
25
|
+
return self._hash_to_rank_and_block_id
|
|
26
26
|
|
|
27
|
-
def
|
|
28
|
-
return self.
|
|
27
|
+
def get_rank_to_hash_and_block_id(self):
|
|
28
|
+
return self._rank_to_hash_and_block_id
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@pytest.fixture
|
|
@@ -53,53 +53,54 @@ async def test_block_tracker(actor_pool_context):
|
|
|
53
53
|
)
|
|
54
54
|
|
|
55
55
|
virtual_engine = 0
|
|
56
|
+
rank = 0
|
|
56
57
|
block_infos = [(123, 0), (456, 1), (789, 2)]
|
|
57
58
|
|
|
58
59
|
# register blocks
|
|
59
|
-
await tracker_ref.register_blocks(virtual_engine, block_infos,
|
|
60
|
+
await tracker_ref.register_blocks(virtual_engine, block_infos, rank)
|
|
60
61
|
|
|
61
62
|
# query blocks
|
|
62
63
|
res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (789, 5)])
|
|
63
64
|
assert len(res) == 1
|
|
64
|
-
assert
|
|
65
|
-
assert len(res[
|
|
66
|
-
assert {x[0] for x in res[
|
|
67
|
-
assert {x[1] for x in res[
|
|
68
|
-
assert {x[2] for x in res[
|
|
65
|
+
assert rank in res
|
|
66
|
+
assert len(res[rank]) == 2
|
|
67
|
+
assert {x[0] for x in res[rank]} == {123, 789}
|
|
68
|
+
assert {x[1] for x in res[rank]} == {0, 2}
|
|
69
|
+
assert {x[2] for x in res[rank]} == {4, 5}
|
|
69
70
|
|
|
70
71
|
# query with extra info
|
|
71
72
|
res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (789, 5), (110, 6)])
|
|
72
73
|
assert len(res) == 1
|
|
73
|
-
assert
|
|
74
|
-
assert len(res[
|
|
75
|
-
assert {x[0] for x in res[
|
|
76
|
-
assert {x[1] for x in res[
|
|
77
|
-
assert {x[2] for x in res[
|
|
74
|
+
assert rank in res
|
|
75
|
+
assert len(res[rank]) == 2
|
|
76
|
+
assert {x[0] for x in res[rank]} == {123, 789}
|
|
77
|
+
assert {x[1] for x in res[rank]} == {0, 2}
|
|
78
|
+
assert {x[2] for x in res[rank]} == {4, 5}
|
|
78
79
|
|
|
79
80
|
# unregister block
|
|
80
|
-
await tracker_ref.unregister_block(virtual_engine,
|
|
81
|
+
await tracker_ref.unregister_block(virtual_engine, rank, 1)
|
|
81
82
|
res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (456, 7)])
|
|
82
83
|
assert len(res) == 1
|
|
83
|
-
assert
|
|
84
|
-
assert len(res[
|
|
85
|
-
assert {x[0] for x in res[
|
|
86
|
-
assert {x[1] for x in res[
|
|
84
|
+
assert rank in res
|
|
85
|
+
assert len(res[rank]) == 1
|
|
86
|
+
assert {x[0] for x in res[rank]} == {123}
|
|
87
|
+
assert {x[1] for x in res[rank]} == {
|
|
87
88
|
0,
|
|
88
89
|
}
|
|
89
|
-
assert {x[2] for x in res[
|
|
90
|
+
assert {x[2] for x in res[rank]} == {
|
|
90
91
|
4,
|
|
91
92
|
}
|
|
92
93
|
# nothing happens
|
|
93
|
-
await tracker_ref.unregister_block(virtual_engine,
|
|
94
|
+
await tracker_ref.unregister_block(virtual_engine, rank, 3)
|
|
94
95
|
res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (456, 7)])
|
|
95
96
|
assert len(res) == 1
|
|
96
|
-
assert
|
|
97
|
-
assert len(res[
|
|
98
|
-
assert {x[0] for x in res[
|
|
99
|
-
assert {x[1] for x in res[
|
|
97
|
+
assert rank in res
|
|
98
|
+
assert len(res[rank]) == 1
|
|
99
|
+
assert {x[0] for x in res[rank]} == {123}
|
|
100
|
+
assert {x[1] for x in res[rank]} == {
|
|
100
101
|
0,
|
|
101
102
|
}
|
|
102
|
-
assert {x[2] for x in res[
|
|
103
|
+
assert {x[2] for x in res[rank]} == {
|
|
103
104
|
4,
|
|
104
105
|
}
|
|
105
106
|
# query returns empty
|
|
@@ -107,16 +108,40 @@ async def test_block_tracker(actor_pool_context):
|
|
|
107
108
|
assert res == {}
|
|
108
109
|
|
|
109
110
|
# check internal data
|
|
110
|
-
|
|
111
|
-
assert virtual_engine in
|
|
112
|
-
assert
|
|
111
|
+
hash_to_rank_and_block_id = await tracker_ref.get_hash_to_rank_and_block_id()
|
|
112
|
+
assert virtual_engine in hash_to_rank_and_block_id
|
|
113
|
+
assert hash_to_rank_and_block_id[virtual_engine] == {
|
|
113
114
|
123: {
|
|
114
|
-
(
|
|
115
|
+
(rank, 0),
|
|
115
116
|
},
|
|
116
117
|
456: set(),
|
|
117
|
-
789: {(
|
|
118
|
+
789: {(rank, 2)},
|
|
118
119
|
}
|
|
119
120
|
|
|
120
|
-
|
|
121
|
-
assert virtual_engine in
|
|
122
|
-
assert
|
|
121
|
+
rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id()
|
|
122
|
+
assert virtual_engine in rank_to_hash_and_block_id
|
|
123
|
+
assert rank_to_hash_and_block_id[virtual_engine] == {rank: {(123, 0), (789, 2)}}
|
|
124
|
+
|
|
125
|
+
# register blocks
|
|
126
|
+
new_rank = 1
|
|
127
|
+
block_infos = [(111, 7), (222, 8), (333, 9), (123, 10)]
|
|
128
|
+
await tracker_ref.register_blocks(virtual_engine, block_infos, new_rank)
|
|
129
|
+
|
|
130
|
+
# test unregister rank
|
|
131
|
+
await tracker_ref.unregister_rank(0)
|
|
132
|
+
res = await tracker_ref.query_blocks(virtual_engine, [(789, 5)])
|
|
133
|
+
assert len(res) == 0
|
|
134
|
+
res = await tracker_ref.query_blocks(virtual_engine, [(123, 6)])
|
|
135
|
+
assert len(res) == 1
|
|
136
|
+
assert new_rank in res
|
|
137
|
+
|
|
138
|
+
# check internal data
|
|
139
|
+
rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id()
|
|
140
|
+
assert rank in rank_to_hash_and_block_id[virtual_engine]
|
|
141
|
+
assert new_rank in rank_to_hash_and_block_id[virtual_engine]
|
|
142
|
+
|
|
143
|
+
# test register rank
|
|
144
|
+
await tracker_ref.register_rank(0)
|
|
145
|
+
rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id()
|
|
146
|
+
assert rank not in rank_to_hash_and_block_id[virtual_engine]
|
|
147
|
+
assert new_rank in rank_to_hash_and_block_id[virtual_engine]
|