sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode
|
8
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
9
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
10
11
|
CudaGraphRunner,
|
@@ -97,13 +98,6 @@ class EAGLEDraftCudaGraphRunner:
|
|
97
98
|
)
|
98
99
|
|
99
100
|
if self.require_gathered_buffer:
|
100
|
-
self.gathered_buffer = torch.zeros(
|
101
|
-
(
|
102
|
-
self.max_num_token,
|
103
|
-
self.model_runner.model_config.hidden_size,
|
104
|
-
),
|
105
|
-
dtype=self.model_runner.dtype,
|
106
|
-
)
|
107
101
|
if self.require_mlp_tp_gather:
|
108
102
|
self.global_num_tokens_gpu = torch.zeros(
|
109
103
|
(self.dp_size,), dtype=torch.int32
|
@@ -111,12 +105,30 @@ class EAGLEDraftCudaGraphRunner:
|
|
111
105
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
112
106
|
(self.dp_size,), dtype=torch.int32
|
113
107
|
)
|
108
|
+
self.gathered_buffer = torch.zeros(
|
109
|
+
(
|
110
|
+
self.max_num_token * self.dp_size,
|
111
|
+
self.model_runner.model_config.hidden_size,
|
112
|
+
),
|
113
|
+
dtype=self.model_runner.dtype,
|
114
|
+
)
|
114
115
|
else:
|
115
116
|
assert self.require_attn_tp_gather
|
116
117
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
117
118
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
118
119
|
(1,), dtype=torch.int32
|
119
120
|
)
|
121
|
+
self.gathered_buffer = torch.zeros(
|
122
|
+
(
|
123
|
+
self.max_num_token,
|
124
|
+
self.model_runner.model_config.hidden_size,
|
125
|
+
),
|
126
|
+
dtype=self.model_runner.dtype,
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
self.global_num_tokens_gpu = None
|
130
|
+
self.global_num_tokens_for_logprob_gpu = None
|
131
|
+
self.gathered_buffer = None
|
120
132
|
|
121
133
|
# Capture
|
122
134
|
try:
|
@@ -130,9 +142,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
130
142
|
def can_run(self, forward_batch: ForwardBatch):
|
131
143
|
if self.require_mlp_tp_gather:
|
132
144
|
cuda_graph_bs = (
|
133
|
-
|
145
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
134
146
|
if self.model_runner.spec_algorithm.is_eagle()
|
135
|
-
else
|
147
|
+
else max(forward_batch.global_num_tokens_cpu)
|
136
148
|
)
|
137
149
|
else:
|
138
150
|
cuda_graph_bs = forward_batch.batch_size
|
@@ -168,26 +180,20 @@ class EAGLEDraftCudaGraphRunner:
|
|
168
180
|
if self.require_mlp_tp_gather:
|
169
181
|
self.global_num_tokens_gpu.copy_(
|
170
182
|
torch.tensor(
|
171
|
-
[
|
172
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
173
|
-
for i in range(self.dp_size)
|
174
|
-
],
|
183
|
+
[num_tokens] * self.dp_size,
|
175
184
|
dtype=torch.int32,
|
176
185
|
device=self.input_ids.device,
|
177
186
|
)
|
178
187
|
)
|
179
188
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
180
189
|
torch.tensor(
|
181
|
-
[
|
182
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
183
|
-
for i in range(self.dp_size)
|
184
|
-
],
|
190
|
+
[num_tokens] * self.dp_size,
|
185
191
|
dtype=torch.int32,
|
186
192
|
device=self.input_ids.device,
|
187
193
|
)
|
188
194
|
)
|
189
195
|
global_num_tokens = self.global_num_tokens_gpu
|
190
|
-
gathered_buffer = self.gathered_buffer[:num_tokens]
|
196
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
191
197
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
192
198
|
elif self.require_attn_tp_gather:
|
193
199
|
self.global_num_tokens_gpu.copy_(
|
@@ -233,6 +239,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
233
239
|
return_logprob=False,
|
234
240
|
positions=positions,
|
235
241
|
global_num_tokens_gpu=global_num_tokens,
|
242
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
236
243
|
gathered_buffer=gathered_buffer,
|
237
244
|
spec_algorithm=self.model_runner.spec_algorithm,
|
238
245
|
spec_info=spec_info,
|
@@ -290,12 +297,13 @@ class EAGLEDraftCudaGraphRunner:
|
|
290
297
|
|
291
298
|
# Pad
|
292
299
|
if self.require_mlp_tp_gather:
|
293
|
-
|
294
|
-
|
300
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
301
|
+
max_batch_size = (
|
302
|
+
max_num_tokens // self.num_tokens_per_bs
|
295
303
|
if self.model_runner.spec_algorithm.is_eagle()
|
296
|
-
else
|
304
|
+
else max_num_tokens
|
297
305
|
)
|
298
|
-
index = bisect.bisect_left(self.capture_bs,
|
306
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
299
307
|
else:
|
300
308
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
301
309
|
bs = self.capture_bs[index]
|
@@ -316,12 +324,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
316
324
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
317
325
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
318
326
|
|
327
|
+
# TODO(ch-wan): support num_token_non_padded
|
319
328
|
if self.require_gathered_buffer:
|
320
|
-
self.global_num_tokens_gpu.
|
321
|
-
self.global_num_tokens_for_logprob_gpu.
|
322
|
-
forward_batch.global_num_tokens_for_logprob_gpu
|
323
|
-
)
|
324
|
-
forward_batch.gathered_buffer = self.gathered_buffer
|
329
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
330
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
325
331
|
|
326
332
|
# Attention backend
|
327
333
|
if bs != raw_bs:
|
@@ -330,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
|
|
330
336
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
331
337
|
forward_batch.positions = self.positions[:num_tokens]
|
332
338
|
|
333
|
-
# Special handle for seq_len_cpu used when flashinfer mla is used
|
334
339
|
if forward_batch.seq_lens_cpu is not None:
|
335
340
|
if bs != raw_bs:
|
336
341
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode
|
8
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
9
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
10
11
|
CudaGraphRunner,
|
@@ -84,7 +85,15 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
84
85
|
self.hidden_states = torch.zeros(
|
85
86
|
(
|
86
87
|
self.max_num_token,
|
87
|
-
|
88
|
+
(
|
89
|
+
self.model_runner.model_config.hf_config.target_hidden_size
|
90
|
+
* 3
|
91
|
+
if hasattr(
|
92
|
+
self.model_runner.model_config.hf_config,
|
93
|
+
"target_hidden_size",
|
94
|
+
)
|
95
|
+
else self.model_runner.model_config.hidden_size * 3
|
96
|
+
),
|
88
97
|
),
|
89
98
|
dtype=self.model_runner.dtype,
|
90
99
|
)
|
@@ -101,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
101
110
|
)
|
102
111
|
|
103
112
|
if self.require_gathered_buffer:
|
104
|
-
self.gathered_buffer = torch.zeros(
|
105
|
-
(
|
106
|
-
self.max_num_token,
|
107
|
-
self.model_runner.model_config.hidden_size,
|
108
|
-
),
|
109
|
-
dtype=self.model_runner.dtype,
|
110
|
-
)
|
111
113
|
if self.require_mlp_tp_gather:
|
112
114
|
self.global_num_tokens_gpu = torch.zeros(
|
113
115
|
(self.dp_size,), dtype=torch.int32
|
@@ -115,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
115
117
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
116
118
|
(self.dp_size,), dtype=torch.int32
|
117
119
|
)
|
120
|
+
self.gathered_buffer = torch.zeros(
|
121
|
+
(
|
122
|
+
self.max_num_token * self.dp_size,
|
123
|
+
self.model_runner.model_config.hidden_size,
|
124
|
+
),
|
125
|
+
dtype=self.model_runner.dtype,
|
126
|
+
)
|
118
127
|
else:
|
119
128
|
assert self.require_attn_tp_gather
|
120
129
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
121
130
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
122
131
|
(1,), dtype=torch.int32
|
123
132
|
)
|
133
|
+
self.gathered_buffer = torch.zeros(
|
134
|
+
(
|
135
|
+
self.max_num_token,
|
136
|
+
self.model_runner.model_config.hidden_size,
|
137
|
+
),
|
138
|
+
dtype=self.model_runner.dtype,
|
139
|
+
)
|
140
|
+
else:
|
141
|
+
self.global_num_tokens_gpu = None
|
142
|
+
self.global_num_tokens_for_logprob_gpu = None
|
143
|
+
self.gathered_buffer = None
|
144
|
+
|
124
145
|
# Capture
|
125
146
|
try:
|
126
147
|
with model_capture_mode():
|
@@ -133,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
133
154
|
def can_run(self, forward_batch: ForwardBatch):
|
134
155
|
if self.require_mlp_tp_gather:
|
135
156
|
cuda_graph_bs = (
|
136
|
-
|
157
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
137
158
|
if self.model_runner.spec_algorithm.is_eagle()
|
138
|
-
else
|
159
|
+
else max(forward_batch.global_num_tokens_cpu)
|
139
160
|
)
|
140
161
|
else:
|
141
162
|
cuda_graph_bs = forward_batch.seq_lens.numel()
|
@@ -172,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
172
193
|
if self.require_mlp_tp_gather:
|
173
194
|
self.global_num_tokens_gpu.copy_(
|
174
195
|
torch.tensor(
|
175
|
-
[
|
176
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
177
|
-
for i in range(self.dp_size)
|
178
|
-
],
|
196
|
+
[num_tokens] * self.dp_size,
|
179
197
|
dtype=torch.int32,
|
180
198
|
device=self.input_ids.device,
|
181
199
|
)
|
182
200
|
)
|
183
201
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
184
202
|
torch.tensor(
|
185
|
-
[
|
186
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
187
|
-
for i in range(self.dp_size)
|
188
|
-
],
|
203
|
+
[bs] * self.dp_size,
|
189
204
|
dtype=torch.int32,
|
190
205
|
device=self.input_ids.device,
|
191
206
|
)
|
192
207
|
)
|
193
|
-
|
194
|
-
gathered_buffer = self.gathered_buffer[:num_tokens]
|
195
|
-
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
208
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
196
209
|
elif self.require_attn_tp_gather:
|
197
210
|
self.global_num_tokens_gpu.copy_(
|
198
211
|
torch.tensor(
|
@@ -203,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
203
216
|
)
|
204
217
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
205
218
|
torch.tensor(
|
206
|
-
[
|
219
|
+
[bs],
|
207
220
|
dtype=torch.int32,
|
208
221
|
device=self.input_ids.device,
|
209
222
|
)
|
210
223
|
)
|
211
|
-
global_num_tokens = self.global_num_tokens_gpu
|
212
224
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
213
|
-
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
214
225
|
else:
|
215
|
-
global_num_tokens = None
|
216
226
|
gathered_buffer = None
|
217
|
-
global_num_tokens_for_logprob = None
|
218
227
|
|
219
228
|
spec_info = EagleDraftInput(
|
220
229
|
hidden_states=hidden_states,
|
@@ -235,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
235
244
|
seq_lens_sum=seq_lens.sum().item(),
|
236
245
|
return_logprob=False,
|
237
246
|
positions=positions,
|
238
|
-
global_num_tokens_gpu=
|
239
|
-
global_num_tokens_for_logprob_gpu=
|
247
|
+
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
248
|
+
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
249
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
240
250
|
gathered_buffer=gathered_buffer,
|
241
251
|
spec_algorithm=self.model_runner.spec_algorithm,
|
242
252
|
spec_info=spec_info,
|
@@ -298,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
298
308
|
raw_bs = forward_batch.batch_size
|
299
309
|
num_tokens = forward_batch.input_ids.shape[0]
|
300
310
|
if self.require_mlp_tp_gather:
|
301
|
-
|
302
|
-
|
311
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
312
|
+
max_batch_size = (
|
313
|
+
max_num_tokens // self.num_tokens_per_bs
|
303
314
|
if self.model_runner.spec_algorithm.is_eagle()
|
304
|
-
else
|
315
|
+
else max_num_tokens
|
305
316
|
)
|
306
|
-
index = bisect.bisect_left(self.capture_bs,
|
317
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
307
318
|
else:
|
308
319
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
309
320
|
|
@@ -326,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
326
337
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
327
338
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
328
339
|
|
340
|
+
# TODO(ch-wan): support num_token_non_padded
|
329
341
|
if self.require_gathered_buffer:
|
330
|
-
self.global_num_tokens_gpu.
|
331
|
-
self.global_num_tokens_for_logprob_gpu.
|
332
|
-
forward_batch.global_num_tokens_for_logprob_gpu
|
333
|
-
)
|
334
|
-
forward_batch.gathered_buffer = self.gathered_buffer
|
342
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
343
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs)
|
335
344
|
|
336
345
|
if forward_batch.seq_lens_cpu is not None:
|
337
346
|
if bs != raw_bs:
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
3
4
|
import logging
|
4
5
|
import os
|
5
6
|
import time
|
@@ -70,9 +71,20 @@ class EagleDraftInput:
|
|
70
71
|
kv_indptr: torch.Tensor = None
|
71
72
|
kv_indices: torch.Tensor = None
|
72
73
|
|
74
|
+
# Shape info for padding
|
75
|
+
num_tokens_per_batch: int = -1
|
76
|
+
num_tokens_for_logprob_per_batch: int = -1
|
77
|
+
|
78
|
+
# Inputs for draft extend
|
79
|
+
# shape: (b,)
|
80
|
+
seq_lens_for_draft_extend: torch.Tensor = None
|
81
|
+
req_pool_indices_for_draft_extend: torch.Tensor = None
|
82
|
+
|
73
83
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
84
|
+
|
74
85
|
if batch.forward_mode.is_idle():
|
75
86
|
return
|
87
|
+
|
76
88
|
# Prefill only generate 1 token.
|
77
89
|
assert len(self.verified_id) == len(batch.seq_lens)
|
78
90
|
|
@@ -94,7 +106,7 @@ class EagleDraftInput:
|
|
94
106
|
capture_hidden_mode: CaptureHiddenMode,
|
95
107
|
):
|
96
108
|
return cls(
|
97
|
-
verified_id=
|
109
|
+
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
98
110
|
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
99
111
|
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
100
112
|
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
@@ -108,7 +120,10 @@ class EagleDraftInput:
|
|
108
120
|
batch: ScheduleBatch,
|
109
121
|
speculative_num_steps: int,
|
110
122
|
):
|
111
|
-
|
123
|
+
|
124
|
+
if batch.forward_mode.is_idle():
|
125
|
+
return
|
126
|
+
|
112
127
|
batch.input_ids = self.verified_id
|
113
128
|
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
114
129
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
@@ -315,7 +330,7 @@ class EagleVerifyInput:
|
|
315
330
|
def verify(
|
316
331
|
self,
|
317
332
|
batch: ScheduleBatch,
|
318
|
-
logits_output:
|
333
|
+
logits_output: LogitsProcessorOutput,
|
319
334
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
320
335
|
page_size: int,
|
321
336
|
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
@@ -362,6 +377,11 @@ class EagleVerifyInput:
|
|
362
377
|
)
|
363
378
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
364
379
|
|
380
|
+
if bs != len(sampling_info):
|
381
|
+
sampling_info = copy.deepcopy(sampling_info)
|
382
|
+
# NOTE: retrive_index are the indices of the requests that are kept.
|
383
|
+
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
384
|
+
|
365
385
|
# Apply the custom logit processors if registered in the sampling info.
|
366
386
|
if sampling_info.has_custom_logit_processor:
|
367
387
|
apply_custom_logit_processor(
|
@@ -593,13 +613,14 @@ class EagleVerifyInput:
|
|
593
613
|
batch.out_cache_loc = tgt_cache_loc
|
594
614
|
batch.seq_lens.add_(accept_length + 1)
|
595
615
|
|
596
|
-
draft_input = EagleDraftInput(
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
616
|
+
draft_input = EagleDraftInput(
|
617
|
+
hidden_states=batch.spec_info.hidden_states[accept_index],
|
618
|
+
verified_id=verified_id,
|
619
|
+
accept_length=accept_length,
|
620
|
+
accept_length_cpu=accept_length.tolist(),
|
621
|
+
seq_lens_for_draft_extend=batch.seq_lens,
|
622
|
+
req_pool_indices_for_draft_extend=batch.req_pool_indices,
|
623
|
+
)
|
603
624
|
|
604
625
|
return EagleVerifyOutput(
|
605
626
|
draft_input=draft_input,
|
@@ -622,7 +643,6 @@ class EagleVerifyInput:
|
|
622
643
|
batch.seq_lens.add_(accept_length + 1)
|
623
644
|
|
624
645
|
accept_length_cpu = accept_length.tolist()
|
625
|
-
draft_input = EagleDraftInput()
|
626
646
|
if len(unfinished_accept_index) > 0:
|
627
647
|
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
628
648
|
unfinished_index_device = torch.tensor(
|
@@ -653,18 +673,26 @@ class EagleVerifyInput:
|
|
653
673
|
next_power_of_2(self.draft_token_num),
|
654
674
|
)
|
655
675
|
|
656
|
-
draft_input
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
unfinished_index_device
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
676
|
+
draft_input = EagleDraftInput(
|
677
|
+
hidden_states=batch.spec_info.hidden_states[
|
678
|
+
unfinished_accept_index
|
679
|
+
],
|
680
|
+
verified_id=predict[unfinished_accept_index],
|
681
|
+
accept_length_cpu=draft_input_accept_length_cpu,
|
682
|
+
accept_length=accept_length[unfinished_index_device],
|
683
|
+
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
|
684
|
+
req_pool_indices_for_draft_extend=batch.req_pool_indices[
|
685
|
+
unfinished_index_device
|
686
|
+
],
|
687
|
+
)
|
688
|
+
else:
|
689
|
+
draft_input = EagleDraftInput.create_idle_input(
|
690
|
+
device=batch.device,
|
691
|
+
hidden_size=batch.model_config.hidden_size,
|
692
|
+
dtype=batch.model_config.dtype,
|
693
|
+
topk=self.topk,
|
694
|
+
capture_hidden_mode=CaptureHiddenMode.LAST,
|
695
|
+
)
|
668
696
|
|
669
697
|
return EagleVerifyOutput(
|
670
698
|
draft_input=draft_input,
|