sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -111,11 +111,14 @@ class PrefillAdder:
|
|
111
111
|
rem_total_tokens: int,
|
112
112
|
rem_input_tokens: int,
|
113
113
|
rem_chunk_tokens: Optional[int],
|
114
|
+
mixed_with_decode_tokens: int = 0,
|
114
115
|
):
|
115
116
|
self.tree_cache = tree_cache
|
116
|
-
self.rem_total_tokens = rem_total_tokens
|
117
|
-
self.rem_input_tokens = rem_input_tokens
|
117
|
+
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
118
|
+
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
118
119
|
self.rem_chunk_tokens = rem_chunk_tokens
|
120
|
+
if self.rem_chunk_tokens is not None:
|
121
|
+
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
119
122
|
|
120
123
|
self.can_run_list = []
|
121
124
|
self.new_inflight_req = None
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -16,20 +18,22 @@ limitations under the License.
|
|
16
18
|
"""Meta data for requests and batches"""
|
17
19
|
|
18
20
|
import logging
|
19
|
-
import warnings
|
20
21
|
from dataclasses import dataclass
|
21
|
-
from typing import List, Optional, Union
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
22
23
|
|
23
24
|
import torch
|
24
|
-
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
25
25
|
|
26
|
-
import sglang.srt.sampling.penaltylib as penaltylib
|
27
26
|
from sglang.global_config import global_config
|
28
27
|
from sglang.srt.constrained import RegexGuide
|
29
28
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
30
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
31
30
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
32
31
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from sglang.srt.layers.sampler import SampleOutput
|
36
|
+
|
33
37
|
|
34
38
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
35
39
|
|
@@ -37,7 +41,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
37
41
|
global_server_args_dict = {
|
38
42
|
"disable_flashinfer": False,
|
39
43
|
"disable_flashinfer_sampling": False,
|
40
|
-
"
|
44
|
+
"triton_attention_reduce_in_fp32": False,
|
41
45
|
"enable_mla": False,
|
42
46
|
}
|
43
47
|
|
@@ -268,7 +272,7 @@ class Req:
|
|
268
272
|
|
269
273
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
270
274
|
# TODO(lsyin): fix token fusion
|
271
|
-
|
275
|
+
logger.warning(
|
272
276
|
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
273
277
|
)
|
274
278
|
return False
|
@@ -327,17 +331,13 @@ class ScheduleBatch:
|
|
327
331
|
out_cache_loc: torch.Tensor = None
|
328
332
|
extend_num_tokens: int = None
|
329
333
|
|
334
|
+
# For mixed chunekd prefill
|
335
|
+
prefix_lens_cpu: List[int] = None
|
336
|
+
|
330
337
|
# For processing logprobs
|
331
338
|
return_logprob: bool = False
|
332
339
|
top_logprobs_nums: List[int] = None
|
333
340
|
|
334
|
-
# Batched sampling params
|
335
|
-
temperatures: torch.Tensor = None
|
336
|
-
top_ps: torch.Tensor = None
|
337
|
-
top_ks: torch.Tensor = None
|
338
|
-
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
339
|
-
logit_bias: torch.Tensor = None
|
340
|
-
|
341
341
|
@classmethod
|
342
342
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
343
343
|
return_logprob = any(req.return_logprob for req in reqs)
|
@@ -385,43 +385,6 @@ class ScheduleBatch:
|
|
385
385
|
|
386
386
|
return out_cache_loc
|
387
387
|
|
388
|
-
def batch_sampling_params(self, vocab_size):
|
389
|
-
device = "cuda"
|
390
|
-
bs, reqs = self.batch_size(), self.reqs
|
391
|
-
self.temperatures = torch.tensor(
|
392
|
-
[r.sampling_params.temperature for r in reqs],
|
393
|
-
dtype=torch.float,
|
394
|
-
device=device,
|
395
|
-
).view(-1, 1)
|
396
|
-
self.top_ps = torch.tensor(
|
397
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
398
|
-
)
|
399
|
-
self.top_ks = torch.tensor(
|
400
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
401
|
-
)
|
402
|
-
|
403
|
-
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
404
|
-
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
405
|
-
# should not add hefty computation overhead other than simple checks.
|
406
|
-
#
|
407
|
-
# While we choose not to even create the class instances if they are not required, this
|
408
|
-
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
409
|
-
# handle {filter_batch()} and {merge()} cases as well.
|
410
|
-
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
411
|
-
vocab_size=vocab_size,
|
412
|
-
batch=self,
|
413
|
-
device=device,
|
414
|
-
Penalizers={
|
415
|
-
penaltylib.BatchedFrequencyPenalizer,
|
416
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
417
|
-
penaltylib.BatchedPresencePenalizer,
|
418
|
-
penaltylib.BatchedRepetitionPenalizer,
|
419
|
-
},
|
420
|
-
)
|
421
|
-
|
422
|
-
# Handle logit bias but only allocate when needed
|
423
|
-
self.logit_bias = None
|
424
|
-
|
425
388
|
def prepare_for_extend(self, vocab_size: int):
|
426
389
|
bs = self.batch_size()
|
427
390
|
reqs = self.reqs
|
@@ -460,8 +423,32 @@ class ScheduleBatch:
|
|
460
423
|
self.extend_num_tokens = extend_num_tokens
|
461
424
|
self.out_cache_loc = out_cache_loc
|
462
425
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
426
|
+
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
427
|
+
|
428
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
429
|
+
|
430
|
+
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
431
|
+
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
432
|
+
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
|
433
|
+
prefix_lens_cpu.extend(
|
434
|
+
[
|
435
|
+
len(r.origin_input_ids) + len(r.output_ids) - 1
|
436
|
+
for r in running_batch.reqs
|
437
|
+
]
|
438
|
+
)
|
463
439
|
|
464
|
-
|
440
|
+
for req in running_batch.reqs:
|
441
|
+
req.fill_ids = req.origin_input_ids + req.output_ids
|
442
|
+
req.extend_input_len = 1
|
443
|
+
|
444
|
+
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
445
|
+
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
446
|
+
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
|
447
|
+
self.merge(running_batch)
|
448
|
+
self.input_ids = input_ids
|
449
|
+
self.out_cache_loc = out_cache_loc
|
450
|
+
self.extend_num_tokens = extend_num_tokens
|
451
|
+
self.prefix_lens_cpu = prefix_lens_cpu
|
465
452
|
|
466
453
|
def check_decode_mem(self):
|
467
454
|
bs = self.batch_size()
|
@@ -634,7 +621,7 @@ class ScheduleBatch:
|
|
634
621
|
for r in self.reqs
|
635
622
|
]
|
636
623
|
else:
|
637
|
-
self.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
624
|
+
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
638
625
|
|
639
626
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
640
627
|
self.seq_lens.add_(1)
|
@@ -647,6 +634,8 @@ class ScheduleBatch:
|
|
647
634
|
self.req_pool_indices, self.seq_lens - 1
|
648
635
|
] = self.out_cache_loc
|
649
636
|
|
637
|
+
self.sampling_info.update_regex_vocab_mask(self)
|
638
|
+
|
650
639
|
def filter_batch(self, unfinished_indices: List[int]):
|
651
640
|
if unfinished_indices is None or len(unfinished_indices) == 0:
|
652
641
|
# Filter out all requests
|
@@ -667,23 +656,13 @@ class ScheduleBatch:
|
|
667
656
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
668
657
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
669
658
|
|
670
|
-
self.
|
671
|
-
|
672
|
-
for item in [
|
673
|
-
"temperatures",
|
674
|
-
"top_ps",
|
675
|
-
"top_ks",
|
676
|
-
"logit_bias",
|
677
|
-
]:
|
678
|
-
self_val = getattr(self, item, None)
|
679
|
-
if self_val is not None: # logit_bias can be None
|
680
|
-
setattr(self, item, self_val[new_indices])
|
659
|
+
self.sampling_info.filter(unfinished_indices, new_indices)
|
681
660
|
|
682
661
|
def merge(self, other: "ScheduleBatch"):
|
683
662
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
684
663
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
685
664
|
# needs to be called with pre-merged Batch.reqs.
|
686
|
-
self.
|
665
|
+
self.sampling_info.merge(other.sampling_info)
|
687
666
|
|
688
667
|
self.reqs.extend(other.reqs)
|
689
668
|
|
@@ -698,111 +677,17 @@ class ScheduleBatch:
|
|
698
677
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
699
678
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
700
679
|
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
self_val = getattr(self, item, None)
|
707
|
-
other_val = getattr(other, item, None)
|
708
|
-
setattr(self, item, torch.concat([self_val, other_val]))
|
709
|
-
|
710
|
-
# logit_bias can be None
|
711
|
-
if self.logit_bias is not None or other.logit_bias is not None:
|
712
|
-
vocab_size = (
|
713
|
-
self.logit_bias.shape[1]
|
714
|
-
if self.logit_bias is not None
|
715
|
-
else other.logit_bias.shape[1]
|
716
|
-
)
|
717
|
-
if self.logit_bias is None:
|
718
|
-
self.logit_bias = torch.zeros(
|
719
|
-
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
720
|
-
)
|
721
|
-
if other.logit_bias is None:
|
722
|
-
other.logit_bias = torch.zeros(
|
723
|
-
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
724
|
-
)
|
725
|
-
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
726
|
-
|
727
|
-
def sample(self, logits: torch.Tensor):
|
728
|
-
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
729
|
-
# Post process logits
|
730
|
-
logits = logits.contiguous()
|
731
|
-
logits.div_(self.temperatures)
|
732
|
-
if self.logit_bias is not None:
|
733
|
-
logits.add_(self.logit_bias)
|
734
|
-
|
735
|
-
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
736
|
-
if has_regex:
|
737
|
-
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
738
|
-
for i, req in enumerate(self.reqs):
|
739
|
-
if req.regex_fsm is not None:
|
740
|
-
allowed_mask.zero_()
|
741
|
-
allowed_mask[
|
742
|
-
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
743
|
-
] = 1
|
744
|
-
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
745
|
-
|
746
|
-
logits = self.penalizer_orchestrator.apply(logits)
|
747
|
-
|
748
|
-
probs = torch.softmax(logits, dim=-1)
|
749
|
-
|
750
|
-
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
751
|
-
max_top_k_round, batch_size = 32, probs.shape[0]
|
752
|
-
uniform_samples = torch.rand(
|
753
|
-
(max_top_k_round, batch_size), device=probs.device
|
754
|
-
)
|
755
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
756
|
-
probs, uniform_samples, self.top_ks, self.top_ps
|
757
|
-
)
|
758
|
-
else:
|
759
|
-
# Here we provide a slower fallback implementation.
|
760
|
-
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
761
|
-
probs, self.top_ks, self.top_ps
|
762
|
-
)
|
763
|
-
|
764
|
-
if not torch.all(success):
|
765
|
-
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
680
|
+
def check_sample_results(self, sample_output: SampleOutput):
|
681
|
+
if not torch.all(sample_output.success):
|
682
|
+
probs = sample_output.probs
|
683
|
+
batch_next_token_ids = sample_output.batch_next_token_ids
|
684
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
766
685
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
767
686
|
argmax_ids = torch.argmax(probs, dim=-1)
|
768
687
|
batch_next_token_ids = torch.where(
|
769
|
-
success, batch_next_token_ids, argmax_ids
|
688
|
+
sample_output.success, batch_next_token_ids, argmax_ids
|
770
689
|
)
|
690
|
+
sample_output.probs = probs
|
691
|
+
sample_output.batch_next_token_ids = batch_next_token_ids
|
771
692
|
|
772
|
-
|
773
|
-
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
774
|
-
for i, req in enumerate(self.reqs):
|
775
|
-
if req.regex_fsm is not None:
|
776
|
-
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
777
|
-
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
778
|
-
)
|
779
|
-
|
780
|
-
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
781
|
-
|
782
|
-
return batch_next_token_ids
|
783
|
-
|
784
|
-
|
785
|
-
def top_k_top_p_sampling_from_probs_torch(
|
786
|
-
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
787
|
-
):
|
788
|
-
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
789
|
-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
790
|
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
791
|
-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
792
|
-
probs_sort[
|
793
|
-
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
794
|
-
>= top_ks.view(-1, 1)
|
795
|
-
] = 0.0
|
796
|
-
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
797
|
-
try:
|
798
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
799
|
-
except RuntimeError:
|
800
|
-
batch_next_token_ids = torch.zeros(
|
801
|
-
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
802
|
-
)
|
803
|
-
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
804
|
-
return batch_next_token_ids, success
|
805
|
-
|
806
|
-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
807
|
-
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
808
|
-
return batch_next_token_ids, success
|
693
|
+
return sample_output.batch_next_token_ids
|