sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
91
91
|
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
92
92
|
)
|
93
93
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
94
|
+
self.mrope_positions = torch.zeros(
|
95
|
+
(3, self.max_num_token), dtype=torch.int64
|
96
|
+
)
|
94
97
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
95
98
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
96
99
|
self.hidden_states = torch.zeros(
|
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
159
162
|
seq_lens = self.seq_lens[:num_seqs]
|
160
163
|
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
161
164
|
positions = self.positions[:num_tokens]
|
165
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
162
166
|
topk_p = self.topk_p[:num_seqs]
|
163
167
|
topk_index = self.topk_index[:num_seqs]
|
164
168
|
hidden_states = self.hidden_states[:num_seqs]
|
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
224
228
|
seq_lens_sum=seq_lens.sum().item(),
|
225
229
|
return_logprob=False,
|
226
230
|
positions=positions,
|
231
|
+
mrope_positions=mrope_positions,
|
227
232
|
global_num_tokens_gpu=global_num_tokens,
|
228
233
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
229
234
|
global_dp_buffer_len=global_dp_buffer_len,
|
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
80
80
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
81
81
|
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
82
82
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
83
|
+
self.mrope_positions = torch.zeros(
|
84
|
+
(3, self.max_num_token), dtype=torch.int64
|
85
|
+
)
|
83
86
|
|
84
87
|
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
85
88
|
self.hidden_states = torch.zeros(
|
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
189
192
|
accept_length = self.accept_length[:bs]
|
190
193
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
191
194
|
positions = self.positions[:num_tokens]
|
195
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
192
196
|
hidden_states = self.hidden_states[:num_tokens]
|
193
197
|
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
194
198
|
|
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
247
251
|
seq_lens_sum=seq_lens.sum().item(),
|
248
252
|
return_logprob=False,
|
249
253
|
positions=positions,
|
254
|
+
mrope_positions=mrope_positions,
|
250
255
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
251
256
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
252
257
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
@@ -336,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
336
341
|
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
337
342
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
338
343
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
339
|
-
|
344
|
+
if (
|
345
|
+
forward_batch.spec_info.hidden_states.shape[1]
|
346
|
+
== self.hidden_states.shape[1]
|
347
|
+
):
|
348
|
+
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
340
349
|
if forward_batch.spec_info.accept_length is not None:
|
341
350
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
342
351
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
@@ -26,8 +26,6 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
27
27
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
28
28
|
|
29
|
-
logger = logging.getLogger(__name__)
|
30
|
-
|
31
29
|
if is_cuda():
|
32
30
|
from sgl_kernel import (
|
33
31
|
fast_topk,
|
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
|
|
14
14
|
)
|
15
15
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
16
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
17
|
+
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
17
18
|
from sglang.srt.managers.schedule_batch import (
|
18
19
|
ScheduleBatch,
|
19
20
|
get_last_loc,
|
@@ -47,6 +48,7 @@ from sglang.srt.utils import (
|
|
47
48
|
empty_context,
|
48
49
|
get_available_gpu_memory,
|
49
50
|
get_bool_env_var,
|
51
|
+
is_blackwell,
|
50
52
|
is_cuda,
|
51
53
|
next_power_of_2,
|
52
54
|
)
|
@@ -187,137 +189,197 @@ class EAGLEWorker(TpModelWorker):
|
|
187
189
|
self.has_prefill_wrapper_verify = False
|
188
190
|
self.draft_extend_attn_backend = None
|
189
191
|
|
190
|
-
|
191
|
-
|
192
|
-
from sglang.srt.layers.attention.flashinfer_backend import (
|
193
|
-
FlashInferAttnBackend,
|
194
|
-
FlashInferMultiStepDraftBackend,
|
195
|
-
)
|
192
|
+
# Initialize decode attention backend
|
193
|
+
self.draft_attn_backend = self._create_decode_backend()
|
196
194
|
|
197
|
-
|
198
|
-
|
199
|
-
self.topk,
|
200
|
-
self.speculative_num_steps,
|
201
|
-
)
|
202
|
-
self.draft_extend_attn_backend = FlashInferAttnBackend(
|
203
|
-
self.draft_model_runner,
|
204
|
-
skip_prefill=False,
|
205
|
-
)
|
206
|
-
else:
|
207
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
208
|
-
FlashInferMLAAttnBackend,
|
209
|
-
FlashInferMLAMultiStepDraftBackend,
|
210
|
-
)
|
195
|
+
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
196
|
+
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
211
197
|
|
212
|
-
|
213
|
-
self.draft_model_runner,
|
214
|
-
self.topk,
|
215
|
-
self.speculative_num_steps,
|
216
|
-
)
|
217
|
-
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
|
218
|
-
self.draft_model_runner,
|
219
|
-
skip_prefill=False,
|
220
|
-
)
|
221
|
-
self.has_prefill_wrapper_verify = True
|
222
|
-
elif self.server_args.attention_backend == "triton":
|
223
|
-
from sglang.srt.layers.attention.triton_backend import (
|
224
|
-
TritonAttnBackend,
|
225
|
-
TritonMultiStepDraftBackend,
|
226
|
-
)
|
198
|
+
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
227
199
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
)
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
200
|
+
def _create_backend(
|
201
|
+
self, backend_name: str, backend_map: dict, error_template: str
|
202
|
+
):
|
203
|
+
backend_type = getattr(self.server_args, backend_name)
|
204
|
+
if backend_type is None:
|
205
|
+
backend_type = self.server_args.attention_backend
|
206
|
+
|
207
|
+
if backend_type not in backend_map:
|
208
|
+
raise ValueError(error_template.format(backend_type=backend_type))
|
209
|
+
|
210
|
+
return backend_map[backend_type]()
|
211
|
+
|
212
|
+
def _create_decode_backend(self):
|
213
|
+
backend_map = {
|
214
|
+
"flashinfer": self._create_flashinfer_decode_backend,
|
215
|
+
"triton": self._create_triton_decode_backend,
|
216
|
+
"aiter": self._create_aiter_decode_backend,
|
217
|
+
"fa3": self._create_fa3_decode_backend,
|
218
|
+
"hybrid_linear_attn": (
|
219
|
+
self._create_fa3_decode_backend
|
220
|
+
if not is_blackwell()
|
221
|
+
else self._create_triton_decode_backend
|
222
|
+
),
|
223
|
+
"flashmla": self._create_flashmla_decode_backend,
|
224
|
+
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
225
|
+
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
226
|
+
}
|
227
|
+
|
228
|
+
return self._create_backend(
|
229
|
+
"decode_attention_backend",
|
230
|
+
backend_map,
|
231
|
+
"EAGLE is not supported in decode attention backend {backend_type}",
|
232
|
+
)
|
242
233
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
self.
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
234
|
+
def _create_draft_extend_backend(self):
|
235
|
+
backend_map = {
|
236
|
+
"flashinfer": self._create_flashinfer_prefill_backend,
|
237
|
+
"triton": self._create_triton_prefill_backend,
|
238
|
+
"aiter": self._create_aiter_prefill_backend,
|
239
|
+
"fa3": self._create_fa3_prefill_backend,
|
240
|
+
"hybrid_linear_attn": (
|
241
|
+
self._create_fa3_prefill_backend
|
242
|
+
if not is_blackwell()
|
243
|
+
else self._create_triton_prefill_backend
|
244
|
+
),
|
245
|
+
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
246
|
+
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
247
|
+
}
|
248
|
+
backend_name = (
|
249
|
+
"decode_attention_backend"
|
250
|
+
if self.server_args.speculative_attention_mode == "decode"
|
251
|
+
else "prefill_attention_backend"
|
252
|
+
)
|
253
|
+
return self._create_backend(
|
254
|
+
backend_name,
|
255
|
+
backend_map,
|
256
|
+
"EAGLE is not supported in attention backend {backend_type}",
|
257
|
+
)
|
258
258
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
)
|
264
|
-
self.draft_extend_attn_backend = FlashAttentionBackend(
|
265
|
-
self.draft_model_runner,
|
266
|
-
skip_prefill=False,
|
267
|
-
)
|
268
|
-
elif self.server_args.attention_backend == "flashmla":
|
269
|
-
from sglang.srt.layers.attention.flashmla_backend import (
|
270
|
-
FlashMLAMultiStepDraftBackend,
|
259
|
+
def _create_flashinfer_decode_backend(self):
|
260
|
+
if not global_server_args_dict["use_mla_backend"]:
|
261
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
262
|
+
FlashInferMultiStepDraftBackend,
|
271
263
|
)
|
272
264
|
|
273
|
-
self.
|
274
|
-
|
275
|
-
self.topk,
|
276
|
-
self.speculative_num_steps,
|
265
|
+
self.has_prefill_wrapper_verify = True
|
266
|
+
return FlashInferMultiStepDraftBackend(
|
267
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
277
268
|
)
|
278
|
-
|
279
|
-
from sglang.srt.layers.attention.
|
280
|
-
|
281
|
-
TRTLLMHAAttnMultiStepDraftBackend,
|
269
|
+
else:
|
270
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
271
|
+
FlashInferMLAMultiStepDraftBackend,
|
282
272
|
)
|
283
273
|
|
284
|
-
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
285
|
-
self.draft_model_runner,
|
286
|
-
self.topk,
|
287
|
-
self.speculative_num_steps,
|
288
|
-
)
|
289
|
-
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
290
|
-
self.draft_model_runner,
|
291
|
-
skip_prefill=False,
|
292
|
-
)
|
293
274
|
self.has_prefill_wrapper_verify = True
|
294
|
-
|
295
|
-
|
296
|
-
raise ValueError(
|
297
|
-
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
298
|
-
)
|
299
|
-
|
300
|
-
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
301
|
-
TRTLLMMLABackend,
|
302
|
-
TRTLLMMLAMultiStepDraftBackend,
|
275
|
+
return FlashInferMLAMultiStepDraftBackend(
|
276
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
303
277
|
)
|
304
278
|
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
279
|
+
def _create_triton_decode_backend(self):
|
280
|
+
from sglang.srt.layers.attention.triton_backend import (
|
281
|
+
TritonMultiStepDraftBackend,
|
282
|
+
)
|
283
|
+
|
284
|
+
return TritonMultiStepDraftBackend(
|
285
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
286
|
+
)
|
287
|
+
|
288
|
+
def _create_aiter_decode_backend(self):
|
289
|
+
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
290
|
+
|
291
|
+
return AiterMultiStepDraftBackend(
|
292
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
293
|
+
)
|
294
|
+
|
295
|
+
def _create_fa3_decode_backend(self):
|
296
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
297
|
+
FlashAttentionMultiStepBackend,
|
298
|
+
)
|
299
|
+
|
300
|
+
return FlashAttentionMultiStepBackend(
|
301
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
302
|
+
)
|
303
|
+
|
304
|
+
def _create_flashmla_decode_backend(self):
|
305
|
+
from sglang.srt.layers.attention.flashmla_backend import (
|
306
|
+
FlashMLAMultiStepDraftBackend,
|
307
|
+
)
|
308
|
+
|
309
|
+
return FlashMLAMultiStepDraftBackend(
|
310
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
311
|
+
)
|
312
|
+
|
313
|
+
def _create_trtllm_mha_decode_backend(self):
|
314
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
315
|
+
TRTLLMHAAttnMultiStepDraftBackend,
|
316
|
+
)
|
317
|
+
|
318
|
+
self.has_prefill_wrapper_verify = True
|
319
|
+
return TRTLLMHAAttnMultiStepDraftBackend(
|
320
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
321
|
+
)
|
322
|
+
|
323
|
+
def _create_trtllm_mla_decode_backend(self):
|
324
|
+
if not global_server_args_dict["use_mla_backend"]:
|
325
|
+
raise ValueError(
|
326
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
309
327
|
)
|
310
|
-
|
311
|
-
|
312
|
-
|
328
|
+
|
329
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
330
|
+
TRTLLMMLAMultiStepDraftBackend,
|
331
|
+
)
|
332
|
+
|
333
|
+
self.has_prefill_wrapper_verify = True
|
334
|
+
return TRTLLMMLAMultiStepDraftBackend(
|
335
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
336
|
+
)
|
337
|
+
|
338
|
+
def _create_flashinfer_prefill_backend(self):
|
339
|
+
if not global_server_args_dict["use_mla_backend"]:
|
340
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
341
|
+
FlashInferAttnBackend,
|
313
342
|
)
|
314
|
-
|
343
|
+
|
344
|
+
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
315
345
|
else:
|
346
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
347
|
+
FlashInferMLAAttnBackend,
|
348
|
+
)
|
349
|
+
|
350
|
+
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
351
|
+
|
352
|
+
def _create_triton_prefill_backend(self):
|
353
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
354
|
+
|
355
|
+
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
356
|
+
|
357
|
+
def _create_aiter_prefill_backend(self):
|
358
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
359
|
+
|
360
|
+
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
361
|
+
|
362
|
+
def _create_fa3_prefill_backend(self):
|
363
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
364
|
+
FlashAttentionBackend,
|
365
|
+
)
|
366
|
+
|
367
|
+
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
368
|
+
|
369
|
+
def _create_trtllm_mha_prefill_backend(self):
|
370
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
371
|
+
|
372
|
+
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
373
|
+
|
374
|
+
def _create_trtllm_mla_prefill_backend(self):
|
375
|
+
if not global_server_args_dict["use_mla_backend"]:
|
316
376
|
raise ValueError(
|
317
|
-
|
377
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
318
378
|
)
|
319
379
|
|
320
|
-
|
380
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
381
|
+
|
382
|
+
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
321
383
|
|
322
384
|
def init_cuda_graphs(self):
|
323
385
|
"""Capture cuda graphs."""
|
@@ -683,6 +745,14 @@ class EAGLEWorker(TpModelWorker):
|
|
683
745
|
|
684
746
|
# Set inputs
|
685
747
|
forward_batch.input_ids = input_ids
|
748
|
+
# This is a temporary fix for the case that the user is using standalone
|
749
|
+
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
|
750
|
+
# rope kernel needs cache_loc to be contiguous.
|
751
|
+
if (
|
752
|
+
self.server_args.speculative_algorithm == "STANDALONE"
|
753
|
+
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
|
754
|
+
):
|
755
|
+
out_cache_loc = out_cache_loc.contiguous()
|
686
756
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
687
757
|
forward_batch.positions.add_(1)
|
688
758
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
@@ -767,6 +837,21 @@ class EAGLEWorker(TpModelWorker):
|
|
767
837
|
]
|
768
838
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
769
839
|
|
840
|
+
# QQ: can be optimized
|
841
|
+
if self.target_worker.model_runner.is_hybrid_gdn:
|
842
|
+
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
843
|
+
accepted_length = (
|
844
|
+
torch.tensor(
|
845
|
+
res.accept_length_per_req_cpu,
|
846
|
+
device=logits_output.hidden_states.device,
|
847
|
+
dtype=torch.int32,
|
848
|
+
)
|
849
|
+
+ 1
|
850
|
+
)
|
851
|
+
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
852
|
+
accepted_length, self.target_worker.model_runner.model
|
853
|
+
)
|
854
|
+
|
770
855
|
if batch.return_logprob:
|
771
856
|
self.add_logprob_values(batch, res, logits_output)
|
772
857
|
|
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|
5
5
|
NONE = auto()
|
6
6
|
EAGLE = auto()
|
7
7
|
EAGLE3 = auto()
|
8
|
+
STANDALONE = auto()
|
8
9
|
|
9
10
|
def is_none(self):
|
10
11
|
return self == SpeculativeAlgorithm.NONE
|
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
|
|
15
16
|
def is_eagle3(self):
|
16
17
|
return self == SpeculativeAlgorithm.EAGLE3
|
17
18
|
|
19
|
+
def is_standalone(self):
|
20
|
+
return self == SpeculativeAlgorithm.STANDALONE
|
21
|
+
|
18
22
|
@staticmethod
|
19
23
|
def from_string(name: str):
|
20
24
|
name_map = {
|
21
25
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
22
26
|
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
27
|
+
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
23
28
|
None: SpeculativeAlgorithm.NONE,
|
24
29
|
}
|
25
30
|
if name is not None:
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
8
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
9
|
+
from sglang.srt.server_args import ServerArgs
|
10
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
|
11
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
12
|
+
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
13
|
+
|
14
|
+
if is_cuda():
|
15
|
+
from sgl_kernel import segment_packbits
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
19
|
+
|
20
|
+
|
21
|
+
@contextmanager
|
22
|
+
def draft_tp_context(tp_group: GroupCoordinator):
|
23
|
+
# Draft model doesn't use dp and has its own tp group.
|
24
|
+
# We disable mscclpp now because it doesn't support 2 comm groups.
|
25
|
+
with patch_tensor_parallel_group(tp_group):
|
26
|
+
yield
|
27
|
+
|
28
|
+
|
29
|
+
class StandaloneWorker(EAGLEWorker):
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
gpu_id: int,
|
35
|
+
tp_rank: int,
|
36
|
+
dp_rank: Optional[int],
|
37
|
+
moe_ep_rank: int,
|
38
|
+
nccl_port: int,
|
39
|
+
target_worker: TpModelWorker,
|
40
|
+
):
|
41
|
+
# Parse arguments
|
42
|
+
self.server_args = server_args
|
43
|
+
self.topk = server_args.speculative_eagle_topk
|
44
|
+
self.speculative_num_steps = server_args.speculative_num_steps
|
45
|
+
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
46
|
+
self.enable_nan_detection = server_args.enable_nan_detection
|
47
|
+
self.gpu_id = gpu_id
|
48
|
+
self.device = server_args.device
|
49
|
+
self.target_worker = target_worker
|
50
|
+
self.page_size = server_args.page_size
|
51
|
+
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
52
|
+
server_args.speculative_algorithm
|
53
|
+
)
|
54
|
+
self.padded_static_len = -1
|
55
|
+
|
56
|
+
# Override the context length of the draft model to be the same as the target model.
|
57
|
+
server_args.context_length = target_worker.model_runner.model_config.context_len
|
58
|
+
|
59
|
+
# Do not capture cuda graph in `super().__init__()`
|
60
|
+
# It will be captured later.
|
61
|
+
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
62
|
+
server_args.disable_cuda_graph = True
|
63
|
+
# Share the allocator with a target worker.
|
64
|
+
# Draft and target worker own their own KV cache pools.
|
65
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
66
|
+
target_worker.get_memory_pool()
|
67
|
+
)
|
68
|
+
|
69
|
+
# Load hot token ids
|
70
|
+
if server_args.speculative_token_map is not None:
|
71
|
+
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
72
|
+
server_args.json_model_override_args = (
|
73
|
+
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
self.hot_token_id = None
|
77
|
+
|
78
|
+
# Init draft worker
|
79
|
+
with empty_context():
|
80
|
+
TpModelWorker.__init__(
|
81
|
+
self,
|
82
|
+
server_args=server_args,
|
83
|
+
gpu_id=gpu_id,
|
84
|
+
tp_rank=tp_rank,
|
85
|
+
pp_rank=0, # FIXME
|
86
|
+
dp_rank=dp_rank,
|
87
|
+
moe_ep_rank=moe_ep_rank,
|
88
|
+
nccl_port=nccl_port,
|
89
|
+
is_draft_worker=True,
|
90
|
+
req_to_token_pool=self.req_to_token_pool,
|
91
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Init attention backend and cuda graphs
|
95
|
+
self.draft_model_runner.server_args.disable_cuda_graph = (
|
96
|
+
backup_disable_cuda_graph
|
97
|
+
)
|
98
|
+
self.draft_tp_context = (
|
99
|
+
draft_tp_context if server_args.enable_dp_attention else empty_context
|
100
|
+
)
|
101
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
102
|
+
self.init_attention_backend()
|
103
|
+
self.init_cuda_graphs()
|
104
|
+
|
105
|
+
# Some dummy tensors
|
106
|
+
self.num_new_pages_per_topk = torch.empty(
|
107
|
+
(), dtype=torch.int64, device=self.device
|
108
|
+
)
|
109
|
+
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|