sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
|
@@ -29,6 +29,7 @@ from typing import Deque, Dict, List, Optional, Tuple, Union
|
|
|
29
29
|
import psutil
|
|
30
30
|
import setproctitle
|
|
31
31
|
import torch
|
|
32
|
+
import torch.distributed
|
|
32
33
|
import zmq
|
|
33
34
|
from torch.cuda import Stream as CudaStream
|
|
34
35
|
from torch.cuda import StreamContext as CudaStreamContext
|
|
@@ -151,11 +152,13 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
|
|
151
152
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
|
152
153
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
|
153
154
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|
155
|
+
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
|
|
154
156
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
|
155
157
|
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
|
156
158
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
157
159
|
from sglang.srt.tracing.trace import (
|
|
158
160
|
process_tracing_init,
|
|
161
|
+
trace_event_batch,
|
|
159
162
|
trace_set_proc_propagate_context,
|
|
160
163
|
trace_set_thread_info,
|
|
161
164
|
trace_slice_batch,
|
|
@@ -168,7 +171,6 @@ from sglang.srt.utils import (
|
|
|
168
171
|
broadcast_pyobj,
|
|
169
172
|
configure_gc_logger,
|
|
170
173
|
configure_logger,
|
|
171
|
-
disable_request_logging,
|
|
172
174
|
freeze_gc,
|
|
173
175
|
get_available_gpu_memory,
|
|
174
176
|
get_bool_env_var,
|
|
@@ -177,7 +179,6 @@ from sglang.srt.utils import (
|
|
|
177
179
|
kill_itself_when_parent_died,
|
|
178
180
|
numa_bind_to_node,
|
|
179
181
|
point_to_point_pyobj,
|
|
180
|
-
pyspy_dump_schedulers,
|
|
181
182
|
require_mlp_sync,
|
|
182
183
|
require_mlp_tp_gather,
|
|
183
184
|
set_gpu_proc_affinity,
|
|
@@ -197,6 +198,7 @@ logger = logging.getLogger(__name__)
|
|
|
197
198
|
# Test retract decode for debugging purposes
|
|
198
199
|
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
|
|
199
200
|
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
|
|
201
|
+
TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
|
|
200
202
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
|
201
203
|
|
|
202
204
|
|
|
@@ -212,6 +214,7 @@ class Scheduler(
|
|
|
212
214
|
SchedulerMetricsMixin,
|
|
213
215
|
SchedulerDisaggregationDecodeMixin,
|
|
214
216
|
SchedulerDisaggregationPrefillMixin,
|
|
217
|
+
SchedulerMultiplexMixin,
|
|
215
218
|
SchedulerRuntimeCheckerMixin,
|
|
216
219
|
SchedulerPPMixin,
|
|
217
220
|
):
|
|
@@ -251,6 +254,7 @@ class Scheduler(
|
|
|
251
254
|
self.enable_lora = server_args.enable_lora
|
|
252
255
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
|
253
256
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
|
257
|
+
self.enable_pdmux = server_args.enable_pdmux
|
|
254
258
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
|
255
259
|
self.enable_metrics = server_args.enable_metrics
|
|
256
260
|
self.enable_metrics_for_all_schedulers = (
|
|
@@ -284,6 +288,10 @@ class Scheduler(
|
|
|
284
288
|
# Init inter-process communication
|
|
285
289
|
self.init_sockets(server_args, port_args)
|
|
286
290
|
|
|
291
|
+
# Init pdmux context
|
|
292
|
+
if self.enable_pdmux:
|
|
293
|
+
self.init_pdmux()
|
|
294
|
+
|
|
287
295
|
# Init tokenizer
|
|
288
296
|
self.init_tokenizer()
|
|
289
297
|
|
|
@@ -320,8 +328,28 @@ class Scheduler(
|
|
|
320
328
|
|
|
321
329
|
# Launch a draft worker for speculative decoding
|
|
322
330
|
|
|
323
|
-
|
|
324
|
-
gpu_id,
|
|
331
|
+
draft_worker_kwargs = dict(
|
|
332
|
+
gpu_id=gpu_id,
|
|
333
|
+
tp_rank=tp_rank,
|
|
334
|
+
moe_ep_rank=moe_ep_rank,
|
|
335
|
+
server_args=server_args,
|
|
336
|
+
nccl_port=port_args.nccl_port,
|
|
337
|
+
target_worker=self.tp_worker,
|
|
338
|
+
dp_rank=dp_rank,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if server_args.speculative_draft_load_format is not None:
|
|
342
|
+
server_args.load_format = server_args.speculative_draft_load_format
|
|
343
|
+
logger.info(
|
|
344
|
+
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Draft workers are looked up via `SpeculativeAlgorithm` registry; new
|
|
348
|
+
# algorithms should register their factory instead of patching this code.
|
|
349
|
+
if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
|
|
350
|
+
draft_worker_kwargs["enable_overlap"] = self.enable_overlap
|
|
351
|
+
self.draft_worker = self.spec_algorithm.create_draft_worker(
|
|
352
|
+
**draft_worker_kwargs
|
|
325
353
|
)
|
|
326
354
|
|
|
327
355
|
# Dispatch the model worker
|
|
@@ -356,6 +384,17 @@ class Scheduler(
|
|
|
356
384
|
self.pp_group = get_pp_group()
|
|
357
385
|
self.world_group = get_world_group()
|
|
358
386
|
|
|
387
|
+
# With DP attention enabled, the entry rank is attn_tp_rank==0;
|
|
388
|
+
# otherwise the entry rank is TP group local rank 0.
|
|
389
|
+
# For #11910, use the CPU communication group to broadcast VLM Python objects,
|
|
390
|
+
# avoiding any coupling with CUDA streams/devices.
|
|
391
|
+
if self.server_args.enable_dp_attention:
|
|
392
|
+
self.cpu_group = self.attn_tp_cpu_group
|
|
393
|
+
self.is_entry_rank = self.attn_tp_rank == 0
|
|
394
|
+
else:
|
|
395
|
+
self.cpu_group = self.tp_cpu_group
|
|
396
|
+
self.is_entry_rank = self.tp_group.rank == 0
|
|
397
|
+
|
|
359
398
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
|
360
399
|
set_random_seed(self.random_seed)
|
|
361
400
|
|
|
@@ -392,6 +431,8 @@ class Scheduler(
|
|
|
392
431
|
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
|
393
432
|
# The current forward batch
|
|
394
433
|
self.cur_batch: Optional[ScheduleBatch] = None
|
|
434
|
+
# The current split prefill batch
|
|
435
|
+
self.split_prefill_batch: Optional[ScheduleBatch] = None
|
|
395
436
|
# The last forward batch
|
|
396
437
|
self.last_batch: Optional[ScheduleBatch] = None
|
|
397
438
|
self.forward_ct = 0
|
|
@@ -494,7 +535,7 @@ class Scheduler(
|
|
|
494
535
|
)
|
|
495
536
|
self.init_disaggregation()
|
|
496
537
|
|
|
497
|
-
if
|
|
538
|
+
if envs.SGLANG_LOG_GC.get():
|
|
498
539
|
configure_gc_logger()
|
|
499
540
|
|
|
500
541
|
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
|
@@ -548,57 +589,6 @@ class Scheduler(
|
|
|
548
589
|
]
|
|
549
590
|
)
|
|
550
591
|
|
|
551
|
-
def launch_draft_worker(
|
|
552
|
-
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
|
553
|
-
):
|
|
554
|
-
if server_args.speculative_draft_load_format is not None:
|
|
555
|
-
server_args.load_format = server_args.speculative_draft_load_format
|
|
556
|
-
logger.info(
|
|
557
|
-
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
|
|
558
|
-
)
|
|
559
|
-
|
|
560
|
-
if self.spec_algorithm.is_eagle():
|
|
561
|
-
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
562
|
-
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
|
|
563
|
-
|
|
564
|
-
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
|
|
565
|
-
|
|
566
|
-
self.draft_worker = WorkerClass(
|
|
567
|
-
gpu_id=gpu_id,
|
|
568
|
-
tp_rank=tp_rank,
|
|
569
|
-
moe_ep_rank=moe_ep_rank,
|
|
570
|
-
server_args=server_args,
|
|
571
|
-
nccl_port=port_args.nccl_port,
|
|
572
|
-
target_worker=self.tp_worker,
|
|
573
|
-
dp_rank=dp_rank,
|
|
574
|
-
)
|
|
575
|
-
elif self.spec_algorithm.is_standalone():
|
|
576
|
-
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
|
577
|
-
|
|
578
|
-
self.draft_worker = StandaloneWorker(
|
|
579
|
-
gpu_id=gpu_id,
|
|
580
|
-
tp_rank=tp_rank,
|
|
581
|
-
moe_ep_rank=moe_ep_rank,
|
|
582
|
-
server_args=server_args,
|
|
583
|
-
nccl_port=port_args.nccl_port,
|
|
584
|
-
target_worker=self.tp_worker,
|
|
585
|
-
dp_rank=dp_rank,
|
|
586
|
-
)
|
|
587
|
-
elif self.spec_algorithm.is_ngram():
|
|
588
|
-
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
|
589
|
-
|
|
590
|
-
self.draft_worker = NGRAMWorker(
|
|
591
|
-
gpu_id=gpu_id,
|
|
592
|
-
tp_rank=tp_rank,
|
|
593
|
-
moe_ep_rank=moe_ep_rank,
|
|
594
|
-
server_args=server_args,
|
|
595
|
-
nccl_port=port_args.nccl_port,
|
|
596
|
-
target_worker=self.tp_worker,
|
|
597
|
-
dp_rank=dp_rank,
|
|
598
|
-
)
|
|
599
|
-
else:
|
|
600
|
-
self.draft_worker = None
|
|
601
|
-
|
|
602
592
|
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
|
|
603
593
|
context = zmq.Context(2)
|
|
604
594
|
self.idle_sleeper = None
|
|
@@ -1162,6 +1152,70 @@ class Scheduler(
|
|
|
1162
1152
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
|
1163
1153
|
)
|
|
1164
1154
|
|
|
1155
|
+
def _process_and_broadcast_mm_inputs(
|
|
1156
|
+
self,
|
|
1157
|
+
raw_mm_inputs: Optional[dict],
|
|
1158
|
+
):
|
|
1159
|
+
"""Materialize MultimodalInputs once on the entry rank and broadcast to others.
|
|
1160
|
+
|
|
1161
|
+
Entry rank:
|
|
1162
|
+
- constructs MultimodalInputs.from_dict(raw_mm_inputs) once
|
|
1163
|
+
- broadcasts to other ranks in self.cpu_group (if world_size > 1)
|
|
1164
|
+
|
|
1165
|
+
Non-entry ranks:
|
|
1166
|
+
- receive the object via broadcast (if world_size > 1)
|
|
1167
|
+
- otherwise (single-rank / no group) fall back to local from_dict
|
|
1168
|
+
|
|
1169
|
+
Returns:
|
|
1170
|
+
MultimodalInputs | None
|
|
1171
|
+
"""
|
|
1172
|
+
if raw_mm_inputs is None:
|
|
1173
|
+
return None
|
|
1174
|
+
|
|
1175
|
+
group_world_size = 1
|
|
1176
|
+
try:
|
|
1177
|
+
if (
|
|
1178
|
+
torch.distributed.is_available()
|
|
1179
|
+
and torch.distributed.is_initialized()
|
|
1180
|
+
and self.cpu_group is not None
|
|
1181
|
+
):
|
|
1182
|
+
group_world_size = torch.distributed.get_world_size(
|
|
1183
|
+
group=self.cpu_group
|
|
1184
|
+
)
|
|
1185
|
+
except Exception as e:
|
|
1186
|
+
logger.warning(
|
|
1187
|
+
f"Failed to get world size in mm_inputs handling with {e}, fallback to 1."
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# In case tp size > 1, all the Scheduler TP ranks runs the duplicated computing
|
|
1191
|
+
# process in CPU which occupies the main thread CPU cycle. This computing logic
|
|
1192
|
+
# merely needs to be run on TP0 and be broadcast to other TP ranks.
|
|
1193
|
+
# Since the Scheduler is single-threaded, any large CPU cost will impact
|
|
1194
|
+
# handling of other messages. For example, CPU hits 99.9% can significantly
|
|
1195
|
+
# increase the CUDA kernel launch time.
|
|
1196
|
+
if self.is_entry_rank:
|
|
1197
|
+
# Only the entry rank materializes once from dict.
|
|
1198
|
+
image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
|
|
1199
|
+
# Broadcast to other TP ranks (use src=0 within the group).
|
|
1200
|
+
if group_world_size > 1:
|
|
1201
|
+
obj_list = [image_inputs]
|
|
1202
|
+
torch.distributed.broadcast_object_list(
|
|
1203
|
+
obj_list, src=0, group=self.cpu_group
|
|
1204
|
+
)
|
|
1205
|
+
image_inputs = obj_list[0]
|
|
1206
|
+
else:
|
|
1207
|
+
# Non-entry ranks: receive if group size > 1; otherwise materialize locally.
|
|
1208
|
+
if group_world_size > 1:
|
|
1209
|
+
obj_list = [None]
|
|
1210
|
+
torch.distributed.broadcast_object_list(
|
|
1211
|
+
obj_list, src=0, group=self.cpu_group
|
|
1212
|
+
)
|
|
1213
|
+
image_inputs = obj_list[0]
|
|
1214
|
+
else:
|
|
1215
|
+
image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
|
|
1216
|
+
|
|
1217
|
+
return image_inputs
|
|
1218
|
+
|
|
1165
1219
|
def handle_generate_request(
|
|
1166
1220
|
self,
|
|
1167
1221
|
recv_req: TokenizedGenerateReqInput,
|
|
@@ -1243,7 +1297,9 @@ class Scheduler(
|
|
|
1243
1297
|
|
|
1244
1298
|
# Handle multimodal inputs
|
|
1245
1299
|
if recv_req.mm_inputs is not None:
|
|
1246
|
-
image_inputs =
|
|
1300
|
+
image_inputs = self._process_and_broadcast_mm_inputs(recv_req.mm_inputs)
|
|
1301
|
+
|
|
1302
|
+
# The following steps are already fast, execute locally on each rank.
|
|
1247
1303
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
|
1248
1304
|
req.origin_input_ids = self.pad_input_ids_func(
|
|
1249
1305
|
req.origin_input_ids, image_inputs
|
|
@@ -1376,7 +1432,7 @@ class Scheduler(
|
|
|
1376
1432
|
self._prefetch_kvcache(req)
|
|
1377
1433
|
self.waiting_queue.append(req)
|
|
1378
1434
|
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
|
1379
|
-
trace_slice_end(
|
|
1435
|
+
trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
|
|
1380
1436
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
1381
1437
|
self._prefetch_kvcache(req)
|
|
1382
1438
|
self.disagg_prefill_bootstrap_queue.add(
|
|
@@ -1466,13 +1522,14 @@ class Scheduler(
|
|
|
1466
1522
|
recv_req.sampling_params,
|
|
1467
1523
|
token_type_ids=recv_req.token_type_ids,
|
|
1468
1524
|
priority=recv_req.priority,
|
|
1525
|
+
dimensions=recv_req.dimensions,
|
|
1469
1526
|
http_worker_ipc=recv_req.http_worker_ipc,
|
|
1470
1527
|
)
|
|
1471
1528
|
req.tokenizer = self.tokenizer
|
|
1472
1529
|
|
|
1473
1530
|
# Handle multimodal inputs
|
|
1474
1531
|
if recv_req.image_inputs is not None:
|
|
1475
|
-
image_inputs =
|
|
1532
|
+
image_inputs = self._process_and_broadcast_mm_inputs(recv_req.image_inputs)
|
|
1476
1533
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
|
1477
1534
|
req.origin_input_ids = self.pad_input_ids_func(
|
|
1478
1535
|
req.origin_input_ids, image_inputs
|
|
@@ -1639,6 +1696,10 @@ class Scheduler(
|
|
|
1639
1696
|
if need_dp_attn_preparation:
|
|
1640
1697
|
ret = self.prepare_mlp_sync_batch(ret)
|
|
1641
1698
|
|
|
1699
|
+
if ret:
|
|
1700
|
+
attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
|
|
1701
|
+
trace_event_batch("schedule", ret.reqs, attrs=attrs)
|
|
1702
|
+
|
|
1642
1703
|
return ret
|
|
1643
1704
|
|
|
1644
1705
|
def get_num_allocatable_reqs(self, running_bs):
|
|
@@ -1682,6 +1743,12 @@ class Scheduler(
|
|
|
1682
1743
|
# Get priority queue
|
|
1683
1744
|
self.policy.calc_priority(self.waiting_queue)
|
|
1684
1745
|
|
|
1746
|
+
if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
|
|
1747
|
+
# If we are testing retraction and the running batch size exceeds
|
|
1748
|
+
# TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
|
|
1749
|
+
# in the waiting queue.
|
|
1750
|
+
return None
|
|
1751
|
+
|
|
1685
1752
|
# Prefill policy
|
|
1686
1753
|
adder = PrefillAdder(
|
|
1687
1754
|
self.page_size,
|
|
@@ -1848,14 +1915,14 @@ class Scheduler(
|
|
|
1848
1915
|
self.num_retracted_reqs = len(retracted_reqs)
|
|
1849
1916
|
self.new_token_ratio = new_token_ratio
|
|
1850
1917
|
for req in reqs_to_abort:
|
|
1918
|
+
abort_reason: FINISH_ABORT = req.to_finish
|
|
1851
1919
|
self.send_to_tokenizer.send_output(
|
|
1852
|
-
AbortReq(abort_reason
|
|
1920
|
+
AbortReq(abort_message=abort_reason.message, rid=req.rid), req
|
|
1853
1921
|
)
|
|
1854
1922
|
|
|
1855
1923
|
logger.info(
|
|
1856
1924
|
"KV cache pool is full. Retract requests. "
|
|
1857
1925
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
|
1858
|
-
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
|
1859
1926
|
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
|
1860
1927
|
)
|
|
1861
1928
|
|
|
@@ -1894,7 +1961,6 @@ class Scheduler(
|
|
|
1894
1961
|
|
|
1895
1962
|
# Run forward
|
|
1896
1963
|
if self.is_generation:
|
|
1897
|
-
|
|
1898
1964
|
batch_or_worker_batch = batch
|
|
1899
1965
|
|
|
1900
1966
|
if self.enable_overlap or self.spec_algorithm.is_none():
|
|
@@ -1951,6 +2017,9 @@ class Scheduler(
|
|
|
1951
2017
|
# The future value, usually for next batch preparation
|
|
1952
2018
|
# Current implementation strictly synchronizes the seq_lens
|
|
1953
2019
|
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
|
|
2020
|
+
elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
|
|
2021
|
+
batch_result = self.tp_worker.forward_batch_split_prefill(batch)
|
|
2022
|
+
future_indices_or_next_token_ids = batch_result.next_token_ids
|
|
1954
2023
|
else:
|
|
1955
2024
|
batch_result = self.model_worker.forward_batch_generation(
|
|
1956
2025
|
batch_or_worker_batch
|
|
@@ -2012,13 +2081,10 @@ class Scheduler(
|
|
|
2012
2081
|
):
|
|
2013
2082
|
if batch.forward_mode.is_decode():
|
|
2014
2083
|
self.process_batch_result_decode(batch, result)
|
|
2015
|
-
|
|
2016
|
-
trace_slice_batch("decode loop", batch.reqs)
|
|
2084
|
+
trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
|
|
2017
2085
|
|
|
2018
2086
|
elif batch.forward_mode.is_extend():
|
|
2019
2087
|
self.process_batch_result_prefill(batch, result)
|
|
2020
|
-
if self.enable_trace:
|
|
2021
|
-
trace_slice_batch("prefill", batch.reqs)
|
|
2022
2088
|
|
|
2023
2089
|
elif batch.forward_mode.is_idle():
|
|
2024
2090
|
if self.enable_overlap:
|
|
@@ -2073,15 +2139,18 @@ class Scheduler(
|
|
|
2073
2139
|
num_tokens_for_logprob = num_tokens
|
|
2074
2140
|
else:
|
|
2075
2141
|
num_tokens = local_batch.extend_num_tokens
|
|
2076
|
-
|
|
2077
|
-
|
|
2142
|
+
if local_batch.return_logprob:
|
|
2143
|
+
num_tokens_for_logprob = sum(
|
|
2078
2144
|
# We should have at least 1 token for sample in every case.
|
|
2079
2145
|
max(extend_len - logprob_start_len, 1)
|
|
2080
2146
|
for logprob_start_len, extend_len in zip(
|
|
2081
|
-
local_batch.extend_logprob_start_lens,
|
|
2147
|
+
local_batch.extend_logprob_start_lens,
|
|
2148
|
+
local_batch.extend_lens,
|
|
2082
2149
|
)
|
|
2083
|
-
|
|
2084
|
-
|
|
2150
|
+
)
|
|
2151
|
+
else:
|
|
2152
|
+
# When return_logprob = False, only need last token per request
|
|
2153
|
+
num_tokens_for_logprob = local_batch.batch_size()
|
|
2085
2154
|
|
|
2086
2155
|
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
|
2087
2156
|
can_cuda_graph = 1
|
|
@@ -2235,59 +2304,6 @@ class Scheduler(
|
|
|
2235
2304
|
self._add_request_to_queue(req)
|
|
2236
2305
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
|
2237
2306
|
|
|
2238
|
-
def watchdog_thread(self):
|
|
2239
|
-
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
|
2240
|
-
self.watchdog_last_forward_ct = 0
|
|
2241
|
-
self.watchdog_last_time = time.perf_counter()
|
|
2242
|
-
|
|
2243
|
-
while True:
|
|
2244
|
-
current = time.perf_counter()
|
|
2245
|
-
if self.cur_batch is not None:
|
|
2246
|
-
if self.watchdog_last_forward_ct == self.forward_ct:
|
|
2247
|
-
if current > self.watchdog_last_time + self.watchdog_timeout:
|
|
2248
|
-
break
|
|
2249
|
-
else:
|
|
2250
|
-
self.watchdog_last_forward_ct = self.forward_ct
|
|
2251
|
-
self.watchdog_last_time = current
|
|
2252
|
-
time.sleep(self.watchdog_timeout // 2)
|
|
2253
|
-
|
|
2254
|
-
if not disable_request_logging():
|
|
2255
|
-
# Print batch size and memory pool info to check whether there are de-sync issues.
|
|
2256
|
-
if self.is_hybrid:
|
|
2257
|
-
(
|
|
2258
|
-
_,
|
|
2259
|
-
_,
|
|
2260
|
-
_,
|
|
2261
|
-
_,
|
|
2262
|
-
full_available_size,
|
|
2263
|
-
full_evictable_size,
|
|
2264
|
-
swa_available_size,
|
|
2265
|
-
swa_evictable_size,
|
|
2266
|
-
) = self._get_swa_token_info()
|
|
2267
|
-
info_msg = (
|
|
2268
|
-
f"{full_available_size=}, "
|
|
2269
|
-
f"{full_evictable_size=}, "
|
|
2270
|
-
f"{swa_available_size=}, "
|
|
2271
|
-
f"{swa_evictable_size=}, "
|
|
2272
|
-
)
|
|
2273
|
-
else:
|
|
2274
|
-
_, _, available_size, evictable_size = self._get_token_info()
|
|
2275
|
-
info_msg = f"{available_size=}, " f"{evictable_size=}, "
|
|
2276
|
-
logger.error(
|
|
2277
|
-
f"{self.cur_batch.batch_size()=}, "
|
|
2278
|
-
f"{self.cur_batch.reqs=}, "
|
|
2279
|
-
f"{info_msg}"
|
|
2280
|
-
)
|
|
2281
|
-
|
|
2282
|
-
pyspy_dump_schedulers()
|
|
2283
|
-
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
|
2284
|
-
print(file=sys.stderr, flush=True)
|
|
2285
|
-
print(file=sys.stdout, flush=True)
|
|
2286
|
-
|
|
2287
|
-
# Wait for some time so that the parent process can print the error.
|
|
2288
|
-
time.sleep(5)
|
|
2289
|
-
self.parent_process.send_signal(signal.SIGQUIT)
|
|
2290
|
-
|
|
2291
2307
|
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
|
|
2292
2308
|
success = self.flush_cache()
|
|
2293
2309
|
return FlushCacheReqOutput(success=success)
|
|
@@ -2302,13 +2318,30 @@ class Scheduler(
|
|
|
2302
2318
|
if_success = False
|
|
2303
2319
|
return ClearHiCacheReqOutput(success=if_success)
|
|
2304
2320
|
|
|
2305
|
-
def
|
|
2306
|
-
|
|
2307
|
-
if (
|
|
2321
|
+
def _is_no_request(self):
|
|
2322
|
+
no_request = (
|
|
2308
2323
|
len(self.waiting_queue) == 0
|
|
2309
2324
|
and self.running_batch.is_empty()
|
|
2325
|
+
and (self.last_batch is None or self.last_batch.is_empty())
|
|
2326
|
+
and (self.cur_batch is None or self.cur_batch.is_empty())
|
|
2327
|
+
and (not self.enable_overlap or len(self.result_queue) == 0)
|
|
2310
2328
|
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
|
2311
|
-
)
|
|
2329
|
+
)
|
|
2330
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
2331
|
+
no_request &= (
|
|
2332
|
+
len(self.disagg_prefill_bootstrap_queue.queue) == 0
|
|
2333
|
+
and len(self.disagg_prefill_inflight_queue) == 0
|
|
2334
|
+
)
|
|
2335
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
2336
|
+
no_request &= (
|
|
2337
|
+
len(self.disagg_decode_prealloc_queue.queue) == 0
|
|
2338
|
+
and len(self.disagg_decode_transfer_queue.queue) == 0
|
|
2339
|
+
)
|
|
2340
|
+
return no_request
|
|
2341
|
+
|
|
2342
|
+
def flush_cache(self):
|
|
2343
|
+
"""Flush the memory pool and cache."""
|
|
2344
|
+
if self._is_no_request():
|
|
2312
2345
|
self.cur_batch = None
|
|
2313
2346
|
self.last_batch = None
|
|
2314
2347
|
self.tree_cache.reset()
|
|
@@ -2322,10 +2355,10 @@ class Scheduler(
|
|
|
2322
2355
|
|
|
2323
2356
|
self.num_generated_tokens = 0
|
|
2324
2357
|
self.forward_ct_decode = 0
|
|
2325
|
-
self.
|
|
2326
|
-
self.
|
|
2327
|
-
self.
|
|
2328
|
-
self.
|
|
2358
|
+
self.spec_num_accepted_tokens = 0
|
|
2359
|
+
self.spec_num_forward_ct = 0
|
|
2360
|
+
self.spec_total_num_accepted_tokens = 0
|
|
2361
|
+
self.spec_total_num_forward_ct = 0
|
|
2329
2362
|
torch.cuda.empty_cache()
|
|
2330
2363
|
logger.info("Cache flushed successfully!")
|
|
2331
2364
|
if_success = True
|
|
@@ -2398,13 +2431,16 @@ class Scheduler(
|
|
|
2398
2431
|
self.tp_worker.model_runner.graph_mem_usage, 2
|
|
2399
2432
|
)
|
|
2400
2433
|
|
|
2401
|
-
if not self.spec_algorithm.is_none() and self.
|
|
2434
|
+
if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
|
|
2402
2435
|
ret["avg_spec_accept_length"] = (
|
|
2403
|
-
self.
|
|
2436
|
+
self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
|
|
2404
2437
|
)
|
|
2405
2438
|
if RECORD_STEP_TIME:
|
|
2406
2439
|
ret["step_time_dict"] = self.step_time_dict
|
|
2407
2440
|
|
|
2441
|
+
# This field is not serializable.
|
|
2442
|
+
ret.pop("model_config", None)
|
|
2443
|
+
|
|
2408
2444
|
return GetInternalStateReqOutput(internal_state=ret)
|
|
2409
2445
|
|
|
2410
2446
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
|
@@ -2431,12 +2467,12 @@ class Scheduler(
|
|
|
2431
2467
|
if_success = False
|
|
2432
2468
|
break
|
|
2433
2469
|
if if_success:
|
|
2434
|
-
if not self.spec_algorithm.is_none() and self.
|
|
2470
|
+
if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
|
|
2435
2471
|
avg_spec_accept_length = (
|
|
2436
|
-
self.
|
|
2472
|
+
self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
|
|
2437
2473
|
)
|
|
2438
2474
|
logger.info(f"{avg_spec_accept_length=}")
|
|
2439
|
-
self.
|
|
2475
|
+
self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0
|
|
2440
2476
|
for k, v in server_args_dict.items():
|
|
2441
2477
|
setattr(get_global_server_args(), k, v)
|
|
2442
2478
|
logger.info(f"Global server args updated! {get_global_server_args()=}")
|
|
@@ -2539,11 +2575,11 @@ class Scheduler(
|
|
|
2539
2575
|
if not req.finished() and (
|
|
2540
2576
|
recv_req.abort_all or req.rid.startswith(recv_req.rid)
|
|
2541
2577
|
):
|
|
2542
|
-
# Abort method 3: set `
|
|
2578
|
+
# Abort method 3: set `to_finish`
|
|
2543
2579
|
# The request will still run one decode forward pass.
|
|
2544
2580
|
# Then we reuse all existing code to clean up the KV cache allocation.
|
|
2545
2581
|
logger.debug(f"Abort running request. {req.rid=}")
|
|
2546
|
-
req.
|
|
2582
|
+
req.to_finish = FINISH_ABORT()
|
|
2547
2583
|
|
|
2548
2584
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
|
2549
2585
|
raise NotImplementedError()
|
|
@@ -2737,10 +2773,13 @@ def run_scheduler_process(
|
|
|
2737
2773
|
|
|
2738
2774
|
# Set up tracing
|
|
2739
2775
|
if server_args.enable_trace:
|
|
2740
|
-
process_tracing_init(server_args.
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
|
|
2776
|
+
process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
|
|
2777
|
+
thread_label = "Scheduler"
|
|
2778
|
+
if server_args.disaggregation_mode == "prefill":
|
|
2779
|
+
thread_label = "Prefill Scheduler"
|
|
2780
|
+
elif server_args.disaggregation_mode == "decode":
|
|
2781
|
+
thread_label = "Decode Scheduler"
|
|
2782
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
|
2744
2783
|
|
|
2745
2784
|
# Create a scheduler and run the event loop
|
|
2746
2785
|
try:
|
|
@@ -2763,7 +2802,9 @@ def run_scheduler_process(
|
|
|
2763
2802
|
|
|
2764
2803
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
|
2765
2804
|
if disaggregation_mode == DisaggregationMode.NULL:
|
|
2766
|
-
if
|
|
2805
|
+
if scheduler.enable_pdmux:
|
|
2806
|
+
scheduler.event_loop_pdmux()
|
|
2807
|
+
elif server_args.pp_size > 1:
|
|
2767
2808
|
scheduler.event_loop_pp()
|
|
2768
2809
|
elif scheduler.enable_overlap:
|
|
2769
2810
|
scheduler.event_loop_overlap()
|
|
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|
|
7
7
|
|
|
8
8
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
|
9
9
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
10
|
+
from sglang.srt.environ import envs
|
|
10
11
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
|
11
12
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
|
12
13
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
|
@@ -18,6 +19,7 @@ if TYPE_CHECKING:
|
|
|
18
19
|
logger = logging.getLogger(__name__)
|
|
19
20
|
|
|
20
21
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
|
22
|
+
LOG_FORWARD_ITERS = envs.SGLANG_LOG_FORWARD_ITERS.get()
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class KvMetrics:
|
|
@@ -39,10 +41,13 @@ class SchedulerMetricsMixin:
|
|
|
39
41
|
self.last_gen_throughput: float = 0.0
|
|
40
42
|
self.last_input_throughput: float = 0.0
|
|
41
43
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
self.
|
|
45
|
-
self.
|
|
44
|
+
|
|
45
|
+
# The number of accepted tokens and forward ct for the recent `decode_log_interval` batches (for logging)
|
|
46
|
+
self.spec_num_accepted_tokens = 0
|
|
47
|
+
self.spec_num_forward_ct = 0
|
|
48
|
+
# The total number of accepted tokens and forward ct for the whole server lifetime
|
|
49
|
+
self.spec_total_num_accepted_tokens = 0
|
|
50
|
+
self.spec_total_num_forward_ct = 0
|
|
46
51
|
self.kv_transfer_speed_gb_s: float = 0.0
|
|
47
52
|
self.kv_transfer_latency_ms: float = 0.0
|
|
48
53
|
|
|
@@ -67,8 +72,8 @@ class SchedulerMetricsMixin:
|
|
|
67
72
|
)
|
|
68
73
|
|
|
69
74
|
def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
75
|
+
self.spec_num_accepted_tokens += num_accepted_tokens + bs
|
|
76
|
+
self.spec_num_forward_ct += bs
|
|
72
77
|
self.num_generated_tokens += num_accepted_tokens
|
|
73
78
|
|
|
74
79
|
def log_prefill_stats(
|
|
@@ -122,8 +127,10 @@ class SchedulerMetricsMixin:
|
|
|
122
127
|
num_used, token_usage, _, _ = self._get_token_info()
|
|
123
128
|
token_usage_msg = f"token usage: {token_usage:.2f}, "
|
|
124
129
|
|
|
130
|
+
iter_msg = f" [{self.forward_ct + 1}]" if LOG_FORWARD_ITERS else ""
|
|
131
|
+
|
|
125
132
|
f = (
|
|
126
|
-
f"Prefill batch
|
|
133
|
+
f"Prefill batch{iter_msg}, "
|
|
127
134
|
f"#new-seq: {len(can_run_list)}, "
|
|
128
135
|
f"#new-token: {adder.log_input_tokens}, "
|
|
129
136
|
f"#cached-token: {adder.log_hit_tokens}, "
|
|
@@ -246,27 +253,28 @@ class SchedulerMetricsMixin:
|
|
|
246
253
|
gap_latency / self.server_args.decode_log_interval
|
|
247
254
|
)
|
|
248
255
|
|
|
249
|
-
|
|
256
|
+
iter_msg = f" [{self.forward_ct}]" if LOG_FORWARD_ITERS else ""
|
|
257
|
+
msg = f"Decode batch{iter_msg}, #running-req: {num_running_reqs}, {token_usage_msg}"
|
|
250
258
|
|
|
251
259
|
if self.spec_algorithm.is_none():
|
|
252
260
|
spec_accept_length = 0
|
|
253
261
|
spec_accept_rate = 0
|
|
254
262
|
else:
|
|
255
263
|
spec_accept_length = (
|
|
256
|
-
self.
|
|
264
|
+
self.spec_num_accepted_tokens / self.spec_num_forward_ct
|
|
257
265
|
)
|
|
258
266
|
# Calculate acceptance rate: accepted tokens / total draft tokens
|
|
259
|
-
total_draft_tokens = self.
|
|
267
|
+
total_draft_tokens = self.spec_num_forward_ct * (
|
|
260
268
|
(self.server_args.speculative_num_steps or 0) + 1
|
|
261
269
|
)
|
|
262
270
|
spec_accept_rate = (
|
|
263
|
-
self.
|
|
271
|
+
self.spec_num_accepted_tokens / total_draft_tokens
|
|
264
272
|
if total_draft_tokens > 0
|
|
265
273
|
else 0
|
|
266
274
|
)
|
|
267
|
-
self.
|
|
268
|
-
self.
|
|
269
|
-
self.
|
|
275
|
+
self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens
|
|
276
|
+
self.spec_total_num_forward_ct += self.spec_num_forward_ct
|
|
277
|
+
self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0
|
|
270
278
|
msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
|
|
271
279
|
cache_hit_rate = 0.0
|
|
272
280
|
|