sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -9,15 +9,18 @@ import torch.nn.functional as F
9
9
  import triton
10
10
  import triton.language as tl
11
11
 
12
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
12
13
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
13
14
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
14
15
  from sglang.srt.managers.schedule_batch import (
16
+ Req,
15
17
  ScheduleBatch,
16
18
  get_last_loc,
17
19
  global_server_args_dict,
18
20
  )
19
21
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
20
22
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
23
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
21
24
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
22
25
  from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
23
26
 
@@ -167,12 +170,12 @@ class EagleVerifyOutput:
167
170
  draft_input: EagleDraftInput
168
171
  # Logit outputs from target worker
169
172
  logits_output: LogitsProcessorOutput
170
- # Accepeted token ids including the bonus token
173
+ # Accepted token ids including the bonus token
171
174
  verified_id: torch.Tensor
172
- # Accepeted token length per sequence in a batch in CPU.
175
+ # Accepted token length per sequence in a batch in CPU.
173
176
  accept_length_per_req_cpu: List[int]
174
- # Accepeted indices from logits_output.next_token_logits
175
- accepeted_indices: torch.Tensor
177
+ # Accepted indices from logits_output.next_token_logits
178
+ accepted_indices: torch.Tensor
176
179
 
177
180
 
178
181
  @dataclass
@@ -187,6 +190,7 @@ class EagleVerifyInput:
187
190
  draft_token_num: int
188
191
  spec_steps: int
189
192
  capture_hidden_mode: CaptureHiddenMode
193
+ grammar: BaseGrammarObject = None
190
194
 
191
195
  @classmethod
192
196
  def create(
@@ -307,6 +311,7 @@ class EagleVerifyInput:
307
311
  logits_output: torch.Tensor,
308
312
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
309
313
  page_size: int,
314
+ vocab_mask: Optional[torch.Tensor] = None,
310
315
  ) -> torch.Tensor:
311
316
  """
312
317
  Verify and find accepted tokens based on logits output and batch
@@ -316,7 +321,7 @@ class EagleVerifyInput:
316
321
 
317
322
  This API updates values inside logits_output based on the accepted
318
323
  tokens. I.e., logits_output.next_token_logits only contains
319
- accepeted token logits.
324
+ accepted token logits.
320
325
  """
321
326
  bs = self.retrive_index.shape[0]
322
327
  candidates = self.draft_token.reshape(bs, self.draft_token_num)
@@ -343,6 +348,13 @@ class EagleVerifyInput:
343
348
  torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
344
349
  )
345
350
 
351
+ # Apply grammar mask
352
+ if vocab_mask is not None:
353
+ assert self.grammar is not None
354
+ self.grammar.apply_vocab_mask(
355
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
356
+ )
357
+
346
358
  # Sample tokens
347
359
  if batch.sampling_info.is_all_greedy:
348
360
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
@@ -440,6 +452,15 @@ class EagleVerifyInput:
440
452
  break
441
453
  else:
442
454
  new_accept_index_.append(idx)
455
+ # update grammar state
456
+ if req.grammar is not None:
457
+ try:
458
+ req.grammar.accept_token(id)
459
+ except ValueError as e:
460
+ logger.info(
461
+ f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
462
+ )
463
+ raise e
443
464
  if not req.finished():
444
465
  new_accept_index.extend(new_accept_index_)
445
466
  unfinished_index.append(i)
@@ -493,7 +514,7 @@ class EagleVerifyInput:
493
514
  logits_output=logits_output,
494
515
  verified_id=verified_id,
495
516
  accept_length_per_req_cpu=accept_length_cpu,
496
- accepeted_indices=accept_index,
517
+ accepted_indices=accept_index,
497
518
  )
498
519
  else:
499
520
  assign_req_to_token_pool[(bs,)](
@@ -539,7 +560,7 @@ class EagleVerifyInput:
539
560
  logits_output=logits_output,
540
561
  verified_id=verified_id,
541
562
  accept_length_per_req_cpu=accept_length_cpu,
542
- accepeted_indices=accept_index,
563
+ accepted_indices=accept_index,
543
564
  )
544
565
 
545
566
 
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
801
822
  accept_length.fill_(simulate_acc_len - 1)
802
823
  predict.fill_(100) # some legit token id
803
824
  return sim_accept_index
825
+
826
+
827
+ def traverse_tree(
828
+ retrieve_next_token: torch.Tensor,
829
+ retrieve_next_sibling: torch.Tensor,
830
+ draft_tokens: torch.Tensor,
831
+ grammar: BaseGrammarObject,
832
+ allocate_token_bitmask: torch.Tensor,
833
+ ):
834
+ """
835
+ Traverse the tree constructed by the draft model to generate the logits mask.
836
+ """
837
+ assert (
838
+ retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
839
+ )
840
+
841
+ allocate_token_bitmask.fill_(0)
842
+
843
+ def dfs(
844
+ curr: int,
845
+ retrieve_next_token: torch.Tensor,
846
+ retrieve_next_sibling: torch.Tensor,
847
+ parent_pos: int,
848
+ ):
849
+ if curr == 0:
850
+ # the first token generated by the target model, and thus it is always
851
+ # accepted from the previous iteration
852
+ accepted = True
853
+ else:
854
+ parent_bitmask = allocate_token_bitmask[parent_pos]
855
+ curr_token_id = draft_tokens[curr]
856
+ # 32 boolean bitmask values are packed into 32-bit integers
857
+ accepted = (
858
+ parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
859
+ ) != 0
860
+
861
+ if accepted:
862
+ if curr != 0:
863
+ # Accept the current token
864
+ grammar.accept_token(draft_tokens[curr])
865
+ if not grammar.is_terminated():
866
+ # Generate the bitmask for the current token
867
+ grammar.fill_vocab_mask(allocate_token_bitmask, curr)
868
+ if retrieve_next_token[curr] != -1:
869
+ # Visit the child node
870
+ dfs(
871
+ retrieve_next_token[curr],
872
+ retrieve_next_token,
873
+ retrieve_next_sibling,
874
+ curr,
875
+ )
876
+
877
+ if curr != 0:
878
+ # Rollback the current token
879
+ grammar.rollback(1)
880
+
881
+ if retrieve_next_sibling[curr] != -1:
882
+ # Visit the sibling node
883
+ dfs(
884
+ retrieve_next_sibling[curr],
885
+ retrieve_next_token,
886
+ retrieve_next_sibling,
887
+ parent_pos,
888
+ )
889
+
890
+ dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
891
+
892
+
893
+ def generate_token_bitmask(
894
+ reqs: List[Req],
895
+ verify_input: EagleVerifyInput,
896
+ retrieve_next_token_cpu: torch.Tensor,
897
+ retrieve_next_sibling_cpu: torch.Tensor,
898
+ draft_tokens_cpu: torch.Tensor,
899
+ vocab_size: int,
900
+ ):
901
+ """
902
+ Generate the logit mask for structured output.
903
+ Draft model's token can be either valid or invalid with respect to the grammar.
904
+ We need to perform DFS to figure out:
905
+ 1. which tokens are accepted by the grammar
906
+ 2. what is the corresponding logit mask.
907
+ """
908
+
909
+ num_draft_tokens = draft_tokens_cpu.shape[-1]
910
+
911
+ allocate_token_bitmask = None
912
+ assert len(reqs) == retrieve_next_token_cpu.shape[0]
913
+ grammar = None
914
+ for i, req in enumerate(reqs):
915
+ if req.grammar is not None:
916
+ if allocate_token_bitmask is None:
917
+ allocate_token_bitmask = req.grammar.allocate_vocab_mask(
918
+ vocab_size=vocab_size,
919
+ batch_size=draft_tokens_cpu.numel(),
920
+ device="cpu",
921
+ )
922
+ grammar = req.grammar
923
+ traverse_tree(
924
+ retrieve_next_token_cpu[i],
925
+ retrieve_next_sibling_cpu[i],
926
+ draft_tokens_cpu[i],
927
+ req.grammar,
928
+ allocate_token_bitmask[
929
+ i * num_draft_tokens : (i + 1) * num_draft_tokens
930
+ ],
931
+ )
932
+
933
+ verify_input.grammar = grammar
934
+ return allocate_token_bitmask
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
31
31
  EagleVerifyInput,
32
32
  EagleVerifyOutput,
33
33
  assign_draft_cache_locs,
34
+ generate_token_bitmask,
34
35
  select_top_k_tokens,
35
36
  )
36
37
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -199,9 +200,22 @@ class EAGLEWorker(TpModelWorker):
199
200
  self.draft_extend_attn_backend = None
200
201
  self.padded_static_len = self.speculative_num_steps + 1
201
202
  self.has_prefill_wrapper_verify = False
203
+ elif self.server_args.attention_backend == "flashmla":
204
+ from sglang.srt.layers.attention.flashmla_backend import (
205
+ FlashMLAMultiStepDraftBackend,
206
+ )
207
+
208
+ self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
209
+ self.draft_model_runner,
210
+ self.topk,
211
+ self.speculative_num_steps,
212
+ )
213
+ self.draft_extend_attn_backend = None
214
+ self.padded_static_len = self.speculative_num_steps + 1
215
+ self.has_prefill_wrapper_verify = False
202
216
  else:
203
217
  raise ValueError(
204
- f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
218
+ f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
205
219
  )
206
220
 
207
221
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
@@ -215,7 +229,7 @@ class EAGLEWorker(TpModelWorker):
215
229
  return
216
230
 
217
231
  # Capture draft
218
- tic = time.time()
232
+ tic = time.perf_counter()
219
233
  before_mem = get_available_gpu_memory(self.device, self.gpu_id)
220
234
  logger.info(
221
235
  f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
@@ -223,7 +237,7 @@ class EAGLEWorker(TpModelWorker):
223
237
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
224
238
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
225
239
  logger.info(
226
- f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
240
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
227
241
  )
228
242
 
229
243
  # Capture extend
@@ -245,14 +259,14 @@ class EAGLEWorker(TpModelWorker):
245
259
  Args:
246
260
  batch: The batch to run forward. The state of the batch is modified as it runs.
247
261
  Returns:
248
- A tuple of the final logit output of the target model, next tokens accepeted,
249
- the batch id (used for overlap schedule), and number of accepeted tokens.
262
+ A tuple of the final logit output of the target model, next tokens accepted,
263
+ the batch id (used for overlap schedule), and number of accepted tokens.
250
264
  """
251
265
  if batch.forward_mode.is_decode():
252
266
  with self.draft_tp_context(self.draft_model_runner.tp_group):
253
267
  spec_info = self.draft(batch)
254
- logits_output, verify_output, model_worker_batch = self.verify(
255
- batch, spec_info
268
+ logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
269
+ self.verify(batch, spec_info)
256
270
  )
257
271
 
258
272
  # If it is None, it means all requests are finished
@@ -264,21 +278,22 @@ class EAGLEWorker(TpModelWorker):
264
278
  verify_output.verified_id,
265
279
  model_worker_batch.bid,
266
280
  sum(verify_output.accept_length_per_req_cpu),
281
+ can_run_cuda_graph,
267
282
  )
268
283
  elif batch.forward_mode.is_idle():
269
284
  model_worker_batch = batch.get_model_worker_batch()
270
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
271
- model_worker_batch
285
+ logits_output, next_token_ids, _ = (
286
+ self.target_worker.forward_batch_generation(model_worker_batch)
272
287
  )
273
288
 
274
- return logits_output, next_token_ids, model_worker_batch.bid, 0
289
+ return logits_output, next_token_ids, model_worker_batch.bid, 0, False
275
290
  else:
276
291
  logits_output, next_token_ids, bid = self.forward_target_extend(batch)
277
292
  with self.draft_tp_context(self.draft_model_runner.tp_group):
278
293
  self.forward_draft_extend(
279
294
  batch, logits_output.hidden_states, next_token_ids
280
295
  )
281
- return logits_output, next_token_ids, bid, 0
296
+ return logits_output, next_token_ids, bid, 0, False
282
297
 
283
298
  def forward_target_extend(
284
299
  self, batch: ScheduleBatch
@@ -297,7 +312,7 @@ class EAGLEWorker(TpModelWorker):
297
312
  # We need the full hidden states to prefill the KV cache of the draft model.
298
313
  model_worker_batch = batch.get_model_worker_batch()
299
314
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
300
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
315
+ logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
301
316
  model_worker_batch
302
317
  )
303
318
  return logits_output, next_token_ids, model_worker_batch.bid
@@ -478,9 +493,41 @@ class EAGLEWorker(TpModelWorker):
478
493
  batch.forward_mode = ForwardMode.TARGET_VERIFY
479
494
  batch.spec_info = spec_info
480
495
  model_worker_batch = batch.get_model_worker_batch()
481
- logits_output, _ = self.target_worker.forward_batch_generation(
482
- model_worker_batch, skip_sample=True
496
+
497
+ if batch.has_grammar:
498
+ retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
499
+ retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
500
+ draft_tokens_cpu = spec_info.draft_token.view(
501
+ spec_info.retrive_next_token.shape
502
+ ).cpu()
503
+
504
+ # Forward
505
+ logits_output, _, can_run_cuda_graph = (
506
+ self.target_worker.forward_batch_generation(
507
+ model_worker_batch, skip_sample=True
508
+ )
483
509
  )
510
+
511
+ vocab_mask = None
512
+ if batch.has_grammar:
513
+ # Generate the logit mask for structured output.
514
+ # Overlap the CPU operations for bitmask generation with the forward pass.
515
+ vocab_mask = generate_token_bitmask(
516
+ batch.reqs,
517
+ spec_info,
518
+ retrieve_next_token_cpu,
519
+ retrieve_next_sibling_cpu,
520
+ draft_tokens_cpu,
521
+ batch.sampling_info.vocab_size,
522
+ )
523
+
524
+ if vocab_mask is not None:
525
+ assert spec_info.grammar is not None
526
+ vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
527
+ # otherwise, this vocab mask will be the one from the previous extend stage
528
+ # and will be applied to produce wrong results
529
+ batch.sampling_info.vocab_mask = None
530
+
484
531
  self._detect_nan_if_needed(logits_output)
485
532
  spec_info.hidden_states = logits_output.hidden_states
486
533
  res: EagleVerifyOutput = spec_info.verify(
@@ -488,14 +535,15 @@ class EAGLEWorker(TpModelWorker):
488
535
  logits_output,
489
536
  self.token_to_kv_pool_allocator,
490
537
  self.page_size,
538
+ vocab_mask,
491
539
  )
492
540
 
493
541
  # Post process based on verified outputs.
494
- # Pick indices that we care (accepeted)
542
+ # Pick indices that we care (accepted)
495
543
  logits_output.next_token_logits = logits_output.next_token_logits[
496
- res.accepeted_indices
544
+ res.accepted_indices
497
545
  ]
498
- logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
546
+ logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
499
547
 
500
548
  # Prepare the batch for the next draft forwards.
501
549
  batch.forward_mode = ForwardMode.DECODE
@@ -504,7 +552,7 @@ class EAGLEWorker(TpModelWorker):
504
552
  if batch.return_logprob:
505
553
  self.add_logprob_values(batch, res, logits_output)
506
554
 
507
- return logits_output, res, model_worker_batch
555
+ return logits_output, res, model_worker_batch, can_run_cuda_graph
508
556
 
509
557
  def add_logprob_values(
510
558
  self,
@@ -590,14 +638,14 @@ class EAGLEWorker(TpModelWorker):
590
638
  model_worker_batch, self.draft_model_runner
591
639
  )
592
640
  forward_batch.return_logprob = False
593
- logits_output = self.draft_model_runner.forward(forward_batch)
641
+ logits_output, _ = self.draft_model_runner.forward(forward_batch)
594
642
  self._detect_nan_if_needed(logits_output)
595
643
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
596
644
  assert forward_batch.spec_info is batch.spec_info
597
645
  self.capture_for_decode(logits_output, forward_batch.spec_info)
598
646
 
599
647
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
600
- # Backup fileds that will be modified in-place
648
+ # Backup fields that will be modified in-place
601
649
  seq_lens_backup = batch.seq_lens.clone()
602
650
  req_pool_indices_backup = batch.req_pool_indices
603
651
  accept_length_backup = batch.spec_info.accept_length
@@ -617,7 +665,7 @@ class EAGLEWorker(TpModelWorker):
617
665
  )
618
666
 
619
667
  # Run
620
- logits_output = self.draft_model_runner.forward(forward_batch)
668
+ logits_output, _ = self.draft_model_runner.forward(forward_batch)
621
669
 
622
670
  self._detect_nan_if_needed(logits_output)
623
671
  self.capture_for_decode(logits_output, forward_batch.spec_info)
sglang/srt/utils.py CHANGED
@@ -46,7 +46,19 @@ from importlib.util import find_spec
46
46
  from io import BytesIO
47
47
  from multiprocessing.reduction import ForkingPickler
48
48
  from pathlib import Path
49
- from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
49
+ from typing import (
50
+ Any,
51
+ Callable,
52
+ Dict,
53
+ Generic,
54
+ List,
55
+ Optional,
56
+ Protocol,
57
+ Set,
58
+ Tuple,
59
+ TypeVar,
60
+ Union,
61
+ )
50
62
 
51
63
  import numpy as np
52
64
  import psutil
@@ -125,10 +137,6 @@ builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
125
137
  builtins.FP8_E4M3_MIN = FP8_E4M3_MIN
126
138
 
127
139
 
128
- def is_rocm() -> bool:
129
- return torch.cuda.is_available() and torch.version.hip
130
-
131
-
132
140
  def is_cuda():
133
141
  return torch.cuda.is_available() and torch.version.cuda
134
142
 
@@ -250,7 +258,7 @@ def mark_start(name, interval=0.1, color=0, indent=0):
250
258
  torch.cuda.synchronize()
251
259
  if time_infos.get(name, None) is None:
252
260
  time_infos[name] = TimeInfo(name, interval, color, indent)
253
- time_infos[name].acc_time -= time.time()
261
+ time_infos[name].acc_time -= time.perf_counter()
254
262
 
255
263
 
256
264
  def mark_end(name):
@@ -258,7 +266,7 @@ def mark_end(name):
258
266
  if not show_time_cost:
259
267
  return
260
268
  torch.cuda.synchronize()
261
- time_infos[name].acc_time += time.time()
269
+ time_infos[name].acc_time += time.perf_counter()
262
270
  if time_infos[name].check():
263
271
  time_infos[name].pretty_print()
264
272
 
@@ -268,11 +276,11 @@ def calculate_time(show=False, min_cost_ms=0.0):
268
276
  def inner_func(*args, **kwargs):
269
277
  torch.cuda.synchronize()
270
278
  if show:
271
- start_time = time.time()
279
+ start_time = time.perf_counter()
272
280
  result = func(*args, **kwargs)
273
281
  torch.cuda.synchronize()
274
282
  if show:
275
- cost_time = (time.time() - start_time) * 1000
283
+ cost_time = (time.perf_counter() - start_time) * 1000
276
284
  if cost_time > min_cost_ms:
277
285
  print(f"Function {func.__name__} took {cost_time} ms to run.")
278
286
  return result
@@ -282,7 +290,9 @@ def calculate_time(show=False, min_cost_ms=0.0):
282
290
  return wrapper
283
291
 
284
292
 
285
- def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
293
+ def get_available_gpu_memory(
294
+ device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
295
+ ):
286
296
  """
287
297
  Get available memory for cuda:gpu_id device.
288
298
  When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -344,10 +354,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
344
354
  free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
345
355
 
346
356
  if distributed:
347
- tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
348
- torch.device(device, gpu_id)
357
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
358
+ torch.distributed.all_reduce(
359
+ tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
349
360
  )
350
- torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
351
361
  free_gpu_memory = tensor.item()
352
362
 
353
363
  return free_gpu_memory / (1 << 30)
@@ -1849,6 +1859,8 @@ def get_cuda_version():
1849
1859
 
1850
1860
 
1851
1861
  def launch_dummy_health_check_server(host, port):
1862
+ import asyncio
1863
+
1852
1864
  import uvicorn
1853
1865
  from fastapi import FastAPI, Response
1854
1866
 
@@ -1864,13 +1876,27 @@ def launch_dummy_health_check_server(host, port):
1864
1876
  """Check the health of the http server."""
1865
1877
  return Response(status_code=200)
1866
1878
 
1867
- uvicorn.run(
1879
+ config = uvicorn.Config(
1868
1880
  app,
1869
1881
  host=host,
1870
1882
  port=port,
1871
1883
  timeout_keep_alive=5,
1872
- loop="uvloop",
1884
+ loop="auto",
1885
+ log_config=None,
1886
+ log_level="warning",
1873
1887
  )
1888
+ server = uvicorn.Server(config=config)
1889
+
1890
+ try:
1891
+ loop = asyncio.get_running_loop()
1892
+ logger.info(
1893
+ f"Dummy health check server scheduled on existing loop at {host}:{port}"
1894
+ )
1895
+ loop.create_task(server.serve())
1896
+
1897
+ except RuntimeError:
1898
+ logger.info(f"Starting dummy health check server at {host}:{port}")
1899
+ server.run()
1874
1900
 
1875
1901
 
1876
1902
  def create_checksum(directory: str):
@@ -2075,8 +2101,6 @@ def is_fa3_default_architecture(hf_config):
2075
2101
  "Qwen2ForCausalLM",
2076
2102
  "Llama4ForConditionalGeneration",
2077
2103
  "LlamaForCausalLM",
2078
- "MistralForCausalLM",
2079
- "MixtralForCausalLM",
2080
2104
  "Gemma2ForCausalLM",
2081
2105
  "Gemma3ForConditionalGeneration",
2082
2106
  "Qwen3ForCausalLM",
@@ -2103,3 +2127,36 @@ def log_info_on_rank0(logger, msg):
2103
2127
 
2104
2128
  if get_tensor_model_parallel_rank() == 0:
2105
2129
  logger.info(msg)
2130
+
2131
+
2132
+ def load_json_config(data: str):
2133
+ try:
2134
+ return json.loads(data)
2135
+ except JSONDecodeError:
2136
+ return json.loads(Path(data).read_text())
2137
+
2138
+
2139
+ def dispose_tensor(x: torch.Tensor):
2140
+ x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
2141
+
2142
+
2143
+ T = TypeVar("T")
2144
+
2145
+
2146
+ class Withable(Generic[T]):
2147
+ def __init__(self):
2148
+ self._value: Optional[T] = None
2149
+
2150
+ @property
2151
+ def value(self) -> T:
2152
+ return self._value
2153
+
2154
+ @contextmanager
2155
+ def with_value(self, new_value: T):
2156
+ assert self._value is None
2157
+ self._value = new_value
2158
+ try:
2159
+ yield
2160
+ finally:
2161
+ assert self._value is new_value
2162
+ self._value = None
@@ -90,7 +90,7 @@ def run_eval(args):
90
90
  #####################################
91
91
 
92
92
  # Run requests
93
- tic = time.time()
93
+ tic = time.perf_counter()
94
94
  states = few_shot_gsm8k.run_batch(
95
95
  arguments,
96
96
  temperature=args.temperature if hasattr(args, "temperature") else 0,
@@ -99,7 +99,7 @@ def run_eval(args):
99
99
  return_logprob=getattr(args, "return_logprob", None),
100
100
  logprob_start_len=getattr(args, "logprob_start_len", None),
101
101
  )
102
- latency = time.time() - tic
102
+ latency = time.perf_counter() - tic
103
103
 
104
104
  preds = []
105
105
  for i in range(len(states)):
@@ -89,7 +89,7 @@ def run_eval(args):
89
89
  }
90
90
 
91
91
  # Run requests
92
- tic = time.time()
92
+ tic = time.perf_counter()
93
93
 
94
94
  loop = asyncio.get_event_loop()
95
95
 
@@ -98,7 +98,7 @@ def run_eval(args):
98
98
  )
99
99
 
100
100
  # End requests
101
- latency = time.time() - tic
101
+ latency = time.perf_counter() - tic
102
102
 
103
103
  # Shutdown the engine
104
104
  engine.shutdown()
sglang/test/run_eval.py CHANGED
@@ -71,9 +71,9 @@ def run_eval(args):
71
71
  )
72
72
 
73
73
  # Run eval
74
- tic = time.time()
74
+ tic = time.perf_counter()
75
75
  result = eval_obj(sampler)
76
- latency = time.time() - tic
76
+ latency = time.perf_counter() - tic
77
77
 
78
78
  # Dump reports
79
79
  metrics = result.metrics | {"score": result.score}
sglang/test/runners.py CHANGED
@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
22
+ import transformers
22
23
  from transformers import (
24
+ AutoConfig,
23
25
  AutoModel,
24
26
  AutoModelForCausalLM,
25
27
  AutoModelForVision2Seq,
@@ -211,7 +213,12 @@ class HFRunner:
211
213
 
212
214
  # Load the model and tokenizer
213
215
  if self.model_type == "generation":
214
- self.base_model = AutoModelForCausalLM.from_pretrained(
216
+ config = AutoConfig.from_pretrained(model_path)
217
+ if model_archs := getattr(config, "architectures"):
218
+ model_cls = getattr(transformers, model_archs[0])
219
+ else:
220
+ model_cls = AutoModelForCausalLM
221
+ self.base_model = model_cls.from_pretrained(
215
222
  model_path,
216
223
  torch_dtype=torch_dtype,
217
224
  trust_remote_code=self.trust_remote_code,