sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -129,6 +129,7 @@ class ImageInputs:
129
129
  image_hashes: Optional[list] = None
130
130
  image_sizes: Optional[list] = None
131
131
  image_offsets: Optional[list] = None
132
+ image_pad_len: Optional[list] = None
132
133
  pad_values: Optional[list] = None
133
134
  modalities: Optional[list] = None
134
135
  num_image_tokens: Optional[int] = None
@@ -181,6 +182,7 @@ class ImageInputs:
181
182
  optional_args = [
182
183
  "image_sizes",
183
184
  "image_offsets",
185
+ "image_pad_len",
184
186
  # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
185
187
  "aspect_ratio_ids",
186
188
  "aspect_ratio_mask",
@@ -200,6 +202,9 @@ class Req:
200
202
  origin_input_text: str,
201
203
  origin_input_ids: Tuple[int],
202
204
  sampling_params: SamplingParams,
205
+ return_logprob: bool = False,
206
+ top_logprobs_num: int = 0,
207
+ stream: bool = False,
203
208
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
204
209
  lora_path: Optional[str] = None,
205
210
  input_embeds: Optional[List[List[float]]] = None,
@@ -217,10 +222,11 @@ class Req:
217
222
  self.output_ids = [] # Each decode stage's output ids
218
223
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
219
224
  self.session_id = session_id
225
+ self.input_embeds = input_embeds
220
226
 
227
+ # Sampling info
221
228
  self.sampling_params = sampling_params
222
229
  self.lora_path = lora_path
223
- self.input_embeds = input_embeds
224
230
 
225
231
  # Memory pool info
226
232
  self.req_pool_idx = None
@@ -228,8 +234,8 @@ class Req:
228
234
  # Check finish
229
235
  self.tokenizer = None
230
236
  self.finished_reason = None
231
- self.stream = False
232
237
  self.to_abort = False
238
+ self.stream = stream
233
239
 
234
240
  # For incremental decoding
235
241
  # ----- | --------- read_ids -------|
@@ -241,37 +247,46 @@ class Req:
241
247
  # 2: read_offset
242
248
  # 3: last token
243
249
  self.vid = 0 # version id to sync decode status with in detokenizer_manager
244
- self.decoded_text = ""
245
250
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
246
251
  self.read_offset = None
247
-
248
- # The number of decoded tokens for token usage report. Note that
249
- # this does not include the jump forward tokens.
250
- self.completion_tokens_wo_jump_forward = 0
252
+ self.decoded_text = ""
251
253
 
252
254
  # For multimodal inputs
253
255
  self.image_inputs: Optional[ImageInputs] = None
254
256
 
255
257
  # Prefix info
256
258
  self.prefix_indices = []
259
+ # Tokens to run prefill. input_tokens - shared_prefix_tokens.
257
260
  self.extend_input_len = 0
258
261
  self.last_node = None
262
+
263
+ # Chunked prefill
259
264
  self.is_being_chunked = 0
260
265
 
261
266
  # For retraction
262
267
  self.is_retracted = False
263
268
 
264
269
  # Logprobs (arguments)
265
- self.return_logprob = False
270
+ self.return_logprob = return_logprob
266
271
  self.logprob_start_len = 0
267
- self.top_logprobs_num = 0
272
+ self.top_logprobs_num = top_logprobs_num
268
273
 
269
274
  # Logprobs (return value)
270
275
  self.normalized_prompt_logprob = None
271
- self.input_token_logprobs = None
272
- self.input_top_logprobs = None
273
- self.output_token_logprobs = []
274
- self.output_top_logprobs = []
276
+ self.input_token_logprobs_val = None
277
+ self.input_token_logprobs_idx = None
278
+ self.input_top_logprobs_val = None
279
+ self.input_top_logprobs_idx = None
280
+
281
+ if return_logprob:
282
+ self.output_token_logprobs_val = []
283
+ self.output_token_logprobs_idx = []
284
+ self.output_top_logprobs_val = []
285
+ self.output_top_logprobs_idx = []
286
+ else:
287
+ self.output_token_logprobs_val = self.output_token_logprobs_idx = (
288
+ self.output_top_logprobs_val
289
+ ) = self.output_top_logprobs_idx = None
275
290
 
276
291
  # Logprobs (internal values)
277
292
  # The tokens is prefilled but need to be considered as decode tokens
@@ -295,13 +310,14 @@ class Req:
295
310
  else:
296
311
  self.image_inputs.merge(image_inputs)
297
312
 
298
- # whether request reached finished condition
299
313
  def finished(self) -> bool:
314
+ # Whether request reached finished condition
300
315
  return self.finished_reason is not None
301
316
 
302
317
  def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
303
318
  self.fill_ids = self.origin_input_ids + self.output_ids
304
319
  if tree_cache is not None:
320
+ # tree cache is None if the prefix is not computed with tree cache.
305
321
  self.prefix_indices, self.last_node = tree_cache.match_prefix(
306
322
  rid=self.rid, key=self.adjust_max_prefix_ids()
307
323
  )
@@ -454,15 +470,31 @@ class Req:
454
470
  k = k + 1
455
471
  else:
456
472
  break
457
- self.output_token_logprobs = self.output_token_logprobs[:k]
458
- self.output_top_logprobs = self.output_top_logprobs[:k]
473
+ self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
474
+ self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
475
+ self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
476
+ self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
459
477
  self.logprob_start_len = prompt_tokens + k
460
478
  self.last_update_decode_tokens = len(self.output_ids) - k
461
479
 
462
480
  return True
463
481
 
482
+ def reset_for_retract(self):
483
+ self.prefix_indices = []
484
+ self.last_node = None
485
+ self.extend_input_len = 0
486
+ self.is_retracted = True
487
+
488
+ # For incremental logprobs
489
+ # TODO: Fix the `logprob_start_len`
490
+ self.last_update_decode_tokens = 0
491
+ self.logprob_start_len = 10**9
492
+
464
493
  def __repr__(self):
465
- return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
494
+ return (
495
+ f"rid(n={self.rid}, "
496
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
497
+ )
466
498
 
467
499
 
468
500
  bid = 0
@@ -470,7 +502,7 @@ bid = 0
470
502
 
471
503
  @dataclasses.dataclass
472
504
  class ScheduleBatch:
473
- """Store all inforamtion of a batch on the scheduler."""
505
+ """Store all information of a batch on the scheduler."""
474
506
 
475
507
  # Request, memory pool, and cache
476
508
  reqs: List[Req]
@@ -876,15 +908,7 @@ class ScheduleBatch:
876
908
  )
877
909
  residual_size = max(0, residual_size)
878
910
  self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
879
-
880
- req.prefix_indices = []
881
- req.last_node = None
882
- req.extend_input_len = 0
883
- req.is_retracted = True
884
-
885
- # For incremental logprobs
886
- req.last_update_decode_tokens = 0
887
- req.logprob_start_len = 10**9
911
+ req.reset_for_retract()
888
912
 
889
913
  self.filter_batch(keep_indices=sorted_indices)
890
914
 
@@ -1068,9 +1092,9 @@ class ScheduleBatch:
1068
1092
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1069
1093
  self.reqs.extend(other.reqs)
1070
1094
 
1071
- self.return_logprob = self.return_logprob or other.return_logprob
1072
- self.has_stream = self.has_stream or other.has_stream
1073
- self.has_grammar = self.has_grammar or other.has_grammar
1095
+ self.return_logprob |= other.return_logprob
1096
+ self.has_stream |= other.has_stream
1097
+ self.has_grammar |= other.has_grammar
1074
1098
 
1075
1099
  def get_model_worker_batch(self):
1076
1100
  if self.forward_mode.is_decode() or self.forward_mode.is_idle():
@@ -1097,7 +1121,6 @@ class ScheduleBatch:
1097
1121
  seq_lens=self.seq_lens,
1098
1122
  out_cache_loc=self.out_cache_loc,
1099
1123
  seq_lens_sum=self.seq_lens_sum,
1100
- req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1101
1124
  return_logprob=self.return_logprob,
1102
1125
  top_logprobs_nums=self.top_logprobs_nums,
1103
1126
  global_num_tokens=self.global_num_tokens,
@@ -1152,9 +1175,6 @@ class ModelWorkerBatch:
1152
1175
  # The sum of all sequence lengths
1153
1176
  seq_lens_sum: int
1154
1177
 
1155
- # The memory pool operation records
1156
- req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
1157
-
1158
1178
  # For logprob
1159
1179
  return_logprob: bool
1160
1180
  top_logprobs_nums: Optional[List[int]]
@@ -20,9 +20,11 @@ from contextlib import contextmanager
20
20
  from enum import Enum, auto
21
21
  from typing import Dict, List, Optional
22
22
 
23
+ import torch
24
+
23
25
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
24
26
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
25
- from sglang.srt.mem_cache.radix_cache import TreeNode
27
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
26
28
 
27
29
  # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
28
30
  # This can prevent the server from being too conservative.
@@ -32,6 +34,21 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
32
34
  os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
33
35
  )
34
36
 
37
+ # Threshold for in-batch prefix cache.
38
+ # If a request has a matched prefix length (against existing cache) less than this value,
39
+ # the scheduler runs the in-batch prefix caching check for this request.
40
+ # If we set it to -1, it means we disable in-batch prefix caching.
41
+ IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
42
+ os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
43
+ )
44
+
45
+ # Threshold for in-batch prefix cache.
46
+ # If a request has a matched prefix length (within the waiting queue) larger than this value,
47
+ # the scheduler deprioritizes this request
48
+ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
49
+ os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
50
+ )
51
+
35
52
 
36
53
  class SchedulePolicy:
37
54
  def __init__(self, policy: str, tree_cache: BasePrefixCache):
@@ -42,6 +59,11 @@ class SchedulePolicy:
42
59
  self.policy = policy
43
60
  self.tree_cache = tree_cache
44
61
 
62
+ # It is used to find the matching prefix for in-batch prefix caching.
63
+ self.waiting_queue_radix_tree = RadixCache(
64
+ req_to_token_pool=None, token_to_kv_pool=None, disable=False
65
+ )
66
+
45
67
  def calc_priority(self, waiting_queue: List[Req]):
46
68
  if len(waiting_queue) > 128 and self.policy == "lpm":
47
69
  # Turn off the expensive prefix matching and sorting when the #queue is large.
@@ -52,17 +74,53 @@ class SchedulePolicy:
52
74
  # Compute matched prefix length
53
75
  prefix_computed = False
54
76
  if policy == "lpm" or policy == "dfs-weight":
77
+ # rid to deprioritize in the current run for in-batch prefix caching.
78
+ temporary_deprioritized = set()
79
+ self.waiting_queue_radix_tree.reset()
80
+
55
81
  for r in waiting_queue:
82
+ prefix_ids = r.adjust_max_prefix_ids()
83
+
56
84
  # NOTE: the prefix_indices must always be aligned with last_node
57
85
  r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
58
- rid=r.rid, key=r.adjust_max_prefix_ids()
86
+ rid=r.rid, key=prefix_ids
59
87
  )
60
88
 
89
+ # NOTE(sang): This logic is for in-batch prefix caching;
90
+ # If there are more than 1 request that have small matching prefix from
91
+ # existing cache, but all those requests share the same prefix, we prefer
92
+ # to schedule only one of them so that we can increase the cache hit rate.
93
+ # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
94
+ # threshold means we cannot use in-batch prefix caching for short prefixes.
95
+ # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
96
+ if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
97
+ in_batch_matching_prefixes, _ = (
98
+ self.waiting_queue_radix_tree.match_prefix(
99
+ rid=r.rid, key=prefix_ids
100
+ )
101
+ )
102
+ if (
103
+ len(in_batch_matching_prefixes)
104
+ >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
105
+ ):
106
+ temporary_deprioritized.add(r.rid)
107
+ else:
108
+ # Insert with a dummy key
109
+ self.waiting_queue_radix_tree.insert(
110
+ prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
111
+ )
112
+
61
113
  prefix_computed = True
62
114
 
63
115
  if policy == "lpm":
64
116
  # Longest Prefix Match
65
- waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
117
+ waiting_queue.sort(
118
+ key=lambda r: (
119
+ -len(r.prefix_indices)
120
+ if r.rid not in temporary_deprioritized
121
+ else float("inf")
122
+ )
123
+ )
66
124
  elif policy == "fcfs":
67
125
  # first come first serve
68
126
  pass
@@ -72,6 +130,7 @@ class SchedulePolicy:
72
130
  elif policy == "random":
73
131
  random.shuffle(waiting_queue)
74
132
  elif policy == "dfs-weight":
133
+ # Experimental policy based on custom weights
75
134
  last_node_to_reqs = defaultdict(list)
76
135
  for req in waiting_queue:
77
136
  last_node_to_reqs[req.last_node].append(req)
@@ -101,8 +160,8 @@ class SchedulePolicy:
101
160
  def get_dfs_priority(
102
161
  self,
103
162
  cur_node: TreeNode,
104
- node_to_priority: Dict,
105
- last_node_to_reqs: Dict,
163
+ node_to_priority: Dict[TreeNode, int],
164
+ last_node_to_reqs: Dict[TreeNode, List[Req]],
106
165
  q: List,
107
166
  ):
108
167
  childs = [child for child in cur_node.children.values()]