sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +49 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +2 -8
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +27 -4
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -4
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/topk.py +5 -13
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/modelopt_quant.py +8 -4
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +53 -6
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +13 -25
- sglang/srt/managers/tokenizer_manager.py +28 -25
- sglang/srt/managers/tp_worker.py +2 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +30 -16
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +41 -23
- sglang/srt/models/deepseek_v2.py +1 -2
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +0 -4
- sglang/srt/models/qwen3_moe.py +1 -6
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +76 -55
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +17 -68
- sglang/test/test_activation.py +50 -1
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode
|
8
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
9
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
10
11
|
CudaGraphRunner,
|
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
|
|
97
98
|
)
|
98
99
|
|
99
100
|
if self.require_gathered_buffer:
|
100
|
-
self.gathered_buffer = torch.zeros(
|
101
|
-
(
|
102
|
-
self.max_num_token,
|
103
|
-
self.model_runner.model_config.hidden_size,
|
104
|
-
),
|
105
|
-
dtype=self.model_runner.dtype,
|
106
|
-
)
|
107
101
|
if self.require_mlp_tp_gather:
|
108
102
|
self.global_num_tokens_gpu = torch.zeros(
|
109
103
|
(self.dp_size,), dtype=torch.int32
|
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
|
|
111
105
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
112
106
|
(self.dp_size,), dtype=torch.int32
|
113
107
|
)
|
108
|
+
self.gathered_buffer = torch.zeros(
|
109
|
+
(
|
110
|
+
self.max_num_token * self.dp_size,
|
111
|
+
self.model_runner.model_config.hidden_size,
|
112
|
+
),
|
113
|
+
dtype=self.model_runner.dtype,
|
114
|
+
)
|
114
115
|
else:
|
115
116
|
assert self.require_attn_tp_gather
|
116
117
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
117
118
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
118
119
|
(1,), dtype=torch.int32
|
119
120
|
)
|
121
|
+
self.gathered_buffer = torch.zeros(
|
122
|
+
(
|
123
|
+
self.max_num_token,
|
124
|
+
self.model_runner.model_config.hidden_size,
|
125
|
+
),
|
126
|
+
dtype=self.model_runner.dtype,
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
self.global_num_tokens_gpu = None
|
130
|
+
self.global_num_tokens_for_logprob_gpu = None
|
131
|
+
self.gathered_buffer = None
|
120
132
|
|
121
133
|
# Capture
|
122
134
|
try:
|
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
130
142
|
def can_run(self, forward_batch: ForwardBatch):
|
131
143
|
if self.require_mlp_tp_gather:
|
132
144
|
cuda_graph_bs = (
|
133
|
-
|
145
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
134
146
|
if self.model_runner.spec_algorithm.is_eagle()
|
135
|
-
else
|
147
|
+
else max(forward_batch.global_num_tokens_cpu)
|
136
148
|
)
|
137
149
|
else:
|
138
150
|
cuda_graph_bs = forward_batch.batch_size
|
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
|
|
168
180
|
if self.require_mlp_tp_gather:
|
169
181
|
self.global_num_tokens_gpu.copy_(
|
170
182
|
torch.tensor(
|
171
|
-
[
|
172
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
173
|
-
for i in range(self.dp_size)
|
174
|
-
],
|
183
|
+
[num_tokens] * self.dp_size,
|
175
184
|
dtype=torch.int32,
|
176
185
|
device=self.input_ids.device,
|
177
186
|
)
|
178
187
|
)
|
179
188
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
180
189
|
torch.tensor(
|
181
|
-
[
|
182
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
183
|
-
for i in range(self.dp_size)
|
184
|
-
],
|
190
|
+
[num_tokens] * self.dp_size,
|
185
191
|
dtype=torch.int32,
|
186
192
|
device=self.input_ids.device,
|
187
193
|
)
|
188
194
|
)
|
189
195
|
global_num_tokens = self.global_num_tokens_gpu
|
190
|
-
gathered_buffer = self.gathered_buffer[:num_tokens]
|
196
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
191
197
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
192
198
|
elif self.require_attn_tp_gather:
|
193
199
|
self.global_num_tokens_gpu.copy_(
|
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
233
239
|
return_logprob=False,
|
234
240
|
positions=positions,
|
235
241
|
global_num_tokens_gpu=global_num_tokens,
|
242
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
236
243
|
gathered_buffer=gathered_buffer,
|
237
244
|
spec_algorithm=self.model_runner.spec_algorithm,
|
238
245
|
spec_info=spec_info,
|
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
290
297
|
|
291
298
|
# Pad
|
292
299
|
if self.require_mlp_tp_gather:
|
293
|
-
|
294
|
-
|
300
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
301
|
+
max_batch_size = (
|
302
|
+
max_num_tokens // self.num_tokens_per_bs
|
295
303
|
if self.model_runner.spec_algorithm.is_eagle()
|
296
|
-
else
|
304
|
+
else max_num_tokens
|
297
305
|
)
|
298
|
-
index = bisect.bisect_left(self.capture_bs,
|
306
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
299
307
|
else:
|
300
308
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
301
309
|
bs = self.capture_bs[index]
|
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
316
324
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
317
325
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
318
326
|
|
327
|
+
# TODO(ch-wan): support num_token_non_padded
|
319
328
|
if self.require_gathered_buffer:
|
320
|
-
self.global_num_tokens_gpu.
|
321
|
-
self.global_num_tokens_for_logprob_gpu.
|
322
|
-
forward_batch.global_num_tokens_for_logprob_gpu
|
323
|
-
)
|
324
|
-
forward_batch.gathered_buffer = self.gathered_buffer
|
329
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
330
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
325
331
|
|
326
332
|
# Attention backend
|
327
333
|
if bs != raw_bs:
|
@@ -330,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
|
|
330
336
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
331
337
|
forward_batch.positions = self.positions[:num_tokens]
|
332
338
|
|
333
|
-
# Special handle for seq_len_cpu used when flashinfer mla is used
|
334
339
|
if forward_batch.seq_lens_cpu is not None:
|
335
340
|
if bs != raw_bs:
|
336
341
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode
|
8
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
9
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
10
11
|
CudaGraphRunner,
|
@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
109
110
|
)
|
110
111
|
|
111
112
|
if self.require_gathered_buffer:
|
112
|
-
self.gathered_buffer = torch.zeros(
|
113
|
-
(
|
114
|
-
self.max_num_token,
|
115
|
-
self.model_runner.model_config.hidden_size,
|
116
|
-
),
|
117
|
-
dtype=self.model_runner.dtype,
|
118
|
-
)
|
119
113
|
if self.require_mlp_tp_gather:
|
120
114
|
self.global_num_tokens_gpu = torch.zeros(
|
121
115
|
(self.dp_size,), dtype=torch.int32
|
@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
123
117
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
124
118
|
(self.dp_size,), dtype=torch.int32
|
125
119
|
)
|
120
|
+
self.gathered_buffer = torch.zeros(
|
121
|
+
(
|
122
|
+
self.max_num_token * self.dp_size,
|
123
|
+
self.model_runner.model_config.hidden_size,
|
124
|
+
),
|
125
|
+
dtype=self.model_runner.dtype,
|
126
|
+
)
|
126
127
|
else:
|
127
128
|
assert self.require_attn_tp_gather
|
128
129
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
129
130
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
130
131
|
(1,), dtype=torch.int32
|
131
132
|
)
|
133
|
+
self.gathered_buffer = torch.zeros(
|
134
|
+
(
|
135
|
+
self.max_num_token,
|
136
|
+
self.model_runner.model_config.hidden_size,
|
137
|
+
),
|
138
|
+
dtype=self.model_runner.dtype,
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
self.global_num_tokens_gpu = None
|
142
|
+
self.global_num_tokens_for_logprob_gpu = None
|
143
|
+
self.gathered_buffer = None
|
144
|
+
|
132
145
|
# Capture
|
133
146
|
try:
|
134
147
|
with model_capture_mode():
|
@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
141
154
|
def can_run(self, forward_batch: ForwardBatch):
|
142
155
|
if self.require_mlp_tp_gather:
|
143
156
|
cuda_graph_bs = (
|
144
|
-
|
157
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
145
158
|
if self.model_runner.spec_algorithm.is_eagle()
|
146
|
-
else
|
159
|
+
else max(forward_batch.global_num_tokens_cpu)
|
147
160
|
)
|
148
161
|
else:
|
149
162
|
cuda_graph_bs = forward_batch.seq_lens.numel()
|
@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
180
193
|
if self.require_mlp_tp_gather:
|
181
194
|
self.global_num_tokens_gpu.copy_(
|
182
195
|
torch.tensor(
|
183
|
-
[
|
184
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
185
|
-
for i in range(self.dp_size)
|
186
|
-
],
|
196
|
+
[num_tokens] * self.dp_size,
|
187
197
|
dtype=torch.int32,
|
188
198
|
device=self.input_ids.device,
|
189
199
|
)
|
190
200
|
)
|
191
201
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
192
202
|
torch.tensor(
|
193
|
-
[
|
194
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
195
|
-
for i in range(self.dp_size)
|
196
|
-
],
|
203
|
+
[bs] * self.dp_size,
|
197
204
|
dtype=torch.int32,
|
198
205
|
device=self.input_ids.device,
|
199
206
|
)
|
200
207
|
)
|
201
|
-
|
202
|
-
gathered_buffer = self.gathered_buffer[:num_tokens]
|
203
|
-
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
208
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
204
209
|
elif self.require_attn_tp_gather:
|
205
210
|
self.global_num_tokens_gpu.copy_(
|
206
211
|
torch.tensor(
|
@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
211
216
|
)
|
212
217
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
213
218
|
torch.tensor(
|
214
|
-
[
|
219
|
+
[bs],
|
215
220
|
dtype=torch.int32,
|
216
221
|
device=self.input_ids.device,
|
217
222
|
)
|
218
223
|
)
|
219
|
-
global_num_tokens = self.global_num_tokens_gpu
|
220
224
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
221
|
-
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
222
225
|
else:
|
223
|
-
global_num_tokens = None
|
224
226
|
gathered_buffer = None
|
225
|
-
global_num_tokens_for_logprob = None
|
226
227
|
|
227
228
|
spec_info = EagleDraftInput(
|
228
229
|
hidden_states=hidden_states,
|
@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
243
244
|
seq_lens_sum=seq_lens.sum().item(),
|
244
245
|
return_logprob=False,
|
245
246
|
positions=positions,
|
246
|
-
global_num_tokens_gpu=
|
247
|
-
global_num_tokens_for_logprob_gpu=
|
247
|
+
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
248
|
+
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
249
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
248
250
|
gathered_buffer=gathered_buffer,
|
249
251
|
spec_algorithm=self.model_runner.spec_algorithm,
|
250
252
|
spec_info=spec_info,
|
@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
306
308
|
raw_bs = forward_batch.batch_size
|
307
309
|
num_tokens = forward_batch.input_ids.shape[0]
|
308
310
|
if self.require_mlp_tp_gather:
|
309
|
-
|
310
|
-
|
311
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
312
|
+
max_batch_size = (
|
313
|
+
max_num_tokens // self.num_tokens_per_bs
|
311
314
|
if self.model_runner.spec_algorithm.is_eagle()
|
312
|
-
else
|
315
|
+
else max_num_tokens
|
313
316
|
)
|
314
|
-
index = bisect.bisect_left(self.capture_bs,
|
317
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
315
318
|
else:
|
316
319
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
317
320
|
|
@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
334
337
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
335
338
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
336
339
|
|
340
|
+
# TODO(ch-wan): support num_token_non_padded
|
337
341
|
if self.require_gathered_buffer:
|
338
|
-
self.global_num_tokens_gpu.
|
339
|
-
self.global_num_tokens_for_logprob_gpu.
|
340
|
-
forward_batch.global_num_tokens_for_logprob_gpu
|
341
|
-
)
|
342
|
-
forward_batch.gathered_buffer = self.gathered_buffer
|
342
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
343
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs)
|
343
344
|
|
344
345
|
if forward_batch.seq_lens_cpu is not None:
|
345
346
|
if bs != raw_bs:
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
3
4
|
import logging
|
4
5
|
import os
|
5
6
|
import time
|
@@ -70,9 +71,20 @@ class EagleDraftInput:
|
|
70
71
|
kv_indptr: torch.Tensor = None
|
71
72
|
kv_indices: torch.Tensor = None
|
72
73
|
|
74
|
+
# Shape info for padding
|
75
|
+
num_tokens_per_batch: int = -1
|
76
|
+
num_tokens_for_logprob_per_batch: int = -1
|
77
|
+
|
78
|
+
# Inputs for draft extend
|
79
|
+
# shape: (b,)
|
80
|
+
seq_lens_for_draft_extend: torch.Tensor = None
|
81
|
+
req_pool_indices_for_draft_extend: torch.Tensor = None
|
82
|
+
|
73
83
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
84
|
+
|
74
85
|
if batch.forward_mode.is_idle():
|
75
86
|
return
|
87
|
+
|
76
88
|
# Prefill only generate 1 token.
|
77
89
|
assert len(self.verified_id) == len(batch.seq_lens)
|
78
90
|
|
@@ -94,7 +106,7 @@ class EagleDraftInput:
|
|
94
106
|
capture_hidden_mode: CaptureHiddenMode,
|
95
107
|
):
|
96
108
|
return cls(
|
97
|
-
verified_id=
|
109
|
+
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
98
110
|
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
99
111
|
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
100
112
|
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
@@ -108,7 +120,10 @@ class EagleDraftInput:
|
|
108
120
|
batch: ScheduleBatch,
|
109
121
|
speculative_num_steps: int,
|
110
122
|
):
|
111
|
-
|
123
|
+
|
124
|
+
if batch.forward_mode.is_idle():
|
125
|
+
return
|
126
|
+
|
112
127
|
batch.input_ids = self.verified_id
|
113
128
|
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
114
129
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
@@ -315,7 +330,7 @@ class EagleVerifyInput:
|
|
315
330
|
def verify(
|
316
331
|
self,
|
317
332
|
batch: ScheduleBatch,
|
318
|
-
logits_output:
|
333
|
+
logits_output: LogitsProcessorOutput,
|
319
334
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
320
335
|
page_size: int,
|
321
336
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
@@ -362,6 +377,11 @@ class EagleVerifyInput:
|
|
362
377
|
)
|
363
378
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
364
379
|
|
380
|
+
if bs != len(sampling_info):
|
381
|
+
sampling_info = copy.deepcopy(sampling_info)
|
382
|
+
# NOTE: retrive_index are the indices of the requests that are kept.
|
383
|
+
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
384
|
+
|
365
385
|
# Apply the custom logit processors if registered in the sampling info.
|
366
386
|
if sampling_info.has_custom_logit_processor:
|
367
387
|
apply_custom_logit_processor(
|
@@ -593,13 +613,14 @@ class EagleVerifyInput:
|
|
593
613
|
batch.out_cache_loc = tgt_cache_loc
|
594
614
|
batch.seq_lens.add_(accept_length + 1)
|
595
615
|
|
596
|
-
draft_input = EagleDraftInput(
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
616
|
+
draft_input = EagleDraftInput(
|
617
|
+
hidden_states=batch.spec_info.hidden_states[accept_index],
|
618
|
+
verified_id=verified_id,
|
619
|
+
accept_length=accept_length,
|
620
|
+
accept_length_cpu=accept_length.tolist(),
|
621
|
+
seq_lens_for_draft_extend=batch.seq_lens,
|
622
|
+
req_pool_indices_for_draft_extend=batch.req_pool_indices,
|
623
|
+
)
|
603
624
|
|
604
625
|
return EagleVerifyOutput(
|
605
626
|
draft_input=draft_input,
|
@@ -622,7 +643,6 @@ class EagleVerifyInput:
|
|
622
643
|
batch.seq_lens.add_(accept_length + 1)
|
623
644
|
|
624
645
|
accept_length_cpu = accept_length.tolist()
|
625
|
-
draft_input = EagleDraftInput()
|
626
646
|
if len(unfinished_accept_index) > 0:
|
627
647
|
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
628
648
|
unfinished_index_device = torch.tensor(
|
@@ -653,18 +673,26 @@ class EagleVerifyInput:
|
|
653
673
|
next_power_of_2(self.draft_token_num),
|
654
674
|
)
|
655
675
|
|
656
|
-
draft_input
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
unfinished_index_device
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
676
|
+
draft_input = EagleDraftInput(
|
677
|
+
hidden_states=batch.spec_info.hidden_states[
|
678
|
+
unfinished_accept_index
|
679
|
+
],
|
680
|
+
verified_id=predict[unfinished_accept_index],
|
681
|
+
accept_length_cpu=draft_input_accept_length_cpu,
|
682
|
+
accept_length=accept_length[unfinished_index_device],
|
683
|
+
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
|
684
|
+
req_pool_indices_for_draft_extend=batch.req_pool_indices[
|
685
|
+
unfinished_index_device
|
686
|
+
],
|
687
|
+
)
|
688
|
+
else:
|
689
|
+
draft_input = EagleDraftInput.create_idle_input(
|
690
|
+
device=batch.device,
|
691
|
+
hidden_size=batch.model_config.hidden_size,
|
692
|
+
dtype=batch.model_config.dtype,
|
693
|
+
topk=self.topk,
|
694
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
695
|
+
)
|
668
696
|
|
669
697
|
return EagleVerifyOutput(
|
670
698
|
draft_input=draft_input,
|
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
|
|
297
297
|
|
298
298
|
def forward_batch_speculative_generation(
|
299
299
|
self, batch: ScheduleBatch
|
300
|
-
) -> Tuple[LogitsProcessorOutput,
|
300
|
+
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
|
301
301
|
"""Run speculative decoding forward.
|
302
302
|
|
303
303
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
|
|
325
325
|
self.verify(batch, spec_info)
|
326
326
|
)
|
327
327
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
328
|
+
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
329
|
+
# NOTE: We should use `check_forward_draft_extend_after_decode`
|
330
|
+
# when DP attention is enabled, but it is slow. Skip it for now.
|
331
|
+
if (
|
332
|
+
self.server_args.enable_dp_attention
|
333
|
+
or batch.spec_info.verified_id.shape[0] > 0
|
334
|
+
):
|
335
|
+
# decode is not finished
|
336
|
+
self.forward_draft_extend_after_decode(batch)
|
337
|
+
|
333
338
|
return (
|
334
339
|
logits_output,
|
335
340
|
verify_output.verified_id,
|
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
|
|
339
344
|
)
|
340
345
|
|
341
346
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
342
|
-
local_need_forward =
|
343
|
-
batch.spec_info.verified_id is not None
|
344
|
-
and batch.spec_info.verified_id.shape[0] > 0
|
345
|
-
)
|
347
|
+
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
|
346
348
|
if not self.server_args.enable_dp_attention:
|
347
349
|
return local_need_forward
|
348
350
|
|
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
|
|
361
363
|
|
362
364
|
def forward_target_extend(
|
363
365
|
self, batch: ScheduleBatch
|
364
|
-
) -> Tuple[LogitsProcessorOutput,
|
366
|
+
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
|
365
367
|
"""Run the target extend.
|
366
368
|
|
367
369
|
Args:
|
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
|
|
376
378
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
377
379
|
model_worker_batch = batch.get_model_worker_batch()
|
378
380
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
379
|
-
model_worker_batch.spec_num_draft_tokens = 1
|
380
381
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
381
382
|
model_worker_batch
|
382
383
|
)
|
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
|
|
508
509
|
self._draft_preprocess_decode(batch)
|
509
510
|
|
510
511
|
spec_info = batch.spec_info
|
512
|
+
assert isinstance(spec_info, EagleDraftInput)
|
511
513
|
|
512
514
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
515
|
+
spec_info.num_tokens_per_batch = self.topk
|
516
|
+
spec_info.num_tokens_for_logprob_per_batch = self.topk
|
513
517
|
batch.return_hidden_states = False
|
514
518
|
|
515
519
|
# Get forward batch
|
516
520
|
model_worker_batch = batch.get_model_worker_batch()
|
517
|
-
model_worker_batch.spec_num_draft_tokens = self.topk
|
518
521
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
519
522
|
forward_batch = ForwardBatch.init_new(
|
520
523
|
model_worker_batch, self.draft_model_runner
|
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
|
|
527
530
|
forward_batch
|
528
531
|
)
|
529
532
|
else:
|
533
|
+
forward_batch.can_run_dp_cuda_graph = False
|
530
534
|
if not forward_batch.forward_mode.is_idle():
|
531
535
|
# Initialize attention backend
|
532
536
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
|
|
578
582
|
def draft_forward(self, forward_batch: ForwardBatch):
|
579
583
|
# Parse args
|
580
584
|
spec_info = forward_batch.spec_info
|
585
|
+
assert isinstance(spec_info, EagleDraftInput)
|
581
586
|
out_cache_loc = forward_batch.out_cache_loc
|
582
587
|
topk_p, topk_index, hidden_states = (
|
583
588
|
spec_info.topk_p,
|
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
|
|
621
626
|
spec_info.hidden_states = hidden_states
|
622
627
|
|
623
628
|
# Run forward
|
624
|
-
logits_output = self.draft_model_runner.
|
625
|
-
forward_batch
|
629
|
+
logits_output, _ = self.draft_model_runner.forward(
|
630
|
+
forward_batch, skip_attn_backend_init=True
|
626
631
|
)
|
627
632
|
self._detect_nan_if_needed(logits_output)
|
628
633
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
|
|
642
647
|
else ForwardMode.IDLE
|
643
648
|
)
|
644
649
|
batch.spec_info = spec_info
|
650
|
+
|
645
651
|
model_worker_batch = batch.get_model_worker_batch(
|
646
652
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
647
653
|
)
|
648
|
-
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
649
654
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
650
655
|
|
651
656
|
if batch.has_grammar:
|
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
|
|
782
787
|
self,
|
783
788
|
batch: ScheduleBatch,
|
784
789
|
hidden_states: torch.Tensor,
|
785
|
-
next_token_ids:
|
786
|
-
seq_lens_cpu: torch.Tensor,
|
790
|
+
next_token_ids: torch.Tensor,
|
791
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
787
792
|
):
|
788
793
|
"""Run draft model extend. This API modifies the states of the batch.
|
789
794
|
|
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
|
|
795
800
|
batch.spec_info = EagleDraftInput(
|
796
801
|
hidden_states=hidden_states,
|
797
802
|
verified_id=next_token_ids,
|
803
|
+
num_tokens_per_batch=1,
|
804
|
+
num_tokens_for_logprob_per_batch=1,
|
798
805
|
)
|
799
806
|
batch.return_hidden_states = False
|
800
807
|
batch.spec_info.prepare_for_extend(batch)
|
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
|
|
802
809
|
model_worker_batch = batch.get_model_worker_batch(
|
803
810
|
seq_lens_cpu_cache=seq_lens_cpu
|
804
811
|
)
|
805
|
-
model_worker_batch.spec_num_draft_tokens = 1
|
806
812
|
forward_batch = ForwardBatch.init_new(
|
807
813
|
model_worker_batch, self.draft_model_runner
|
808
814
|
)
|
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
|
|
814
820
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
815
821
|
|
816
822
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
823
|
+
assert isinstance(batch.spec_info, EagleDraftInput)
|
817
824
|
# Backup fields that will be modified in-place
|
818
825
|
seq_lens_backup = batch.seq_lens.clone()
|
819
826
|
req_pool_indices_backup = batch.req_pool_indices
|
820
827
|
accept_length_backup = batch.spec_info.accept_length
|
821
828
|
return_logprob_backup = batch.return_logprob
|
829
|
+
|
822
830
|
input_is_idle = batch.forward_mode.is_idle()
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
)
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
831
|
+
|
832
|
+
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
|
833
|
+
batch = batch.copy()
|
834
|
+
batch.prepare_for_idle()
|
835
|
+
hidden_size = (
|
836
|
+
self.model_config.hidden_size * 3
|
837
|
+
if self.speculative_algorithm.is_eagle3()
|
838
|
+
else self.model_config.hidden_size
|
839
|
+
)
|
840
|
+
batch.spec_info = EagleDraftInput.create_idle_input(
|
841
|
+
device=self.device,
|
842
|
+
hidden_size=hidden_size,
|
843
|
+
dtype=self.model_config.dtype,
|
844
|
+
topk=self.topk,
|
845
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
846
|
+
)
|
847
|
+
|
848
|
+
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
|
849
|
+
batch.spec_info.num_tokens_for_logprob_per_batch = 1
|
850
|
+
batch.spec_info.prepare_extend_after_decode(
|
851
|
+
batch,
|
852
|
+
self.speculative_num_steps,
|
853
|
+
)
|
854
|
+
batch.forward_mode = (
|
855
|
+
ForwardMode.DRAFT_EXTEND
|
856
|
+
if not batch.forward_mode.is_idle()
|
857
|
+
else ForwardMode.IDLE
|
858
|
+
)
|
859
|
+
|
845
860
|
batch.return_hidden_states = False
|
846
861
|
model_worker_batch = batch.get_model_worker_batch()
|
847
|
-
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
848
862
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
849
863
|
forward_batch = ForwardBatch.init_new(
|
850
864
|
model_worker_batch, self.draft_model_runner
|
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
|
|
869
883
|
)
|
870
884
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
871
885
|
else:
|
886
|
+
forward_batch.can_run_dp_cuda_graph = False
|
872
887
|
if not forward_batch.forward_mode.is_idle():
|
873
888
|
self.draft_model_runner.attn_backend.init_forward_metadata(
|
874
889
|
forward_batch
|
875
890
|
)
|
876
|
-
logits_output = self.draft_model_runner.
|
877
|
-
forward_batch
|
891
|
+
logits_output, _ = self.draft_model_runner.forward(
|
892
|
+
forward_batch, skip_attn_backend_init=True
|
878
893
|
)
|
879
894
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
880
895
|
|