sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ import threading
|
|
10
10
|
import uuid
|
11
11
|
from collections import defaultdict
|
12
12
|
from functools import cache
|
13
|
-
from typing import Dict, List, Optional, Tuple, Union
|
13
|
+
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
import numpy.typing as npt
|
@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
|
32
32
|
|
33
33
|
logger = logging.getLogger(__name__)
|
34
34
|
|
35
|
+
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
36
|
+
|
37
|
+
|
38
|
+
# From Mooncake backend.
|
39
|
+
def group_concurrent_contiguous(
|
40
|
+
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
41
|
+
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
42
|
+
src_groups = []
|
43
|
+
dst_groups = []
|
44
|
+
current_src = [src_indices[0]]
|
45
|
+
current_dst = [dst_indices[0]]
|
46
|
+
|
47
|
+
for i in range(1, len(src_indices)):
|
48
|
+
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
|
49
|
+
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
50
|
+
if src_contiguous and dst_contiguous:
|
51
|
+
current_src.append(src_indices[i])
|
52
|
+
current_dst.append(dst_indices[i])
|
53
|
+
else:
|
54
|
+
src_groups.append(current_src)
|
55
|
+
dst_groups.append(current_dst)
|
56
|
+
current_src = [src_indices[i]]
|
57
|
+
current_dst = [dst_indices[i]]
|
58
|
+
|
59
|
+
src_groups.append(current_src)
|
60
|
+
dst_groups.append(current_dst)
|
61
|
+
|
62
|
+
return src_groups, dst_groups
|
63
|
+
|
64
|
+
|
65
|
+
GUARD = "NixlMsgGuard".encode("ascii")
|
66
|
+
|
35
67
|
|
36
68
|
@dataclasses.dataclass
|
37
69
|
class TransferInfo:
|
@@ -45,19 +77,36 @@ class TransferInfo:
|
|
45
77
|
dst_aux_index: int
|
46
78
|
dst_gpu_id: int
|
47
79
|
|
80
|
+
def is_dummy(self):
|
81
|
+
return self.endpoint == ""
|
82
|
+
|
48
83
|
@classmethod
|
49
84
|
def from_zmq(cls, msg: List[bytes]):
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
85
|
+
if len(msg) == 1:
|
86
|
+
# dummy msg
|
87
|
+
return cls(
|
88
|
+
room=int(msg[0].decode("ascii")),
|
89
|
+
endpoint="",
|
90
|
+
dst_port=0,
|
91
|
+
agent_metadata=b"",
|
92
|
+
dst_kv_ptrs=[],
|
93
|
+
dst_kv_indices=np.array([], dtype=np.int64),
|
94
|
+
dst_aux_ptrs=[],
|
95
|
+
dst_aux_index=0,
|
96
|
+
dst_gpu_id=0,
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
return cls(
|
100
|
+
room=int(msg[0].decode("ascii")),
|
101
|
+
endpoint=msg[1].decode("ascii"),
|
102
|
+
dst_port=int(msg[2].decode("ascii")),
|
103
|
+
agent_metadata=msg[3],
|
104
|
+
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
105
|
+
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
|
106
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
107
|
+
dst_aux_index=int(msg[7].decode("ascii")),
|
108
|
+
dst_gpu_id=int(msg[8].decode("ascii")),
|
109
|
+
)
|
61
110
|
|
62
111
|
|
63
112
|
@dataclasses.dataclass
|
@@ -83,6 +132,7 @@ class NixlKVManager(BaseKVManager):
|
|
83
132
|
args: KVArgs,
|
84
133
|
disaggregation_mode: DisaggregationMode,
|
85
134
|
server_args: ServerArgs,
|
135
|
+
is_mla_backend: Optional[bool] = False,
|
86
136
|
):
|
87
137
|
try:
|
88
138
|
from nixl._api import nixl_agent
|
@@ -98,6 +148,19 @@ class NixlKVManager(BaseKVManager):
|
|
98
148
|
# for p/d multi node infer
|
99
149
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
100
150
|
self.dist_init_addr = server_args.dist_init_addr
|
151
|
+
self.tp_size = server_args.tp_size
|
152
|
+
|
153
|
+
self.tp_rank = args.engine_rank
|
154
|
+
self.enable_dp_attention = server_args.enable_dp_attention
|
155
|
+
if self.enable_dp_attention:
|
156
|
+
assert (
|
157
|
+
server_args.dp_size > 1
|
158
|
+
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
|
159
|
+
self.dp_size = server_args.dp_size
|
160
|
+
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
|
161
|
+
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
|
162
|
+
self.dp_rank = args.engine_rank // self.tp_size_of_dp
|
163
|
+
|
101
164
|
self.rank_port = None
|
102
165
|
self.server_socket = zmq.Context().socket(zmq.PULL)
|
103
166
|
self.register_buffer_to_engine()
|
@@ -110,7 +173,8 @@ class NixlKVManager(BaseKVManager):
|
|
110
173
|
self._start_bootstrap_thread()
|
111
174
|
self._register_to_bootstrap()
|
112
175
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
113
|
-
|
176
|
+
# bootstrap key -> (remote_engine_rank -> possible remote source info)
|
177
|
+
self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
|
114
178
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
115
179
|
TransferStatus
|
116
180
|
)
|
@@ -126,6 +190,7 @@ class NixlKVManager(BaseKVManager):
|
|
126
190
|
):
|
127
191
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
128
192
|
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
|
193
|
+
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
129
194
|
if not self.kv_descs:
|
130
195
|
raise Exception("NIXL memory registration failed for kv tensors")
|
131
196
|
aux_addrs = []
|
@@ -134,6 +199,7 @@ class NixlKVManager(BaseKVManager):
|
|
134
199
|
):
|
135
200
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
136
201
|
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
|
202
|
+
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
137
203
|
if not self.aux_descs:
|
138
204
|
raise Exception("NIXL memory registration failed for aux tensors")
|
139
205
|
|
@@ -157,6 +223,12 @@ class NixlKVManager(BaseKVManager):
|
|
157
223
|
dst_gpu_id: int,
|
158
224
|
notif: str,
|
159
225
|
):
|
226
|
+
# group by indices
|
227
|
+
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
228
|
+
prefill_kv_indices, dst_kv_indices
|
229
|
+
)
|
230
|
+
|
231
|
+
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
160
232
|
# Make descs
|
161
233
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
162
234
|
src_addrs = []
|
@@ -166,12 +238,16 @@ class NixlKVManager(BaseKVManager):
|
|
166
238
|
dst_ptr = dst_kv_ptrs[layer_id]
|
167
239
|
item_len = self.kv_args.kv_item_lens[layer_id]
|
168
240
|
|
169
|
-
for prefill_index, decode_index in zip(
|
170
|
-
src_addr = src_ptr + int(prefill_index) * item_len
|
171
|
-
dst_addr = dst_ptr + int(decode_index) * item_len
|
172
|
-
length = item_len
|
241
|
+
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
242
|
+
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
243
|
+
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
244
|
+
length = item_len * len(prefill_index)
|
173
245
|
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
|
174
246
|
dst_addrs.append((dst_addr, length, dst_gpu_id))
|
247
|
+
|
248
|
+
logger.debug(
|
249
|
+
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
250
|
+
)
|
175
251
|
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
|
176
252
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
|
177
253
|
# Transfer data
|
@@ -180,7 +256,7 @@ class NixlKVManager(BaseKVManager):
|
|
180
256
|
src_descs,
|
181
257
|
dst_descs,
|
182
258
|
peer_name,
|
183
|
-
notif.encode("ascii"),
|
259
|
+
notif.encode("ascii"), # type: ignore
|
184
260
|
)
|
185
261
|
if not xfer_handle:
|
186
262
|
raise Exception("KVSender failed to create transfer")
|
@@ -213,7 +289,7 @@ class NixlKVManager(BaseKVManager):
|
|
213
289
|
src_descs,
|
214
290
|
dst_descs,
|
215
291
|
peer_name,
|
216
|
-
notif.encode("ascii"),
|
292
|
+
notif.encode("ascii"), # type: ignore
|
217
293
|
)
|
218
294
|
if not xfer_handle:
|
219
295
|
raise Exception("KVSender failed to create transfer")
|
@@ -240,6 +316,9 @@ class NixlKVManager(BaseKVManager):
|
|
240
316
|
req = self.transfer_infos[bootstrap_room]
|
241
317
|
assert bootstrap_room == req.room
|
242
318
|
|
319
|
+
if req.is_dummy():
|
320
|
+
return []
|
321
|
+
|
243
322
|
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
|
244
323
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
245
324
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
@@ -256,6 +335,7 @@ class NixlKVManager(BaseKVManager):
|
|
256
335
|
handles = [kv_xfer_handle]
|
257
336
|
# Only the last chunk we need to send the aux data.
|
258
337
|
if is_last:
|
338
|
+
assert aux_index is not None
|
259
339
|
aux_xfer_handle = self.send_aux(
|
260
340
|
peer_name,
|
261
341
|
aux_index,
|
@@ -325,6 +405,13 @@ class NixlKVManager(BaseKVManager):
|
|
325
405
|
"""This thread recvs transfer info from the decode engine"""
|
326
406
|
while True:
|
327
407
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
408
|
+
logger.debug(
|
409
|
+
f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
|
410
|
+
)
|
411
|
+
assert (
|
412
|
+
waiting_req_bytes[0] == GUARD
|
413
|
+
), f"First message should be {GUARD}. Foreign traffic?"
|
414
|
+
waiting_req_bytes = waiting_req_bytes[1:]
|
328
415
|
room = waiting_req_bytes[0].decode("ascii")
|
329
416
|
if room == "None":
|
330
417
|
continue
|
@@ -372,14 +459,13 @@ class NixlKVSender(BaseKVSender):
|
|
372
459
|
|
373
460
|
def poll(self) -> KVPoll:
|
374
461
|
if not self.has_sent:
|
375
|
-
return KVPoll.WaitingForInput
|
376
|
-
|
462
|
+
return KVPoll.WaitingForInput # type: ignore
|
377
463
|
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
|
378
464
|
if all([x == "DONE" for x in states]):
|
379
|
-
return KVPoll.Success
|
465
|
+
return KVPoll.Success # type: ignore
|
380
466
|
if any([x == "ERR" for x in states]):
|
381
467
|
raise Exception("KVSender transfer encountered an error.")
|
382
|
-
return KVPoll.WaitingForInput
|
468
|
+
return KVPoll.WaitingForInput # type: ignore
|
383
469
|
|
384
470
|
def failure_exception(self):
|
385
471
|
raise Exception("Fake KVSender Exception")
|
@@ -401,7 +487,7 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
401
487
|
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
402
488
|
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
403
489
|
|
404
|
-
if bootstrap_key not in self.kv_mgr.
|
490
|
+
if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
|
405
491
|
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
406
492
|
self.kv_mgr.kv_args.engine_rank
|
407
493
|
)
|
@@ -410,25 +496,79 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
410
496
|
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
411
497
|
)
|
412
498
|
else:
|
413
|
-
self.kv_mgr.
|
499
|
+
self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
|
414
500
|
else:
|
415
|
-
self.bootstrap_info = self.kv_mgr.
|
416
|
-
|
501
|
+
self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
|
417
502
|
assert self.bootstrap_info is not None
|
418
503
|
|
419
|
-
|
504
|
+
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
|
505
|
+
# In each dict, there are multiple possible remotes named "equal sources".
|
506
|
+
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
|
507
|
+
def _get_bootstrap_info_from_server(
|
508
|
+
self, engine_rank
|
509
|
+
) -> Optional[List[Dict[int, NixlEngineInfo]]]:
|
420
510
|
"""Fetch the bootstrap info from the bootstrap server."""
|
421
511
|
try:
|
422
|
-
|
423
|
-
|
424
|
-
|
512
|
+
if self.kv_mgr.enable_dp_attention:
|
513
|
+
url = f"http://{self.bootstrap_addr}/route"
|
514
|
+
response = requests.get(url)
|
515
|
+
if response.status_code != 200:
|
516
|
+
logger.error(
|
517
|
+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
518
|
+
)
|
519
|
+
return None
|
520
|
+
|
425
521
|
bootstrap_info = response.json()
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
522
|
+
assert isinstance(bootstrap_info, dict)
|
523
|
+
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
|
524
|
+
|
525
|
+
# split out who need to send to this rank.
|
526
|
+
# currently for dpsk mla model, those ranks share the same latent cache.
|
527
|
+
# pick one as the real source
|
528
|
+
|
529
|
+
prefill_tp_size = len(bootstrap_info.keys())
|
530
|
+
|
531
|
+
assert (
|
532
|
+
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
|
533
|
+
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
|
534
|
+
|
535
|
+
num_remote_tp_rank_we_managed = (
|
536
|
+
prefill_tp_size // self.kv_mgr.tp_size_of_dp
|
537
|
+
)
|
538
|
+
|
539
|
+
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
|
540
|
+
remote_tp_ranks = list(range(0, prefill_tp_size))
|
541
|
+
# split it into tp_size_of_dp parts and get our part
|
542
|
+
remote_tp_ranks_grouped = [
|
543
|
+
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
|
544
|
+
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
|
545
|
+
]
|
546
|
+
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
|
547
|
+
|
548
|
+
assert len(managed_ranks) == num_remote_tp_rank_we_managed
|
549
|
+
|
550
|
+
logger.debug(
|
551
|
+
f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
|
430
552
|
)
|
431
|
-
|
553
|
+
|
554
|
+
return [
|
555
|
+
{
|
556
|
+
rk: bootstrap_info[rk]
|
557
|
+
for rk in bootstrap_info.keys()
|
558
|
+
if rk in managed_ranks
|
559
|
+
}
|
560
|
+
]
|
561
|
+
else:
|
562
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
563
|
+
response = requests.get(url)
|
564
|
+
if response.status_code == 200:
|
565
|
+
bootstrap_info = response.json()
|
566
|
+
return [{engine_rank: bootstrap_info}]
|
567
|
+
else:
|
568
|
+
logger.error(
|
569
|
+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
570
|
+
)
|
571
|
+
return None
|
432
572
|
except Exception as e:
|
433
573
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
434
574
|
return None
|
@@ -440,43 +580,67 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
440
580
|
return socket
|
441
581
|
|
442
582
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
443
|
-
self.prefill_server_url = (
|
444
|
-
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
445
|
-
)
|
446
|
-
logger.debug(
|
447
|
-
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
448
|
-
)
|
449
583
|
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
457
|
-
[
|
458
|
-
str(self.bootstrap_room).encode("ascii"),
|
459
|
-
get_local_ip_by_remote().encode("ascii"),
|
460
|
-
str(self.kv_mgr.rank_port).encode("ascii"),
|
461
|
-
self.kv_mgr.agent.get_agent_metadata(),
|
462
|
-
packed_kv_data_ptrs,
|
463
|
-
kv_indices.tobytes(),
|
464
|
-
packed_aux_data_ptrs,
|
465
|
-
str(aux_index).encode("ascii"),
|
466
|
-
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
584
|
+
assert self.bootstrap_info is not None
|
585
|
+
assert self.bootstrap_room is not None
|
586
|
+
|
587
|
+
for equal_sources in self.bootstrap_info:
|
588
|
+
remote_rank = list(equal_sources.keys())[
|
589
|
+
self.bootstrap_room % len(equal_sources)
|
467
590
|
]
|
468
|
-
|
591
|
+
self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
|
592
|
+
logger.debug(
|
593
|
+
f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
|
594
|
+
)
|
595
|
+
|
596
|
+
packed_kv_data_ptrs = b"".join(
|
597
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
598
|
+
)
|
599
|
+
packed_aux_data_ptrs = b"".join(
|
600
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
601
|
+
)
|
602
|
+
|
603
|
+
logger.debug(
|
604
|
+
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
605
|
+
)
|
606
|
+
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
607
|
+
[
|
608
|
+
GUARD,
|
609
|
+
str(self.bootstrap_room).encode("ascii"),
|
610
|
+
get_local_ip_by_remote().encode("ascii"),
|
611
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
612
|
+
self.kv_mgr.agent.get_agent_metadata(),
|
613
|
+
packed_kv_data_ptrs,
|
614
|
+
kv_indices.tobytes(),
|
615
|
+
packed_aux_data_ptrs,
|
616
|
+
str(aux_index).encode("ascii"),
|
617
|
+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
618
|
+
]
|
619
|
+
)
|
620
|
+
|
621
|
+
for dummy_rank in equal_sources.keys():
|
622
|
+
if dummy_rank == remote_rank:
|
623
|
+
continue
|
624
|
+
dummy_info = equal_sources[dummy_rank]
|
625
|
+
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
|
626
|
+
self._connect("tcp://" + dummy_url).send_multipart(
|
627
|
+
[
|
628
|
+
GUARD,
|
629
|
+
str(self.bootstrap_room).encode("ascii"),
|
630
|
+
]
|
631
|
+
)
|
632
|
+
|
469
633
|
self.started_transfer = True
|
470
634
|
|
471
635
|
def poll(self) -> KVPoll:
|
472
636
|
if not self.started_transfer:
|
473
|
-
return KVPoll.WaitingForInput
|
637
|
+
return KVPoll.WaitingForInput # type: ignore
|
474
638
|
|
475
639
|
self.kv_mgr.update_transfer_status()
|
476
640
|
|
477
|
-
if self.kv_mgr.check_transfer_done(self.bootstrap_room):
|
478
|
-
return KVPoll.Success
|
479
|
-
return KVPoll.WaitingForInput
|
641
|
+
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
642
|
+
return KVPoll.Success # type: ignore
|
643
|
+
return KVPoll.WaitingForInput # type: ignore
|
480
644
|
|
481
645
|
def failure_exception(self):
|
482
646
|
raise Exception("Fake KVReceiver Exception")
|
@@ -484,6 +648,7 @@ class NixlKVReceiver(BaseKVReceiver):
|
|
484
648
|
|
485
649
|
class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
486
650
|
def __init__(self, port: int):
|
651
|
+
logger.debug(f"NixlKVBootstrapServer started on port {port}")
|
487
652
|
self.port = port
|
488
653
|
self.app = web.Application()
|
489
654
|
self.store = dict()
|
@@ -564,13 +729,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
|
564
729
|
engine_rank = int(data["engine_rank"])
|
565
730
|
agent_name = data["agent_name"]
|
566
731
|
|
567
|
-
# Add lock to make sure thread-safe
|
568
732
|
if role == "Prefill":
|
569
|
-
self.
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
733
|
+
async with self.lock:
|
734
|
+
self.prefill_port_table[engine_rank] = {
|
735
|
+
"rank_ip": rank_ip,
|
736
|
+
"rank_port": rank_port,
|
737
|
+
"agent_name": agent_name,
|
738
|
+
}
|
574
739
|
logger.info(
|
575
740
|
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
|
576
741
|
)
|
@@ -580,7 +745,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
|
|
580
745
|
async def _handle_route_get(self, request: web.Request):
|
581
746
|
engine_rank = request.query.get("engine_rank")
|
582
747
|
if not engine_rank:
|
583
|
-
|
748
|
+
logger.debug(
|
749
|
+
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
|
750
|
+
)
|
751
|
+
# Return a dict of all engine_rank
|
752
|
+
async with self.lock:
|
753
|
+
bootstrap_info = self.prefill_port_table
|
754
|
+
return web.json_response(bootstrap_info, status=200)
|
584
755
|
|
585
756
|
# Find corresponding prefill info
|
586
757
|
async with self.lock:
|
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.utils import (
|
|
34
34
|
ReqToMetadataIdxAllocator,
|
35
35
|
TransferBackend,
|
36
36
|
get_kv_class,
|
37
|
+
is_mla_backend,
|
37
38
|
kv_to_page_indices,
|
38
39
|
kv_to_page_num,
|
39
40
|
poll_and_all_reduce,
|
@@ -69,6 +70,7 @@ class PrefillBootstrapQueue:
|
|
69
70
|
scheduler: Scheduler,
|
70
71
|
):
|
71
72
|
self.token_to_kv_pool = token_to_kv_pool
|
73
|
+
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
72
74
|
self.aux_dtype = aux_dtype
|
73
75
|
|
74
76
|
self.metadata_buffers = metadata_buffers
|
@@ -112,7 +114,10 @@ class PrefillBootstrapQueue:
|
|
112
114
|
kv_args.gpu_id = self.scheduler.gpu_id
|
113
115
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
114
116
|
kv_manager = kv_manager_class(
|
115
|
-
kv_args,
|
117
|
+
kv_args,
|
118
|
+
DisaggregationMode.PREFILL,
|
119
|
+
self.scheduler.server_args,
|
120
|
+
self.is_mla_backend,
|
116
121
|
)
|
117
122
|
return kv_manager
|
118
123
|
|
@@ -277,19 +282,17 @@ class SchedulerDisaggregationPrefillMixin:
|
|
277
282
|
next_token_ids,
|
278
283
|
extend_input_len_per_req,
|
279
284
|
extend_logprob_start_len_per_req,
|
280
|
-
bid,
|
281
285
|
) = (
|
282
286
|
result.logits_output,
|
283
287
|
result.next_token_ids,
|
284
288
|
result.extend_input_len_per_req,
|
285
289
|
result.extend_logprob_start_len_per_req,
|
286
|
-
result.bid,
|
287
290
|
)
|
288
291
|
|
289
292
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
290
293
|
if self.enable_overlap:
|
291
294
|
# wait
|
292
|
-
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
295
|
+
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
|
293
296
|
else:
|
294
297
|
next_token_ids = result.next_token_ids.tolist()
|
295
298
|
|
@@ -1,13 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import dataclasses
|
4
|
+
import warnings
|
3
5
|
from collections import deque
|
4
6
|
from enum import Enum
|
5
|
-
from typing import List
|
7
|
+
from typing import List, Optional
|
6
8
|
|
7
9
|
import numpy as np
|
10
|
+
import requests
|
8
11
|
import torch
|
9
12
|
import torch.distributed as dist
|
10
13
|
|
14
|
+
from sglang.srt.utils import get_ip
|
15
|
+
|
11
16
|
|
12
17
|
class DisaggregationMode(Enum):
|
13
18
|
NULL = "null"
|
@@ -107,7 +112,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
107
112
|
|
108
113
|
|
109
114
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
110
|
-
# 1. The page is
|
115
|
+
# 1. The page is guaranteed to be full except the last page.
|
111
116
|
# 2. page index = kv_index // page_size
|
112
117
|
# The return vector is kv_indices[::page_size] // page_size
|
113
118
|
if page_size == 1: # shortcut
|
@@ -119,3 +124,47 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
|
119
124
|
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
120
125
|
# ceil(num_kv_indices / page_size)
|
121
126
|
return (num_kv_indices + page_size - 1) // page_size
|
127
|
+
|
128
|
+
|
129
|
+
@dataclasses.dataclass
|
130
|
+
class PDRegistryRequest:
|
131
|
+
"""A request to register a machine itself to the LB."""
|
132
|
+
|
133
|
+
mode: str
|
134
|
+
registry_url: str
|
135
|
+
bootstrap_port: Optional[int] = None
|
136
|
+
|
137
|
+
def __post_init__(self):
|
138
|
+
if self.mode == "prefill" and self.bootstrap_port is None:
|
139
|
+
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
140
|
+
elif self.mode == "decode" and self.bootstrap_port is not None:
|
141
|
+
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
142
|
+
elif self.mode not in ["prefill", "decode"]:
|
143
|
+
raise ValueError(
|
144
|
+
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
def register_disaggregation_server(
|
149
|
+
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
150
|
+
):
|
151
|
+
boostrap_port = bootstrap_port if mode == "prefill" else None
|
152
|
+
registry_request = PDRegistryRequest(
|
153
|
+
mode=mode,
|
154
|
+
registry_url=f"http://{get_ip()}:{server_port}",
|
155
|
+
bootstrap_port=boostrap_port,
|
156
|
+
)
|
157
|
+
res = requests.post(
|
158
|
+
f"{pdlb_url}/register",
|
159
|
+
json=dataclasses.asdict(registry_request),
|
160
|
+
)
|
161
|
+
if res.status_code != 200:
|
162
|
+
warnings.warn(
|
163
|
+
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
164
|
+
)
|
165
|
+
|
166
|
+
|
167
|
+
def is_mla_backend(target_kv_pool) -> bool:
|
168
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
169
|
+
|
170
|
+
return isinstance(target_kv_pool, MLATokenToKVPool)
|
@@ -296,7 +296,6 @@ class CustomAllreduce:
|
|
296
296
|
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
297
297
|
)
|
298
298
|
self.register_buffer(self.buffer)
|
299
|
-
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
|
300
299
|
|
301
300
|
self.disabled = False
|
302
301
|
|
@@ -430,13 +429,7 @@ class CustomAllreduce:
|
|
430
429
|
|
431
430
|
if _is_hip:
|
432
431
|
if self.full_nvlink:
|
433
|
-
|
434
|
-
if self.MSCCL:
|
435
|
-
return False
|
436
|
-
else:
|
437
|
-
return inp_size < self.max_size
|
438
|
-
else:
|
439
|
-
return inp_size < self.max_size
|
432
|
+
return inp_size < self.max_size
|
440
433
|
return False
|
441
434
|
|
442
435
|
return False
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.distributed as dist
|
3
|
+
from torch.distributed import ProcessGroup
|
4
|
+
|
5
|
+
from sglang.srt.utils import is_npu
|
6
|
+
|
7
|
+
|
8
|
+
class NpuCommunicator:
|
9
|
+
|
10
|
+
def __init__(self, group: ProcessGroup):
|
11
|
+
if not is_npu():
|
12
|
+
self.disabled = True
|
13
|
+
return
|
14
|
+
self.disabled = False
|
15
|
+
self.group = group
|
16
|
+
self.world_size = dist.get_world_size(self.group)
|
17
|
+
|
18
|
+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
19
|
+
dist.all_reduce(x, group=self.group)
|
20
|
+
return x
|
21
|
+
|
22
|
+
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
23
|
+
world_size = self.world_size
|
24
|
+
if dim < 0:
|
25
|
+
# Convert negative dim to positive.
|
26
|
+
dim += x.dim()
|
27
|
+
input_size = x.size()
|
28
|
+
output_size = (input_size[0] * world_size,) + input_size[1:]
|
29
|
+
# Allocate output tensor.
|
30
|
+
output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
|
31
|
+
# All-gather.
|
32
|
+
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
33
|
+
# Reshape
|
34
|
+
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
35
|
+
output_tensor = output_tensor.movedim(0, dim)
|
36
|
+
output_tensor = output_tensor.reshape(
|
37
|
+
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
38
|
+
)
|
39
|
+
return output_tensor
|