sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -32,16 +32,33 @@ import psutil
|
|
32
32
|
import setproctitle
|
33
33
|
import torch
|
34
34
|
import zmq
|
35
|
+
from torch.distributed import barrier
|
35
36
|
|
36
37
|
from sglang.global_config import global_config
|
37
38
|
from sglang.srt.configs.model_config import ModelConfig
|
38
39
|
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
40
|
+
from sglang.srt.disaggregation.decode import (
|
41
|
+
DecodePreallocQueue,
|
42
|
+
DecodeTransferQueue,
|
43
|
+
SchedulerDisaggregationDecodeMixin,
|
44
|
+
)
|
45
|
+
from sglang.srt.disaggregation.prefill import (
|
46
|
+
PrefillBootstrapQueue,
|
47
|
+
SchedulerDisaggregationPrefillMixin,
|
48
|
+
)
|
49
|
+
from sglang.srt.disaggregation.utils import (
|
50
|
+
DisaggregationMode,
|
51
|
+
ReqToMetadataIdxAllocator,
|
52
|
+
)
|
39
53
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
40
54
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
41
55
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
56
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
42
57
|
from sglang.srt.managers.io_struct import (
|
43
58
|
AbortReq,
|
44
59
|
CloseSessionReqInput,
|
60
|
+
ExpertDistributionReq,
|
61
|
+
ExpertDistributionReqOutput,
|
45
62
|
FlushCacheReq,
|
46
63
|
GetInternalStateReq,
|
47
64
|
GetInternalStateReqOutput,
|
@@ -59,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
|
59
76
|
ReleaseMemoryOccupationReqOutput,
|
60
77
|
ResumeMemoryOccupationReqInput,
|
61
78
|
ResumeMemoryOccupationReqOutput,
|
79
|
+
RpcReqInput,
|
80
|
+
RpcReqOutput,
|
62
81
|
SetInternalStateReq,
|
63
82
|
SetInternalStateReqOutput,
|
64
83
|
TokenizedEmbeddingReqInput,
|
@@ -72,7 +91,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
91
|
)
|
73
92
|
from sglang.srt.managers.schedule_batch import (
|
74
93
|
FINISH_ABORT,
|
75
|
-
|
94
|
+
MultimodalInputs,
|
76
95
|
Req,
|
77
96
|
ScheduleBatch,
|
78
97
|
global_server_args_dict,
|
@@ -98,6 +117,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
98
117
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
99
118
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
100
119
|
from sglang.srt.utils import (
|
120
|
+
DynamicGradMode,
|
101
121
|
broadcast_pyobj,
|
102
122
|
configure_logger,
|
103
123
|
crash_on_warnings,
|
@@ -111,6 +131,8 @@ from sglang.srt.utils import (
|
|
111
131
|
)
|
112
132
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
113
133
|
|
134
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
135
|
+
|
114
136
|
logger = logging.getLogger(__name__)
|
115
137
|
|
116
138
|
# Test retract decode for debugging purposes
|
@@ -133,7 +155,11 @@ class EmbeddingBatchResult:
|
|
133
155
|
bid: int
|
134
156
|
|
135
157
|
|
136
|
-
class Scheduler(
|
158
|
+
class Scheduler(
|
159
|
+
SchedulerOutputProcessorMixin,
|
160
|
+
SchedulerDisaggregationDecodeMixin,
|
161
|
+
SchedulerDisaggregationPrefillMixin,
|
162
|
+
):
|
137
163
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
138
164
|
|
139
165
|
def __init__(
|
@@ -193,8 +219,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
193
219
|
self.send_to_detokenizer = get_zmq_socket(
|
194
220
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
195
221
|
)
|
222
|
+
|
223
|
+
self.recv_from_rpc = get_zmq_socket(
|
224
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
225
|
+
)
|
196
226
|
else:
|
197
227
|
self.recv_from_tokenizer = None
|
228
|
+
self.recv_from_rpc = None
|
198
229
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
199
230
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
200
231
|
|
@@ -348,7 +379,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
348
379
|
# Init profiler
|
349
380
|
self.torch_profiler = None
|
350
381
|
self.torch_profiler_output_dir: Optional[str] = None
|
351
|
-
self.
|
382
|
+
self.profiler_activities: Optional[List[str]] = None
|
352
383
|
self.profiler_target_forward_ct: Optional[int] = None
|
353
384
|
|
354
385
|
# Init metrics stats
|
@@ -376,9 +407,16 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
376
407
|
(ProfileReq, self.profile),
|
377
408
|
(GetInternalStateReq, self.get_internal_state),
|
378
409
|
(SetInternalStateReq, self.set_internal_state),
|
410
|
+
(RpcReqInput, self.handle_rpc_request),
|
411
|
+
(ExpertDistributionReq, self.expert_distribution_handle),
|
379
412
|
]
|
380
413
|
)
|
381
414
|
|
415
|
+
self.disaggregation_mode = DisaggregationMode(
|
416
|
+
self.server_args.disaggregation_mode
|
417
|
+
)
|
418
|
+
self.init_disaggregation()
|
419
|
+
|
382
420
|
def init_tokenizer(self):
|
383
421
|
server_args = self.server_args
|
384
422
|
|
@@ -435,6 +473,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
435
473
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
474
|
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
437
475
|
page_size=self.page_size,
|
476
|
+
hicache_ratio=server_args.hicache_ratio,
|
438
477
|
)
|
439
478
|
else:
|
440
479
|
self.tree_cache = RadixCache(
|
@@ -478,7 +517,74 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
478
517
|
},
|
479
518
|
)
|
480
519
|
|
481
|
-
|
520
|
+
def init_disaggregation(self):
|
521
|
+
if (
|
522
|
+
self.disaggregation_mode == DisaggregationMode.DECODE
|
523
|
+
): # *2 for the headroom.
|
524
|
+
buffer_size = (self.req_to_token_pool.size) * 2
|
525
|
+
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
526
|
+
buffer_size
|
527
|
+
)
|
528
|
+
aux_dtype = torch.int32
|
529
|
+
# A list of metadata buffers. The shape is (b, metadata_size) where
|
530
|
+
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
531
|
+
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
532
|
+
output_id_buffer = torch.zeros(
|
533
|
+
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
534
|
+
)
|
535
|
+
metadata_buffers = [output_id_buffer]
|
536
|
+
|
537
|
+
# The decode requests polling kv cache
|
538
|
+
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
539
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
540
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
541
|
+
metadata_buffers=metadata_buffers,
|
542
|
+
)
|
543
|
+
|
544
|
+
# The decode requests pending for pre-allocation
|
545
|
+
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
546
|
+
req_to_token_pool=self.req_to_token_pool,
|
547
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
548
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
549
|
+
metadata_buffers=metadata_buffers,
|
550
|
+
aux_dtype=aux_dtype,
|
551
|
+
scheduler=self,
|
552
|
+
transfer_queue=self.disagg_decode_transfer_queue,
|
553
|
+
tree_cache=self.tree_cache,
|
554
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
555
|
+
tp_rank=self.tp_rank,
|
556
|
+
tp_size=self.tp_size,
|
557
|
+
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
558
|
+
)
|
559
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
560
|
+
# *2 for the headroom.
|
561
|
+
buffer_size = self.max_running_requests * 2
|
562
|
+
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
563
|
+
buffer_size
|
564
|
+
)
|
565
|
+
aux_dtype = torch.int32
|
566
|
+
# A list of metadata buffers. The shape is (b, metadata_size) where
|
567
|
+
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
568
|
+
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
569
|
+
output_id_buffer = torch.zeros(
|
570
|
+
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
571
|
+
)
|
572
|
+
metadata_buffers = [output_id_buffer]
|
573
|
+
|
574
|
+
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
|
575
|
+
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
576
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
577
|
+
metadata_buffers=metadata_buffers,
|
578
|
+
aux_dtype=aux_dtype,
|
579
|
+
tp_rank=self.tp_rank,
|
580
|
+
tp_size=self.tp_size,
|
581
|
+
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
582
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
583
|
+
)
|
584
|
+
# The prefill requests that are in the middle of kv sending
|
585
|
+
self.disagg_prefill_infight_queue: List[Req] = []
|
586
|
+
|
587
|
+
@DynamicGradMode()
|
482
588
|
def event_loop_normal(self):
|
483
589
|
"""A normal scheduler loop."""
|
484
590
|
while True:
|
@@ -498,7 +604,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
498
604
|
|
499
605
|
self.last_batch = batch
|
500
606
|
|
501
|
-
@
|
607
|
+
@DynamicGradMode()
|
502
608
|
def event_loop_overlap(self):
|
503
609
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
504
610
|
self.result_queue = deque()
|
@@ -538,6 +644,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
538
644
|
|
539
645
|
self.last_batch = batch
|
540
646
|
|
647
|
+
@torch.no_grad()
|
648
|
+
def event_loop_normal_disagg_prefill(self):
|
649
|
+
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
650
|
+
|
651
|
+
while True:
|
652
|
+
recv_reqs = self.recv_requests()
|
653
|
+
self.process_input_requests(recv_reqs)
|
654
|
+
self.waiting_queue.extend(
|
655
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
656
|
+
)
|
657
|
+
self.process_prefill_chunk()
|
658
|
+
batch = self.get_new_batch_prefill()
|
659
|
+
self.cur_batch = batch
|
660
|
+
|
661
|
+
if batch:
|
662
|
+
result = self.run_batch(batch)
|
663
|
+
self.process_batch_result_disagg_prefill(batch, result)
|
664
|
+
|
665
|
+
if len(self.disagg_prefill_infight_queue) > 0:
|
666
|
+
self.process_disagg_prefill_infight_queue()
|
667
|
+
|
668
|
+
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
669
|
+
self.check_memory()
|
670
|
+
self.new_token_ratio = self.init_new_token_ratio
|
671
|
+
|
672
|
+
self.last_batch = batch
|
673
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
674
|
+
# Otherwise, it hangs under high concurrency
|
675
|
+
self.running_batch.batch_is_full = False
|
676
|
+
|
677
|
+
@torch.no_grad()
|
678
|
+
def event_loop_normal_disagg_decode(self):
|
679
|
+
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
680
|
+
|
681
|
+
while True:
|
682
|
+
recv_reqs = self.recv_requests()
|
683
|
+
self.process_input_requests(recv_reqs)
|
684
|
+
# polling and allocating kv cache
|
685
|
+
self.process_decode_queue()
|
686
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
687
|
+
self.cur_batch = batch
|
688
|
+
|
689
|
+
if batch:
|
690
|
+
# Generate fake extend output.
|
691
|
+
if batch.forward_mode.is_extend():
|
692
|
+
# Note: Logprobs should be handled on the prefill engine.
|
693
|
+
self.stream_output(
|
694
|
+
batch.reqs, [False for _ in range(len(batch.reqs))]
|
695
|
+
)
|
696
|
+
else:
|
697
|
+
result = self.run_batch(batch)
|
698
|
+
self.process_batch_result(batch, result)
|
699
|
+
|
700
|
+
if batch is None and (
|
701
|
+
len(self.disagg_decode_transfer_queue.queue)
|
702
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
703
|
+
== 0
|
704
|
+
):
|
705
|
+
# When the server is idle, do self-check and re-init some states
|
706
|
+
self.check_memory()
|
707
|
+
self.new_token_ratio = self.init_new_token_ratio
|
708
|
+
|
709
|
+
self.last_batch = batch
|
710
|
+
|
541
711
|
def recv_requests(self) -> List[Req]:
|
542
712
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
543
713
|
if self.attn_tp_rank == 0:
|
@@ -549,6 +719,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
549
719
|
except zmq.ZMQError:
|
550
720
|
break
|
551
721
|
recv_reqs.append(recv_req)
|
722
|
+
|
723
|
+
while True:
|
724
|
+
try:
|
725
|
+
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
726
|
+
except zmq.ZMQError:
|
727
|
+
break
|
728
|
+
recv_reqs.append(recv_rpc)
|
552
729
|
else:
|
553
730
|
recv_reqs = None
|
554
731
|
|
@@ -600,7 +777,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
600
777
|
|
601
778
|
output = self._request_dispatcher(recv_req)
|
602
779
|
if output is not None:
|
603
|
-
|
780
|
+
if isinstance(output, RpcReqOutput):
|
781
|
+
if self.recv_from_rpc is not None:
|
782
|
+
self.recv_from_rpc.send_pyobj(output)
|
783
|
+
else:
|
784
|
+
self.send_to_tokenizer.send_pyobj(output)
|
604
785
|
|
605
786
|
def handle_generate_request(
|
606
787
|
self,
|
@@ -666,8 +847,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
666
847
|
return
|
667
848
|
|
668
849
|
# Handle multimodal inputs
|
669
|
-
if recv_req.
|
670
|
-
image_inputs =
|
850
|
+
if recv_req.mm_inputs is not None:
|
851
|
+
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
|
671
852
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
672
853
|
req.origin_input_ids = self.pad_input_ids_func(
|
673
854
|
req.origin_input_ids, image_inputs
|
@@ -681,7 +862,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
681
862
|
)
|
682
863
|
logger.error(error_msg)
|
683
864
|
req.origin_input_ids = [0]
|
684
|
-
req.
|
865
|
+
req.multimodal_inputs = None
|
685
866
|
req.sampling_params.max_new_tokens = 0
|
686
867
|
req.finished_reason = FINISH_ABORT(
|
687
868
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
@@ -756,10 +937,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
756
937
|
self._add_request_to_queue(req)
|
757
938
|
|
758
939
|
def _add_request_to_queue(self, req: Req):
|
759
|
-
self.
|
940
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
941
|
+
self.disagg_prefill_pending_queue.add(req)
|
942
|
+
|
943
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
944
|
+
self.disagg_decode_prealloc_queue.add(req)
|
945
|
+
|
946
|
+
else:
|
947
|
+
self.waiting_queue.append(req)
|
760
948
|
|
761
|
-
def _extend_requests_to_queue(self, reqs: List[Req]):
|
762
|
-
self.
|
949
|
+
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
950
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
951
|
+
self.disagg_decode_prealloc_queue.extend(reqs)
|
952
|
+
else:
|
953
|
+
self.waiting_queue.extend(reqs)
|
763
954
|
|
764
955
|
def handle_embedding_request(
|
765
956
|
self,
|
@@ -775,7 +966,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
775
966
|
|
776
967
|
# Handle multimodal inputs
|
777
968
|
if recv_req.image_inputs is not None:
|
778
|
-
image_inputs =
|
969
|
+
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
|
779
970
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
780
971
|
req.origin_input_ids = self.pad_input_ids_func(
|
781
972
|
req.origin_input_ids, image_inputs
|
@@ -789,7 +980,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
789
980
|
)
|
790
981
|
logger.error(error_msg)
|
791
982
|
req.origin_input_ids = [0]
|
792
|
-
req.
|
983
|
+
req.multimodal_inputs = None
|
793
984
|
req.sampling_params.max_new_tokens = 0
|
794
985
|
req.finished_reason = FINISH_ABORT(
|
795
986
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
@@ -875,7 +1066,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
875
1066
|
f"#token: {num_used}, "
|
876
1067
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
877
1068
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
878
|
-
f"largest-len: {self._largest_prefill_decode_len}, "
|
879
1069
|
f"#queue-req: {len(self.waiting_queue)}, "
|
880
1070
|
)
|
881
1071
|
spec_accept_length = 0
|
@@ -893,7 +1083,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
893
1083
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
894
1084
|
f"accept len: {spec_accept_length:.2f}, "
|
895
1085
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
896
|
-
f"largest-len: {self._largest_prefill_decode_len}, "
|
897
1086
|
f"#queue-req: {len(self.waiting_queue)}, "
|
898
1087
|
)
|
899
1088
|
|
@@ -997,7 +1186,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
997
1186
|
ret = None
|
998
1187
|
|
999
1188
|
# Handle DP attention
|
1000
|
-
if self.server_args.enable_dp_attention:
|
1189
|
+
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
|
1001
1190
|
ret, _ = self.prepare_dp_attn_batch(ret)
|
1002
1191
|
|
1003
1192
|
return ret
|
@@ -1492,6 +1681,41 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1492
1681
|
server_args=global_server_args_dict,
|
1493
1682
|
)
|
1494
1683
|
|
1684
|
+
def handle_rpc_request(self, recv_req: RpcReqInput):
|
1685
|
+
# Handle RPC requests
|
1686
|
+
logger.info(
|
1687
|
+
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
1688
|
+
)
|
1689
|
+
|
1690
|
+
success = True
|
1691
|
+
exec = None
|
1692
|
+
try:
|
1693
|
+
func = getattr(self, recv_req.method)
|
1694
|
+
func(recv_req.parameters)
|
1695
|
+
except Exception as e:
|
1696
|
+
success = False
|
1697
|
+
exec = e
|
1698
|
+
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
1699
|
+
|
1700
|
+
barrier()
|
1701
|
+
return RpcReqOutput(success, "" if not exec else str(exec))
|
1702
|
+
|
1703
|
+
def save_remote_model(self, params):
|
1704
|
+
url = params["url"]
|
1705
|
+
|
1706
|
+
worker = self.tp_worker.worker
|
1707
|
+
|
1708
|
+
worker.model_runner.save_remote_model(url)
|
1709
|
+
|
1710
|
+
def save_sharded_model(self, params):
|
1711
|
+
worker = self.tp_worker.worker
|
1712
|
+
|
1713
|
+
worker.model_runner.save_sharded_model(
|
1714
|
+
path=params["path"],
|
1715
|
+
pattern=params["pattern"],
|
1716
|
+
max_size=params["max_size"],
|
1717
|
+
)
|
1718
|
+
|
1495
1719
|
def abort_request(self, recv_req: AbortReq):
|
1496
1720
|
# Delete requests in the waiting queue
|
1497
1721
|
to_del = []
|
@@ -1561,6 +1785,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1561
1785
|
return GetWeightsByNameReqOutput(parameter)
|
1562
1786
|
|
1563
1787
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
1788
|
+
self.memory_saver_adapter.check_validity(
|
1789
|
+
caller_name="release_memory_occupation"
|
1790
|
+
)
|
1564
1791
|
self.stashed_model_static_state = _export_static_state(
|
1565
1792
|
self.tp_worker.worker.model_runner.model
|
1566
1793
|
)
|
@@ -1569,6 +1796,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1569
1796
|
return ReleaseMemoryOccupationReqOutput()
|
1570
1797
|
|
1571
1798
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
1799
|
+
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
1572
1800
|
self.memory_saver_adapter.resume()
|
1573
1801
|
_import_static_state(
|
1574
1802
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
@@ -1579,7 +1807,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1579
1807
|
def profile(self, recv_req: ProfileReq):
|
1580
1808
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
1581
1809
|
return self.start_profile(
|
1582
|
-
recv_req.output_dir,
|
1810
|
+
recv_req.output_dir,
|
1811
|
+
recv_req.num_steps,
|
1812
|
+
recv_req.activities,
|
1813
|
+
recv_req.with_stack,
|
1814
|
+
recv_req.record_shapes,
|
1583
1815
|
)
|
1584
1816
|
else:
|
1585
1817
|
return self.stop_profile()
|
@@ -1589,8 +1821,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1589
1821
|
output_dir: Optional[str],
|
1590
1822
|
num_steps: Optional[int],
|
1591
1823
|
activities: Optional[List[str]],
|
1824
|
+
with_stack: Optional[bool],
|
1825
|
+
record_shapes: Optional[bool],
|
1592
1826
|
) -> None:
|
1593
|
-
if self.
|
1827
|
+
if self.profiler_activities:
|
1594
1828
|
return ProfileReqOutput(
|
1595
1829
|
success=False,
|
1596
1830
|
message="Profiling is already in progress. Call /stop_profile first.",
|
@@ -1602,7 +1836,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1602
1836
|
activities = ["CPU", "GPU"]
|
1603
1837
|
|
1604
1838
|
self.torch_profiler_output_dir = output_dir
|
1605
|
-
self.
|
1839
|
+
self.profiler_activities = activities
|
1606
1840
|
logger.info(
|
1607
1841
|
"Profiling starts. Traces will be saved to: %s",
|
1608
1842
|
self.torch_profiler_output_dir,
|
@@ -1619,13 +1853,17 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1619
1853
|
if torchprof_activities:
|
1620
1854
|
self.torch_profiler = torch.profiler.profile(
|
1621
1855
|
activities=torchprof_activities,
|
1622
|
-
with_stack=True,
|
1856
|
+
with_stack=with_stack if with_stack is not None else True,
|
1857
|
+
record_shapes=record_shapes if record_shapes is not None else False,
|
1623
1858
|
)
|
1624
1859
|
self.torch_profiler.start()
|
1625
1860
|
|
1626
1861
|
if "MEM" in activities:
|
1627
1862
|
torch.cuda.memory._record_memory_history(max_entries=100000)
|
1628
1863
|
|
1864
|
+
if "CUDA_PROFILER" in activities:
|
1865
|
+
torch.cuda.cudart().cudaProfilerStart()
|
1866
|
+
|
1629
1867
|
if num_steps:
|
1630
1868
|
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
1631
1869
|
# The caller will be notified when reaching profiler_target_forward_ct
|
@@ -1634,7 +1872,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1634
1872
|
return ProfileReqOutput(success=True, message="Succeeded")
|
1635
1873
|
|
1636
1874
|
def stop_profile(self) -> None:
|
1637
|
-
if self.
|
1875
|
+
if self.profiler_activities is None:
|
1638
1876
|
return
|
1639
1877
|
|
1640
1878
|
logger.info("Stop profiling...")
|
@@ -1647,27 +1885,41 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1647
1885
|
)
|
1648
1886
|
)
|
1649
1887
|
|
1650
|
-
if "MEM" in self.
|
1888
|
+
if "MEM" in self.profiler_activities:
|
1651
1889
|
memory_profile_path = os.path.join(
|
1652
|
-
self.
|
1890
|
+
self.torch_profiler_output_dir,
|
1653
1891
|
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
1654
1892
|
)
|
1655
1893
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
1656
1894
|
torch.cuda.memory._record_memory_history(enabled=None)
|
1657
1895
|
|
1896
|
+
if "CUDA_PROFILER" in self.profiler_activities:
|
1897
|
+
torch.cuda.cudart().cudaProfilerStop()
|
1898
|
+
|
1658
1899
|
logger.info(
|
1659
1900
|
"Profiling done. Traces are saved to: %s",
|
1660
1901
|
self.torch_profiler_output_dir,
|
1661
1902
|
)
|
1662
1903
|
self.torch_profiler = None
|
1663
1904
|
self.torch_profiler_output_dir = None
|
1664
|
-
self.
|
1905
|
+
self.profiler_activities = None
|
1665
1906
|
|
1666
1907
|
if self.profiler_target_forward_ct:
|
1667
1908
|
self.send_to_tokenizer.send_pyobj(
|
1668
1909
|
ProfileReqOutput(success=True, message="Succeeded.")
|
1669
1910
|
)
|
1670
1911
|
|
1912
|
+
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
1913
|
+
if recv_req == ExpertDistributionReq.START_RECORD:
|
1914
|
+
expert_distribution_recorder.start_record()
|
1915
|
+
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
1916
|
+
expert_distribution_recorder.stop_record()
|
1917
|
+
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
1918
|
+
expert_distribution_recorder.dump_record()
|
1919
|
+
else:
|
1920
|
+
raise ValueError("Unrecognized ExpertDistributionReq value")
|
1921
|
+
return ExpertDistributionReqOutput()
|
1922
|
+
|
1671
1923
|
def open_session(self, recv_req: OpenSessionReqInput):
|
1672
1924
|
# handle error
|
1673
1925
|
session_id = recv_req.session_id
|
@@ -1718,7 +1970,6 @@ def run_scheduler_process(
|
|
1718
1970
|
dp_rank: Optional[int],
|
1719
1971
|
pipe_writer,
|
1720
1972
|
):
|
1721
|
-
|
1722
1973
|
# Generate the prefix
|
1723
1974
|
if dp_rank is None:
|
1724
1975
|
prefix = f" TP{tp_rank}"
|
@@ -1726,7 +1977,7 @@ def run_scheduler_process(
|
|
1726
1977
|
prefix = f" DP{dp_rank} TP{tp_rank}"
|
1727
1978
|
|
1728
1979
|
# Config the process
|
1729
|
-
|
1980
|
+
kill_itself_when_parent_died()
|
1730
1981
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
1731
1982
|
faulthandler.enable()
|
1732
1983
|
parent_process = psutil.Process().parent()
|
@@ -1753,10 +2004,18 @@ def run_scheduler_process(
|
|
1753
2004
|
"max_req_input_len": scheduler.max_req_input_len,
|
1754
2005
|
}
|
1755
2006
|
)
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
scheduler.
|
2007
|
+
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2008
|
+
|
2009
|
+
if disaggregation_mode == DisaggregationMode.NULL:
|
2010
|
+
if scheduler.enable_overlap:
|
2011
|
+
scheduler.event_loop_overlap()
|
2012
|
+
else:
|
2013
|
+
scheduler.event_loop_normal()
|
2014
|
+
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
2015
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2016
|
+
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2017
|
+
scheduler.event_loop_normal_disagg_decode()
|
2018
|
+
|
1760
2019
|
except Exception:
|
1761
2020
|
traceback = get_exception_traceback()
|
1762
2021
|
logger.error(f"Scheduler hit an exception: {traceback}")
|
@@ -138,7 +138,7 @@ class Session:
|
|
138
138
|
token_ids_logprob=req.token_ids_logprob,
|
139
139
|
)
|
140
140
|
if last_req is not None:
|
141
|
-
new_req.
|
141
|
+
new_req.multimodal_inputs = last_req.mm_inputs
|
142
142
|
new_req.tokenizer = tokenizer
|
143
143
|
if abort:
|
144
144
|
new_req.to_abort = True
|