sglang 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +14 -12
- sglang/srt/managers/controller/infer_batch.py +47 -49
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +35 -23
- sglang/srt/managers/controller/tp_worker.py +127 -138
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +19 -6
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +113 -84
- sglang/srt/server_args.py +23 -15
- sglang/srt/utils.py +16 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -167
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,8 @@ from torch import nn
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
9
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
-
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
11
10
|
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
11
|
+
from sglang.srt.server import global_server_args_dict
|
12
12
|
|
13
13
|
|
14
14
|
class RadixAttention(nn.Module):
|
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
|
|
136
136
|
return self.decode_forward(q, k, v, input_metadata)
|
137
137
|
|
138
138
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
139
|
+
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
140
|
+
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
141
|
+
|
142
|
+
|
143
|
+
try:
|
144
|
+
|
145
|
+
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
146
|
+
def _store_kv_cache(
|
147
|
+
k: torch.Tensor,
|
148
|
+
v: torch.Tensor,
|
149
|
+
kv_cache: torch.Tensor,
|
150
|
+
cache_loc: torch.Tensor,
|
151
|
+
) -> None:
|
152
|
+
kv_cache[cache_loc, 0] = k
|
153
|
+
kv_cache[cache_loc, 1] = v
|
154
|
+
|
155
|
+
@_store_kv_cache.register_fake
|
156
|
+
def _(k, v, kv_cache, cache_loc):
|
157
|
+
pass
|
158
|
+
|
159
|
+
except:
|
160
|
+
|
161
|
+
def _store_kv_cache(
|
162
|
+
k: torch.Tensor,
|
163
|
+
v: torch.Tensor,
|
164
|
+
kv_cache: torch.Tensor,
|
165
|
+
cache_loc: torch.Tensor,
|
166
|
+
) -> None:
|
167
|
+
kv_cache[cache_loc, 0] = k
|
168
|
+
kv_cache[cache_loc, 1] = v
|
@@ -5,8 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.
|
9
|
-
from sglang.srt.utils import wrap_kernel_launcher
|
8
|
+
from sglang.srt.server import global_server_args_dict
|
10
9
|
|
11
10
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
12
11
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
|
|
162
161
|
tl.store(out_ptrs, acc)
|
163
162
|
|
164
163
|
|
165
|
-
cached_kernel_stage1 = None
|
166
|
-
cached_kernel_stage2 = None
|
167
|
-
|
168
|
-
|
169
164
|
def _token_att_m_fwd(
|
170
165
|
q,
|
171
166
|
k_buffer,
|
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
|
|
194
189
|
else:
|
195
190
|
num_warps = 2
|
196
191
|
|
197
|
-
global cached_kernel_stage1
|
198
|
-
if cached_kernel_stage1:
|
199
|
-
cached_kernel_stage1(
|
200
|
-
grid,
|
201
|
-
num_warps,
|
202
|
-
q,
|
203
|
-
k_buffer,
|
204
|
-
sm_scale,
|
205
|
-
Req_to_tokens,
|
206
|
-
B_req_idx,
|
207
|
-
B_Start_Loc,
|
208
|
-
B_Seqlen,
|
209
|
-
att_out,
|
210
|
-
Req_to_tokens.stride(0),
|
211
|
-
q.stride(0),
|
212
|
-
q.stride(1),
|
213
|
-
k_buffer.stride(0),
|
214
|
-
k_buffer.stride(1),
|
215
|
-
att_out.stride(0),
|
216
|
-
)
|
217
|
-
return
|
218
|
-
|
219
192
|
_fwd_kernel_stage1[grid](
|
220
193
|
q,
|
221
194
|
k_buffer,
|
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
|
|
238
211
|
num_warps=num_warps,
|
239
212
|
num_stages=1,
|
240
213
|
)
|
241
|
-
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
242
214
|
|
243
215
|
|
244
216
|
def _token_softmax_reducev_fwd(
|
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
|
|
257
229
|
|
258
230
|
num_warps = 1
|
259
231
|
|
260
|
-
global cached_kernel_stage2
|
261
|
-
if cached_kernel_stage2:
|
262
|
-
cached_kernel_stage2(
|
263
|
-
grid,
|
264
|
-
num_warps,
|
265
|
-
logics,
|
266
|
-
v_buffer,
|
267
|
-
o,
|
268
|
-
req_to_tokens,
|
269
|
-
b_req_idx,
|
270
|
-
b_start_loc,
|
271
|
-
b_seq_len,
|
272
|
-
logics.stride(0),
|
273
|
-
v_buffer.stride(0),
|
274
|
-
v_buffer.stride(1),
|
275
|
-
o.stride(0),
|
276
|
-
o.stride(1),
|
277
|
-
req_to_tokens.stride(0),
|
278
|
-
)
|
279
|
-
return
|
280
|
-
|
281
232
|
_fwd_kernel_stage2[grid](
|
282
233
|
logics,
|
283
234
|
v_buffer,
|
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|
298
249
|
num_warps=num_warps,
|
299
250
|
num_stages=3,
|
300
251
|
)
|
301
|
-
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
302
252
|
|
303
253
|
|
304
254
|
def token_attention_fwd(
|
@@ -3,9 +3,10 @@
|
|
3
3
|
import bisect
|
4
4
|
|
5
5
|
import torch
|
6
|
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
7
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
6
8
|
from vllm.distributed.parallel_state import graph_capture
|
7
9
|
|
8
|
-
from sglang.global_config import global_config
|
9
10
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
11
|
from sglang.srt.managers.controller.infer_batch import (
|
11
12
|
Batch,
|
@@ -74,9 +75,6 @@ class CudaGraphRunner:
|
|
74
75
|
self.flashinfer_handlers[bs] = flashinfer_handler
|
75
76
|
|
76
77
|
def capture_one_batch_size(self, bs):
|
77
|
-
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
78
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
79
|
-
|
80
78
|
graph = torch.cuda.CUDAGraph()
|
81
79
|
stream = self.stream
|
82
80
|
|
@@ -152,8 +150,8 @@ class CudaGraphRunner:
|
|
152
150
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
153
151
|
bs = self.batch_size_list[index]
|
154
152
|
if bs != raw_bs:
|
155
|
-
self.seq_lens.
|
156
|
-
self.position_ids_offsets.
|
153
|
+
self.seq_lens.fill_(1)
|
154
|
+
self.position_ids_offsets.zero_()
|
157
155
|
self.out_cache_loc.zero_()
|
158
156
|
|
159
157
|
# Common inputs
|
@@ -183,14 +181,18 @@ class CudaGraphRunner:
|
|
183
181
|
else:
|
184
182
|
output = LogitProcessorOutput(
|
185
183
|
next_token_logits=output.next_token_logits[:raw_bs],
|
186
|
-
next_token_logprobs=
|
187
|
-
|
188
|
-
|
184
|
+
next_token_logprobs=(
|
185
|
+
output.next_token_logprobs[:raw_bs]
|
186
|
+
if output.next_token_logprobs is not None
|
187
|
+
else None
|
188
|
+
),
|
189
189
|
normalized_prompt_logprobs=None,
|
190
190
|
prefill_token_logprobs=None,
|
191
191
|
prefill_top_logprobs=None,
|
192
|
-
decode_top_logprobs=
|
193
|
-
|
194
|
-
|
192
|
+
decode_top_logprobs=(
|
193
|
+
output.decode_top_logprobs[:raw_bs]
|
194
|
+
if output.decode_top_logprobs is not None
|
195
|
+
else None
|
196
|
+
),
|
195
197
|
)
|
196
198
|
return output
|
@@ -7,6 +7,7 @@ from typing import List, Union
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
10
|
+
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
10
11
|
|
11
12
|
from sglang.srt.constrained import RegexGuide
|
12
13
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
@@ -15,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
15
16
|
|
16
17
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
18
|
|
18
|
-
# Store some global server args
|
19
|
-
global_server_args_dict = {}
|
20
|
-
|
21
19
|
|
22
20
|
class ForwardMode(IntEnum):
|
23
21
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -84,6 +82,15 @@ class Req:
|
|
84
82
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
85
83
|
|
86
84
|
# For incremental decoding
|
85
|
+
# ----- | --------- read_ids -------|
|
86
|
+
# ----- | surr_ids |
|
87
|
+
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
88
|
+
# ----- ^ ----------- ^ ----------- ^
|
89
|
+
# ----- 1 ----------- 2 ----------- 3
|
90
|
+
# 1: surr_offset
|
91
|
+
# 2: read_offset
|
92
|
+
# 3: last token
|
93
|
+
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
87
94
|
self.decoded_text = ""
|
88
95
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
89
96
|
self.read_offset = None
|
@@ -134,7 +141,7 @@ class Req:
|
|
134
141
|
return self.finished_reason is not None
|
135
142
|
|
136
143
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
137
|
-
def
|
144
|
+
def init_incremental_detokenize(self):
|
138
145
|
first_iter = self.surr_offset is None or self.read_offset is None
|
139
146
|
|
140
147
|
if first_iter:
|
@@ -144,13 +151,11 @@ class Req:
|
|
144
151
|
)
|
145
152
|
|
146
153
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
147
|
-
|
148
|
-
read_ids = all_ids[self.surr_offset :]
|
149
|
-
|
150
|
-
return surr_ids, read_ids, len(all_ids)
|
154
|
+
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
151
155
|
|
152
|
-
def
|
153
|
-
|
156
|
+
def get_next_inc_detokenization(self):
|
157
|
+
read_ids, read_offset = self.init_incremental_detokenize()
|
158
|
+
surr_ids = read_ids[:read_offset]
|
154
159
|
|
155
160
|
surr_text = self.tokenizer.decode(
|
156
161
|
surr_ids,
|
@@ -164,13 +169,7 @@ class Req:
|
|
164
169
|
)
|
165
170
|
|
166
171
|
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
167
|
-
|
168
|
-
if inplace:
|
169
|
-
self.decoded_text += new_text
|
170
|
-
self.surr_offset = self.read_offset
|
171
|
-
self.read_offset = num_all_tokens
|
172
|
-
|
173
|
-
return True, new_text
|
172
|
+
return True, new_text[len(surr_text) :]
|
174
173
|
|
175
174
|
return False, ""
|
176
175
|
|
@@ -272,6 +271,7 @@ class Batch:
|
|
272
271
|
prefix_lens: torch.Tensor = None
|
273
272
|
position_ids_offsets: torch.Tensor = None
|
274
273
|
out_cache_loc: torch.Tensor = None
|
274
|
+
extend_num_tokens: int = None
|
275
275
|
|
276
276
|
# For processing logprobs
|
277
277
|
return_logprob: bool = False
|
@@ -282,10 +282,6 @@ class Batch:
|
|
282
282
|
image_sizes: List[List[int]] = None
|
283
283
|
image_offsets: List[int] = None
|
284
284
|
|
285
|
-
# Other arguments for control
|
286
|
-
output_ids: torch.Tensor = None
|
287
|
-
extend_num_tokens: int = None
|
288
|
-
|
289
285
|
# Batched sampling params
|
290
286
|
temperatures: torch.Tensor = None
|
291
287
|
top_ps: torch.Tensor = None
|
@@ -327,6 +323,13 @@ class Batch:
|
|
327
323
|
seq_lens = []
|
328
324
|
|
329
325
|
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
326
|
+
|
327
|
+
if req_pool_indices is None:
|
328
|
+
raise RuntimeError(
|
329
|
+
"Out of memory. "
|
330
|
+
"Please set a smaller number for `--max-running-requests`."
|
331
|
+
)
|
332
|
+
|
330
333
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
331
334
|
for i in range(bs):
|
332
335
|
flatten_input_ids.extend(input_ids[i])
|
@@ -398,10 +401,10 @@ class Batch:
|
|
398
401
|
).view(-1, 1)
|
399
402
|
self.top_ps = torch.tensor(
|
400
403
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
401
|
-
)
|
404
|
+
)
|
402
405
|
self.top_ks = torch.tensor(
|
403
406
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
404
|
-
)
|
407
|
+
)
|
405
408
|
self.frequency_penalties = torch.tensor(
|
406
409
|
[r.sampling_params.frequency_penalty for r in reqs],
|
407
410
|
dtype=torch.float,
|
@@ -499,7 +502,7 @@ class Batch:
|
|
499
502
|
cur_output_ids = req.output_ids
|
500
503
|
|
501
504
|
req.output_ids.extend(suffix_ids)
|
502
|
-
decode_res, new_text = req.
|
505
|
+
decode_res, new_text = req.get_next_inc_detokenization()
|
503
506
|
if not decode_res:
|
504
507
|
req.output_ids = cur_output_ids
|
505
508
|
continue
|
@@ -518,6 +521,9 @@ class Batch:
|
|
518
521
|
req.output_ids = cur_output_ids
|
519
522
|
continue
|
520
523
|
|
524
|
+
# The decode status has diverged from detokenizer_manager
|
525
|
+
req.vid += 1
|
526
|
+
|
521
527
|
# insert the old request into tree_cache
|
522
528
|
if req_pool_indices_cpu is None:
|
523
529
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
@@ -659,20 +665,21 @@ class Batch:
|
|
659
665
|
|
660
666
|
# TODO(lmzheng): apply penalty
|
661
667
|
probs = torch.softmax(logits, dim=-1)
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
-1
|
668
|
+
|
669
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
670
|
+
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
671
|
+
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
672
|
+
probs, uniform_samples, self.top_ks, self.top_ps
|
673
|
+
)
|
674
|
+
|
675
|
+
# FIXME: this is a temporary fix for the illegal token ids
|
676
|
+
illegal_mask = torch.logical_or(
|
677
|
+
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
|
672
678
|
)
|
673
|
-
|
674
|
-
|
675
|
-
|
679
|
+
if torch.any(illegal_mask):
|
680
|
+
warnings.warn("Illegal sampled token ids")
|
681
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
682
|
+
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
676
683
|
|
677
684
|
if has_regex:
|
678
685
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
@@ -682,18 +689,7 @@ class Batch:
|
|
682
689
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
683
690
|
)
|
684
691
|
|
685
|
-
return batch_next_token_ids
|
686
|
-
|
687
|
-
|
688
|
-
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
689
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
690
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
691
|
-
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
692
|
-
probs_sort[
|
693
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
694
|
-
] = 0.0
|
695
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
696
|
-
return probs_sort, probs_idx
|
692
|
+
return batch_next_token_ids
|
697
693
|
|
698
694
|
|
699
695
|
@dataclass
|
@@ -829,6 +825,7 @@ def init_flashinfer_args(
|
|
829
825
|
prefix_lens,
|
830
826
|
flashinfer_decode_wrapper,
|
831
827
|
):
|
828
|
+
"""Init auxiliary variables for FlashInfer attention backend."""
|
832
829
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
833
830
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
834
831
|
head_dim = model_runner.model_config.head_dim
|
@@ -894,6 +891,7 @@ def init_flashinfer_args(
|
|
894
891
|
|
895
892
|
|
896
893
|
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
894
|
+
"""Init auxiliary variables for triton attention backend."""
|
897
895
|
batch_size = len(seq_lens)
|
898
896
|
max_seq_len = int(torch.max(seq_lens))
|
899
897
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|