sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for TRTLLM MLA kernels from flashinfer.
5
+ """
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
15
+ from sglang.srt.layers.attention.utils import (
16
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
17
+ create_flashmla_kv_indices_triton,
18
+ )
19
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
21
+ from sglang.srt.utils import is_flashinfer_available
22
+
23
+ if is_flashinfer_available():
24
+ import flashinfer
25
+
26
+ if TYPE_CHECKING:
27
+ from sglang.srt.layers.radix_attention import RadixAttention
28
+ from sglang.srt.model_executor.model_runner import ModelRunner
29
+ from sglang.srt.speculative.spec_info import SpecInfo
30
+
31
+ # Constants
32
+ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
33
+
34
+ # Block constraint from flashinfer requirements
35
+ # From flashinfer.decode._check_trtllm_gen_mla_shape:
36
+ # block_num % (128 / block_size) == 0
37
+ # This imposes that the total number of blocks must be divisible by
38
+ # (128 / block_size). We capture the 128 constant here so we can
39
+ # compute the LCM with other padding constraints.
40
+ TRTLLM_BLOCK_CONSTRAINT = 128
41
+
42
+
43
+ @dataclass
44
+ class TRTLLMMLADecodeMetadata:
45
+ """Metadata for TRTLLM MLA decode operations."""
46
+
47
+ workspace: Optional[torch.Tensor] = None
48
+ block_kv_indices: Optional[torch.Tensor] = None
49
+
50
+
51
+ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
52
+ """TRTLLM MLA attention kernel from flashinfer."""
53
+
54
+ def __init__(
55
+ self,
56
+ model_runner: ModelRunner,
57
+ skip_prefill: bool = False,
58
+ kv_indptr_buf: Optional[torch.Tensor] = None,
59
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
60
+ ):
61
+ super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
62
+
63
+ config = model_runner.model_config
64
+
65
+ # Model parameters
66
+ self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
67
+ self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
68
+ self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
69
+
70
+ # MLA-specific dimensions
71
+ self.kv_lora_rank = config.kv_lora_rank
72
+ self.qk_nope_head_dim = config.qk_nope_head_dim
73
+ self.qk_rope_head_dim = config.qk_rope_head_dim
74
+ self.v_head_dim = config.v_head_dim
75
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
76
+
77
+ # Runtime parameters
78
+ self.scaling = config.scaling
79
+ self.data_type = model_runner.kv_cache_dtype
80
+ self.q_data_type = model_runner.dtype
81
+ self.page_size = model_runner.page_size
82
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
83
+
84
+ # Workspace allocation
85
+ self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
86
+ self.workspace_buffer = torch.empty(
87
+ self.workspace_size, dtype=torch.int8, device=self.device
88
+ )
89
+
90
+ # CUDA graph state
91
+ self.decode_cuda_graph_metadata = {}
92
+ self.cuda_graph_kv_indices = None
93
+ self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
94
+
95
+ def _calc_padded_blocks(self, max_seq_len: int) -> int:
96
+ """
97
+ Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
98
+
99
+ Args:
100
+ max_seq_len: Maximum sequence length in tokens
101
+
102
+ Returns:
103
+ Number of blocks padded to satisfy all constraints
104
+ """
105
+ blocks = triton.cdiv(max_seq_len, self.page_size)
106
+
107
+ # Apply dual constraints (take LCM to satisfy both):
108
+ # 1. TRT-LLM: block_num % (128 / page_size) == 0
109
+ # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
110
+ trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
111
+ constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
112
+
113
+ if blocks % constraint_lcm != 0:
114
+ blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
115
+ return blocks
116
+
117
+ def _create_block_kv_indices(
118
+ self,
119
+ batch_size: int,
120
+ max_blocks: int,
121
+ req_pool_indices: torch.Tensor,
122
+ seq_lens: torch.Tensor,
123
+ device: torch.device,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Create block KV indices tensor using Triton kernel.
127
+
128
+ Args:
129
+ batch_size: Batch size
130
+ max_blocks: Maximum number of blocks per sequence
131
+ req_pool_indices: Request pool indices
132
+ seq_lens: Sequence lengths
133
+ device: Target device
134
+
135
+ Returns:
136
+ Block KV indices tensor
137
+ """
138
+ block_kv_indices = torch.full(
139
+ (batch_size, max_blocks), -1, dtype=torch.int32, device=device
140
+ )
141
+
142
+ create_flashmla_kv_indices_triton[(batch_size,)](
143
+ self.req_to_token,
144
+ req_pool_indices,
145
+ seq_lens,
146
+ None,
147
+ block_kv_indices,
148
+ self.req_to_token.stride(0),
149
+ max_blocks,
150
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
151
+ self.page_size,
152
+ )
153
+
154
+ return block_kv_indices
155
+
156
+ def init_cuda_graph_state(
157
+ self,
158
+ max_bs: int,
159
+ max_num_tokens: int,
160
+ kv_indices_buf: Optional[torch.Tensor] = None,
161
+ ):
162
+ """Initialize CUDA graph state for TRTLLM MLA."""
163
+ max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
164
+
165
+ self.cuda_graph_kv_indices = torch.full(
166
+ (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
167
+ )
168
+ self.cuda_graph_workspace = torch.empty(
169
+ self.workspace_size, dtype=torch.int8, device=self.device
170
+ )
171
+
172
+ def init_forward_metadata_capture_cuda_graph(
173
+ self,
174
+ bs: int,
175
+ num_tokens: int,
176
+ req_pool_indices: torch.Tensor,
177
+ seq_lens: torch.Tensor,
178
+ encoder_lens: Optional[torch.Tensor],
179
+ forward_mode: ForwardMode,
180
+ spec_info: Optional[SpecInfo],
181
+ ):
182
+ """Initialize metadata for CUDA graph capture."""
183
+ # Delegate to parent for non-decode modes or when speculative execution is used.
184
+ if not (forward_mode.is_decode_or_idle() and spec_info is None):
185
+ return super().init_forward_metadata_capture_cuda_graph(
186
+ bs,
187
+ num_tokens,
188
+ req_pool_indices,
189
+ seq_lens,
190
+ encoder_lens,
191
+ forward_mode,
192
+ spec_info,
193
+ )
194
+
195
+ # Custom fast-path for decode/idle without speculative execution.
196
+ max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
197
+ block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
198
+
199
+ create_flashmla_kv_indices_triton[(bs,)](
200
+ self.req_to_token,
201
+ req_pool_indices,
202
+ seq_lens,
203
+ None,
204
+ block_kv_indices,
205
+ self.req_to_token.stride(0),
206
+ max_seqlen_pad,
207
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
208
+ self.page_size,
209
+ )
210
+
211
+ metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
212
+ self.decode_cuda_graph_metadata[bs] = metadata
213
+ self.forward_metadata = metadata
214
+
215
+ def init_forward_metadata_replay_cuda_graph(
216
+ self,
217
+ bs: int,
218
+ req_pool_indices: torch.Tensor,
219
+ seq_lens: torch.Tensor,
220
+ seq_lens_sum: int,
221
+ encoder_lens: Optional[torch.Tensor],
222
+ forward_mode: ForwardMode,
223
+ spec_info: Optional[SpecInfo],
224
+ seq_lens_cpu: Optional[torch.Tensor],
225
+ ):
226
+ """Replay CUDA graph with new inputs."""
227
+ # Delegate to parent for non-decode modes or when speculative execution is used.
228
+ if not (forward_mode.is_decode_or_idle() and spec_info is None):
229
+ return super().init_forward_metadata_replay_cuda_graph(
230
+ bs,
231
+ req_pool_indices,
232
+ seq_lens,
233
+ seq_lens_sum,
234
+ encoder_lens,
235
+ forward_mode,
236
+ spec_info,
237
+ seq_lens_cpu,
238
+ )
239
+
240
+ metadata = self.decode_cuda_graph_metadata[bs]
241
+
242
+ # Update block indices for new sequences.
243
+ create_flashmla_kv_indices_triton[(bs,)](
244
+ self.req_to_token,
245
+ req_pool_indices[:bs],
246
+ seq_lens[:bs],
247
+ None,
248
+ metadata.block_kv_indices,
249
+ self.req_to_token.stride(0),
250
+ metadata.block_kv_indices.shape[1],
251
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
252
+ self.page_size,
253
+ )
254
+
255
+ def get_cuda_graph_seq_len_fill_value(self) -> int:
256
+ """Get the fill value for sequence lengths in CUDA graph."""
257
+ return 1
258
+
259
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
260
+ """Initialize the metadata for a forward pass."""
261
+ # Delegate to parent for non-decode modes or when speculative execution is used.
262
+ if not (
263
+ forward_batch.forward_mode.is_decode_or_idle()
264
+ and forward_batch.spec_info is None
265
+ ):
266
+ return super().init_forward_metadata(forward_batch)
267
+
268
+ bs = forward_batch.batch_size
269
+
270
+ # Get maximum sequence length.
271
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
272
+ max_seq = forward_batch.seq_lens_cpu.max().item()
273
+ else:
274
+ max_seq = forward_batch.seq_lens.max().item()
275
+
276
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
277
+ block_kv_indices = self._create_block_kv_indices(
278
+ bs,
279
+ max_seqlen_pad,
280
+ forward_batch.req_pool_indices,
281
+ forward_batch.seq_lens,
282
+ forward_batch.seq_lens.device,
283
+ )
284
+
285
+ self.forward_metadata = TRTLLMMLADecodeMetadata(
286
+ self.workspace_buffer, block_kv_indices
287
+ )
288
+ forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
289
+
290
+ def forward_decode(
291
+ self,
292
+ q: torch.Tensor,
293
+ k: torch.Tensor,
294
+ v: torch.Tensor,
295
+ layer: RadixAttention,
296
+ forward_batch: ForwardBatch,
297
+ save_kv_cache: bool = True,
298
+ q_rope: Optional[torch.Tensor] = None,
299
+ k_rope: Optional[torch.Tensor] = None,
300
+ ) -> torch.Tensor:
301
+ """Run forward for decode using TRTLLM MLA kernel."""
302
+ # Save KV cache if requested
303
+ if k is not None and save_kv_cache:
304
+ cache_loc = forward_batch.out_cache_loc
305
+ if k_rope is not None:
306
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
307
+ layer, cache_loc, k, k_rope
308
+ )
309
+ elif v is not None:
310
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
311
+
312
+ # Prepare query tensor inline
313
+ if q_rope is not None:
314
+ # q contains NOPE part (v_head_dim)
315
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
316
+ q_rope_reshaped = q_rope.view(
317
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
318
+ )
319
+ query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
320
+ else:
321
+ # q already has both parts
322
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
323
+
324
+ # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
325
+ if query.dim() == 3:
326
+ query = query.unsqueeze(1)
327
+
328
+ # Prepare KV cache inline
329
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
330
+ pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
331
+ # TRT-LLM expects single KV data with extra dimension
332
+ kv_cache = pages.unsqueeze(1)
333
+
334
+ # Get metadata
335
+ metadata = (
336
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
337
+ or self.forward_metadata
338
+ )
339
+
340
+ # Scale computation for TRTLLM MLA kernel:
341
+ # - BMM1 scale = q_scale * k_scale * softmax_scale
342
+ # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
343
+ # - k_scale is read from model checkpoint if available
344
+ # TODO: Change once fp8 path is supported
345
+ q_scale = 1.0
346
+ k_scale = (
347
+ layer.k_scale_float
348
+ if getattr(layer, "k_scale_float", None) is not None
349
+ else 1.0
350
+ )
351
+
352
+ bmm1_scale = q_scale * k_scale * layer.scaling
353
+
354
+ # Call TRT-LLM kernel
355
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
356
+ query=query,
357
+ kv_cache=kv_cache,
358
+ workspace_buffer=metadata.workspace,
359
+ qk_nope_head_dim=self.qk_nope_head_dim,
360
+ kv_lora_rank=self.kv_lora_rank,
361
+ qk_rope_head_dim=self.qk_rope_head_dim,
362
+ block_tables=metadata.block_kv_indices,
363
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
364
+ max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
365
+ bmm1_scale=bmm1_scale,
366
+ )
367
+
368
+ # Extract value projection part and reshape
369
+ raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
370
+ output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
371
+
372
+ return output
@@ -1,6 +1,11 @@
1
1
  import triton
2
2
  import triton.language as tl
3
3
 
4
+ # Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
5
+ # Number of pages that the kernel writes per iteration.
6
+ # Exposed here so other Python modules can import it instead of hard-coding 64.
7
+ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
8
+
4
9
 
5
10
  @triton.jit
6
11
  def create_flashinfer_kv_indices_triton(
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
50
55
  kv_indices_ptr,
51
56
  req_to_token_ptr_stride: tl.constexpr,
52
57
  kv_indices_ptr_stride: tl.constexpr,
58
+ NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
53
59
  PAGED_SIZE: tl.constexpr = 64,
54
60
  ):
55
61
  BLOCK_SIZE: tl.constexpr = 4096
56
- NUM_PAGE_PER_BLOCK: tl.constexpr = 64
57
62
  pid = tl.program_id(axis=0)
58
63
 
59
64
  # find the req pool idx, this is for batch to token
@@ -209,7 +209,8 @@ def cutlass_fused_experts_fp8(
209
209
  )
210
210
 
211
211
  result = torch.empty((m, k), device=device, dtype=out_dtype)
212
- return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
212
+ apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
213
+ return result
213
214
 
214
215
 
215
216
  FLOAT4_E2M1_MAX = 6.0