sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
+ import triton
9
+ import triton.language as tl
8
10
 
9
11
  from sglang.srt.configs.model_config import AttentionArch
10
12
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -64,6 +66,9 @@ class FlashAttentionMetadata:
64
66
 
65
67
  local_attn_metadata: Optional[LocalAttentionMetadata] = None
66
68
 
69
+ # For sliding window attention topk>1 spec decoding
70
+ swa_spec_metadata: Optional[FlashAttentionMetadata] = None
71
+
67
72
 
68
73
  # Copied from:
69
74
  # https://github.com/houseroad/vllm/blob/4e45bfcaf928bdb9bd952b4ac922a3c205589ae8/vllm/v1/attention/backends/flash_attn.py
@@ -340,6 +345,13 @@ class FlashAttentionBackend(AttentionBackend):
340
345
  else None
341
346
  )
342
347
 
348
+ # For each layer, the sliding_window_size can be different. This is only used for preparing SWA metadata.
349
+ # We use `layer.sliding_window_size` to decide whether to use SWA for each layer.
350
+ self.sliding_window_size = model_runner.sliding_window_size
351
+ self.has_swa = (
352
+ self.sliding_window_size is not None and self.sliding_window_size > -1
353
+ )
354
+
343
355
  def init_forward_metadata(self, forward_batch: ForwardBatch):
344
356
  """Initialize forward metadata hence all layers in the forward pass can reuse it."""
345
357
  metadata = FlashAttentionMetadata()
@@ -556,6 +568,12 @@ class FlashAttentionBackend(AttentionBackend):
556
568
  (1, 0),
557
569
  )
558
570
  self.forward_metadata_spec_decode_expand = metadata_expand
571
+
572
+ if self.has_swa:
573
+ self._init_sliding_window_attn_spec_metadata(
574
+ metadata, metadata_expand
575
+ )
576
+
559
577
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
560
578
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
561
579
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -657,11 +675,10 @@ class FlashAttentionBackend(AttentionBackend):
657
675
  # Calculate window size (can be moved to metadata if layer properties don't change)
658
676
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
659
677
  # here is two side inclusive
660
- window_size = (
661
- (layer.sliding_window_size, 0)
662
- if layer.sliding_window_size is not None and layer.sliding_window_size > -1
663
- else (-1, -1)
678
+ is_swa = (
679
+ layer.sliding_window_size is not None and layer.sliding_window_size > -1
664
680
  )
681
+ window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
665
682
  k_descale, v_descale = None, None
666
683
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
667
684
  # has corresponding quantization method so that layer.k_scale is not None,
@@ -684,8 +701,13 @@ class FlashAttentionBackend(AttentionBackend):
684
701
  )
685
702
 
686
703
  # We do cascade attention for Target Verify with topk > 1
704
+ # We don't use cascade attention for Sliding Window Attention:
705
+ # - Different window sizes should be passed in for each q in the first stage of cascade attention, but FA3 interface doesn't support pass in a list of window sizes.
706
+ # - The overhead of duplicated computation of the common prefix part is small for sliding window layers (seq_len <= window_size), so we can just expand it.
687
707
  use_cascade_attn = (
688
- forward_batch.forward_mode.is_target_verify() and self.topk > 1
708
+ forward_batch.forward_mode.is_target_verify()
709
+ and self.topk > 1
710
+ and not is_swa
689
711
  )
690
712
 
691
713
  # For fa3 interface version compatibility, we put new fields into conditional keyword args
@@ -700,13 +722,18 @@ class FlashAttentionBackend(AttentionBackend):
700
722
  cu_seqlens_q = local_metadata.local_query_start_loc
701
723
  cache_seqlens = local_metadata.local_seqused_k
702
724
  max_seqlen_q = local_metadata.local_max_query_len
703
- max_seqlen_k = local_metadata.local_max_seq_len
725
+ elif is_swa and metadata.swa_spec_metadata is not None:
726
+ swa_spec_metadata = metadata.swa_spec_metadata
727
+ page_table = swa_spec_metadata.page_table
728
+ cu_seqlens_q = swa_spec_metadata.cu_seqlens_q
729
+ cache_seqlens = swa_spec_metadata.cache_seqlens_int32
730
+ max_seqlen_q = swa_spec_metadata.max_seq_len_q
731
+ cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
704
732
  else:
705
733
  page_table = metadata.page_table
706
734
  cu_seqlens_q = metadata.cu_seqlens_q
707
735
  cache_seqlens = metadata.cache_seqlens_int32
708
736
  max_seqlen_q = metadata.max_seq_len_q
709
- max_seqlen_k = metadata.max_seq_len_k
710
737
  cu_seqlens_k = metadata.cu_seqlens_k
711
738
 
712
739
  # Use Flash Attention for prefill
@@ -1377,6 +1404,32 @@ class FlashAttentionBackend(AttentionBackend):
1377
1404
  ),
1378
1405
  }
1379
1406
 
1407
+ if self.has_swa:
1408
+ self.target_verify_metadata_topk_swa = {
1409
+ "cache_seqlens": torch.zeros(
1410
+ max_bs * self.speculative_num_draft_tokens,
1411
+ dtype=torch.int32,
1412
+ device=self.device,
1413
+ ),
1414
+ "cu_seqlens_k": torch.zeros(
1415
+ max_bs * self.speculative_num_draft_tokens + 1,
1416
+ dtype=torch.int32,
1417
+ device=self.device,
1418
+ ),
1419
+ "cu_seqlens_q": torch.arange(
1420
+ 0,
1421
+ max_bs * self.speculative_num_draft_tokens + 1,
1422
+ dtype=torch.int32,
1423
+ device=self.device,
1424
+ ),
1425
+ "page_table": torch.zeros(
1426
+ max_bs * self.speculative_num_draft_tokens,
1427
+ self.max_context_len,
1428
+ dtype=torch.int32,
1429
+ device=self.device,
1430
+ ),
1431
+ }
1432
+
1380
1433
  self.encoder_metadata = {
1381
1434
  "encoder_page_table": torch.zeros(
1382
1435
  max_bs,
@@ -1564,6 +1617,28 @@ class FlashAttentionBackend(AttentionBackend):
1564
1617
 
1565
1618
  self.target_verify_metadata_topk_normal[bs] = metadata
1566
1619
  self.target_verify_metadata_topk_expand[bs] = metadata_expand
1620
+
1621
+ if self.has_swa:
1622
+ metadata_swa = FlashAttentionMetadata()
1623
+ metadata_swa.cache_seqlens_int32 = (
1624
+ self.target_verify_metadata_topk_swa["cache_seqlens"][
1625
+ : bs * self.speculative_num_draft_tokens
1626
+ ]
1627
+ )
1628
+ metadata_swa.max_seq_len_q = 1
1629
+ metadata_swa.cu_seqlens_q = self.target_verify_metadata_topk_swa[
1630
+ "cu_seqlens_q"
1631
+ ][: bs * self.speculative_num_draft_tokens + 1]
1632
+ metadata_swa.cu_seqlens_k = self.target_verify_metadata_topk_swa[
1633
+ "cu_seqlens_k"
1634
+ ][: bs * self.speculative_num_draft_tokens + 1]
1635
+
1636
+ metadata_swa.page_table = self.target_verify_metadata_topk_swa[
1637
+ "page_table"
1638
+ ][: bs * self.speculative_num_draft_tokens]
1639
+ self.target_verify_metadata_topk_swa[bs] = metadata_swa
1640
+ metadata.swa_spec_metadata = metadata_swa
1641
+
1567
1642
  elif forward_mode.is_draft_extend():
1568
1643
  metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
1569
1644
  :bs
@@ -1804,6 +1879,12 @@ class FlashAttentionBackend(AttentionBackend):
1804
1879
  )
1805
1880
  )
1806
1881
 
1882
+ if self.has_swa:
1883
+ metadata_swa = self.target_verify_metadata_topk_swa[bs]
1884
+ self._init_sliding_window_attn_spec_metadata(
1885
+ metadata, metadata_expand, metadata_swa
1886
+ )
1887
+
1807
1888
  elif forward_mode.is_draft_extend():
1808
1889
  metadata = self.draft_extend_metadata[bs]
1809
1890
  metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -2039,6 +2120,159 @@ class FlashAttentionBackend(AttentionBackend):
2039
2120
  lam.local_max_query_len = int(seqlens_q_local_np.max())
2040
2121
  lam.local_max_seq_len = int(seqlens_k_local_np.max())
2041
2122
 
2123
+ def _init_sliding_window_attn_spec_metadata(
2124
+ self,
2125
+ metadata: FlashAttentionMetadata,
2126
+ metadata_expand: FlashAttentionMetadata,
2127
+ metadata_swa: Optional[FlashAttentionMetadata] = None,
2128
+ ):
2129
+ # TODO: support page_size > 1 for swa spec
2130
+ assert (
2131
+ self.page_size == 1
2132
+ ), "FlashAttention backend doesn't support topk > 1 speculative decoding with page size > 1 sliding window attention"
2133
+
2134
+ cache_seqlens_int32 = (
2135
+ metadata.cache_seqlens_int32.repeat_interleave(
2136
+ self.speculative_num_draft_tokens
2137
+ )
2138
+ + metadata_expand.cache_seqlens_int32
2139
+ )
2140
+ cu_seqlens_k = torch.nn.functional.pad(
2141
+ torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32), (1, 0)
2142
+ )
2143
+ bs = cache_seqlens_int32.shape[0]
2144
+ page_table = (
2145
+ metadata.page_table.new_zeros(
2146
+ (bs, metadata.max_seq_len_k + metadata_expand.page_table.shape[1])
2147
+ )
2148
+ if metadata_swa is None
2149
+ else metadata_swa.page_table
2150
+ )
2151
+
2152
+ prepare_swa_spec_page_table_triton(
2153
+ page_table,
2154
+ metadata.page_table,
2155
+ metadata_expand.page_table,
2156
+ metadata.cache_seqlens_int32,
2157
+ metadata_expand.cache_seqlens_int32,
2158
+ self.speculative_num_draft_tokens,
2159
+ )
2160
+
2161
+ if metadata_swa is None:
2162
+ metadata_swa = FlashAttentionMetadata()
2163
+ metadata_swa.max_seq_len_q = 1
2164
+ metadata_swa.cu_seqlens_q = metadata_expand.cu_seqlens_q
2165
+ metadata_swa.cache_seqlens_int32 = cache_seqlens_int32
2166
+ metadata_swa.cu_seqlens_k = cu_seqlens_k
2167
+ metadata_swa.page_table = page_table
2168
+ else:
2169
+ metadata_swa.cache_seqlens_int32.copy_(cache_seqlens_int32)
2170
+ metadata_swa.cu_seqlens_k.copy_(cu_seqlens_k)
2171
+
2172
+ metadata.swa_spec_metadata = metadata_swa
2173
+
2174
+
2175
+ @triton.jit
2176
+ def _prepare_swa_spec_page_table_kernel(
2177
+ dst_ptr,
2178
+ src_a_ptr,
2179
+ src_b_ptr,
2180
+ seq_len_a_ptr,
2181
+ seq_len_b_ptr,
2182
+ dst_stride_m,
2183
+ dst_stride_n,
2184
+ a_stride_m,
2185
+ a_stride_n,
2186
+ b_stride_m,
2187
+ b_stride_n,
2188
+ LEN_A: tl.constexpr,
2189
+ LEN_B: tl.constexpr,
2190
+ REPEAT_STEP: tl.constexpr,
2191
+ BLOCK_N: tl.constexpr,
2192
+ ):
2193
+ pid_m = tl.program_id(0)
2194
+ pid_n = tl.program_id(1)
2195
+
2196
+ idx_a = pid_m // REPEAT_STEP
2197
+ idx_b = pid_m
2198
+ seq_len_a = tl.load(seq_len_a_ptr + idx_a)
2199
+ seq_len_b = tl.load(seq_len_b_ptr + idx_b)
2200
+
2201
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2202
+ total_len = seq_len_a + seq_len_b
2203
+
2204
+ if pid_n * BLOCK_N >= total_len:
2205
+ return
2206
+
2207
+ mask = offs_n < total_len
2208
+ dst = dst_ptr + pid_m * dst_stride_m + offs_n * dst_stride_n
2209
+
2210
+ if (pid_n + 1) * BLOCK_N < seq_len_a:
2211
+ a_ptr = src_a_ptr + idx_a * a_stride_m + offs_n * a_stride_n
2212
+ a_mask = mask & (offs_n < LEN_A)
2213
+ val = tl.load(a_ptr, mask=a_mask, other=0)
2214
+ tl.store(dst, val, mask=mask)
2215
+ elif pid_n * BLOCK_N >= seq_len_a:
2216
+ offs_b = offs_n - seq_len_a
2217
+ b_ptr = src_b_ptr + idx_b * b_stride_m + offs_b * b_stride_n
2218
+ b_mask = mask & (offs_b < LEN_B)
2219
+ val = tl.load(b_ptr, mask=b_mask, other=0)
2220
+ tl.store(dst, val, mask=mask)
2221
+ else:
2222
+ # mixed part
2223
+ a_offs = offs_n
2224
+ a_mask = (a_offs < seq_len_a) & (a_offs < LEN_A)
2225
+ a_ptr = src_a_ptr + idx_a * a_stride_m + a_offs * a_stride_n
2226
+ a_val = tl.load(a_ptr, mask=a_mask, other=0)
2227
+
2228
+ b_offs = offs_n - seq_len_a
2229
+ b_mask = (b_offs >= 0) & (b_offs < seq_len_b) & (b_offs < LEN_B)
2230
+ b_ptr = src_b_ptr + idx_b * b_stride_m + b_offs * b_stride_n
2231
+ b_val = tl.load(b_ptr, mask=b_mask, other=0)
2232
+
2233
+ result = tl.where(offs_n < seq_len_a, a_val, b_val)
2234
+ tl.store(dst, result, mask=mask)
2235
+
2236
+
2237
+ def prepare_swa_spec_page_table_triton(
2238
+ page_table_dst: torch.Tensor,
2239
+ page_table_a: torch.Tensor,
2240
+ page_table_b: torch.Tensor, # expand page table
2241
+ seq_len_a: torch.Tensor,
2242
+ seq_len_b: torch.Tensor, # expand seq lens
2243
+ speculative_num_draft_tokens: int,
2244
+ ):
2245
+ # concat page_table and expand page_table by kv seq length
2246
+ bs = seq_len_a.numel()
2247
+ bs_expand = seq_len_b.numel()
2248
+ assert bs_expand == bs * speculative_num_draft_tokens
2249
+
2250
+ LEN_A = page_table_a.shape[1]
2251
+ LEN_B = page_table_b.shape[1]
2252
+ LEN_OUT = LEN_A + LEN_B
2253
+ REPEAT_STEP = speculative_num_draft_tokens
2254
+ BLOCK_N = 256
2255
+
2256
+ grid = (bs_expand, triton.cdiv(LEN_OUT, BLOCK_N))
2257
+ _prepare_swa_spec_page_table_kernel[grid](
2258
+ page_table_dst,
2259
+ page_table_a,
2260
+ page_table_b,
2261
+ seq_len_a,
2262
+ seq_len_b,
2263
+ page_table_dst.stride(0),
2264
+ page_table_dst.stride(1),
2265
+ page_table_a.stride(0),
2266
+ page_table_a.stride(1),
2267
+ page_table_b.stride(0),
2268
+ page_table_b.stride(1),
2269
+ LEN_A=LEN_A,
2270
+ LEN_B=LEN_B,
2271
+ REPEAT_STEP=REPEAT_STEP,
2272
+ BLOCK_N=BLOCK_N,
2273
+ num_warps=4,
2274
+ )
2275
+
2042
2276
 
2043
2277
  class FlashAttentionMultiStepBackend:
2044
2278
 
@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
26
26
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
27
27
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
28
  from sglang.srt.layers.radix_attention import AttentionType
29
- from sglang.srt.layers.utils import is_sm100_supported
30
29
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
31
30
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
32
31
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
33
- from sglang.srt.utils import is_flashinfer_available, next_power_of_2
32
+ from sglang.srt.utils import (
33
+ is_flashinfer_available,
34
+ is_sm100_supported,
35
+ next_power_of_2,
36
+ )
34
37
 
35
38
  if TYPE_CHECKING:
36
39
  from sglang.srt.layers.radix_attention import RadixAttention