sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -153,8 +153,9 @@ class TokenizerManager:
153
153
  async def _handle_single_request(
154
154
  self, obj, request, index=None, is_cache_for_prefill=False
155
155
  ):
156
- if not is_cache_for_prefill:
157
- not_use_index = not (index is not None)
156
+ if not is_cache_for_prefill: # The normal case with a single prompt
157
+ not_use_index = index is None
158
+
158
159
  rid = obj.rid if not_use_index else obj.rid[index]
159
160
  input_text = obj.text if not_use_index else obj.text[index]
160
161
  input_ids = (
@@ -182,14 +183,27 @@ class TokenizerManager:
182
183
  top_logprobs_num = (
183
184
  obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
184
185
  )
185
- else:
186
- if isinstance(obj.text, list):
187
- input_text = obj.text[index]
188
- rid = obj.rid[index]
186
+ else: # A prefill request to cache the common prompt for parallel sampling
187
+ if obj.text is not None:
188
+ if isinstance(obj.text, list):
189
+ input_text = obj.text[index]
190
+ rid = obj.rid[index]
191
+ else:
192
+ input_text = obj.text
193
+ rid = obj.rid[0]
194
+ input_ids = self.tokenizer.encode(input_text)
189
195
  else:
190
- input_text = obj.text
191
- rid = obj.rid[0]
192
- input_ids = self.tokenizer.encode(input_text)
196
+ input_text = None
197
+ if isinstance(obj.input_ids, list) and isinstance(
198
+ obj.input_ids[0], list
199
+ ):
200
+ # when obj["input_ids"] is List[List[int]]
201
+ input_ids = obj.input_ids[index]
202
+ rid = obj.rid[index]
203
+ else:
204
+ input_ids = obj.input_ids
205
+ rid = obj.rid[0]
206
+
193
207
  sampling_params = SamplingParams(**obj.sampling_params[0])
194
208
  sampling_params.max_new_tokens = 0
195
209
  pixel_values, image_hash, image_size = await self._get_pixel_values(
@@ -240,11 +254,11 @@ class TokenizerManager:
240
254
  ):
241
255
  if input_id_result is not None:
242
256
  input_id_result.append(input_id)
243
- pass
244
- if len(input_id_result) > 1 and input_id_result is not None:
257
+ if input_id_result is not None and len(input_id_result) > 1:
245
258
  obj.input_ids = input_id_result
246
259
  elif input_id_result is not None:
247
260
  obj.input_ids = input_id_result[0]
261
+
248
262
  # First send out all requests
249
263
  for i in range(batch_size):
250
264
  for j in range(parallel_sample_num):
@@ -264,11 +278,12 @@ class TokenizerManager:
264
278
  input_text = None
265
279
  input_ids = obj.input_ids[i]
266
280
  else:
281
+ assert obj.input_ids is not None
267
282
  if batch_size == 1:
268
- input_text = obj.text
283
+ input_text = None
269
284
  input_ids = obj.input_ids
270
285
  else:
271
- input_text = obj.text[i]
286
+ input_text = None
272
287
  input_ids = obj.input_ids[i]
273
288
  sampling_params = self._get_sampling_params(obj.sampling_params[index])
274
289
  pixel_values, image_hash, image_size = await self._get_pixel_values(
@@ -293,7 +308,6 @@ class TokenizerManager:
293
308
  event = asyncio.Event()
294
309
  state = ReqState([], False, event)
295
310
  self.rid_to_state[rid] = state
296
-
297
311
  # Then wait for all responses
298
312
  output_list = []
299
313
  for i in range(batch_size):
@@ -326,7 +340,6 @@ class TokenizerManager:
326
340
  )
327
341
  assert state.finished
328
342
  del self.rid_to_state[rid]
329
-
330
343
  yield output_list
331
344
 
332
345
  def _validate_input_length(self, input_ids: List[int]):
@@ -375,8 +388,13 @@ class TokenizerManager:
375
388
  obj.return_text_in_logprobs,
376
389
  )
377
390
 
391
+ # Log requests
378
392
  if self.server_args.log_requests and state.finished:
379
- logger.info(f"in={obj.text}, out={out}")
393
+ if obj.text is None:
394
+ in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
395
+ else:
396
+ in_obj = {"text": obj.text}
397
+ logger.info(f"in={in_obj}, out={out}")
380
398
 
381
399
  state.out_list = []
382
400
  if state.finished:
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
39
39
  from sglang.srt.managers.schedule_batch import (
40
40
  FINISH_ABORT,
41
41
  BaseFinishReason,
42
- Batch,
43
- ForwardMode,
44
42
  Req,
43
+ ScheduleBatch,
45
44
  )
46
45
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
47
46
  from sglang.srt.mem_cache.radix_cache import RadixCache
48
47
  from sglang.srt.model_config import ModelConfig
48
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
49
49
  from sglang.srt.model_executor.model_runner import ModelRunner
50
50
  from sglang.srt.server_args import ServerArgs
51
51
  from sglang.srt.utils import (
@@ -172,7 +172,7 @@ class ModelTpServer:
172
172
 
173
173
  # Init running status
174
174
  self.waiting_queue: List[Req] = []
175
- self.running_batch: Batch = None
175
+ self.running_batch: ScheduleBatch = None
176
176
  self.out_pyobjs = []
177
177
  self.decode_forward_ct = 0
178
178
  self.stream_interval = server_args.stream_interval
@@ -200,7 +200,6 @@ class ModelTpServer:
200
200
  )
201
201
  self.new_token_ratio = self.min_new_token_ratio
202
202
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
203
- self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
204
203
 
205
204
  def exposed_step(self, recv_reqs):
206
205
  try:
@@ -290,10 +289,10 @@ class ModelTpServer:
290
289
  "KV cache pool leak detected!"
291
290
  )
292
291
 
293
- if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
292
+ if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
294
293
  warnings.warn(
295
294
  "Warning: "
296
- f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
295
+ f"available req slots={len(self.req_to_token_pool.free_slots)}, "
297
296
  f"total slots={self.req_to_token_pool.size}\n"
298
297
  "Memory pool leak detected!"
299
298
  )
@@ -353,7 +352,7 @@ class ModelTpServer:
353
352
  )
354
353
  self.waiting_queue.append(req)
355
354
 
356
- def get_new_prefill_batch(self) -> Optional[Batch]:
355
+ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
357
356
  # TODO(lsyin): organize this function
358
357
  running_bs = (
359
358
  len(self.running_batch.reqs) if self.running_batch is not None else 0
@@ -364,12 +363,13 @@ class ModelTpServer:
364
363
  # Compute matched prefix length
365
364
  for req in self.waiting_queue:
366
365
  req.input_ids = req.origin_input_ids + req.output_ids
366
+ try_match_ids = req.input_ids
367
+ if req.return_logprob:
368
+ try_match_ids = req.input_ids[: req.logprob_start_len]
369
+ # NOTE: the prefix_indices must always be aligned with last_node
367
370
  prefix_indices, last_node = self.tree_cache.match_prefix(
368
- rid=req.rid,
369
- key=req.input_ids,
371
+ rid=req.rid, key=try_match_ids
370
372
  )
371
- if req.return_logprob:
372
- prefix_indices = prefix_indices[: req.logprob_start_len]
373
373
  req.extend_input_len = len(req.input_ids) - len(prefix_indices)
374
374
  req.prefix_indices = prefix_indices
375
375
  req.last_node = last_node
@@ -525,7 +525,7 @@ class ModelTpServer:
525
525
  )
526
526
 
527
527
  # Return the new batch
528
- new_batch = Batch.init_new(
528
+ new_batch = ScheduleBatch.init_new(
529
529
  can_run_list,
530
530
  self.req_to_token_pool,
531
531
  self.token_to_kv_pool,
@@ -534,7 +534,7 @@ class ModelTpServer:
534
534
  self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
535
535
  return new_batch
536
536
 
537
- def forward_prefill_batch(self, batch: Batch):
537
+ def forward_prefill_batch(self, batch: ScheduleBatch):
538
538
  # Build batch tensors
539
539
  batch.prepare_for_extend(
540
540
  self.model_config.vocab_size, self.int_token_logit_bias
@@ -623,14 +623,13 @@ class ModelTpServer:
623
623
  )
624
624
  req.output_top_logprobs.append(output.output_top_logprobs[i])
625
625
 
626
- def cache_filled_batch(self, batch: Batch):
627
- req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
626
+ def cache_filled_batch(self, batch: ScheduleBatch):
628
627
  for i, req in enumerate(batch.reqs):
629
628
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
630
629
  rid=req.rid,
631
630
  token_ids=tuple(req.input_ids),
632
631
  last_uncached_pos=len(req.prefix_indices),
633
- req_pool_idx=req_pool_indices_cpu[i],
632
+ req_pool_idx=req.req_pool_idx,
634
633
  del_in_memory_pool=False,
635
634
  old_last_node=req.last_node,
636
635
  )
@@ -638,9 +637,9 @@ class ModelTpServer:
638
637
 
639
638
  if req is self.current_inflight_req:
640
639
  # inflight request would get a new req idx
641
- self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
640
+ self.req_to_token_pool.free(req.req_pool_idx)
642
641
 
643
- def forward_decode_batch(self, batch: Batch):
642
+ def forward_decode_batch(self, batch: ScheduleBatch):
644
643
  # Check if decode out of memory
645
644
  if not batch.check_decode_mem():
646
645
  old_ratio = self.new_token_ratio
@@ -699,7 +698,7 @@ class ModelTpServer:
699
698
 
700
699
  self.handle_finished_requests(batch)
701
700
 
702
- def handle_finished_requests(self, batch: Batch):
701
+ def handle_finished_requests(self, batch: ScheduleBatch):
703
702
  output_rids = []
704
703
  output_vids = []
705
704
  decoded_texts = []
@@ -781,14 +780,13 @@ class ModelTpServer:
781
780
  # Remove finished reqs
782
781
  if finished_indices:
783
782
  # Update radix cache
784
- req_pool_indices_cpu = batch.req_pool_indices.tolist()
785
783
  for i in finished_indices:
786
784
  req = batch.reqs[i]
787
785
  self.tree_cache.cache_req(
788
786
  rid=req.rid,
789
787
  token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
790
788
  last_uncached_pos=len(req.prefix_indices),
791
- req_pool_idx=req_pool_indices_cpu[i],
789
+ req_pool_idx=req.req_pool_idx,
792
790
  )
793
791
 
794
792
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -799,7 +797,7 @@ class ModelTpServer:
799
797
  else:
800
798
  batch.reqs = []
801
799
 
802
- def filter_out_inflight(self, batch: Batch):
800
+ def filter_out_inflight(self, batch: ScheduleBatch):
803
801
  # TODO(lsyin): reduce the overhead, make a special version for this
804
802
  if self.current_inflight_req is None:
805
803
  return
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
+ from typing import List
19
20
 
20
21
  import torch
21
22
 
@@ -27,62 +28,42 @@ class ReqToTokenPool:
27
28
 
28
29
  def __init__(self, size: int, max_context_len: int):
29
30
  self.size = size
30
- self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
31
+ self.free_slots = list(range(size))
31
32
  self.req_to_token = torch.empty(
32
33
  (size, max_context_len), dtype=torch.int32, device="cuda"
33
34
  )
34
- self.can_use_mem_size = size
35
35
 
36
- def alloc(self, need_size: int):
37
- if need_size > self.can_use_mem_size:
36
+ def alloc(self, need_size: int) -> List[int]:
37
+ if need_size > len(self.free_slots):
38
38
  return None
39
39
 
40
- select_index = (
41
- torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
42
- )
43
- self.mem_state[select_index] = False
44
- self.can_use_mem_size -= need_size
40
+ select_index = self.free_slots[:need_size]
41
+ self.free_slots = self.free_slots[need_size:]
45
42
 
46
43
  return select_index
47
44
 
48
45
  def free(self, free_index):
49
- self.mem_state[free_index] = True
50
46
  if isinstance(free_index, (int,)):
51
- self.can_use_mem_size += 1
47
+ self.free_slots.append(free_index)
52
48
  else:
53
- self.can_use_mem_size += free_index.shape[0]
49
+ self.free_slots.extend(free_index)
54
50
 
55
51
  def clear(self):
56
- self.mem_state.fill_(True)
57
- self.can_use_mem_size = len(self.mem_state)
52
+ self.free_slots = list(range(self.size))
58
53
 
59
54
 
60
- class TokenToKVPool:
55
+ class BaseTokenToKVPool:
61
56
  """A memory pool that maps a token to its kv cache locations"""
62
57
 
63
58
  def __init__(
64
59
  self,
65
60
  size: int,
66
- dtype: torch.dtype,
67
- head_num: int,
68
- head_dim: int,
69
- layer_num: int,
70
61
  ):
71
62
  self.size = size
72
63
 
73
64
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
74
65
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
75
66
 
76
- # [size, head_num, head_dim] for each layer
77
- self.k_buffer = [
78
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
79
- for _ in range(layer_num)
80
- ]
81
- self.v_buffer = [
82
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
83
- for _ in range(layer_num)
84
- ]
85
-
86
67
  # Prefetch buffer
87
68
  self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
88
69
  self.prefetch_chunk_size = 512
@@ -90,15 +71,6 @@ class TokenToKVPool:
90
71
  self.can_use_mem_size = self.size
91
72
  self.clear()
92
73
 
93
- def get_key_buffer(self, layer_id: int):
94
- return self.k_buffer[layer_id]
95
-
96
- def get_value_buffer(self, layer_id: int):
97
- return self.v_buffer[layer_id]
98
-
99
- def get_kv_buffer(self, layer_id: int):
100
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
101
-
102
74
  def available_size(self):
103
75
  return self.can_use_mem_size + len(self.prefetch_buffer)
104
76
 
@@ -139,3 +111,67 @@ class TokenToKVPool:
139
111
 
140
112
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
141
113
  self.mem_state[0] = False
114
+
115
+
116
+ class MHATokenToKVPool(BaseTokenToKVPool):
117
+
118
+ def __init__(
119
+ self,
120
+ size: int,
121
+ dtype: torch.dtype,
122
+ head_num: int,
123
+ head_dim: int,
124
+ layer_num: int,
125
+ ):
126
+ super().__init__(size)
127
+
128
+ # [size, head_num, head_dim] for each layer
129
+ self.k_buffer = [
130
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
131
+ for _ in range(layer_num)
132
+ ]
133
+ self.v_buffer = [
134
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
135
+ for _ in range(layer_num)
136
+ ]
137
+
138
+ def get_key_buffer(self, layer_id: int):
139
+ return self.k_buffer[layer_id]
140
+
141
+ def get_value_buffer(self, layer_id: int):
142
+ return self.v_buffer[layer_id]
143
+
144
+ def get_kv_buffer(self, layer_id: int):
145
+ return self.k_buffer[layer_id], self.v_buffer[layer_id]
146
+
147
+
148
+ class MLATokenToKVPool(BaseTokenToKVPool):
149
+
150
+ def __init__(
151
+ self,
152
+ size: int,
153
+ dtype: torch.dtype,
154
+ kv_lora_rank: int,
155
+ qk_rope_head_dim: int,
156
+ layer_num: int,
157
+ ):
158
+ super().__init__(size)
159
+
160
+ self.kv_lora_rank = kv_lora_rank
161
+ self.kv_buffer = [
162
+ torch.empty(
163
+ (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
164
+ dtype=dtype,
165
+ device="cuda",
166
+ )
167
+ for _ in range(layer_num)
168
+ ]
169
+
170
+ def get_key_buffer(self, layer_id: int):
171
+ return self.kv_buffer[layer_id]
172
+
173
+ def get_value_buffer(self, layer_id: int):
174
+ return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
175
+
176
+ def get_kv_buffer(self, layer_id: int):
177
+ return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from enum import IntEnum, auto
16
17
  from typing import Optional
17
18
 
18
19
  from transformers import PretrainedConfig
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
20
21
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
21
22
 
22
23
 
24
+ class AttentionArch(IntEnum):
25
+ MLA = auto()
26
+ MHA = auto()
27
+
28
+
23
29
  class ModelConfig:
24
30
  def __init__(
25
31
  self,
@@ -55,6 +61,11 @@ class ModelConfig:
55
61
  # FIXME: temporary special judge for deepseek v2 MLA architecture
56
62
  if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
57
63
  self.head_dim = 256
64
+ self.attention_arch = AttentionArch.MLA
65
+ self.kv_lora_rank = self.hf_config.kv_lora_rank
66
+ self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
67
+ else:
68
+ self.attention_arch = AttentionArch.MHA
58
69
 
59
70
  self.num_attention_heads = self.hf_config.num_attention_heads
60
71
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
29
29
  LogitsMetadata,
30
30
  LogitsProcessor,
31
31
  )
32
- from sglang.srt.managers.schedule_batch import (
33
- Batch,
32
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
33
+ from sglang.srt.model_executor.forward_batch_info import (
34
34
  ForwardMode,
35
35
  InputMetadata,
36
36
  init_flashinfer_args,
@@ -202,7 +202,7 @@ class CudaGraphRunner:
202
202
  self.graph_memory_pool = graph.pool()
203
203
  return graph, None, out, flashinfer_decode_wrapper
204
204
 
205
- def replay(self, batch: Batch):
205
+ def replay(self, batch: ScheduleBatch):
206
206
  assert batch.out_cache_loc is not None
207
207
  raw_bs = len(batch.reqs)
208
208