sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,18 @@ limitations under the License.
18
18
  import logging
19
19
  import warnings
20
20
  from dataclasses import dataclass
21
- from typing import List, Union
21
+ from typing import List, Optional, Union
22
22
 
23
- import numpy as np
24
23
  import torch
25
24
  from flashinfer.sampling import top_k_top_p_sampling_from_probs
26
25
 
26
+ import sglang.srt.sampling.penaltylib as penaltylib
27
27
  from sglang.global_config import global_config
28
28
  from sglang.srt.constrained import RegexGuide
29
29
  from sglang.srt.constrained.jump_forward import JumpForwardMap
30
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
31
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
32
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
- from sglang.srt.mem_cache.radix_cache import RadixCache
33
33
 
34
34
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
35
35
 
@@ -98,7 +98,7 @@ class Req:
98
98
  self.origin_input_ids_unpadded = origin_input_ids # Before image padding
99
99
  self.origin_input_ids = origin_input_ids
100
100
  self.output_ids = [] # Each decode stage's output ids
101
- self.input_ids = None # input_ids = origin_input_ids + output_ids
101
+ self.fill_ids = None # fill_ids = origin_input_ids + output_ids
102
102
 
103
103
  # Memory info
104
104
  self.req_pool_idx = None
@@ -124,7 +124,7 @@ class Req:
124
124
  # For vision input
125
125
  self.pixel_values = None
126
126
  self.image_size = None
127
- self.image_offset = 0
127
+ self.image_offset = None
128
128
  self.pad_value = None
129
129
 
130
130
  # Prefix info
@@ -142,6 +142,7 @@ class Req:
142
142
 
143
143
  # Logprobs
144
144
  self.return_logprob = False
145
+ self.embedding = None
145
146
  self.logprob_start_len = 0
146
147
  self.top_logprobs_num = 0
147
148
  self.normalized_prompt_logprob = None
@@ -162,6 +163,32 @@ class Req:
162
163
  def finished(self) -> bool:
163
164
  return self.finished_reason is not None
164
165
 
166
+ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
167
+ self.fill_ids = self.origin_input_ids + self.output_ids
168
+ if tree_cache is not None:
169
+ self.prefix_indices, self.last_node = tree_cache.match_prefix(
170
+ rid=self.rid, key=self.adjust_max_prefix_ids()
171
+ )
172
+ self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
173
+
174
+ def adjust_max_prefix_ids(self):
175
+ self.fill_ids = self.origin_input_ids + self.output_ids
176
+ input_len = len(self.fill_ids)
177
+ max_prefix_len = input_len
178
+
179
+ if self.sampling_params.max_new_tokens > 0:
180
+ # Need at least one token to compute logits
181
+ max_prefix_len = min(max_prefix_len, input_len - 1)
182
+
183
+ if self.return_logprob:
184
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
185
+
186
+ if self.normalized_prompt_logprob is None:
187
+ # Need at least two tokens to compute normalized logprob
188
+ max_prefix_len = min(max_prefix_len, input_len - 2)
189
+
190
+ return self.fill_ids[:max_prefix_len]
191
+
165
192
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
166
193
  def init_incremental_detokenize(self):
167
194
  first_iter = self.surr_offset is None or self.read_offset is None
@@ -176,6 +203,8 @@ class Req:
176
203
  return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
177
204
 
178
205
  def get_next_inc_detokenization(self):
206
+ if self.tokenizer is None:
207
+ return False, ""
179
208
  read_ids, read_offset = self.init_incremental_detokenize()
180
209
  surr_ids = read_ids[:read_offset]
181
210
 
@@ -200,16 +229,18 @@ class Req:
200
229
  return
201
230
 
202
231
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
203
- self.finished_reason = FINISH_LENGTH(len(self.output_ids))
232
+ self.finished_reason = FINISH_LENGTH(
233
+ length=self.sampling_params.max_new_tokens
234
+ )
204
235
  return
205
236
 
206
- if (
207
- self.output_ids[-1] == self.tokenizer.eos_token_id
208
- and not self.sampling_params.ignore_eos
209
- ):
210
- self.finished_reason = FINISH_MATCHED_TOKEN(
211
- matched=self.tokenizer.eos_token_id
212
- )
237
+ last_token_id = self.output_ids[-1]
238
+ if self.tokenizer is None:
239
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
240
+ else:
241
+ matched_eos = last_token_id == self.tokenizer.eos_token_id
242
+ if matched_eos and not self.sampling_params.ignore_eos:
243
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
213
244
  return
214
245
 
215
246
  if len(self.sampling_params.stop_strs) > 0:
@@ -284,13 +315,12 @@ class ScheduleBatch:
284
315
  reqs: List[Req]
285
316
  req_to_token_pool: ReqToTokenPool
286
317
  token_to_kv_pool: BaseTokenToKVPool
287
- tree_cache: RadixCache
318
+ tree_cache: BasePrefixCache
288
319
 
289
320
  # Batched arguments to model runner
290
321
  input_ids: torch.Tensor = None
291
322
  req_pool_indices: torch.Tensor = None
292
323
  seq_lens: torch.Tensor = None
293
- prefix_lens: torch.Tensor = None
294
324
  position_ids_offsets: torch.Tensor = None
295
325
  out_cache_loc: torch.Tensor = None
296
326
  extend_num_tokens: int = None
@@ -299,17 +329,11 @@ class ScheduleBatch:
299
329
  return_logprob: bool = False
300
330
  top_logprobs_nums: List[int] = None
301
331
 
302
- # For multimodal
303
- pixel_values: List[torch.Tensor] = None
304
- image_sizes: List[List[int]] = None
305
- image_offsets: List[int] = None
306
-
307
332
  # Batched sampling params
308
333
  temperatures: torch.Tensor = None
309
334
  top_ps: torch.Tensor = None
310
335
  top_ks: torch.Tensor = None
311
- frequency_penalties: torch.Tensor = None
312
- presence_penalties: torch.Tensor = None
336
+ penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
313
337
  logit_bias: torch.Tensor = None
314
338
 
315
339
  @classmethod
@@ -373,15 +397,24 @@ class ScheduleBatch:
373
397
  self.top_ks = torch.tensor(
374
398
  [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
375
399
  )
376
- self.frequency_penalties = torch.tensor(
377
- [r.sampling_params.frequency_penalty for r in reqs],
378
- dtype=torch.float,
379
- device=device,
380
- )
381
- self.presence_penalties = torch.tensor(
382
- [r.sampling_params.presence_penalty for r in reqs],
383
- dtype=torch.float,
400
+
401
+ # Each penalizers will do nothing if they evaluate themselves as not required by looking at
402
+ # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
403
+ # should not add hefty computation overhead other than simple checks.
404
+ #
405
+ # While we choose not to even create the class instances if they are not required, this
406
+ # could add additional complexity to the {ScheduleBatch} class, especially we need to
407
+ # handle {filter_batch()} and {merge()} cases as well.
408
+ self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
409
+ vocab_size=vocab_size,
410
+ batch=self,
384
411
  device=device,
412
+ Penalizers={
413
+ penaltylib.BatchedFrequencyPenalizer,
414
+ penaltylib.BatchedMinNewTokensPenalizer,
415
+ penaltylib.BatchedPresencePenalizer,
416
+ penaltylib.BatchedRepetitionPenalizer,
417
+ },
385
418
  )
386
419
 
387
420
  # Handle logit bias but only allocate when needed
@@ -395,58 +428,40 @@ class ScheduleBatch:
395
428
  self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
396
429
 
397
430
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
398
- device = "cuda"
399
431
  bs = self.batch_size()
400
432
  reqs = self.reqs
401
- input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
402
- prefix_indices = [r.prefix_indices for r in reqs]
403
-
404
- # Handle prefix
405
- extend_lens = []
406
- prefix_lens = []
433
+ input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
434
+ extend_num_tokens = sum(len(ids) for ids in input_ids)
407
435
  seq_lens = []
408
436
 
437
+ # Allocate memory
409
438
  req_pool_indices_cpu = self.alloc_req_slots(bs)
439
+ out_cache_loc = self.alloc_token_slots(extend_num_tokens)
410
440
 
441
+ pt = 0
411
442
  for i, req in enumerate(reqs):
412
443
  req.req_pool_idx = req_pool_indices_cpu[i]
413
- extend_lens.append(len(input_ids[i]))
444
+ pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
445
+ ext_len = seq_len - pre_len
446
+ seq_lens.append(seq_len)
414
447
 
415
- if len(prefix_indices[i]) == 0:
416
- prefix_lens.append(0)
417
- else:
418
- prefix_lens.append(len(prefix_indices[i]))
448
+ if pre_len > 0:
419
449
  self.req_to_token_pool.req_to_token[req.req_pool_idx][
420
- : len(prefix_indices[i])
421
- ] = prefix_indices[i]
422
-
423
- seq_lens.append(prefix_lens[-1] + extend_lens[-1])
424
-
425
- # Allocate memory
426
- seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
427
- extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
428
- out_cache_loc = self.alloc_token_slots(extend_num_tokens)
450
+ :pre_len
451
+ ] = req.prefix_indices
429
452
 
430
- pt = 0
431
- for i, req in enumerate(reqs):
432
- self.req_to_token_pool.req_to_token[req.req_pool_idx][
433
- prefix_lens[i] : prefix_lens[i] + extend_lens[i]
434
- ] = out_cache_loc[pt : pt + extend_lens[i]]
435
- pt += extend_lens[i]
453
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
454
+ out_cache_loc[pt : pt + ext_len]
455
+ )
456
+ pt += ext_len
436
457
 
437
458
  # Set fields
438
459
  with torch.device("cuda"):
439
460
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
440
461
  self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
441
462
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
442
- self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
443
-
444
- self.pixel_values = [r.pixel_values for r in reqs]
445
- self.image_sizes = [r.image_size for r in reqs]
446
- self.image_offsets = [
447
- r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
448
- ]
449
- self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
463
+ self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
464
+
450
465
  self.extend_num_tokens = extend_num_tokens
451
466
  self.out_cache_loc = out_cache_loc
452
467
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
@@ -522,7 +537,7 @@ class ScheduleBatch:
522
537
  residual_size = max(0, residual_size)
523
538
  self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
524
539
 
525
- req.prefix_indices = None
540
+ req.prefix_indices = []
526
541
  req.last_node = None
527
542
  req.extend_input_len = 0
528
543
 
@@ -596,15 +611,7 @@ class ScheduleBatch:
596
611
  req.vid += 1
597
612
 
598
613
  # insert the old request into tree_cache
599
- self.tree_cache.cache_req(
600
- rid=req.rid,
601
- token_ids=cur_all_ids,
602
- last_uncached_pos=len(req.prefix_indices),
603
- req_pool_idx=req.req_pool_idx,
604
- )
605
-
606
- # unlock the last node
607
- self.tree_cache.dec_lock_ref(req.last_node)
614
+ self.tree_cache.cache_finished_req(req, cur_all_ids)
608
615
 
609
616
  # re-applying image padding
610
617
  if req.pixel_values is not None:
@@ -621,19 +628,21 @@ class ScheduleBatch:
621
628
  jump_forward_reqs.append(req)
622
629
  filter_indices.remove(i)
623
630
 
624
- if len(filter_indices) < len(self.reqs):
625
- self.filter_batch(filter_indices)
631
+ self.filter_batch(filter_indices)
626
632
 
627
633
  return jump_forward_reqs
628
634
 
629
635
  def prepare_for_decode(self, input_ids=None):
630
636
  if input_ids is None:
631
637
  input_ids = [
632
- r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
638
+ r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
639
+ for r in self.reqs
633
640
  ]
641
+ else:
642
+ self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
643
+
634
644
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
635
645
  self.seq_lens.add_(1)
636
- self.prefix_lens = None
637
646
 
638
647
  # Alloc mem
639
648
  bs = self.batch_size()
@@ -644,23 +653,31 @@ class ScheduleBatch:
644
653
  ] = self.out_cache_loc
645
654
 
646
655
  def filter_batch(self, unfinished_indices: List[int]):
656
+ if unfinished_indices is None or len(unfinished_indices) == 0:
657
+ # Filter out all requests
658
+ self.reqs = []
659
+ return
660
+
661
+ if len(unfinished_indices) == len(self.reqs):
662
+ # No need to filter
663
+ return
664
+
647
665
  self.reqs = [self.reqs[i] for i in unfinished_indices]
648
666
  new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
649
667
  self.seq_lens = self.seq_lens[new_indices]
650
668
  self.input_ids = None
651
669
  self.req_pool_indices = self.req_pool_indices[new_indices]
652
- self.prefix_lens = None
653
670
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
654
671
  self.out_cache_loc = None
655
672
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
656
673
  self.return_logprob = any(req.return_logprob for req in self.reqs)
657
674
 
675
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
676
+
658
677
  for item in [
659
678
  "temperatures",
660
679
  "top_ps",
661
680
  "top_ks",
662
- "frequency_penalties",
663
- "presence_penalties",
664
681
  "logit_bias",
665
682
  ]:
666
683
  self_val = getattr(self, item, None)
@@ -668,13 +685,17 @@ class ScheduleBatch:
668
685
  setattr(self, item, self_val[new_indices])
669
686
 
670
687
  def merge(self, other: "ScheduleBatch"):
688
+ # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
689
+ # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
690
+ # needs to be called with pre-merged Batch.reqs.
691
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
692
+
671
693
  self.reqs.extend(other.reqs)
672
694
 
673
695
  self.req_pool_indices = torch.concat(
674
696
  [self.req_pool_indices, other.req_pool_indices]
675
697
  )
676
698
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
677
- self.prefix_lens = None
678
699
  self.position_ids_offsets = torch.concat(
679
700
  [self.position_ids_offsets, other.position_ids_offsets]
680
701
  )
@@ -686,8 +707,6 @@ class ScheduleBatch:
686
707
  "temperatures",
687
708
  "top_ps",
688
709
  "top_ks",
689
- "frequency_penalties",
690
- "presence_penalties",
691
710
  ]:
692
711
  self_val = getattr(self, item, None)
693
712
  other_val = getattr(other, item, None)
@@ -711,6 +730,7 @@ class ScheduleBatch:
711
730
  self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
712
731
 
713
732
  def sample(self, logits: torch.Tensor):
733
+ # TODO(lsyin): move this into a part of layer and run with CUDA Graph
714
734
  # Post process logits
715
735
  logits = logits.contiguous()
716
736
  logits.div_(self.temperatures)
@@ -728,7 +748,8 @@ class ScheduleBatch:
728
748
  ] = 1
729
749
  logits[i].masked_fill_(~allowed_mask, float("-inf"))
730
750
 
731
- # TODO(lmzheng): apply penalty
751
+ logits = self.penalizer_orchestrator.apply(logits)
752
+
732
753
  probs = torch.softmax(logits, dim=-1)
733
754
 
734
755
  if not global_server_args_dict["disable_flashinfer_sampling"]:
@@ -761,6 +782,8 @@ class ScheduleBatch:
761
782
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
762
783
  )
763
784
 
785
+ self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
786
+
764
787
  return batch_next_token_ids
765
788
 
766
789
 
@@ -780,7 +803,7 @@ def top_k_top_p_sampling_from_probs_torch(
780
803
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
781
804
  except RuntimeError:
782
805
  batch_next_token_ids = torch.zeros(
783
- (probs_sort.shape[0],), dtype=torch.int64, device=probs.device
806
+ (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
784
807
  )
785
808
  success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
786
809
  return batch_next_token_ids, success