sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,281 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention import AttentionBackend
|
9
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
|
+
|
15
|
+
|
16
|
+
class DoubleSparseAttnBackend(AttentionBackend):
|
17
|
+
def __init__(self, model_runner: ModelRunner):
|
18
|
+
# Lazy import to avoid the initialization of cuda context
|
19
|
+
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
20
|
+
flash_decode_attention_fwd,
|
21
|
+
flash_decode_sparse_attention_fwd,
|
22
|
+
)
|
23
|
+
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
24
|
+
extend_attention_fwd,
|
25
|
+
)
|
26
|
+
|
27
|
+
super().__init__()
|
28
|
+
|
29
|
+
self.decode_attention_fwd = flash_decode_attention_fwd
|
30
|
+
self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
|
31
|
+
self.extend_attention_fwd = extend_attention_fwd
|
32
|
+
self.num_head = model_runner.model_config.num_attention_heads
|
33
|
+
self.head_dim = model_runner.model_config.hidden_size // self.num_head
|
34
|
+
self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
|
35
|
+
|
36
|
+
self.sorted_channels = model_runner.sorted_channels
|
37
|
+
self.sparse_decode_thresold = (
|
38
|
+
model_runner.server_args.ds_sparse_decode_threshold
|
39
|
+
)
|
40
|
+
self.att_out_approx: torch.Tensor = None
|
41
|
+
self.mid_out: torch.Tensor = None
|
42
|
+
self.mid_o_logexpsum: torch.Tensor = None
|
43
|
+
|
44
|
+
# TODO: Change the hard-coded block_seq_num
|
45
|
+
self.BLOCK_SEQ = 128
|
46
|
+
|
47
|
+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
48
|
+
self.reduce_dtype = torch.float32
|
49
|
+
else:
|
50
|
+
self.reduce_dtype = torch.float16
|
51
|
+
|
52
|
+
self.forward_metadata = None
|
53
|
+
|
54
|
+
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
55
|
+
|
56
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
57
|
+
"""Init auxiliary variables for triton attention backend."""
|
58
|
+
|
59
|
+
if forward_batch.forward_mode.is_decode():
|
60
|
+
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
61
|
+
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
62
|
+
|
63
|
+
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
64
|
+
attn_logits = torch.empty(
|
65
|
+
(self.num_head, total_num_tokens),
|
66
|
+
dtype=self.reduce_dtype,
|
67
|
+
device="cuda",
|
68
|
+
)
|
69
|
+
|
70
|
+
max_seq_len = torch.max(forward_batch.seq_lens).item()
|
71
|
+
min_seq_len = torch.min(forward_batch.seq_lens).item()
|
72
|
+
max_extend_len = None
|
73
|
+
# NOTE: Align sequence order with req_to_token order
|
74
|
+
ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
|
75
|
+
forward_batch.req_pool_indices
|
76
|
+
]
|
77
|
+
|
78
|
+
bsz = forward_batch.seq_lens.shape[0]
|
79
|
+
|
80
|
+
att_out_approx = torch.empty(
|
81
|
+
[self.num_head, bsz, max_seq_len],
|
82
|
+
dtype=self.reduce_dtype,
|
83
|
+
device="cuda",
|
84
|
+
)
|
85
|
+
|
86
|
+
block_seq_num = (
|
87
|
+
self.heavy_token_num + self.BLOCK_SEQ - 1
|
88
|
+
) // self.BLOCK_SEQ
|
89
|
+
|
90
|
+
mid_out = torch.empty(
|
91
|
+
[bsz, self.num_head, block_seq_num, self.head_dim],
|
92
|
+
dtype=torch.float32,
|
93
|
+
device="cuda",
|
94
|
+
)
|
95
|
+
mid_o_logexpsum = torch.empty(
|
96
|
+
[bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
|
97
|
+
)
|
98
|
+
self.att_out_approx = att_out_approx
|
99
|
+
self.mid_out = mid_out
|
100
|
+
self.mid_o_logexpsum = mid_o_logexpsum
|
101
|
+
|
102
|
+
else:
|
103
|
+
start_loc = attn_logits = max_seq_len = min_seq_len = None
|
104
|
+
prefix_lens = forward_batch.extend_prefix_lens
|
105
|
+
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
106
|
+
ds_req_to_token = None
|
107
|
+
|
108
|
+
self.forward_metadata = (
|
109
|
+
start_loc,
|
110
|
+
attn_logits,
|
111
|
+
max_seq_len,
|
112
|
+
min_seq_len,
|
113
|
+
max_extend_len,
|
114
|
+
ds_req_to_token,
|
115
|
+
)
|
116
|
+
|
117
|
+
def init_cuda_graph_state(self, max_bs: int):
|
118
|
+
# TODO(Andy): Support CUDA graph for double sparse attention
|
119
|
+
raise ValueError(
|
120
|
+
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
121
|
+
)
|
122
|
+
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
123
|
+
|
124
|
+
self.cuda_graph_start_loc = torch.zeros(
|
125
|
+
(max_bs,), dtype=torch.int32, device="cuda"
|
126
|
+
)
|
127
|
+
self.cuda_graph_attn_logits = torch.empty(
|
128
|
+
(
|
129
|
+
self.num_head,
|
130
|
+
self.cuda_graph_max_total_num_tokens,
|
131
|
+
),
|
132
|
+
dtype=self.reduce_dtype,
|
133
|
+
device="cuda",
|
134
|
+
)
|
135
|
+
|
136
|
+
def init_forward_metadata_capture_cuda_graph(
|
137
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
138
|
+
):
|
139
|
+
self.forward_metadata = (
|
140
|
+
self.cuda_graph_start_loc,
|
141
|
+
self.cuda_graph_attn_logits,
|
142
|
+
self.cuda_graph_max_seq_len,
|
143
|
+
None,
|
144
|
+
)
|
145
|
+
|
146
|
+
def init_forward_metadata_replay_cuda_graph(
|
147
|
+
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
|
148
|
+
):
|
149
|
+
self.cuda_graph_start_loc.zero_()
|
150
|
+
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
151
|
+
|
152
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
153
|
+
return 1
|
154
|
+
|
155
|
+
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
156
|
+
# TODO: reuse the buffer across layers
|
157
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
158
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
159
|
+
else:
|
160
|
+
o = torch.empty_like(q)
|
161
|
+
|
162
|
+
k_label = torch.gather(
|
163
|
+
k,
|
164
|
+
2,
|
165
|
+
self.sorted_channels[layer.layer_id]
|
166
|
+
.unsqueeze(0)
|
167
|
+
.expand(k.shape[0], -1, -1),
|
168
|
+
)
|
169
|
+
|
170
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
171
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
172
|
+
)
|
173
|
+
|
174
|
+
(
|
175
|
+
start_loc,
|
176
|
+
attn_logits,
|
177
|
+
max_seq_len,
|
178
|
+
min_seq_len,
|
179
|
+
max_extend_len,
|
180
|
+
ds_req_to_token,
|
181
|
+
) = self.forward_metadata
|
182
|
+
self.extend_attention_fwd(
|
183
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
184
|
+
k.contiguous(),
|
185
|
+
v.contiguous(),
|
186
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
187
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
188
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
189
|
+
forward_batch.req_to_token_pool.req_to_token,
|
190
|
+
forward_batch.req_pool_indices,
|
191
|
+
forward_batch.seq_lens,
|
192
|
+
forward_batch.extend_seq_lens,
|
193
|
+
forward_batch.extend_start_loc,
|
194
|
+
max_extend_len,
|
195
|
+
layer.scaling,
|
196
|
+
layer.logit_cap,
|
197
|
+
)
|
198
|
+
return o
|
199
|
+
|
200
|
+
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
201
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
202
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
203
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
204
|
+
|
205
|
+
# TODO: reuse the buffer across layers
|
206
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
207
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
208
|
+
else:
|
209
|
+
o = torch.empty_like(q)
|
210
|
+
|
211
|
+
# TODO: Add min seqlen
|
212
|
+
(
|
213
|
+
start_loc,
|
214
|
+
attn_logits,
|
215
|
+
max_seq_len,
|
216
|
+
min_seq_len,
|
217
|
+
max_extend_len,
|
218
|
+
ds_req_to_token,
|
219
|
+
) = self.forward_metadata
|
220
|
+
|
221
|
+
k_label = torch.gather(
|
222
|
+
k,
|
223
|
+
2,
|
224
|
+
self.sorted_channels[layer.layer_id]
|
225
|
+
.unsqueeze(0)
|
226
|
+
.expand(k.shape[0], -1, -1),
|
227
|
+
)
|
228
|
+
|
229
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
230
|
+
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
|
231
|
+
)
|
232
|
+
|
233
|
+
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
234
|
+
# and set a minimum value for sparse_decode
|
235
|
+
if (
|
236
|
+
min_seq_len < self.heavy_token_num
|
237
|
+
or max_seq_len < self.sparse_decode_thresold
|
238
|
+
):
|
239
|
+
self.decode_attention_fwd(
|
240
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
241
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
242
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
243
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
244
|
+
forward_batch.req_to_token_pool.req_to_token,
|
245
|
+
forward_batch.req_pool_indices,
|
246
|
+
start_loc,
|
247
|
+
forward_batch.seq_lens,
|
248
|
+
attn_logits,
|
249
|
+
max_seq_len,
|
250
|
+
layer.scaling,
|
251
|
+
layer.logit_cap,
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
|
255
|
+
q_label = torch.gather(
|
256
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
257
|
+
2,
|
258
|
+
self.sorted_channels[layer.layer_id]
|
259
|
+
.unsqueeze(0)
|
260
|
+
.expand(q.shape[0], -1, -1),
|
261
|
+
)
|
262
|
+
self.decode_sparse_attention_fwd(
|
263
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
264
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
265
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
266
|
+
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
267
|
+
q_label,
|
268
|
+
forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
|
269
|
+
ds_req_to_token,
|
270
|
+
forward_batch.seq_lens,
|
271
|
+
max_seq_len,
|
272
|
+
layer.scaling,
|
273
|
+
layer.logit_cap,
|
274
|
+
self.heavy_token_num,
|
275
|
+
self.att_out_approx,
|
276
|
+
self.mid_out,
|
277
|
+
self.mid_o_logexpsum,
|
278
|
+
self.BLOCK_SEQ,
|
279
|
+
)
|
280
|
+
|
281
|
+
return o
|