sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
91
91
|
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
92
92
|
)
|
93
93
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
94
|
+
self.mrope_positions = torch.zeros(
|
95
|
+
(3, self.max_num_token), dtype=torch.int64
|
96
|
+
)
|
94
97
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
95
98
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
96
99
|
self.hidden_states = torch.zeros(
|
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
159
162
|
seq_lens = self.seq_lens[:num_seqs]
|
160
163
|
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
161
164
|
positions = self.positions[:num_tokens]
|
165
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
162
166
|
topk_p = self.topk_p[:num_seqs]
|
163
167
|
topk_index = self.topk_index[:num_seqs]
|
164
168
|
hidden_states = self.hidden_states[:num_seqs]
|
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
224
228
|
seq_lens_sum=seq_lens.sum().item(),
|
225
229
|
return_logprob=False,
|
226
230
|
positions=positions,
|
231
|
+
mrope_positions=mrope_positions,
|
227
232
|
global_num_tokens_gpu=global_num_tokens,
|
228
233
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
229
234
|
global_dp_buffer_len=global_dp_buffer_len,
|
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
80
80
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
81
81
|
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
82
82
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
83
|
+
self.mrope_positions = torch.zeros(
|
84
|
+
(3, self.max_num_token), dtype=torch.int64
|
85
|
+
)
|
83
86
|
|
84
87
|
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
85
88
|
self.hidden_states = torch.zeros(
|
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
189
192
|
accept_length = self.accept_length[:bs]
|
190
193
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
191
194
|
positions = self.positions[:num_tokens]
|
195
|
+
mrope_positions = self.mrope_positions[:, :num_tokens]
|
192
196
|
hidden_states = self.hidden_states[:num_tokens]
|
193
197
|
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
194
198
|
|
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
247
251
|
seq_lens_sum=seq_lens.sum().item(),
|
248
252
|
return_logprob=False,
|
249
253
|
positions=positions,
|
254
|
+
mrope_positions=mrope_positions,
|
250
255
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
251
256
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
252
257
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
@@ -336,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
336
341
|
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
337
342
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
338
343
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
339
|
-
|
344
|
+
if (
|
345
|
+
forward_batch.spec_info.hidden_states.shape[1]
|
346
|
+
== self.hidden_states.shape[1]
|
347
|
+
):
|
348
|
+
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
340
349
|
if forward_batch.spec_info.accept_length is not None:
|
341
350
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
342
351
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
|
|
14
14
|
)
|
15
15
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
16
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
17
|
+
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
17
18
|
from sglang.srt.managers.schedule_batch import (
|
18
19
|
ScheduleBatch,
|
19
20
|
get_last_loc,
|
@@ -46,6 +47,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
46
47
|
from sglang.srt.utils import (
|
47
48
|
empty_context,
|
48
49
|
get_available_gpu_memory,
|
50
|
+
get_bool_env_var,
|
49
51
|
is_cuda,
|
50
52
|
next_power_of_2,
|
51
53
|
)
|
@@ -54,6 +56,7 @@ if is_cuda():
|
|
54
56
|
from sgl_kernel import segment_packbits
|
55
57
|
|
56
58
|
logger = logging.getLogger(__name__)
|
59
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
57
60
|
|
58
61
|
|
59
62
|
@contextmanager
|
@@ -137,8 +140,15 @@ class EAGLEWorker(TpModelWorker):
|
|
137
140
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
138
141
|
|
139
142
|
if self.speculative_algorithm.is_eagle3():
|
140
|
-
# EAGLE3 models don't share lm_head
|
141
|
-
|
143
|
+
# most cases EAGLE3 models don't share lm_head
|
144
|
+
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
|
145
|
+
if (
|
146
|
+
hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
|
147
|
+
and self.draft_model_runner.model.load_lm_head_from_target
|
148
|
+
):
|
149
|
+
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
150
|
+
else:
|
151
|
+
self.draft_model_runner.model.set_embed(embed)
|
142
152
|
|
143
153
|
# grab hot token ids
|
144
154
|
if self.draft_model_runner.model.hot_token_id is not None:
|
@@ -178,137 +188,189 @@ class EAGLEWorker(TpModelWorker):
|
|
178
188
|
self.has_prefill_wrapper_verify = False
|
179
189
|
self.draft_extend_attn_backend = None
|
180
190
|
|
181
|
-
|
182
|
-
|
183
|
-
from sglang.srt.layers.attention.flashinfer_backend import (
|
184
|
-
FlashInferAttnBackend,
|
185
|
-
FlashInferMultiStepDraftBackend,
|
186
|
-
)
|
191
|
+
# Initialize decode attention backend
|
192
|
+
self.draft_attn_backend = self._create_decode_backend()
|
187
193
|
|
188
|
-
|
189
|
-
|
190
|
-
self.topk,
|
191
|
-
self.speculative_num_steps,
|
192
|
-
)
|
193
|
-
self.draft_extend_attn_backend = FlashInferAttnBackend(
|
194
|
-
self.draft_model_runner,
|
195
|
-
skip_prefill=False,
|
196
|
-
)
|
197
|
-
else:
|
198
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
199
|
-
FlashInferMLAAttnBackend,
|
200
|
-
FlashInferMLAMultiStepDraftBackend,
|
201
|
-
)
|
194
|
+
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
195
|
+
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
202
196
|
|
203
|
-
|
204
|
-
self.draft_model_runner,
|
205
|
-
self.topk,
|
206
|
-
self.speculative_num_steps,
|
207
|
-
)
|
208
|
-
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
|
209
|
-
self.draft_model_runner,
|
210
|
-
skip_prefill=False,
|
211
|
-
)
|
212
|
-
self.has_prefill_wrapper_verify = True
|
213
|
-
elif self.server_args.attention_backend == "triton":
|
214
|
-
from sglang.srt.layers.attention.triton_backend import (
|
215
|
-
TritonAttnBackend,
|
216
|
-
TritonMultiStepDraftBackend,
|
217
|
-
)
|
197
|
+
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
218
198
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
)
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
199
|
+
def _create_backend(
|
200
|
+
self, backend_name: str, backend_map: dict, error_template: str
|
201
|
+
):
|
202
|
+
backend_type = getattr(self.server_args, backend_name)
|
203
|
+
if backend_type is None:
|
204
|
+
backend_type = self.server_args.attention_backend
|
205
|
+
|
206
|
+
if backend_type not in backend_map:
|
207
|
+
raise ValueError(error_template.format(backend_type=backend_type))
|
208
|
+
|
209
|
+
return backend_map[backend_type]()
|
210
|
+
|
211
|
+
def _create_decode_backend(self):
|
212
|
+
backend_map = {
|
213
|
+
"flashinfer": self._create_flashinfer_decode_backend,
|
214
|
+
"triton": self._create_triton_decode_backend,
|
215
|
+
"aiter": self._create_aiter_decode_backend,
|
216
|
+
"fa3": self._create_fa3_decode_backend,
|
217
|
+
"hybrid_linear_attn": self._create_fa3_decode_backend,
|
218
|
+
"flashmla": self._create_flashmla_decode_backend,
|
219
|
+
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
220
|
+
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
221
|
+
}
|
222
|
+
|
223
|
+
return self._create_backend(
|
224
|
+
"decode_attention_backend",
|
225
|
+
backend_map,
|
226
|
+
"EAGLE is not supported in decode attention backend {backend_type}",
|
227
|
+
)
|
233
228
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
self.
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
229
|
+
def _create_draft_extend_backend(self):
|
230
|
+
backend_map = {
|
231
|
+
"flashinfer": self._create_flashinfer_prefill_backend,
|
232
|
+
"triton": self._create_triton_prefill_backend,
|
233
|
+
"aiter": self._create_aiter_prefill_backend,
|
234
|
+
"fa3": self._create_fa3_prefill_backend,
|
235
|
+
"hybrid_linear_attn": self._create_fa3_prefill_backend,
|
236
|
+
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
237
|
+
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
238
|
+
}
|
239
|
+
backend_name = (
|
240
|
+
"decode_attention_backend"
|
241
|
+
if self.server_args.speculative_attention_mode == "decode"
|
242
|
+
else "prefill_attention_backend"
|
243
|
+
)
|
244
|
+
return self._create_backend(
|
245
|
+
backend_name,
|
246
|
+
backend_map,
|
247
|
+
"EAGLE is not supported in attention backend {backend_type}",
|
248
|
+
)
|
249
249
|
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
)
|
255
|
-
self.draft_extend_attn_backend = FlashAttentionBackend(
|
256
|
-
self.draft_model_runner,
|
257
|
-
skip_prefill=False,
|
258
|
-
)
|
259
|
-
elif self.server_args.attention_backend == "flashmla":
|
260
|
-
from sglang.srt.layers.attention.flashmla_backend import (
|
261
|
-
FlashMLAMultiStepDraftBackend,
|
250
|
+
def _create_flashinfer_decode_backend(self):
|
251
|
+
if not global_server_args_dict["use_mla_backend"]:
|
252
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
253
|
+
FlashInferMultiStepDraftBackend,
|
262
254
|
)
|
263
255
|
|
264
|
-
self.
|
265
|
-
|
266
|
-
self.topk,
|
267
|
-
self.speculative_num_steps,
|
256
|
+
self.has_prefill_wrapper_verify = True
|
257
|
+
return FlashInferMultiStepDraftBackend(
|
258
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
268
259
|
)
|
269
|
-
|
270
|
-
from sglang.srt.layers.attention.
|
271
|
-
|
272
|
-
TRTLLMHAAttnMultiStepDraftBackend,
|
260
|
+
else:
|
261
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
262
|
+
FlashInferMLAMultiStepDraftBackend,
|
273
263
|
)
|
274
264
|
|
275
|
-
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
276
|
-
self.draft_model_runner,
|
277
|
-
self.topk,
|
278
|
-
self.speculative_num_steps,
|
279
|
-
)
|
280
|
-
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
281
|
-
self.draft_model_runner,
|
282
|
-
skip_prefill=False,
|
283
|
-
)
|
284
265
|
self.has_prefill_wrapper_verify = True
|
285
|
-
|
286
|
-
|
287
|
-
raise ValueError(
|
288
|
-
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
289
|
-
)
|
290
|
-
|
291
|
-
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
292
|
-
TRTLLMMLABackend,
|
293
|
-
TRTLLMMLAMultiStepDraftBackend,
|
266
|
+
return FlashInferMLAMultiStepDraftBackend(
|
267
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
294
268
|
)
|
295
269
|
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
270
|
+
def _create_triton_decode_backend(self):
|
271
|
+
from sglang.srt.layers.attention.triton_backend import (
|
272
|
+
TritonMultiStepDraftBackend,
|
273
|
+
)
|
274
|
+
|
275
|
+
return TritonMultiStepDraftBackend(
|
276
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
277
|
+
)
|
278
|
+
|
279
|
+
def _create_aiter_decode_backend(self):
|
280
|
+
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
281
|
+
|
282
|
+
return AiterMultiStepDraftBackend(
|
283
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
284
|
+
)
|
285
|
+
|
286
|
+
def _create_fa3_decode_backend(self):
|
287
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
288
|
+
FlashAttentionMultiStepBackend,
|
289
|
+
)
|
290
|
+
|
291
|
+
return FlashAttentionMultiStepBackend(
|
292
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
293
|
+
)
|
294
|
+
|
295
|
+
def _create_flashmla_decode_backend(self):
|
296
|
+
from sglang.srt.layers.attention.flashmla_backend import (
|
297
|
+
FlashMLAMultiStepDraftBackend,
|
298
|
+
)
|
299
|
+
|
300
|
+
return FlashMLAMultiStepDraftBackend(
|
301
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
302
|
+
)
|
303
|
+
|
304
|
+
def _create_trtllm_mha_decode_backend(self):
|
305
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
306
|
+
TRTLLMHAAttnMultiStepDraftBackend,
|
307
|
+
)
|
308
|
+
|
309
|
+
self.has_prefill_wrapper_verify = True
|
310
|
+
return TRTLLMHAAttnMultiStepDraftBackend(
|
311
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
312
|
+
)
|
313
|
+
|
314
|
+
def _create_trtllm_mla_decode_backend(self):
|
315
|
+
if not global_server_args_dict["use_mla_backend"]:
|
316
|
+
raise ValueError(
|
317
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
300
318
|
)
|
301
|
-
|
302
|
-
|
303
|
-
|
319
|
+
|
320
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
321
|
+
TRTLLMMLAMultiStepDraftBackend,
|
322
|
+
)
|
323
|
+
|
324
|
+
self.has_prefill_wrapper_verify = True
|
325
|
+
return TRTLLMMLAMultiStepDraftBackend(
|
326
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
327
|
+
)
|
328
|
+
|
329
|
+
def _create_flashinfer_prefill_backend(self):
|
330
|
+
if not global_server_args_dict["use_mla_backend"]:
|
331
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
332
|
+
FlashInferAttnBackend,
|
304
333
|
)
|
305
|
-
|
334
|
+
|
335
|
+
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
306
336
|
else:
|
337
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
338
|
+
FlashInferMLAAttnBackend,
|
339
|
+
)
|
340
|
+
|
341
|
+
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
342
|
+
|
343
|
+
def _create_triton_prefill_backend(self):
|
344
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
345
|
+
|
346
|
+
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
347
|
+
|
348
|
+
def _create_aiter_prefill_backend(self):
|
349
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
350
|
+
|
351
|
+
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
352
|
+
|
353
|
+
def _create_fa3_prefill_backend(self):
|
354
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
355
|
+
FlashAttentionBackend,
|
356
|
+
)
|
357
|
+
|
358
|
+
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
359
|
+
|
360
|
+
def _create_trtllm_mha_prefill_backend(self):
|
361
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
362
|
+
|
363
|
+
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
364
|
+
|
365
|
+
def _create_trtllm_mla_prefill_backend(self):
|
366
|
+
if not global_server_args_dict["use_mla_backend"]:
|
307
367
|
raise ValueError(
|
308
|
-
|
368
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
309
369
|
)
|
310
370
|
|
311
|
-
|
371
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
372
|
+
|
373
|
+
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
312
374
|
|
313
375
|
def init_cuda_graphs(self):
|
314
376
|
"""Capture cuda graphs."""
|
@@ -674,6 +736,14 @@ class EAGLEWorker(TpModelWorker):
|
|
674
736
|
|
675
737
|
# Set inputs
|
676
738
|
forward_batch.input_ids = input_ids
|
739
|
+
# This is a temporary fix for the case that the user is using standalone
|
740
|
+
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
|
741
|
+
# rope kernel needs cache_loc to be contiguous.
|
742
|
+
if (
|
743
|
+
self.server_args.speculative_algorithm == "STANDALONE"
|
744
|
+
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
|
745
|
+
):
|
746
|
+
out_cache_loc = out_cache_loc.contiguous()
|
677
747
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
678
748
|
forward_batch.positions.add_(1)
|
679
749
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
@@ -758,6 +828,21 @@ class EAGLEWorker(TpModelWorker):
|
|
758
828
|
]
|
759
829
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
760
830
|
|
831
|
+
# QQ: can be optimized
|
832
|
+
if self.target_worker.model_runner.is_hybrid_gdn:
|
833
|
+
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
834
|
+
accepted_length = (
|
835
|
+
torch.tensor(
|
836
|
+
res.accept_length_per_req_cpu,
|
837
|
+
device=logits_output.hidden_states.device,
|
838
|
+
dtype=torch.int32,
|
839
|
+
)
|
840
|
+
+ 1
|
841
|
+
)
|
842
|
+
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
843
|
+
accepted_length, self.target_worker.model_runner.model
|
844
|
+
)
|
845
|
+
|
761
846
|
if batch.return_logprob:
|
762
847
|
self.add_logprob_values(batch, res, logits_output)
|
763
848
|
|
@@ -781,15 +866,20 @@ class EAGLEWorker(TpModelWorker):
|
|
781
866
|
token_ids_logprobs = batch.token_ids_logprobs
|
782
867
|
accepted_indices = res.accepted_indices
|
783
868
|
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
869
|
+
|
784
870
|
temperatures = batch.sampling_info.temperatures
|
785
871
|
num_draft_tokens = batch.spec_info.draft_token_num
|
786
872
|
# acceptance indices are the indices in a "flattened" batch.
|
787
873
|
# dividing it to num_draft_tokens will yield the actual batch index.
|
788
874
|
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
875
|
+
if RETURN_ORIGINAL_LOGPROB:
|
876
|
+
logprobs = torch.nn.functional.log_softmax(
|
877
|
+
logits_output.next_token_logits, dim=-1
|
878
|
+
)
|
879
|
+
else:
|
880
|
+
logprobs = torch.nn.functional.log_softmax(
|
881
|
+
logits_output.next_token_logits / temperatures, dim=-1
|
882
|
+
)
|
793
883
|
batch_next_token_ids = res.verified_id
|
794
884
|
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
795
885
|
|
@@ -806,13 +896,19 @@ class EAGLEWorker(TpModelWorker):
|
|
806
896
|
(
|
807
897
|
logits_output.next_token_top_logprobs_val,
|
808
898
|
logits_output.next_token_top_logprobs_idx,
|
809
|
-
) = get_top_logprobs(
|
899
|
+
) = get_top_logprobs(
|
900
|
+
logprobs,
|
901
|
+
top_logprobs_nums_repeat_interleaved,
|
902
|
+
)
|
810
903
|
|
811
904
|
if any(x is not None for x in token_ids_logprobs):
|
812
905
|
(
|
813
906
|
logits_output.next_token_token_ids_logprobs_val,
|
814
907
|
logits_output.next_token_token_ids_logprobs_idx,
|
815
|
-
) = get_token_ids_logprobs(
|
908
|
+
) = get_token_ids_logprobs(
|
909
|
+
logprobs,
|
910
|
+
token_ids_logprobs_repeat_interleaved,
|
911
|
+
)
|
816
912
|
|
817
913
|
logits_output.next_token_logprobs = logprobs[
|
818
914
|
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|
5
5
|
NONE = auto()
|
6
6
|
EAGLE = auto()
|
7
7
|
EAGLE3 = auto()
|
8
|
+
STANDALONE = auto()
|
8
9
|
|
9
10
|
def is_none(self):
|
10
11
|
return self == SpeculativeAlgorithm.NONE
|
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
|
|
15
16
|
def is_eagle3(self):
|
16
17
|
return self == SpeculativeAlgorithm.EAGLE3
|
17
18
|
|
19
|
+
def is_standalone(self):
|
20
|
+
return self == SpeculativeAlgorithm.STANDALONE
|
21
|
+
|
18
22
|
@staticmethod
|
19
23
|
def from_string(name: str):
|
20
24
|
name_map = {
|
21
25
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
22
26
|
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
27
|
+
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
23
28
|
None: SpeculativeAlgorithm.NONE,
|
24
29
|
}
|
25
30
|
if name is not None:
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
8
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
9
|
+
from sglang.srt.server_args import ServerArgs
|
10
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
|
11
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
12
|
+
from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
|
13
|
+
|
14
|
+
if is_cuda():
|
15
|
+
from sgl_kernel import segment_packbits
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
19
|
+
|
20
|
+
|
21
|
+
@contextmanager
|
22
|
+
def draft_tp_context(tp_group: GroupCoordinator):
|
23
|
+
# Draft model doesn't use dp and has its own tp group.
|
24
|
+
# We disable mscclpp now because it doesn't support 2 comm groups.
|
25
|
+
with patch_tensor_parallel_group(tp_group):
|
26
|
+
yield
|
27
|
+
|
28
|
+
|
29
|
+
class StandaloneWorker(EAGLEWorker):
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
gpu_id: int,
|
35
|
+
tp_rank: int,
|
36
|
+
dp_rank: Optional[int],
|
37
|
+
moe_ep_rank: int,
|
38
|
+
nccl_port: int,
|
39
|
+
target_worker: TpModelWorker,
|
40
|
+
):
|
41
|
+
# Parse arguments
|
42
|
+
self.server_args = server_args
|
43
|
+
self.topk = server_args.speculative_eagle_topk
|
44
|
+
self.speculative_num_steps = server_args.speculative_num_steps
|
45
|
+
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
46
|
+
self.enable_nan_detection = server_args.enable_nan_detection
|
47
|
+
self.gpu_id = gpu_id
|
48
|
+
self.device = server_args.device
|
49
|
+
self.target_worker = target_worker
|
50
|
+
self.page_size = server_args.page_size
|
51
|
+
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
52
|
+
server_args.speculative_algorithm
|
53
|
+
)
|
54
|
+
self.padded_static_len = -1
|
55
|
+
|
56
|
+
# Override the context length of the draft model to be the same as the target model.
|
57
|
+
server_args.context_length = target_worker.model_runner.model_config.context_len
|
58
|
+
|
59
|
+
# Do not capture cuda graph in `super().__init__()`
|
60
|
+
# It will be captured later.
|
61
|
+
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
62
|
+
server_args.disable_cuda_graph = True
|
63
|
+
# Share the allocator with a target worker.
|
64
|
+
# Draft and target worker own their own KV cache pools.
|
65
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
66
|
+
target_worker.get_memory_pool()
|
67
|
+
)
|
68
|
+
|
69
|
+
# Load hot token ids
|
70
|
+
if server_args.speculative_token_map is not None:
|
71
|
+
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
72
|
+
server_args.json_model_override_args = (
|
73
|
+
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
self.hot_token_id = None
|
77
|
+
|
78
|
+
# Init draft worker
|
79
|
+
with empty_context():
|
80
|
+
TpModelWorker.__init__(
|
81
|
+
self,
|
82
|
+
server_args=server_args,
|
83
|
+
gpu_id=gpu_id,
|
84
|
+
tp_rank=tp_rank,
|
85
|
+
pp_rank=0, # FIXME
|
86
|
+
dp_rank=dp_rank,
|
87
|
+
moe_ep_rank=moe_ep_rank,
|
88
|
+
nccl_port=nccl_port,
|
89
|
+
is_draft_worker=True,
|
90
|
+
req_to_token_pool=self.req_to_token_pool,
|
91
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Init attention backend and cuda graphs
|
95
|
+
self.draft_model_runner.server_args.disable_cuda_graph = (
|
96
|
+
backup_disable_cuda_graph
|
97
|
+
)
|
98
|
+
self.draft_tp_context = (
|
99
|
+
draft_tp_context if server_args.enable_dp_attention else empty_context
|
100
|
+
)
|
101
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
102
|
+
self.init_attention_backend()
|
103
|
+
self.init_cuda_graphs()
|
104
|
+
|
105
|
+
# Some dummy tensors
|
106
|
+
self.num_new_pages_per_topk = torch.empty(
|
107
|
+
(), dtype=torch.int64, device=self.device
|
108
|
+
)
|
109
|
+
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|