sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -1,237 +0,0 @@
1
- from enum import Enum, auto
2
-
3
- import torch
4
- import triton
5
- import triton.language as tl
6
-
7
-
8
- class WrapperDispatch(Enum):
9
- SLIDING_WINDOW = auto()
10
- CROSS_ATTENTION = auto()
11
-
12
-
13
- @triton.jit
14
- def create_flashinfer_kv_indices_triton(
15
- req_to_token_ptr, # [max_batch, max_context_len]
16
- req_pool_indices_ptr,
17
- page_kernel_lens_ptr,
18
- kv_indptr,
19
- kv_start_idx,
20
- kv_indices_ptr,
21
- max_context_len: tl.constexpr,
22
- ):
23
- BLOCK_SIZE: tl.constexpr = 512
24
- pid = tl.program_id(axis=0)
25
- req_pool_index = tl.load(req_pool_indices_ptr + pid)
26
- kv_indices_offset = tl.load(kv_indptr + pid)
27
-
28
- kv_start = 0
29
- kv_end = 0
30
- if kv_start_idx:
31
- kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
32
- kv_end = kv_start
33
- kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
34
-
35
- req_to_token_ptr += req_pool_index * max_context_len
36
- kv_indices_ptr += kv_indices_offset
37
-
38
- ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
39
- st_offset = tl.arange(0, BLOCK_SIZE)
40
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
41
- for _ in range(num_loop):
42
- mask = ld_offset < kv_end
43
- data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
44
- tl.store(kv_indices_ptr + st_offset, data, mask=mask)
45
- ld_offset += BLOCK_SIZE
46
- st_offset += BLOCK_SIZE
47
-
48
-
49
- class FlashinferUpdater:
50
- def __init__(
51
- self,
52
- forward_mode,
53
- model_runner,
54
- req_pool_indices,
55
- seq_lens,
56
- prefix_lens,
57
- decode_wrappers=None,
58
- use_ragged=False,
59
- ):
60
- self.forward_mode = forward_mode
61
- self.model_runner = model_runner
62
- self.req_pool_indices = req_pool_indices
63
- self.seq_lens = seq_lens
64
- self.prefix_lens = prefix_lens
65
- self.use_ragged = use_ragged
66
-
67
- self.num_qo_heads = (
68
- model_runner.model_config.num_attention_heads // model_runner.tp_size
69
- )
70
- self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
71
- model_runner.tp_size
72
- )
73
- self.head_dim = model_runner.model_config.head_dim
74
- self.batch_size = len(req_pool_indices)
75
-
76
- self.decode_wrappers = (
77
- decode_wrappers or self.model_runner.attn_backend.decode_wrappers
78
- )
79
- self.prefill_wrapper_ragged = (
80
- self.model_runner.attn_backend.prefill_wrapper_ragged
81
- )
82
- self.prefill_wrappers_paged = (
83
- self.model_runner.attn_backend.prefill_wrappers_paged
84
- )
85
-
86
- self.kv_last_page_len = torch.ones(
87
- (self.batch_size,), dtype=torch.int32, device="cuda"
88
- )
89
-
90
- def _update_decode_indices(self, decode_wrapper):
91
- assert not isinstance(decode_wrapper, list)
92
- decode_wrapper.end_forward()
93
- decode_wrapper.begin_forward(
94
- self.kv_indptr,
95
- self.kv_indices,
96
- self.kv_last_page_len,
97
- self.num_qo_heads,
98
- self.num_kv_heads,
99
- self.head_dim,
100
- 1,
101
- data_type=self.model_runner.kv_cache_dtype,
102
- q_data_type=self.model_runner.dtype,
103
- )
104
-
105
- def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
106
- assert not isinstance(paged_wrapper, list)
107
- assert not isinstance(ragged_wrapper, list)
108
-
109
- # extend part
110
- qo_indptr = torch.zeros(
111
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
112
- )
113
- qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
114
-
115
- if self.use_ragged:
116
- ragged_wrapper.end_forward()
117
- ragged_wrapper.begin_forward(
118
- qo_indptr,
119
- qo_indptr,
120
- self.num_qo_heads,
121
- self.num_kv_heads,
122
- self.head_dim,
123
- )
124
-
125
- # cached part
126
- paged_wrapper.end_forward()
127
- paged_wrapper.begin_forward(
128
- qo_indptr,
129
- self.kv_indptr,
130
- self.kv_indices,
131
- self.kv_last_page_len,
132
- self.num_qo_heads,
133
- self.num_kv_heads,
134
- self.head_dim,
135
- 1,
136
- )
137
-
138
- def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
139
- if dispatch_reason is None:
140
- if self.use_ragged:
141
- paged_kernel_lens = self.prefix_lens
142
- else:
143
- paged_kernel_lens = self.seq_lens
144
- self.kv_start_idx = None
145
- elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
146
- if wrapper_id == 0:
147
- # window attention use paged only
148
- if self.forward_mode.is_decode():
149
- paged_kernel_lens = torch.minimum(
150
- self.seq_lens,
151
- torch.tensor(self.model_runner.sliding_window_size + 1),
152
- )
153
- else:
154
- paged_kernel_lens = torch.minimum(
155
- self.seq_lens,
156
- torch.tensor(self.model_runner.sliding_window_size)
157
- + self.seq_lens
158
- - self.prefix_lens,
159
- )
160
- else:
161
- # full attention
162
- paged_kernel_lens = self.seq_lens
163
- self.kv_start_idx = self.seq_lens - paged_kernel_lens
164
-
165
- self.kv_indptr = torch.zeros(
166
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
167
- )
168
- self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
169
- self.kv_indices = torch.empty(
170
- self.kv_indptr[-1], dtype=torch.int32, device="cuda"
171
- )
172
-
173
- create_flashinfer_kv_indices_triton[(self.batch_size,)](
174
- self.model_runner.req_to_token_pool.req_to_token,
175
- self.req_pool_indices,
176
- paged_kernel_lens,
177
- self.kv_indptr,
178
- self.kv_start_idx,
179
- self.kv_indices,
180
- self.model_runner.req_to_token_pool.req_to_token.size(1),
181
- )
182
-
183
- def _update_indicess_single_wrapper(self):
184
- self._get_indices()
185
-
186
- if self.forward_mode.is_decode():
187
- self._update_decode_indices(self.decode_wrappers[0])
188
- else:
189
- self._update_extend_indices(
190
- self.prefill_wrapper_ragged,
191
- self.prefill_wrappers_paged[0],
192
- )
193
-
194
- def _update_indices_cross_attention(self):
195
- pass
196
-
197
- def _update_indices_sliding_window(self):
198
- assert self.use_ragged is False
199
- for wrapper_id in range(2):
200
- self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
201
- if self.forward_mode.is_decode():
202
- self._update_decode_indices(self.decode_wrappers[wrapper_id])
203
- else:
204
- self._update_extend_indices(
205
- None,
206
- self.prefill_wrappers_paged[wrapper_id],
207
- )
208
-
209
-
210
- def update_flashinfer_indices(
211
- forward_mode,
212
- model_runner,
213
- req_pool_indices,
214
- seq_lens,
215
- prefix_lens,
216
- decode_wrappers=None,
217
- use_ragged=False,
218
- ):
219
- updater = FlashinferUpdater(
220
- forward_mode,
221
- model_runner,
222
- req_pool_indices,
223
- seq_lens,
224
- prefix_lens,
225
- decode_wrappers,
226
- use_ragged,
227
- )
228
-
229
- dispatch_reason = model_runner.attn_backend.dispatch_reason
230
-
231
- if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
232
- updater._update_indices_sliding_window()
233
- elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
234
- updater._update_indices_cross_attention()
235
- else:
236
- assert model_runner.attn_backend.num_wrappers == 1
237
- updater._update_indicess_single_wrapper()