sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -46,7 +46,6 @@ class ServerArgs:
46
46
  tokenizer_path: Optional[str] = None
47
47
  tokenizer_mode: str = "auto"
48
48
  skip_tokenizer_init: bool = False
49
- enable_tokenizer_batch_encode: bool = False
50
49
  load_format: str = "auto"
51
50
  trust_remote_code: bool = False
52
51
  dtype: str = "auto"
@@ -59,6 +58,7 @@ class ServerArgs:
59
58
  chat_template: Optional[str] = None
60
59
  completion_template: Optional[str] = None
61
60
  is_embedding: bool = False
61
+ enable_multimodal: Optional[bool] = None
62
62
  revision: Optional[str] = None
63
63
 
64
64
  # Port for the HTTP server
@@ -97,8 +97,13 @@ class ServerArgs:
97
97
  log_requests_level: int = 0
98
98
  show_time_cost: bool = False
99
99
  enable_metrics: bool = False
100
+ bucket_time_to_first_token: Optional[List[float]] = None
101
+ bucket_e2e_request_latency: Optional[List[float]] = None
102
+ bucket_inter_token_latency: Optional[List[float]] = None
103
+ collect_tokens_histogram: bool = False
100
104
  decode_log_interval: int = 40
101
105
  enable_request_time_stats_logging: bool = False
106
+ kv_events_config: Optional[str] = None
102
107
 
103
108
  # API related
104
109
  api_key: Optional[str] = None
@@ -120,6 +125,7 @@ class ServerArgs:
120
125
 
121
126
  # Model override args in JSON
122
127
  json_model_override_args: str = "{}"
128
+ preferred_sampling_params: Optional[str] = None
123
129
 
124
130
  # LoRA
125
131
  lora_paths: Optional[List[str]] = None
@@ -154,9 +160,9 @@ class ServerArgs:
154
160
  disable_cuda_graph: bool = False
155
161
  disable_cuda_graph_padding: bool = False
156
162
  enable_nccl_nvls: bool = False
163
+ enable_tokenizer_batch_encode: bool = False
157
164
  disable_outlines_disk_cache: bool = False
158
165
  disable_custom_all_reduce: bool = False
159
- enable_multimodal: Optional[bool] = None
160
166
  disable_overlap_schedule: bool = False
161
167
  enable_mixed_chunk: bool = False
162
168
  enable_dp_attention: bool = False
@@ -164,6 +170,17 @@ class ServerArgs:
164
170
  enable_ep_moe: bool = False
165
171
  enable_deepep_moe: bool = False
166
172
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
173
+ ep_num_redundant_experts: int = 0
174
+ ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
175
+ init_expert_location: str = "trivial"
176
+ enable_eplb: bool = False
177
+ eplb_rebalance_num_iterations: int = 1000
178
+ expert_distribution_recorder_mode: Optional[
179
+ Literal["stat", "per_pass", "per_token"]
180
+ ] = None
181
+ expert_distribution_recorder_buffer_size: Optional[int] = None
182
+ enable_expert_distribution_metrics: bool = False
183
+ deepep_config: Optional[str] = None
167
184
  enable_torch_compile: bool = False
168
185
  torch_compile_max_bs: int = 32
169
186
  cuda_graph_max_bs: Optional[int] = None
@@ -229,7 +246,7 @@ class ServerArgs:
229
246
  # Set mem fraction static, which depends on the tensor parallelism size
230
247
  if self.mem_fraction_static is None:
231
248
  parallel_size = self.tp_size * self.pp_size
232
- if gpu_mem <= 81920:
249
+ if gpu_mem is not None and gpu_mem <= 81920:
233
250
  if parallel_size >= 16:
234
251
  self.mem_fraction_static = 0.79
235
252
  elif parallel_size >= 8:
@@ -242,7 +259,7 @@ class ServerArgs:
242
259
  self.mem_fraction_static = 0.88
243
260
  else:
244
261
  self.mem_fraction_static = 0.88
245
- if gpu_mem > 96 * 1024:
262
+ if gpu_mem is not None and gpu_mem > 96 * 1024:
246
263
  mem_fraction = self.mem_fraction_static
247
264
  self.mem_fraction_static = min(
248
265
  mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
@@ -307,12 +324,6 @@ class ServerArgs:
307
324
  if self.grammar_backend is None:
308
325
  self.grammar_backend = "xgrammar"
309
326
 
310
- if self.pp_size > 1:
311
- self.disable_overlap_schedule = True
312
- logger.warning(
313
- "Overlap scheduler is disabled because of using pipeline parallelism."
314
- )
315
-
316
327
  # Data parallelism attention
317
328
  if self.enable_dp_attention:
318
329
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
@@ -354,6 +365,15 @@ class ServerArgs:
354
365
  "Pipeline parallelism is incompatible with overlap schedule."
355
366
  )
356
367
 
368
+ if self.expert_distribution_recorder_buffer_size is None:
369
+ # TODO pr-chain: enable this later
370
+ # if (x := self.eplb_rebalance_num_iterations) is not None:
371
+ # self.expert_distribution_recorder_buffer_size = x
372
+ if False:
373
+ pass
374
+ elif self.expert_distribution_recorder_mode is not None:
375
+ self.expert_distribution_recorder_buffer_size = 1000
376
+
357
377
  # Speculative Decoding
358
378
  if self.speculative_algorithm == "NEXTN":
359
379
  # NEXTN shares the same implementation of EAGLE
@@ -474,11 +494,6 @@ class ServerArgs:
474
494
  action="store_true",
475
495
  help="If set, skip init tokenizer and pass input_ids in generate request.",
476
496
  )
477
- parser.add_argument(
478
- "--enable-tokenizer-batch-encode",
479
- action="store_true",
480
- help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
481
- )
482
497
  parser.add_argument(
483
498
  "--load-format",
484
499
  type=str,
@@ -556,6 +571,7 @@ class ServerArgs:
556
571
  "w8a8_int8",
557
572
  "w8a8_fp8",
558
573
  "moe_wna16",
574
+ "qoq",
559
575
  ],
560
576
  help="The quantization method.",
561
577
  )
@@ -603,6 +619,12 @@ class ServerArgs:
603
619
  action="store_true",
604
620
  help="Whether to use a CausalLM as an embedding model.",
605
621
  )
622
+ parser.add_argument(
623
+ "--enable-multimodal",
624
+ default=ServerArgs.enable_multimodal,
625
+ action="store_true",
626
+ help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
627
+ )
606
628
  parser.add_argument(
607
629
  "--revision",
608
630
  type=str,
@@ -780,6 +802,39 @@ class ServerArgs:
780
802
  action="store_true",
781
803
  help="Enable log prometheus metrics.",
782
804
  )
805
+ parser.add_argument(
806
+ "--bucket-time-to-first-token",
807
+ type=float,
808
+ nargs="+",
809
+ default=ServerArgs.bucket_time_to_first_token,
810
+ help="The buckets of time to first token, specified as a list of floats.",
811
+ )
812
+ parser.add_argument(
813
+ "--bucket-inter-token-latency",
814
+ type=float,
815
+ nargs="+",
816
+ default=ServerArgs.bucket_inter_token_latency,
817
+ help="The buckets of inter-token latency, specified as a list of floats.",
818
+ )
819
+ parser.add_argument(
820
+ "--bucket-e2e-request-latency",
821
+ type=float,
822
+ nargs="+",
823
+ default=ServerArgs.bucket_e2e_request_latency,
824
+ help="The buckets of end-to-end request latency, specified as a list of floats.",
825
+ )
826
+ parser.add_argument(
827
+ "--collect-tokens-histogram",
828
+ action="store_true",
829
+ default=ServerArgs.collect_tokens_histogram,
830
+ help="Collect prompt/generation tokens histogram.",
831
+ )
832
+ parser.add_argument(
833
+ "--kv-events-config",
834
+ type=str,
835
+ default=None,
836
+ help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
837
+ )
783
838
  parser.add_argument(
784
839
  "--decode-log-interval",
785
840
  type=int,
@@ -868,6 +923,11 @@ class ServerArgs:
868
923
  help="A dictionary in JSON string format used to override default model configurations.",
869
924
  default=ServerArgs.json_model_override_args,
870
925
  )
926
+ parser.add_argument(
927
+ "--preferred-sampling-params",
928
+ type=str,
929
+ help="json-formatted sampling settings that will be returned in /get_model_info",
930
+ )
871
931
 
872
932
  # LoRA
873
933
  parser.add_argument(
@@ -896,6 +956,7 @@ class ServerArgs:
896
956
  "--attention-backend",
897
957
  type=str,
898
958
  choices=[
959
+ "aiter",
899
960
  "flashinfer",
900
961
  "triton",
901
962
  "torch_native",
@@ -1043,6 +1104,11 @@ class ServerArgs:
1043
1104
  action="store_true",
1044
1105
  help="Enable NCCL NVLS for prefill heavy requests when available.",
1045
1106
  )
1107
+ parser.add_argument(
1108
+ "--enable-tokenizer-batch-encode",
1109
+ action="store_true",
1110
+ help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
1111
+ )
1046
1112
  parser.add_argument(
1047
1113
  "--disable-outlines-disk-cache",
1048
1114
  action="store_true",
@@ -1053,12 +1119,6 @@ class ServerArgs:
1053
1119
  action="store_true",
1054
1120
  help="Disable the custom all-reduce kernel and fall back to NCCL.",
1055
1121
  )
1056
- parser.add_argument(
1057
- "--enable-multimodal",
1058
- default=ServerArgs.enable_multimodal,
1059
- action="store_true",
1060
- help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
1061
- )
1062
1122
  parser.add_argument(
1063
1123
  "--disable-overlap-schedule",
1064
1124
  action="store_true",
@@ -1072,7 +1132,7 @@ class ServerArgs:
1072
1132
  parser.add_argument(
1073
1133
  "--enable-dp-attention",
1074
1134
  action="store_true",
1075
- help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
1135
+ help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
1076
1136
  )
1077
1137
  parser.add_argument(
1078
1138
  "--enable-dp-lm-head",
@@ -1212,6 +1272,58 @@ class ServerArgs:
1212
1272
  default="auto",
1213
1273
  help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
1214
1274
  )
1275
+ parser.add_argument(
1276
+ "--ep-num-redundant-experts",
1277
+ type=int,
1278
+ default=ServerArgs.ep_num_redundant_experts,
1279
+ help="Allocate this number of redundant experts in expert parallel.",
1280
+ )
1281
+ parser.add_argument(
1282
+ "--ep-dispatch-algorithm",
1283
+ type=str,
1284
+ default=ServerArgs.ep_dispatch_algorithm,
1285
+ help="The algorithm to choose ranks for redundant experts in expert parallel.",
1286
+ )
1287
+ parser.add_argument(
1288
+ "--init-expert-location",
1289
+ type=str,
1290
+ default=ServerArgs.init_expert_location,
1291
+ help="Initial location of EP experts.",
1292
+ )
1293
+ parser.add_argument(
1294
+ "--enable-eplb",
1295
+ action="store_true",
1296
+ help="Enable EPLB algorithm",
1297
+ )
1298
+ parser.add_argument(
1299
+ "--eplb-rebalance-num-iterations",
1300
+ type=int,
1301
+ default=ServerArgs.eplb_rebalance_num_iterations,
1302
+ help="Number of iterations to automatically trigger a EPLB re-balance.",
1303
+ )
1304
+ parser.add_argument(
1305
+ "--expert-distribution-recorder-mode",
1306
+ type=str,
1307
+ default=ServerArgs.expert_distribution_recorder_mode,
1308
+ help="Mode of expert distribution recorder.",
1309
+ )
1310
+ parser.add_argument(
1311
+ "--expert-distribution-recorder-buffer-size",
1312
+ type=int,
1313
+ default=ServerArgs.expert_distribution_recorder_buffer_size,
1314
+ help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
1315
+ )
1316
+ parser.add_argument(
1317
+ "--enable-expert-distribution-metrics",
1318
+ action="store_true",
1319
+ help="Enable logging metrics for expert balancedness",
1320
+ )
1321
+ parser.add_argument(
1322
+ "--deepep-config",
1323
+ type=str,
1324
+ default=ServerArgs.deepep_config,
1325
+ help="Tuned DeepEP config suitable for your own cluster.",
1326
+ )
1215
1327
 
1216
1328
  parser.add_argument(
1217
1329
  "--n-share-experts-fusion",
@@ -1326,8 +1438,6 @@ class ServerArgs:
1326
1438
 
1327
1439
  # FIXME pp constraints
1328
1440
  if self.pp_size > 1:
1329
- logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
1330
- self.disable_overlap_schedule = True
1331
1441
  assert (
1332
1442
  self.disable_overlap_schedule
1333
1443
  and self.speculative_algorithm is None
@@ -9,15 +9,18 @@ import torch.nn.functional as F
9
9
  import triton
10
10
  import triton.language as tl
11
11
 
12
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
12
13
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
13
14
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
14
15
  from sglang.srt.managers.schedule_batch import (
16
+ Req,
15
17
  ScheduleBatch,
16
18
  get_last_loc,
17
19
  global_server_args_dict,
18
20
  )
19
21
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
20
22
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
23
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
21
24
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
22
25
  from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
23
26
 
@@ -187,6 +190,7 @@ class EagleVerifyInput:
187
190
  draft_token_num: int
188
191
  spec_steps: int
189
192
  capture_hidden_mode: CaptureHiddenMode
193
+ grammar: BaseGrammarObject = None
190
194
 
191
195
  @classmethod
192
196
  def create(
@@ -307,6 +311,7 @@ class EagleVerifyInput:
307
311
  logits_output: torch.Tensor,
308
312
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
309
313
  page_size: int,
314
+ vocab_mask: Optional[torch.Tensor] = None,
310
315
  ) -> torch.Tensor:
311
316
  """
312
317
  Verify and find accepted tokens based on logits output and batch
@@ -343,6 +348,13 @@ class EagleVerifyInput:
343
348
  torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
344
349
  )
345
350
 
351
+ # Apply grammar mask
352
+ if vocab_mask is not None:
353
+ assert self.grammar is not None
354
+ self.grammar.apply_vocab_mask(
355
+ logits=logits_output.next_token_logits, vocab_mask=vocab_mask
356
+ )
357
+
346
358
  # Sample tokens
347
359
  if batch.sampling_info.is_all_greedy:
348
360
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
@@ -440,6 +452,15 @@ class EagleVerifyInput:
440
452
  break
441
453
  else:
442
454
  new_accept_index_.append(idx)
455
+ # update grammar state
456
+ if req.grammar is not None:
457
+ try:
458
+ req.grammar.accept_token(id)
459
+ except ValueError as e:
460
+ logger.info(
461
+ f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
462
+ )
463
+ raise e
443
464
  if not req.finished():
444
465
  new_accept_index.extend(new_accept_index_)
445
466
  unfinished_index.append(i)
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
801
822
  accept_length.fill_(simulate_acc_len - 1)
802
823
  predict.fill_(100) # some legit token id
803
824
  return sim_accept_index
825
+
826
+
827
+ def traverse_tree(
828
+ retrieve_next_token: torch.Tensor,
829
+ retrieve_next_sibling: torch.Tensor,
830
+ draft_tokens: torch.Tensor,
831
+ grammar: BaseGrammarObject,
832
+ allocate_token_bitmask: torch.Tensor,
833
+ ):
834
+ """
835
+ Traverse the tree constructed by the draft model to generate the logits mask.
836
+ """
837
+ assert (
838
+ retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
839
+ )
840
+
841
+ allocate_token_bitmask.fill_(0)
842
+
843
+ def dfs(
844
+ curr: int,
845
+ retrieve_next_token: torch.Tensor,
846
+ retrieve_next_sibling: torch.Tensor,
847
+ parent_pos: int,
848
+ ):
849
+ if curr == 0:
850
+ # the first token generated by the target model, and thus it is always
851
+ # accepted from the previous iteration
852
+ accepted = True
853
+ else:
854
+ parent_bitmask = allocate_token_bitmask[parent_pos]
855
+ curr_token_id = draft_tokens[curr]
856
+ # 32 boolean bitmask values are packed into 32-bit integers
857
+ accepted = (
858
+ parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
859
+ ) != 0
860
+
861
+ if accepted:
862
+ if curr != 0:
863
+ # Accept the current token
864
+ grammar.accept_token(draft_tokens[curr])
865
+ if not grammar.is_terminated():
866
+ # Generate the bitmask for the current token
867
+ grammar.fill_vocab_mask(allocate_token_bitmask, curr)
868
+ if retrieve_next_token[curr] != -1:
869
+ # Visit the child node
870
+ dfs(
871
+ retrieve_next_token[curr],
872
+ retrieve_next_token,
873
+ retrieve_next_sibling,
874
+ curr,
875
+ )
876
+
877
+ if curr != 0:
878
+ # Rollback the current token
879
+ grammar.rollback(1)
880
+
881
+ if retrieve_next_sibling[curr] != -1:
882
+ # Visit the sibling node
883
+ dfs(
884
+ retrieve_next_sibling[curr],
885
+ retrieve_next_token,
886
+ retrieve_next_sibling,
887
+ parent_pos,
888
+ )
889
+
890
+ dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
891
+
892
+
893
+ def generate_token_bitmask(
894
+ reqs: List[Req],
895
+ verify_input: EagleVerifyInput,
896
+ retrieve_next_token_cpu: torch.Tensor,
897
+ retrieve_next_sibling_cpu: torch.Tensor,
898
+ draft_tokens_cpu: torch.Tensor,
899
+ vocab_size: int,
900
+ ):
901
+ """
902
+ Generate the logit mask for structured output.
903
+ Draft model's token can be either valid or invalid with respect to the grammar.
904
+ We need to perform DFS to figure out:
905
+ 1. which tokens are accepted by the grammar
906
+ 2. what is the corresponding logit mask.
907
+ """
908
+
909
+ num_draft_tokens = draft_tokens_cpu.shape[-1]
910
+
911
+ allocate_token_bitmask = None
912
+ assert len(reqs) == retrieve_next_token_cpu.shape[0]
913
+ grammar = None
914
+ for i, req in enumerate(reqs):
915
+ if req.grammar is not None:
916
+ if allocate_token_bitmask is None:
917
+ allocate_token_bitmask = req.grammar.allocate_vocab_mask(
918
+ vocab_size=vocab_size,
919
+ batch_size=draft_tokens_cpu.numel(),
920
+ device="cpu",
921
+ )
922
+ grammar = req.grammar
923
+ traverse_tree(
924
+ retrieve_next_token_cpu[i],
925
+ retrieve_next_sibling_cpu[i],
926
+ draft_tokens_cpu[i],
927
+ req.grammar,
928
+ allocate_token_bitmask[
929
+ i * num_draft_tokens : (i + 1) * num_draft_tokens
930
+ ],
931
+ )
932
+
933
+ verify_input.grammar = grammar
934
+ return allocate_token_bitmask
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
31
31
  EagleVerifyInput,
32
32
  EagleVerifyOutput,
33
33
  assign_draft_cache_locs,
34
+ generate_token_bitmask,
34
35
  select_top_k_tokens,
35
36
  )
36
37
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -199,6 +200,19 @@ class EAGLEWorker(TpModelWorker):
199
200
  self.draft_extend_attn_backend = None
200
201
  self.padded_static_len = self.speculative_num_steps + 1
201
202
  self.has_prefill_wrapper_verify = False
203
+ elif self.server_args.attention_backend == "flashmla":
204
+ from sglang.srt.layers.attention.flashmla_backend import (
205
+ FlashMLAMultiStepDraftBackend,
206
+ )
207
+
208
+ self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
209
+ self.draft_model_runner,
210
+ self.topk,
211
+ self.speculative_num_steps,
212
+ )
213
+ self.draft_extend_attn_backend = None
214
+ self.padded_static_len = self.speculative_num_steps + 1
215
+ self.has_prefill_wrapper_verify = False
202
216
  else:
203
217
  raise ValueError(
204
218
  f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
@@ -215,7 +229,7 @@ class EAGLEWorker(TpModelWorker):
215
229
  return
216
230
 
217
231
  # Capture draft
218
- tic = time.time()
232
+ tic = time.perf_counter()
219
233
  before_mem = get_available_gpu_memory(self.device, self.gpu_id)
220
234
  logger.info(
221
235
  f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
@@ -223,7 +237,7 @@ class EAGLEWorker(TpModelWorker):
223
237
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
224
238
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
225
239
  logger.info(
226
- f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
240
+ f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
227
241
  )
228
242
 
229
243
  # Capture extend
@@ -479,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
479
493
  batch.forward_mode = ForwardMode.TARGET_VERIFY
480
494
  batch.spec_info = spec_info
481
495
  model_worker_batch = batch.get_model_worker_batch()
496
+
497
+ if batch.has_grammar:
498
+ retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
499
+ retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
500
+ draft_tokens_cpu = spec_info.draft_token.view(
501
+ spec_info.retrive_next_token.shape
502
+ ).cpu()
503
+
504
+ # Forward
482
505
  logits_output, _, can_run_cuda_graph = (
483
506
  self.target_worker.forward_batch_generation(
484
507
  model_worker_batch, skip_sample=True
485
508
  )
486
509
  )
510
+
511
+ vocab_mask = None
512
+ if batch.has_grammar:
513
+ # Generate the logit mask for structured output.
514
+ # Overlap the CPU operations for bitmask generation with the forward pass.
515
+ vocab_mask = generate_token_bitmask(
516
+ batch.reqs,
517
+ spec_info,
518
+ retrieve_next_token_cpu,
519
+ retrieve_next_sibling_cpu,
520
+ draft_tokens_cpu,
521
+ batch.sampling_info.vocab_size,
522
+ )
523
+
524
+ if vocab_mask is not None:
525
+ assert spec_info.grammar is not None
526
+ vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
527
+ # otherwise, this vocab mask will be the one from the previous extend stage
528
+ # and will be applied to produce wrong results
529
+ batch.sampling_info.vocab_mask = None
530
+
487
531
  self._detect_nan_if_needed(logits_output)
488
532
  spec_info.hidden_states = logits_output.hidden_states
489
533
  res: EagleVerifyOutput = spec_info.verify(
@@ -491,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
491
535
  logits_output,
492
536
  self.token_to_kv_pool_allocator,
493
537
  self.page_size,
538
+ vocab_mask,
494
539
  )
495
540
 
496
541
  # Post process based on verified outputs.