sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
|
|
61
61
|
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
62
62
|
) -> int:
|
63
63
|
"""Synchronously transfer data to the specified address."""
|
64
|
-
|
64
|
+
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
|
65
|
+
# later: based on the cached queue pair to send data
|
65
66
|
ret = self.engine.transfer_sync_write(
|
66
67
|
session_id, buffer, peer_buffer_address, length
|
67
68
|
)
|
@@ -35,29 +35,19 @@ logger = logging.getLogger(__name__)
|
|
35
35
|
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
36
36
|
|
37
37
|
|
38
|
-
# From Mooncake backend.
|
39
38
|
def group_concurrent_contiguous(
|
40
39
|
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
41
40
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
50
|
-
if src_contiguous and dst_contiguous:
|
51
|
-
current_src.append(src_indices[i])
|
52
|
-
current_dst.append(dst_indices[i])
|
53
|
-
else:
|
54
|
-
src_groups.append(current_src)
|
55
|
-
dst_groups.append(current_dst)
|
56
|
-
current_src = [src_indices[i]]
|
57
|
-
current_dst = [dst_indices[i]]
|
41
|
+
"""Vectorised NumPy implementation."""
|
42
|
+
if src_indices.size == 0:
|
43
|
+
return [], []
|
44
|
+
|
45
|
+
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
46
|
+
src_groups = np.split(src_indices, brk)
|
47
|
+
dst_groups = np.split(dst_indices, brk)
|
58
48
|
|
59
|
-
src_groups.
|
60
|
-
dst_groups.
|
49
|
+
src_groups = [g.tolist() for g in src_groups]
|
50
|
+
dst_groups = [g.tolist() for g in dst_groups]
|
61
51
|
|
62
52
|
return src_groups, dst_groups
|
63
53
|
|
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
import logging
|
23
23
|
import threading
|
24
24
|
from collections import deque
|
25
|
+
from http import HTTPStatus
|
25
26
|
from typing import TYPE_CHECKING, List, Optional
|
26
27
|
|
27
28
|
import torch
|
@@ -31,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
|
|
31
32
|
DisaggregationMode,
|
32
33
|
FakeBootstrapHost,
|
33
34
|
KVClassType,
|
35
|
+
MetadataBuffers,
|
34
36
|
ReqToMetadataIdxAllocator,
|
35
37
|
TransferBackend,
|
36
38
|
get_kv_class,
|
@@ -38,8 +40,10 @@ from sglang.srt.disaggregation.utils import (
|
|
38
40
|
kv_to_page_indices,
|
39
41
|
kv_to_page_num,
|
40
42
|
poll_and_all_reduce,
|
43
|
+
prepare_abort,
|
41
44
|
)
|
42
45
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
43
47
|
|
44
48
|
if TYPE_CHECKING:
|
45
49
|
from torch.distributed import ProcessGroup
|
@@ -59,9 +63,9 @@ class PrefillBootstrapQueue:
|
|
59
63
|
def __init__(
|
60
64
|
self,
|
61
65
|
token_to_kv_pool: KVCache,
|
66
|
+
draft_token_to_kv_pool: Optional[KVCache],
|
62
67
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
63
|
-
metadata_buffers:
|
64
|
-
aux_dtype: torch.dtype,
|
68
|
+
metadata_buffers: MetadataBuffers,
|
65
69
|
tp_rank: int,
|
66
70
|
tp_size: int,
|
67
71
|
bootstrap_port: int,
|
@@ -70,8 +74,9 @@ class PrefillBootstrapQueue:
|
|
70
74
|
scheduler: Scheduler,
|
71
75
|
):
|
72
76
|
self.token_to_kv_pool = token_to_kv_pool
|
77
|
+
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
78
|
+
|
73
79
|
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
74
|
-
self.aux_dtype = aux_dtype
|
75
80
|
|
76
81
|
self.metadata_buffers = metadata_buffers
|
77
82
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
@@ -96,20 +101,24 @@ class PrefillBootstrapQueue:
|
|
96
101
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
97
102
|
)
|
98
103
|
|
104
|
+
if self.draft_token_to_kv_pool is not None:
|
105
|
+
# We should also transfer draft model kv cache. The indices are
|
106
|
+
# always shared with a target model.
|
107
|
+
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
108
|
+
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
109
|
+
)
|
110
|
+
kv_data_ptrs += draft_kv_data_ptrs
|
111
|
+
kv_data_lens += draft_kv_data_lens
|
112
|
+
kv_item_lens += draft_kv_item_lens
|
113
|
+
|
99
114
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
100
115
|
kv_args.kv_data_lens = kv_data_lens
|
101
116
|
kv_args.kv_item_lens = kv_item_lens
|
102
117
|
|
103
118
|
# Define req -> input ids buffer
|
104
|
-
kv_args.aux_data_ptrs =
|
105
|
-
|
106
|
-
|
107
|
-
kv_args.aux_data_lens = [
|
108
|
-
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
109
|
-
]
|
110
|
-
kv_args.aux_item_lens = [
|
111
|
-
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
112
|
-
]
|
119
|
+
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
120
|
+
self.metadata_buffers.get_buf_infos()
|
121
|
+
)
|
113
122
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
114
123
|
kv_args.gpu_id = self.scheduler.gpu_id
|
115
124
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
@@ -135,6 +144,10 @@ class PrefillBootstrapQueue:
|
|
135
144
|
self._process_req(req)
|
136
145
|
self.queue.append(req)
|
137
146
|
|
147
|
+
def extend(self, reqs: List[Req]) -> None:
|
148
|
+
for req in reqs:
|
149
|
+
self.add(req)
|
150
|
+
|
138
151
|
def _process_req(self, req: Req) -> None:
|
139
152
|
"""
|
140
153
|
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
@@ -157,7 +170,18 @@ class PrefillBootstrapQueue:
|
|
157
170
|
if poll == KVPoll.Bootstrapping:
|
158
171
|
continue
|
159
172
|
elif poll == KVPoll.Failed:
|
160
|
-
|
173
|
+
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
174
|
+
try:
|
175
|
+
req.disagg_kv_sender.failure_exception()
|
176
|
+
except Exception as e:
|
177
|
+
error_message += f" with exception {e}"
|
178
|
+
logger.error(error_message)
|
179
|
+
prepare_abort(
|
180
|
+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
181
|
+
)
|
182
|
+
self.scheduler.stream_output([req], req.return_logprob)
|
183
|
+
indices_to_remove.add(i)
|
184
|
+
continue
|
161
185
|
|
162
186
|
# KV.WaitingForInput
|
163
187
|
num_kv_indices = len(req.origin_input_ids)
|
@@ -250,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
250
274
|
result = self.run_batch(batch)
|
251
275
|
self.result_queue.append((batch.copy(), result))
|
252
276
|
|
277
|
+
if self.last_batch is None:
|
278
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
279
|
+
# It is now used for triggering the sampling_info_done event.
|
280
|
+
tmp_batch = ScheduleBatch(
|
281
|
+
reqs=None,
|
282
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
283
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
284
|
+
)
|
285
|
+
self.set_next_batch_sampling_info_done(tmp_batch)
|
286
|
+
|
253
287
|
if self.last_batch:
|
254
288
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
255
289
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
@@ -273,10 +307,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
273
307
|
launch_done: Optional[threading.Event] = None,
|
274
308
|
) -> None:
|
275
309
|
"""
|
276
|
-
Transfer kv for prefill completed requests and add it into
|
310
|
+
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
277
311
|
Adapted from process_batch_result_prefill
|
278
312
|
"""
|
279
|
-
|
280
313
|
(
|
281
314
|
logits_output,
|
282
315
|
next_token_ids,
|
@@ -289,27 +322,78 @@ class SchedulerDisaggregationPrefillMixin:
|
|
289
322
|
result.extend_logprob_start_len_per_req,
|
290
323
|
)
|
291
324
|
|
325
|
+
logprob_pt = 0
|
292
326
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
293
327
|
if self.enable_overlap:
|
294
328
|
# wait
|
295
|
-
|
329
|
+
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
330
|
+
launch_done
|
331
|
+
)
|
296
332
|
else:
|
297
333
|
next_token_ids = result.next_token_ids.tolist()
|
298
|
-
|
299
|
-
|
334
|
+
if batch.return_logprob:
|
335
|
+
if logits_output.next_token_logprobs is not None:
|
336
|
+
logits_output.next_token_logprobs = (
|
337
|
+
logits_output.next_token_logprobs.tolist()
|
338
|
+
)
|
339
|
+
if logits_output.input_token_logprobs is not None:
|
340
|
+
logits_output.input_token_logprobs = tuple(
|
341
|
+
logits_output.input_token_logprobs.tolist()
|
342
|
+
)
|
343
|
+
for i, (req, next_token_id) in enumerate(
|
344
|
+
zip(batch.reqs, next_token_ids, strict=True)
|
345
|
+
):
|
300
346
|
req: Req
|
301
347
|
if req.is_chunked <= 0:
|
302
348
|
# There is no output_ids for prefill
|
303
349
|
req.output_ids.append(next_token_id)
|
304
350
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
305
|
-
self.send_kv_chunk(req, token_id=next_token_id)
|
306
351
|
self.disagg_prefill_inflight_queue.append(req)
|
352
|
+
if req.return_logprob:
|
353
|
+
assert extend_logprob_start_len_per_req is not None
|
354
|
+
assert extend_input_len_per_req is not None
|
355
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
356
|
+
extend_input_len = extend_input_len_per_req[i]
|
357
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
358
|
+
self.add_logprob_return_values(
|
359
|
+
i,
|
360
|
+
req,
|
361
|
+
logprob_pt,
|
362
|
+
next_token_ids,
|
363
|
+
num_input_logprobs,
|
364
|
+
logits_output,
|
365
|
+
)
|
366
|
+
logprob_pt += num_input_logprobs
|
367
|
+
self.send_kv_chunk(req, last_chunk=True)
|
368
|
+
|
369
|
+
if req.grammar is not None:
|
370
|
+
req.grammar.accept_token(next_token_id)
|
371
|
+
req.grammar.finished = req.finished()
|
307
372
|
else:
|
308
373
|
# being chunked reqs' prefill is not finished
|
309
374
|
req.is_chunked -= 1
|
310
375
|
|
376
|
+
if req.return_logprob:
|
377
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
378
|
+
extend_input_len = extend_input_len_per_req[i]
|
379
|
+
if extend_logprob_start_len < extend_input_len:
|
380
|
+
# Update input logprobs.
|
381
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
382
|
+
self.add_input_logprob_return_values(
|
383
|
+
i,
|
384
|
+
req,
|
385
|
+
logits_output,
|
386
|
+
logprob_pt,
|
387
|
+
num_input_logprobs,
|
388
|
+
last_prefill_chunk=False,
|
389
|
+
)
|
390
|
+
logprob_pt += num_input_logprobs
|
391
|
+
|
311
392
|
if self.enable_overlap:
|
312
|
-
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
393
|
+
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
394
|
+
|
395
|
+
# We need to remove the sync in the following function for overlap schedule.
|
396
|
+
self.set_next_batch_sampling_info_done(batch)
|
313
397
|
|
314
398
|
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
315
399
|
"""
|
@@ -335,7 +419,17 @@ class SchedulerDisaggregationPrefillMixin:
|
|
335
419
|
# FIXME: clean up req's data in transfer engine
|
336
420
|
done_reqs.append(req)
|
337
421
|
elif poll == KVPoll.Failed:
|
338
|
-
|
422
|
+
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
423
|
+
try:
|
424
|
+
req.disagg_kv_sender.failure_exception()
|
425
|
+
except Exception as e:
|
426
|
+
error_message += f" with exception {e}"
|
427
|
+
logger.warning(error_message)
|
428
|
+
self.tree_cache.cache_finished_req(req) # unlock the tree
|
429
|
+
prepare_abort(
|
430
|
+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
431
|
+
)
|
432
|
+
done_reqs.append(req)
|
339
433
|
|
340
434
|
for req in done_reqs:
|
341
435
|
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
@@ -343,7 +437,11 @@ class SchedulerDisaggregationPrefillMixin:
|
|
343
437
|
)
|
344
438
|
|
345
439
|
# Stream requests which have finished transfer
|
346
|
-
self.stream_output(
|
440
|
+
self.stream_output(
|
441
|
+
done_reqs,
|
442
|
+
any(req.return_logprob for req in done_reqs),
|
443
|
+
None,
|
444
|
+
)
|
347
445
|
|
348
446
|
self.disagg_prefill_inflight_queue = undone_reqs
|
349
447
|
|
@@ -369,7 +467,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
369
467
|
def send_kv_chunk(
|
370
468
|
self: Scheduler,
|
371
469
|
req: Req,
|
372
|
-
|
470
|
+
last_chunk: bool = False,
|
373
471
|
end_idx: Optional[int] = None,
|
374
472
|
) -> None:
|
375
473
|
"""
|
@@ -377,44 +475,28 @@ class SchedulerDisaggregationPrefillMixin:
|
|
377
475
|
"""
|
378
476
|
page_size = self.token_to_kv_pool_allocator.page_size
|
379
477
|
start_idx = req.start_send_idx
|
380
|
-
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
381
|
-
# the resolved length is not the same as fill_ids's length
|
382
478
|
end_idx = (
|
383
479
|
end_idx
|
384
480
|
if end_idx is not None
|
385
481
|
else min(len(req.fill_ids), len(req.origin_input_ids))
|
386
482
|
)
|
387
|
-
last_chunk = token_id is not None
|
388
483
|
|
389
|
-
if
|
390
|
-
end_idx % page_size != 0
|
391
|
-
): # todo: remove the second condition
|
484
|
+
if not last_chunk:
|
392
485
|
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
393
486
|
end_idx = end_idx - end_idx % page_size
|
394
487
|
|
395
|
-
# Update next start_send_idx
|
396
|
-
req.start_send_idx = end_idx
|
397
|
-
|
398
488
|
kv_indices = (
|
399
489
|
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
400
490
|
.cpu()
|
401
491
|
.numpy()
|
402
492
|
)
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
)
|
493
|
+
req.start_send_idx = end_idx
|
494
|
+
if last_chunk:
|
495
|
+
self.disagg_metadata_buffers.set_buf(req)
|
407
496
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
408
|
-
|
409
|
-
page_start_idx = start_idx // page_size
|
410
|
-
page_end_idx = page_start_idx + len(page_indices)
|
411
|
-
|
412
497
|
if len(page_indices) == 0:
|
413
498
|
logger.info(
|
414
499
|
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
415
500
|
)
|
416
501
|
return
|
417
|
-
|
418
|
-
req.disagg_kv_sender.send(
|
419
|
-
page_indices, slice(page_start_idx, page_end_idx), last_chunk
|
420
|
-
)
|
502
|
+
req.disagg_kv_sender.send(page_indices)
|
@@ -1,10 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
+
import os
|
5
|
+
import random
|
4
6
|
import warnings
|
5
7
|
from collections import deque
|
6
8
|
from enum import Enum
|
7
|
-
from typing import List, Optional
|
9
|
+
from typing import TYPE_CHECKING, List, Optional
|
8
10
|
|
9
11
|
import numpy as np
|
10
12
|
import requests
|
@@ -13,6 +15,14 @@ import torch.distributed as dist
|
|
13
15
|
|
14
16
|
from sglang.srt.utils import get_ip
|
15
17
|
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from sglang.srt.managers.schedule_batch import Req
|
20
|
+
|
21
|
+
FakeBootstrapHost = "2.2.2.2"
|
22
|
+
|
23
|
+
# env var for testing failure, convert to float explicitly
|
24
|
+
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
25
|
+
|
16
26
|
|
17
27
|
class DisaggregationMode(Enum):
|
18
28
|
NULL = "null"
|
@@ -20,11 +30,17 @@ class DisaggregationMode(Enum):
|
|
20
30
|
DECODE = "decode"
|
21
31
|
|
22
32
|
|
23
|
-
FakeBootstrapHost = "2.2.2.2"
|
24
|
-
|
25
|
-
|
26
33
|
def poll_and_all_reduce(pollers, gloo_group):
|
27
|
-
|
34
|
+
# at a certain prob, the poll is failed to simulate failure
|
35
|
+
if FAILURE_PROB > 0:
|
36
|
+
from sglang.srt.disaggregation.base import KVPoll
|
37
|
+
|
38
|
+
polls = [
|
39
|
+
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
|
40
|
+
for poller in pollers
|
41
|
+
]
|
42
|
+
else:
|
43
|
+
polls = [int(poller.poll()) for poller in pollers]
|
28
44
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
29
45
|
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
30
46
|
return tensor_to_reduce.tolist()
|
@@ -168,3 +184,98 @@ def is_mla_backend(target_kv_pool) -> bool:
|
|
168
184
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
169
185
|
|
170
186
|
return isinstance(target_kv_pool, MLATokenToKVPool)
|
187
|
+
|
188
|
+
|
189
|
+
def prepare_abort(req: Req, error_message: str, status_code=None):
|
190
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
191
|
+
|
192
|
+
# populate finish metadata and stream output
|
193
|
+
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
194
|
+
|
195
|
+
if req.return_logprob:
|
196
|
+
req.input_token_logprobs_val = []
|
197
|
+
req.input_token_logprobs_idx = []
|
198
|
+
req.input_top_logprobs_val = []
|
199
|
+
req.input_top_logprobs_idx = []
|
200
|
+
req.input_token_ids_logprobs_val = []
|
201
|
+
req.input_token_ids_logprobs_idx = []
|
202
|
+
|
203
|
+
|
204
|
+
class MetadataBuffers:
|
205
|
+
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
206
|
+
# TODO: abort top_logprobs_num > 128 in PD
|
207
|
+
|
208
|
+
# We transfer the metadata of first output token to decode
|
209
|
+
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
210
|
+
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
211
|
+
self.output_token_logprobs_val = torch.zeros(
|
212
|
+
(size, 16), dtype=torch.float32, device="cpu"
|
213
|
+
)
|
214
|
+
self.output_token_logprobs_idx = torch.zeros(
|
215
|
+
(size, 16), dtype=torch.int32, device="cpu"
|
216
|
+
)
|
217
|
+
self.output_top_logprobs_val = torch.zeros(
|
218
|
+
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
219
|
+
)
|
220
|
+
self.output_top_logprobs_idx = torch.zeros(
|
221
|
+
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
222
|
+
)
|
223
|
+
|
224
|
+
def get_buf_infos(self):
|
225
|
+
ptrs = [
|
226
|
+
self.output_ids.data_ptr(),
|
227
|
+
self.output_token_logprobs_val.data_ptr(),
|
228
|
+
self.output_token_logprobs_idx.data_ptr(),
|
229
|
+
self.output_top_logprobs_val.data_ptr(),
|
230
|
+
self.output_top_logprobs_idx.data_ptr(),
|
231
|
+
]
|
232
|
+
data_lens = [
|
233
|
+
self.output_ids.nbytes,
|
234
|
+
self.output_token_logprobs_val.nbytes,
|
235
|
+
self.output_token_logprobs_idx.nbytes,
|
236
|
+
self.output_top_logprobs_val.nbytes,
|
237
|
+
self.output_top_logprobs_idx.nbytes,
|
238
|
+
]
|
239
|
+
item_lens = [
|
240
|
+
self.output_ids[0].nbytes,
|
241
|
+
self.output_token_logprobs_val[0].nbytes,
|
242
|
+
self.output_token_logprobs_idx[0].nbytes,
|
243
|
+
self.output_top_logprobs_val[0].nbytes,
|
244
|
+
self.output_top_logprobs_idx[0].nbytes,
|
245
|
+
]
|
246
|
+
return ptrs, data_lens, item_lens
|
247
|
+
|
248
|
+
def get_buf(self, idx: int):
|
249
|
+
return (
|
250
|
+
self.output_ids[idx],
|
251
|
+
self.output_token_logprobs_val[idx],
|
252
|
+
self.output_token_logprobs_idx[idx],
|
253
|
+
self.output_top_logprobs_val[idx],
|
254
|
+
self.output_top_logprobs_idx[idx],
|
255
|
+
)
|
256
|
+
|
257
|
+
def set_buf(self, req: Req):
|
258
|
+
|
259
|
+
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
260
|
+
if req.return_logprob:
|
261
|
+
if req.output_token_logprobs_val: # not none or empty list
|
262
|
+
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
263
|
+
req.output_token_logprobs_val[0]
|
264
|
+
)
|
265
|
+
if req.output_token_logprobs_idx: # not none or empty list
|
266
|
+
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
267
|
+
req.output_token_logprobs_idx[0]
|
268
|
+
)
|
269
|
+
|
270
|
+
if req.output_top_logprobs_val: # not none or empty list
|
271
|
+
self.output_top_logprobs_val[req.metadata_buffer_index][
|
272
|
+
: len(req.output_top_logprobs_val[0])
|
273
|
+
] = torch.tensor(
|
274
|
+
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
275
|
+
)
|
276
|
+
if req.output_top_logprobs_idx: # not none or empty list
|
277
|
+
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
278
|
+
: len(req.output_top_logprobs_idx[0])
|
279
|
+
] = torch.tensor(
|
280
|
+
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
281
|
+
)
|
sglang/srt/distributed/utils.py
CHANGED
@@ -127,14 +127,14 @@ class StatelessProcessGroup:
|
|
127
127
|
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
128
128
|
self.store.set(key, pickle.dumps(obj))
|
129
129
|
self.send_dst_counter[dst] += 1
|
130
|
-
self.entries.append((key, time.
|
130
|
+
self.entries.append((key, time.perf_counter()))
|
131
131
|
|
132
132
|
def expire_data(self):
|
133
133
|
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
134
134
|
while self.entries:
|
135
135
|
# check the oldest entry
|
136
136
|
key, timestamp = self.entries[0]
|
137
|
-
if time.
|
137
|
+
if time.perf_counter() - timestamp > self.data_expiration_seconds:
|
138
138
|
self.store.delete_key(key)
|
139
139
|
self.entries.popleft()
|
140
140
|
else:
|
@@ -158,7 +158,7 @@ class StatelessProcessGroup:
|
|
158
158
|
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
|
159
159
|
self.store.set(key, pickle.dumps(obj))
|
160
160
|
self.broadcast_send_counter += 1
|
161
|
-
self.entries.append((key, time.
|
161
|
+
self.entries.append((key, time.perf_counter()))
|
162
162
|
return obj
|
163
163
|
else:
|
164
164
|
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
|
|
47
47
|
EmbeddingReqInput,
|
48
48
|
GenerateReqInput,
|
49
49
|
GetWeightsByNameReqInput,
|
50
|
+
ImageDataItem,
|
50
51
|
InitWeightsUpdateGroupReqInput,
|
51
52
|
ReleaseMemoryOccupationReqInput,
|
52
53
|
ResumeMemoryOccupationReqInput,
|
@@ -150,9 +151,9 @@ class Engine(EngineBase):
|
|
150
151
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
151
152
|
image_data: Optional[
|
152
153
|
Union[
|
153
|
-
List[List[
|
154
|
-
List[
|
155
|
-
|
154
|
+
List[List[ImageDataItem]],
|
155
|
+
List[ImageDataItem],
|
156
|
+
ImageDataItem,
|
156
157
|
]
|
157
158
|
] = None,
|
158
159
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
@@ -221,9 +222,9 @@ class Engine(EngineBase):
|
|
221
222
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
222
223
|
image_data: Optional[
|
223
224
|
Union[
|
224
|
-
List[List[
|
225
|
-
List[
|
226
|
-
|
225
|
+
List[List[ImageDataItem]],
|
226
|
+
List[ImageDataItem],
|
227
|
+
ImageDataItem,
|
227
228
|
]
|
228
229
|
] = None,
|
229
230
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
@@ -320,7 +321,26 @@ class Engine(EngineBase):
|
|
320
321
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
321
322
|
|
322
323
|
def stop_profile(self):
|
323
|
-
|
324
|
+
loop = asyncio.get_event_loop()
|
325
|
+
loop.run_until_complete(self.tokenizer_manager.stop_profile())
|
326
|
+
|
327
|
+
def start_expert_distribution_record(self):
|
328
|
+
loop = asyncio.get_event_loop()
|
329
|
+
loop.run_until_complete(
|
330
|
+
self.tokenizer_manager.start_expert_distribution_record()
|
331
|
+
)
|
332
|
+
|
333
|
+
def stop_expert_distribution_record(self):
|
334
|
+
loop = asyncio.get_event_loop()
|
335
|
+
loop.run_until_complete(
|
336
|
+
self.tokenizer_manager.stop_expert_distribution_record()
|
337
|
+
)
|
338
|
+
|
339
|
+
def dump_expert_distribution_record(self):
|
340
|
+
loop = asyncio.get_event_loop()
|
341
|
+
loop.run_until_complete(
|
342
|
+
self.tokenizer_manager.dump_expert_distribution_record()
|
343
|
+
)
|
324
344
|
|
325
345
|
def get_server_info(self):
|
326
346
|
loop = asyncio.get_event_loop()
|
@@ -486,7 +506,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
486
506
|
if _is_cuda:
|
487
507
|
assert_pkg_version(
|
488
508
|
"sgl-kernel",
|
489
|
-
"0.1.
|
509
|
+
"0.1.4",
|
490
510
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
491
511
|
)
|
492
512
|
|
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
|
|
47
47
|
register_disaggregation_server,
|
48
48
|
)
|
49
49
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
50
|
-
from sglang.srt.function_call_parser import FunctionCallParser
|
50
|
+
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
51
51
|
from sglang.srt.managers.io_struct import (
|
52
52
|
AbortReq,
|
53
53
|
CloseSessionReqInput,
|
@@ -182,13 +182,14 @@ async def health_generate(request: Request) -> Response:
|
|
182
182
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
183
183
|
break
|
184
184
|
|
185
|
-
tic = time.
|
185
|
+
tic = time.perf_counter()
|
186
186
|
task = asyncio.create_task(gen())
|
187
|
-
while time.
|
187
|
+
while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT:
|
188
188
|
await asyncio.sleep(1)
|
189
189
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
190
190
|
task.cancel()
|
191
191
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
192
|
+
_global_state.tokenizer_manager.health_check_failed = False
|
192
193
|
return Response(status_code=200)
|
193
194
|
|
194
195
|
task.cancel()
|
@@ -202,6 +203,7 @@ async def health_generate(request: Request) -> Response:
|
|
202
203
|
f"last_heartbeat time: {last_receive_time}"
|
203
204
|
)
|
204
205
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
206
|
+
_global_state.tokenizer_manager.health_check_failed = True
|
205
207
|
return Response(status_code=503)
|
206
208
|
|
207
209
|
|
@@ -353,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
|
353
355
|
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
354
356
|
async def stop_profile_async():
|
355
357
|
"""Stop profiling."""
|
356
|
-
_global_state.tokenizer_manager.stop_profile()
|
358
|
+
await _global_state.tokenizer_manager.stop_profile()
|
357
359
|
return Response(
|
358
360
|
content="Stop profiling. This will take some time.\n",
|
359
361
|
status_code=200,
|