sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -16,20 +16,18 @@ limitations under the License.
|
|
16
16
|
"""Meta data for requests and batches"""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
import warnings
|
20
19
|
from dataclasses import dataclass
|
21
20
|
from typing import List, Optional, Union
|
22
21
|
|
23
22
|
import torch
|
24
|
-
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
25
23
|
|
26
|
-
import sglang.srt.sampling.penaltylib as penaltylib
|
27
24
|
from sglang.global_config import global_config
|
28
25
|
from sglang.srt.constrained import RegexGuide
|
29
26
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
30
27
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
31
28
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
29
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
30
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
33
31
|
|
34
32
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
35
33
|
|
@@ -37,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
37
35
|
global_server_args_dict = {
|
38
36
|
"disable_flashinfer": False,
|
39
37
|
"disable_flashinfer_sampling": False,
|
40
|
-
"
|
38
|
+
"triton_attention_reduce_in_fp32": False,
|
41
39
|
"enable_mla": False,
|
42
40
|
}
|
43
41
|
|
@@ -264,11 +262,18 @@ class Req:
|
|
264
262
|
|
265
263
|
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
266
264
|
all_ids = self.tokenizer.encode(all_text)
|
265
|
+
if not all_ids:
|
266
|
+
logger.warning("Encoded all_text resulted in empty all_ids")
|
267
|
+
return False
|
268
|
+
|
267
269
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
270
|
+
if prompt_tokens > len(all_ids):
|
271
|
+
logger.warning("prompt_tokens is larger than encoded all_ids")
|
272
|
+
return False
|
268
273
|
|
269
274
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
270
275
|
# TODO(lsyin): fix token fusion
|
271
|
-
|
276
|
+
logger.warning(
|
272
277
|
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
273
278
|
)
|
274
279
|
return False
|
@@ -327,17 +332,13 @@ class ScheduleBatch:
|
|
327
332
|
out_cache_loc: torch.Tensor = None
|
328
333
|
extend_num_tokens: int = None
|
329
334
|
|
335
|
+
# For mixed chunekd prefill
|
336
|
+
prefix_lens_cpu: List[int] = None
|
337
|
+
|
330
338
|
# For processing logprobs
|
331
339
|
return_logprob: bool = False
|
332
340
|
top_logprobs_nums: List[int] = None
|
333
341
|
|
334
|
-
# Batched sampling params
|
335
|
-
temperatures: torch.Tensor = None
|
336
|
-
top_ps: torch.Tensor = None
|
337
|
-
top_ks: torch.Tensor = None
|
338
|
-
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
339
|
-
logit_bias: torch.Tensor = None
|
340
|
-
|
341
342
|
@classmethod
|
342
343
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
343
344
|
return_logprob = any(req.return_logprob for req in reqs)
|
@@ -385,43 +386,6 @@ class ScheduleBatch:
|
|
385
386
|
|
386
387
|
return out_cache_loc
|
387
388
|
|
388
|
-
def batch_sampling_params(self, vocab_size):
|
389
|
-
device = "cuda"
|
390
|
-
bs, reqs = self.batch_size(), self.reqs
|
391
|
-
self.temperatures = torch.tensor(
|
392
|
-
[r.sampling_params.temperature for r in reqs],
|
393
|
-
dtype=torch.float,
|
394
|
-
device=device,
|
395
|
-
).view(-1, 1)
|
396
|
-
self.top_ps = torch.tensor(
|
397
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
398
|
-
)
|
399
|
-
self.top_ks = torch.tensor(
|
400
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
401
|
-
)
|
402
|
-
|
403
|
-
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
404
|
-
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
405
|
-
# should not add hefty computation overhead other than simple checks.
|
406
|
-
#
|
407
|
-
# While we choose not to even create the class instances if they are not required, this
|
408
|
-
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
409
|
-
# handle {filter_batch()} and {merge()} cases as well.
|
410
|
-
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
411
|
-
vocab_size=vocab_size,
|
412
|
-
batch=self,
|
413
|
-
device=device,
|
414
|
-
Penalizers={
|
415
|
-
penaltylib.BatchedFrequencyPenalizer,
|
416
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
417
|
-
penaltylib.BatchedPresencePenalizer,
|
418
|
-
penaltylib.BatchedRepetitionPenalizer,
|
419
|
-
},
|
420
|
-
)
|
421
|
-
|
422
|
-
# Handle logit bias but only allocate when needed
|
423
|
-
self.logit_bias = None
|
424
|
-
|
425
389
|
def prepare_for_extend(self, vocab_size: int):
|
426
390
|
bs = self.batch_size()
|
427
391
|
reqs = self.reqs
|
@@ -460,8 +424,32 @@ class ScheduleBatch:
|
|
460
424
|
self.extend_num_tokens = extend_num_tokens
|
461
425
|
self.out_cache_loc = out_cache_loc
|
462
426
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
427
|
+
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
428
|
+
|
429
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
430
|
+
|
431
|
+
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
432
|
+
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
433
|
+
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
|
434
|
+
prefix_lens_cpu.extend(
|
435
|
+
[
|
436
|
+
len(r.origin_input_ids) + len(r.output_ids) - 1
|
437
|
+
for r in running_batch.reqs
|
438
|
+
]
|
439
|
+
)
|
463
440
|
|
464
|
-
|
441
|
+
for req in running_batch.reqs:
|
442
|
+
req.fill_ids = req.origin_input_ids + req.output_ids
|
443
|
+
req.extend_input_len = 1
|
444
|
+
|
445
|
+
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
446
|
+
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
447
|
+
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
|
448
|
+
self.merge(running_batch)
|
449
|
+
self.input_ids = input_ids
|
450
|
+
self.out_cache_loc = out_cache_loc
|
451
|
+
self.extend_num_tokens = extend_num_tokens
|
452
|
+
self.prefix_lens_cpu = prefix_lens_cpu
|
465
453
|
|
466
454
|
def check_decode_mem(self):
|
467
455
|
bs = self.batch_size()
|
@@ -634,7 +622,7 @@ class ScheduleBatch:
|
|
634
622
|
for r in self.reqs
|
635
623
|
]
|
636
624
|
else:
|
637
|
-
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
625
|
+
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
638
626
|
|
639
627
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
640
628
|
self.seq_lens.add_(1)
|
@@ -647,6 +635,8 @@ class ScheduleBatch:
|
|
647
635
|
self.req_pool_indices, self.seq_lens - 1
|
648
636
|
] = self.out_cache_loc
|
649
637
|
|
638
|
+
self.sampling_info.update_regex_vocab_mask(self)
|
639
|
+
|
650
640
|
def filter_batch(self, unfinished_indices: List[int]):
|
651
641
|
if unfinished_indices is None or len(unfinished_indices) == 0:
|
652
642
|
# Filter out all requests
|
@@ -667,23 +657,13 @@ class ScheduleBatch:
|
|
667
657
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
668
658
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
669
659
|
|
670
|
-
self.
|
671
|
-
|
672
|
-
for item in [
|
673
|
-
"temperatures",
|
674
|
-
"top_ps",
|
675
|
-
"top_ks",
|
676
|
-
"logit_bias",
|
677
|
-
]:
|
678
|
-
self_val = getattr(self, item, None)
|
679
|
-
if self_val is not None: # logit_bias can be None
|
680
|
-
setattr(self, item, self_val[new_indices])
|
660
|
+
self.sampling_info.filter(unfinished_indices, new_indices)
|
681
661
|
|
682
662
|
def merge(self, other: "ScheduleBatch"):
|
683
663
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
684
664
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
685
665
|
# needs to be called with pre-merged Batch.reqs.
|
686
|
-
self.
|
666
|
+
self.sampling_info.merge(other.sampling_info)
|
687
667
|
|
688
668
|
self.reqs.extend(other.reqs)
|
689
669
|
|
@@ -698,111 +678,11 @@ class ScheduleBatch:
|
|
698
678
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
699
679
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
700
680
|
|
701
|
-
for item in [
|
702
|
-
"temperatures",
|
703
|
-
"top_ps",
|
704
|
-
"top_ks",
|
705
|
-
]:
|
706
|
-
self_val = getattr(self, item, None)
|
707
|
-
other_val = getattr(other, item, None)
|
708
|
-
setattr(self, item, torch.concat([self_val, other_val]))
|
709
|
-
|
710
|
-
# logit_bias can be None
|
711
|
-
if self.logit_bias is not None or other.logit_bias is not None:
|
712
|
-
vocab_size = (
|
713
|
-
self.logit_bias.shape[1]
|
714
|
-
if self.logit_bias is not None
|
715
|
-
else other.logit_bias.shape[1]
|
716
|
-
)
|
717
|
-
if self.logit_bias is None:
|
718
|
-
self.logit_bias = torch.zeros(
|
719
|
-
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
720
|
-
)
|
721
|
-
if other.logit_bias is None:
|
722
|
-
other.logit_bias = torch.zeros(
|
723
|
-
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
724
|
-
)
|
725
|
-
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
726
|
-
|
727
681
|
def sample(self, logits: torch.Tensor):
|
728
|
-
|
729
|
-
# Post process logits
|
730
|
-
logits = logits.contiguous()
|
731
|
-
logits.div_(self.temperatures)
|
732
|
-
if self.logit_bias is not None:
|
733
|
-
logits.add_(self.logit_bias)
|
734
|
-
|
735
|
-
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
736
|
-
if has_regex:
|
737
|
-
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
738
|
-
for i, req in enumerate(self.reqs):
|
739
|
-
if req.regex_fsm is not None:
|
740
|
-
allowed_mask.zero_()
|
741
|
-
allowed_mask[
|
742
|
-
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
743
|
-
] = 1
|
744
|
-
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
745
|
-
|
746
|
-
logits = self.penalizer_orchestrator.apply(logits)
|
747
|
-
|
748
|
-
probs = torch.softmax(logits, dim=-1)
|
749
|
-
|
750
|
-
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
751
|
-
max_top_k_round, batch_size = 32, probs.shape[0]
|
752
|
-
uniform_samples = torch.rand(
|
753
|
-
(max_top_k_round, batch_size), device=probs.device
|
754
|
-
)
|
755
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
756
|
-
probs, uniform_samples, self.top_ks, self.top_ps
|
757
|
-
)
|
758
|
-
else:
|
759
|
-
# Here we provide a slower fallback implementation.
|
760
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
761
|
-
probs, self.top_ks, self.top_ps
|
762
|
-
)
|
763
|
-
|
764
|
-
if not torch.all(success):
|
765
|
-
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
766
|
-
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
767
|
-
argmax_ids = torch.argmax(probs, dim=-1)
|
768
|
-
batch_next_token_ids = torch.where(
|
769
|
-
success, batch_next_token_ids, argmax_ids
|
770
|
-
)
|
682
|
+
from sglang.srt.layers.sampler import Sampler
|
771
683
|
|
772
|
-
|
773
|
-
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
774
|
-
for i, req in enumerate(self.reqs):
|
775
|
-
if req.regex_fsm is not None:
|
776
|
-
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
777
|
-
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
778
|
-
)
|
684
|
+
sampler = Sampler()
|
779
685
|
|
780
|
-
self.
|
686
|
+
batch_next_token_ids = sampler(logits, self.sampling_info)
|
781
687
|
|
782
688
|
return batch_next_token_ids
|
783
|
-
|
784
|
-
|
785
|
-
def top_k_top_p_sampling_from_probs_torch(
|
786
|
-
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
787
|
-
):
|
788
|
-
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
789
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
790
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
791
|
-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
792
|
-
probs_sort[
|
793
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
794
|
-
>= top_ks.view(-1, 1)
|
795
|
-
] = 0.0
|
796
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
797
|
-
try:
|
798
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
799
|
-
except RuntimeError:
|
800
|
-
batch_next_token_ids = torch.zeros(
|
801
|
-
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
802
|
-
)
|
803
|
-
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
804
|
-
return batch_next_token_ids, success
|
805
|
-
|
806
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
807
|
-
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
808
|
-
return batch_next_token_ids, success
|