sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
import logging
|
23
23
|
import threading
|
24
24
|
from collections import deque
|
25
|
+
from http import HTTPStatus
|
25
26
|
from typing import TYPE_CHECKING, List, Optional
|
26
27
|
|
27
28
|
import torch
|
@@ -31,14 +32,18 @@ from sglang.srt.disaggregation.utils import (
|
|
31
32
|
DisaggregationMode,
|
32
33
|
FakeBootstrapHost,
|
33
34
|
KVClassType,
|
35
|
+
MetadataBuffers,
|
34
36
|
ReqToMetadataIdxAllocator,
|
35
37
|
TransferBackend,
|
36
38
|
get_kv_class,
|
39
|
+
is_mla_backend,
|
37
40
|
kv_to_page_indices,
|
38
41
|
kv_to_page_num,
|
39
42
|
poll_and_all_reduce,
|
43
|
+
prepare_abort,
|
40
44
|
)
|
41
45
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
42
47
|
|
43
48
|
if TYPE_CHECKING:
|
44
49
|
from torch.distributed import ProcessGroup
|
@@ -58,9 +63,9 @@ class PrefillBootstrapQueue:
|
|
58
63
|
def __init__(
|
59
64
|
self,
|
60
65
|
token_to_kv_pool: KVCache,
|
66
|
+
draft_token_to_kv_pool: Optional[KVCache],
|
61
67
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
62
|
-
metadata_buffers:
|
63
|
-
aux_dtype: torch.dtype,
|
68
|
+
metadata_buffers: MetadataBuffers,
|
64
69
|
tp_rank: int,
|
65
70
|
tp_size: int,
|
66
71
|
bootstrap_port: int,
|
@@ -69,7 +74,9 @@ class PrefillBootstrapQueue:
|
|
69
74
|
scheduler: Scheduler,
|
70
75
|
):
|
71
76
|
self.token_to_kv_pool = token_to_kv_pool
|
72
|
-
self.
|
77
|
+
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
78
|
+
|
79
|
+
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
73
80
|
|
74
81
|
self.metadata_buffers = metadata_buffers
|
75
82
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
@@ -94,25 +101,32 @@ class PrefillBootstrapQueue:
|
|
94
101
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
95
102
|
)
|
96
103
|
|
104
|
+
if self.draft_token_to_kv_pool is not None:
|
105
|
+
# We should also transfer draft model kv cache. The indices are
|
106
|
+
# always shared with a target model.
|
107
|
+
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
108
|
+
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
109
|
+
)
|
110
|
+
kv_data_ptrs += draft_kv_data_ptrs
|
111
|
+
kv_data_lens += draft_kv_data_lens
|
112
|
+
kv_item_lens += draft_kv_item_lens
|
113
|
+
|
97
114
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
98
115
|
kv_args.kv_data_lens = kv_data_lens
|
99
116
|
kv_args.kv_item_lens = kv_item_lens
|
100
117
|
|
101
118
|
# Define req -> input ids buffer
|
102
|
-
kv_args.aux_data_ptrs =
|
103
|
-
|
104
|
-
|
105
|
-
kv_args.aux_data_lens = [
|
106
|
-
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
107
|
-
]
|
108
|
-
kv_args.aux_item_lens = [
|
109
|
-
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
110
|
-
]
|
119
|
+
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
120
|
+
self.metadata_buffers.get_buf_infos()
|
121
|
+
)
|
111
122
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
112
123
|
kv_args.gpu_id = self.scheduler.gpu_id
|
113
124
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
114
125
|
kv_manager = kv_manager_class(
|
115
|
-
kv_args,
|
126
|
+
kv_args,
|
127
|
+
DisaggregationMode.PREFILL,
|
128
|
+
self.scheduler.server_args,
|
129
|
+
self.is_mla_backend,
|
116
130
|
)
|
117
131
|
return kv_manager
|
118
132
|
|
@@ -130,6 +144,10 @@ class PrefillBootstrapQueue:
|
|
130
144
|
self._process_req(req)
|
131
145
|
self.queue.append(req)
|
132
146
|
|
147
|
+
def extend(self, reqs: List[Req]) -> None:
|
148
|
+
for req in reqs:
|
149
|
+
self.add(req)
|
150
|
+
|
133
151
|
def _process_req(self, req: Req) -> None:
|
134
152
|
"""
|
135
153
|
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
@@ -152,7 +170,18 @@ class PrefillBootstrapQueue:
|
|
152
170
|
if poll == KVPoll.Bootstrapping:
|
153
171
|
continue
|
154
172
|
elif poll == KVPoll.Failed:
|
155
|
-
|
173
|
+
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
174
|
+
try:
|
175
|
+
req.disagg_kv_sender.failure_exception()
|
176
|
+
except Exception as e:
|
177
|
+
error_message += f" with exception {e}"
|
178
|
+
logger.error(error_message)
|
179
|
+
prepare_abort(
|
180
|
+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
181
|
+
)
|
182
|
+
self.scheduler.stream_output([req], req.return_logprob)
|
183
|
+
indices_to_remove.add(i)
|
184
|
+
continue
|
156
185
|
|
157
186
|
# KV.WaitingForInput
|
158
187
|
num_kv_indices = len(req.origin_input_ids)
|
@@ -245,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
245
274
|
result = self.run_batch(batch)
|
246
275
|
self.result_queue.append((batch.copy(), result))
|
247
276
|
|
277
|
+
if self.last_batch is None:
|
278
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
279
|
+
# It is now used for triggering the sampling_info_done event.
|
280
|
+
tmp_batch = ScheduleBatch(
|
281
|
+
reqs=None,
|
282
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
283
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
284
|
+
)
|
285
|
+
self.set_next_batch_sampling_info_done(tmp_batch)
|
286
|
+
|
248
287
|
if self.last_batch:
|
249
288
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
250
289
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
@@ -268,45 +307,93 @@ class SchedulerDisaggregationPrefillMixin:
|
|
268
307
|
launch_done: Optional[threading.Event] = None,
|
269
308
|
) -> None:
|
270
309
|
"""
|
271
|
-
Transfer kv for prefill completed requests and add it into
|
310
|
+
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
272
311
|
Adapted from process_batch_result_prefill
|
273
312
|
"""
|
274
|
-
|
275
313
|
(
|
276
314
|
logits_output,
|
277
315
|
next_token_ids,
|
278
316
|
extend_input_len_per_req,
|
279
317
|
extend_logprob_start_len_per_req,
|
280
|
-
bid,
|
281
318
|
) = (
|
282
319
|
result.logits_output,
|
283
320
|
result.next_token_ids,
|
284
321
|
result.extend_input_len_per_req,
|
285
322
|
result.extend_logprob_start_len_per_req,
|
286
|
-
result.bid,
|
287
323
|
)
|
288
324
|
|
325
|
+
logprob_pt = 0
|
289
326
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
290
327
|
if self.enable_overlap:
|
291
328
|
# wait
|
292
|
-
|
329
|
+
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
330
|
+
launch_done
|
331
|
+
)
|
293
332
|
else:
|
294
333
|
next_token_ids = result.next_token_ids.tolist()
|
295
|
-
|
296
|
-
|
334
|
+
if batch.return_logprob:
|
335
|
+
if logits_output.next_token_logprobs is not None:
|
336
|
+
logits_output.next_token_logprobs = (
|
337
|
+
logits_output.next_token_logprobs.tolist()
|
338
|
+
)
|
339
|
+
if logits_output.input_token_logprobs is not None:
|
340
|
+
logits_output.input_token_logprobs = tuple(
|
341
|
+
logits_output.input_token_logprobs.tolist()
|
342
|
+
)
|
343
|
+
for i, (req, next_token_id) in enumerate(
|
344
|
+
zip(batch.reqs, next_token_ids, strict=True)
|
345
|
+
):
|
297
346
|
req: Req
|
298
347
|
if req.is_chunked <= 0:
|
299
348
|
# There is no output_ids for prefill
|
300
349
|
req.output_ids.append(next_token_id)
|
301
350
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
302
|
-
self.send_kv_chunk(req, token_id=next_token_id)
|
303
351
|
self.disagg_prefill_inflight_queue.append(req)
|
352
|
+
if req.return_logprob:
|
353
|
+
assert extend_logprob_start_len_per_req is not None
|
354
|
+
assert extend_input_len_per_req is not None
|
355
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
356
|
+
extend_input_len = extend_input_len_per_req[i]
|
357
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
358
|
+
self.add_logprob_return_values(
|
359
|
+
i,
|
360
|
+
req,
|
361
|
+
logprob_pt,
|
362
|
+
next_token_ids,
|
363
|
+
num_input_logprobs,
|
364
|
+
logits_output,
|
365
|
+
)
|
366
|
+
logprob_pt += num_input_logprobs
|
367
|
+
self.send_kv_chunk(req, last_chunk=True)
|
368
|
+
|
369
|
+
if req.grammar is not None:
|
370
|
+
req.grammar.accept_token(next_token_id)
|
371
|
+
req.grammar.finished = req.finished()
|
304
372
|
else:
|
305
373
|
# being chunked reqs' prefill is not finished
|
306
374
|
req.is_chunked -= 1
|
307
375
|
|
376
|
+
if req.return_logprob:
|
377
|
+
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
378
|
+
extend_input_len = extend_input_len_per_req[i]
|
379
|
+
if extend_logprob_start_len < extend_input_len:
|
380
|
+
# Update input logprobs.
|
381
|
+
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
382
|
+
self.add_input_logprob_return_values(
|
383
|
+
i,
|
384
|
+
req,
|
385
|
+
logits_output,
|
386
|
+
logprob_pt,
|
387
|
+
num_input_logprobs,
|
388
|
+
last_prefill_chunk=False,
|
389
|
+
)
|
390
|
+
logprob_pt += num_input_logprobs
|
391
|
+
|
308
392
|
if self.enable_overlap:
|
309
|
-
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
393
|
+
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
394
|
+
|
395
|
+
# We need to remove the sync in the following function for overlap schedule.
|
396
|
+
self.set_next_batch_sampling_info_done(batch)
|
310
397
|
|
311
398
|
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
312
399
|
"""
|
@@ -332,7 +419,17 @@ class SchedulerDisaggregationPrefillMixin:
|
|
332
419
|
# FIXME: clean up req's data in transfer engine
|
333
420
|
done_reqs.append(req)
|
334
421
|
elif poll == KVPoll.Failed:
|
335
|
-
|
422
|
+
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
423
|
+
try:
|
424
|
+
req.disagg_kv_sender.failure_exception()
|
425
|
+
except Exception as e:
|
426
|
+
error_message += f" with exception {e}"
|
427
|
+
logger.warning(error_message)
|
428
|
+
self.tree_cache.cache_finished_req(req) # unlock the tree
|
429
|
+
prepare_abort(
|
430
|
+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
431
|
+
)
|
432
|
+
done_reqs.append(req)
|
336
433
|
|
337
434
|
for req in done_reqs:
|
338
435
|
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
@@ -340,7 +437,11 @@ class SchedulerDisaggregationPrefillMixin:
|
|
340
437
|
)
|
341
438
|
|
342
439
|
# Stream requests which have finished transfer
|
343
|
-
self.stream_output(
|
440
|
+
self.stream_output(
|
441
|
+
done_reqs,
|
442
|
+
any(req.return_logprob for req in done_reqs),
|
443
|
+
None,
|
444
|
+
)
|
344
445
|
|
345
446
|
self.disagg_prefill_inflight_queue = undone_reqs
|
346
447
|
|
@@ -366,7 +467,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
366
467
|
def send_kv_chunk(
|
367
468
|
self: Scheduler,
|
368
469
|
req: Req,
|
369
|
-
|
470
|
+
last_chunk: bool = False,
|
370
471
|
end_idx: Optional[int] = None,
|
371
472
|
) -> None:
|
372
473
|
"""
|
@@ -374,44 +475,28 @@ class SchedulerDisaggregationPrefillMixin:
|
|
374
475
|
"""
|
375
476
|
page_size = self.token_to_kv_pool_allocator.page_size
|
376
477
|
start_idx = req.start_send_idx
|
377
|
-
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
378
|
-
# the resolved length is not the same as fill_ids's length
|
379
478
|
end_idx = (
|
380
479
|
end_idx
|
381
480
|
if end_idx is not None
|
382
481
|
else min(len(req.fill_ids), len(req.origin_input_ids))
|
383
482
|
)
|
384
|
-
last_chunk = token_id is not None
|
385
483
|
|
386
|
-
if
|
387
|
-
end_idx % page_size != 0
|
388
|
-
): # todo: remove the second condition
|
484
|
+
if not last_chunk:
|
389
485
|
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
390
486
|
end_idx = end_idx - end_idx % page_size
|
391
487
|
|
392
|
-
# Update next start_send_idx
|
393
|
-
req.start_send_idx = end_idx
|
394
|
-
|
395
488
|
kv_indices = (
|
396
489
|
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
397
490
|
.cpu()
|
398
491
|
.numpy()
|
399
492
|
)
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
)
|
493
|
+
req.start_send_idx = end_idx
|
494
|
+
if last_chunk:
|
495
|
+
self.disagg_metadata_buffers.set_buf(req)
|
404
496
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
405
|
-
|
406
|
-
page_start_idx = start_idx // page_size
|
407
|
-
page_end_idx = page_start_idx + len(page_indices)
|
408
|
-
|
409
497
|
if len(page_indices) == 0:
|
410
498
|
logger.info(
|
411
499
|
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
412
500
|
)
|
413
501
|
return
|
414
|
-
|
415
|
-
req.disagg_kv_sender.send(
|
416
|
-
page_indices, slice(page_start_idx, page_end_idx), last_chunk
|
417
|
-
)
|
502
|
+
req.disagg_kv_sender.send(page_indices)
|
@@ -1,10 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
+
import os
|
5
|
+
import random
|
4
6
|
import warnings
|
5
7
|
from collections import deque
|
6
8
|
from enum import Enum
|
7
|
-
from typing import List, Optional
|
9
|
+
from typing import TYPE_CHECKING, List, Optional
|
8
10
|
|
9
11
|
import numpy as np
|
10
12
|
import requests
|
@@ -13,6 +15,14 @@ import torch.distributed as dist
|
|
13
15
|
|
14
16
|
from sglang.srt.utils import get_ip
|
15
17
|
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from sglang.srt.managers.schedule_batch import Req
|
20
|
+
|
21
|
+
FakeBootstrapHost = "2.2.2.2"
|
22
|
+
|
23
|
+
# env var for testing failure, convert to float explicitly
|
24
|
+
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
25
|
+
|
16
26
|
|
17
27
|
class DisaggregationMode(Enum):
|
18
28
|
NULL = "null"
|
@@ -20,11 +30,17 @@ class DisaggregationMode(Enum):
|
|
20
30
|
DECODE = "decode"
|
21
31
|
|
22
32
|
|
23
|
-
FakeBootstrapHost = "2.2.2.2"
|
24
|
-
|
25
|
-
|
26
33
|
def poll_and_all_reduce(pollers, gloo_group):
|
27
|
-
|
34
|
+
# at a certain prob, the poll is failed to simulate failure
|
35
|
+
if FAILURE_PROB > 0:
|
36
|
+
from sglang.srt.disaggregation.base import KVPoll
|
37
|
+
|
38
|
+
polls = [
|
39
|
+
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
|
40
|
+
for poller in pollers
|
41
|
+
]
|
42
|
+
else:
|
43
|
+
polls = [int(poller.poll()) for poller in pollers]
|
28
44
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
29
45
|
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
30
46
|
return tensor_to_reduce.tolist()
|
@@ -112,7 +128,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
112
128
|
|
113
129
|
|
114
130
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
115
|
-
# 1. The page is
|
131
|
+
# 1. The page is guaranteed to be full except the last page.
|
116
132
|
# 2. page index = kv_index // page_size
|
117
133
|
# The return vector is kv_indices[::page_size] // page_size
|
118
134
|
if page_size == 1: # shortcut
|
@@ -162,3 +178,104 @@ def register_disaggregation_server(
|
|
162
178
|
warnings.warn(
|
163
179
|
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
164
180
|
)
|
181
|
+
|
182
|
+
|
183
|
+
def is_mla_backend(target_kv_pool) -> bool:
|
184
|
+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
185
|
+
|
186
|
+
return isinstance(target_kv_pool, MLATokenToKVPool)
|
187
|
+
|
188
|
+
|
189
|
+
def prepare_abort(req: Req, error_message: str, status_code=None):
|
190
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
191
|
+
|
192
|
+
# populate finish metadata and stream output
|
193
|
+
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
194
|
+
|
195
|
+
if req.return_logprob:
|
196
|
+
req.input_token_logprobs_val = []
|
197
|
+
req.input_token_logprobs_idx = []
|
198
|
+
req.input_top_logprobs_val = []
|
199
|
+
req.input_top_logprobs_idx = []
|
200
|
+
req.input_token_ids_logprobs_val = []
|
201
|
+
req.input_token_ids_logprobs_idx = []
|
202
|
+
|
203
|
+
|
204
|
+
class MetadataBuffers:
|
205
|
+
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
206
|
+
# TODO: abort top_logprobs_num > 128 in PD
|
207
|
+
|
208
|
+
# We transfer the metadata of first output token to decode
|
209
|
+
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
210
|
+
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
211
|
+
self.output_token_logprobs_val = torch.zeros(
|
212
|
+
(size, 16), dtype=torch.float32, device="cpu"
|
213
|
+
)
|
214
|
+
self.output_token_logprobs_idx = torch.zeros(
|
215
|
+
(size, 16), dtype=torch.int32, device="cpu"
|
216
|
+
)
|
217
|
+
self.output_top_logprobs_val = torch.zeros(
|
218
|
+
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
219
|
+
)
|
220
|
+
self.output_top_logprobs_idx = torch.zeros(
|
221
|
+
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
222
|
+
)
|
223
|
+
|
224
|
+
def get_buf_infos(self):
|
225
|
+
ptrs = [
|
226
|
+
self.output_ids.data_ptr(),
|
227
|
+
self.output_token_logprobs_val.data_ptr(),
|
228
|
+
self.output_token_logprobs_idx.data_ptr(),
|
229
|
+
self.output_top_logprobs_val.data_ptr(),
|
230
|
+
self.output_top_logprobs_idx.data_ptr(),
|
231
|
+
]
|
232
|
+
data_lens = [
|
233
|
+
self.output_ids.nbytes,
|
234
|
+
self.output_token_logprobs_val.nbytes,
|
235
|
+
self.output_token_logprobs_idx.nbytes,
|
236
|
+
self.output_top_logprobs_val.nbytes,
|
237
|
+
self.output_top_logprobs_idx.nbytes,
|
238
|
+
]
|
239
|
+
item_lens = [
|
240
|
+
self.output_ids[0].nbytes,
|
241
|
+
self.output_token_logprobs_val[0].nbytes,
|
242
|
+
self.output_token_logprobs_idx[0].nbytes,
|
243
|
+
self.output_top_logprobs_val[0].nbytes,
|
244
|
+
self.output_top_logprobs_idx[0].nbytes,
|
245
|
+
]
|
246
|
+
return ptrs, data_lens, item_lens
|
247
|
+
|
248
|
+
def get_buf(self, idx: int):
|
249
|
+
return (
|
250
|
+
self.output_ids[idx],
|
251
|
+
self.output_token_logprobs_val[idx],
|
252
|
+
self.output_token_logprobs_idx[idx],
|
253
|
+
self.output_top_logprobs_val[idx],
|
254
|
+
self.output_top_logprobs_idx[idx],
|
255
|
+
)
|
256
|
+
|
257
|
+
def set_buf(self, req: Req):
|
258
|
+
|
259
|
+
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
260
|
+
if req.return_logprob:
|
261
|
+
if req.output_token_logprobs_val: # not none or empty list
|
262
|
+
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
263
|
+
req.output_token_logprobs_val[0]
|
264
|
+
)
|
265
|
+
if req.output_token_logprobs_idx: # not none or empty list
|
266
|
+
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
267
|
+
req.output_token_logprobs_idx[0]
|
268
|
+
)
|
269
|
+
|
270
|
+
if req.output_top_logprobs_val: # not none or empty list
|
271
|
+
self.output_top_logprobs_val[req.metadata_buffer_index][
|
272
|
+
: len(req.output_top_logprobs_val[0])
|
273
|
+
] = torch.tensor(
|
274
|
+
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
275
|
+
)
|
276
|
+
if req.output_top_logprobs_idx: # not none or empty list
|
277
|
+
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
278
|
+
: len(req.output_top_logprobs_idx[0])
|
279
|
+
] = torch.tensor(
|
280
|
+
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
281
|
+
)
|
sglang/srt/distributed/utils.py
CHANGED
@@ -127,14 +127,14 @@ class StatelessProcessGroup:
|
|
127
127
|
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
128
128
|
self.store.set(key, pickle.dumps(obj))
|
129
129
|
self.send_dst_counter[dst] += 1
|
130
|
-
self.entries.append((key, time.
|
130
|
+
self.entries.append((key, time.perf_counter()))
|
131
131
|
|
132
132
|
def expire_data(self):
|
133
133
|
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
134
134
|
while self.entries:
|
135
135
|
# check the oldest entry
|
136
136
|
key, timestamp = self.entries[0]
|
137
|
-
if time.
|
137
|
+
if time.perf_counter() - timestamp > self.data_expiration_seconds:
|
138
138
|
self.store.delete_key(key)
|
139
139
|
self.entries.popleft()
|
140
140
|
else:
|
@@ -158,7 +158,7 @@ class StatelessProcessGroup:
|
|
158
158
|
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
|
159
159
|
self.store.set(key, pickle.dumps(obj))
|
160
160
|
self.broadcast_send_counter += 1
|
161
|
-
self.entries.append((key, time.
|
161
|
+
self.entries.append((key, time.perf_counter()))
|
162
162
|
return obj
|
163
163
|
else:
|
164
164
|
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
|
|
47
47
|
EmbeddingReqInput,
|
48
48
|
GenerateReqInput,
|
49
49
|
GetWeightsByNameReqInput,
|
50
|
+
ImageDataItem,
|
50
51
|
InitWeightsUpdateGroupReqInput,
|
51
52
|
ReleaseMemoryOccupationReqInput,
|
52
53
|
ResumeMemoryOccupationReqInput,
|
@@ -150,9 +151,9 @@ class Engine(EngineBase):
|
|
150
151
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
151
152
|
image_data: Optional[
|
152
153
|
Union[
|
153
|
-
List[List[
|
154
|
-
List[
|
155
|
-
|
154
|
+
List[List[ImageDataItem]],
|
155
|
+
List[ImageDataItem],
|
156
|
+
ImageDataItem,
|
156
157
|
]
|
157
158
|
] = None,
|
158
159
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
@@ -221,9 +222,9 @@ class Engine(EngineBase):
|
|
221
222
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
222
223
|
image_data: Optional[
|
223
224
|
Union[
|
224
|
-
List[List[
|
225
|
-
List[
|
226
|
-
|
225
|
+
List[List[ImageDataItem]],
|
226
|
+
List[ImageDataItem],
|
227
|
+
ImageDataItem,
|
227
228
|
]
|
228
229
|
] = None,
|
229
230
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
@@ -285,6 +286,21 @@ class Engine(EngineBase):
|
|
285
286
|
ret = loop.run_until_complete(generator.__anext__())
|
286
287
|
return ret
|
287
288
|
|
289
|
+
async def async_encode(
|
290
|
+
self,
|
291
|
+
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
292
|
+
image_data: Optional[Union[List[str], str]] = None,
|
293
|
+
) -> Dict:
|
294
|
+
"""
|
295
|
+
Asynchronous version of encode method.
|
296
|
+
|
297
|
+
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
298
|
+
Please refer to `EmbeddingReqInput` for the documentation.
|
299
|
+
"""
|
300
|
+
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
|
301
|
+
generator = self.tokenizer_manager.generate_request(obj, None)
|
302
|
+
return await generator.__anext__()
|
303
|
+
|
288
304
|
def shutdown(self):
|
289
305
|
"""Shutdown the engine"""
|
290
306
|
kill_process_tree(os.getpid(), include_parent=False)
|
@@ -305,7 +321,26 @@ class Engine(EngineBase):
|
|
305
321
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
306
322
|
|
307
323
|
def stop_profile(self):
|
308
|
-
|
324
|
+
loop = asyncio.get_event_loop()
|
325
|
+
loop.run_until_complete(self.tokenizer_manager.stop_profile())
|
326
|
+
|
327
|
+
def start_expert_distribution_record(self):
|
328
|
+
loop = asyncio.get_event_loop()
|
329
|
+
loop.run_until_complete(
|
330
|
+
self.tokenizer_manager.start_expert_distribution_record()
|
331
|
+
)
|
332
|
+
|
333
|
+
def stop_expert_distribution_record(self):
|
334
|
+
loop = asyncio.get_event_loop()
|
335
|
+
loop.run_until_complete(
|
336
|
+
self.tokenizer_manager.stop_expert_distribution_record()
|
337
|
+
)
|
338
|
+
|
339
|
+
def dump_expert_distribution_record(self):
|
340
|
+
loop = asyncio.get_event_loop()
|
341
|
+
loop.run_until_complete(
|
342
|
+
self.tokenizer_manager.dump_expert_distribution_record()
|
343
|
+
)
|
309
344
|
|
310
345
|
def get_server_info(self):
|
311
346
|
loop = asyncio.get_event_loop()
|
@@ -315,7 +350,7 @@ class Engine(EngineBase):
|
|
315
350
|
return {
|
316
351
|
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
317
352
|
**self.scheduler_info,
|
318
|
-
|
353
|
+
"internal_states": internal_states,
|
319
354
|
"version": __version__,
|
320
355
|
}
|
321
356
|
|
@@ -471,7 +506,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
471
506
|
if _is_cuda:
|
472
507
|
assert_pkg_version(
|
473
508
|
"sgl-kernel",
|
474
|
-
"0.1.
|
509
|
+
"0.1.4",
|
475
510
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
476
511
|
)
|
477
512
|
|