sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -53,6 +53,7 @@ global_server_args_dict = {
|
|
53
53
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
54
54
|
"disable_mla": ServerArgs.disable_mla,
|
55
55
|
"torchao_config": ServerArgs.torchao_config,
|
56
|
+
"disable_nan_detection": ServerArgs.disable_nan_detection,
|
56
57
|
}
|
57
58
|
|
58
59
|
|
@@ -196,6 +197,9 @@ class Req:
|
|
196
197
|
# this does not include the jump forward tokens.
|
197
198
|
self.completion_tokens_wo_jump_forward = 0
|
198
199
|
|
200
|
+
# The number of cached tokens, that were already cached in the KV store
|
201
|
+
self.cached_tokens = 0
|
202
|
+
|
199
203
|
# For vision inputs
|
200
204
|
self.image_inputs: Optional[ImageInputs] = None
|
201
205
|
|
@@ -203,6 +207,7 @@ class Req:
|
|
203
207
|
self.prefix_indices = []
|
204
208
|
self.extend_input_len = 0
|
205
209
|
self.last_node = None
|
210
|
+
self.is_inflight_req = 0
|
206
211
|
|
207
212
|
# Logprobs (arguments)
|
208
213
|
self.return_logprob = False
|
@@ -391,25 +396,30 @@ class Req:
|
|
391
396
|
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
392
397
|
|
393
398
|
|
399
|
+
bid = 0
|
400
|
+
|
401
|
+
|
394
402
|
@dataclass
|
395
403
|
class ScheduleBatch:
|
396
404
|
"""Store all inforamtion of a batch."""
|
397
405
|
|
398
406
|
# Request, memory pool, and cache
|
399
407
|
reqs: List[Req]
|
400
|
-
req_to_token_pool: ReqToTokenPool
|
401
|
-
token_to_kv_pool: BaseTokenToKVPool
|
402
|
-
tree_cache: BasePrefixCache
|
408
|
+
req_to_token_pool: ReqToTokenPool = None
|
409
|
+
token_to_kv_pool: BaseTokenToKVPool = None
|
410
|
+
tree_cache: BasePrefixCache = None
|
403
411
|
|
404
412
|
forward_mode: ForwardMode = None
|
405
413
|
sampling_info: SamplingBatchInfo = None
|
406
414
|
|
407
415
|
# Batched arguments to model runner
|
408
|
-
input_ids:
|
409
|
-
req_pool_indices:
|
410
|
-
seq_lens:
|
416
|
+
input_ids: torch.Tensor = None
|
417
|
+
req_pool_indices: torch.Tensor = None
|
418
|
+
seq_lens: torch.Tensor = None
|
411
419
|
out_cache_loc: torch.Tensor = None
|
412
420
|
|
421
|
+
output_ids: torch.Tensor = None
|
422
|
+
|
413
423
|
# For processing logprobs
|
414
424
|
return_logprob: bool = False
|
415
425
|
top_logprobs_nums: Optional[List[int]] = None
|
@@ -419,10 +429,14 @@ class ScheduleBatch:
|
|
419
429
|
extend_lens: List[int] = None
|
420
430
|
extend_num_tokens: int = None
|
421
431
|
running_bs: int = None
|
432
|
+
decoding_reqs: List[Req] = None
|
422
433
|
|
423
434
|
# Stream
|
424
435
|
has_stream: bool = False
|
425
436
|
|
437
|
+
# device
|
438
|
+
device: str = "cuda"
|
439
|
+
|
426
440
|
# Has regex
|
427
441
|
has_regex: bool = False
|
428
442
|
|
@@ -439,6 +453,7 @@ class ScheduleBatch:
|
|
439
453
|
tree_cache=tree_cache,
|
440
454
|
return_logprob=return_logprob,
|
441
455
|
has_stream=has_stream,
|
456
|
+
device=req_to_token_pool.device,
|
442
457
|
has_regex=has_regex,
|
443
458
|
)
|
444
459
|
|
@@ -488,17 +503,24 @@ class ScheduleBatch:
|
|
488
503
|
|
489
504
|
pt = 0
|
490
505
|
for i, req in enumerate(reqs):
|
506
|
+
already_computed = (
|
507
|
+
req.extend_logprob_start_len + 1 + req.cached_tokens
|
508
|
+
if req.extend_logprob_start_len > 0
|
509
|
+
else 0
|
510
|
+
)
|
511
|
+
req.cached_tokens += len(req.prefix_indices) - already_computed
|
512
|
+
|
491
513
|
req.req_pool_idx = req_pool_indices[i]
|
492
514
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
493
515
|
seq_lens.append(seq_len)
|
494
516
|
assert seq_len - pre_len == req.extend_input_len
|
495
517
|
|
496
518
|
if pre_len > 0:
|
497
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
498
|
-
|
499
|
-
|
519
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
520
|
+
req.prefix_indices
|
521
|
+
)
|
500
522
|
|
501
|
-
self.req_to_token_pool.req_to_token[req.req_pool_idx
|
523
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
502
524
|
out_cache_loc[pt : pt + req.extend_input_len]
|
503
525
|
)
|
504
526
|
|
@@ -514,10 +536,15 @@ class ScheduleBatch:
|
|
514
536
|
pt += req.extend_input_len
|
515
537
|
|
516
538
|
# Set fields
|
517
|
-
|
518
|
-
self.
|
519
|
-
|
520
|
-
|
539
|
+
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
540
|
+
self.device, non_blocking=True
|
541
|
+
)
|
542
|
+
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
|
543
|
+
self.device, non_blocking=True
|
544
|
+
)
|
545
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
546
|
+
self.device, non_blocking=True
|
547
|
+
)
|
521
548
|
|
522
549
|
self.extend_num_tokens = extend_num_tokens
|
523
550
|
self.out_cache_loc = out_cache_loc
|
@@ -527,7 +554,9 @@ class ScheduleBatch:
|
|
527
554
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
528
555
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
529
556
|
|
530
|
-
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
557
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
558
|
+
self, vocab_size, global_server_args_dict["disable_penalizer"]
|
559
|
+
)
|
531
560
|
|
532
561
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
533
562
|
self.forward_mode = ForwardMode.MIXED
|
@@ -582,9 +611,11 @@ class ScheduleBatch:
|
|
582
611
|
|
583
612
|
retracted_reqs = []
|
584
613
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
614
|
+
first_iter = True
|
585
615
|
while (
|
586
616
|
self.token_to_kv_pool.available_size()
|
587
617
|
< len(sorted_indices) * global_config.retract_decode_steps
|
618
|
+
or first_iter
|
588
619
|
):
|
589
620
|
if len(sorted_indices) == 1:
|
590
621
|
# Corner case: only one request left
|
@@ -593,6 +624,7 @@ class ScheduleBatch:
|
|
593
624
|
), "No space left for only one request"
|
594
625
|
break
|
595
626
|
|
627
|
+
first_iter = False
|
596
628
|
idx = sorted_indices.pop()
|
597
629
|
req = self.reqs[idx]
|
598
630
|
retracted_reqs.append(req)
|
@@ -633,7 +665,7 @@ class ScheduleBatch:
|
|
633
665
|
req.last_update_decode_tokens = 0
|
634
666
|
req.logprob_start_len = 10**9
|
635
667
|
|
636
|
-
self.filter_batch(sorted_indices)
|
668
|
+
self.filter_batch(keep_indices=sorted_indices)
|
637
669
|
|
638
670
|
# Reqs in batch are filtered
|
639
671
|
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
@@ -648,7 +680,7 @@ class ScheduleBatch:
|
|
648
680
|
|
649
681
|
def check_for_jump_forward(self, pad_input_ids_func):
|
650
682
|
jump_forward_reqs = []
|
651
|
-
|
683
|
+
keep_indices = set(i for i in range(len(self.reqs)))
|
652
684
|
|
653
685
|
for i, req in enumerate(self.reqs):
|
654
686
|
if req.jump_forward_map is not None:
|
@@ -708,63 +740,71 @@ class ScheduleBatch:
|
|
708
740
|
)
|
709
741
|
|
710
742
|
jump_forward_reqs.append(req)
|
711
|
-
|
743
|
+
keep_indices.remove(i)
|
712
744
|
|
713
|
-
self.filter_batch(
|
745
|
+
self.filter_batch(keep_indices=list(keep_indices))
|
714
746
|
|
715
747
|
return jump_forward_reqs
|
716
748
|
|
717
|
-
def prepare_for_decode(self
|
749
|
+
def prepare_for_decode(self):
|
718
750
|
self.forward_mode = ForwardMode.DECODE
|
719
751
|
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
self.input_ids = torch.tensor(
|
727
|
-
input_ids, dtype=torch.int32, device=self.seq_lens.device
|
728
|
-
)
|
729
|
-
self.seq_lens.add_(1)
|
752
|
+
self.input_ids = self.output_ids
|
753
|
+
self.output_ids = None
|
754
|
+
if self.sampling_info.penalizer_orchestrator:
|
755
|
+
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
756
|
+
self.input_ids
|
757
|
+
)
|
730
758
|
|
731
759
|
# Alloc mem
|
732
760
|
bs = len(self.reqs)
|
733
761
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
734
762
|
|
735
|
-
self.req_to_token_pool.req_to_token[
|
736
|
-
self.
|
737
|
-
|
763
|
+
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
764
|
+
self.out_cache_loc
|
765
|
+
)
|
766
|
+
self.seq_lens.add_(1)
|
738
767
|
|
739
|
-
def filter_batch(
|
740
|
-
|
768
|
+
def filter_batch(
|
769
|
+
self,
|
770
|
+
current_inflight_req: Optional[Req] = None,
|
771
|
+
keep_indices: Optional[List[int]] = None,
|
772
|
+
):
|
773
|
+
if keep_indices is None:
|
774
|
+
keep_indices = [
|
775
|
+
i
|
776
|
+
for i in range(len(self.reqs))
|
777
|
+
if not self.reqs[i].finished()
|
778
|
+
and self.reqs[i] is not current_inflight_req
|
779
|
+
]
|
780
|
+
|
781
|
+
if keep_indices is None or len(keep_indices) == 0:
|
741
782
|
# Filter out all requests
|
742
783
|
self.reqs = []
|
743
784
|
return
|
744
785
|
|
745
|
-
if len(
|
786
|
+
if len(keep_indices) == len(self.reqs):
|
746
787
|
# No need to filter
|
747
788
|
return
|
748
789
|
|
749
|
-
self.reqs = [self.reqs[i] for i in
|
750
|
-
new_indices = torch.tensor(
|
751
|
-
|
790
|
+
self.reqs = [self.reqs[i] for i in keep_indices]
|
791
|
+
new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
|
792
|
+
self.device, non_blocking=True
|
752
793
|
)
|
753
794
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
754
795
|
self.seq_lens = self.seq_lens[new_indices]
|
755
796
|
self.out_cache_loc = None
|
797
|
+
self.output_ids = self.output_ids[new_indices]
|
756
798
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
757
799
|
if self.return_logprob:
|
758
|
-
self.top_logprobs_nums = [
|
759
|
-
self.top_logprobs_nums[i] for i in unfinished_indices
|
760
|
-
]
|
800
|
+
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
761
801
|
else:
|
762
802
|
self.top_logprobs_nums = None
|
763
803
|
|
764
804
|
self.has_stream = any(req.stream for req in self.reqs)
|
765
805
|
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
766
806
|
|
767
|
-
self.sampling_info.filter_batch(
|
807
|
+
self.sampling_info.filter_batch(keep_indices, new_indices)
|
768
808
|
|
769
809
|
def merge_batch(self, other: "ScheduleBatch"):
|
770
810
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -777,6 +817,8 @@ class ScheduleBatch:
|
|
777
817
|
)
|
778
818
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
779
819
|
self.out_cache_loc = None
|
820
|
+
if self.output_ids is not None:
|
821
|
+
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
780
822
|
if self.return_logprob and other.return_logprob:
|
781
823
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
782
824
|
elif self.return_logprob:
|
@@ -806,8 +848,14 @@ class ScheduleBatch:
|
|
806
848
|
self.sampling_info.regex_fsm_states = [
|
807
849
|
req.regex_fsm_state for req in self.reqs
|
808
850
|
]
|
851
|
+
else:
|
852
|
+
self.sampling_info.regex_fsms = None
|
853
|
+
|
854
|
+
global bid
|
855
|
+
bid += 1
|
809
856
|
|
810
857
|
return ModelWorkerBatch(
|
858
|
+
bid=bid,
|
811
859
|
forward_mode=self.forward_mode,
|
812
860
|
input_ids=self.input_ids,
|
813
861
|
req_pool_indices=self.req_pool_indices,
|
@@ -823,9 +871,26 @@ class ScheduleBatch:
|
|
823
871
|
sampling_info=self.sampling_info,
|
824
872
|
)
|
825
873
|
|
874
|
+
def copy(self):
|
875
|
+
return ScheduleBatch(
|
876
|
+
reqs=self.reqs,
|
877
|
+
forward_mode=self.forward_mode,
|
878
|
+
out_cache_loc=self.out_cache_loc,
|
879
|
+
return_logprob=self.return_logprob,
|
880
|
+
decoding_reqs=self.decoding_reqs,
|
881
|
+
)
|
882
|
+
|
883
|
+
def __str__(self):
|
884
|
+
return (
|
885
|
+
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
886
|
+
f"#req={(len(self.reqs))})"
|
887
|
+
)
|
888
|
+
|
826
889
|
|
827
890
|
@dataclass
|
828
891
|
class ModelWorkerBatch:
|
892
|
+
# The batch id
|
893
|
+
bid: int
|
829
894
|
# The forward mode
|
830
895
|
forward_mode: ForwardMode
|
831
896
|
# The input ids
|
@@ -854,3 +919,21 @@ class ModelWorkerBatch:
|
|
854
919
|
|
855
920
|
# Sampling info
|
856
921
|
sampling_info: SamplingBatchInfo
|
922
|
+
|
923
|
+
def copy(self):
|
924
|
+
return ModelWorkerBatch(
|
925
|
+
bid=self.bid,
|
926
|
+
forward_mode=self.forward_mode,
|
927
|
+
input_ids=self.input_ids.clone(),
|
928
|
+
req_pool_indices=self.req_pool_indices,
|
929
|
+
seq_lens=self.seq_lens.clone(),
|
930
|
+
out_cache_loc=self.out_cache_loc,
|
931
|
+
return_logprob=self.return_logprob,
|
932
|
+
top_logprobs_nums=self.top_logprobs_nums,
|
933
|
+
extend_seq_lens=self.extend_seq_lens,
|
934
|
+
extend_prefix_lens=self.extend_prefix_lens,
|
935
|
+
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
936
|
+
image_inputs=self.image_inputs,
|
937
|
+
lora_paths=self.lora_paths,
|
938
|
+
sampling_info=self.sampling_info.copy(),
|
939
|
+
)
|
@@ -45,12 +45,13 @@ class SchedulePolicy:
|
|
45
45
|
def calc_priority(self, waiting_queue: List[Req]):
|
46
46
|
# Compute matched prefix length
|
47
47
|
prefix_computed = False
|
48
|
-
if self.policy
|
48
|
+
if self.policy == "lpm" or self.policy == "dfs-weight":
|
49
49
|
for r in waiting_queue:
|
50
50
|
# NOTE: the prefix_indices must always be aligned with last_node
|
51
51
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
52
52
|
rid=r.rid, key=r.adjust_max_prefix_ids()
|
53
53
|
)
|
54
|
+
|
54
55
|
prefix_computed = True
|
55
56
|
|
56
57
|
if self.policy == "lpm":
|