sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  """
4
4
  Support attention backend for FlashMLA.
5
-
6
- #TODO
7
- Enable speculative sampling in FlashMLA
8
5
  """
9
6
 
10
7
  from dataclasses import dataclass
@@ -14,8 +11,6 @@ import torch
14
11
  import triton
15
12
  from flash_mla import flash_mla_with_kvcache, get_mla_metadata
16
13
 
17
- from sglang.global_config import global_config
18
- from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
19
14
  from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
20
15
  from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
21
16
  from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -24,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
24
19
  if TYPE_CHECKING:
25
20
  from sglang.srt.layers.radix_attention import RadixAttention
26
21
  from sglang.srt.model_executor.model_runner import ModelRunner
27
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
28
22
  from sglang.srt.speculative.spec_info import SpecInfo
29
23
 
30
24
 
@@ -154,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
154
148
  def init_cuda_graph_state(
155
149
  self,
156
150
  max_bs: int,
151
+ max_num_tokens: int,
157
152
  block_kv_indices: Optional[torch.Tensor] = None,
158
153
  ):
159
154
  if block_kv_indices is None:
@@ -330,7 +325,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
330
325
  )
331
326
 
332
327
  def get_cuda_graph_seq_len_fill_value(self):
333
- return 1024
328
+ return 1
334
329
 
335
330
  def forward_decode(
336
331
  self,
@@ -464,11 +459,9 @@ class FlashMLAMultiStepDraftBackend:
464
459
  topk: int,
465
460
  speculative_num_steps: int,
466
461
  ):
467
- from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
468
-
469
462
  if topk > 1:
470
463
  raise ValueError(
471
- f"Currently FlashMLA only supports topk=1 for speculative decoding"
464
+ "Currently FlashMLA only supports topk=1 for speculative decoding"
472
465
  )
473
466
  self.topk = topk
474
467
  self.speculative_num_steps = speculative_num_steps
@@ -510,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
510
503
 
511
504
  self.common_template(forward_batch, call_fn)
512
505
 
513
- def init_cuda_graph_state(self, max_bs: int):
506
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
514
507
  for i in range(self.speculative_num_steps):
515
- self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
508
+ self.attn_backends[i].init_cuda_graph_state(
509
+ max_bs, max_num_tokens, block_kv_indices=None
510
+ )
516
511
 
517
512
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
518
513
  def call_fn(i, forward_batch):
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
32
32
  if forward_batch_child.batch_size > 0:
33
33
  child.init_forward_metadata(forward_batch=forward_batch_child)
34
34
 
35
- def init_cuda_graph_state(self, max_bs: int):
36
- self.primary.init_cuda_graph_state(max_bs=max_bs)
35
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
36
+ self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
37
37
  for item in self.children:
38
38
  # TODO for children, maybe can provide *smaller* max_bs to optimize
39
- item.init_cuda_graph_state(max_bs=max_bs)
39
+ item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
40
40
 
41
41
  def init_forward_metadata_capture_cuda_graph(
42
42
  self,
@@ -12,7 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
- from sglang.srt.utils import get_bool_env_var, get_device_core_count
15
+ from sglang.srt.utils import get_bool_env_var, get_device_core_count, next_power_of_2
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
20
20
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
21
 
22
22
 
23
- @triton.jit
24
- def get_num_kv_splits_triton(
25
- num_kv_splits_ptr,
26
- seq_lens_ptr,
27
- num_seq,
28
- num_group,
29
- num_head,
30
- num_kv_head,
31
- max_kv_splits,
32
- device_core_count,
33
- MAX_NUM_SEQ: tl.constexpr,
34
- ):
35
- # TODO: this method is tunable, we need more online serving data to tune it
36
- offs_seq = tl.arange(0, MAX_NUM_SEQ)
37
- mask_seq = offs_seq < num_seq
38
-
39
- seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
40
- max_seq_len = tl.max(seq_lens)
41
- seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
42
- min_seq_len = tl.min(seq_lens)
43
- if max_seq_len * 8 < min_seq_len * 10:
44
- min_seq_len = max_seq_len
45
- max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
46
- kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
47
-
48
- # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
49
- ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
50
- ext_device_core_count = tl.cast(
51
- device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
52
- )
53
- block_h, num_kv_group = 16, num_head // num_kv_head
54
- if num_kv_group == 1:
55
- token_grid = num_seq * num_group * num_head
56
- else:
57
- # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
58
- block_h = tl.minimum(block_h, num_kv_group)
59
- token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
60
- max_kv_splits_2 = tl.minimum(
61
- tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
62
- )
63
- kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
64
-
65
- num_kv_splits = tl.maximum(
66
- tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
67
- )
68
-
69
- offs_token = offs_seq * num_group
70
- mask_token = offs_token < num_seq * num_group
71
- for i in range(0, num_group):
72
- tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
73
-
74
-
75
- def update_sliding_window_buffer(
76
- window_kv_indptr,
77
- req_to_token,
78
- sliding_window_size,
79
- seq_lens,
80
- req_pool_indices,
81
- bs,
82
- device,
83
- ):
84
- window_kv_lens = torch.minimum(
85
- seq_lens,
86
- torch.tensor(sliding_window_size + 1),
87
- )
88
- window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
89
- window_kv_indptr = window_kv_indptr[: bs + 1]
90
- window_kv_indices = torch.empty(
91
- window_kv_indptr[-1], dtype=torch.int32, device=device
92
- )
93
- window_kv_start_idx = seq_lens - window_kv_lens
94
- create_flashinfer_kv_indices_triton[(bs,)](
95
- req_to_token,
96
- req_pool_indices,
97
- window_kv_lens,
98
- window_kv_indptr,
99
- window_kv_start_idx,
100
- window_kv_indices,
101
- req_to_token.stride(0),
102
- )
103
- return window_kv_indptr, window_kv_indices, window_kv_lens
104
-
105
-
106
- def update_sliding_window_buffer_cuda_graph(
107
- window_kv_indptr,
108
- window_kv_indices,
109
- req_to_token,
110
- sliding_window_size,
111
- seq_lens,
112
- req_pool_indices,
113
- bs,
114
- ):
115
- window_kv_lens = torch.minimum(
116
- seq_lens,
117
- torch.tensor(sliding_window_size + 1),
118
- )
119
- window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
120
- window_kv_indptr = window_kv_indptr[: bs + 1]
121
- window_kv_start_idx = seq_lens - window_kv_lens
122
- create_flashinfer_kv_indices_triton[(bs,)](
123
- req_to_token,
124
- req_pool_indices,
125
- window_kv_lens,
126
- window_kv_indptr,
127
- window_kv_start_idx,
128
- window_kv_indices,
129
- req_to_token.stride(0),
130
- )
131
- return window_kv_indptr, window_kv_lens
132
-
133
-
134
23
  @dataclass
135
24
  class ForwardMetadata:
136
25
  attn_logits: torch.Tensor
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
165
54
 
166
55
  super().__init__()
167
56
 
168
- self.decode_attention_fwd = decode_attention_fwd
169
- self.extend_attention_fwd = extend_attention_fwd
57
+ self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
58
+ self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
170
59
 
171
60
  self.skip_prefill = skip_prefill
172
61
 
@@ -372,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
372
261
  num_kv_splits = None
373
262
  attn_logits = None
374
263
  attn_lse = None
264
+
375
265
  elif forward_batch.forward_mode.is_draft_extend():
376
266
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
377
267
  spec_info.generate_attn_arg_prefill(
@@ -446,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
446
336
  )
447
337
 
448
338
  def init_cuda_graph_state(
449
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
339
+ self,
340
+ max_bs: int,
341
+ max_num_tokens: int,
342
+ kv_indices_buf: Optional[torch.Tensor] = None,
450
343
  ):
451
344
  self.cuda_graph_attn_logits = torch.zeros(
452
- (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
345
+ (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
453
346
  dtype=torch.float32,
454
347
  device=self.device,
455
348
  )
456
349
  self.cuda_graph_attn_lse = torch.zeros(
457
- (max_bs, self.num_head, self.max_kv_splits),
350
+ (max_num_tokens, self.num_head, self.max_kv_splits),
458
351
  dtype=torch.float32,
459
352
  device=self.device,
460
353
  )
461
354
  self.cuda_graph_num_kv_splits = torch.full(
462
- (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
355
+ (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
463
356
  )
464
357
  if kv_indices_buf is None:
465
358
  self.cuda_graph_kv_indices = torch.zeros(
466
- (max_bs * self.max_context_len),
359
+ (max_num_tokens * self.max_context_len),
467
360
  dtype=torch.int32,
468
361
  device=self.device,
469
362
  )
@@ -472,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
472
365
 
473
366
  if not self.skip_prefill:
474
367
  self.cuda_graph_custom_mask = torch.zeros(
475
- (max_bs * self.max_context_len),
368
+ (max_num_tokens * self.max_context_len),
476
369
  dtype=torch.uint8,
477
370
  device=self.device,
478
371
  )
@@ -480,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
480
373
  if self.sliding_window_size is not None and self.sliding_window_size > 0:
481
374
  if kv_indices_buf is None:
482
375
  self.cuda_graph_window_kv_indices = torch.zeros(
483
- (max_bs * self.sliding_window_size),
376
+ (max_num_tokens * self.sliding_window_size),
484
377
  dtype=torch.int32,
485
378
  device=self.device,
486
379
  )
@@ -488,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
488
381
  self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
489
382
 
490
383
  self.cuda_graph_window_num_kv_splits = torch.full(
491
- (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
384
+ (max_num_tokens,),
385
+ self.max_kv_splits,
386
+ dtype=torch.int32,
387
+ device=self.device,
492
388
  )
493
389
 
494
390
  def init_forward_metadata_capture_cuda_graph(
@@ -569,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
569
465
  )
570
466
 
571
467
  custom_mask = self.cuda_graph_custom_mask
468
+ custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
572
469
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
573
470
  mask_indptr = self.mask_indptr[: bs + 1]
574
471
  mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
@@ -877,6 +774,7 @@ class TritonMultiStepDraftBackend:
877
774
  self.device = model_runner.device
878
775
  # Cached variables for generate_draft_decode_kv_indices
879
776
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
777
+ self.page_size = model_runner.server_args.page_size
880
778
 
881
779
  def common_template(
882
780
  self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
@@ -894,14 +792,13 @@ class TritonMultiStepDraftBackend:
894
792
  kv_indices_buffer,
895
793
  self.kv_indptr,
896
794
  forward_batch.positions,
897
- num_seqs,
898
- self.topk,
899
795
  self.pool_len,
900
796
  kv_indices_buffer.shape[1],
901
797
  self.kv_indptr.shape[1],
902
- triton.next_power_of_2(num_seqs),
903
- triton.next_power_of_2(self.speculative_num_steps),
904
- triton.next_power_of_2(bs),
798
+ next_power_of_2(num_seqs),
799
+ next_power_of_2(self.speculative_num_steps),
800
+ next_power_of_2(bs),
801
+ self.page_size,
905
802
  )
906
803
 
907
804
  for i in range(self.speculative_num_steps):
@@ -932,15 +829,15 @@ class TritonMultiStepDraftBackend:
932
829
 
933
830
  self.common_template(forward_batch, kv_indices, call_fn)
934
831
 
935
- def init_cuda_graph_state(self, max_bs: int):
832
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
936
833
  self.cuda_graph_kv_indices = torch.zeros(
937
- (self.speculative_num_steps, max_bs * self.max_context_len),
834
+ (self.speculative_num_steps, max_num_tokens * self.max_context_len),
938
835
  dtype=torch.int32,
939
836
  device=self.device,
940
837
  )
941
838
  for i in range(self.speculative_num_steps):
942
839
  self.attn_backends[i].init_cuda_graph_state(
943
- max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
840
+ max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
944
841
  )
945
842
 
946
843
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -973,3 +870,114 @@ class TritonMultiStepDraftBackend:
973
870
  )
974
871
 
975
872
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
873
+
874
+
875
+ @triton.jit
876
+ def get_num_kv_splits_triton(
877
+ num_kv_splits_ptr,
878
+ seq_lens_ptr,
879
+ num_seq,
880
+ num_group,
881
+ num_head,
882
+ num_kv_head,
883
+ max_kv_splits,
884
+ device_core_count,
885
+ MAX_NUM_SEQ: tl.constexpr,
886
+ ):
887
+ # TODO: this method is tunable, we need more online serving data to tune it
888
+ offs_seq = tl.arange(0, MAX_NUM_SEQ)
889
+ mask_seq = offs_seq < num_seq
890
+
891
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
892
+ max_seq_len = tl.max(seq_lens)
893
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
894
+ min_seq_len = tl.min(seq_lens)
895
+ if max_seq_len * 8 < min_seq_len * 10:
896
+ min_seq_len = max_seq_len
897
+ max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
898
+ kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
899
+
900
+ # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
901
+ ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
902
+ ext_device_core_count = tl.cast(
903
+ device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
904
+ )
905
+ block_h, num_kv_group = 16, num_head // num_kv_head
906
+ if num_kv_group == 1:
907
+ token_grid = num_seq * num_group * num_head
908
+ else:
909
+ # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
910
+ block_h = tl.minimum(block_h, num_kv_group)
911
+ token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
912
+ max_kv_splits_2 = tl.minimum(
913
+ tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
914
+ )
915
+ kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
916
+
917
+ num_kv_splits = tl.maximum(
918
+ tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
919
+ )
920
+
921
+ offs_token = offs_seq * num_group
922
+ mask_token = offs_token < num_seq * num_group
923
+ for i in range(0, num_group):
924
+ tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
925
+
926
+
927
+ def update_sliding_window_buffer(
928
+ window_kv_indptr,
929
+ req_to_token,
930
+ sliding_window_size,
931
+ seq_lens,
932
+ req_pool_indices,
933
+ bs,
934
+ device,
935
+ ):
936
+ window_kv_lens = torch.minimum(
937
+ seq_lens,
938
+ torch.tensor(sliding_window_size + 1),
939
+ )
940
+ window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
941
+ window_kv_indptr = window_kv_indptr[: bs + 1]
942
+ window_kv_indices = torch.empty(
943
+ window_kv_indptr[-1], dtype=torch.int32, device=device
944
+ )
945
+ window_kv_start_idx = seq_lens - window_kv_lens
946
+ create_flashinfer_kv_indices_triton[(bs,)](
947
+ req_to_token,
948
+ req_pool_indices,
949
+ window_kv_lens,
950
+ window_kv_indptr,
951
+ window_kv_start_idx,
952
+ window_kv_indices,
953
+ req_to_token.stride(0),
954
+ )
955
+ return window_kv_indptr, window_kv_indices, window_kv_lens
956
+
957
+
958
+ def update_sliding_window_buffer_cuda_graph(
959
+ window_kv_indptr,
960
+ window_kv_indices,
961
+ req_to_token,
962
+ sliding_window_size,
963
+ seq_lens,
964
+ req_pool_indices,
965
+ bs,
966
+ ):
967
+ window_kv_lens = torch.minimum(
968
+ seq_lens,
969
+ torch.tensor(sliding_window_size + 1),
970
+ )
971
+ window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
972
+ window_kv_indptr = window_kv_indptr[: bs + 1]
973
+ window_kv_start_idx = seq_lens - window_kv_lens
974
+ create_flashinfer_kv_indices_triton[(bs,)](
975
+ req_to_token,
976
+ req_pool_indices,
977
+ window_kv_lens,
978
+ window_kv_indptr,
979
+ window_kv_start_idx,
980
+ window_kv_indices,
981
+ req_to_token.stride(0),
982
+ )
983
+ return window_kv_indptr, window_kv_lens
@@ -31,11 +31,6 @@ _is_hip = is_hip()
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
35
- logger.warning(
36
- "The following error message 'operation scheduled before its operands' can be ignored."
37
- )
38
-
39
34
 
40
35
  _MIN_BLOCK_KV = 32
41
36
 
@@ -713,7 +708,7 @@ def decode_attention_fwd(
713
708
  num_kv_splits,
714
709
  max_kv_splits,
715
710
  sm_scale,
716
- logit_cap,
711
+ logit_cap=logit_cap,
717
712
  )
718
713
  else:
719
714
  # GQA/MQA/MLA
@@ -729,5 +724,5 @@ def decode_attention_fwd(
729
724
  num_kv_splits,
730
725
  max_kv_splits,
731
726
  sm_scale,
732
- logit_cap,
727
+ logit_cap=logit_cap,
733
728
  )
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import dataclasses
4
+ import functools
3
5
  import math
4
- from functools import lru_cache, wraps
5
- from typing import Optional, Tuple
6
+ from functools import lru_cache
7
+ from typing import Any, Optional, Tuple, Union
6
8
 
7
9
  import torch
8
10
  import torch.nn as nn
9
11
  import torch.nn.functional as F
10
12
  from einops import rearrange
11
13
 
12
- from sglang.srt.utils import is_cuda
14
+ from sglang.srt.utils import is_cuda, print_info_once
13
15
 
14
16
  _is_cuda = is_cuda()
15
17
 
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
29
31
  from sglang.srt.layers.quantization import QuantizationConfig
30
32
  from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
31
33
  from sglang.srt.managers.schedule_batch import global_server_args_dict
32
- from sglang.srt.utils import add_prefix, logger
34
+ from sglang.srt.utils import add_prefix
33
35
 
34
36
  ROTARY_EMBED_CLASSES = {
35
37
  "normal": apply_rotary_pos_emb,
36
38
  }
37
39
 
38
40
 
39
- def execute_once(func):
40
- has_run = None
41
+ @dataclasses.dataclass
42
+ class SingletonCache:
43
+ data: Any = None
41
44
 
42
- @wraps(func)
43
- def wrapper(*args, **kwargs):
44
- nonlocal has_run
45
- if not has_run:
46
- func(*args, **kwargs)
47
- has_run = True
45
+ def set_data(self, value: Any) -> None:
46
+ self.data = value
48
47
 
49
- return wrapper
48
+ def get_data(self) -> Optional[Any]:
49
+ return self.data
50
50
 
51
+ def empty(self) -> bool:
52
+ return self.get_data() is None
51
53
 
52
- @execute_once
53
- def info_once(message: str):
54
- logger.info(message)
54
+
55
+ # TODO: requires real seqlens from images
56
+ @functools.lru_cache(maxsize=128)
57
+ def _get_cu_seqlens_for_shape(batch_size: int, seqlen: int, device) -> torch.Tensor:
58
+ """
59
+ Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
60
+ Caches the result based on these parameters.
61
+ """
62
+ cu_seqlens = torch.arange(
63
+ 0,
64
+ (batch_size + 1) * seqlen,
65
+ step=seqlen,
66
+ dtype=torch.int32,
67
+ device=device,
68
+ )
69
+ return cu_seqlens
55
70
 
56
71
 
57
72
  class VisionSdpaAttention(nn.Module):
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
265
280
  q: torch.Tensor,
266
281
  k: torch.Tensor,
267
282
  v: torch.Tensor,
268
- cu_seqlens: Optional[torch.Tensor],
269
- attention_mask: Optional[torch.Tensor] = None,
283
+ cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
284
+ bsz: int,
285
+ seq_len: int,
270
286
  **kwargs,
271
287
  ) -> torch.Tensor:
272
288
  r"""
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
275
291
  Returns:
276
292
  [b * s, h, head_size]
277
293
  """
278
- cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
294
+ if cu_seqlens is None:
295
+ cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
296
+ elif isinstance(cu_seqlens, SingletonCache):
297
+ if cu_seqlens.empty():
298
+ cu_seqlens.set_data(
299
+ _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
300
+ )
301
+ cu_seqlens = cu_seqlens.get_data()
302
+
303
+ cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
279
304
  seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
280
305
  max_seqlen = seq_lens.max().item()
281
306
  output = flash_attn_varlen_func(
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
346
371
  if global_server_args_dict["mm_attention_backend"] is None:
347
372
  if qkv_backend is None:
348
373
  qkv_backend = "sdpa"
349
- info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
374
+ print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
350
375
  else:
351
376
  qkv_backend = global_server_args_dict["mm_attention_backend"]
352
377
 
353
- info_once(f"Using {qkv_backend} as multimodal attention backend.")
378
+ print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
354
379
 
355
380
  self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
356
381
  head_dim=self.head_size,
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
423
448
  # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
424
449
  qkv, _ = self.qkv_proj(x)
425
450
 
426
- # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
451
+ # [s, b, head, head_dim_sum]
427
452
  new_x_shape = qkv.size()[:-1] + (
428
453
  head,
429
- 3 * self.hidden_size_per_attention_head,
454
+ self.q_size + 2 * self.kv_size,
430
455
  )
431
456
  qkv = qkv.view(*new_x_shape)
432
457
 
433
458
  # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
434
- q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
459
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
460
+
435
461
  # [s, b, head, head_size] --> [b, s, head, head_size]
436
462
  q, k, v = [
437
463
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
468
494
  k=k,
469
495
  v=v,
470
496
  bsz=bsz,
497
+ seq_len=s,
471
498
  cu_seqlens=cu_seqlens,
472
499
  attention_mask=attention_mask,
473
500
  )