sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -15,11 +15,12 @@
|
|
15
15
|
|
16
16
|
import argparse
|
17
17
|
import dataclasses
|
18
|
+
import json
|
18
19
|
import logging
|
19
20
|
import os
|
20
21
|
import random
|
21
22
|
import tempfile
|
22
|
-
from typing import List, Optional
|
23
|
+
from typing import List, Literal, Optional
|
23
24
|
|
24
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
25
26
|
from sglang.srt.reasoning_parser import ReasoningParser
|
@@ -127,14 +128,14 @@ class ServerArgs:
|
|
127
128
|
# Kernel backend
|
128
129
|
attention_backend: Optional[str] = None
|
129
130
|
sampling_backend: Optional[str] = None
|
130
|
-
grammar_backend: Optional[str] =
|
131
|
+
grammar_backend: Optional[str] = None
|
131
132
|
|
132
133
|
# Speculative decoding
|
133
134
|
speculative_algorithm: Optional[str] = None
|
134
135
|
speculative_draft_model_path: Optional[str] = None
|
135
|
-
speculative_num_steps: int =
|
136
|
-
speculative_eagle_topk: int =
|
137
|
-
speculative_num_draft_tokens: int =
|
136
|
+
speculative_num_steps: Optional[int] = None
|
137
|
+
speculative_eagle_topk: Optional[int] = None
|
138
|
+
speculative_num_draft_tokens: Optional[int] = None
|
138
139
|
speculative_accept_threshold_single: float = 1.0
|
139
140
|
speculative_accept_threshold_acc: float = 1.0
|
140
141
|
speculative_token_map: Optional[str] = None
|
@@ -160,6 +161,7 @@ class ServerArgs:
|
|
160
161
|
enable_dp_attention: bool = False
|
161
162
|
enable_ep_moe: bool = False
|
162
163
|
enable_deepep_moe: bool = False
|
164
|
+
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
163
165
|
enable_torch_compile: bool = False
|
164
166
|
torch_compile_max_bs: int = 32
|
165
167
|
cuda_graph_max_bs: Optional[int] = None
|
@@ -177,10 +179,12 @@ class ServerArgs:
|
|
177
179
|
tool_call_parser: Optional[str] = None
|
178
180
|
enable_hierarchical_cache: bool = False
|
179
181
|
hicache_ratio: float = 2.0
|
180
|
-
enable_flashinfer_mla: bool = False
|
182
|
+
enable_flashinfer_mla: bool = False # TODO: remove this argument
|
181
183
|
enable_flashmla: bool = False
|
182
184
|
flashinfer_mla_disable_ragged: bool = False
|
183
185
|
warmups: Optional[str] = None
|
186
|
+
n_share_experts_fusion: int = 0
|
187
|
+
disable_shared_experts_fusion: bool = False
|
184
188
|
|
185
189
|
# Debug tensor dumps
|
186
190
|
debug_tensor_dump_output_folder: Optional[str] = None
|
@@ -192,6 +196,13 @@ class ServerArgs:
|
|
192
196
|
disaggregation_bootstrap_port: int = 8998
|
193
197
|
|
194
198
|
def __post_init__(self):
|
199
|
+
# Expert parallelism
|
200
|
+
if self.enable_ep_moe:
|
201
|
+
self.ep_size = self.tp_size
|
202
|
+
logger.info(
|
203
|
+
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
204
|
+
)
|
205
|
+
|
195
206
|
# Set missing default values
|
196
207
|
if self.tokenizer_path is None:
|
197
208
|
self.tokenizer_path = self.model_path
|
@@ -215,6 +226,9 @@ class ServerArgs:
|
|
215
226
|
# GPU memory is not known yet or no GPU is available.
|
216
227
|
gpu_mem = None
|
217
228
|
|
229
|
+
if is_hip():
|
230
|
+
self.disable_shared_experts_fusion = True
|
231
|
+
|
218
232
|
# Set mem fraction static, which depends on the tensor parallelism size
|
219
233
|
if self.mem_fraction_static is None:
|
220
234
|
if self.tp_size >= 16:
|
@@ -253,15 +267,11 @@ class ServerArgs:
|
|
253
267
|
else:
|
254
268
|
self.cuda_graph_max_bs = 160
|
255
269
|
|
256
|
-
#
|
270
|
+
# Set kernel backends for hpu device
|
257
271
|
if self.device == "hpu":
|
258
272
|
self.attention_backend = "torch_native"
|
259
273
|
self.sampling_backend = "pytorch"
|
260
274
|
|
261
|
-
if self.attention_backend is None:
|
262
|
-
self.attention_backend = (
|
263
|
-
"flashinfer" if is_flashinfer_available() else "triton"
|
264
|
-
)
|
265
275
|
if self.sampling_backend is None:
|
266
276
|
self.sampling_backend = (
|
267
277
|
"flashinfer" if is_flashinfer_available() else "pytorch"
|
@@ -273,6 +283,10 @@ class ServerArgs:
|
|
273
283
|
)
|
274
284
|
self.disable_cuda_graph = True
|
275
285
|
|
286
|
+
# Choose grammar backend
|
287
|
+
if self.grammar_backend is None:
|
288
|
+
self.grammar_backend = "xgrammar"
|
289
|
+
|
276
290
|
# Expert parallelism
|
277
291
|
if self.enable_ep_moe:
|
278
292
|
self.ep_size = self.tp_size
|
@@ -295,6 +309,10 @@ class ServerArgs:
|
|
295
309
|
self.enable_sp_layernorm = False
|
296
310
|
# DeepEP MoE
|
297
311
|
if self.enable_deepep_moe:
|
312
|
+
if self.deepep_mode == "auto":
|
313
|
+
assert (
|
314
|
+
not self.enable_dp_attention
|
315
|
+
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
298
316
|
self.ep_size = self.tp_size
|
299
317
|
self.enable_sp_layernorm = (
|
300
318
|
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
@@ -313,12 +331,29 @@ class ServerArgs:
|
|
313
331
|
or self.speculative_algorithm == "EAGLE3"
|
314
332
|
):
|
315
333
|
if self.max_running_requests is None:
|
316
|
-
self.max_running_requests =
|
334
|
+
self.max_running_requests = 48
|
317
335
|
self.disable_overlap_schedule = True
|
318
336
|
logger.info(
|
319
337
|
"Overlap scheduler is disabled because of using "
|
320
338
|
"eagle speculative decoding."
|
321
339
|
)
|
340
|
+
|
341
|
+
# Auto choose parameters
|
342
|
+
if self.speculative_num_steps is None:
|
343
|
+
assert (
|
344
|
+
self.speculative_eagle_topk is None
|
345
|
+
and self.speculative_num_draft_tokens is None
|
346
|
+
)
|
347
|
+
(
|
348
|
+
self.speculative_num_steps,
|
349
|
+
self.speculative_eagle_topk,
|
350
|
+
self.speculative_num_draft_tokens,
|
351
|
+
) = auto_choose_speculative_params(self)
|
352
|
+
|
353
|
+
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
354
|
+
self.speculative_eagle_topk = 1
|
355
|
+
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
356
|
+
|
322
357
|
# The token generated from the verify step is counted.
|
323
358
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
324
359
|
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
@@ -462,6 +497,7 @@ class ServerArgs:
|
|
462
497
|
"modelopt",
|
463
498
|
"w8a8_int8",
|
464
499
|
"w8a8_fp8",
|
500
|
+
"moe_wna16",
|
465
501
|
],
|
466
502
|
help="The quantization method.",
|
467
503
|
)
|
@@ -795,14 +831,14 @@ class ServerArgs:
|
|
795
831
|
parser.add_argument(
|
796
832
|
"--grammar-backend",
|
797
833
|
type=str,
|
798
|
-
choices=["xgrammar", "outlines", "llguidance"],
|
834
|
+
choices=["xgrammar", "outlines", "llguidance", "none"],
|
799
835
|
default=ServerArgs.grammar_backend,
|
800
836
|
help="Choose the backend for grammar-guided decoding.",
|
801
837
|
)
|
802
838
|
parser.add_argument(
|
803
839
|
"--enable-flashinfer-mla",
|
804
840
|
action="store_true",
|
805
|
-
help="Enable FlashInfer MLA optimization",
|
841
|
+
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
|
806
842
|
)
|
807
843
|
parser.add_argument(
|
808
844
|
"--enable-flashmla",
|
@@ -1060,6 +1096,25 @@ class ServerArgs:
|
|
1060
1096
|
action="store_true",
|
1061
1097
|
help="Enabling DeepEP MoE implementation for EP MoE.",
|
1062
1098
|
)
|
1099
|
+
parser.add_argument(
|
1100
|
+
"--deepep-mode",
|
1101
|
+
type=str,
|
1102
|
+
choices=["normal", "low_latency", "auto"],
|
1103
|
+
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
1104
|
+
)
|
1105
|
+
|
1106
|
+
parser.add_argument(
|
1107
|
+
"--n-share-experts-fusion",
|
1108
|
+
type=int,
|
1109
|
+
default=0,
|
1110
|
+
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
|
1111
|
+
"we use tp_size by default.",
|
1112
|
+
)
|
1113
|
+
parser.add_argument(
|
1114
|
+
"--disable-shared-experts-fusion",
|
1115
|
+
action="store_true",
|
1116
|
+
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
|
1117
|
+
)
|
1063
1118
|
|
1064
1119
|
# Server warmups
|
1065
1120
|
parser.add_argument(
|
@@ -1253,3 +1308,33 @@ class DeprecatedAction(argparse.Action):
|
|
1253
1308
|
|
1254
1309
|
def __call__(self, parser, namespace, values, option_string=None):
|
1255
1310
|
raise ValueError(self.help)
|
1311
|
+
|
1312
|
+
|
1313
|
+
def auto_choose_speculative_params(self: ServerArgs):
|
1314
|
+
"""
|
1315
|
+
Automatically choose the parameters for speculative decoding.
|
1316
|
+
|
1317
|
+
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
1318
|
+
"""
|
1319
|
+
if self.decrypted_config_file:
|
1320
|
+
config_path = self.decrypted_config_file
|
1321
|
+
else:
|
1322
|
+
config_path = os.path.join(self.model_path, "config.json")
|
1323
|
+
if not os.path.exists(config_path):
|
1324
|
+
raise ValueError(f"{config_path} is not found.")
|
1325
|
+
|
1326
|
+
config = json.load(open(config_path))
|
1327
|
+
|
1328
|
+
arch = config.get("architectures", ["Unknown"])[0]
|
1329
|
+
|
1330
|
+
if arch in ["LlamaForCausalLM"]:
|
1331
|
+
# The default value for llama
|
1332
|
+
return (5, 4, 8)
|
1333
|
+
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
|
1334
|
+
# The default value for deepseek
|
1335
|
+
return (5, 4, 8)
|
1336
|
+
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
|
1337
|
+
return (5, 4, 8)
|
1338
|
+
else:
|
1339
|
+
# The default value for all other models
|
1340
|
+
return (5, 4, 8)
|
@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
|
|
214
214
|
forward_batch.positions = self.positions[:num_tokens]
|
215
215
|
|
216
216
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
217
|
-
if (forward_batch.
|
217
|
+
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
|
218
218
|
self.seq_lens_cpu.fill_(1)
|
219
|
-
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.
|
220
|
-
forward_batch.
|
219
|
+
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
220
|
+
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
221
221
|
|
222
222
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
223
223
|
forward_batch, bs
|
@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
233
233
|
forward_batch.positions = self.positions[:raw_num_token]
|
234
234
|
forward_batch.seq_lens = self.seq_lens[:raw_bs]
|
235
235
|
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
|
236
|
-
if forward_batch.
|
237
|
-
forward_batch.
|
236
|
+
if forward_batch.seq_lens_cpu is not None:
|
237
|
+
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
|
238
238
|
|
239
239
|
return out
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import os
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from typing import TYPE_CHECKING, List, Optional
|
5
6
|
|
@@ -10,11 +11,15 @@ import triton.language as tl
|
|
10
11
|
|
11
12
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
13
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
13
|
-
from sglang.srt.managers.schedule_batch import
|
14
|
+
from sglang.srt.managers.schedule_batch import (
|
15
|
+
ScheduleBatch,
|
16
|
+
get_last_loc,
|
17
|
+
global_server_args_dict,
|
18
|
+
)
|
14
19
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
15
20
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
16
21
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
17
|
-
from sglang.srt.utils import is_cuda_available, is_hip
|
22
|
+
from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
|
18
23
|
|
19
24
|
if is_cuda_available():
|
20
25
|
from sgl_kernel import (
|
@@ -34,6 +39,9 @@ import logging
|
|
34
39
|
logger = logging.getLogger(__name__)
|
35
40
|
|
36
41
|
|
42
|
+
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
|
43
|
+
|
44
|
+
|
37
45
|
@dataclass
|
38
46
|
class EagleDraftInput:
|
39
47
|
# The inputs for decode
|
@@ -93,7 +101,7 @@ class EagleDraftInput:
|
|
93
101
|
torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
|
94
102
|
self.positions,
|
95
103
|
new_verified_id,
|
96
|
-
|
104
|
+
next_power_of_2(speculative_num_steps + 1),
|
97
105
|
)
|
98
106
|
|
99
107
|
batch.seq_lens_sum = sum(seq_lens_cpu)
|
@@ -225,18 +233,34 @@ class EagleVerifyInput:
|
|
225
233
|
CaptureHiddenMode.FULL,
|
226
234
|
)
|
227
235
|
|
228
|
-
def prepare_for_verify(self, batch: ScheduleBatch):
|
236
|
+
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
229
237
|
batch.input_ids = self.draft_token
|
230
|
-
|
238
|
+
|
239
|
+
if page_size == 1:
|
240
|
+
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
241
|
+
end_offset = batch.seq_lens + self.draft_token_num
|
242
|
+
else:
|
243
|
+
prefix_lens = batch.seq_lens
|
244
|
+
end_offset = prefix_lens + self.draft_token_num
|
245
|
+
last_loc = get_last_loc(
|
246
|
+
batch.req_to_token_pool.req_to_token,
|
247
|
+
batch.req_pool_indices,
|
248
|
+
prefix_lens,
|
249
|
+
)
|
250
|
+
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
251
|
+
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
252
|
+
)
|
253
|
+
self.last_loc = last_loc
|
254
|
+
|
231
255
|
bs = batch.batch_size()
|
232
256
|
assign_req_to_token_pool[(bs,)](
|
233
257
|
batch.req_pool_indices,
|
234
258
|
batch.req_to_token_pool.req_to_token,
|
235
259
|
batch.seq_lens,
|
236
|
-
|
260
|
+
end_offset,
|
237
261
|
batch.out_cache_loc,
|
238
262
|
batch.req_to_token_pool.req_to_token.shape[1],
|
239
|
-
|
263
|
+
next_power_of_2(bs),
|
240
264
|
)
|
241
265
|
|
242
266
|
def generate_attn_arg_prefill(
|
@@ -282,6 +306,7 @@ class EagleVerifyInput:
|
|
282
306
|
batch: ScheduleBatch,
|
283
307
|
logits_output: torch.Tensor,
|
284
308
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
309
|
+
page_size: int,
|
285
310
|
) -> torch.Tensor:
|
286
311
|
"""
|
287
312
|
Verify and find accepted tokens based on logits output and batch
|
@@ -305,6 +330,7 @@ class EagleVerifyInput:
|
|
305
330
|
)
|
306
331
|
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
|
307
332
|
|
333
|
+
# Apply penalty
|
308
334
|
if sampling_info.penalizer_orchestrator.is_required:
|
309
335
|
# This is a relaxed version of penalties for speculative decoding.
|
310
336
|
linear_penalty = torch.zeros(
|
@@ -317,6 +343,7 @@ class EagleVerifyInput:
|
|
317
343
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
318
344
|
)
|
319
345
|
|
346
|
+
# Sample tokens
|
320
347
|
if batch.sampling_info.is_all_greedy:
|
321
348
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
322
349
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
@@ -378,13 +405,24 @@ class EagleVerifyInput:
|
|
378
405
|
deterministic=True,
|
379
406
|
)
|
380
407
|
|
408
|
+
if SIMULATE_ACC_LEN:
|
409
|
+
# Do simulation
|
410
|
+
accept_index = _generate_simulated_accept_index(
|
411
|
+
accept_index=accept_index,
|
412
|
+
predict=predict, # mutable
|
413
|
+
accept_length=accept_length, # mutable
|
414
|
+
simulate_acc_len=SIMULATE_ACC_LEN,
|
415
|
+
bs=bs,
|
416
|
+
spec_steps=self.spec_steps,
|
417
|
+
)
|
418
|
+
|
381
419
|
new_accept_index = []
|
382
420
|
unfinished_index = []
|
383
421
|
accept_index_cpu = accept_index.tolist()
|
384
422
|
predict_cpu = predict.tolist()
|
385
423
|
has_finished = False
|
386
424
|
|
387
|
-
#
|
425
|
+
# Iterate every accepted token and check if req has finished after append the token
|
388
426
|
# should be checked BEFORE free kv cache slots
|
389
427
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
390
428
|
new_accept_index_ = []
|
@@ -407,13 +445,28 @@ class EagleVerifyInput:
|
|
407
445
|
unfinished_index.append(i)
|
408
446
|
req.spec_verify_ct += 1
|
409
447
|
|
448
|
+
if has_finished:
|
449
|
+
accept_length = (accept_index != -1).sum(dim=1) - 1
|
450
|
+
|
451
|
+
# Free the KV cache for unaccepted tokens
|
452
|
+
accept_index = accept_index[accept_index != -1]
|
453
|
+
verified_id = predict[accept_index]
|
454
|
+
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
455
|
+
evict_mask[accept_index] = False
|
456
|
+
|
457
|
+
if page_size != 1:
|
458
|
+
align_evict_mask_to_page_size[len(batch.seq_lens),](
|
459
|
+
batch.seq_lens,
|
460
|
+
evict_mask,
|
461
|
+
page_size,
|
462
|
+
self.draft_token_num,
|
463
|
+
next_power_of_2(self.draft_token_num),
|
464
|
+
)
|
465
|
+
|
466
|
+
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
467
|
+
|
468
|
+
# Construct EagleVerifyOutput
|
410
469
|
if not has_finished:
|
411
|
-
accept_index = accept_index[accept_index != -1]
|
412
|
-
verified_id = predict[accept_index]
|
413
|
-
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
414
|
-
evict_mask[accept_index] = False
|
415
|
-
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
416
|
-
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
417
470
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
418
471
|
assign_req_to_token_pool[(bs,)](
|
419
472
|
batch.req_pool_indices,
|
@@ -422,7 +475,7 @@ class EagleVerifyInput:
|
|
422
475
|
batch.seq_lens + accept_length + 1,
|
423
476
|
batch.out_cache_loc,
|
424
477
|
batch.req_to_token_pool.req_to_token.shape[1],
|
425
|
-
|
478
|
+
next_power_of_2(bs),
|
426
479
|
)
|
427
480
|
batch.seq_lens.add_(accept_length + 1)
|
428
481
|
accept_length_cpu = accept_length.tolist()
|
@@ -443,13 +496,6 @@ class EagleVerifyInput:
|
|
443
496
|
accepeted_indices=accept_index,
|
444
497
|
)
|
445
498
|
else:
|
446
|
-
accept_length = (accept_index != -1).sum(dim=1) - 1
|
447
|
-
accept_index = accept_index[accept_index != -1]
|
448
|
-
verified_id = predict[accept_index]
|
449
|
-
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
450
|
-
evict_mask[accept_index] = False
|
451
|
-
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
452
|
-
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
453
499
|
assign_req_to_token_pool[(bs,)](
|
454
500
|
batch.req_pool_indices,
|
455
501
|
batch.req_to_token_pool.req_to_token,
|
@@ -457,7 +503,7 @@ class EagleVerifyInput:
|
|
457
503
|
batch.seq_lens + accept_length + 1,
|
458
504
|
batch.out_cache_loc[accept_index],
|
459
505
|
batch.req_to_token_pool.req_to_token.shape[1],
|
460
|
-
|
506
|
+
next_power_of_2(bs),
|
461
507
|
)
|
462
508
|
batch.seq_lens.add_(accept_length + 1)
|
463
509
|
accept_length_cpu = accept_length.tolist()
|
@@ -465,20 +511,21 @@ class EagleVerifyInput:
|
|
465
511
|
draft_input = EagleDraftInput()
|
466
512
|
if len(new_accept_index) > 0:
|
467
513
|
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
514
|
+
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
|
468
515
|
draft_input.hidden_states = batch.spec_info.hidden_states[
|
469
516
|
new_accept_index
|
470
517
|
]
|
471
518
|
draft_input.verified_id = predict[new_accept_index]
|
472
|
-
draft_input.accept_length = accept_length[unfinished_index]
|
473
519
|
draft_input.accept_length_cpu = [
|
474
520
|
accept_length_cpu[i] for i in unfinished_index
|
475
521
|
]
|
522
|
+
draft_input.accept_length = accept_length[unfinished_index_device]
|
476
523
|
if has_finished:
|
477
524
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[
|
478
|
-
|
525
|
+
unfinished_index_device
|
479
526
|
]
|
480
527
|
draft_input.req_pool_indices_for_draft_extend = (
|
481
|
-
batch.req_pool_indices[
|
528
|
+
batch.req_pool_indices[unfinished_index_device]
|
482
529
|
)
|
483
530
|
else:
|
484
531
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
|
|
564
611
|
pool_len: tl.constexpr,
|
565
612
|
topk: tl.constexpr,
|
566
613
|
speculative_num_steps: tl.constexpr,
|
614
|
+
page_size: tl.constexpr,
|
567
615
|
):
|
568
616
|
BLOCK_SIZE: tl.constexpr = 32
|
569
617
|
pid = tl.program_id(axis=0)
|
570
618
|
kv_start = tl.load(seq_lens + pid)
|
571
|
-
|
619
|
+
|
620
|
+
if page_size == 1 or topk == 1:
|
621
|
+
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
|
622
|
+
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
623
|
+
else:
|
624
|
+
prefix_len = tl.load(seq_lens + pid)
|
625
|
+
last_page_len = prefix_len % page_size
|
626
|
+
num_new_page = (
|
627
|
+
last_page_len + speculative_num_steps + page_size - 1
|
628
|
+
) // page_size
|
629
|
+
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
|
630
|
+
|
572
631
|
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
573
|
-
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
574
632
|
|
575
633
|
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
|
576
634
|
for i in range(num_loop):
|
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
|
|
642
700
|
tl.store(kv_indptr + zid, base + zid * iters)
|
643
701
|
|
644
702
|
|
703
|
+
@triton.jit
|
704
|
+
def align_evict_mask_to_page_size(
|
705
|
+
seq_lens,
|
706
|
+
evict_mask,
|
707
|
+
page_size: tl.constexpr,
|
708
|
+
num_draft_tokens: tl.constexpr,
|
709
|
+
BLOCK_SIZE: tl.constexpr,
|
710
|
+
):
|
711
|
+
t_range = tl.arange(0, BLOCK_SIZE)
|
712
|
+
|
713
|
+
bid = tl.program_id(axis=0)
|
714
|
+
seq_len = tl.load(seq_lens + bid)
|
715
|
+
io_mask = t_range < num_draft_tokens
|
716
|
+
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
|
717
|
+
|
718
|
+
num_trues = tl.sum(mask_row)
|
719
|
+
num_false = num_draft_tokens - num_trues
|
720
|
+
|
721
|
+
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
|
722
|
+
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
|
723
|
+
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
724
|
+
|
725
|
+
|
645
726
|
@torch.compile(dynamic=True)
|
646
727
|
def select_top_k_tokens(
|
647
728
|
i: int,
|
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
|
|
699
780
|
else:
|
700
781
|
# Use topk for efficiency with larger k values
|
701
782
|
return torch.topk(values, topk, dim=dim)
|
783
|
+
|
784
|
+
|
785
|
+
def _generate_simulated_accept_index(
|
786
|
+
accept_index,
|
787
|
+
predict,
|
788
|
+
accept_length,
|
789
|
+
simulate_acc_len,
|
790
|
+
bs,
|
791
|
+
spec_steps,
|
792
|
+
):
|
793
|
+
simulate_acc_len_float = float(simulate_acc_len)
|
794
|
+
simulated_values = torch.normal(
|
795
|
+
mean=simulate_acc_len_float,
|
796
|
+
std=1.0,
|
797
|
+
size=(1,),
|
798
|
+
device="cpu",
|
799
|
+
)
|
800
|
+
# clamp simulated values to be between 1 and self.spec_steps
|
801
|
+
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
|
802
|
+
simulate_acc_len = int(simulated_values.round().item())
|
803
|
+
|
804
|
+
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
805
|
+
sim_accept_index = torch.full(
|
806
|
+
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
807
|
+
)
|
808
|
+
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
|
809
|
+
simulate_acc_len, device=accept_index.device
|
810
|
+
)
|
811
|
+
accept_length.fill_(simulate_acc_len - 1)
|
812
|
+
predict.fill_(100) # some legit token id
|
813
|
+
return sim_accept_index
|