sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,281 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from sglang.srt.layers.attention import AttentionBackend
9
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.model_runner import ModelRunner
14
+
15
+
16
+ class DoubleSparseAttnBackend(AttentionBackend):
17
+ def __init__(self, model_runner: ModelRunner):
18
+ # Lazy import to avoid the initialization of cuda context
19
+ from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
20
+ flash_decode_attention_fwd,
21
+ flash_decode_sparse_attention_fwd,
22
+ )
23
+ from sglang.srt.layers.attention.triton_ops.extend_attention import (
24
+ extend_attention_fwd,
25
+ )
26
+
27
+ super().__init__()
28
+
29
+ self.decode_attention_fwd = flash_decode_attention_fwd
30
+ self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
31
+ self.extend_attention_fwd = extend_attention_fwd
32
+ self.num_head = model_runner.model_config.num_attention_heads
33
+ self.head_dim = model_runner.model_config.hidden_size // self.num_head
34
+ self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
35
+
36
+ self.sorted_channels = model_runner.sorted_channels
37
+ self.sparse_decode_thresold = (
38
+ model_runner.server_args.ds_sparse_decode_threshold
39
+ )
40
+ self.att_out_approx: torch.Tensor = None
41
+ self.mid_out: torch.Tensor = None
42
+ self.mid_o_logexpsum: torch.Tensor = None
43
+
44
+ # TODO: Change the hard-coded block_seq_num
45
+ self.BLOCK_SEQ = 128
46
+
47
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
48
+ self.reduce_dtype = torch.float32
49
+ else:
50
+ self.reduce_dtype = torch.float16
51
+
52
+ self.forward_metadata = None
53
+
54
+ self.cuda_graph_max_seq_len = model_runner.model_config.context_len
55
+
56
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
57
+ """Init auxiliary variables for triton attention backend."""
58
+
59
+ if forward_batch.forward_mode.is_decode():
60
+ start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
61
+ start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
62
+
63
+ total_num_tokens = torch.sum(forward_batch.seq_lens).item()
64
+ attn_logits = torch.empty(
65
+ (self.num_head, total_num_tokens),
66
+ dtype=self.reduce_dtype,
67
+ device="cuda",
68
+ )
69
+
70
+ max_seq_len = torch.max(forward_batch.seq_lens).item()
71
+ min_seq_len = torch.min(forward_batch.seq_lens).item()
72
+ max_extend_len = None
73
+ # NOTE: Align sequence order with req_to_token order
74
+ ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
75
+ forward_batch.req_pool_indices
76
+ ]
77
+
78
+ bsz = forward_batch.seq_lens.shape[0]
79
+
80
+ att_out_approx = torch.empty(
81
+ [self.num_head, bsz, max_seq_len],
82
+ dtype=self.reduce_dtype,
83
+ device="cuda",
84
+ )
85
+
86
+ block_seq_num = (
87
+ self.heavy_token_num + self.BLOCK_SEQ - 1
88
+ ) // self.BLOCK_SEQ
89
+
90
+ mid_out = torch.empty(
91
+ [bsz, self.num_head, block_seq_num, self.head_dim],
92
+ dtype=torch.float32,
93
+ device="cuda",
94
+ )
95
+ mid_o_logexpsum = torch.empty(
96
+ [bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
97
+ )
98
+ self.att_out_approx = att_out_approx
99
+ self.mid_out = mid_out
100
+ self.mid_o_logexpsum = mid_o_logexpsum
101
+
102
+ else:
103
+ start_loc = attn_logits = max_seq_len = min_seq_len = None
104
+ prefix_lens = forward_batch.extend_prefix_lens
105
+ max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
106
+ ds_req_to_token = None
107
+
108
+ self.forward_metadata = (
109
+ start_loc,
110
+ attn_logits,
111
+ max_seq_len,
112
+ min_seq_len,
113
+ max_extend_len,
114
+ ds_req_to_token,
115
+ )
116
+
117
+ def init_cuda_graph_state(self, max_bs: int):
118
+ # TODO(Andy): Support CUDA graph for double sparse attention
119
+ raise ValueError(
120
+ "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
121
+ )
122
+ self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
123
+
124
+ self.cuda_graph_start_loc = torch.zeros(
125
+ (max_bs,), dtype=torch.int32, device="cuda"
126
+ )
127
+ self.cuda_graph_attn_logits = torch.empty(
128
+ (
129
+ self.num_head,
130
+ self.cuda_graph_max_total_num_tokens,
131
+ ),
132
+ dtype=self.reduce_dtype,
133
+ device="cuda",
134
+ )
135
+
136
+ def init_forward_metadata_capture_cuda_graph(
137
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
138
+ ):
139
+ self.forward_metadata = (
140
+ self.cuda_graph_start_loc,
141
+ self.cuda_graph_attn_logits,
142
+ self.cuda_graph_max_seq_len,
143
+ None,
144
+ )
145
+
146
+ def init_forward_metadata_replay_cuda_graph(
147
+ self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
148
+ ):
149
+ self.cuda_graph_start_loc.zero_()
150
+ self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
151
+
152
+ def get_cuda_graph_seq_len_fill_value(self):
153
+ return 1
154
+
155
+ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
156
+ # TODO: reuse the buffer across layers
157
+ if layer.qk_head_dim != layer.v_head_dim:
158
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
159
+ else:
160
+ o = torch.empty_like(q)
161
+
162
+ k_label = torch.gather(
163
+ k,
164
+ 2,
165
+ self.sorted_channels[layer.layer_id]
166
+ .unsqueeze(0)
167
+ .expand(k.shape[0], -1, -1),
168
+ )
169
+
170
+ forward_batch.token_to_kv_pool.set_kv_buffer(
171
+ layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
172
+ )
173
+
174
+ (
175
+ start_loc,
176
+ attn_logits,
177
+ max_seq_len,
178
+ min_seq_len,
179
+ max_extend_len,
180
+ ds_req_to_token,
181
+ ) = self.forward_metadata
182
+ self.extend_attention_fwd(
183
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
184
+ k.contiguous(),
185
+ v.contiguous(),
186
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
187
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
188
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
189
+ forward_batch.req_to_token_pool.req_to_token,
190
+ forward_batch.req_pool_indices,
191
+ forward_batch.seq_lens,
192
+ forward_batch.extend_seq_lens,
193
+ forward_batch.extend_start_loc,
194
+ max_extend_len,
195
+ layer.scaling,
196
+ layer.logit_cap,
197
+ )
198
+ return o
199
+
200
+ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
201
+ # During torch.compile, there is a bug in rotary_emb that causes the
202
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
203
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
204
+
205
+ # TODO: reuse the buffer across layers
206
+ if layer.qk_head_dim != layer.v_head_dim:
207
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
208
+ else:
209
+ o = torch.empty_like(q)
210
+
211
+ # TODO: Add min seqlen
212
+ (
213
+ start_loc,
214
+ attn_logits,
215
+ max_seq_len,
216
+ min_seq_len,
217
+ max_extend_len,
218
+ ds_req_to_token,
219
+ ) = self.forward_metadata
220
+
221
+ k_label = torch.gather(
222
+ k,
223
+ 2,
224
+ self.sorted_channels[layer.layer_id]
225
+ .unsqueeze(0)
226
+ .expand(k.shape[0], -1, -1),
227
+ )
228
+
229
+ forward_batch.token_to_kv_pool.set_kv_buffer(
230
+ layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
231
+ )
232
+
233
+ # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
234
+ # and set a minimum value for sparse_decode
235
+ if (
236
+ min_seq_len < self.heavy_token_num
237
+ or max_seq_len < self.sparse_decode_thresold
238
+ ):
239
+ self.decode_attention_fwd(
240
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
241
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
242
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
243
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
244
+ forward_batch.req_to_token_pool.req_to_token,
245
+ forward_batch.req_pool_indices,
246
+ start_loc,
247
+ forward_batch.seq_lens,
248
+ attn_logits,
249
+ max_seq_len,
250
+ layer.scaling,
251
+ layer.logit_cap,
252
+ )
253
+ else:
254
+ # TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
255
+ q_label = torch.gather(
256
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
257
+ 2,
258
+ self.sorted_channels[layer.layer_id]
259
+ .unsqueeze(0)
260
+ .expand(q.shape[0], -1, -1),
261
+ )
262
+ self.decode_sparse_attention_fwd(
263
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
264
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
265
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
266
+ o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
267
+ q_label,
268
+ forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
269
+ ds_req_to_token,
270
+ forward_batch.seq_lens,
271
+ max_seq_len,
272
+ layer.scaling,
273
+ layer.logit_cap,
274
+ self.heavy_token_num,
275
+ self.att_out_approx,
276
+ self.mid_out,
277
+ self.mid_o_logexpsum,
278
+ self.BLOCK_SEQ,
279
+ )
280
+
281
+ return o