sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -79,13 +79,17 @@ from sglang.srt.managers.io_struct import (
|
|
79
79
|
FreezeGCReq,
|
80
80
|
GetInternalStateReq,
|
81
81
|
GetInternalStateReqOutput,
|
82
|
+
GetLoadReqInput,
|
83
|
+
GetLoadReqOutput,
|
82
84
|
GetWeightsByNameReqInput,
|
83
85
|
HealthCheckOutput,
|
86
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
87
|
+
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
84
88
|
InitWeightsUpdateGroupReqInput,
|
85
89
|
LoadLoRAAdapterReqInput,
|
86
90
|
LoadLoRAAdapterReqOutput,
|
87
91
|
MultiTokenizerRegisterReq,
|
88
|
-
|
92
|
+
MultiTokenizerWrapper,
|
89
93
|
OpenSessionReqInput,
|
90
94
|
OpenSessionReqOutput,
|
91
95
|
ProfileReq,
|
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
|
|
93
97
|
ResumeMemoryOccupationReqInput,
|
94
98
|
RpcReqInput,
|
95
99
|
RpcReqOutput,
|
100
|
+
SendWeightsToRemoteInstanceReqInput,
|
101
|
+
SendWeightsToRemoteInstanceReqOutput,
|
96
102
|
SetInternalStateReq,
|
97
103
|
SetInternalStateReqOutput,
|
98
104
|
SlowDownReqInput,
|
@@ -141,10 +147,19 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
|
141
147
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
142
148
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
143
149
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
144
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
150
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
145
151
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
152
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
153
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
154
|
+
from sglang.srt.tracing.trace import (
|
155
|
+
process_tracing_init,
|
156
|
+
trace_event,
|
157
|
+
trace_set_proc_propagate_context,
|
158
|
+
trace_set_thread_info,
|
159
|
+
trace_slice,
|
160
|
+
trace_slice_end,
|
161
|
+
trace_slice_start,
|
162
|
+
)
|
148
163
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
149
164
|
from sglang.srt.utils import (
|
150
165
|
DynamicGradMode,
|
@@ -158,6 +173,7 @@ from sglang.srt.utils import (
|
|
158
173
|
get_zmq_socket,
|
159
174
|
is_cpu,
|
160
175
|
kill_itself_when_parent_died,
|
176
|
+
numa_bind_to_node,
|
161
177
|
point_to_point_pyobj,
|
162
178
|
pyspy_dump_schedulers,
|
163
179
|
require_mlp_sync,
|
@@ -348,6 +364,18 @@ class Scheduler(
|
|
348
364
|
target_worker=self.tp_worker,
|
349
365
|
dp_rank=dp_rank,
|
350
366
|
)
|
367
|
+
elif self.spec_algorithm.is_standalone():
|
368
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
369
|
+
|
370
|
+
self.draft_worker = StandaloneWorker(
|
371
|
+
gpu_id=gpu_id,
|
372
|
+
tp_rank=tp_rank,
|
373
|
+
moe_ep_rank=moe_ep_rank,
|
374
|
+
server_args=server_args,
|
375
|
+
nccl_port=port_args.nccl_port,
|
376
|
+
target_worker=self.tp_worker,
|
377
|
+
dp_rank=dp_rank,
|
378
|
+
)
|
351
379
|
else:
|
352
380
|
self.draft_worker = None
|
353
381
|
|
@@ -401,7 +429,7 @@ class Scheduler(
|
|
401
429
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
402
430
|
f"max_running_requests={self.max_running_requests}, "
|
403
431
|
f"context_len={self.model_config.context_len}, "
|
404
|
-
f"available_gpu_mem={avail_mem:.2f} GB"
|
432
|
+
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
405
433
|
)
|
406
434
|
|
407
435
|
# Init memory pool and cache
|
@@ -488,7 +516,7 @@ class Scheduler(
|
|
488
516
|
enable=server_args.enable_memory_saver
|
489
517
|
)
|
490
518
|
self.offload_tags = set()
|
491
|
-
self.
|
519
|
+
self.init_profiler()
|
492
520
|
|
493
521
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
494
522
|
self.input_blocker = (
|
@@ -500,6 +528,7 @@ class Scheduler(
|
|
500
528
|
# Init metrics stats
|
501
529
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
502
530
|
self.init_kv_events(server_args.kv_events_config)
|
531
|
+
self.init_dp_balance(dp_balance_meta)
|
503
532
|
|
504
533
|
# Init disaggregation
|
505
534
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -524,6 +553,14 @@ class Scheduler(
|
|
524
553
|
(CloseSessionReqInput, self.close_session),
|
525
554
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
526
555
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
556
|
+
(
|
557
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
558
|
+
self.init_weights_send_group_for_remote_instance,
|
559
|
+
),
|
560
|
+
(
|
561
|
+
SendWeightsToRemoteInstanceReqInput,
|
562
|
+
self.send_weights_to_remote_instance,
|
563
|
+
),
|
527
564
|
(
|
528
565
|
UpdateWeightsFromDistributedReqInput,
|
529
566
|
self.update_weights_from_distributed,
|
@@ -542,18 +579,10 @@ class Scheduler(
|
|
542
579
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
543
580
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
544
581
|
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
582
|
+
(GetLoadReqInput, self.get_load),
|
545
583
|
]
|
546
584
|
)
|
547
585
|
|
548
|
-
self.balance_meta = dp_balance_meta
|
549
|
-
if (
|
550
|
-
server_args.enable_dp_attention
|
551
|
-
and server_args.load_balance_method == "minimum_tokens"
|
552
|
-
):
|
553
|
-
assert dp_balance_meta is not None
|
554
|
-
|
555
|
-
self.recv_dp_balance_id_this_term = []
|
556
|
-
|
557
586
|
def init_tokenizer(self):
|
558
587
|
server_args = self.server_args
|
559
588
|
self.is_generation = self.model_config.is_generation
|
@@ -630,6 +659,7 @@ class Scheduler(
|
|
630
659
|
hicache_write_policy=server_args.hicache_write_policy,
|
631
660
|
hicache_io_backend=server_args.hicache_io_backend,
|
632
661
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
662
|
+
enable_metrics=self.enable_metrics,
|
633
663
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
634
664
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
635
665
|
model_name=server_args.served_model_name,
|
@@ -662,6 +692,21 @@ class Scheduler(
|
|
662
692
|
page_size=self.page_size,
|
663
693
|
disable=server_args.disable_radix_cache,
|
664
694
|
)
|
695
|
+
elif server_args.enable_lmcache:
|
696
|
+
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
697
|
+
LMCRadixCache,
|
698
|
+
)
|
699
|
+
|
700
|
+
self.tree_cache = LMCRadixCache(
|
701
|
+
req_to_token_pool=self.req_to_token_pool,
|
702
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
703
|
+
page_size=self.page_size,
|
704
|
+
disable=server_args.disable_radix_cache,
|
705
|
+
model_config=self.model_config,
|
706
|
+
tp_size=self.tp_size,
|
707
|
+
rank=self.tp_rank,
|
708
|
+
tp_group=self.tp_group,
|
709
|
+
)
|
665
710
|
else:
|
666
711
|
self.tree_cache = RadixCache(
|
667
712
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -793,6 +838,10 @@ class Scheduler(
|
|
793
838
|
batch = self.get_next_batch_to_run()
|
794
839
|
self.cur_batch = batch
|
795
840
|
|
841
|
+
if batch:
|
842
|
+
for req in batch.reqs:
|
843
|
+
trace_event("schedule", req.rid)
|
844
|
+
|
796
845
|
if batch:
|
797
846
|
result = self.run_batch(batch)
|
798
847
|
self.process_batch_result(batch, result)
|
@@ -814,6 +863,10 @@ class Scheduler(
|
|
814
863
|
batch = self.get_next_batch_to_run()
|
815
864
|
self.cur_batch = batch
|
816
865
|
|
866
|
+
if batch:
|
867
|
+
for req in batch.reqs:
|
868
|
+
trace_event("schedule", req.rid)
|
869
|
+
|
817
870
|
if batch:
|
818
871
|
batch.launch_done = threading.Event()
|
819
872
|
result = self.run_batch(batch)
|
@@ -1077,6 +1130,12 @@ class Scheduler(
|
|
1077
1130
|
self.tp_cpu_group,
|
1078
1131
|
src=self.tp_group.ranks[0],
|
1079
1132
|
)
|
1133
|
+
|
1134
|
+
for req in recv_reqs:
|
1135
|
+
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
|
1136
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
1137
|
+
trace_slice_start("", req.rid, anonymous=True)
|
1138
|
+
|
1080
1139
|
return recv_reqs
|
1081
1140
|
|
1082
1141
|
def process_input_requests(self, recv_reqs: List):
|
@@ -1104,13 +1163,13 @@ class Scheduler(
|
|
1104
1163
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
1105
1164
|
continue
|
1106
1165
|
|
1107
|
-
# If it is a
|
1108
|
-
if isinstance(recv_req,
|
1166
|
+
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1167
|
+
if isinstance(recv_req, MultiTokenizerWrapper):
|
1109
1168
|
worker_id = recv_req.worker_id
|
1110
1169
|
recv_req = recv_req.obj
|
1111
1170
|
output = self._request_dispatcher(recv_req)
|
1112
1171
|
if output is not None:
|
1113
|
-
output =
|
1172
|
+
output = MultiTokenizerWrapper(worker_id, output)
|
1114
1173
|
self.send_to_tokenizer.send_pyobj(output)
|
1115
1174
|
continue
|
1116
1175
|
|
@@ -1122,15 +1181,21 @@ class Scheduler(
|
|
1122
1181
|
else:
|
1123
1182
|
self.send_to_tokenizer.send_pyobj(output)
|
1124
1183
|
|
1184
|
+
def init_req_max_new_tokens(self, req):
|
1185
|
+
req.sampling_params.max_new_tokens = min(
|
1186
|
+
(
|
1187
|
+
req.sampling_params.max_new_tokens
|
1188
|
+
if req.sampling_params.max_new_tokens is not None
|
1189
|
+
else 1 << 30
|
1190
|
+
),
|
1191
|
+
self.max_req_len - len(req.origin_input_ids) - 1,
|
1192
|
+
)
|
1193
|
+
|
1125
1194
|
def handle_generate_request(
|
1126
1195
|
self,
|
1127
1196
|
recv_req: TokenizedGenerateReqInput,
|
1128
1197
|
):
|
1129
|
-
|
1130
|
-
self.server_args.enable_dp_attention
|
1131
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
1132
|
-
):
|
1133
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1198
|
+
self.maybe_update_dp_balance_data(recv_req)
|
1134
1199
|
|
1135
1200
|
# Create a new request
|
1136
1201
|
if (
|
@@ -1189,6 +1254,7 @@ class Scheduler(
|
|
1189
1254
|
req.set_finish_with_abort(
|
1190
1255
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
1191
1256
|
)
|
1257
|
+
self.init_req_max_new_tokens(req)
|
1192
1258
|
self._add_request_to_queue(req)
|
1193
1259
|
return
|
1194
1260
|
else:
|
@@ -1196,6 +1262,7 @@ class Scheduler(
|
|
1196
1262
|
session = self.sessions[recv_req.session_params.id]
|
1197
1263
|
req = session.create_req(recv_req, self.tokenizer)
|
1198
1264
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
1265
|
+
self.init_req_max_new_tokens(req)
|
1199
1266
|
self._add_request_to_queue(req)
|
1200
1267
|
return
|
1201
1268
|
|
@@ -1215,9 +1282,13 @@ class Scheduler(
|
|
1215
1282
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
1216
1283
|
)
|
1217
1284
|
)
|
1285
|
+
self.init_req_max_new_tokens(req)
|
1218
1286
|
self._add_request_to_queue(req)
|
1219
1287
|
return
|
1220
1288
|
|
1289
|
+
# initialize before returning
|
1290
|
+
self.init_req_max_new_tokens(req)
|
1291
|
+
|
1221
1292
|
# Validate prompt length
|
1222
1293
|
error_msg = validate_input_length(
|
1223
1294
|
req,
|
@@ -1232,26 +1303,25 @@ class Scheduler(
|
|
1232
1303
|
# Copy more attributes
|
1233
1304
|
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
1234
1305
|
# By default, only return the logprobs for output tokens
|
1235
|
-
|
1306
|
+
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
|
1307
|
+
# to skip input logprob computation entirely
|
1308
|
+
if req.is_prefill_only:
|
1309
|
+
req.logprob_start_len = len(req.origin_input_ids)
|
1310
|
+
else:
|
1311
|
+
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
|
1312
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1236
1313
|
else:
|
1237
1314
|
req.logprob_start_len = recv_req.logprob_start_len
|
1238
1315
|
|
1239
|
-
if req.logprob_start_len >= len(
|
1316
|
+
if not req.is_prefill_only and req.logprob_start_len >= len(
|
1317
|
+
req.origin_input_ids
|
1318
|
+
):
|
1240
1319
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
1241
1320
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1242
1321
|
req.set_finish_with_abort(error_msg)
|
1243
1322
|
self._add_request_to_queue(req)
|
1244
1323
|
return
|
1245
1324
|
|
1246
|
-
req.sampling_params.max_new_tokens = min(
|
1247
|
-
(
|
1248
|
-
req.sampling_params.max_new_tokens
|
1249
|
-
if req.sampling_params.max_new_tokens is not None
|
1250
|
-
else 1 << 30
|
1251
|
-
),
|
1252
|
-
self.max_req_len - len(req.origin_input_ids) - 1,
|
1253
|
-
)
|
1254
|
-
|
1255
1325
|
# Init grammar cache for this request
|
1256
1326
|
add_to_grammar_queue = False
|
1257
1327
|
if (
|
@@ -1310,6 +1380,7 @@ class Scheduler(
|
|
1310
1380
|
else:
|
1311
1381
|
self._prefetch_kvcache(req)
|
1312
1382
|
self.waiting_queue.append(req)
|
1383
|
+
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1313
1384
|
|
1314
1385
|
def _prefetch_kvcache(self, req: Req):
|
1315
1386
|
if self.enable_hicache_storage:
|
@@ -1421,9 +1492,11 @@ class Scheduler(
|
|
1421
1492
|
_, _, available_size, evictable_size = self._get_token_info()
|
1422
1493
|
protected_size = self.tree_cache.protected_size()
|
1423
1494
|
memory_leak = (available_size + evictable_size) != (
|
1495
|
+
# self.max_total_num_tokens
|
1496
|
+
# if not self.enable_hierarchical_cache
|
1497
|
+
# else self.max_total_num_tokens - protected_size
|
1424
1498
|
self.max_total_num_tokens
|
1425
|
-
|
1426
|
-
else self.max_total_num_tokens - protected_size
|
1499
|
+
- protected_size
|
1427
1500
|
)
|
1428
1501
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1429
1502
|
|
@@ -1474,6 +1547,20 @@ class Scheduler(
|
|
1474
1547
|
self.stats.gen_throughput = 0
|
1475
1548
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1476
1549
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1550
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1551
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
1552
|
+
self.disagg_prefill_bootstrap_queue.queue
|
1553
|
+
)
|
1554
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
1555
|
+
self.disagg_prefill_inflight_queue
|
1556
|
+
)
|
1557
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1558
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
1559
|
+
self.disagg_decode_prealloc_queue.queue
|
1560
|
+
)
|
1561
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
1562
|
+
self.disagg_decode_transfer_queue.queue
|
1563
|
+
)
|
1477
1564
|
self.metrics_collector.log_stats(self.stats)
|
1478
1565
|
self._publish_kv_events()
|
1479
1566
|
|
@@ -1521,7 +1608,12 @@ class Scheduler(
|
|
1521
1608
|
chunked_req_to_exclude.add(self.chunked_req)
|
1522
1609
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1523
1610
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1524
|
-
self.
|
1611
|
+
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
1612
|
+
self.req_to_token_pool.free(
|
1613
|
+
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1614
|
+
)
|
1615
|
+
else:
|
1616
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1525
1617
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1526
1618
|
if self.last_batch.chunked_req is not None:
|
1527
1619
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
@@ -1568,11 +1660,7 @@ class Scheduler(
|
|
1568
1660
|
|
1569
1661
|
# Handle DP attention
|
1570
1662
|
if need_dp_attn_preparation:
|
1571
|
-
|
1572
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
1573
|
-
and self.forward_ct % 40 == 0
|
1574
|
-
):
|
1575
|
-
self.handle_dp_balance_data(ret)
|
1663
|
+
self.maybe_handle_dp_balance_data()
|
1576
1664
|
ret = self.prepare_mlp_sync_batch(ret)
|
1577
1665
|
|
1578
1666
|
return ret
|
@@ -1792,10 +1880,6 @@ class Scheduler(
|
|
1792
1880
|
if self.spec_algorithm.is_none():
|
1793
1881
|
model_worker_batch = batch.get_model_worker_batch()
|
1794
1882
|
|
1795
|
-
# update the consumer index of hicache to the running batch
|
1796
|
-
self.tp_worker.set_hicache_consumer(
|
1797
|
-
model_worker_batch.hicache_consumer_index
|
1798
|
-
)
|
1799
1883
|
if self.pp_group.is_last_rank:
|
1800
1884
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
1801
1885
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
@@ -1864,8 +1948,23 @@ class Scheduler(
|
|
1864
1948
|
):
|
1865
1949
|
if batch.forward_mode.is_decode():
|
1866
1950
|
self.process_batch_result_decode(batch, result, launch_done)
|
1951
|
+
for req in batch.reqs:
|
1952
|
+
trace_slice(
|
1953
|
+
"decode loop",
|
1954
|
+
req.rid,
|
1955
|
+
auto_next_anon=not req.finished(),
|
1956
|
+
thread_finish_flag=req.finished(),
|
1957
|
+
)
|
1958
|
+
|
1867
1959
|
elif batch.forward_mode.is_extend():
|
1868
1960
|
self.process_batch_result_prefill(batch, result, launch_done)
|
1961
|
+
for req in batch.reqs:
|
1962
|
+
trace_slice(
|
1963
|
+
"prefill",
|
1964
|
+
req.rid,
|
1965
|
+
auto_next_anon=not req.finished(),
|
1966
|
+
thread_finish_flag=req.finished(),
|
1967
|
+
)
|
1869
1968
|
elif batch.forward_mode.is_idle():
|
1870
1969
|
if self.enable_overlap:
|
1871
1970
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
@@ -1897,86 +1996,6 @@ class Scheduler(
|
|
1897
1996
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1898
1997
|
)
|
1899
1998
|
|
1900
|
-
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1901
|
-
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1902
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1903
|
-
recv_list = self.recv_dp_balance_id_this_term
|
1904
|
-
assert len(recv_list) <= 511, (
|
1905
|
-
"The number of requests received this round is too large. "
|
1906
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
1907
|
-
)
|
1908
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
1909
|
-
gather_tensor_size = 512
|
1910
|
-
|
1911
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1912
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1913
|
-
recv_tensor[0] = holding_tokens_list
|
1914
|
-
recv_tensor[1] = len(
|
1915
|
-
recv_list
|
1916
|
-
) # The first element is the length of the list.
|
1917
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1918
|
-
recv_list, dtype=torch.int32
|
1919
|
-
)
|
1920
|
-
|
1921
|
-
if self.tp_rank == 0:
|
1922
|
-
gathered_list = [
|
1923
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1924
|
-
for _ in range(self.balance_meta.num_workers)
|
1925
|
-
]
|
1926
|
-
else:
|
1927
|
-
gathered_list = None
|
1928
|
-
|
1929
|
-
torch.distributed.gather(
|
1930
|
-
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1931
|
-
)
|
1932
|
-
|
1933
|
-
gathered_id_list_per_worker = None
|
1934
|
-
if self.tp_rank == 0:
|
1935
|
-
gathered_id_list_per_worker = []
|
1936
|
-
holding_tokens_list = []
|
1937
|
-
for tensor in gathered_list:
|
1938
|
-
holding_tokens_list.append(tensor[0].item())
|
1939
|
-
list_length = tensor[1].item()
|
1940
|
-
gathered_id_list_per_worker.append(
|
1941
|
-
tensor[2 : list_length + 2].tolist()
|
1942
|
-
)
|
1943
|
-
|
1944
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
1945
|
-
|
1946
|
-
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1947
|
-
meta = self.balance_meta
|
1948
|
-
|
1949
|
-
with meta.mutex:
|
1950
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1951
|
-
assert len(new_recv_rid_lists) == len(
|
1952
|
-
onfly_list
|
1953
|
-
), "num_worker not equal"
|
1954
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
1955
|
-
# If it is, remove the corresponding onfly item.
|
1956
|
-
worker_id = 0
|
1957
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1958
|
-
for new_recv_rid in new_recv_rids:
|
1959
|
-
assert (
|
1960
|
-
new_recv_rid in on_fly_reqs
|
1961
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1962
|
-
del on_fly_reqs[new_recv_rid]
|
1963
|
-
worker_id += 1
|
1964
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1965
|
-
meta.set_shared_onfly_info(onfly_list)
|
1966
|
-
meta.set_shared_local_tokens(local_tokens)
|
1967
|
-
|
1968
|
-
holding_tokens = self.get_load()
|
1969
|
-
|
1970
|
-
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1971
|
-
holding_tokens
|
1972
|
-
)
|
1973
|
-
|
1974
|
-
self.recv_dp_balance_id_this_term.clear()
|
1975
|
-
if self.tp_rank == 0: # only first worker write info
|
1976
|
-
write_shared_dp_balance_info(
|
1977
|
-
new_recv_dp_balance_id_list, holding_token_list
|
1978
|
-
)
|
1979
|
-
|
1980
1999
|
@staticmethod
|
1981
2000
|
def prepare_mlp_sync_batch_raw(
|
1982
2001
|
local_batch: ScheduleBatch,
|
@@ -2270,39 +2289,50 @@ class Scheduler(
|
|
2270
2289
|
if_success = False
|
2271
2290
|
return if_success
|
2272
2291
|
|
2273
|
-
def get_load(self):
|
2292
|
+
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
2274
2293
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
2294
|
+
|
2275
2295
|
if self.is_hybrid:
|
2276
|
-
|
2296
|
+
num_tokens_full = (
|
2277
2297
|
self.full_tokens_per_layer
|
2278
2298
|
- self.token_to_kv_pool_allocator.full_available_size()
|
2279
2299
|
- self.tree_cache.full_evictable_size()
|
2280
2300
|
)
|
2281
|
-
|
2301
|
+
num_tokens_swa = (
|
2282
2302
|
self.swa_tokens_per_layer
|
2283
2303
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
2284
2304
|
- self.tree_cache.swa_evictable_size()
|
2285
2305
|
)
|
2286
|
-
|
2306
|
+
num_tokens = max(num_tokens_full, num_tokens_swa)
|
2287
2307
|
else:
|
2288
|
-
|
2308
|
+
num_tokens = (
|
2289
2309
|
self.max_total_num_tokens
|
2290
2310
|
- self.token_to_kv_pool_allocator.available_size()
|
2291
2311
|
- self.tree_cache.evictable_size()
|
2292
2312
|
)
|
2293
|
-
|
2313
|
+
|
2314
|
+
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
2315
|
+
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
2316
|
+
num_waiting_reqs = len(self.waiting_queue)
|
2294
2317
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2295
|
-
|
2318
|
+
num_tokens += sum(
|
2296
2319
|
len(req.origin_input_ids)
|
2297
2320
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
2298
2321
|
)
|
2322
|
+
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
2299
2323
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2300
|
-
|
2324
|
+
num_tokens += sum(
|
2301
2325
|
len(req.req.origin_input_ids)
|
2302
2326
|
for req in self.disagg_decode_prealloc_queue.queue
|
2303
2327
|
)
|
2328
|
+
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
2304
2329
|
|
2305
|
-
return
|
2330
|
+
return GetLoadReqOutput(
|
2331
|
+
dp_rank=self.dp_rank,
|
2332
|
+
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
2333
|
+
num_waiting_reqs=num_waiting_reqs,
|
2334
|
+
num_tokens=num_tokens,
|
2335
|
+
)
|
2306
2336
|
|
2307
2337
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
2308
2338
|
ret = dict(global_server_args_dict)
|
@@ -2317,10 +2347,9 @@ class Scheduler(
|
|
2317
2347
|
"token_capacity": int(self.max_total_num_tokens),
|
2318
2348
|
}
|
2319
2349
|
|
2320
|
-
|
2321
|
-
|
2322
|
-
|
2323
|
-
)
|
2350
|
+
ret["memory_usage"]["graph"] = round(
|
2351
|
+
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
2352
|
+
)
|
2324
2353
|
|
2325
2354
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
2326
2355
|
ret["avg_spec_accept_length"] = (
|
@@ -2329,8 +2358,6 @@ class Scheduler(
|
|
2329
2358
|
if RECORD_STEP_TIME:
|
2330
2359
|
ret["step_time_dict"] = self.step_time_dict
|
2331
2360
|
|
2332
|
-
ret["load"] = self.get_load()
|
2333
|
-
|
2334
2361
|
return GetInternalStateReqOutput(internal_state=ret)
|
2335
2362
|
|
2336
2363
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
@@ -2494,6 +2521,22 @@ class Scheduler(
|
|
2494
2521
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
2495
2522
|
return recv_req
|
2496
2523
|
|
2524
|
+
def init_weights_send_group_for_remote_instance(
|
2525
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
2526
|
+
):
|
2527
|
+
"""Init the seed and client instance communication group."""
|
2528
|
+
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
2529
|
+
recv_req
|
2530
|
+
)
|
2531
|
+
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
2532
|
+
|
2533
|
+
def send_weights_to_remote_instance(
|
2534
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
2535
|
+
):
|
2536
|
+
"""Send the seed instance weights to the destination instance."""
|
2537
|
+
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
2538
|
+
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
2539
|
+
|
2497
2540
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2498
2541
|
t = recv_req.forward_sleep_time
|
2499
2542
|
if t is not None and t <= 0:
|
@@ -2615,6 +2658,15 @@ def run_scheduler_process(
|
|
2615
2658
|
pipe_writer,
|
2616
2659
|
balance_meta: Optional[DPBalanceMeta] = None,
|
2617
2660
|
):
|
2661
|
+
if server_args.enable_trace:
|
2662
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2663
|
+
if server_args.disaggregation_mode == "null":
|
2664
|
+
thread_label = "Scheduler"
|
2665
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2666
|
+
|
2667
|
+
if (numa_node := server_args.numa_node) is not None:
|
2668
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2669
|
+
|
2618
2670
|
# Generate the prefix
|
2619
2671
|
prefix = ""
|
2620
2672
|
if dp_rank is not None:
|