sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/configs/qwen3_next.py
CHANGED
@@ -15,14 +15,12 @@
|
|
15
15
|
"""Qwen3Hybrid model configuration"""
|
16
16
|
|
17
17
|
import enum
|
18
|
-
import os
|
19
18
|
|
20
|
-
import numpy as np
|
21
|
-
import torch
|
22
19
|
from transformers.configuration_utils import PretrainedConfig
|
23
20
|
from transformers.modeling_rope_utils import rope_config_validation
|
24
21
|
from transformers.utils import logging
|
25
22
|
|
23
|
+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
26
24
|
from sglang.srt.distributed.utils import divide
|
27
25
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
28
26
|
|
@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
|
|
282
280
|
]
|
283
281
|
|
284
282
|
@property
|
285
|
-
def
|
286
|
-
|
287
|
-
|
288
|
-
self.
|
289
|
-
|
283
|
+
def mamba2_cache_params(self) -> Mamba2CacheParams:
|
284
|
+
shape = Mamba2StateShape.create(
|
285
|
+
tp_world_size=get_attention_tp_size(),
|
286
|
+
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
|
287
|
+
n_groups=self.linear_num_key_heads,
|
288
|
+
num_heads=self.linear_num_value_heads,
|
289
|
+
head_dim=self.linear_value_head_dim,
|
290
|
+
state_size=self.linear_key_head_dim,
|
291
|
+
conv_kernel=self.linear_conv_kernel_dim,
|
290
292
|
)
|
291
|
-
conv_state_shape = (
|
292
|
-
divide(conv_dim, world_size),
|
293
|
-
self.linear_conv_kernel_dim - 1,
|
294
|
-
)
|
295
|
-
|
296
|
-
temporal_state_shape = (
|
297
|
-
divide(self.linear_num_value_heads, world_size),
|
298
|
-
self.linear_key_head_dim,
|
299
|
-
self.linear_value_head_dim,
|
300
|
-
)
|
301
|
-
conv_dtype = torch.bfloat16
|
302
|
-
dtype_map = {
|
303
|
-
"float32": torch.float32,
|
304
|
-
"bfloat16": torch.bfloat16,
|
305
|
-
}
|
306
|
-
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
307
|
-
mamba_layers = self.linear_layer_ids
|
308
|
-
return (
|
309
|
-
conv_state_shape,
|
310
|
-
temporal_state_shape,
|
311
|
-
conv_dtype,
|
312
|
-
ssm_dtype,
|
313
|
-
mamba_layers,
|
314
|
-
)
|
315
|
-
|
316
|
-
@property
|
317
|
-
def mamba_cache_per_req(self):
|
318
|
-
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
319
|
-
self.hybrid_gdn_params
|
320
|
-
)
|
321
|
-
mamba_layers_len = len(mamba_layers)
|
322
293
|
|
323
|
-
return (
|
324
|
-
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
325
|
-
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
326
|
-
) * mamba_layers_len
|
294
|
+
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
|
@@ -747,11 +747,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|
747
747
|
|
748
748
|
@torch.no_grad()
|
749
749
|
def event_loop_overlap_disagg_decode(self: Scheduler):
|
750
|
-
result_queue = deque()
|
750
|
+
self.result_queue = deque()
|
751
751
|
self.last_batch: Optional[ScheduleBatch] = None
|
752
752
|
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
753
753
|
|
754
754
|
while True:
|
755
|
+
self.launch_last_batch_sample_if_needed()
|
756
|
+
|
755
757
|
recv_reqs = self.recv_requests()
|
756
758
|
self.process_input_requests(recv_reqs)
|
757
759
|
# polling and allocating kv cache
|
@@ -774,23 +776,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|
774
776
|
None, delay_process=True
|
775
777
|
)
|
776
778
|
if batch_:
|
777
|
-
result_queue.append((batch_.copy(), result))
|
779
|
+
self.result_queue.append((batch_.copy(), result))
|
778
780
|
last_batch_in_queue = True
|
779
781
|
else:
|
780
782
|
if prepare_mlp_sync_flag:
|
781
783
|
self.prepare_mlp_sync_batch(batch)
|
782
784
|
result = self.run_batch(batch)
|
783
|
-
result_queue.append((batch.copy(), result))
|
784
|
-
|
785
|
-
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
786
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
787
|
-
# It is now used for triggering the sampling_info_done event.
|
788
|
-
tmp_batch = ScheduleBatch(
|
789
|
-
reqs=None,
|
790
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
791
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
792
|
-
)
|
793
|
-
self.set_next_batch_sampling_info_done(tmp_batch)
|
785
|
+
self.result_queue.append((batch.copy(), result))
|
794
786
|
last_batch_in_queue = True
|
795
787
|
|
796
788
|
elif prepare_mlp_sync_flag:
|
@@ -798,15 +790,12 @@ class SchedulerDisaggregationDecodeMixin:
|
|
798
790
|
None, delay_process=True
|
799
791
|
)
|
800
792
|
if batch:
|
801
|
-
result_queue.append((batch.copy(), result))
|
793
|
+
self.result_queue.append((batch.copy(), result))
|
802
794
|
last_batch_in_queue = True
|
803
795
|
|
804
796
|
# Process the results of the previous batch but skip if the last batch is extend
|
805
797
|
if self.last_batch and self.last_batch_in_queue:
|
806
|
-
tmp_batch, tmp_result = result_queue.popleft()
|
807
|
-
tmp_batch.next_batch_sampling_info = (
|
808
|
-
self.tp_worker.cur_sampling_info if batch else None
|
809
|
-
)
|
798
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
810
799
|
self.process_batch_result(tmp_batch, tmp_result)
|
811
800
|
|
812
801
|
queue_size = (
|
@@ -4,7 +4,6 @@ import time
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
|
-
from sglang import ServerArgs
|
8
7
|
from sglang.srt.managers.cache_controller import HiCacheController
|
9
8
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
10
9
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -17,6 +16,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
|
|
17
16
|
MHATokenToKVPoolHost,
|
18
17
|
MLATokenToKVPoolHost,
|
19
18
|
)
|
19
|
+
from sglang.srt.server_args import ServerArgs
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager):
|
|
319
319
|
|
320
320
|
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
321
321
|
# Make descs
|
322
|
-
|
322
|
+
if self.is_mla_backend:
|
323
|
+
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
324
|
+
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
325
|
+
)
|
326
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
327
|
+
layers_params = [
|
328
|
+
(
|
329
|
+
src_kv_ptrs[layer_id],
|
330
|
+
dst_kv_ptrs[layer_id],
|
331
|
+
kv_item_len,
|
332
|
+
)
|
333
|
+
for layer_id in range(layers_current_pp_stage)
|
334
|
+
]
|
335
|
+
else:
|
336
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
337
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
338
|
+
)
|
339
|
+
|
340
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
341
|
+
layers_params = [
|
342
|
+
(
|
343
|
+
src_k_ptrs[layer_id],
|
344
|
+
dst_k_ptrs[layer_id],
|
345
|
+
kv_item_len,
|
346
|
+
)
|
347
|
+
for layer_id in range(layers_current_pp_stage)
|
348
|
+
] + [
|
349
|
+
(
|
350
|
+
src_v_ptrs[layer_id],
|
351
|
+
dst_v_ptrs[layer_id],
|
352
|
+
kv_item_len,
|
353
|
+
)
|
354
|
+
for layer_id in range(layers_current_pp_stage)
|
355
|
+
]
|
356
|
+
|
323
357
|
src_addrs = []
|
324
358
|
dst_addrs = []
|
325
|
-
for
|
326
|
-
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
327
|
-
dst_ptr = dst_kv_ptrs[layer_id]
|
328
|
-
item_len = self.kv_args.kv_item_lens[layer_id]
|
329
|
-
|
359
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
330
360
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
331
361
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
332
362
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager):
|
|
397
427
|
num_heads_to_send = dst_heads_per_rank
|
398
428
|
dst_head_start_offset = 0
|
399
429
|
|
430
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
431
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
432
|
+
)
|
400
433
|
# Create transfer descriptors
|
401
434
|
src_addrs = []
|
402
435
|
dst_addrs = []
|
@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager):
|
|
404
437
|
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
405
438
|
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
406
439
|
|
407
|
-
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
408
|
-
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
409
|
-
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
410
|
-
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
411
|
-
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
412
|
-
|
413
440
|
# Calculate precise byte offset and length for the sub-slice within the token
|
414
441
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
415
442
|
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager):
|
|
420
447
|
src_k_ptrs[layer_id],
|
421
448
|
dst_k_ptrs[layer_id],
|
422
449
|
)
|
423
|
-
for layer_id in range(
|
450
|
+
for layer_id in range(layers_current_pp_stage)
|
424
451
|
] + [
|
425
452
|
(
|
426
453
|
src_v_ptrs[layer_id],
|
427
454
|
dst_v_ptrs[layer_id],
|
428
455
|
)
|
429
|
-
for layer_id in range(
|
456
|
+
for layer_id in range(layers_current_pp_stage)
|
430
457
|
]
|
431
458
|
|
432
459
|
src_addrs = []
|
@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager):
|
|
496
523
|
dst_aux_index: int,
|
497
524
|
notif: str,
|
498
525
|
):
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
526
|
+
src_addrs = []
|
527
|
+
dst_addrs = []
|
528
|
+
|
529
|
+
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
530
|
+
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
531
|
+
|
532
|
+
for i, _ in enumerate(dst_aux_ptrs):
|
533
|
+
length = prefill_aux_item_lens[i]
|
534
|
+
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
535
|
+
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
536
|
+
src_addrs.append((src_addr, length, 0))
|
537
|
+
dst_addrs.append((dst_addr, length, 0))
|
538
|
+
|
507
539
|
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
508
540
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
509
541
|
# Transfer data
|
@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager):
|
|
576
608
|
|
577
609
|
handles.append(kv_xfer_handle)
|
578
610
|
# Only the last chunk we need to send the aux data.
|
579
|
-
if is_last:
|
611
|
+
if is_last and self.pp_group.is_last_rank:
|
580
612
|
assert aux_index is not None
|
581
613
|
aux_xfer_handle = self.send_aux(
|
582
614
|
req.agent_name,
|
@@ -321,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
321
321
|
self.result_queue = deque()
|
322
322
|
|
323
323
|
while True:
|
324
|
+
self.launch_last_batch_sample_if_needed()
|
325
|
+
|
324
326
|
recv_reqs = self.recv_requests()
|
325
327
|
self.process_input_requests(recv_reqs)
|
326
328
|
self.waiting_queue.extend(
|
@@ -336,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
336
338
|
result = self.run_batch(batch)
|
337
339
|
self.result_queue.append((batch.copy(), result))
|
338
340
|
|
339
|
-
if self.last_batch is None:
|
340
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
341
|
-
# It is now used for triggering the sampling_info_done event.
|
342
|
-
tmp_batch = ScheduleBatch(
|
343
|
-
reqs=None,
|
344
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
345
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
346
|
-
)
|
347
|
-
self.set_next_batch_sampling_info_done(tmp_batch)
|
348
|
-
|
349
341
|
if self.last_batch:
|
350
342
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
351
|
-
tmp_batch.next_batch_sampling_info = (
|
352
|
-
self.tp_worker.cur_sampling_info if batch else None
|
353
|
-
)
|
354
343
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
355
344
|
|
356
345
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
@@ -368,7 +357,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
368
357
|
self: Scheduler,
|
369
358
|
batch: ScheduleBatch,
|
370
359
|
result: GenerationBatchResult,
|
371
|
-
launch_done: Optional[threading.Event] = None,
|
372
360
|
) -> None:
|
373
361
|
"""
|
374
362
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
@@ -379,31 +367,30 @@ class SchedulerDisaggregationPrefillMixin:
|
|
379
367
|
next_token_ids,
|
380
368
|
extend_input_len_per_req,
|
381
369
|
extend_logprob_start_len_per_req,
|
370
|
+
copy_done,
|
382
371
|
) = (
|
383
372
|
result.logits_output,
|
384
373
|
result.next_token_ids,
|
385
374
|
result.extend_input_len_per_req,
|
386
375
|
result.extend_logprob_start_len_per_req,
|
376
|
+
result.copy_done,
|
387
377
|
)
|
388
378
|
|
379
|
+
if copy_done is not None:
|
380
|
+
copy_done.synchronize()
|
381
|
+
|
389
382
|
logprob_pt = 0
|
390
383
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
391
|
-
|
392
|
-
|
393
|
-
logits_output
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
logits_output.next_token_logprobs.tolist()
|
402
|
-
)
|
403
|
-
if logits_output.input_token_logprobs is not None:
|
404
|
-
logits_output.input_token_logprobs = tuple(
|
405
|
-
logits_output.input_token_logprobs.tolist()
|
406
|
-
)
|
384
|
+
next_token_ids = result.next_token_ids.tolist()
|
385
|
+
if batch.return_logprob:
|
386
|
+
if logits_output.next_token_logprobs is not None:
|
387
|
+
logits_output.next_token_logprobs = (
|
388
|
+
logits_output.next_token_logprobs.tolist()
|
389
|
+
)
|
390
|
+
if logits_output.input_token_logprobs is not None:
|
391
|
+
logits_output.input_token_logprobs = tuple(
|
392
|
+
logits_output.input_token_logprobs.tolist()
|
393
|
+
)
|
407
394
|
|
408
395
|
hidden_state_offset = 0
|
409
396
|
for i, (req, next_token_id) in enumerate(
|
@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
491
478
|
if self.enable_overlap:
|
492
479
|
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
493
480
|
|
494
|
-
# We need to remove the sync in the following function for overlap schedule.
|
495
|
-
self.set_next_batch_sampling_info_done(batch)
|
496
481
|
self.maybe_send_health_check_signal()
|
497
482
|
|
498
483
|
def process_disagg_prefill_inflight_queue(
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
703
703
|
if server_args.attention_backend == "flashinfer":
|
704
704
|
assert_pkg_version(
|
705
705
|
"flashinfer_python",
|
706
|
-
"0.4.
|
706
|
+
"0.4.0",
|
707
707
|
"Please uninstall the old version and "
|
708
708
|
"reinstall the latest version by following the instructions "
|
709
709
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
711
711
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
712
712
|
assert_pkg_version(
|
713
713
|
"sgl-kernel",
|
714
|
-
"0.3.
|
714
|
+
"0.3.15",
|
715
715
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
716
716
|
)
|
717
717
|
|
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
|
|
27
27
|
TokenizedEmbeddingReqInput,
|
28
28
|
TokenizedGenerateReqInput,
|
29
29
|
)
|
30
|
+
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
30
31
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
32
|
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
32
33
|
from sglang.utils import get_exception_traceback
|
@@ -263,8 +264,8 @@ class GrpcRequestManager:
|
|
263
264
|
response = await task
|
264
265
|
|
265
266
|
# Add index for client-side ordering
|
266
|
-
if isinstance(response, dict)
|
267
|
-
response_rid = response
|
267
|
+
if isinstance(response, dict):
|
268
|
+
response_rid = response.get("request_id", "")
|
268
269
|
if response_rid in rid_to_index:
|
269
270
|
response["index"] = rid_to_index[response_rid]
|
270
271
|
|
@@ -338,12 +339,9 @@ class GrpcRequestManager:
|
|
338
339
|
break
|
339
340
|
|
340
341
|
except asyncio.TimeoutError:
|
341
|
-
# Timeout
|
342
|
-
|
343
|
-
|
344
|
-
)
|
345
|
-
await self.abort_request(request_id)
|
346
|
-
return
|
342
|
+
# Timeout is for periodic client cancellation check
|
343
|
+
# Continue waiting for scheduler response
|
344
|
+
continue
|
347
345
|
|
348
346
|
finally:
|
349
347
|
# Always clean up request state when exiting
|
@@ -397,9 +395,7 @@ class GrpcRequestManager:
|
|
397
395
|
# Wait for result in background
|
398
396
|
async def wait_for_result():
|
399
397
|
try:
|
400
|
-
# Wait for completion
|
401
398
|
await state.event.wait()
|
402
|
-
# Get result from queue
|
403
399
|
result = await state.out_queue.get()
|
404
400
|
future.set_result(result)
|
405
401
|
except Exception as e:
|
@@ -414,6 +410,10 @@ class GrpcRequestManager:
|
|
414
410
|
|
415
411
|
async def abort_request(self, request_id: str) -> bool:
|
416
412
|
"""Abort a running request."""
|
413
|
+
# Skip aborting health check requests (they clean themselves up)
|
414
|
+
if request_id.startswith("HEALTH_CHECK"):
|
415
|
+
return False
|
416
|
+
|
417
417
|
if request_id not in self.rid_to_state:
|
418
418
|
return False
|
419
419
|
|
@@ -437,19 +437,6 @@ class GrpcRequestManager:
|
|
437
437
|
|
438
438
|
return True
|
439
439
|
|
440
|
-
async def pause_generation(self):
|
441
|
-
"""Pause generation processing."""
|
442
|
-
async with self.is_pause_cond:
|
443
|
-
self.is_pause = True
|
444
|
-
logger.info("Generation paused")
|
445
|
-
|
446
|
-
async def resume_generation(self):
|
447
|
-
"""Resume generation processing."""
|
448
|
-
async with self.is_pause_cond:
|
449
|
-
self.is_pause = False
|
450
|
-
self.is_pause_cond.notify_all()
|
451
|
-
logger.info("Generation resumed")
|
452
|
-
|
453
440
|
async def handle_loop(self):
|
454
441
|
"""
|
455
442
|
Main event loop - processes outputs from scheduler.
|