sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -111,11 +111,14 @@ class PrefillAdder:
111
111
  rem_total_tokens: int,
112
112
  rem_input_tokens: int,
113
113
  rem_chunk_tokens: Optional[int],
114
+ mixed_with_decode_tokens: int = 0,
114
115
  ):
115
116
  self.tree_cache = tree_cache
116
- self.rem_total_tokens = rem_total_tokens
117
- self.rem_input_tokens = rem_input_tokens
117
+ self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
118
+ self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
118
119
  self.rem_chunk_tokens = rem_chunk_tokens
120
+ if self.rem_chunk_tokens is not None:
121
+ self.rem_chunk_tokens -= mixed_with_decode_tokens
119
122
 
120
123
  self.can_run_list = []
121
124
  self.new_inflight_req = None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,20 +18,22 @@ limitations under the License.
16
18
  """Meta data for requests and batches"""
17
19
 
18
20
  import logging
19
- import warnings
20
21
  from dataclasses import dataclass
21
- from typing import List, Optional, Union
22
+ from typing import TYPE_CHECKING, List, Optional, Union
22
23
 
23
24
  import torch
24
- from flashinfer.sampling import top_k_top_p_sampling_from_probs
25
25
 
26
- import sglang.srt.sampling.penaltylib as penaltylib
27
26
  from sglang.global_config import global_config
28
27
  from sglang.srt.constrained import RegexGuide
29
28
  from sglang.srt.constrained.jump_forward import JumpForwardMap
30
29
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
31
30
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
32
31
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
+
34
+ if TYPE_CHECKING:
35
+ from sglang.srt.layers.sampler import SampleOutput
36
+
33
37
 
34
38
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
35
39
 
@@ -37,7 +41,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
37
41
  global_server_args_dict = {
38
42
  "disable_flashinfer": False,
39
43
  "disable_flashinfer_sampling": False,
40
- "attention_reduce_in_fp32": False,
44
+ "triton_attention_reduce_in_fp32": False,
41
45
  "enable_mla": False,
42
46
  }
43
47
 
@@ -268,7 +272,7 @@ class Req:
268
272
 
269
273
  if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
270
274
  # TODO(lsyin): fix token fusion
271
- warnings.warn(
275
+ logger.warning(
272
276
  "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
273
277
  )
274
278
  return False
@@ -327,17 +331,13 @@ class ScheduleBatch:
327
331
  out_cache_loc: torch.Tensor = None
328
332
  extend_num_tokens: int = None
329
333
 
334
+ # For mixed chunekd prefill
335
+ prefix_lens_cpu: List[int] = None
336
+
330
337
  # For processing logprobs
331
338
  return_logprob: bool = False
332
339
  top_logprobs_nums: List[int] = None
333
340
 
334
- # Batched sampling params
335
- temperatures: torch.Tensor = None
336
- top_ps: torch.Tensor = None
337
- top_ks: torch.Tensor = None
338
- penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
339
- logit_bias: torch.Tensor = None
340
-
341
341
  @classmethod
342
342
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
343
343
  return_logprob = any(req.return_logprob for req in reqs)
@@ -385,43 +385,6 @@ class ScheduleBatch:
385
385
 
386
386
  return out_cache_loc
387
387
 
388
- def batch_sampling_params(self, vocab_size):
389
- device = "cuda"
390
- bs, reqs = self.batch_size(), self.reqs
391
- self.temperatures = torch.tensor(
392
- [r.sampling_params.temperature for r in reqs],
393
- dtype=torch.float,
394
- device=device,
395
- ).view(-1, 1)
396
- self.top_ps = torch.tensor(
397
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
398
- )
399
- self.top_ks = torch.tensor(
400
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
401
- )
402
-
403
- # Each penalizers will do nothing if they evaluate themselves as not required by looking at
404
- # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
405
- # should not add hefty computation overhead other than simple checks.
406
- #
407
- # While we choose not to even create the class instances if they are not required, this
408
- # could add additional complexity to the {ScheduleBatch} class, especially we need to
409
- # handle {filter_batch()} and {merge()} cases as well.
410
- self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
411
- vocab_size=vocab_size,
412
- batch=self,
413
- device=device,
414
- Penalizers={
415
- penaltylib.BatchedFrequencyPenalizer,
416
- penaltylib.BatchedMinNewTokensPenalizer,
417
- penaltylib.BatchedPresencePenalizer,
418
- penaltylib.BatchedRepetitionPenalizer,
419
- },
420
- )
421
-
422
- # Handle logit bias but only allocate when needed
423
- self.logit_bias = None
424
-
425
388
  def prepare_for_extend(self, vocab_size: int):
426
389
  bs = self.batch_size()
427
390
  reqs = self.reqs
@@ -460,8 +423,32 @@ class ScheduleBatch:
460
423
  self.extend_num_tokens = extend_num_tokens
461
424
  self.out_cache_loc = out_cache_loc
462
425
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
426
+ self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
427
+
428
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
429
+
430
+ def mix_with_running(self, running_batch: "ScheduleBatch"):
431
+ # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
432
+ prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
433
+ prefix_lens_cpu.extend(
434
+ [
435
+ len(r.origin_input_ids) + len(r.output_ids) - 1
436
+ for r in running_batch.reqs
437
+ ]
438
+ )
463
439
 
464
- self.batch_sampling_params(vocab_size)
440
+ for req in running_batch.reqs:
441
+ req.fill_ids = req.origin_input_ids + req.output_ids
442
+ req.extend_input_len = 1
443
+
444
+ input_ids = torch.cat([self.input_ids, running_batch.input_ids])
445
+ out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
446
+ extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
447
+ self.merge(running_batch)
448
+ self.input_ids = input_ids
449
+ self.out_cache_loc = out_cache_loc
450
+ self.extend_num_tokens = extend_num_tokens
451
+ self.prefix_lens_cpu = prefix_lens_cpu
465
452
 
466
453
  def check_decode_mem(self):
467
454
  bs = self.batch_size()
@@ -634,7 +621,7 @@ class ScheduleBatch:
634
621
  for r in self.reqs
635
622
  ]
636
623
  else:
637
- self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
624
+ self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
638
625
 
639
626
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
640
627
  self.seq_lens.add_(1)
@@ -647,6 +634,8 @@ class ScheduleBatch:
647
634
  self.req_pool_indices, self.seq_lens - 1
648
635
  ] = self.out_cache_loc
649
636
 
637
+ self.sampling_info.update_regex_vocab_mask(self)
638
+
650
639
  def filter_batch(self, unfinished_indices: List[int]):
651
640
  if unfinished_indices is None or len(unfinished_indices) == 0:
652
641
  # Filter out all requests
@@ -667,23 +656,13 @@ class ScheduleBatch:
667
656
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
668
657
  self.return_logprob = any(req.return_logprob for req in self.reqs)
669
658
 
670
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
671
-
672
- for item in [
673
- "temperatures",
674
- "top_ps",
675
- "top_ks",
676
- "logit_bias",
677
- ]:
678
- self_val = getattr(self, item, None)
679
- if self_val is not None: # logit_bias can be None
680
- setattr(self, item, self_val[new_indices])
659
+ self.sampling_info.filter(unfinished_indices, new_indices)
681
660
 
682
661
  def merge(self, other: "ScheduleBatch"):
683
662
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
684
663
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
685
664
  # needs to be called with pre-merged Batch.reqs.
686
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
665
+ self.sampling_info.merge(other.sampling_info)
687
666
 
688
667
  self.reqs.extend(other.reqs)
689
668
 
@@ -698,111 +677,17 @@ class ScheduleBatch:
698
677
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
699
678
  self.return_logprob = any(req.return_logprob for req in self.reqs)
700
679
 
701
- for item in [
702
- "temperatures",
703
- "top_ps",
704
- "top_ks",
705
- ]:
706
- self_val = getattr(self, item, None)
707
- other_val = getattr(other, item, None)
708
- setattr(self, item, torch.concat([self_val, other_val]))
709
-
710
- # logit_bias can be None
711
- if self.logit_bias is not None or other.logit_bias is not None:
712
- vocab_size = (
713
- self.logit_bias.shape[1]
714
- if self.logit_bias is not None
715
- else other.logit_bias.shape[1]
716
- )
717
- if self.logit_bias is None:
718
- self.logit_bias = torch.zeros(
719
- (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
720
- )
721
- if other.logit_bias is None:
722
- other.logit_bias = torch.zeros(
723
- (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
724
- )
725
- self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
726
-
727
- def sample(self, logits: torch.Tensor):
728
- # TODO(lsyin): move this into a part of layer and run with CUDA Graph
729
- # Post process logits
730
- logits = logits.contiguous()
731
- logits.div_(self.temperatures)
732
- if self.logit_bias is not None:
733
- logits.add_(self.logit_bias)
734
-
735
- has_regex = any(req.regex_fsm is not None for req in self.reqs)
736
- if has_regex:
737
- allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
738
- for i, req in enumerate(self.reqs):
739
- if req.regex_fsm is not None:
740
- allowed_mask.zero_()
741
- allowed_mask[
742
- req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
743
- ] = 1
744
- logits[i].masked_fill_(~allowed_mask, float("-inf"))
745
-
746
- logits = self.penalizer_orchestrator.apply(logits)
747
-
748
- probs = torch.softmax(logits, dim=-1)
749
-
750
- if not global_server_args_dict["disable_flashinfer_sampling"]:
751
- max_top_k_round, batch_size = 32, probs.shape[0]
752
- uniform_samples = torch.rand(
753
- (max_top_k_round, batch_size), device=probs.device
754
- )
755
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
756
- probs, uniform_samples, self.top_ks, self.top_ps
757
- )
758
- else:
759
- # Here we provide a slower fallback implementation.
760
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
761
- probs, self.top_ks, self.top_ps
762
- )
763
-
764
- if not torch.all(success):
765
- warnings.warn("Sampling failed, fallback to top_k=1 strategy")
680
+ def check_sample_results(self, sample_output: SampleOutput):
681
+ if not torch.all(sample_output.success):
682
+ probs = sample_output.probs
683
+ batch_next_token_ids = sample_output.batch_next_token_ids
684
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
766
685
  probs = probs.masked_fill(torch.isnan(probs), 0.0)
767
686
  argmax_ids = torch.argmax(probs, dim=-1)
768
687
  batch_next_token_ids = torch.where(
769
- success, batch_next_token_ids, argmax_ids
688
+ sample_output.success, batch_next_token_ids, argmax_ids
770
689
  )
690
+ sample_output.probs = probs
691
+ sample_output.batch_next_token_ids = batch_next_token_ids
771
692
 
772
- if has_regex:
773
- batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
774
- for i, req in enumerate(self.reqs):
775
- if req.regex_fsm is not None:
776
- req.regex_fsm_state = req.regex_fsm.get_next_state(
777
- req.regex_fsm_state, batch_next_token_ids_cpu[i]
778
- )
779
-
780
- self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
781
-
782
- return batch_next_token_ids
783
-
784
-
785
- def top_k_top_p_sampling_from_probs_torch(
786
- probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
787
- ):
788
- """A top-k and top-k sampling implementation with native pytorch operations."""
789
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
790
- probs_sum = torch.cumsum(probs_sort, dim=-1)
791
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
792
- probs_sort[
793
- torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
794
- >= top_ks.view(-1, 1)
795
- ] = 0.0
796
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
797
- try:
798
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
799
- except RuntimeError:
800
- batch_next_token_ids = torch.zeros(
801
- (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
802
- )
803
- success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
804
- return batch_next_token_ids, success
805
-
806
- batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
807
- success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
808
- return batch_next_token_ids, success
693
+ return sample_output.batch_next_token_ids