sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
20
20
|
ForwardMode,
|
21
21
|
)
|
22
22
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
require_attn_tp_gather,
|
25
|
+
require_gathered_buffer,
|
26
|
+
require_mlp_sync,
|
27
|
+
require_mlp_tp_gather,
|
28
|
+
)
|
23
29
|
|
24
30
|
if TYPE_CHECKING:
|
25
31
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
@@ -38,6 +44,12 @@ class EAGLEDraftCudaGraphRunner:
|
|
38
44
|
self.output_buffers = {}
|
39
45
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
40
46
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
47
|
+
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
48
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
49
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
50
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
51
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
52
|
+
self.dp_size = self.model_runner.dp_size
|
41
53
|
self.tp_size = self.model_runner.tp_size
|
42
54
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
43
55
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
@@ -53,7 +65,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
53
65
|
# Attention backend
|
54
66
|
self.max_bs = max(self.capture_bs)
|
55
67
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
56
|
-
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
68
|
+
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
69
|
+
self.max_bs, self.max_num_token
|
70
|
+
)
|
57
71
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
58
72
|
0
|
59
73
|
].get_cuda_graph_seq_len_fill_value()
|
@@ -78,10 +92,32 @@ class EAGLEDraftCudaGraphRunner:
|
|
78
92
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
79
93
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
80
94
|
self.hidden_states = torch.zeros(
|
81
|
-
(self.
|
95
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
82
96
|
dtype=self.model_runner.dtype,
|
83
97
|
)
|
84
98
|
|
99
|
+
if self.require_gathered_buffer:
|
100
|
+
self.gathered_buffer = torch.zeros(
|
101
|
+
(
|
102
|
+
self.max_num_token,
|
103
|
+
self.model_runner.model_config.hidden_size,
|
104
|
+
),
|
105
|
+
dtype=self.model_runner.dtype,
|
106
|
+
)
|
107
|
+
if self.require_mlp_tp_gather:
|
108
|
+
self.global_num_tokens_gpu = torch.zeros(
|
109
|
+
(self.dp_size,), dtype=torch.int32
|
110
|
+
)
|
111
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
112
|
+
(self.dp_size,), dtype=torch.int32
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
assert self.require_attn_tp_gather
|
116
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
117
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
118
|
+
(1,), dtype=torch.int32
|
119
|
+
)
|
120
|
+
|
85
121
|
# Capture
|
86
122
|
try:
|
87
123
|
with model_capture_mode():
|
@@ -92,11 +128,24 @@ class EAGLEDraftCudaGraphRunner:
|
|
92
128
|
)
|
93
129
|
|
94
130
|
def can_run(self, forward_batch: ForwardBatch):
|
131
|
+
if self.require_mlp_tp_gather:
|
132
|
+
cuda_graph_bs = (
|
133
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
134
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
135
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
136
|
+
)
|
137
|
+
else:
|
138
|
+
cuda_graph_bs = forward_batch.batch_size
|
139
|
+
|
95
140
|
is_bs_supported = (
|
96
|
-
|
141
|
+
cuda_graph_bs in self.graphs
|
97
142
|
if self.disable_padding
|
98
|
-
else
|
143
|
+
else cuda_graph_bs <= self.max_bs
|
99
144
|
)
|
145
|
+
|
146
|
+
if self.require_mlp_sync:
|
147
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
148
|
+
|
100
149
|
return is_bs_supported
|
101
150
|
|
102
151
|
def capture(self):
|
@@ -116,8 +165,58 @@ class EAGLEDraftCudaGraphRunner:
|
|
116
165
|
topk_index = self.topk_index[:num_seqs]
|
117
166
|
hidden_states = self.hidden_states[:num_seqs]
|
118
167
|
|
168
|
+
if self.require_mlp_tp_gather:
|
169
|
+
self.global_num_tokens_gpu.copy_(
|
170
|
+
torch.tensor(
|
171
|
+
[
|
172
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
173
|
+
for i in range(self.dp_size)
|
174
|
+
],
|
175
|
+
dtype=torch.int32,
|
176
|
+
device=self.input_ids.device,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
180
|
+
torch.tensor(
|
181
|
+
[
|
182
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
183
|
+
for i in range(self.dp_size)
|
184
|
+
],
|
185
|
+
dtype=torch.int32,
|
186
|
+
device=self.input_ids.device,
|
187
|
+
)
|
188
|
+
)
|
189
|
+
global_num_tokens = self.global_num_tokens_gpu
|
190
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
191
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
192
|
+
elif self.require_attn_tp_gather:
|
193
|
+
self.global_num_tokens_gpu.copy_(
|
194
|
+
torch.tensor(
|
195
|
+
[num_tokens],
|
196
|
+
dtype=torch.int32,
|
197
|
+
device=self.input_ids.device,
|
198
|
+
)
|
199
|
+
)
|
200
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
201
|
+
torch.tensor(
|
202
|
+
[num_tokens],
|
203
|
+
dtype=torch.int32,
|
204
|
+
device=self.input_ids.device,
|
205
|
+
)
|
206
|
+
)
|
207
|
+
global_num_tokens = self.global_num_tokens_gpu
|
208
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
209
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
210
|
+
else:
|
211
|
+
global_num_tokens = None
|
212
|
+
gathered_buffer = None
|
213
|
+
global_num_tokens_for_logprob = None
|
214
|
+
|
119
215
|
spec_info = EagleDraftInput(
|
120
|
-
topk_p=topk_p,
|
216
|
+
topk_p=topk_p,
|
217
|
+
topk_index=topk_index,
|
218
|
+
hidden_states=hidden_states,
|
219
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
121
220
|
)
|
122
221
|
|
123
222
|
# Forward batch
|
@@ -133,11 +232,14 @@ class EAGLEDraftCudaGraphRunner:
|
|
133
232
|
seq_lens_sum=seq_lens.sum().item(),
|
134
233
|
return_logprob=False,
|
135
234
|
positions=positions,
|
235
|
+
global_num_tokens_gpu=global_num_tokens,
|
236
|
+
gathered_buffer=gathered_buffer,
|
136
237
|
spec_algorithm=self.model_runner.spec_algorithm,
|
137
238
|
spec_info=spec_info,
|
138
239
|
capture_hidden_mode=(
|
139
240
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
140
241
|
),
|
242
|
+
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
141
243
|
)
|
142
244
|
|
143
245
|
# Attention backend
|
@@ -147,6 +249,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
147
249
|
|
148
250
|
# Run and capture
|
149
251
|
def run_once():
|
252
|
+
# Clean intermediate result cache for DP attention
|
253
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
254
|
+
|
150
255
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
151
256
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
152
257
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
@@ -184,12 +289,19 @@ class EAGLEDraftCudaGraphRunner:
|
|
184
289
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
185
290
|
|
186
291
|
# Pad
|
187
|
-
|
292
|
+
if self.require_mlp_tp_gather:
|
293
|
+
total_batch_size = (
|
294
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
295
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
296
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
297
|
+
)
|
298
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
299
|
+
else:
|
300
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
188
301
|
bs = self.capture_bs[index]
|
189
302
|
if bs != raw_bs:
|
190
|
-
self.seq_lens.fill_(
|
303
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
191
304
|
self.out_cache_loc.zero_()
|
192
|
-
self.positions.zero_()
|
193
305
|
|
194
306
|
num_tokens = bs * self.num_tokens_per_bs
|
195
307
|
|
@@ -204,6 +316,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
204
316
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
205
317
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
206
318
|
|
319
|
+
if self.require_gathered_buffer:
|
320
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
321
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
322
|
+
forward_batch.global_num_tokens_for_logprob_gpu
|
323
|
+
)
|
324
|
+
forward_batch.gathered_buffer = self.gathered_buffer
|
325
|
+
|
207
326
|
# Attention backend
|
208
327
|
if bs != raw_bs:
|
209
328
|
forward_batch.batch_size = bs
|
@@ -212,14 +331,16 @@ class EAGLEDraftCudaGraphRunner:
|
|
212
331
|
forward_batch.positions = self.positions[:num_tokens]
|
213
332
|
|
214
333
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
215
|
-
if forward_batch.seq_lens_cpu is not None
|
216
|
-
|
334
|
+
if forward_batch.seq_lens_cpu is not None:
|
335
|
+
if bs != raw_bs:
|
336
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
217
337
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
218
338
|
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
219
339
|
|
220
340
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
221
341
|
forward_batch, bs
|
222
342
|
)
|
343
|
+
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
|
223
344
|
|
224
345
|
# Replay
|
225
346
|
self.graphs[bs].replay()
|
@@ -21,6 +21,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
21
21
|
ForwardMode,
|
22
22
|
)
|
23
23
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
24
|
+
from sglang.srt.utils import (
|
25
|
+
require_attn_tp_gather,
|
26
|
+
require_gathered_buffer,
|
27
|
+
require_mlp_sync,
|
28
|
+
require_mlp_tp_gather,
|
29
|
+
)
|
24
30
|
|
25
31
|
if TYPE_CHECKING:
|
26
32
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
@@ -35,6 +41,10 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
35
41
|
self.output_buffers = {}
|
36
42
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
37
43
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
44
|
+
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
45
|
+
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
46
|
+
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
47
|
+
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
38
48
|
self.tp_size = self.model_runner.tp_size
|
39
49
|
self.dp_size = model_runner.server_args.dp_size
|
40
50
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
@@ -51,7 +61,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
51
61
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
52
62
|
|
53
63
|
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
|
54
|
-
self.max_num_token
|
64
|
+
self.max_bs, self.max_num_token
|
55
65
|
)
|
56
66
|
self.seq_len_fill_value = (
|
57
67
|
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
|
@@ -90,6 +100,27 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
90
100
|
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
91
101
|
)
|
92
102
|
|
103
|
+
if self.require_gathered_buffer:
|
104
|
+
self.gathered_buffer = torch.zeros(
|
105
|
+
(
|
106
|
+
self.max_num_token,
|
107
|
+
self.model_runner.model_config.hidden_size,
|
108
|
+
),
|
109
|
+
dtype=self.model_runner.dtype,
|
110
|
+
)
|
111
|
+
if self.require_mlp_tp_gather:
|
112
|
+
self.global_num_tokens_gpu = torch.zeros(
|
113
|
+
(self.dp_size,), dtype=torch.int32
|
114
|
+
)
|
115
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
116
|
+
(self.dp_size,), dtype=torch.int32
|
117
|
+
)
|
118
|
+
else:
|
119
|
+
assert self.require_attn_tp_gather
|
120
|
+
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
121
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
122
|
+
(1,), dtype=torch.int32
|
123
|
+
)
|
93
124
|
# Capture
|
94
125
|
try:
|
95
126
|
with model_capture_mode():
|
@@ -100,14 +131,24 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
100
131
|
)
|
101
132
|
|
102
133
|
def can_run(self, forward_batch: ForwardBatch):
|
103
|
-
|
134
|
+
if self.require_mlp_tp_gather:
|
135
|
+
cuda_graph_bs = (
|
136
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
137
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
138
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
cuda_graph_bs = forward_batch.seq_lens.numel()
|
104
142
|
|
105
143
|
is_bs_supported = (
|
106
|
-
|
144
|
+
cuda_graph_bs in self.graphs
|
107
145
|
if self.disable_padding
|
108
|
-
else
|
146
|
+
else cuda_graph_bs <= self.max_bs
|
109
147
|
)
|
110
148
|
|
149
|
+
if self.require_mlp_sync:
|
150
|
+
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
151
|
+
|
111
152
|
return is_bs_supported
|
112
153
|
|
113
154
|
def capture(self):
|
@@ -128,6 +169,53 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
128
169
|
positions = self.positions[:num_tokens]
|
129
170
|
hidden_states = self.hidden_states[:num_tokens]
|
130
171
|
|
172
|
+
if self.require_mlp_tp_gather:
|
173
|
+
self.global_num_tokens_gpu.copy_(
|
174
|
+
torch.tensor(
|
175
|
+
[
|
176
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
177
|
+
for i in range(self.dp_size)
|
178
|
+
],
|
179
|
+
dtype=torch.int32,
|
180
|
+
device=self.input_ids.device,
|
181
|
+
)
|
182
|
+
)
|
183
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
184
|
+
torch.tensor(
|
185
|
+
[
|
186
|
+
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
187
|
+
for i in range(self.dp_size)
|
188
|
+
],
|
189
|
+
dtype=torch.int32,
|
190
|
+
device=self.input_ids.device,
|
191
|
+
)
|
192
|
+
)
|
193
|
+
global_num_tokens = self.global_num_tokens_gpu
|
194
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
195
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
196
|
+
elif self.require_attn_tp_gather:
|
197
|
+
self.global_num_tokens_gpu.copy_(
|
198
|
+
torch.tensor(
|
199
|
+
[num_tokens],
|
200
|
+
dtype=torch.int32,
|
201
|
+
device=self.input_ids.device,
|
202
|
+
)
|
203
|
+
)
|
204
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
205
|
+
torch.tensor(
|
206
|
+
[num_tokens],
|
207
|
+
dtype=torch.int32,
|
208
|
+
device=self.input_ids.device,
|
209
|
+
)
|
210
|
+
)
|
211
|
+
global_num_tokens = self.global_num_tokens_gpu
|
212
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
213
|
+
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
214
|
+
else:
|
215
|
+
global_num_tokens = None
|
216
|
+
gathered_buffer = None
|
217
|
+
global_num_tokens_for_logprob = None
|
218
|
+
|
131
219
|
spec_info = EagleDraftInput(
|
132
220
|
hidden_states=hidden_states,
|
133
221
|
accept_length=accept_length,
|
@@ -147,6 +235,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
147
235
|
seq_lens_sum=seq_lens.sum().item(),
|
148
236
|
return_logprob=False,
|
149
237
|
positions=positions,
|
238
|
+
global_num_tokens_gpu=global_num_tokens,
|
239
|
+
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
240
|
+
gathered_buffer=gathered_buffer,
|
150
241
|
spec_algorithm=self.model_runner.spec_algorithm,
|
151
242
|
spec_info=spec_info,
|
152
243
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
@@ -167,6 +258,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
167
258
|
|
168
259
|
# Run and capture
|
169
260
|
def run_once():
|
261
|
+
# Clean intermediate result cache for DP attention
|
262
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
263
|
+
|
170
264
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
171
265
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
172
266
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
@@ -203,38 +297,57 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
203
297
|
# in the batch, which will not be counted as num_seqs
|
204
298
|
raw_bs = forward_batch.batch_size
|
205
299
|
num_tokens = forward_batch.input_ids.shape[0]
|
300
|
+
if self.require_mlp_tp_gather:
|
301
|
+
total_batch_size = (
|
302
|
+
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
303
|
+
if self.model_runner.spec_algorithm.is_eagle()
|
304
|
+
else sum(forward_batch.global_num_tokens_cpu)
|
305
|
+
)
|
306
|
+
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
307
|
+
else:
|
308
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
206
309
|
|
207
|
-
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
208
310
|
bs = self.capture_bs[index]
|
209
311
|
if bs * self.num_tokens_per_bs != num_tokens:
|
210
|
-
self.seq_lens.fill_(
|
211
|
-
self.accept_length.fill_(1)
|
312
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
212
313
|
self.out_cache_loc.zero_()
|
314
|
+
self.accept_length.fill_(1)
|
315
|
+
self.extend_seq_lens.fill_(1)
|
213
316
|
|
214
317
|
# Common inputs
|
215
318
|
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
216
319
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
217
|
-
|
320
|
+
if forward_batch.extend_seq_lens is not None:
|
321
|
+
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
218
322
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
219
323
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
220
324
|
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
221
|
-
|
325
|
+
if forward_batch.spec_info.accept_length is not None:
|
326
|
+
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
222
327
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
223
328
|
|
329
|
+
if self.require_gathered_buffer:
|
330
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
331
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
332
|
+
forward_batch.global_num_tokens_for_logprob_gpu
|
333
|
+
)
|
334
|
+
forward_batch.gathered_buffer = self.gathered_buffer
|
335
|
+
|
224
336
|
if forward_batch.seq_lens_cpu is not None:
|
225
337
|
if bs != raw_bs:
|
226
|
-
self.seq_lens_cpu.fill_(
|
338
|
+
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
227
339
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
228
340
|
|
229
341
|
if bs != raw_bs:
|
342
|
+
forward_batch.spec_info.positions = self.positions[:num_tokens]
|
230
343
|
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|
231
|
-
forward_batch.spec_info.positions = None
|
232
344
|
|
233
345
|
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
|
234
346
|
bs=bs,
|
235
347
|
req_pool_indices=self.req_pool_indices,
|
236
348
|
seq_lens=self.seq_lens,
|
237
|
-
seq_lens_sum=forward_batch.seq_lens_sum
|
349
|
+
seq_lens_sum=forward_batch.seq_lens_sum
|
350
|
+
+ (bs - raw_bs) * self.seq_len_fill_value,
|
238
351
|
encoder_lens=None,
|
239
352
|
forward_mode=ForwardMode.DRAFT_EXTEND,
|
240
353
|
spec_info=forward_batch.spec_info,
|
@@ -21,20 +21,22 @@ from sglang.srt.managers.schedule_batch import (
|
|
21
21
|
get_last_loc,
|
22
22
|
global_server_args_dict,
|
23
23
|
)
|
24
|
-
from sglang.srt.mem_cache.
|
24
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
26
26
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
27
27
|
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
28
30
|
if is_cuda():
|
29
31
|
from sgl_kernel import (
|
32
|
+
fast_topk,
|
30
33
|
top_k_renorm_prob,
|
31
34
|
top_p_renorm_prob,
|
32
35
|
tree_speculative_sampling_target_only,
|
33
36
|
verify_tree_greedy,
|
34
37
|
)
|
35
|
-
from sgl_kernel.top_k import fast_topk
|
36
38
|
elif is_hip():
|
37
|
-
from sgl_kernel import verify_tree_greedy
|
39
|
+
from sgl_kernel import fast_topk, verify_tree_greedy
|
38
40
|
|
39
41
|
|
40
42
|
logger = logging.getLogger(__name__)
|
@@ -69,6 +71,8 @@ class EagleDraftInput:
|
|
69
71
|
kv_indices: torch.Tensor = None
|
70
72
|
|
71
73
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
74
|
+
if batch.forward_mode.is_idle():
|
75
|
+
return
|
72
76
|
# Prefill only generate 1 token.
|
73
77
|
assert len(self.verified_id) == len(batch.seq_lens)
|
74
78
|
|
@@ -80,6 +84,25 @@ class EagleDraftInput:
|
|
80
84
|
)
|
81
85
|
pt += extend_len
|
82
86
|
|
87
|
+
@classmethod
|
88
|
+
def create_idle_input(
|
89
|
+
cls,
|
90
|
+
device: torch.device,
|
91
|
+
hidden_size: int,
|
92
|
+
dtype: torch.dtype,
|
93
|
+
topk: int,
|
94
|
+
capture_hidden_mode: CaptureHiddenMode,
|
95
|
+
):
|
96
|
+
return cls(
|
97
|
+
verified_id=None,
|
98
|
+
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
99
|
+
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
100
|
+
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
101
|
+
capture_hidden_mode=capture_hidden_mode,
|
102
|
+
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
|
103
|
+
accept_length_cpu=[],
|
104
|
+
)
|
105
|
+
|
83
106
|
def prepare_extend_after_decode(
|
84
107
|
self,
|
85
108
|
batch: ScheduleBatch,
|
@@ -193,7 +216,35 @@ class EagleVerifyInput:
|
|
193
216
|
seq_lens_cpu: torch.Tensor
|
194
217
|
grammar: BaseGrammarObject = None
|
195
218
|
|
219
|
+
@classmethod
|
220
|
+
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
221
|
+
return cls(
|
222
|
+
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
|
223
|
+
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
|
224
|
+
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
|
225
|
+
retrive_index=torch.full(
|
226
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
227
|
+
),
|
228
|
+
retrive_next_token=torch.full(
|
229
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
230
|
+
),
|
231
|
+
retrive_next_sibling=torch.full(
|
232
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
233
|
+
),
|
234
|
+
retrive_cum_len=None,
|
235
|
+
topk=topk,
|
236
|
+
draft_token_num=num_verify_tokens,
|
237
|
+
spec_steps=spec_steps,
|
238
|
+
capture_hidden_mode=CaptureHiddenMode.FULL,
|
239
|
+
seq_lens_sum=0,
|
240
|
+
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
|
241
|
+
)
|
242
|
+
|
196
243
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
244
|
+
|
245
|
+
if batch.forward_mode.is_idle():
|
246
|
+
return
|
247
|
+
|
197
248
|
batch.input_ids = self.draft_token
|
198
249
|
|
199
250
|
if page_size == 1:
|
@@ -265,7 +316,7 @@ class EagleVerifyInput:
|
|
265
316
|
self,
|
266
317
|
batch: ScheduleBatch,
|
267
318
|
logits_output: torch.Tensor,
|
268
|
-
token_to_kv_pool_allocator:
|
319
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
269
320
|
page_size: int,
|
270
321
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
271
322
|
) -> torch.Tensor:
|
@@ -279,6 +330,26 @@ class EagleVerifyInput:
|
|
279
330
|
tokens. I.e., logits_output.next_token_logits only contains
|
280
331
|
accepted token logits.
|
281
332
|
"""
|
333
|
+
if batch.forward_mode.is_idle():
|
334
|
+
return EagleVerifyOutput(
|
335
|
+
draft_input=EagleDraftInput.create_idle_input(
|
336
|
+
device=batch.device,
|
337
|
+
hidden_size=batch.model_config.hidden_size,
|
338
|
+
dtype=batch.model_config.dtype,
|
339
|
+
topk=self.topk,
|
340
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
341
|
+
),
|
342
|
+
logits_output=logits_output,
|
343
|
+
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
|
344
|
+
accept_length_per_req_cpu=[],
|
345
|
+
accepted_indices=torch.full(
|
346
|
+
(0, self.spec_steps + 1),
|
347
|
+
-1,
|
348
|
+
dtype=torch.int32,
|
349
|
+
device=batch.device,
|
350
|
+
),
|
351
|
+
)
|
352
|
+
|
282
353
|
bs = self.retrive_index.shape[0]
|
283
354
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
284
355
|
sampling_info = batch.sampling_info
|
@@ -992,10 +1063,11 @@ def select_top_k_tokens(
|
|
992
1063
|
topk_index = topk_index.reshape(-1, topk**2)
|
993
1064
|
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
994
1065
|
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
1066
|
+
if hidden_states.shape[0] > 0:
|
1067
|
+
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
1068
|
+
0, hidden_states.shape[0], step=topk, device="cuda"
|
1069
|
+
).repeat_interleave(topk)
|
1070
|
+
hidden_states = hidden_states[selected_input_index, :]
|
999
1071
|
|
1000
1072
|
tree_info = (
|
1001
1073
|
expand_scores, # shape: (b, topk, topk)
|