sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode
|
8
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
9
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
10
11
|
CudaGraphRunner,
|
@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
109
110
|
)
|
110
111
|
|
111
112
|
if self.require_gathered_buffer:
|
112
|
-
self.gathered_buffer = torch.zeros(
|
113
|
-
(
|
114
|
-
self.max_num_token,
|
115
|
-
self.model_runner.model_config.hidden_size,
|
116
|
-
),
|
117
|
-
dtype=self.model_runner.dtype,
|
118
|
-
)
|
119
113
|
if self.require_mlp_tp_gather:
|
120
114
|
self.global_num_tokens_gpu = torch.zeros(
|
121
115
|
(self.dp_size,), dtype=torch.int32
|
@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
123
117
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
124
118
|
(self.dp_size,), dtype=torch.int32
|
125
119
|
)
|
120
|
+
self.gathered_buffer = torch.zeros(
|
121
|
+
(
|
122
|
+
self.max_num_token * self.dp_size,
|
123
|
+
self.model_runner.model_config.hidden_size,
|
124
|
+
),
|
125
|
+
dtype=self.model_runner.dtype,
|
126
|
+
)
|
126
127
|
else:
|
127
128
|
assert self.require_attn_tp_gather
|
128
129
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
129
130
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
130
131
|
(1,), dtype=torch.int32
|
131
132
|
)
|
133
|
+
self.gathered_buffer = torch.zeros(
|
134
|
+
(
|
135
|
+
self.max_num_token,
|
136
|
+
self.model_runner.model_config.hidden_size,
|
137
|
+
),
|
138
|
+
dtype=self.model_runner.dtype,
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
self.global_num_tokens_gpu = None
|
142
|
+
self.global_num_tokens_for_logprob_gpu = None
|
143
|
+
self.gathered_buffer = None
|
144
|
+
|
132
145
|
# Capture
|
133
146
|
try:
|
134
147
|
with model_capture_mode():
|
@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
141
154
|
def can_run(self, forward_batch: ForwardBatch):
|
142
155
|
if self.require_mlp_tp_gather:
|
143
156
|
cuda_graph_bs = (
|
144
|
-
|
157
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
145
158
|
if self.model_runner.spec_algorithm.is_eagle()
|
146
|
-
else
|
159
|
+
else max(forward_batch.global_num_tokens_cpu)
|
147
160
|
)
|
148
161
|
else:
|
149
162
|
cuda_graph_bs = forward_batch.seq_lens.numel()
|
@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
180
193
|
if self.require_mlp_tp_gather:
|
181
194
|
self.global_num_tokens_gpu.copy_(
|
182
195
|
torch.tensor(
|
183
|
-
[
|
184
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
185
|
-
for i in range(self.dp_size)
|
186
|
-
],
|
196
|
+
[num_tokens] * self.dp_size,
|
187
197
|
dtype=torch.int32,
|
188
198
|
device=self.input_ids.device,
|
189
199
|
)
|
190
200
|
)
|
191
201
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
192
202
|
torch.tensor(
|
193
|
-
[
|
194
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
195
|
-
for i in range(self.dp_size)
|
196
|
-
],
|
203
|
+
[bs] * self.dp_size,
|
197
204
|
dtype=torch.int32,
|
198
205
|
device=self.input_ids.device,
|
199
206
|
)
|
200
207
|
)
|
201
|
-
|
202
|
-
gathered_buffer = self.gathered_buffer[:num_tokens]
|
203
|
-
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
208
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
204
209
|
elif self.require_attn_tp_gather:
|
205
210
|
self.global_num_tokens_gpu.copy_(
|
206
211
|
torch.tensor(
|
@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
211
216
|
)
|
212
217
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
213
218
|
torch.tensor(
|
214
|
-
[
|
219
|
+
[bs],
|
215
220
|
dtype=torch.int32,
|
216
221
|
device=self.input_ids.device,
|
217
222
|
)
|
218
223
|
)
|
219
|
-
global_num_tokens = self.global_num_tokens_gpu
|
220
224
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
221
|
-
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
222
225
|
else:
|
223
|
-
global_num_tokens = None
|
224
226
|
gathered_buffer = None
|
225
|
-
global_num_tokens_for_logprob = None
|
226
227
|
|
227
228
|
spec_info = EagleDraftInput(
|
228
229
|
hidden_states=hidden_states,
|
@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
243
244
|
seq_lens_sum=seq_lens.sum().item(),
|
244
245
|
return_logprob=False,
|
245
246
|
positions=positions,
|
246
|
-
global_num_tokens_gpu=
|
247
|
-
global_num_tokens_for_logprob_gpu=
|
247
|
+
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
248
|
+
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
249
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
248
250
|
gathered_buffer=gathered_buffer,
|
249
251
|
spec_algorithm=self.model_runner.spec_algorithm,
|
250
252
|
spec_info=spec_info,
|
@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
306
308
|
raw_bs = forward_batch.batch_size
|
307
309
|
num_tokens = forward_batch.input_ids.shape[0]
|
308
310
|
if self.require_mlp_tp_gather:
|
309
|
-
|
310
|
-
|
311
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
312
|
+
max_batch_size = (
|
313
|
+
max_num_tokens // self.num_tokens_per_bs
|
311
314
|
if self.model_runner.spec_algorithm.is_eagle()
|
312
|
-
else
|
315
|
+
else max_num_tokens
|
313
316
|
)
|
314
|
-
index = bisect.bisect_left(self.capture_bs,
|
317
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
315
318
|
else:
|
316
319
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
317
320
|
|
@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
334
337
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
335
338
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
336
339
|
|
340
|
+
# TODO(ch-wan): support num_token_non_padded
|
337
341
|
if self.require_gathered_buffer:
|
338
|
-
self.global_num_tokens_gpu.
|
339
|
-
self.global_num_tokens_for_logprob_gpu.
|
340
|
-
forward_batch.global_num_tokens_for_logprob_gpu
|
341
|
-
)
|
342
|
-
forward_batch.gathered_buffer = self.gathered_buffer
|
342
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
343
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs)
|
343
344
|
|
344
345
|
if forward_batch.seq_lens_cpu is not None:
|
345
346
|
if bs != raw_bs:
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
3
4
|
import logging
|
4
5
|
import os
|
5
6
|
import time
|
@@ -70,9 +71,20 @@ class EagleDraftInput:
|
|
70
71
|
kv_indptr: torch.Tensor = None
|
71
72
|
kv_indices: torch.Tensor = None
|
72
73
|
|
74
|
+
# Shape info for padding
|
75
|
+
num_tokens_per_batch: int = -1
|
76
|
+
num_tokens_for_logprob_per_batch: int = -1
|
77
|
+
|
78
|
+
# Inputs for draft extend
|
79
|
+
# shape: (b,)
|
80
|
+
seq_lens_for_draft_extend: torch.Tensor = None
|
81
|
+
req_pool_indices_for_draft_extend: torch.Tensor = None
|
82
|
+
|
73
83
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
84
|
+
|
74
85
|
if batch.forward_mode.is_idle():
|
75
86
|
return
|
87
|
+
|
76
88
|
# Prefill only generate 1 token.
|
77
89
|
assert len(self.verified_id) == len(batch.seq_lens)
|
78
90
|
|
@@ -94,7 +106,7 @@ class EagleDraftInput:
|
|
94
106
|
capture_hidden_mode: CaptureHiddenMode,
|
95
107
|
):
|
96
108
|
return cls(
|
97
|
-
verified_id=
|
109
|
+
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
98
110
|
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
99
111
|
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
100
112
|
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
@@ -108,7 +120,10 @@ class EagleDraftInput:
|
|
108
120
|
batch: ScheduleBatch,
|
109
121
|
speculative_num_steps: int,
|
110
122
|
):
|
111
|
-
|
123
|
+
|
124
|
+
if batch.forward_mode.is_idle():
|
125
|
+
return
|
126
|
+
|
112
127
|
batch.input_ids = self.verified_id
|
113
128
|
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
114
129
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
@@ -315,7 +330,7 @@ class EagleVerifyInput:
|
|
315
330
|
def verify(
|
316
331
|
self,
|
317
332
|
batch: ScheduleBatch,
|
318
|
-
logits_output:
|
333
|
+
logits_output: LogitsProcessorOutput,
|
319
334
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
320
335
|
page_size: int,
|
321
336
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
@@ -362,6 +377,11 @@ class EagleVerifyInput:
|
|
362
377
|
)
|
363
378
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
364
379
|
|
380
|
+
if bs != len(sampling_info):
|
381
|
+
sampling_info = copy.deepcopy(sampling_info)
|
382
|
+
# NOTE: retrive_index are the indices of the requests that are kept.
|
383
|
+
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
384
|
+
|
365
385
|
# Apply the custom logit processors if registered in the sampling info.
|
366
386
|
if sampling_info.has_custom_logit_processor:
|
367
387
|
apply_custom_logit_processor(
|
@@ -593,13 +613,14 @@ class EagleVerifyInput:
|
|
593
613
|
batch.out_cache_loc = tgt_cache_loc
|
594
614
|
batch.seq_lens.add_(accept_length + 1)
|
595
615
|
|
596
|
-
draft_input = EagleDraftInput(
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
616
|
+
draft_input = EagleDraftInput(
|
617
|
+
hidden_states=batch.spec_info.hidden_states[accept_index],
|
618
|
+
verified_id=verified_id,
|
619
|
+
accept_length=accept_length,
|
620
|
+
accept_length_cpu=accept_length.tolist(),
|
621
|
+
seq_lens_for_draft_extend=batch.seq_lens,
|
622
|
+
req_pool_indices_for_draft_extend=batch.req_pool_indices,
|
623
|
+
)
|
603
624
|
|
604
625
|
return EagleVerifyOutput(
|
605
626
|
draft_input=draft_input,
|
@@ -622,7 +643,6 @@ class EagleVerifyInput:
|
|
622
643
|
batch.seq_lens.add_(accept_length + 1)
|
623
644
|
|
624
645
|
accept_length_cpu = accept_length.tolist()
|
625
|
-
draft_input = EagleDraftInput()
|
626
646
|
if len(unfinished_accept_index) > 0:
|
627
647
|
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
628
648
|
unfinished_index_device = torch.tensor(
|
@@ -653,18 +673,26 @@ class EagleVerifyInput:
|
|
653
673
|
next_power_of_2(self.draft_token_num),
|
654
674
|
)
|
655
675
|
|
656
|
-
draft_input
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
unfinished_index_device
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
676
|
+
draft_input = EagleDraftInput(
|
677
|
+
hidden_states=batch.spec_info.hidden_states[
|
678
|
+
unfinished_accept_index
|
679
|
+
],
|
680
|
+
verified_id=predict[unfinished_accept_index],
|
681
|
+
accept_length_cpu=draft_input_accept_length_cpu,
|
682
|
+
accept_length=accept_length[unfinished_index_device],
|
683
|
+
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
|
684
|
+
req_pool_indices_for_draft_extend=batch.req_pool_indices[
|
685
|
+
unfinished_index_device
|
686
|
+
],
|
687
|
+
)
|
688
|
+
else:
|
689
|
+
draft_input = EagleDraftInput.create_idle_input(
|
690
|
+
device=batch.device,
|
691
|
+
hidden_size=batch.model_config.hidden_size,
|
692
|
+
dtype=batch.model_config.dtype,
|
693
|
+
topk=self.topk,
|
694
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
695
|
+
)
|
668
696
|
|
669
697
|
return EagleVerifyOutput(
|
670
698
|
draft_input=draft_input,
|
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
|
|
297
297
|
|
298
298
|
def forward_batch_speculative_generation(
|
299
299
|
self, batch: ScheduleBatch
|
300
|
-
) -> Tuple[LogitsProcessorOutput,
|
300
|
+
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
301
301
|
"""Run speculative decoding forward.
|
302
302
|
|
303
303
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
|
|
325
325
|
self.verify(batch, spec_info)
|
326
326
|
)
|
327
327
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
328
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
329
|
+
# NOTE: We should use `check_forward_draft_extend_after_decode`
|
330
|
+
# when DP attention is enabled, but it is slow. Skip it for now.
|
331
|
+
if (
|
332
|
+
self.server_args.enable_dp_attention
|
333
|
+
or batch.spec_info.verified_id.shape[0] > 0
|
334
|
+
):
|
335
|
+
# decode is not finished
|
336
|
+
self.forward_draft_extend_after_decode(batch)
|
337
|
+
|
333
338
|
return (
|
334
339
|
logits_output,
|
335
340
|
verify_output.verified_id,
|
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
|
|
339
344
|
)
|
340
345
|
|
341
346
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
342
|
-
local_need_forward =
|
343
|
-
batch.spec_info.verified_id is not None
|
344
|
-
and batch.spec_info.verified_id.shape[0] > 0
|
345
|
-
)
|
347
|
+
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
|
346
348
|
if not self.server_args.enable_dp_attention:
|
347
349
|
return local_need_forward
|
348
350
|
|
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
|
|
361
363
|
|
362
364
|
def forward_target_extend(
|
363
365
|
self, batch: ScheduleBatch
|
364
|
-
) -> Tuple[LogitsProcessorOutput,
|
366
|
+
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
|
365
367
|
"""Run the target extend.
|
366
368
|
|
367
369
|
Args:
|
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
|
|
376
378
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
377
379
|
model_worker_batch = batch.get_model_worker_batch()
|
378
380
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
379
|
-
model_worker_batch.spec_num_draft_tokens = 1
|
380
381
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
381
382
|
model_worker_batch
|
382
383
|
)
|
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
|
|
508
509
|
self._draft_preprocess_decode(batch)
|
509
510
|
|
510
511
|
spec_info = batch.spec_info
|
512
|
+
assert isinstance(spec_info, EagleDraftInput)
|
511
513
|
|
512
514
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
515
|
+
spec_info.num_tokens_per_batch = self.topk
|
516
|
+
spec_info.num_tokens_for_logprob_per_batch = self.topk
|
513
517
|
batch.return_hidden_states = False
|
514
518
|
|
515
519
|
# Get forward batch
|
516
520
|
model_worker_batch = batch.get_model_worker_batch()
|
517
|
-
model_worker_batch.spec_num_draft_tokens = self.topk
|
518
521
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
519
522
|
forward_batch = ForwardBatch.init_new(
|
520
523
|
model_worker_batch, self.draft_model_runner
|
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
|
|
527
530
|
forward_batch
|
528
531
|
)
|
529
532
|
else:
|
533
|
+
forward_batch.can_run_dp_cuda_graph = False
|
530
534
|
if not forward_batch.forward_mode.is_idle():
|
531
535
|
# Initialize attention backend
|
532
536
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
|
|
578
582
|
def draft_forward(self, forward_batch: ForwardBatch):
|
579
583
|
# Parse args
|
580
584
|
spec_info = forward_batch.spec_info
|
585
|
+
assert isinstance(spec_info, EagleDraftInput)
|
581
586
|
out_cache_loc = forward_batch.out_cache_loc
|
582
587
|
topk_p, topk_index, hidden_states = (
|
583
588
|
spec_info.topk_p,
|
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
|
|
621
626
|
spec_info.hidden_states = hidden_states
|
622
627
|
|
623
628
|
# Run forward
|
624
|
-
logits_output = self.draft_model_runner.
|
625
|
-
forward_batch
|
629
|
+
logits_output, _ = self.draft_model_runner.forward(
|
630
|
+
forward_batch, skip_attn_backend_init=True
|
626
631
|
)
|
627
632
|
self._detect_nan_if_needed(logits_output)
|
628
633
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
|
|
642
647
|
else ForwardMode.IDLE
|
643
648
|
)
|
644
649
|
batch.spec_info = spec_info
|
650
|
+
|
645
651
|
model_worker_batch = batch.get_model_worker_batch(
|
646
652
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
647
653
|
)
|
648
|
-
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
649
654
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
650
655
|
|
651
656
|
if batch.has_grammar:
|
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
|
|
782
787
|
self,
|
783
788
|
batch: ScheduleBatch,
|
784
789
|
hidden_states: torch.Tensor,
|
785
|
-
next_token_ids:
|
786
|
-
seq_lens_cpu: torch.Tensor,
|
790
|
+
next_token_ids: torch.Tensor,
|
791
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
787
792
|
):
|
788
793
|
"""Run draft model extend. This API modifies the states of the batch.
|
789
794
|
|
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
|
|
795
800
|
batch.spec_info = EagleDraftInput(
|
796
801
|
hidden_states=hidden_states,
|
797
802
|
verified_id=next_token_ids,
|
803
|
+
num_tokens_per_batch=1,
|
804
|
+
num_tokens_for_logprob_per_batch=1,
|
798
805
|
)
|
799
806
|
batch.return_hidden_states = False
|
800
807
|
batch.spec_info.prepare_for_extend(batch)
|
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
|
|
802
809
|
model_worker_batch = batch.get_model_worker_batch(
|
803
810
|
seq_lens_cpu_cache=seq_lens_cpu
|
804
811
|
)
|
805
|
-
model_worker_batch.spec_num_draft_tokens = 1
|
806
812
|
forward_batch = ForwardBatch.init_new(
|
807
813
|
model_worker_batch, self.draft_model_runner
|
808
814
|
)
|
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
|
|
814
820
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
815
821
|
|
816
822
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
823
|
+
assert isinstance(batch.spec_info, EagleDraftInput)
|
817
824
|
# Backup fields that will be modified in-place
|
818
825
|
seq_lens_backup = batch.seq_lens.clone()
|
819
826
|
req_pool_indices_backup = batch.req_pool_indices
|
820
827
|
accept_length_backup = batch.spec_info.accept_length
|
821
828
|
return_logprob_backup = batch.return_logprob
|
829
|
+
|
822
830
|
input_is_idle = batch.forward_mode.is_idle()
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
)
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
831
|
+
|
832
|
+
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
|
833
|
+
batch = batch.copy()
|
834
|
+
batch.prepare_for_idle()
|
835
|
+
hidden_size = (
|
836
|
+
self.model_config.hidden_size * 3
|
837
|
+
if self.speculative_algorithm.is_eagle3()
|
838
|
+
else self.model_config.hidden_size
|
839
|
+
)
|
840
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
841
|
+
device=self.device,
|
842
|
+
hidden_size=hidden_size,
|
843
|
+
dtype=self.model_config.dtype,
|
844
|
+
topk=self.topk,
|
845
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
846
|
+
)
|
847
|
+
|
848
|
+
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
|
849
|
+
batch.spec_info.num_tokens_for_logprob_per_batch = 1
|
850
|
+
batch.spec_info.prepare_extend_after_decode(
|
851
|
+
batch,
|
852
|
+
self.speculative_num_steps,
|
853
|
+
)
|
854
|
+
batch.forward_mode = (
|
855
|
+
ForwardMode.DRAFT_EXTEND
|
856
|
+
if not batch.forward_mode.is_idle()
|
857
|
+
else ForwardMode.IDLE
|
858
|
+
)
|
859
|
+
|
845
860
|
batch.return_hidden_states = False
|
846
861
|
model_worker_batch = batch.get_model_worker_batch()
|
847
|
-
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
848
862
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
849
863
|
forward_batch = ForwardBatch.init_new(
|
850
864
|
model_worker_batch, self.draft_model_runner
|
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
|
|
869
883
|
)
|
870
884
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
871
885
|
else:
|
886
|
+
forward_batch.can_run_dp_cuda_graph = False
|
872
887
|
if not forward_batch.forward_mode.is_idle():
|
873
888
|
self.draft_model_runner.attn_backend.init_forward_metadata(
|
874
889
|
forward_batch
|
875
890
|
)
|
876
|
-
logits_output = self.draft_model_runner.
|
877
|
-
forward_batch
|
891
|
+
logits_output, _ = self.draft_model_runner.forward(
|
892
|
+
forward_batch, skip_attn_backend_init=True
|
878
893
|
)
|
879
894
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
880
895
|
|
sglang/srt/two_batch_overlap.py
CHANGED
@@ -341,15 +341,18 @@ class TboDPAttentionPreparer:
|
|
341
341
|
|
342
342
|
@staticmethod
|
343
343
|
def _compute_global_forward_mode(forward_modes):
|
344
|
-
|
345
|
-
|
346
|
-
for x in forward_modes
|
344
|
+
forward_modes_excluding_idle = [
|
345
|
+
x for x in forward_modes if x != ForwardMode.IDLE.value
|
347
346
|
]
|
347
|
+
|
348
|
+
if not forward_modes_excluding_idle:
|
349
|
+
return ForwardMode.IDLE, False
|
350
|
+
|
348
351
|
forward_mode_agree = TboDPAttentionPreparer._is_all_same(
|
349
|
-
|
352
|
+
forward_modes_excluding_idle
|
350
353
|
)
|
351
354
|
global_forward_mode = (
|
352
|
-
ForwardMode(
|
355
|
+
ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
|
353
356
|
)
|
354
357
|
return global_forward_mode, forward_mode_agree
|
355
358
|
|
@@ -542,6 +545,7 @@ class TboForwardBatchPreparer:
|
|
542
545
|
tbo_children=None,
|
543
546
|
global_num_tokens_gpu=None,
|
544
547
|
global_num_tokens_cpu=None,
|
548
|
+
dp_padding_mode=None,
|
545
549
|
gathered_buffer=gathered_buffer,
|
546
550
|
global_num_tokens_for_logprob_gpu=None,
|
547
551
|
global_num_tokens_for_logprob_cpu=None,
|