sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -129,6 +129,7 @@ class ImageInputs:
|
|
129
129
|
image_hashes: Optional[list] = None
|
130
130
|
image_sizes: Optional[list] = None
|
131
131
|
image_offsets: Optional[list] = None
|
132
|
+
image_pad_len: Optional[list] = None
|
132
133
|
pad_values: Optional[list] = None
|
133
134
|
modalities: Optional[list] = None
|
134
135
|
num_image_tokens: Optional[int] = None
|
@@ -181,6 +182,7 @@ class ImageInputs:
|
|
181
182
|
optional_args = [
|
182
183
|
"image_sizes",
|
183
184
|
"image_offsets",
|
185
|
+
"image_pad_len",
|
184
186
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
185
187
|
"aspect_ratio_ids",
|
186
188
|
"aspect_ratio_mask",
|
@@ -200,6 +202,9 @@ class Req:
|
|
200
202
|
origin_input_text: str,
|
201
203
|
origin_input_ids: Tuple[int],
|
202
204
|
sampling_params: SamplingParams,
|
205
|
+
return_logprob: bool = False,
|
206
|
+
top_logprobs_num: int = 0,
|
207
|
+
stream: bool = False,
|
203
208
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
204
209
|
lora_path: Optional[str] = None,
|
205
210
|
input_embeds: Optional[List[List[float]]] = None,
|
@@ -217,10 +222,11 @@ class Req:
|
|
217
222
|
self.output_ids = [] # Each decode stage's output ids
|
218
223
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
219
224
|
self.session_id = session_id
|
225
|
+
self.input_embeds = input_embeds
|
220
226
|
|
227
|
+
# Sampling info
|
221
228
|
self.sampling_params = sampling_params
|
222
229
|
self.lora_path = lora_path
|
223
|
-
self.input_embeds = input_embeds
|
224
230
|
|
225
231
|
# Memory pool info
|
226
232
|
self.req_pool_idx = None
|
@@ -228,8 +234,8 @@ class Req:
|
|
228
234
|
# Check finish
|
229
235
|
self.tokenizer = None
|
230
236
|
self.finished_reason = None
|
231
|
-
self.stream = False
|
232
237
|
self.to_abort = False
|
238
|
+
self.stream = stream
|
233
239
|
|
234
240
|
# For incremental decoding
|
235
241
|
# ----- | --------- read_ids -------|
|
@@ -241,37 +247,46 @@ class Req:
|
|
241
247
|
# 2: read_offset
|
242
248
|
# 3: last token
|
243
249
|
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
244
|
-
self.decoded_text = ""
|
245
250
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
246
251
|
self.read_offset = None
|
247
|
-
|
248
|
-
# The number of decoded tokens for token usage report. Note that
|
249
|
-
# this does not include the jump forward tokens.
|
250
|
-
self.completion_tokens_wo_jump_forward = 0
|
252
|
+
self.decoded_text = ""
|
251
253
|
|
252
254
|
# For multimodal inputs
|
253
255
|
self.image_inputs: Optional[ImageInputs] = None
|
254
256
|
|
255
257
|
# Prefix info
|
256
258
|
self.prefix_indices = []
|
259
|
+
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
257
260
|
self.extend_input_len = 0
|
258
261
|
self.last_node = None
|
262
|
+
|
263
|
+
# Chunked prefill
|
259
264
|
self.is_being_chunked = 0
|
260
265
|
|
261
266
|
# For retraction
|
262
267
|
self.is_retracted = False
|
263
268
|
|
264
269
|
# Logprobs (arguments)
|
265
|
-
self.return_logprob =
|
270
|
+
self.return_logprob = return_logprob
|
266
271
|
self.logprob_start_len = 0
|
267
|
-
self.top_logprobs_num =
|
272
|
+
self.top_logprobs_num = top_logprobs_num
|
268
273
|
|
269
274
|
# Logprobs (return value)
|
270
275
|
self.normalized_prompt_logprob = None
|
271
|
-
self.
|
272
|
-
self.
|
273
|
-
self.
|
274
|
-
self.
|
276
|
+
self.input_token_logprobs_val = None
|
277
|
+
self.input_token_logprobs_idx = None
|
278
|
+
self.input_top_logprobs_val = None
|
279
|
+
self.input_top_logprobs_idx = None
|
280
|
+
|
281
|
+
if return_logprob:
|
282
|
+
self.output_token_logprobs_val = []
|
283
|
+
self.output_token_logprobs_idx = []
|
284
|
+
self.output_top_logprobs_val = []
|
285
|
+
self.output_top_logprobs_idx = []
|
286
|
+
else:
|
287
|
+
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
288
|
+
self.output_top_logprobs_val
|
289
|
+
) = self.output_top_logprobs_idx = None
|
275
290
|
|
276
291
|
# Logprobs (internal values)
|
277
292
|
# The tokens is prefilled but need to be considered as decode tokens
|
@@ -295,13 +310,14 @@ class Req:
|
|
295
310
|
else:
|
296
311
|
self.image_inputs.merge(image_inputs)
|
297
312
|
|
298
|
-
# whether request reached finished condition
|
299
313
|
def finished(self) -> bool:
|
314
|
+
# Whether request reached finished condition
|
300
315
|
return self.finished_reason is not None
|
301
316
|
|
302
317
|
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
303
318
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
304
319
|
if tree_cache is not None:
|
320
|
+
# tree cache is None if the prefix is not computed with tree cache.
|
305
321
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
306
322
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
307
323
|
)
|
@@ -454,15 +470,31 @@ class Req:
|
|
454
470
|
k = k + 1
|
455
471
|
else:
|
456
472
|
break
|
457
|
-
self.
|
458
|
-
self.
|
473
|
+
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
474
|
+
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
475
|
+
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
476
|
+
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
459
477
|
self.logprob_start_len = prompt_tokens + k
|
460
478
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
461
479
|
|
462
480
|
return True
|
463
481
|
|
482
|
+
def reset_for_retract(self):
|
483
|
+
self.prefix_indices = []
|
484
|
+
self.last_node = None
|
485
|
+
self.extend_input_len = 0
|
486
|
+
self.is_retracted = True
|
487
|
+
|
488
|
+
# For incremental logprobs
|
489
|
+
# TODO: Fix the `logprob_start_len`
|
490
|
+
self.last_update_decode_tokens = 0
|
491
|
+
self.logprob_start_len = 10**9
|
492
|
+
|
464
493
|
def __repr__(self):
|
465
|
-
return
|
494
|
+
return (
|
495
|
+
f"rid(n={self.rid}, "
|
496
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
497
|
+
)
|
466
498
|
|
467
499
|
|
468
500
|
bid = 0
|
@@ -470,7 +502,7 @@ bid = 0
|
|
470
502
|
|
471
503
|
@dataclasses.dataclass
|
472
504
|
class ScheduleBatch:
|
473
|
-
"""Store all
|
505
|
+
"""Store all information of a batch on the scheduler."""
|
474
506
|
|
475
507
|
# Request, memory pool, and cache
|
476
508
|
reqs: List[Req]
|
@@ -876,15 +908,7 @@ class ScheduleBatch:
|
|
876
908
|
)
|
877
909
|
residual_size = max(0, residual_size)
|
878
910
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
879
|
-
|
880
|
-
req.prefix_indices = []
|
881
|
-
req.last_node = None
|
882
|
-
req.extend_input_len = 0
|
883
|
-
req.is_retracted = True
|
884
|
-
|
885
|
-
# For incremental logprobs
|
886
|
-
req.last_update_decode_tokens = 0
|
887
|
-
req.logprob_start_len = 10**9
|
911
|
+
req.reset_for_retract()
|
888
912
|
|
889
913
|
self.filter_batch(keep_indices=sorted_indices)
|
890
914
|
|
@@ -1068,9 +1092,9 @@ class ScheduleBatch:
|
|
1068
1092
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1069
1093
|
self.reqs.extend(other.reqs)
|
1070
1094
|
|
1071
|
-
self.return_logprob
|
1072
|
-
self.has_stream
|
1073
|
-
self.has_grammar
|
1095
|
+
self.return_logprob |= other.return_logprob
|
1096
|
+
self.has_stream |= other.has_stream
|
1097
|
+
self.has_grammar |= other.has_grammar
|
1074
1098
|
|
1075
1099
|
def get_model_worker_batch(self):
|
1076
1100
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
@@ -1097,7 +1121,6 @@ class ScheduleBatch:
|
|
1097
1121
|
seq_lens=self.seq_lens,
|
1098
1122
|
out_cache_loc=self.out_cache_loc,
|
1099
1123
|
seq_lens_sum=self.seq_lens_sum,
|
1100
|
-
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
1101
1124
|
return_logprob=self.return_logprob,
|
1102
1125
|
top_logprobs_nums=self.top_logprobs_nums,
|
1103
1126
|
global_num_tokens=self.global_num_tokens,
|
@@ -1152,9 +1175,6 @@ class ModelWorkerBatch:
|
|
1152
1175
|
# The sum of all sequence lengths
|
1153
1176
|
seq_lens_sum: int
|
1154
1177
|
|
1155
|
-
# The memory pool operation records
|
1156
|
-
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
1157
|
-
|
1158
1178
|
# For logprob
|
1159
1179
|
return_logprob: bool
|
1160
1180
|
top_logprobs_nums: Optional[List[int]]
|
@@ -20,9 +20,11 @@ from contextlib import contextmanager
|
|
20
20
|
from enum import Enum, auto
|
21
21
|
from typing import Dict, List, Optional
|
22
22
|
|
23
|
+
import torch
|
24
|
+
|
23
25
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
24
26
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
25
|
-
from sglang.srt.mem_cache.radix_cache import TreeNode
|
27
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
26
28
|
|
27
29
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
28
30
|
# This can prevent the server from being too conservative.
|
@@ -32,6 +34,21 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
|
|
32
34
|
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
33
35
|
)
|
34
36
|
|
37
|
+
# Threshold for in-batch prefix cache.
|
38
|
+
# If a request has a matched prefix length (against existing cache) less than this value,
|
39
|
+
# the scheduler runs the in-batch prefix caching check for this request.
|
40
|
+
# If we set it to -1, it means we disable in-batch prefix caching.
|
41
|
+
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
|
42
|
+
os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
|
43
|
+
)
|
44
|
+
|
45
|
+
# Threshold for in-batch prefix cache.
|
46
|
+
# If a request has a matched prefix length (within the waiting queue) larger than this value,
|
47
|
+
# the scheduler deprioritizes this request
|
48
|
+
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
49
|
+
os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
|
50
|
+
)
|
51
|
+
|
35
52
|
|
36
53
|
class SchedulePolicy:
|
37
54
|
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
@@ -42,6 +59,11 @@ class SchedulePolicy:
|
|
42
59
|
self.policy = policy
|
43
60
|
self.tree_cache = tree_cache
|
44
61
|
|
62
|
+
# It is used to find the matching prefix for in-batch prefix caching.
|
63
|
+
self.waiting_queue_radix_tree = RadixCache(
|
64
|
+
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
65
|
+
)
|
66
|
+
|
45
67
|
def calc_priority(self, waiting_queue: List[Req]):
|
46
68
|
if len(waiting_queue) > 128 and self.policy == "lpm":
|
47
69
|
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
@@ -52,17 +74,53 @@ class SchedulePolicy:
|
|
52
74
|
# Compute matched prefix length
|
53
75
|
prefix_computed = False
|
54
76
|
if policy == "lpm" or policy == "dfs-weight":
|
77
|
+
# rid to deprioritize in the current run for in-batch prefix caching.
|
78
|
+
temporary_deprioritized = set()
|
79
|
+
self.waiting_queue_radix_tree.reset()
|
80
|
+
|
55
81
|
for r in waiting_queue:
|
82
|
+
prefix_ids = r.adjust_max_prefix_ids()
|
83
|
+
|
56
84
|
# NOTE: the prefix_indices must always be aligned with last_node
|
57
85
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
58
|
-
rid=r.rid, key=
|
86
|
+
rid=r.rid, key=prefix_ids
|
59
87
|
)
|
60
88
|
|
89
|
+
# NOTE(sang): This logic is for in-batch prefix caching;
|
90
|
+
# If there are more than 1 request that have small matching prefix from
|
91
|
+
# existing cache, but all those requests share the same prefix, we prefer
|
92
|
+
# to schedule only one of them so that we can increase the cache hit rate.
|
93
|
+
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
94
|
+
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
95
|
+
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
96
|
+
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
97
|
+
in_batch_matching_prefixes, _ = (
|
98
|
+
self.waiting_queue_radix_tree.match_prefix(
|
99
|
+
rid=r.rid, key=prefix_ids
|
100
|
+
)
|
101
|
+
)
|
102
|
+
if (
|
103
|
+
len(in_batch_matching_prefixes)
|
104
|
+
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
105
|
+
):
|
106
|
+
temporary_deprioritized.add(r.rid)
|
107
|
+
else:
|
108
|
+
# Insert with a dummy key
|
109
|
+
self.waiting_queue_radix_tree.insert(
|
110
|
+
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
111
|
+
)
|
112
|
+
|
61
113
|
prefix_computed = True
|
62
114
|
|
63
115
|
if policy == "lpm":
|
64
116
|
# Longest Prefix Match
|
65
|
-
waiting_queue.sort(
|
117
|
+
waiting_queue.sort(
|
118
|
+
key=lambda r: (
|
119
|
+
-len(r.prefix_indices)
|
120
|
+
if r.rid not in temporary_deprioritized
|
121
|
+
else float("inf")
|
122
|
+
)
|
123
|
+
)
|
66
124
|
elif policy == "fcfs":
|
67
125
|
# first come first serve
|
68
126
|
pass
|
@@ -72,6 +130,7 @@ class SchedulePolicy:
|
|
72
130
|
elif policy == "random":
|
73
131
|
random.shuffle(waiting_queue)
|
74
132
|
elif policy == "dfs-weight":
|
133
|
+
# Experimental policy based on custom weights
|
75
134
|
last_node_to_reqs = defaultdict(list)
|
76
135
|
for req in waiting_queue:
|
77
136
|
last_node_to_reqs[req.last_node].append(req)
|
@@ -101,8 +160,8 @@ class SchedulePolicy:
|
|
101
160
|
def get_dfs_priority(
|
102
161
|
self,
|
103
162
|
cur_node: TreeNode,
|
104
|
-
node_to_priority: Dict,
|
105
|
-
last_node_to_reqs: Dict,
|
163
|
+
node_to_priority: Dict[TreeNode, int],
|
164
|
+
last_node_to_reqs: Dict[TreeNode, List[Req]],
|
106
165
|
q: List,
|
107
166
|
):
|
108
167
|
childs = [child for child in cur_node.children.values()]
|