sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
91
91
|
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
92
92
|
)
|
93
93
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
94
|
+
self.mrope_positions = torch.zeros(
|
95
|
+
(3, self.max_num_token), dtype=torch.int64
|
96
|
+
)
|
94
97
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
95
98
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
96
99
|
self.hidden_states = torch.zeros(
|
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
159
162
|
seq_lens = self.seq_lens[:num_seqs]
|
160
163
|
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
161
164
|
positions = self.positions[:num_tokens]
|
165
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
162
166
|
topk_p = self.topk_p[:num_seqs]
|
163
167
|
topk_index = self.topk_index[:num_seqs]
|
164
168
|
hidden_states = self.hidden_states[:num_seqs]
|
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
224
228
|
seq_lens_sum=seq_lens.sum().item(),
|
225
229
|
return_logprob=False,
|
226
230
|
positions=positions,
|
231
|
+
mrope_positions=mrope_positions,
|
227
232
|
global_num_tokens_gpu=global_num_tokens,
|
228
233
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
229
234
|
global_dp_buffer_len=global_dp_buffer_len,
|
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
80
80
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
81
81
|
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
82
82
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
83
|
+
self.mrope_positions = torch.zeros(
|
84
|
+
(3, self.max_num_token), dtype=torch.int64
|
85
|
+
)
|
83
86
|
|
84
87
|
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
85
88
|
self.hidden_states = torch.zeros(
|
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
189
192
|
accept_length = self.accept_length[:bs]
|
190
193
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
191
194
|
positions = self.positions[:num_tokens]
|
195
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
192
196
|
hidden_states = self.hidden_states[:num_tokens]
|
193
197
|
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
194
198
|
|
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
247
251
|
seq_lens_sum=seq_lens.sum().item(),
|
248
252
|
return_logprob=False,
|
249
253
|
positions=positions,
|
254
|
+
mrope_positions=mrope_positions,
|
250
255
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
251
256
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
252
257
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
@@ -336,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
336
341
|
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
337
342
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
338
343
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
339
|
-
|
344
|
+
if (
|
345
|
+
forward_batch.spec_info.hidden_states.shape[1]
|
346
|
+
== self.hidden_states.shape[1]
|
347
|
+
):
|
348
|
+
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
340
349
|
if forward_batch.spec_info.accept_length is not None:
|
341
350
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
342
351
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
@@ -26,8 +26,6 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
27
27
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
28
28
|
|
29
|
-
logger = logging.getLogger(__name__)
|
30
|
-
|
31
29
|
if is_cuda():
|
32
30
|
from sgl_kernel import (
|
33
31
|
fast_topk,
|
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
|
|
14
14
|
)
|
15
15
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
16
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
17
|
+
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
17
18
|
from sglang.srt.managers.schedule_batch import (
|
18
19
|
ScheduleBatch,
|
19
20
|
get_last_loc,
|
@@ -47,6 +48,7 @@ from sglang.srt.utils import (
|
|
47
48
|
empty_context,
|
48
49
|
get_available_gpu_memory,
|
49
50
|
get_bool_env_var,
|
51
|
+
is_blackwell,
|
50
52
|
is_cuda,
|
51
53
|
next_power_of_2,
|
52
54
|
)
|
@@ -190,7 +192,7 @@ class EAGLEWorker(TpModelWorker):
|
|
190
192
|
# Initialize decode attention backend
|
191
193
|
self.draft_attn_backend = self._create_decode_backend()
|
192
194
|
|
193
|
-
# Initialize
|
195
|
+
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
194
196
|
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
195
197
|
|
196
198
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
@@ -213,6 +215,11 @@ class EAGLEWorker(TpModelWorker):
|
|
213
215
|
"triton": self._create_triton_decode_backend,
|
214
216
|
"aiter": self._create_aiter_decode_backend,
|
215
217
|
"fa3": self._create_fa3_decode_backend,
|
218
|
+
"hybrid_linear_attn": (
|
219
|
+
self._create_fa3_decode_backend
|
220
|
+
if not is_blackwell()
|
221
|
+
else self._create_triton_decode_backend
|
222
|
+
),
|
216
223
|
"flashmla": self._create_flashmla_decode_backend,
|
217
224
|
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
218
225
|
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
@@ -230,14 +237,23 @@ class EAGLEWorker(TpModelWorker):
|
|
230
237
|
"triton": self._create_triton_prefill_backend,
|
231
238
|
"aiter": self._create_aiter_prefill_backend,
|
232
239
|
"fa3": self._create_fa3_prefill_backend,
|
240
|
+
"hybrid_linear_attn": (
|
241
|
+
self._create_fa3_prefill_backend
|
242
|
+
if not is_blackwell()
|
243
|
+
else self._create_triton_prefill_backend
|
244
|
+
),
|
233
245
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
234
246
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
235
247
|
}
|
236
|
-
|
248
|
+
backend_name = (
|
249
|
+
"decode_attention_backend"
|
250
|
+
if self.server_args.speculative_attention_mode == "decode"
|
251
|
+
else "prefill_attention_backend"
|
252
|
+
)
|
237
253
|
return self._create_backend(
|
238
|
-
|
254
|
+
backend_name,
|
239
255
|
backend_map,
|
240
|
-
"EAGLE is not supported in
|
256
|
+
"EAGLE is not supported in attention backend {backend_type}",
|
241
257
|
)
|
242
258
|
|
243
259
|
def _create_flashinfer_decode_backend(self):
|
@@ -729,6 +745,14 @@ class EAGLEWorker(TpModelWorker):
|
|
729
745
|
|
730
746
|
# Set inputs
|
731
747
|
forward_batch.input_ids = input_ids
|
748
|
+
# This is a temporary fix for the case that the user is using standalone
|
749
|
+
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
|
750
|
+
# rope kernel needs cache_loc to be contiguous.
|
751
|
+
if (
|
752
|
+
self.server_args.speculative_algorithm == "STANDALONE"
|
753
|
+
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
|
754
|
+
):
|
755
|
+
out_cache_loc = out_cache_loc.contiguous()
|
732
756
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
733
757
|
forward_batch.positions.add_(1)
|
734
758
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
@@ -813,6 +837,21 @@ class EAGLEWorker(TpModelWorker):
|
|
813
837
|
]
|
814
838
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
815
839
|
|
840
|
+
# QQ: can be optimized
|
841
|
+
if self.target_worker.model_runner.is_hybrid_gdn:
|
842
|
+
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
843
|
+
accepted_length = (
|
844
|
+
torch.tensor(
|
845
|
+
res.accept_length_per_req_cpu,
|
846
|
+
device=logits_output.hidden_states.device,
|
847
|
+
dtype=torch.int32,
|
848
|
+
)
|
849
|
+
+ 1
|
850
|
+
)
|
851
|
+
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
852
|
+
accepted_length, self.target_worker.model_runner.model
|
853
|
+
)
|
854
|
+
|
816
855
|
if batch.return_logprob:
|
817
856
|
self.add_logprob_values(batch, res, logits_output)
|
818
857
|
|
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|
5
5
|
NONE = auto()
|
6
6
|
EAGLE = auto()
|
7
7
|
EAGLE3 = auto()
|
8
|
+
STANDALONE = auto()
|
8
9
|
|
9
10
|
def is_none(self):
|
10
11
|
return self == SpeculativeAlgorithm.NONE
|
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
|
|
15
16
|
def is_eagle3(self):
|
16
17
|
return self == SpeculativeAlgorithm.EAGLE3
|
17
18
|
|
19
|
+
def is_standalone(self):
|
20
|
+
return self == SpeculativeAlgorithm.STANDALONE
|
21
|
+
|
18
22
|
@staticmethod
|
19
23
|
def from_string(name: str):
|
20
24
|
name_map = {
|
21
25
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
22
26
|
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
27
|
+
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
23
28
|
None: SpeculativeAlgorithm.NONE,
|
24
29
|
}
|
25
30
|
if name is not None:
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
8
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
9
|
+
from sglang.srt.server_args import ServerArgs
|
10
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
|
11
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
12
|
+
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
13
|
+
|
14
|
+
if is_cuda():
|
15
|
+
from sgl_kernel import segment_packbits
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
19
|
+
|
20
|
+
|
21
|
+
@contextmanager
|
22
|
+
def draft_tp_context(tp_group: GroupCoordinator):
|
23
|
+
# Draft model doesn't use dp and has its own tp group.
|
24
|
+
# We disable mscclpp now because it doesn't support 2 comm groups.
|
25
|
+
with patch_tensor_parallel_group(tp_group):
|
26
|
+
yield
|
27
|
+
|
28
|
+
|
29
|
+
class StandaloneWorker(EAGLEWorker):
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
gpu_id: int,
|
35
|
+
tp_rank: int,
|
36
|
+
dp_rank: Optional[int],
|
37
|
+
moe_ep_rank: int,
|
38
|
+
nccl_port: int,
|
39
|
+
target_worker: TpModelWorker,
|
40
|
+
):
|
41
|
+
# Parse arguments
|
42
|
+
self.server_args = server_args
|
43
|
+
self.topk = server_args.speculative_eagle_topk
|
44
|
+
self.speculative_num_steps = server_args.speculative_num_steps
|
45
|
+
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
46
|
+
self.enable_nan_detection = server_args.enable_nan_detection
|
47
|
+
self.gpu_id = gpu_id
|
48
|
+
self.device = server_args.device
|
49
|
+
self.target_worker = target_worker
|
50
|
+
self.page_size = server_args.page_size
|
51
|
+
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
52
|
+
server_args.speculative_algorithm
|
53
|
+
)
|
54
|
+
self.padded_static_len = -1
|
55
|
+
|
56
|
+
# Override the context length of the draft model to be the same as the target model.
|
57
|
+
server_args.context_length = target_worker.model_runner.model_config.context_len
|
58
|
+
|
59
|
+
# Do not capture cuda graph in `super().__init__()`
|
60
|
+
# It will be captured later.
|
61
|
+
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
62
|
+
server_args.disable_cuda_graph = True
|
63
|
+
# Share the allocator with a target worker.
|
64
|
+
# Draft and target worker own their own KV cache pools.
|
65
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
66
|
+
target_worker.get_memory_pool()
|
67
|
+
)
|
68
|
+
|
69
|
+
# Load hot token ids
|
70
|
+
if server_args.speculative_token_map is not None:
|
71
|
+
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
72
|
+
server_args.json_model_override_args = (
|
73
|
+
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
self.hot_token_id = None
|
77
|
+
|
78
|
+
# Init draft worker
|
79
|
+
with empty_context():
|
80
|
+
TpModelWorker.__init__(
|
81
|
+
self,
|
82
|
+
server_args=server_args,
|
83
|
+
gpu_id=gpu_id,
|
84
|
+
tp_rank=tp_rank,
|
85
|
+
pp_rank=0, # FIXME
|
86
|
+
dp_rank=dp_rank,
|
87
|
+
moe_ep_rank=moe_ep_rank,
|
88
|
+
nccl_port=nccl_port,
|
89
|
+
is_draft_worker=True,
|
90
|
+
req_to_token_pool=self.req_to_token_pool,
|
91
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Init attention backend and cuda graphs
|
95
|
+
self.draft_model_runner.server_args.disable_cuda_graph = (
|
96
|
+
backup_disable_cuda_graph
|
97
|
+
)
|
98
|
+
self.draft_tp_context = (
|
99
|
+
draft_tp_context if server_args.enable_dp_attention else empty_context
|
100
|
+
)
|
101
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
102
|
+
self.init_attention_backend()
|
103
|
+
self.init_cuda_graphs()
|
104
|
+
|
105
|
+
# Some dummy tensors
|
106
|
+
self.num_new_pages_per_topk = torch.empty(
|
107
|
+
(), dtype=torch.int64, device=self.device
|
108
|
+
)
|
109
|
+
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|