sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -32,16 +32,33 @@ import psutil
|
|
32
32
|
import setproctitle
|
33
33
|
import torch
|
34
34
|
import zmq
|
35
|
+
from torch.distributed import barrier
|
35
36
|
|
36
37
|
from sglang.global_config import global_config
|
37
38
|
from sglang.srt.configs.model_config import ModelConfig
|
38
39
|
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
40
|
+
from sglang.srt.disaggregation.decode import (
|
41
|
+
DecodePreallocQueue,
|
42
|
+
DecodeTransferQueue,
|
43
|
+
SchedulerDisaggregationDecodeMixin,
|
44
|
+
)
|
45
|
+
from sglang.srt.disaggregation.prefill import (
|
46
|
+
PrefillBootstrapQueue,
|
47
|
+
SchedulerDisaggregationPrefillMixin,
|
48
|
+
)
|
49
|
+
from sglang.srt.disaggregation.utils import (
|
50
|
+
DisaggregationMode,
|
51
|
+
ReqToMetadataIdxAllocator,
|
52
|
+
)
|
39
53
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
40
54
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
41
55
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
56
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
42
57
|
from sglang.srt.managers.io_struct import (
|
43
58
|
AbortReq,
|
44
59
|
CloseSessionReqInput,
|
60
|
+
ExpertDistributionReq,
|
61
|
+
ExpertDistributionReqOutput,
|
45
62
|
FlushCacheReq,
|
46
63
|
GetInternalStateReq,
|
47
64
|
GetInternalStateReqOutput,
|
@@ -59,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
|
59
76
|
ReleaseMemoryOccupationReqOutput,
|
60
77
|
ResumeMemoryOccupationReqInput,
|
61
78
|
ResumeMemoryOccupationReqOutput,
|
79
|
+
RpcReqInput,
|
80
|
+
RpcReqOutput,
|
62
81
|
SetInternalStateReq,
|
63
82
|
SetInternalStateReqOutput,
|
64
83
|
TokenizedEmbeddingReqInput,
|
@@ -72,7 +91,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
91
|
)
|
73
92
|
from sglang.srt.managers.schedule_batch import (
|
74
93
|
FINISH_ABORT,
|
75
|
-
|
94
|
+
MultimodalInputs,
|
76
95
|
Req,
|
77
96
|
ScheduleBatch,
|
78
97
|
global_server_args_dict,
|
@@ -98,6 +117,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
98
117
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
99
118
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
100
119
|
from sglang.srt.utils import (
|
120
|
+
DynamicGradMode,
|
101
121
|
broadcast_pyobj,
|
102
122
|
configure_logger,
|
103
123
|
crash_on_warnings,
|
@@ -111,6 +131,8 @@ from sglang.srt.utils import (
|
|
111
131
|
)
|
112
132
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
113
133
|
|
134
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
135
|
+
|
114
136
|
logger = logging.getLogger(__name__)
|
115
137
|
|
116
138
|
# Test retract decode for debugging purposes
|
@@ -133,7 +155,11 @@ class EmbeddingBatchResult:
|
|
133
155
|
bid: int
|
134
156
|
|
135
157
|
|
136
|
-
class Scheduler(
|
158
|
+
class Scheduler(
|
159
|
+
SchedulerOutputProcessorMixin,
|
160
|
+
SchedulerDisaggregationDecodeMixin,
|
161
|
+
SchedulerDisaggregationPrefillMixin,
|
162
|
+
):
|
137
163
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
138
164
|
|
139
165
|
def __init__(
|
@@ -193,8 +219,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
193
219
|
self.send_to_detokenizer = get_zmq_socket(
|
194
220
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
195
221
|
)
|
222
|
+
|
223
|
+
self.recv_from_rpc = get_zmq_socket(
|
224
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
225
|
+
)
|
196
226
|
else:
|
197
227
|
self.recv_from_tokenizer = None
|
228
|
+
self.recv_from_rpc = None
|
198
229
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
199
230
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
200
231
|
|
@@ -376,9 +407,16 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
376
407
|
(ProfileReq, self.profile),
|
377
408
|
(GetInternalStateReq, self.get_internal_state),
|
378
409
|
(SetInternalStateReq, self.set_internal_state),
|
410
|
+
(RpcReqInput, self.handle_rpc_request),
|
411
|
+
(ExpertDistributionReq, self.expert_distribution_handle),
|
379
412
|
]
|
380
413
|
)
|
381
414
|
|
415
|
+
self.disaggregation_mode = DisaggregationMode(
|
416
|
+
self.server_args.disaggregation_mode
|
417
|
+
)
|
418
|
+
self.init_disaggregation()
|
419
|
+
|
382
420
|
def init_tokenizer(self):
|
383
421
|
server_args = self.server_args
|
384
422
|
|
@@ -434,6 +472,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
434
472
|
req_to_token_pool=self.req_to_token_pool,
|
435
473
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
474
|
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
475
|
+
page_size=self.page_size,
|
476
|
+
hicache_ratio=server_args.hicache_ratio,
|
437
477
|
)
|
438
478
|
else:
|
439
479
|
self.tree_cache = RadixCache(
|
@@ -477,7 +517,74 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
477
517
|
},
|
478
518
|
)
|
479
519
|
|
480
|
-
|
520
|
+
def init_disaggregation(self):
|
521
|
+
if (
|
522
|
+
self.disaggregation_mode == DisaggregationMode.DECODE
|
523
|
+
): # *2 for the headroom.
|
524
|
+
buffer_size = (self.req_to_token_pool.size) * 2
|
525
|
+
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
526
|
+
buffer_size
|
527
|
+
)
|
528
|
+
aux_dtype = torch.int32
|
529
|
+
# A list of metadata buffers. The shape is (b, metadata_size) where
|
530
|
+
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
531
|
+
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
532
|
+
output_id_buffer = torch.zeros(
|
533
|
+
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
534
|
+
)
|
535
|
+
metadata_buffers = [output_id_buffer]
|
536
|
+
|
537
|
+
# The decode requests polling kv cache
|
538
|
+
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
539
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
540
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
541
|
+
metadata_buffers=metadata_buffers,
|
542
|
+
)
|
543
|
+
|
544
|
+
# The decode requests pending for pre-allocation
|
545
|
+
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
546
|
+
req_to_token_pool=self.req_to_token_pool,
|
547
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
548
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
549
|
+
metadata_buffers=metadata_buffers,
|
550
|
+
aux_dtype=aux_dtype,
|
551
|
+
scheduler=self,
|
552
|
+
transfer_queue=self.disagg_decode_transfer_queue,
|
553
|
+
tree_cache=self.tree_cache,
|
554
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
555
|
+
tp_rank=self.tp_rank,
|
556
|
+
tp_size=self.tp_size,
|
557
|
+
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
558
|
+
)
|
559
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
560
|
+
# *2 for the headroom.
|
561
|
+
buffer_size = self.max_running_requests * 2
|
562
|
+
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
563
|
+
buffer_size
|
564
|
+
)
|
565
|
+
aux_dtype = torch.int32
|
566
|
+
# A list of metadata buffers. The shape is (b, metadata_size) where
|
567
|
+
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
568
|
+
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
569
|
+
output_id_buffer = torch.zeros(
|
570
|
+
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
571
|
+
)
|
572
|
+
metadata_buffers = [output_id_buffer]
|
573
|
+
|
574
|
+
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
|
575
|
+
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
576
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
577
|
+
metadata_buffers=metadata_buffers,
|
578
|
+
aux_dtype=aux_dtype,
|
579
|
+
tp_rank=self.tp_rank,
|
580
|
+
tp_size=self.tp_size,
|
581
|
+
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
582
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
583
|
+
)
|
584
|
+
# The prefill requests that are in the middle of kv sending
|
585
|
+
self.disagg_prefill_infight_queue: List[Req] = []
|
586
|
+
|
587
|
+
@DynamicGradMode()
|
481
588
|
def event_loop_normal(self):
|
482
589
|
"""A normal scheduler loop."""
|
483
590
|
while True:
|
@@ -497,7 +604,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
497
604
|
|
498
605
|
self.last_batch = batch
|
499
606
|
|
500
|
-
@
|
607
|
+
@DynamicGradMode()
|
501
608
|
def event_loop_overlap(self):
|
502
609
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
503
610
|
self.result_queue = deque()
|
@@ -537,6 +644,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
537
644
|
|
538
645
|
self.last_batch = batch
|
539
646
|
|
647
|
+
@torch.no_grad()
|
648
|
+
def event_loop_normal_disagg_prefill(self):
|
649
|
+
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
650
|
+
|
651
|
+
while True:
|
652
|
+
recv_reqs = self.recv_requests()
|
653
|
+
self.process_input_requests(recv_reqs)
|
654
|
+
self.waiting_queue.extend(
|
655
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
656
|
+
)
|
657
|
+
self.process_prefill_chunk()
|
658
|
+
batch = self.get_new_batch_prefill()
|
659
|
+
self.cur_batch = batch
|
660
|
+
|
661
|
+
if batch:
|
662
|
+
result = self.run_batch(batch)
|
663
|
+
self.process_batch_result_disagg_prefill(batch, result)
|
664
|
+
|
665
|
+
if len(self.disagg_prefill_infight_queue) > 0:
|
666
|
+
self.process_disagg_prefill_infight_queue()
|
667
|
+
|
668
|
+
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
669
|
+
self.check_memory()
|
670
|
+
self.new_token_ratio = self.init_new_token_ratio
|
671
|
+
|
672
|
+
self.last_batch = batch
|
673
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
674
|
+
# Otherwise, it hangs under high concurrency
|
675
|
+
self.running_batch.batch_is_full = False
|
676
|
+
|
677
|
+
@torch.no_grad()
|
678
|
+
def event_loop_normal_disagg_decode(self):
|
679
|
+
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
680
|
+
|
681
|
+
while True:
|
682
|
+
recv_reqs = self.recv_requests()
|
683
|
+
self.process_input_requests(recv_reqs)
|
684
|
+
# polling and allocating kv cache
|
685
|
+
self.process_decode_queue()
|
686
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
687
|
+
self.cur_batch = batch
|
688
|
+
|
689
|
+
if batch:
|
690
|
+
# Generate fake extend output.
|
691
|
+
if batch.forward_mode.is_extend():
|
692
|
+
# Note: Logprobs should be handled on the prefill engine.
|
693
|
+
self.stream_output(
|
694
|
+
batch.reqs, [False for _ in range(len(batch.reqs))]
|
695
|
+
)
|
696
|
+
else:
|
697
|
+
result = self.run_batch(batch)
|
698
|
+
self.process_batch_result(batch, result)
|
699
|
+
|
700
|
+
if batch is None and (
|
701
|
+
len(self.disagg_decode_transfer_queue.queue)
|
702
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
703
|
+
== 0
|
704
|
+
):
|
705
|
+
# When the server is idle, do self-check and re-init some states
|
706
|
+
self.check_memory()
|
707
|
+
self.new_token_ratio = self.init_new_token_ratio
|
708
|
+
|
709
|
+
self.last_batch = batch
|
710
|
+
|
540
711
|
def recv_requests(self) -> List[Req]:
|
541
712
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
542
713
|
if self.attn_tp_rank == 0:
|
@@ -548,6 +719,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
548
719
|
except zmq.ZMQError:
|
549
720
|
break
|
550
721
|
recv_reqs.append(recv_req)
|
722
|
+
|
723
|
+
while True:
|
724
|
+
try:
|
725
|
+
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
726
|
+
except zmq.ZMQError:
|
727
|
+
break
|
728
|
+
recv_reqs.append(recv_rpc)
|
551
729
|
else:
|
552
730
|
recv_reqs = None
|
553
731
|
|
@@ -599,7 +777,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
599
777
|
|
600
778
|
output = self._request_dispatcher(recv_req)
|
601
779
|
if output is not None:
|
602
|
-
|
780
|
+
if isinstance(output, RpcReqOutput):
|
781
|
+
if self.recv_from_rpc is not None:
|
782
|
+
self.recv_from_rpc.send_pyobj(output)
|
783
|
+
else:
|
784
|
+
self.send_to_tokenizer.send_pyobj(output)
|
603
785
|
|
604
786
|
def handle_generate_request(
|
605
787
|
self,
|
@@ -665,8 +847,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
665
847
|
return
|
666
848
|
|
667
849
|
# Handle multimodal inputs
|
668
|
-
if recv_req.
|
669
|
-
image_inputs =
|
850
|
+
if recv_req.mm_inputs is not None:
|
851
|
+
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
|
670
852
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
671
853
|
req.origin_input_ids = self.pad_input_ids_func(
|
672
854
|
req.origin_input_ids, image_inputs
|
@@ -680,7 +862,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
680
862
|
)
|
681
863
|
logger.error(error_msg)
|
682
864
|
req.origin_input_ids = [0]
|
683
|
-
req.
|
865
|
+
req.multimodal_inputs = None
|
684
866
|
req.sampling_params.max_new_tokens = 0
|
685
867
|
req.finished_reason = FINISH_ABORT(
|
686
868
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
@@ -755,10 +937,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
755
937
|
self._add_request_to_queue(req)
|
756
938
|
|
757
939
|
def _add_request_to_queue(self, req: Req):
|
758
|
-
self.
|
940
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
941
|
+
self.disagg_prefill_pending_queue.add(req)
|
942
|
+
|
943
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
944
|
+
self.disagg_decode_prealloc_queue.add(req)
|
759
945
|
|
760
|
-
|
761
|
-
|
946
|
+
else:
|
947
|
+
self.waiting_queue.append(req)
|
948
|
+
|
949
|
+
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
950
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
951
|
+
self.disagg_decode_prealloc_queue.extend(reqs)
|
952
|
+
else:
|
953
|
+
self.waiting_queue.extend(reqs)
|
762
954
|
|
763
955
|
def handle_embedding_request(
|
764
956
|
self,
|
@@ -774,7 +966,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
774
966
|
|
775
967
|
# Handle multimodal inputs
|
776
968
|
if recv_req.image_inputs is not None:
|
777
|
-
image_inputs =
|
969
|
+
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
|
778
970
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
779
971
|
req.origin_input_ids = self.pad_input_ids_func(
|
780
972
|
req.origin_input_ids, image_inputs
|
@@ -788,7 +980,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
788
980
|
)
|
789
981
|
logger.error(error_msg)
|
790
982
|
req.origin_input_ids = [0]
|
791
|
-
req.
|
983
|
+
req.multimodal_inputs = None
|
792
984
|
req.sampling_params.max_new_tokens = 0
|
793
985
|
req.finished_reason = FINISH_ABORT(
|
794
986
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
@@ -874,7 +1066,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
874
1066
|
f"#token: {num_used}, "
|
875
1067
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
876
1068
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
877
|
-
f"largest-len: {self._largest_prefill_decode_len}, "
|
878
1069
|
f"#queue-req: {len(self.waiting_queue)}, "
|
879
1070
|
)
|
880
1071
|
spec_accept_length = 0
|
@@ -892,7 +1083,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
892
1083
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
893
1084
|
f"accept len: {spec_accept_length:.2f}, "
|
894
1085
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
895
|
-
f"largest-len: {self._largest_prefill_decode_len}, "
|
896
1086
|
f"#queue-req: {len(self.waiting_queue)}, "
|
897
1087
|
)
|
898
1088
|
|
@@ -997,7 +1187,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
997
1187
|
|
998
1188
|
# Handle DP attention
|
999
1189
|
if self.server_args.enable_dp_attention:
|
1000
|
-
ret = self.prepare_dp_attn_batch(ret)
|
1190
|
+
ret, _ = self.prepare_dp_attn_batch(ret)
|
1001
1191
|
|
1002
1192
|
return ret
|
1003
1193
|
|
@@ -1269,39 +1459,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1269
1459
|
# Check if other DP workers have running batches
|
1270
1460
|
if local_batch is None:
|
1271
1461
|
num_tokens = 0
|
1462
|
+
global_num_tokens_for_logprob = 0
|
1272
1463
|
elif local_batch.forward_mode.is_decode():
|
1273
1464
|
num_tokens = local_batch.batch_size()
|
1465
|
+
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
|
1466
|
+
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
|
1467
|
+
global_num_tokens_for_logprob = num_tokens
|
1274
1468
|
else:
|
1275
1469
|
num_tokens = local_batch.extend_num_tokens
|
1470
|
+
global_num_tokens_for_logprob = sum(
|
1471
|
+
[
|
1472
|
+
# We should have at least 1 token for sample in every case.
|
1473
|
+
max(extend_len - logprob_start_len, 1)
|
1474
|
+
for logprob_start_len, extend_len in zip(
|
1475
|
+
local_batch.extend_logprob_start_lens, local_batch.extend_lens
|
1476
|
+
)
|
1477
|
+
]
|
1478
|
+
)
|
1479
|
+
|
1480
|
+
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
1481
|
+
can_cuda_graph = 1
|
1482
|
+
else:
|
1483
|
+
can_cuda_graph = 0
|
1484
|
+
|
1485
|
+
if not self.spec_algorithm.is_none():
|
1486
|
+
# TODO(sang): Support cuda graph when idle batch is there.
|
1487
|
+
if local_batch is None or local_batch.forward_mode.is_idle():
|
1488
|
+
can_cuda_graph = 0
|
1276
1489
|
|
1277
|
-
|
1278
|
-
|
1490
|
+
is_extend_in_batch = (
|
1491
|
+
local_batch.forward_mode.is_extend() if local_batch else False
|
1492
|
+
)
|
1493
|
+
local_info = torch.tensor(
|
1494
|
+
[
|
1495
|
+
num_tokens,
|
1496
|
+
can_cuda_graph,
|
1497
|
+
global_num_tokens_for_logprob,
|
1498
|
+
is_extend_in_batch,
|
1499
|
+
],
|
1500
|
+
dtype=torch.int64,
|
1501
|
+
)
|
1502
|
+
global_info = torch.empty(
|
1503
|
+
(self.server_args.dp_size, self.attn_tp_size, 4),
|
1504
|
+
dtype=torch.int64,
|
1505
|
+
)
|
1279
1506
|
torch.distributed.all_gather_into_tensor(
|
1280
|
-
|
1281
|
-
|
1507
|
+
global_info.flatten(),
|
1508
|
+
local_info,
|
1282
1509
|
group=self.tp_cpu_group,
|
1283
1510
|
)
|
1511
|
+
global_num_tokens = global_info[:, 0, 0].tolist()
|
1512
|
+
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
1513
|
+
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
1514
|
+
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
1284
1515
|
|
1285
|
-
if local_batch is None and
|
1516
|
+
if local_batch is None and max(global_num_tokens) > 0:
|
1286
1517
|
local_batch = self.get_idle_batch()
|
1287
1518
|
|
1288
1519
|
if local_batch is not None:
|
1289
|
-
local_batch.global_num_tokens = global_num_tokens
|
1520
|
+
local_batch.global_num_tokens = global_num_tokens
|
1521
|
+
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
1290
1522
|
|
1291
1523
|
# Check forward mode for cuda graph
|
1292
1524
|
if not self.server_args.disable_cuda_graph:
|
1293
|
-
|
1294
|
-
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
1295
|
-
dtype=torch.int32,
|
1296
|
-
)
|
1297
|
-
torch.distributed.all_reduce(
|
1298
|
-
forward_mode_state,
|
1299
|
-
op=torch.distributed.ReduceOp.MIN,
|
1300
|
-
group=self.tp_cpu_group,
|
1301
|
-
)
|
1302
|
-
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
1525
|
+
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
1303
1526
|
|
1304
|
-
return local_batch
|
1527
|
+
return local_batch, any(is_extend_in_batch)
|
1305
1528
|
|
1306
1529
|
def get_idle_batch(self):
|
1307
1530
|
idle_batch = ScheduleBatch.init_new(
|
@@ -1458,6 +1681,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1458
1681
|
server_args=global_server_args_dict,
|
1459
1682
|
)
|
1460
1683
|
|
1684
|
+
def handle_rpc_request(self, recv_req: RpcReqInput):
|
1685
|
+
# Handle RPC requests
|
1686
|
+
logger.info(
|
1687
|
+
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
1688
|
+
)
|
1689
|
+
|
1690
|
+
success = True
|
1691
|
+
exec = None
|
1692
|
+
try:
|
1693
|
+
func = getattr(self, recv_req.method)
|
1694
|
+
func(recv_req.parameters)
|
1695
|
+
except Exception as e:
|
1696
|
+
success = False
|
1697
|
+
exec = e
|
1698
|
+
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
1699
|
+
|
1700
|
+
barrier()
|
1701
|
+
return RpcReqOutput(success, "" if not exec else str(exec))
|
1702
|
+
|
1703
|
+
def save_remote_model(self, params):
|
1704
|
+
url = params["url"]
|
1705
|
+
|
1706
|
+
if isinstance(self.tp_worker, TpModelWorkerClient):
|
1707
|
+
worker = self.tp_worker.worker
|
1708
|
+
else:
|
1709
|
+
worker = self.tp_worker
|
1710
|
+
|
1711
|
+
worker.model_runner.save_remote_model(url)
|
1712
|
+
|
1713
|
+
def save_sharded_model(self, params):
|
1714
|
+
if isinstance(self.tp_worker, TpModelWorkerClient):
|
1715
|
+
worker = self.tp_worker.worker
|
1716
|
+
else:
|
1717
|
+
worker = self.tp_worker
|
1718
|
+
|
1719
|
+
worker.model_runner.save_sharded_model(
|
1720
|
+
path=params["path"],
|
1721
|
+
pattern=params["pattern"],
|
1722
|
+
max_size=params["max_size"],
|
1723
|
+
)
|
1724
|
+
|
1461
1725
|
def abort_request(self, recv_req: AbortReq):
|
1462
1726
|
# Delete requests in the waiting queue
|
1463
1727
|
to_del = []
|
@@ -1527,6 +1791,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1527
1791
|
return GetWeightsByNameReqOutput(parameter)
|
1528
1792
|
|
1529
1793
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
1794
|
+
self.memory_saver_adapter.check_validity(
|
1795
|
+
caller_name="release_memory_occupation"
|
1796
|
+
)
|
1530
1797
|
self.stashed_model_static_state = _export_static_state(
|
1531
1798
|
self.tp_worker.worker.model_runner.model
|
1532
1799
|
)
|
@@ -1535,6 +1802,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1535
1802
|
return ReleaseMemoryOccupationReqOutput()
|
1536
1803
|
|
1537
1804
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
1805
|
+
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
1538
1806
|
self.memory_saver_adapter.resume()
|
1539
1807
|
_import_static_state(
|
1540
1808
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
@@ -1634,6 +1902,17 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1634
1902
|
ProfileReqOutput(success=True, message="Succeeded.")
|
1635
1903
|
)
|
1636
1904
|
|
1905
|
+
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
1906
|
+
if recv_req == ExpertDistributionReq.START_RECORD:
|
1907
|
+
expert_distribution_recorder.start_record()
|
1908
|
+
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
1909
|
+
expert_distribution_recorder.stop_record()
|
1910
|
+
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
1911
|
+
expert_distribution_recorder.dump_record()
|
1912
|
+
else:
|
1913
|
+
raise ValueError("Unrecognized ExpertDistributionReq value")
|
1914
|
+
return ExpertDistributionReqOutput()
|
1915
|
+
|
1637
1916
|
def open_session(self, recv_req: OpenSessionReqInput):
|
1638
1917
|
# handle error
|
1639
1918
|
session_id = recv_req.session_id
|
@@ -1692,7 +1971,7 @@ def run_scheduler_process(
|
|
1692
1971
|
prefix = f" DP{dp_rank} TP{tp_rank}"
|
1693
1972
|
|
1694
1973
|
# Config the process
|
1695
|
-
|
1974
|
+
kill_itself_when_parent_died()
|
1696
1975
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
1697
1976
|
faulthandler.enable()
|
1698
1977
|
parent_process = psutil.Process().parent()
|
@@ -1719,10 +1998,18 @@ def run_scheduler_process(
|
|
1719
1998
|
"max_req_input_len": scheduler.max_req_input_len,
|
1720
1999
|
}
|
1721
2000
|
)
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
scheduler.
|
2001
|
+
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2002
|
+
|
2003
|
+
if disaggregation_mode == DisaggregationMode.NULL:
|
2004
|
+
if scheduler.enable_overlap:
|
2005
|
+
scheduler.event_loop_overlap()
|
2006
|
+
else:
|
2007
|
+
scheduler.event_loop_normal()
|
2008
|
+
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
2009
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2010
|
+
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2011
|
+
scheduler.event_loop_normal_disagg_decode()
|
2012
|
+
|
1726
2013
|
except Exception:
|
1727
2014
|
traceback = get_exception_traceback()
|
1728
2015
|
logger.error(f"Scheduler hit an exception: {traceback}")
|
@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
|
|
111
111
|
]
|
112
112
|
.cpu()
|
113
113
|
.clone()
|
114
|
+
.tolist()
|
114
115
|
)
|
115
116
|
|
116
117
|
if req.grammar is not None:
|
@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
|
|
245
246
|
)
|
246
247
|
|
247
248
|
if req.return_hidden_states and logits_output.hidden_states is not None:
|
248
|
-
req.hidden_states.append(
|
249
|
+
req.hidden_states.append(
|
250
|
+
logits_output.hidden_states[i].cpu().clone().tolist()
|
251
|
+
)
|
249
252
|
|
250
253
|
if req.grammar is not None and batch.spec_algorithm.is_none():
|
251
254
|
req.grammar.accept_token(next_token_id)
|
@@ -138,7 +138,7 @@ class Session:
|
|
138
138
|
token_ids_logprob=req.token_ids_logprob,
|
139
139
|
)
|
140
140
|
if last_req is not None:
|
141
|
-
new_req.
|
141
|
+
new_req.multimodal_inputs = last_req.mm_inputs
|
142
142
|
new_req.tokenizer = tokenizer
|
143
143
|
if abort:
|
144
144
|
new_req.to_abort = True
|