sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -1,237 +0,0 @@
|
|
1
|
-
from enum import Enum, auto
|
2
|
-
|
3
|
-
import torch
|
4
|
-
import triton
|
5
|
-
import triton.language as tl
|
6
|
-
|
7
|
-
|
8
|
-
class WrapperDispatch(Enum):
|
9
|
-
SLIDING_WINDOW = auto()
|
10
|
-
CROSS_ATTENTION = auto()
|
11
|
-
|
12
|
-
|
13
|
-
@triton.jit
|
14
|
-
def create_flashinfer_kv_indices_triton(
|
15
|
-
req_to_token_ptr, # [max_batch, max_context_len]
|
16
|
-
req_pool_indices_ptr,
|
17
|
-
page_kernel_lens_ptr,
|
18
|
-
kv_indptr,
|
19
|
-
kv_start_idx,
|
20
|
-
kv_indices_ptr,
|
21
|
-
max_context_len: tl.constexpr,
|
22
|
-
):
|
23
|
-
BLOCK_SIZE: tl.constexpr = 512
|
24
|
-
pid = tl.program_id(axis=0)
|
25
|
-
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
26
|
-
kv_indices_offset = tl.load(kv_indptr + pid)
|
27
|
-
|
28
|
-
kv_start = 0
|
29
|
-
kv_end = 0
|
30
|
-
if kv_start_idx:
|
31
|
-
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
32
|
-
kv_end = kv_start
|
33
|
-
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
34
|
-
|
35
|
-
req_to_token_ptr += req_pool_index * max_context_len
|
36
|
-
kv_indices_ptr += kv_indices_offset
|
37
|
-
|
38
|
-
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
39
|
-
st_offset = tl.arange(0, BLOCK_SIZE)
|
40
|
-
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
41
|
-
for _ in range(num_loop):
|
42
|
-
mask = ld_offset < kv_end
|
43
|
-
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
44
|
-
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
45
|
-
ld_offset += BLOCK_SIZE
|
46
|
-
st_offset += BLOCK_SIZE
|
47
|
-
|
48
|
-
|
49
|
-
class FlashinferUpdater:
|
50
|
-
def __init__(
|
51
|
-
self,
|
52
|
-
forward_mode,
|
53
|
-
model_runner,
|
54
|
-
req_pool_indices,
|
55
|
-
seq_lens,
|
56
|
-
prefix_lens,
|
57
|
-
decode_wrappers=None,
|
58
|
-
use_ragged=False,
|
59
|
-
):
|
60
|
-
self.forward_mode = forward_mode
|
61
|
-
self.model_runner = model_runner
|
62
|
-
self.req_pool_indices = req_pool_indices
|
63
|
-
self.seq_lens = seq_lens
|
64
|
-
self.prefix_lens = prefix_lens
|
65
|
-
self.use_ragged = use_ragged
|
66
|
-
|
67
|
-
self.num_qo_heads = (
|
68
|
-
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
69
|
-
)
|
70
|
-
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
71
|
-
model_runner.tp_size
|
72
|
-
)
|
73
|
-
self.head_dim = model_runner.model_config.head_dim
|
74
|
-
self.batch_size = len(req_pool_indices)
|
75
|
-
|
76
|
-
self.decode_wrappers = (
|
77
|
-
decode_wrappers or self.model_runner.attn_backend.decode_wrappers
|
78
|
-
)
|
79
|
-
self.prefill_wrapper_ragged = (
|
80
|
-
self.model_runner.attn_backend.prefill_wrapper_ragged
|
81
|
-
)
|
82
|
-
self.prefill_wrappers_paged = (
|
83
|
-
self.model_runner.attn_backend.prefill_wrappers_paged
|
84
|
-
)
|
85
|
-
|
86
|
-
self.kv_last_page_len = torch.ones(
|
87
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
88
|
-
)
|
89
|
-
|
90
|
-
def _update_decode_indices(self, decode_wrapper):
|
91
|
-
assert not isinstance(decode_wrapper, list)
|
92
|
-
decode_wrapper.end_forward()
|
93
|
-
decode_wrapper.begin_forward(
|
94
|
-
self.kv_indptr,
|
95
|
-
self.kv_indices,
|
96
|
-
self.kv_last_page_len,
|
97
|
-
self.num_qo_heads,
|
98
|
-
self.num_kv_heads,
|
99
|
-
self.head_dim,
|
100
|
-
1,
|
101
|
-
data_type=self.model_runner.kv_cache_dtype,
|
102
|
-
q_data_type=self.model_runner.dtype,
|
103
|
-
)
|
104
|
-
|
105
|
-
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
106
|
-
assert not isinstance(paged_wrapper, list)
|
107
|
-
assert not isinstance(ragged_wrapper, list)
|
108
|
-
|
109
|
-
# extend part
|
110
|
-
qo_indptr = torch.zeros(
|
111
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
112
|
-
)
|
113
|
-
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
114
|
-
|
115
|
-
if self.use_ragged:
|
116
|
-
ragged_wrapper.end_forward()
|
117
|
-
ragged_wrapper.begin_forward(
|
118
|
-
qo_indptr,
|
119
|
-
qo_indptr,
|
120
|
-
self.num_qo_heads,
|
121
|
-
self.num_kv_heads,
|
122
|
-
self.head_dim,
|
123
|
-
)
|
124
|
-
|
125
|
-
# cached part
|
126
|
-
paged_wrapper.end_forward()
|
127
|
-
paged_wrapper.begin_forward(
|
128
|
-
qo_indptr,
|
129
|
-
self.kv_indptr,
|
130
|
-
self.kv_indices,
|
131
|
-
self.kv_last_page_len,
|
132
|
-
self.num_qo_heads,
|
133
|
-
self.num_kv_heads,
|
134
|
-
self.head_dim,
|
135
|
-
1,
|
136
|
-
)
|
137
|
-
|
138
|
-
def _get_indices(self, dispatch_reason: WrapperDispatch = None, wrapper_id=0):
|
139
|
-
if dispatch_reason is None:
|
140
|
-
if self.use_ragged:
|
141
|
-
paged_kernel_lens = self.prefix_lens
|
142
|
-
else:
|
143
|
-
paged_kernel_lens = self.seq_lens
|
144
|
-
self.kv_start_idx = None
|
145
|
-
elif dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
146
|
-
if wrapper_id == 0:
|
147
|
-
# window attention use paged only
|
148
|
-
if self.forward_mode.is_decode():
|
149
|
-
paged_kernel_lens = torch.minimum(
|
150
|
-
self.seq_lens,
|
151
|
-
torch.tensor(self.model_runner.sliding_window_size + 1),
|
152
|
-
)
|
153
|
-
else:
|
154
|
-
paged_kernel_lens = torch.minimum(
|
155
|
-
self.seq_lens,
|
156
|
-
torch.tensor(self.model_runner.sliding_window_size)
|
157
|
-
+ self.seq_lens
|
158
|
-
- self.prefix_lens,
|
159
|
-
)
|
160
|
-
else:
|
161
|
-
# full attention
|
162
|
-
paged_kernel_lens = self.seq_lens
|
163
|
-
self.kv_start_idx = self.seq_lens - paged_kernel_lens
|
164
|
-
|
165
|
-
self.kv_indptr = torch.zeros(
|
166
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
167
|
-
)
|
168
|
-
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
169
|
-
self.kv_indices = torch.empty(
|
170
|
-
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
171
|
-
)
|
172
|
-
|
173
|
-
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
174
|
-
self.model_runner.req_to_token_pool.req_to_token,
|
175
|
-
self.req_pool_indices,
|
176
|
-
paged_kernel_lens,
|
177
|
-
self.kv_indptr,
|
178
|
-
self.kv_start_idx,
|
179
|
-
self.kv_indices,
|
180
|
-
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
181
|
-
)
|
182
|
-
|
183
|
-
def _update_indicess_single_wrapper(self):
|
184
|
-
self._get_indices()
|
185
|
-
|
186
|
-
if self.forward_mode.is_decode():
|
187
|
-
self._update_decode_indices(self.decode_wrappers[0])
|
188
|
-
else:
|
189
|
-
self._update_extend_indices(
|
190
|
-
self.prefill_wrapper_ragged,
|
191
|
-
self.prefill_wrappers_paged[0],
|
192
|
-
)
|
193
|
-
|
194
|
-
def _update_indices_cross_attention(self):
|
195
|
-
pass
|
196
|
-
|
197
|
-
def _update_indices_sliding_window(self):
|
198
|
-
assert self.use_ragged is False
|
199
|
-
for wrapper_id in range(2):
|
200
|
-
self._get_indices(WrapperDispatch.SLIDING_WINDOW, wrapper_id)
|
201
|
-
if self.forward_mode.is_decode():
|
202
|
-
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
203
|
-
else:
|
204
|
-
self._update_extend_indices(
|
205
|
-
None,
|
206
|
-
self.prefill_wrappers_paged[wrapper_id],
|
207
|
-
)
|
208
|
-
|
209
|
-
|
210
|
-
def update_flashinfer_indices(
|
211
|
-
forward_mode,
|
212
|
-
model_runner,
|
213
|
-
req_pool_indices,
|
214
|
-
seq_lens,
|
215
|
-
prefix_lens,
|
216
|
-
decode_wrappers=None,
|
217
|
-
use_ragged=False,
|
218
|
-
):
|
219
|
-
updater = FlashinferUpdater(
|
220
|
-
forward_mode,
|
221
|
-
model_runner,
|
222
|
-
req_pool_indices,
|
223
|
-
seq_lens,
|
224
|
-
prefix_lens,
|
225
|
-
decode_wrappers,
|
226
|
-
use_ragged,
|
227
|
-
)
|
228
|
-
|
229
|
-
dispatch_reason = model_runner.attn_backend.dispatch_reason
|
230
|
-
|
231
|
-
if dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
232
|
-
updater._update_indices_sliding_window()
|
233
|
-
elif dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
234
|
-
updater._update_indices_cross_attention()
|
235
|
-
else:
|
236
|
-
assert model_runner.attn_backend.num_wrappers == 1
|
237
|
-
updater._update_indicess_single_wrapper()
|
File without changes
|
File without changes
|