xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2022-2025 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import logging
16
+ import traceback
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, no_type_check
18
+
19
+ import xoscar as xo
20
+
21
+ from .block_tracker import VLLMBlockTracker
22
+
23
+ if TYPE_CHECKING:
24
+ from .transfer import Rank0TransferActor, TransferActor
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class Rank0ModelActor(xo.StatelessActor):
31
+ @classmethod
32
+ def default_uid(cls):
33
+ return "rank0-model-actor"
34
+
35
+ def __init__(self, xavier_config: Dict[str, Any]):
36
+ super().__init__()
37
+ self._rank = 0
38
+ self._xavier_config = xavier_config
39
+ self._transfer_ref: Optional[xo.ActorRefType["Rank0TransferActor"]] = None
40
+
41
+ async def __pre_destroy__(self):
42
+ if self._transfer_ref is not None:
43
+ try:
44
+ await xo.destroy_actor(self._transfer_ref)
45
+ del self._transfer_ref
46
+ except Exception as e:
47
+ logger.debug(
48
+ f"Destroy transfer actor failed, rank: {self._rank}, address: {self.address}, error: {e}"
49
+ )
50
+
51
+ @no_type_check
52
+ async def start_transfer_for_vllm(self, rank_addresses: List[str]):
53
+ from .transfer import Rank0TransferActor
54
+
55
+ self._transfer_ref = await xo.create_actor(
56
+ Rank0TransferActor,
57
+ address=self.address,
58
+ uid=f"{Rank0TransferActor.default_uid()}-{self._rank}",
59
+ rank=self._rank,
60
+ world_size=self._xavier_config.get("world_size"), # type: ignore
61
+ rank_address=self._xavier_config.get("rank_address"), # type: ignore
62
+ store_address=self._xavier_config.get("store_address"), # type: ignore
63
+ store_port=self._xavier_config.get("store_port"), # type: ignore
64
+ world_addresses=rank_addresses,
65
+ )
66
+ logger.debug(
67
+ f"Init transfer actor: {self._transfer_ref.address}, rank: {self._rank} done for vllm." # type: ignore
68
+ )
69
+
70
+
71
+ def with_lock(method):
72
+ async def wrapper(self, *args, **kwargs):
73
+ async with self._lock:
74
+ return await method(self, *args, **kwargs)
75
+
76
+ return wrapper
77
+
78
+
79
+ class CollectiveManager(xo.StatelessActor):
80
+ @classmethod
81
+ def default_uid(cls):
82
+ return f"xavier-collective-manager"
83
+
84
+ def __init__(self, model_uid: str):
85
+ super().__init__()
86
+ self._model_uid = model_uid
87
+ self._tracker_ref: Optional[xo.ActorRefType["VLLMBlockTracker"]] = None
88
+ self._rank_to_ref: Dict[int, xo.ActorRefType["TransferActor"]] = {}
89
+ self._lock = asyncio.Lock()
90
+
91
+ async def __post_create__(self):
92
+ self._tracker_ref = await xo.actor_ref(
93
+ address=self.address,
94
+ uid=f"{VLLMBlockTracker.default_uid()}-{self._model_uid}",
95
+ )
96
+
97
+ async def unregister_rank(self, rank: int):
98
+ self._rank_to_ref.pop(rank, None)
99
+ await self._tracker_ref.unregister_rank(rank) # type: ignore
100
+ logger.debug(f"Unregister rank: {rank}")
101
+
102
+ async def register_rank(self, rank: int, address: str, update: bool = False):
103
+ from .transfer import TransferActor
104
+
105
+ rank_ref = await xo.actor_ref(
106
+ address=address, uid=f"{TransferActor.default_uid()}-{rank}"
107
+ )
108
+ self._rank_to_ref[rank] = rank_ref
109
+ logger.debug(f"Register rank: {rank}, address: {address}")
110
+ if update:
111
+ await self._update_world()
112
+ await self._tracker_ref.register_rank(rank) # type: ignore
113
+
114
+ @with_lock
115
+ async def _update_world(self):
116
+ """
117
+ Locking is used to prevent chaos when multiple replicas trigger recovery simultaneously.
118
+ """
119
+ from .....core.utils import gen_random_string
120
+
121
+ prefix = gen_random_string(6)
122
+ tasks = []
123
+ rank_to_ref = self._rank_to_ref.copy()
124
+ world_addresses = [ref.address for _, ref in sorted(rank_to_ref.items())]
125
+ for rank, ref in rank_to_ref.items():
126
+ tasks.append(ref.connect_full_mesh(prefix, world_addresses))
127
+ try:
128
+ logger.debug(
129
+ f"Rebuild collective communication with world_addresses: {world_addresses}, prefix: {prefix}"
130
+ )
131
+ await asyncio.gather(*tasks)
132
+ logger.debug(
133
+ f"Rebuild collective communication with world_addresses: {world_addresses}, prefix: {prefix} done."
134
+ )
135
+ except Exception as e:
136
+ """
137
+ The exception here is most likely due to another replica triggering recovery during the recovery process,
138
+ causing `connect_full_mesh` to time out.
139
+ Simply log the exception and
140
+ let the subsequent update process handle the reconstruction of the collective communication world.
141
+ """
142
+ logger.error(
143
+ f"Rebuild collective communication with world_addresses: {world_addresses} failed. "
144
+ f"Exception: {e}"
145
+ )
146
+ # Print the complete error stack
147
+ traceback.print_exception(type(e), e, e.__traceback__)
@@ -64,14 +64,13 @@ class XavierExecutor(GPUExecutorAsync):
64
64
  )
65
65
 
66
66
  async def _get_block_tracker_ref(self):
67
- from .block_tracker import VLLMBlockTracker
68
-
69
67
  if self._block_tracker_ref is None:
70
68
  block_tracker_address = self.vllm_config.xavier_config.get(
71
69
  "block_tracker_address"
72
70
  )
71
+ block_tracker_uid = self.vllm_config.xavier_config.get("block_tracker_uid")
73
72
  self._block_tracker_ref = await xo.actor_ref(
74
- address=block_tracker_address, uid=VLLMBlockTracker.default_uid()
73
+ address=block_tracker_address, uid=block_tracker_uid
75
74
  )
76
75
  return self._block_tracker_ref
77
76
 
@@ -86,8 +85,8 @@ class XavierExecutor(GPUExecutorAsync):
86
85
  )
87
86
  return self._transfer_ref
88
87
 
89
- def get_rank_address(self) -> str:
90
- return self.vllm_config.xavier_config.get("rank_address")
88
+ def get_rank(self) -> int:
89
+ return self.vllm_config.xavier_config.get("rank")
91
90
 
92
91
  async def execute_model_async(
93
92
  self,
@@ -100,7 +99,7 @@ class XavierExecutor(GPUExecutorAsync):
100
99
  virtual_engine = execute_model_req.virtual_engine
101
100
  block_tracker_ref = await self._get_block_tracker_ref()
102
101
  scheduler = self.scheduler[virtual_engine] # type: ignore
103
- rank_address = self.get_rank_address()
102
+ rank = self.get_rank()
104
103
  executed_blocks_details: Set[Tuple[int, int]] = set()
105
104
  for meta in execute_model_req.seq_group_metadata_list:
106
105
  block_tables = meta.block_tables
@@ -117,16 +116,19 @@ class XavierExecutor(GPUExecutorAsync):
117
116
 
118
117
  res = await super().execute_model_async(execute_model_req)
119
118
 
120
- """
121
- Why not collect and register the information after execution?
122
- Because after execution, the model's execution callback hook will release the block_id,
123
- causing the block manager to lose access to the correct information.
124
- """
125
- await block_tracker_ref.register_blocks(
126
- virtual_engine, list(executed_blocks_details), rank_address
127
- )
119
+ if executed_blocks_details:
120
+ """
121
+ Why not collect and register the information after execution?
122
+ Because after execution, the model's execution callback hook will release the block_id,
123
+ causing the block manager to lose access to the correct information.
124
+ """
125
+ await block_tracker_ref.register_blocks(
126
+ virtual_engine, list(executed_blocks_details), rank
127
+ )
128
128
 
129
- for _, _id in executed_blocks_details:
130
- scheduler.block_manager.set_block_status_by_block_id("executed", _id, True)
129
+ for _, _id in executed_blocks_details:
130
+ scheduler.block_manager.set_block_status_by_block_id(
131
+ "executed", _id, True
132
+ )
131
133
 
132
134
  return res
@@ -72,12 +72,11 @@ class XavierScheduler(Scheduler):
72
72
  self._transfer_status: Dict[SequenceGroup, Set[int]] = {}
73
73
 
74
74
  async def _get_block_tracker_ref(self):
75
- from .block_tracker import VLLMBlockTracker
76
-
77
75
  if self._block_tracker_ref is None:
78
76
  block_tracker_address = self._xavier_config.get("block_tracker_address")
77
+ block_tracker_uid = self._xavier_config.get("block_tracker_uid")
79
78
  self._block_tracker_ref = await xo.actor_ref(
80
- address=block_tracker_address, uid=VLLMBlockTracker.default_uid()
79
+ address=block_tracker_address, uid=block_tracker_uid
81
80
  )
82
81
  return self._block_tracker_ref
83
82
 
@@ -97,7 +96,12 @@ class XavierScheduler(Scheduler):
97
96
  virtual_engine: int,
98
97
  block_tables: Dict[int, List[int]],
99
98
  seq_group: SequenceGroup,
100
- ) -> Tuple[Set[int], Dict[str, Set[Tuple[int, int, int]]]]:
99
+ ) -> Tuple[Set[int], Dict[int, Set[Tuple[int, int, int]]]]:
100
+ # If the `seq_group` has the `force_calculation` attribute set to `True`,
101
+ # it indicates that there were issues during the transmission process.
102
+ # In this case, force the computation and exclude it from the Xavier process.
103
+ if getattr(seq_group, "force_calculation", False):
104
+ return set(), dict()
101
105
  """
102
106
  Retrieve information from other replicas to check if any blocks have already been computed,
103
107
  for the purpose of data transfer.
@@ -132,48 +136,63 @@ class XavierScheduler(Scheduler):
132
136
  )
133
137
  ):
134
138
  details.add(detail)
135
- tracker_ref = await self._get_block_tracker_ref()
136
- remote = await tracker_ref.query_blocks(virtual_engine, list(details))
137
- # Not all queried blocks have corresponding results in other replicas.
138
- # Therefore, it is necessary to record which local block data was actually transferred.
139
- local: Set[int] = set()
140
- for _, remote_details in remote.items():
141
- for _, _, local_block_id in remote_details:
142
- local.add(local_block_id)
143
- if local:
144
- logger.debug(
145
- f"Data in local blocks: {local} will be transmitted from the remote."
146
- )
147
- return local, remote
139
+
140
+ if details:
141
+ tracker_ref = await self._get_block_tracker_ref()
142
+ remote = await tracker_ref.query_blocks(virtual_engine, list(details))
143
+ # Not all queried blocks have corresponding results in other replicas.
144
+ # Therefore, it is necessary to record which local block data was actually transferred.
145
+ local: Set[int] = set()
146
+ for _, remote_details in remote.items():
147
+ for _, _, local_block_id in remote_details:
148
+ local.add(local_block_id)
149
+ if local:
150
+ logger.debug(
151
+ f"Data in local blocks: {local} will be transmitted from the remote."
152
+ )
153
+ return local, remote
154
+ else:
155
+ return set(), dict()
148
156
 
149
157
  async def _do_transfer_inner(
150
- self, virtual_engine: int, remote: Dict[str, Set[Tuple[int, int, int]]]
158
+ self, virtual_engine: int, remote: Dict[int, Set[Tuple[int, int, int]]]
151
159
  ):
152
160
  transfer_ref = await self._get_transfer_ref()
153
- for addr, hash_and_block_id in remote.items():
161
+ for from_rank, hash_and_block_id in remote.items():
154
162
  src_to_dst: Dict[int, int] = {x[1]: x[2] for x in hash_and_block_id}
155
- await transfer_ref.recv(virtual_engine, addr, src_to_dst)
163
+ await transfer_ref.recv(virtual_engine, from_rank, src_to_dst)
156
164
 
157
165
  async def _do_transfer(
158
166
  self,
159
167
  virtual_engine: int,
160
168
  local: Set[int],
161
- remote: Dict[str, Set[Tuple[int, int, int]]],
169
+ remote: Dict[int, Set[Tuple[int, int, int]]],
162
170
  seq_group: SequenceGroup,
163
- is_prefill: bool,
164
171
  ):
165
- await self._do_transfer_inner(virtual_engine, remote)
166
- # After the transfer is completed, update the corresponding metadata.
167
- self._transfer_status[seq_group] = local
168
- for _id in local:
169
- self.block_manager.set_block_status_by_block_id("transferred", _id, True)
170
- # After the transfer, place the `seq_group` back into the appropriate queue to
171
- # wait for the next scheduling execution.
172
- if is_prefill:
172
+ try:
173
+ await self._do_transfer_inner(virtual_engine, remote)
174
+ except Exception as e:
175
+ """
176
+ The exception here is most likely due to the sender triggering recovery during the transmission process.
177
+ In this case, fallback to performing computation during the prefill stage.
178
+ """
179
+ logger.error(f"Transfer failed: {e}")
180
+ # Force this `seq_group` to perform computation.
181
+ seq_group.force_calculation = True
182
+ self._transfer_status.pop(seq_group, None)
173
183
  self.waiting.appendleft(seq_group)
184
+ self._transferring.remove(seq_group)
174
185
  else:
175
- self.running.appendleft(seq_group)
176
- self._transferring.remove(seq_group)
186
+ # After the transfer is completed, update the corresponding metadata.
187
+ self._transfer_status[seq_group] = local
188
+ for _id in local:
189
+ self.block_manager.set_block_status_by_block_id(
190
+ "transferred", _id, True
191
+ )
192
+ # After the transfer, place the `seq_group` back into the `waiting` queue to
193
+ # wait for the next scheduling execution.
194
+ self.waiting.appendleft(seq_group)
195
+ self._transferring.remove(seq_group)
177
196
 
178
197
  @no_type_check
179
198
  async def schedule(
@@ -240,39 +259,36 @@ class XavierScheduler(Scheduler):
240
259
  After completing the scheduling, the blocks have been allocated.
241
260
  Therefore, it is possible to check whether some blocks have already been computed on other replicas based on this information,
242
261
  and subsequently initiate the transfer.
262
+ According to the internal code comments in vllm,
263
+ whether `token_chunk_size` is 1 can indicate whether the `seq_group` is in the decode or prefill stage.
264
+ It is noted that data transmission is only applied during the prefill stage.
265
+ In the decode stage, it only applies to the last token of the block, which can negatively impact throughput.
243
266
  """
244
- local, remote = await self._get_transfer_details(
245
- virtual_engine, block_tables, seq_group
246
- )
247
- if remote:
248
- running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
249
- # According to the internal code comments in vllm,
250
- # whether `token_chunk_size` is 1 can indicate whether the `seq_group` is in the decode or prefill stage.
251
- is_prefill = token_chunk_size != 1
252
- for seq in running_seqs:
253
- seq.status = (
254
- SequenceStatus.WAITING if is_prefill else SequenceStatus.RUNNING
255
- )
256
- # Additional attribute `transferred` to mark that this `seq_group` involves a transfer process.
257
- # During the next scheduling, block allocation will no longer be required
258
- # since it has already been completed.
259
- seq.transferred = True
260
- seq.data._stage = (
261
- SequenceStage.PREFILL if is_prefill else SequenceStage.DECODE
262
- )
263
- self._transfer_status[seq_group] = set()
264
- # Use `create_task` to avoid blocking subsequent scheduling.
265
- asyncio.create_task(
266
- self._do_transfer(
267
- virtual_engine, local, remote, seq_group, is_prefill
268
- )
267
+ is_prefill: bool = token_chunk_size != 1
268
+ if is_prefill:
269
+ local, remote = await self._get_transfer_details(
270
+ virtual_engine, block_tables, seq_group
269
271
  )
270
- # The `seq_group` that is currently being transferred enters a new queue.
271
- self._transferring.append(seq_group)
272
- has_transferring = True
273
- continue
274
- else:
275
- scheduled_seq_groups.append(seq_group)
272
+ if remote:
273
+ running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
274
+ for seq in running_seqs:
275
+ seq.status = SequenceStatus.WAITING
276
+ # Additional attribute `transferred` to mark that this `seq_group` involves a transfer process.
277
+ # During the next scheduling, block allocation will no longer be required
278
+ # since it has already been completed.
279
+ seq.transferred = True
280
+ seq.data._stage = SequenceStage.PREFILL
281
+ self._transfer_status[seq_group] = set()
282
+ # Use `create_task` to avoid blocking subsequent scheduling.
283
+ asyncio.create_task(
284
+ self._do_transfer(virtual_engine, local, remote, seq_group)
285
+ )
286
+ # The `seq_group` that is currently being transferred enters a new queue.
287
+ self._transferring.append(seq_group)
288
+ has_transferring = True
289
+ continue
290
+ else:
291
+ scheduled_seq_groups.append(seq_group)
276
292
 
277
293
  if self.cache_config.enable_prefix_caching:
278
294
  common_computed_block_nums = (
@@ -21,11 +21,11 @@ from ..block_tracker import VLLMBlockTracker
21
21
 
22
22
 
23
23
  class ExtendedBlockTracker(VLLMBlockTracker):
24
- def get_hash_to_address_and_block_id(self):
25
- return self._hash_to_address_and_block_id
24
+ def get_hash_to_rank_and_block_id(self):
25
+ return self._hash_to_rank_and_block_id
26
26
 
27
- def get_address_to_hash_and_block_id(self):
28
- return self._address_to_hash_and_block_id
27
+ def get_rank_to_hash_and_block_id(self):
28
+ return self._rank_to_hash_and_block_id
29
29
 
30
30
 
31
31
  @pytest.fixture
@@ -53,53 +53,54 @@ async def test_block_tracker(actor_pool_context):
53
53
  )
54
54
 
55
55
  virtual_engine = 0
56
+ rank = 0
56
57
  block_infos = [(123, 0), (456, 1), (789, 2)]
57
58
 
58
59
  # register blocks
59
- await tracker_ref.register_blocks(virtual_engine, block_infos, addr)
60
+ await tracker_ref.register_blocks(virtual_engine, block_infos, rank)
60
61
 
61
62
  # query blocks
62
63
  res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (789, 5)])
63
64
  assert len(res) == 1
64
- assert addr in res
65
- assert len(res[addr]) == 2
66
- assert {x[0] for x in res[addr]} == {123, 789}
67
- assert {x[1] for x in res[addr]} == {0, 2}
68
- assert {x[2] for x in res[addr]} == {4, 5}
65
+ assert rank in res
66
+ assert len(res[rank]) == 2
67
+ assert {x[0] for x in res[rank]} == {123, 789}
68
+ assert {x[1] for x in res[rank]} == {0, 2}
69
+ assert {x[2] for x in res[rank]} == {4, 5}
69
70
 
70
71
  # query with extra info
71
72
  res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (789, 5), (110, 6)])
72
73
  assert len(res) == 1
73
- assert addr in res
74
- assert len(res[addr]) == 2
75
- assert {x[0] for x in res[addr]} == {123, 789}
76
- assert {x[1] for x in res[addr]} == {0, 2}
77
- assert {x[2] for x in res[addr]} == {4, 5}
74
+ assert rank in res
75
+ assert len(res[rank]) == 2
76
+ assert {x[0] for x in res[rank]} == {123, 789}
77
+ assert {x[1] for x in res[rank]} == {0, 2}
78
+ assert {x[2] for x in res[rank]} == {4, 5}
78
79
 
79
80
  # unregister block
80
- await tracker_ref.unregister_block(virtual_engine, addr, 1)
81
+ await tracker_ref.unregister_block(virtual_engine, rank, 1)
81
82
  res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (456, 7)])
82
83
  assert len(res) == 1
83
- assert addr in res
84
- assert len(res[addr]) == 1
85
- assert {x[0] for x in res[addr]} == {123}
86
- assert {x[1] for x in res[addr]} == {
84
+ assert rank in res
85
+ assert len(res[rank]) == 1
86
+ assert {x[0] for x in res[rank]} == {123}
87
+ assert {x[1] for x in res[rank]} == {
87
88
  0,
88
89
  }
89
- assert {x[2] for x in res[addr]} == {
90
+ assert {x[2] for x in res[rank]} == {
90
91
  4,
91
92
  }
92
93
  # nothing happens
93
- await tracker_ref.unregister_block(virtual_engine, addr, 3)
94
+ await tracker_ref.unregister_block(virtual_engine, rank, 3)
94
95
  res = await tracker_ref.query_blocks(virtual_engine, [(123, 4), (456, 7)])
95
96
  assert len(res) == 1
96
- assert addr in res
97
- assert len(res[addr]) == 1
98
- assert {x[0] for x in res[addr]} == {123}
99
- assert {x[1] for x in res[addr]} == {
97
+ assert rank in res
98
+ assert len(res[rank]) == 1
99
+ assert {x[0] for x in res[rank]} == {123}
100
+ assert {x[1] for x in res[rank]} == {
100
101
  0,
101
102
  }
102
- assert {x[2] for x in res[addr]} == {
103
+ assert {x[2] for x in res[rank]} == {
103
104
  4,
104
105
  }
105
106
  # query returns empty
@@ -107,16 +108,40 @@ async def test_block_tracker(actor_pool_context):
107
108
  assert res == {}
108
109
 
109
110
  # check internal data
110
- hash_to_address_and_block_id = await tracker_ref.get_hash_to_address_and_block_id()
111
- assert virtual_engine in hash_to_address_and_block_id
112
- assert hash_to_address_and_block_id[virtual_engine] == {
111
+ hash_to_rank_and_block_id = await tracker_ref.get_hash_to_rank_and_block_id()
112
+ assert virtual_engine in hash_to_rank_and_block_id
113
+ assert hash_to_rank_and_block_id[virtual_engine] == {
113
114
  123: {
114
- (addr, 0),
115
+ (rank, 0),
115
116
  },
116
117
  456: set(),
117
- 789: {(addr, 2)},
118
+ 789: {(rank, 2)},
118
119
  }
119
120
 
120
- address_to_hash_and_block_id = await tracker_ref.get_address_to_hash_and_block_id()
121
- assert virtual_engine in address_to_hash_and_block_id
122
- assert address_to_hash_and_block_id[virtual_engine] == {addr: {(123, 0), (789, 2)}}
121
+ rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id()
122
+ assert virtual_engine in rank_to_hash_and_block_id
123
+ assert rank_to_hash_and_block_id[virtual_engine] == {rank: {(123, 0), (789, 2)}}
124
+
125
+ # register blocks
126
+ new_rank = 1
127
+ block_infos = [(111, 7), (222, 8), (333, 9), (123, 10)]
128
+ await tracker_ref.register_blocks(virtual_engine, block_infos, new_rank)
129
+
130
+ # test unregister rank
131
+ await tracker_ref.unregister_rank(0)
132
+ res = await tracker_ref.query_blocks(virtual_engine, [(789, 5)])
133
+ assert len(res) == 0
134
+ res = await tracker_ref.query_blocks(virtual_engine, [(123, 6)])
135
+ assert len(res) == 1
136
+ assert new_rank in res
137
+
138
+ # check internal data
139
+ rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id()
140
+ assert rank in rank_to_hash_and_block_id[virtual_engine]
141
+ assert new_rank in rank_to_hash_and_block_id[virtual_engine]
142
+
143
+ # test register rank
144
+ await tracker_ref.register_rank(0)
145
+ rank_to_hash_and_block_id = await tracker_ref.get_rank_to_hash_and_block_id()
146
+ assert rank not in rank_to_hash_and_block_id[virtual_engine]
147
+ assert new_rank in rank_to_hash_and_block_id[virtual_engine]