sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,14 @@ from __future__ import annotations
17
17
 
18
18
  import bisect
19
19
  import inspect
20
+ import logging
20
21
  import os
21
22
  from contextlib import contextmanager
22
23
  from typing import TYPE_CHECKING, Callable, Optional, Union
23
24
 
24
25
  import torch
25
26
  import tqdm
27
+ from torch.profiler import ProfilerActivity, profile
26
28
 
27
29
  from sglang.srt.custom_op import CustomOp
28
30
  from sglang.srt.distributed import get_tensor_model_parallel_rank
@@ -40,11 +42,18 @@ from sglang.srt.model_executor.forward_batch_info import (
40
42
  from sglang.srt.patch_torch import monkey_patch_torch_compile
41
43
  from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
42
44
  from sglang.srt.utils import (
45
+ empty_context,
43
46
  get_available_gpu_memory,
44
47
  get_device_memory_capacity,
45
48
  rank0_log,
49
+ require_attn_tp_gather,
50
+ require_gathered_buffer,
51
+ require_mlp_sync,
52
+ require_mlp_tp_gather,
46
53
  )
47
54
 
55
+ logger = logging.getLogger(__name__)
56
+
48
57
  if TYPE_CHECKING:
49
58
  from sglang.srt.model_executor.model_runner import ModelRunner
50
59
 
@@ -147,10 +156,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
147
156
  )
148
157
 
149
158
  gpu_mem = get_device_memory_capacity()
150
- if gpu_mem is not None and gpu_mem > 96 * 1024:
151
- capture_bs += list(range(160, 257, 8))
152
- if gpu_mem is not None and gpu_mem > 180 * 1000:
153
- capture_bs += list(range(256, 513, 16))
159
+ if gpu_mem is not None:
160
+ if gpu_mem > 90 * 1024: # H200, H20
161
+ capture_bs += list(range(160, 257, 8))
162
+ if gpu_mem > 160 * 1000: # B200, MI300
163
+ capture_bs += list(range(256, 513, 16))
154
164
 
155
165
  if max(capture_bs) > model_runner.req_to_token_pool.size:
156
166
  # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -201,12 +211,17 @@ class CudaGraphRunner:
201
211
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
202
212
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
203
213
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
204
- self.enable_dp_attention = model_runner.server_args.enable_dp_attention
205
- self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
214
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
215
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
216
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
217
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
206
218
  self.enable_two_batch_overlap = (
207
219
  model_runner.server_args.enable_two_batch_overlap
208
220
  )
209
221
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
222
+ self.enable_profile_cuda_graph = (
223
+ model_runner.server_args.enable_profile_cuda_graph
224
+ )
210
225
  self.tp_size = model_runner.server_args.tp_size
211
226
  self.dp_size = model_runner.server_args.dp_size
212
227
  self.pp_size = model_runner.server_args.pp_size
@@ -226,16 +241,20 @@ class CudaGraphRunner:
226
241
  self.model_runner.server_args.speculative_num_draft_tokens
227
242
  )
228
243
 
244
+ # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
245
+ if model_runner.server_args.enable_return_hidden_states:
246
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
247
+
229
248
  # Attention backend
230
249
  self.max_bs = max(self.capture_bs)
231
250
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
232
- if global_server_args_dict["attention_backend"] == "flashmla":
233
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
234
- else:
235
- self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
251
+ self.model_runner.attn_backend.init_cuda_graph_state(
252
+ self.max_bs, self.max_num_token
253
+ )
236
254
  self.seq_len_fill_value = (
237
255
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
238
256
  )
257
+
239
258
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
240
259
  self.encoder_len_fill_value = 0
241
260
  self.seq_lens_cpu = torch.full(
@@ -286,18 +305,30 @@ class CudaGraphRunner:
286
305
  else:
287
306
  self.encoder_lens = None
288
307
 
289
- if self.enable_dp_attention or self.enable_sp_layernorm:
290
- # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
308
+ if self.require_gathered_buffer:
291
309
  self.gathered_buffer = torch.zeros(
292
310
  (
293
- self.max_bs * self.dp_size * self.num_tokens_per_bs,
311
+ self.max_num_token,
294
312
  self.model_runner.model_config.hidden_size,
295
313
  ),
296
314
  dtype=self.model_runner.dtype,
297
315
  )
298
- self.global_num_tokens_gpu = torch.zeros(
299
- (self.dp_size,), dtype=torch.int32
300
- )
316
+ if self.require_mlp_tp_gather:
317
+ self.global_num_tokens_gpu = torch.zeros(
318
+ (self.dp_size,), dtype=torch.int32
319
+ )
320
+ else:
321
+ assert self.require_attn_tp_gather
322
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
323
+
324
+ self.custom_mask = torch.ones(
325
+ (
326
+ (self.seq_lens.sum().item() + self.max_num_token)
327
+ * self.num_tokens_per_bs
328
+ ),
329
+ dtype=torch.bool,
330
+ device="cuda",
331
+ )
301
332
 
302
333
  # Capture
303
334
  try:
@@ -309,20 +340,23 @@ class CudaGraphRunner:
309
340
  )
310
341
 
311
342
  def can_run(self, forward_batch: ForwardBatch):
312
- if self.enable_dp_attention or self.enable_sp_layernorm:
313
- total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
314
-
315
- is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
316
- total_global_tokens in self.graphs
317
- if self.disable_padding
318
- else total_global_tokens <= self.max_bs
343
+ if self.require_mlp_tp_gather:
344
+ cuda_graph_bs = (
345
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
346
+ if self.model_runner.spec_algorithm.is_eagle()
347
+ else sum(forward_batch.global_num_tokens_cpu)
319
348
  )
320
349
  else:
321
- is_bs_supported = (
322
- forward_batch.batch_size in self.graphs
323
- if self.disable_padding
324
- else forward_batch.batch_size <= self.max_bs
325
- )
350
+ cuda_graph_bs = forward_batch.batch_size
351
+
352
+ is_bs_supported = (
353
+ cuda_graph_bs in self.graphs
354
+ if self.disable_padding
355
+ else cuda_graph_bs <= self.max_bs
356
+ )
357
+
358
+ if self.require_mlp_sync:
359
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
326
360
 
327
361
  # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
328
362
  # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
@@ -333,50 +367,91 @@ class CudaGraphRunner:
333
367
  else True
334
368
  )
335
369
 
370
+ requested_capture_hidden_mode = max(
371
+ forward_batch.capture_hidden_mode,
372
+ (
373
+ forward_batch.spec_info.capture_hidden_mode
374
+ if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
375
+ is not None
376
+ else CaptureHiddenMode.NULL
377
+ ),
378
+ )
379
+ capture_hidden_mode_matches = (
380
+ requested_capture_hidden_mode == CaptureHiddenMode.NULL
381
+ or requested_capture_hidden_mode == self.capture_hidden_mode
382
+ )
336
383
  is_tbo_supported = (
337
384
  forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
338
385
  )
339
386
 
340
- return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
387
+ return (
388
+ is_bs_supported
389
+ and is_encoder_lens_supported
390
+ and is_tbo_supported
391
+ and capture_hidden_mode_matches
392
+ )
341
393
 
342
- def capture(self):
343
- with graph_capture() as graph_capture_context:
344
- self.stream = graph_capture_context.stream
345
- avail_mem = get_available_gpu_memory(
346
- self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
394
+ def capture(self) -> None:
395
+ profile_context = empty_context()
396
+ if self.enable_profile_cuda_graph:
397
+ profile_context = profile(
398
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
399
+ record_shapes=True,
347
400
  )
348
- # Reverse the order to enable better memory sharing across cuda graphs.
349
- capture_range = (
350
- tqdm.tqdm(list(reversed(self.capture_bs)))
351
- if get_tensor_model_parallel_rank() == 0
352
- else reversed(self.capture_bs)
353
- )
354
- for bs in capture_range:
355
- if get_tensor_model_parallel_rank() == 0:
356
- avail_mem = get_available_gpu_memory(
357
- self.model_runner.device,
358
- self.model_runner.gpu_id,
359
- empty_cache=False,
360
- )
361
- capture_range.set_description(
362
- f"Capturing batches ({avail_mem=:.2f} GB)"
363
- )
364
401
 
365
- with patch_model(
366
- self.model_runner.model,
367
- bs in self.compile_bs,
368
- num_tokens=bs * self.num_tokens_per_bs,
369
- tp_group=self.model_runner.tp_group,
370
- ) as forward:
371
- (
372
- graph,
373
- output_buffers,
374
- ) = self.capture_one_batch_size(bs, forward)
375
- self.graphs[bs] = graph
376
- self.output_buffers[bs] = output_buffers
377
-
378
- # Save gemlite cache after each capture
379
- save_gemlite_cache()
402
+ with graph_capture() as graph_capture_context:
403
+ with profile_context as prof:
404
+ self.stream = graph_capture_context.stream
405
+ avail_mem = get_available_gpu_memory(
406
+ self.model_runner.device,
407
+ self.model_runner.gpu_id,
408
+ empty_cache=False,
409
+ )
410
+ # Reverse the order to enable better memory sharing across cuda graphs.
411
+ capture_range = (
412
+ tqdm.tqdm(list(reversed(self.capture_bs)))
413
+ if get_tensor_model_parallel_rank() == 0
414
+ else reversed(self.capture_bs)
415
+ )
416
+ for i, bs in enumerate(capture_range):
417
+ if get_tensor_model_parallel_rank() == 0:
418
+ avail_mem = get_available_gpu_memory(
419
+ self.model_runner.device,
420
+ self.model_runner.gpu_id,
421
+ empty_cache=False,
422
+ )
423
+ capture_range.set_description(
424
+ f"Capturing batches ({avail_mem=:.2f} GB)"
425
+ )
426
+
427
+ with patch_model(
428
+ self.model_runner.model,
429
+ bs in self.compile_bs,
430
+ num_tokens=bs * self.num_tokens_per_bs,
431
+ tp_group=self.model_runner.tp_group,
432
+ ) as forward:
433
+ (
434
+ graph,
435
+ output_buffers,
436
+ ) = self.capture_one_batch_size(bs, forward)
437
+ self.graphs[bs] = graph
438
+ self.output_buffers[bs] = output_buffers
439
+
440
+ # Save gemlite cache after each capture
441
+ save_gemlite_cache()
442
+
443
+ if self.enable_profile_cuda_graph:
444
+ log_message = (
445
+ "Sorted by CUDA Time:\n"
446
+ + prof.key_averages(group_by_input_shape=True).table(
447
+ sort_by="cuda_time_total", row_limit=10
448
+ )
449
+ + "\n\nSorted by CPU Time:\n"
450
+ + prof.key_averages(group_by_input_shape=True).table(
451
+ sort_by="cpu_time_total", row_limit=10
452
+ )
453
+ )
454
+ logger.info(log_message)
380
455
 
381
456
  def capture_one_batch_size(self, bs: int, forward: Callable):
382
457
  graph = torch.cuda.CUDAGraph()
@@ -402,11 +477,11 @@ class CudaGraphRunner:
402
477
  {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
403
478
  )
404
479
 
405
- if self.enable_dp_attention or self.enable_sp_layernorm:
480
+ if self.require_mlp_tp_gather:
406
481
  self.global_num_tokens_gpu.copy_(
407
482
  torch.tensor(
408
483
  [
409
- num_tokens // self.dp_size + (i < bs % self.dp_size)
484
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
410
485
  for i in range(self.dp_size)
411
486
  ],
412
487
  dtype=torch.int32,
@@ -415,6 +490,16 @@ class CudaGraphRunner:
415
490
  )
416
491
  global_num_tokens = self.global_num_tokens_gpu
417
492
  gathered_buffer = self.gathered_buffer[:num_tokens]
493
+ elif self.require_attn_tp_gather:
494
+ self.global_num_tokens_gpu.copy_(
495
+ torch.tensor(
496
+ [num_tokens],
497
+ dtype=torch.int32,
498
+ device=input_ids.device,
499
+ )
500
+ )
501
+ global_num_tokens = self.global_num_tokens_gpu
502
+ gathered_buffer = self.gathered_buffer[:num_tokens]
418
503
  else:
419
504
  global_num_tokens = None
420
505
  gathered_buffer = None
@@ -443,7 +528,7 @@ class CudaGraphRunner:
443
528
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
444
529
  attn_backend=self.model_runner.attn_backend,
445
530
  out_cache_loc=out_cache_loc,
446
- seq_lens_sum=seq_lens.sum(),
531
+ seq_lens_sum=seq_lens.sum().item(),
447
532
  encoder_lens=encoder_lens,
448
533
  return_logprob=False,
449
534
  positions=positions,
@@ -509,21 +594,34 @@ class CudaGraphRunner:
509
594
  return graph, out
510
595
 
511
596
  def recapture_if_needed(self, forward_batch: ForwardBatch):
512
- # If the capture_hidden_mode changes, we need to recapture the graph
513
- hidden_mode_from_spec_info = getattr(
597
+
598
+ # If the required capture_hidden_mode changes, we need to recapture the graph
599
+
600
+ # These are the different factors that can influence the capture_hidden_mode
601
+ capture_hidden_mode_required_by_forward_batch = (
602
+ forward_batch.capture_hidden_mode
603
+ )
604
+ capture_hidden_mode_required_by_spec_info = getattr(
514
605
  forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
515
606
  )
516
- if (
517
- forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
518
- and self.capture_hidden_mode != CaptureHiddenMode.FULL
519
- ):
520
- self.capture_hidden_mode = CaptureHiddenMode.FULL
521
- self.capture()
522
- elif (
523
- forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
524
- and self.capture_hidden_mode != hidden_mode_from_spec_info
525
- ):
526
- self.capture_hidden_mode = hidden_mode_from_spec_info
607
+ capture_hidden_mode_required_for_returning_hidden_states = (
608
+ CaptureHiddenMode.FULL
609
+ if self.model_runner.server_args.enable_return_hidden_states
610
+ else CaptureHiddenMode.NULL
611
+ )
612
+
613
+ # Determine the highest capture_hidden_mode required
614
+ # (If we have FULL, we can emulate LAST or NULL)
615
+ # (If we have LAST, we can emulate NULL)
616
+ required_capture_hidden_mode = max(
617
+ capture_hidden_mode_required_by_forward_batch,
618
+ capture_hidden_mode_required_by_spec_info,
619
+ capture_hidden_mode_required_for_returning_hidden_states,
620
+ )
621
+
622
+ # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
623
+ if self.capture_hidden_mode != required_capture_hidden_mode:
624
+ self.capture_hidden_mode = required_capture_hidden_mode
527
625
  self.capture()
528
626
 
529
627
  def replay_prepare(
@@ -537,15 +635,18 @@ class CudaGraphRunner:
537
635
  raw_num_token = raw_bs * self.num_tokens_per_bs
538
636
 
539
637
  # Pad
540
- if self.enable_dp_attention or self.enable_sp_layernorm:
541
- index = bisect.bisect_left(
542
- self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
638
+ if self.require_mlp_tp_gather:
639
+ total_batch_size = (
640
+ sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
641
+ if self.model_runner.spec_algorithm.is_eagle()
642
+ else sum(forward_batch.global_num_tokens_cpu)
543
643
  )
644
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
544
645
  else:
545
646
  index = bisect.bisect_left(self.capture_bs, raw_bs)
546
647
  bs = self.capture_bs[index]
547
648
  if bs != raw_bs:
548
- self.seq_lens.fill_(1)
649
+ self.seq_lens.fill_(self.seq_len_fill_value)
549
650
  self.out_cache_loc.zero_()
550
651
 
551
652
  # Common inputs
@@ -557,7 +658,7 @@ class CudaGraphRunner:
557
658
 
558
659
  if forward_batch.seq_lens_cpu is not None:
559
660
  if bs != raw_bs:
560
- self.seq_lens_cpu.fill_(1)
661
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
561
662
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
562
663
 
563
664
  if pp_proxy_tensors:
@@ -569,27 +670,28 @@ class CudaGraphRunner:
569
670
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
570
671
  if forward_batch.mrope_positions is not None:
571
672
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
572
- if self.enable_dp_attention or self.enable_sp_layernorm:
673
+ if self.require_gathered_buffer:
573
674
  self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
574
675
  if enable_num_token_non_padded(self.model_runner.server_args):
575
676
  self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
576
677
  if self.enable_two_batch_overlap:
577
678
  self.tbo_plugin.replay_prepare(
578
- forward_mode=forward_batch.forward_mode,
679
+ forward_mode=self.capture_forward_mode,
579
680
  bs=bs,
580
681
  num_token_non_padded=len(forward_batch.input_ids),
581
682
  )
582
-
683
+ if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
684
+ forward_batch.spec_info.custom_mask = self.custom_mask
583
685
  # Attention backend
584
686
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
585
687
  bs,
586
- self.req_pool_indices,
587
- self.seq_lens,
588
- forward_batch.seq_lens_sum + (bs - raw_bs),
589
- self.encoder_lens,
590
- forward_batch.forward_mode,
688
+ self.req_pool_indices[:bs],
689
+ self.seq_lens[:bs],
690
+ forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
691
+ self.encoder_lens[:bs] if self.is_encoder_decoder else None,
692
+ self.capture_forward_mode,
591
693
  forward_batch.spec_info,
592
- seq_lens_cpu=self.seq_lens_cpu,
694
+ seq_lens_cpu=self.seq_lens_cpu[:bs],
593
695
  )
594
696
 
595
697
  # Store fields
@@ -637,11 +739,7 @@ class CudaGraphRunner:
637
739
  else:
638
740
  spec_info = EagleVerifyInput(
639
741
  draft_token=None,
640
- custom_mask=torch.ones(
641
- (num_tokens * self.model_runner.model_config.context_len),
642
- dtype=torch.bool,
643
- device="cuda",
644
- ),
742
+ custom_mask=self.custom_mask,
645
743
  positions=None,
646
744
  retrive_index=None,
647
745
  retrive_next_token=None,
@@ -31,6 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
+ from functools import total_ordering
34
35
  from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
35
36
 
36
37
  import torch
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
117
118
  return self == ForwardMode.DECODE or self == ForwardMode.IDLE
118
119
 
119
120
 
121
+ @total_ordering
120
122
  class CaptureHiddenMode(IntEnum):
121
123
  # Do not capture anything.
122
- NULL = auto()
123
- # Capture hidden states of all tokens.
124
- FULL = auto()
124
+ NULL = 0
125
125
  # Capture a hidden state of the last token.
126
- LAST = auto()
126
+ LAST = 1
127
+ # Capture hidden states of all tokens.
128
+ FULL = 2
127
129
 
128
130
  def need_capture(self):
129
131
  return self != CaptureHiddenMode.NULL
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
134
136
  def is_last(self):
135
137
  return self == CaptureHiddenMode.LAST
136
138
 
139
+ def __lt__(self, other):
140
+ return self.value < other.value
141
+
137
142
 
138
143
  @dataclass
139
144
  class ForwardBatch:
@@ -219,6 +224,9 @@ class ForwardBatch:
219
224
  # For input embeddings
220
225
  input_embeds: Optional[torch.tensor] = None
221
226
 
227
+ # For cross-encoder model
228
+ token_type_ids: Optional[torch.Tensor] = None
229
+
222
230
  # Sampling info
223
231
  sampling_info: SamplingBatchInfo = None
224
232
 
@@ -295,6 +303,7 @@ class ForwardBatch:
295
303
  spec_info=batch.spec_info,
296
304
  capture_hidden_mode=batch.capture_hidden_mode,
297
305
  input_embeds=batch.input_embeds,
306
+ token_type_ids=batch.token_type_ids,
298
307
  tbo_split_seq_index=batch.tbo_split_seq_index,
299
308
  )
300
309
  device = model_runner.device
@@ -311,17 +320,30 @@ class ForwardBatch:
311
320
 
312
321
  # For DP attention
313
322
  if batch.global_num_tokens is not None:
314
- ret.global_num_tokens_cpu = batch.global_num_tokens
323
+
324
+ spec_num_draft_tokens = (
325
+ batch.spec_num_draft_tokens
326
+ if batch.spec_num_draft_tokens is not None
327
+ else 1
328
+ )
329
+ global_num_tokens = [
330
+ x * spec_num_draft_tokens for x in batch.global_num_tokens
331
+ ]
332
+ global_num_tokens_for_logprob = [
333
+ x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
334
+ ]
335
+
336
+ ret.global_num_tokens_cpu = global_num_tokens
315
337
  ret.global_num_tokens_gpu = torch.tensor(
316
- batch.global_num_tokens, dtype=torch.int64
338
+ global_num_tokens, dtype=torch.int64
317
339
  ).to(device, non_blocking=True)
318
340
 
319
- ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
341
+ ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
320
342
  ret.global_num_tokens_for_logprob_gpu = torch.tensor(
321
- batch.global_num_tokens_for_logprob, dtype=torch.int64
343
+ global_num_tokens_for_logprob, dtype=torch.int64
322
344
  ).to(device, non_blocking=True)
323
345
 
324
- sum_len = sum(batch.global_num_tokens)
346
+ sum_len = sum(global_num_tokens)
325
347
  ret.gathered_buffer = torch.zeros(
326
348
  (sum_len, model_runner.model_config.hidden_size),
327
349
  dtype=model_runner.dtype,
@@ -351,8 +373,8 @@ class ForwardBatch:
351
373
  ret.extend_prefix_lens = torch.tensor(
352
374
  batch.extend_prefix_lens, dtype=torch.int32
353
375
  ).to(device, non_blocking=True)
376
+ ret.extend_num_tokens = batch.extend_num_tokens
354
377
  if support_triton(model_runner.server_args.attention_backend):
355
- ret.extend_num_tokens = batch.extend_num_tokens
356
378
  positions, ret.extend_start_loc = compute_position_triton(
357
379
  ret.extend_prefix_lens,
358
380
  ret.extend_seq_lens,