sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -9,15 +9,18 @@ import torch.nn.functional as F
|
|
9
9
|
import triton
|
10
10
|
import triton.language as tl
|
11
11
|
|
12
|
+
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
12
13
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
13
14
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
14
15
|
from sglang.srt.managers.schedule_batch import (
|
16
|
+
Req,
|
15
17
|
ScheduleBatch,
|
16
18
|
get_last_loc,
|
17
19
|
global_server_args_dict,
|
18
20
|
)
|
19
21
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
20
22
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
23
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
21
24
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
22
25
|
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
23
26
|
|
@@ -167,12 +170,12 @@ class EagleVerifyOutput:
|
|
167
170
|
draft_input: EagleDraftInput
|
168
171
|
# Logit outputs from target worker
|
169
172
|
logits_output: LogitsProcessorOutput
|
170
|
-
#
|
173
|
+
# Accepted token ids including the bonus token
|
171
174
|
verified_id: torch.Tensor
|
172
|
-
#
|
175
|
+
# Accepted token length per sequence in a batch in CPU.
|
173
176
|
accept_length_per_req_cpu: List[int]
|
174
|
-
#
|
175
|
-
|
177
|
+
# Accepted indices from logits_output.next_token_logits
|
178
|
+
accepted_indices: torch.Tensor
|
176
179
|
|
177
180
|
|
178
181
|
@dataclass
|
@@ -187,6 +190,7 @@ class EagleVerifyInput:
|
|
187
190
|
draft_token_num: int
|
188
191
|
spec_steps: int
|
189
192
|
capture_hidden_mode: CaptureHiddenMode
|
193
|
+
grammar: BaseGrammarObject = None
|
190
194
|
|
191
195
|
@classmethod
|
192
196
|
def create(
|
@@ -307,6 +311,7 @@ class EagleVerifyInput:
|
|
307
311
|
logits_output: torch.Tensor,
|
308
312
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
309
313
|
page_size: int,
|
314
|
+
vocab_mask: Optional[torch.Tensor] = None,
|
310
315
|
) -> torch.Tensor:
|
311
316
|
"""
|
312
317
|
Verify and find accepted tokens based on logits output and batch
|
@@ -316,7 +321,7 @@ class EagleVerifyInput:
|
|
316
321
|
|
317
322
|
This API updates values inside logits_output based on the accepted
|
318
323
|
tokens. I.e., logits_output.next_token_logits only contains
|
319
|
-
|
324
|
+
accepted token logits.
|
320
325
|
"""
|
321
326
|
bs = self.retrive_index.shape[0]
|
322
327
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
@@ -343,6 +348,13 @@ class EagleVerifyInput:
|
|
343
348
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
344
349
|
)
|
345
350
|
|
351
|
+
# Apply grammar mask
|
352
|
+
if vocab_mask is not None:
|
353
|
+
assert self.grammar is not None
|
354
|
+
self.grammar.apply_vocab_mask(
|
355
|
+
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
356
|
+
)
|
357
|
+
|
346
358
|
# Sample tokens
|
347
359
|
if batch.sampling_info.is_all_greedy:
|
348
360
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
@@ -440,6 +452,15 @@ class EagleVerifyInput:
|
|
440
452
|
break
|
441
453
|
else:
|
442
454
|
new_accept_index_.append(idx)
|
455
|
+
# update grammar state
|
456
|
+
if req.grammar is not None:
|
457
|
+
try:
|
458
|
+
req.grammar.accept_token(id)
|
459
|
+
except ValueError as e:
|
460
|
+
logger.info(
|
461
|
+
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
|
462
|
+
)
|
463
|
+
raise e
|
443
464
|
if not req.finished():
|
444
465
|
new_accept_index.extend(new_accept_index_)
|
445
466
|
unfinished_index.append(i)
|
@@ -493,7 +514,7 @@ class EagleVerifyInput:
|
|
493
514
|
logits_output=logits_output,
|
494
515
|
verified_id=verified_id,
|
495
516
|
accept_length_per_req_cpu=accept_length_cpu,
|
496
|
-
|
517
|
+
accepted_indices=accept_index,
|
497
518
|
)
|
498
519
|
else:
|
499
520
|
assign_req_to_token_pool[(bs,)](
|
@@ -539,7 +560,7 @@ class EagleVerifyInput:
|
|
539
560
|
logits_output=logits_output,
|
540
561
|
verified_id=verified_id,
|
541
562
|
accept_length_per_req_cpu=accept_length_cpu,
|
542
|
-
|
563
|
+
accepted_indices=accept_index,
|
543
564
|
)
|
544
565
|
|
545
566
|
|
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
|
|
801
822
|
accept_length.fill_(simulate_acc_len - 1)
|
802
823
|
predict.fill_(100) # some legit token id
|
803
824
|
return sim_accept_index
|
825
|
+
|
826
|
+
|
827
|
+
def traverse_tree(
|
828
|
+
retrieve_next_token: torch.Tensor,
|
829
|
+
retrieve_next_sibling: torch.Tensor,
|
830
|
+
draft_tokens: torch.Tensor,
|
831
|
+
grammar: BaseGrammarObject,
|
832
|
+
allocate_token_bitmask: torch.Tensor,
|
833
|
+
):
|
834
|
+
"""
|
835
|
+
Traverse the tree constructed by the draft model to generate the logits mask.
|
836
|
+
"""
|
837
|
+
assert (
|
838
|
+
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
|
839
|
+
)
|
840
|
+
|
841
|
+
allocate_token_bitmask.fill_(0)
|
842
|
+
|
843
|
+
def dfs(
|
844
|
+
curr: int,
|
845
|
+
retrieve_next_token: torch.Tensor,
|
846
|
+
retrieve_next_sibling: torch.Tensor,
|
847
|
+
parent_pos: int,
|
848
|
+
):
|
849
|
+
if curr == 0:
|
850
|
+
# the first token generated by the target model, and thus it is always
|
851
|
+
# accepted from the previous iteration
|
852
|
+
accepted = True
|
853
|
+
else:
|
854
|
+
parent_bitmask = allocate_token_bitmask[parent_pos]
|
855
|
+
curr_token_id = draft_tokens[curr]
|
856
|
+
# 32 boolean bitmask values are packed into 32-bit integers
|
857
|
+
accepted = (
|
858
|
+
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
|
859
|
+
) != 0
|
860
|
+
|
861
|
+
if accepted:
|
862
|
+
if curr != 0:
|
863
|
+
# Accept the current token
|
864
|
+
grammar.accept_token(draft_tokens[curr])
|
865
|
+
if not grammar.is_terminated():
|
866
|
+
# Generate the bitmask for the current token
|
867
|
+
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
|
868
|
+
if retrieve_next_token[curr] != -1:
|
869
|
+
# Visit the child node
|
870
|
+
dfs(
|
871
|
+
retrieve_next_token[curr],
|
872
|
+
retrieve_next_token,
|
873
|
+
retrieve_next_sibling,
|
874
|
+
curr,
|
875
|
+
)
|
876
|
+
|
877
|
+
if curr != 0:
|
878
|
+
# Rollback the current token
|
879
|
+
grammar.rollback(1)
|
880
|
+
|
881
|
+
if retrieve_next_sibling[curr] != -1:
|
882
|
+
# Visit the sibling node
|
883
|
+
dfs(
|
884
|
+
retrieve_next_sibling[curr],
|
885
|
+
retrieve_next_token,
|
886
|
+
retrieve_next_sibling,
|
887
|
+
parent_pos,
|
888
|
+
)
|
889
|
+
|
890
|
+
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
|
891
|
+
|
892
|
+
|
893
|
+
def generate_token_bitmask(
|
894
|
+
reqs: List[Req],
|
895
|
+
verify_input: EagleVerifyInput,
|
896
|
+
retrieve_next_token_cpu: torch.Tensor,
|
897
|
+
retrieve_next_sibling_cpu: torch.Tensor,
|
898
|
+
draft_tokens_cpu: torch.Tensor,
|
899
|
+
vocab_size: int,
|
900
|
+
):
|
901
|
+
"""
|
902
|
+
Generate the logit mask for structured output.
|
903
|
+
Draft model's token can be either valid or invalid with respect to the grammar.
|
904
|
+
We need to perform DFS to figure out:
|
905
|
+
1. which tokens are accepted by the grammar
|
906
|
+
2. what is the corresponding logit mask.
|
907
|
+
"""
|
908
|
+
|
909
|
+
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
910
|
+
|
911
|
+
allocate_token_bitmask = None
|
912
|
+
assert len(reqs) == retrieve_next_token_cpu.shape[0]
|
913
|
+
grammar = None
|
914
|
+
for i, req in enumerate(reqs):
|
915
|
+
if req.grammar is not None:
|
916
|
+
if allocate_token_bitmask is None:
|
917
|
+
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
|
918
|
+
vocab_size=vocab_size,
|
919
|
+
batch_size=draft_tokens_cpu.numel(),
|
920
|
+
device="cpu",
|
921
|
+
)
|
922
|
+
grammar = req.grammar
|
923
|
+
traverse_tree(
|
924
|
+
retrieve_next_token_cpu[i],
|
925
|
+
retrieve_next_sibling_cpu[i],
|
926
|
+
draft_tokens_cpu[i],
|
927
|
+
req.grammar,
|
928
|
+
allocate_token_bitmask[
|
929
|
+
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
930
|
+
],
|
931
|
+
)
|
932
|
+
|
933
|
+
verify_input.grammar = grammar
|
934
|
+
return allocate_token_bitmask
|
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
|
|
31
31
|
EagleVerifyInput,
|
32
32
|
EagleVerifyOutput,
|
33
33
|
assign_draft_cache_locs,
|
34
|
+
generate_token_bitmask,
|
34
35
|
select_top_k_tokens,
|
35
36
|
)
|
36
37
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -199,9 +200,22 @@ class EAGLEWorker(TpModelWorker):
|
|
199
200
|
self.draft_extend_attn_backend = None
|
200
201
|
self.padded_static_len = self.speculative_num_steps + 1
|
201
202
|
self.has_prefill_wrapper_verify = False
|
203
|
+
elif self.server_args.attention_backend == "flashmla":
|
204
|
+
from sglang.srt.layers.attention.flashmla_backend import (
|
205
|
+
FlashMLAMultiStepDraftBackend,
|
206
|
+
)
|
207
|
+
|
208
|
+
self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
|
209
|
+
self.draft_model_runner,
|
210
|
+
self.topk,
|
211
|
+
self.speculative_num_steps,
|
212
|
+
)
|
213
|
+
self.draft_extend_attn_backend = None
|
214
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
215
|
+
self.has_prefill_wrapper_verify = False
|
202
216
|
else:
|
203
217
|
raise ValueError(
|
204
|
-
f"EAGLE is not
|
218
|
+
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
205
219
|
)
|
206
220
|
|
207
221
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
@@ -215,7 +229,7 @@ class EAGLEWorker(TpModelWorker):
|
|
215
229
|
return
|
216
230
|
|
217
231
|
# Capture draft
|
218
|
-
tic = time.
|
232
|
+
tic = time.perf_counter()
|
219
233
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
220
234
|
logger.info(
|
221
235
|
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
@@ -223,7 +237,7 @@ class EAGLEWorker(TpModelWorker):
|
|
223
237
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
224
238
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
225
239
|
logger.info(
|
226
|
-
f"Capture draft cuda graph end. Time elapsed: {time.
|
240
|
+
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
227
241
|
)
|
228
242
|
|
229
243
|
# Capture extend
|
@@ -245,14 +259,14 @@ class EAGLEWorker(TpModelWorker):
|
|
245
259
|
Args:
|
246
260
|
batch: The batch to run forward. The state of the batch is modified as it runs.
|
247
261
|
Returns:
|
248
|
-
A tuple of the final logit output of the target model, next tokens
|
249
|
-
the batch id (used for overlap schedule), and number of
|
262
|
+
A tuple of the final logit output of the target model, next tokens accepted,
|
263
|
+
the batch id (used for overlap schedule), and number of accepted tokens.
|
250
264
|
"""
|
251
265
|
if batch.forward_mode.is_decode():
|
252
266
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
253
267
|
spec_info = self.draft(batch)
|
254
|
-
logits_output, verify_output, model_worker_batch =
|
255
|
-
batch, spec_info
|
268
|
+
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
269
|
+
self.verify(batch, spec_info)
|
256
270
|
)
|
257
271
|
|
258
272
|
# If it is None, it means all requests are finished
|
@@ -264,21 +278,22 @@ class EAGLEWorker(TpModelWorker):
|
|
264
278
|
verify_output.verified_id,
|
265
279
|
model_worker_batch.bid,
|
266
280
|
sum(verify_output.accept_length_per_req_cpu),
|
281
|
+
can_run_cuda_graph,
|
267
282
|
)
|
268
283
|
elif batch.forward_mode.is_idle():
|
269
284
|
model_worker_batch = batch.get_model_worker_batch()
|
270
|
-
logits_output, next_token_ids =
|
271
|
-
model_worker_batch
|
285
|
+
logits_output, next_token_ids, _ = (
|
286
|
+
self.target_worker.forward_batch_generation(model_worker_batch)
|
272
287
|
)
|
273
288
|
|
274
|
-
return logits_output, next_token_ids, model_worker_batch.bid, 0
|
289
|
+
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
275
290
|
else:
|
276
291
|
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
277
292
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
278
293
|
self.forward_draft_extend(
|
279
294
|
batch, logits_output.hidden_states, next_token_ids
|
280
295
|
)
|
281
|
-
return logits_output, next_token_ids, bid, 0
|
296
|
+
return logits_output, next_token_ids, bid, 0, False
|
282
297
|
|
283
298
|
def forward_target_extend(
|
284
299
|
self, batch: ScheduleBatch
|
@@ -297,7 +312,7 @@ class EAGLEWorker(TpModelWorker):
|
|
297
312
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
298
313
|
model_worker_batch = batch.get_model_worker_batch()
|
299
314
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
300
|
-
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
315
|
+
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
301
316
|
model_worker_batch
|
302
317
|
)
|
303
318
|
return logits_output, next_token_ids, model_worker_batch.bid
|
@@ -478,9 +493,41 @@ class EAGLEWorker(TpModelWorker):
|
|
478
493
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
479
494
|
batch.spec_info = spec_info
|
480
495
|
model_worker_batch = batch.get_model_worker_batch()
|
481
|
-
|
482
|
-
|
496
|
+
|
497
|
+
if batch.has_grammar:
|
498
|
+
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
499
|
+
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
|
500
|
+
draft_tokens_cpu = spec_info.draft_token.view(
|
501
|
+
spec_info.retrive_next_token.shape
|
502
|
+
).cpu()
|
503
|
+
|
504
|
+
# Forward
|
505
|
+
logits_output, _, can_run_cuda_graph = (
|
506
|
+
self.target_worker.forward_batch_generation(
|
507
|
+
model_worker_batch, skip_sample=True
|
508
|
+
)
|
483
509
|
)
|
510
|
+
|
511
|
+
vocab_mask = None
|
512
|
+
if batch.has_grammar:
|
513
|
+
# Generate the logit mask for structured output.
|
514
|
+
# Overlap the CPU operations for bitmask generation with the forward pass.
|
515
|
+
vocab_mask = generate_token_bitmask(
|
516
|
+
batch.reqs,
|
517
|
+
spec_info,
|
518
|
+
retrieve_next_token_cpu,
|
519
|
+
retrieve_next_sibling_cpu,
|
520
|
+
draft_tokens_cpu,
|
521
|
+
batch.sampling_info.vocab_size,
|
522
|
+
)
|
523
|
+
|
524
|
+
if vocab_mask is not None:
|
525
|
+
assert spec_info.grammar is not None
|
526
|
+
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
527
|
+
# otherwise, this vocab mask will be the one from the previous extend stage
|
528
|
+
# and will be applied to produce wrong results
|
529
|
+
batch.sampling_info.vocab_mask = None
|
530
|
+
|
484
531
|
self._detect_nan_if_needed(logits_output)
|
485
532
|
spec_info.hidden_states = logits_output.hidden_states
|
486
533
|
res: EagleVerifyOutput = spec_info.verify(
|
@@ -488,14 +535,15 @@ class EAGLEWorker(TpModelWorker):
|
|
488
535
|
logits_output,
|
489
536
|
self.token_to_kv_pool_allocator,
|
490
537
|
self.page_size,
|
538
|
+
vocab_mask,
|
491
539
|
)
|
492
540
|
|
493
541
|
# Post process based on verified outputs.
|
494
|
-
# Pick indices that we care (
|
542
|
+
# Pick indices that we care (accepted)
|
495
543
|
logits_output.next_token_logits = logits_output.next_token_logits[
|
496
|
-
res.
|
544
|
+
res.accepted_indices
|
497
545
|
]
|
498
|
-
logits_output.hidden_states = logits_output.hidden_states[res.
|
546
|
+
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
499
547
|
|
500
548
|
# Prepare the batch for the next draft forwards.
|
501
549
|
batch.forward_mode = ForwardMode.DECODE
|
@@ -504,7 +552,7 @@ class EAGLEWorker(TpModelWorker):
|
|
504
552
|
if batch.return_logprob:
|
505
553
|
self.add_logprob_values(batch, res, logits_output)
|
506
554
|
|
507
|
-
return logits_output, res, model_worker_batch
|
555
|
+
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
508
556
|
|
509
557
|
def add_logprob_values(
|
510
558
|
self,
|
@@ -590,14 +638,14 @@ class EAGLEWorker(TpModelWorker):
|
|
590
638
|
model_worker_batch, self.draft_model_runner
|
591
639
|
)
|
592
640
|
forward_batch.return_logprob = False
|
593
|
-
logits_output = self.draft_model_runner.forward(forward_batch)
|
641
|
+
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
594
642
|
self._detect_nan_if_needed(logits_output)
|
595
643
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
596
644
|
assert forward_batch.spec_info is batch.spec_info
|
597
645
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
598
646
|
|
599
647
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
600
|
-
# Backup
|
648
|
+
# Backup fields that will be modified in-place
|
601
649
|
seq_lens_backup = batch.seq_lens.clone()
|
602
650
|
req_pool_indices_backup = batch.req_pool_indices
|
603
651
|
accept_length_backup = batch.spec_info.accept_length
|
@@ -617,7 +665,7 @@ class EAGLEWorker(TpModelWorker):
|
|
617
665
|
)
|
618
666
|
|
619
667
|
# Run
|
620
|
-
logits_output = self.draft_model_runner.forward(forward_batch)
|
668
|
+
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
621
669
|
|
622
670
|
self._detect_nan_if_needed(logits_output)
|
623
671
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
sglang/srt/utils.py
CHANGED
@@ -46,7 +46,19 @@ from importlib.util import find_spec
|
|
46
46
|
from io import BytesIO
|
47
47
|
from multiprocessing.reduction import ForkingPickler
|
48
48
|
from pathlib import Path
|
49
|
-
from typing import
|
49
|
+
from typing import (
|
50
|
+
Any,
|
51
|
+
Callable,
|
52
|
+
Dict,
|
53
|
+
Generic,
|
54
|
+
List,
|
55
|
+
Optional,
|
56
|
+
Protocol,
|
57
|
+
Set,
|
58
|
+
Tuple,
|
59
|
+
TypeVar,
|
60
|
+
Union,
|
61
|
+
)
|
50
62
|
|
51
63
|
import numpy as np
|
52
64
|
import psutil
|
@@ -125,10 +137,6 @@ builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
|
|
125
137
|
builtins.FP8_E4M3_MIN = FP8_E4M3_MIN
|
126
138
|
|
127
139
|
|
128
|
-
def is_rocm() -> bool:
|
129
|
-
return torch.cuda.is_available() and torch.version.hip
|
130
|
-
|
131
|
-
|
132
140
|
def is_cuda():
|
133
141
|
return torch.cuda.is_available() and torch.version.cuda
|
134
142
|
|
@@ -250,7 +258,7 @@ def mark_start(name, interval=0.1, color=0, indent=0):
|
|
250
258
|
torch.cuda.synchronize()
|
251
259
|
if time_infos.get(name, None) is None:
|
252
260
|
time_infos[name] = TimeInfo(name, interval, color, indent)
|
253
|
-
time_infos[name].acc_time -= time.
|
261
|
+
time_infos[name].acc_time -= time.perf_counter()
|
254
262
|
|
255
263
|
|
256
264
|
def mark_end(name):
|
@@ -258,7 +266,7 @@ def mark_end(name):
|
|
258
266
|
if not show_time_cost:
|
259
267
|
return
|
260
268
|
torch.cuda.synchronize()
|
261
|
-
time_infos[name].acc_time += time.
|
269
|
+
time_infos[name].acc_time += time.perf_counter()
|
262
270
|
if time_infos[name].check():
|
263
271
|
time_infos[name].pretty_print()
|
264
272
|
|
@@ -268,11 +276,11 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
268
276
|
def inner_func(*args, **kwargs):
|
269
277
|
torch.cuda.synchronize()
|
270
278
|
if show:
|
271
|
-
start_time = time.
|
279
|
+
start_time = time.perf_counter()
|
272
280
|
result = func(*args, **kwargs)
|
273
281
|
torch.cuda.synchronize()
|
274
282
|
if show:
|
275
|
-
cost_time = (time.
|
283
|
+
cost_time = (time.perf_counter() - start_time) * 1000
|
276
284
|
if cost_time > min_cost_ms:
|
277
285
|
print(f"Function {func.__name__} took {cost_time} ms to run.")
|
278
286
|
return result
|
@@ -282,7 +290,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
282
290
|
return wrapper
|
283
291
|
|
284
292
|
|
285
|
-
def get_available_gpu_memory(
|
293
|
+
def get_available_gpu_memory(
|
294
|
+
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
|
295
|
+
):
|
286
296
|
"""
|
287
297
|
Get available memory for cuda:gpu_id device.
|
288
298
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
@@ -344,10 +354,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
|
344
354
|
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
|
345
355
|
|
346
356
|
if distributed:
|
347
|
-
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
|
348
|
-
|
357
|
+
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
|
358
|
+
torch.distributed.all_reduce(
|
359
|
+
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
|
349
360
|
)
|
350
|
-
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
351
361
|
free_gpu_memory = tensor.item()
|
352
362
|
|
353
363
|
return free_gpu_memory / (1 << 30)
|
@@ -1849,6 +1859,8 @@ def get_cuda_version():
|
|
1849
1859
|
|
1850
1860
|
|
1851
1861
|
def launch_dummy_health_check_server(host, port):
|
1862
|
+
import asyncio
|
1863
|
+
|
1852
1864
|
import uvicorn
|
1853
1865
|
from fastapi import FastAPI, Response
|
1854
1866
|
|
@@ -1864,13 +1876,27 @@ def launch_dummy_health_check_server(host, port):
|
|
1864
1876
|
"""Check the health of the http server."""
|
1865
1877
|
return Response(status_code=200)
|
1866
1878
|
|
1867
|
-
uvicorn.
|
1879
|
+
config = uvicorn.Config(
|
1868
1880
|
app,
|
1869
1881
|
host=host,
|
1870
1882
|
port=port,
|
1871
1883
|
timeout_keep_alive=5,
|
1872
|
-
loop="
|
1884
|
+
loop="auto",
|
1885
|
+
log_config=None,
|
1886
|
+
log_level="warning",
|
1873
1887
|
)
|
1888
|
+
server = uvicorn.Server(config=config)
|
1889
|
+
|
1890
|
+
try:
|
1891
|
+
loop = asyncio.get_running_loop()
|
1892
|
+
logger.info(
|
1893
|
+
f"Dummy health check server scheduled on existing loop at {host}:{port}"
|
1894
|
+
)
|
1895
|
+
loop.create_task(server.serve())
|
1896
|
+
|
1897
|
+
except RuntimeError:
|
1898
|
+
logger.info(f"Starting dummy health check server at {host}:{port}")
|
1899
|
+
server.run()
|
1874
1900
|
|
1875
1901
|
|
1876
1902
|
def create_checksum(directory: str):
|
@@ -2075,8 +2101,6 @@ def is_fa3_default_architecture(hf_config):
|
|
2075
2101
|
"Qwen2ForCausalLM",
|
2076
2102
|
"Llama4ForConditionalGeneration",
|
2077
2103
|
"LlamaForCausalLM",
|
2078
|
-
"MistralForCausalLM",
|
2079
|
-
"MixtralForCausalLM",
|
2080
2104
|
"Gemma2ForCausalLM",
|
2081
2105
|
"Gemma3ForConditionalGeneration",
|
2082
2106
|
"Qwen3ForCausalLM",
|
@@ -2103,3 +2127,36 @@ def log_info_on_rank0(logger, msg):
|
|
2103
2127
|
|
2104
2128
|
if get_tensor_model_parallel_rank() == 0:
|
2105
2129
|
logger.info(msg)
|
2130
|
+
|
2131
|
+
|
2132
|
+
def load_json_config(data: str):
|
2133
|
+
try:
|
2134
|
+
return json.loads(data)
|
2135
|
+
except JSONDecodeError:
|
2136
|
+
return json.loads(Path(data).read_text())
|
2137
|
+
|
2138
|
+
|
2139
|
+
def dispose_tensor(x: torch.Tensor):
|
2140
|
+
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
|
2141
|
+
|
2142
|
+
|
2143
|
+
T = TypeVar("T")
|
2144
|
+
|
2145
|
+
|
2146
|
+
class Withable(Generic[T]):
|
2147
|
+
def __init__(self):
|
2148
|
+
self._value: Optional[T] = None
|
2149
|
+
|
2150
|
+
@property
|
2151
|
+
def value(self) -> T:
|
2152
|
+
return self._value
|
2153
|
+
|
2154
|
+
@contextmanager
|
2155
|
+
def with_value(self, new_value: T):
|
2156
|
+
assert self._value is None
|
2157
|
+
self._value = new_value
|
2158
|
+
try:
|
2159
|
+
yield
|
2160
|
+
finally:
|
2161
|
+
assert self._value is new_value
|
2162
|
+
self._value = None
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -90,7 +90,7 @@ def run_eval(args):
|
|
90
90
|
#####################################
|
91
91
|
|
92
92
|
# Run requests
|
93
|
-
tic = time.
|
93
|
+
tic = time.perf_counter()
|
94
94
|
states = few_shot_gsm8k.run_batch(
|
95
95
|
arguments,
|
96
96
|
temperature=args.temperature if hasattr(args, "temperature") else 0,
|
@@ -99,7 +99,7 @@ def run_eval(args):
|
|
99
99
|
return_logprob=getattr(args, "return_logprob", None),
|
100
100
|
logprob_start_len=getattr(args, "logprob_start_len", None),
|
101
101
|
)
|
102
|
-
latency = time.
|
102
|
+
latency = time.perf_counter() - tic
|
103
103
|
|
104
104
|
preds = []
|
105
105
|
for i in range(len(states)):
|
@@ -89,7 +89,7 @@ def run_eval(args):
|
|
89
89
|
}
|
90
90
|
|
91
91
|
# Run requests
|
92
|
-
tic = time.
|
92
|
+
tic = time.perf_counter()
|
93
93
|
|
94
94
|
loop = asyncio.get_event_loop()
|
95
95
|
|
@@ -98,7 +98,7 @@ def run_eval(args):
|
|
98
98
|
)
|
99
99
|
|
100
100
|
# End requests
|
101
|
-
latency = time.
|
101
|
+
latency = time.perf_counter() - tic
|
102
102
|
|
103
103
|
# Shutdown the engine
|
104
104
|
engine.shutdown()
|
sglang/test/run_eval.py
CHANGED
@@ -71,9 +71,9 @@ def run_eval(args):
|
|
71
71
|
)
|
72
72
|
|
73
73
|
# Run eval
|
74
|
-
tic = time.
|
74
|
+
tic = time.perf_counter()
|
75
75
|
result = eval_obj(sampler)
|
76
|
-
latency = time.
|
76
|
+
latency = time.perf_counter() - tic
|
77
77
|
|
78
78
|
# Dump reports
|
79
79
|
metrics = result.metrics | {"score": result.score}
|
sglang/test/runners.py
CHANGED
@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn.functional as F
|
22
|
+
import transformers
|
22
23
|
from transformers import (
|
24
|
+
AutoConfig,
|
23
25
|
AutoModel,
|
24
26
|
AutoModelForCausalLM,
|
25
27
|
AutoModelForVision2Seq,
|
@@ -211,7 +213,12 @@ class HFRunner:
|
|
211
213
|
|
212
214
|
# Load the model and tokenizer
|
213
215
|
if self.model_type == "generation":
|
214
|
-
|
216
|
+
config = AutoConfig.from_pretrained(model_path)
|
217
|
+
if model_archs := getattr(config, "architectures"):
|
218
|
+
model_cls = getattr(transformers, model_archs[0])
|
219
|
+
else:
|
220
|
+
model_cls = AutoModelForCausalLM
|
221
|
+
self.base_model = model_cls.from_pretrained(
|
215
222
|
model_path,
|
216
223
|
torch_dtype=torch_dtype,
|
217
224
|
trust_remote_code=self.trust_remote_code,
|