sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -16,20 +16,18 @@ limitations under the License.
16
16
  """Meta data for requests and batches"""
17
17
 
18
18
  import logging
19
- import warnings
20
19
  from dataclasses import dataclass
21
20
  from typing import List, Optional, Union
22
21
 
23
22
  import torch
24
- from flashinfer.sampling import top_k_top_p_sampling_from_probs
25
23
 
26
- import sglang.srt.sampling.penaltylib as penaltylib
27
24
  from sglang.global_config import global_config
28
25
  from sglang.srt.constrained import RegexGuide
29
26
  from sglang.srt.constrained.jump_forward import JumpForwardMap
30
27
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
31
28
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
29
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
30
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
31
 
34
32
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
35
33
 
@@ -37,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
37
35
  global_server_args_dict = {
38
36
  "disable_flashinfer": False,
39
37
  "disable_flashinfer_sampling": False,
40
- "attention_reduce_in_fp32": False,
38
+ "triton_attention_reduce_in_fp32": False,
41
39
  "enable_mla": False,
42
40
  }
43
41
 
@@ -264,11 +262,18 @@ class Req:
264
262
 
265
263
  all_text = self.origin_input_text + self.decoded_text + jump_forward_str
266
264
  all_ids = self.tokenizer.encode(all_text)
265
+ if not all_ids:
266
+ logger.warning("Encoded all_text resulted in empty all_ids")
267
+ return False
268
+
267
269
  prompt_tokens = len(self.origin_input_ids_unpadded)
270
+ if prompt_tokens > len(all_ids):
271
+ logger.warning("prompt_tokens is larger than encoded all_ids")
272
+ return False
268
273
 
269
274
  if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
270
275
  # TODO(lsyin): fix token fusion
271
- warnings.warn(
276
+ logger.warning(
272
277
  "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
273
278
  )
274
279
  return False
@@ -327,17 +332,13 @@ class ScheduleBatch:
327
332
  out_cache_loc: torch.Tensor = None
328
333
  extend_num_tokens: int = None
329
334
 
335
+ # For mixed chunekd prefill
336
+ prefix_lens_cpu: List[int] = None
337
+
330
338
  # For processing logprobs
331
339
  return_logprob: bool = False
332
340
  top_logprobs_nums: List[int] = None
333
341
 
334
- # Batched sampling params
335
- temperatures: torch.Tensor = None
336
- top_ps: torch.Tensor = None
337
- top_ks: torch.Tensor = None
338
- penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
339
- logit_bias: torch.Tensor = None
340
-
341
342
  @classmethod
342
343
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
343
344
  return_logprob = any(req.return_logprob for req in reqs)
@@ -385,43 +386,6 @@ class ScheduleBatch:
385
386
 
386
387
  return out_cache_loc
387
388
 
388
- def batch_sampling_params(self, vocab_size):
389
- device = "cuda"
390
- bs, reqs = self.batch_size(), self.reqs
391
- self.temperatures = torch.tensor(
392
- [r.sampling_params.temperature for r in reqs],
393
- dtype=torch.float,
394
- device=device,
395
- ).view(-1, 1)
396
- self.top_ps = torch.tensor(
397
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
398
- )
399
- self.top_ks = torch.tensor(
400
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
401
- )
402
-
403
- # Each penalizers will do nothing if they evaluate themselves as not required by looking at
404
- # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
405
- # should not add hefty computation overhead other than simple checks.
406
- #
407
- # While we choose not to even create the class instances if they are not required, this
408
- # could add additional complexity to the {ScheduleBatch} class, especially we need to
409
- # handle {filter_batch()} and {merge()} cases as well.
410
- self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
411
- vocab_size=vocab_size,
412
- batch=self,
413
- device=device,
414
- Penalizers={
415
- penaltylib.BatchedFrequencyPenalizer,
416
- penaltylib.BatchedMinNewTokensPenalizer,
417
- penaltylib.BatchedPresencePenalizer,
418
- penaltylib.BatchedRepetitionPenalizer,
419
- },
420
- )
421
-
422
- # Handle logit bias but only allocate when needed
423
- self.logit_bias = None
424
-
425
389
  def prepare_for_extend(self, vocab_size: int):
426
390
  bs = self.batch_size()
427
391
  reqs = self.reqs
@@ -460,8 +424,32 @@ class ScheduleBatch:
460
424
  self.extend_num_tokens = extend_num_tokens
461
425
  self.out_cache_loc = out_cache_loc
462
426
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
427
+ self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
428
+
429
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
430
+
431
+ def mix_with_running(self, running_batch: "ScheduleBatch"):
432
+ # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
433
+ prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
434
+ prefix_lens_cpu.extend(
435
+ [
436
+ len(r.origin_input_ids) + len(r.output_ids) - 1
437
+ for r in running_batch.reqs
438
+ ]
439
+ )
463
440
 
464
- self.batch_sampling_params(vocab_size)
441
+ for req in running_batch.reqs:
442
+ req.fill_ids = req.origin_input_ids + req.output_ids
443
+ req.extend_input_len = 1
444
+
445
+ input_ids = torch.cat([self.input_ids, running_batch.input_ids])
446
+ out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
447
+ extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
448
+ self.merge(running_batch)
449
+ self.input_ids = input_ids
450
+ self.out_cache_loc = out_cache_loc
451
+ self.extend_num_tokens = extend_num_tokens
452
+ self.prefix_lens_cpu = prefix_lens_cpu
465
453
 
466
454
  def check_decode_mem(self):
467
455
  bs = self.batch_size()
@@ -634,7 +622,7 @@ class ScheduleBatch:
634
622
  for r in self.reqs
635
623
  ]
636
624
  else:
637
- self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
625
+ self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
638
626
 
639
627
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
640
628
  self.seq_lens.add_(1)
@@ -647,6 +635,8 @@ class ScheduleBatch:
647
635
  self.req_pool_indices, self.seq_lens - 1
648
636
  ] = self.out_cache_loc
649
637
 
638
+ self.sampling_info.update_regex_vocab_mask(self)
639
+
650
640
  def filter_batch(self, unfinished_indices: List[int]):
651
641
  if unfinished_indices is None or len(unfinished_indices) == 0:
652
642
  # Filter out all requests
@@ -667,23 +657,13 @@ class ScheduleBatch:
667
657
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
668
658
  self.return_logprob = any(req.return_logprob for req in self.reqs)
669
659
 
670
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
671
-
672
- for item in [
673
- "temperatures",
674
- "top_ps",
675
- "top_ks",
676
- "logit_bias",
677
- ]:
678
- self_val = getattr(self, item, None)
679
- if self_val is not None: # logit_bias can be None
680
- setattr(self, item, self_val[new_indices])
660
+ self.sampling_info.filter(unfinished_indices, new_indices)
681
661
 
682
662
  def merge(self, other: "ScheduleBatch"):
683
663
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
684
664
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
685
665
  # needs to be called with pre-merged Batch.reqs.
686
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
666
+ self.sampling_info.merge(other.sampling_info)
687
667
 
688
668
  self.reqs.extend(other.reqs)
689
669
 
@@ -698,111 +678,11 @@ class ScheduleBatch:
698
678
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
699
679
  self.return_logprob = any(req.return_logprob for req in self.reqs)
700
680
 
701
- for item in [
702
- "temperatures",
703
- "top_ps",
704
- "top_ks",
705
- ]:
706
- self_val = getattr(self, item, None)
707
- other_val = getattr(other, item, None)
708
- setattr(self, item, torch.concat([self_val, other_val]))
709
-
710
- # logit_bias can be None
711
- if self.logit_bias is not None or other.logit_bias is not None:
712
- vocab_size = (
713
- self.logit_bias.shape[1]
714
- if self.logit_bias is not None
715
- else other.logit_bias.shape[1]
716
- )
717
- if self.logit_bias is None:
718
- self.logit_bias = torch.zeros(
719
- (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
720
- )
721
- if other.logit_bias is None:
722
- other.logit_bias = torch.zeros(
723
- (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
724
- )
725
- self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
726
-
727
681
  def sample(self, logits: torch.Tensor):
728
- # TODO(lsyin): move this into a part of layer and run with CUDA Graph
729
- # Post process logits
730
- logits = logits.contiguous()
731
- logits.div_(self.temperatures)
732
- if self.logit_bias is not None:
733
- logits.add_(self.logit_bias)
734
-
735
- has_regex = any(req.regex_fsm is not None for req in self.reqs)
736
- if has_regex:
737
- allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
738
- for i, req in enumerate(self.reqs):
739
- if req.regex_fsm is not None:
740
- allowed_mask.zero_()
741
- allowed_mask[
742
- req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
743
- ] = 1
744
- logits[i].masked_fill_(~allowed_mask, float("-inf"))
745
-
746
- logits = self.penalizer_orchestrator.apply(logits)
747
-
748
- probs = torch.softmax(logits, dim=-1)
749
-
750
- if not global_server_args_dict["disable_flashinfer_sampling"]:
751
- max_top_k_round, batch_size = 32, probs.shape[0]
752
- uniform_samples = torch.rand(
753
- (max_top_k_round, batch_size), device=probs.device
754
- )
755
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
756
- probs, uniform_samples, self.top_ks, self.top_ps
757
- )
758
- else:
759
- # Here we provide a slower fallback implementation.
760
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
761
- probs, self.top_ks, self.top_ps
762
- )
763
-
764
- if not torch.all(success):
765
- warnings.warn("Sampling failed, fallback to top_k=1 strategy")
766
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
767
- argmax_ids = torch.argmax(probs, dim=-1)
768
- batch_next_token_ids = torch.where(
769
- success, batch_next_token_ids, argmax_ids
770
- )
682
+ from sglang.srt.layers.sampler import Sampler
771
683
 
772
- if has_regex:
773
- batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
774
- for i, req in enumerate(self.reqs):
775
- if req.regex_fsm is not None:
776
- req.regex_fsm_state = req.regex_fsm.get_next_state(
777
- req.regex_fsm_state, batch_next_token_ids_cpu[i]
778
- )
684
+ sampler = Sampler()
779
685
 
780
- self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
686
+ batch_next_token_ids = sampler(logits, self.sampling_info)
781
687
 
782
688
  return batch_next_token_ids
783
-
784
-
785
- def top_k_top_p_sampling_from_probs_torch(
786
- probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
787
- ):
788
- """A top-k and top-k sampling implementation with native pytorch operations."""
789
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
790
- probs_sum = torch.cumsum(probs_sort, dim=-1)
791
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
792
- probs_sort[
793
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
794
- >= top_ks.view(-1, 1)
795
- ] = 0.0
796
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
797
- try:
798
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
799
- except RuntimeError:
800
- batch_next_token_ids = torch.zeros(
801
- (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
802
- )
803
- success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
804
- return batch_next_token_ids, success
805
-
806
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
807
- success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
808
- return batch_next_token_ids, success