sglang 0.1.22__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/bench_serving.py +243 -25
- sglang/global_config.py +3 -2
- sglang/lang/interpreter.py +1 -0
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/radix_attention.py +38 -49
- sglang/srt/managers/controller/cuda_graph_runner.py +58 -16
- sglang/srt/managers/controller/infer_batch.py +51 -22
- sglang/srt/managers/controller/model_runner.py +58 -4
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +9 -11
- sglang/srt/memory_pool.py +13 -5
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/llama2.py +19 -10
- sglang/srt/server.py +26 -1
- sglang/srt/server_args.py +12 -6
- sglang/srt/utils.py +93 -1
- sglang/version.py +1 -0
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/METADATA +10 -6
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/RECORD +25 -36
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/WHEEL +1 -1
- sglang/backend/__init__.py +0 -0
- sglang/backend/anthropic.py +0 -77
- sglang/backend/base_backend.py +0 -80
- sglang/backend/litellm.py +0 -90
- sglang/backend/openai.py +0 -438
- sglang/backend/runtime_endpoint.py +0 -283
- sglang/backend/vertexai.py +0 -149
- sglang/bench.py +0 -627
- sglang/srt/managers/controller/dp_worker.py +0 -113
- sglang/srt/openai_api/api_adapter.py +0 -432
- sglang/srt/openai_api/openai_api_adapter.py +0 -431
- sglang/srt/openai_api/openai_protocol.py +0 -207
- sglang/srt/openai_api_adapter.py +0 -411
- sglang/srt/openai_protocol.py +0 -207
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/LICENSE +0 -0
- {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/top_level.txt +0 -0
@@ -85,32 +85,47 @@ class RadixAttention(nn.Module):
|
|
85
85
|
return o
|
86
86
|
|
87
87
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
88
|
-
|
89
|
-
|
90
|
-
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
91
|
-
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
92
|
-
causal=True,
|
93
|
-
sm_scale=self.scaling,
|
94
|
-
logits_soft_cap=self.logit_cap,
|
95
|
-
)
|
88
|
+
if not input_metadata.use_ragged:
|
89
|
+
self.store_kv_cache(k, v, input_metadata)
|
96
90
|
|
97
|
-
|
98
|
-
o = o1
|
99
|
-
else:
|
100
|
-
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
91
|
+
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
|
101
92
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
102
|
-
input_metadata.token_to_kv_pool.
|
103
|
-
causal=
|
93
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
94
|
+
causal=True,
|
104
95
|
sm_scale=self.scaling,
|
105
96
|
logits_soft_cap=self.logit_cap,
|
106
97
|
)
|
98
|
+
else:
|
99
|
+
o1, s1 = (
|
100
|
+
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
101
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
102
|
+
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
103
|
+
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
104
|
+
causal=True,
|
105
|
+
sm_scale=self.scaling,
|
106
|
+
logits_soft_cap=self.logit_cap,
|
107
|
+
)
|
108
|
+
)
|
107
109
|
|
108
|
-
|
110
|
+
if input_metadata.extend_no_prefix:
|
111
|
+
o = o1
|
112
|
+
else:
|
113
|
+
o2, s2 = (
|
114
|
+
input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
115
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
116
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
117
|
+
causal=False,
|
118
|
+
sm_scale=self.scaling,
|
119
|
+
logits_soft_cap=self.logit_cap,
|
120
|
+
)
|
121
|
+
)
|
109
122
|
|
110
|
-
|
123
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
124
|
+
|
125
|
+
self.store_kv_cache(k, v, input_metadata)
|
111
126
|
|
112
|
-
|
113
|
-
|
127
|
+
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
128
|
+
torch.cuda.synchronize()
|
114
129
|
|
115
130
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
116
131
|
|
@@ -119,7 +134,7 @@ class RadixAttention(nn.Module):
|
|
119
134
|
|
120
135
|
o = input_metadata.flashinfer_decode_wrapper.forward(
|
121
136
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
122
|
-
input_metadata.token_to_kv_pool.
|
137
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
123
138
|
sm_scale=self.scaling,
|
124
139
|
logits_soft_cap=self.logit_cap,
|
125
140
|
)
|
@@ -136,33 +151,7 @@ class RadixAttention(nn.Module):
|
|
136
151
|
return self.decode_forward(q, k, v, input_metadata)
|
137
152
|
|
138
153
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
try:
|
144
|
-
|
145
|
-
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
146
|
-
def _store_kv_cache(
|
147
|
-
k: torch.Tensor,
|
148
|
-
v: torch.Tensor,
|
149
|
-
kv_cache: torch.Tensor,
|
150
|
-
cache_loc: torch.Tensor,
|
151
|
-
) -> None:
|
152
|
-
kv_cache[cache_loc, 0] = k
|
153
|
-
kv_cache[cache_loc, 1] = v
|
154
|
-
|
155
|
-
@_store_kv_cache.register_fake
|
156
|
-
def _(k, v, kv_cache, cache_loc):
|
157
|
-
pass
|
158
|
-
|
159
|
-
except:
|
160
|
-
|
161
|
-
def _store_kv_cache(
|
162
|
-
k: torch.Tensor,
|
163
|
-
v: torch.Tensor,
|
164
|
-
kv_cache: torch.Tensor,
|
165
|
-
cache_loc: torch.Tensor,
|
166
|
-
) -> None:
|
167
|
-
kv_cache[cache_loc, 0] = k
|
168
|
-
kv_cache[cache_loc, 1] = v
|
154
|
+
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
155
|
+
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
156
|
+
k_cache[input_metadata.out_cache_loc] = cache_k
|
157
|
+
v_cache[input_metadata.out_cache_loc] = cache_v
|
@@ -1,11 +1,13 @@
|
|
1
1
|
"""Run the model with cuda graph."""
|
2
2
|
|
3
3
|
import bisect
|
4
|
+
from contextlib import contextmanager
|
4
5
|
|
5
6
|
import torch
|
6
7
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
7
8
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
8
9
|
from vllm.distributed.parallel_state import graph_capture
|
10
|
+
from vllm.model_executor.custom_op import CustomOp
|
9
11
|
|
10
12
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
11
13
|
from sglang.srt.managers.controller.infer_batch import (
|
@@ -14,10 +16,44 @@ from sglang.srt.managers.controller.infer_batch import (
|
|
14
16
|
InputMetadata,
|
15
17
|
init_flashinfer_args,
|
16
18
|
)
|
19
|
+
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
20
|
+
|
21
|
+
|
22
|
+
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
23
|
+
for sub in model._modules.values():
|
24
|
+
if isinstance(sub, CustomOp):
|
25
|
+
if reverse:
|
26
|
+
sub._forward_method = sub.forward_cuda
|
27
|
+
else:
|
28
|
+
sub._forward_method = sub.forward_native
|
29
|
+
if isinstance(sub, torch.nn.Module):
|
30
|
+
_to_torch(sub, reverse)
|
31
|
+
|
32
|
+
|
33
|
+
@contextmanager
|
34
|
+
def patch_model(
|
35
|
+
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
|
36
|
+
):
|
37
|
+
backup_ca_comm = None
|
38
|
+
|
39
|
+
try:
|
40
|
+
if use_compile:
|
41
|
+
_to_torch(model)
|
42
|
+
monkey_patch_vllm_all_gather()
|
43
|
+
backup_ca_comm = tp_group.ca_comm
|
44
|
+
tp_group.ca_comm = None
|
45
|
+
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
46
|
+
else:
|
47
|
+
yield model.forward
|
48
|
+
finally:
|
49
|
+
if use_compile:
|
50
|
+
_to_torch(model, reverse=True)
|
51
|
+
monkey_patch_vllm_all_gather(reverse=True)
|
52
|
+
tp_group.ca_comm = backup_ca_comm
|
17
53
|
|
18
54
|
|
19
55
|
class CudaGraphRunner:
|
20
|
-
def __init__(self, model_runner, max_batch_size_to_capture):
|
56
|
+
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
21
57
|
self.model_runner = model_runner
|
22
58
|
self.graphs = {}
|
23
59
|
self.input_buffers = {}
|
@@ -55,6 +91,8 @@ class CudaGraphRunner:
|
|
55
91
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
56
92
|
)
|
57
93
|
|
94
|
+
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
95
|
+
|
58
96
|
def can_run(self, batch_size):
|
59
97
|
return batch_size < self.max_bs
|
60
98
|
|
@@ -63,18 +101,23 @@ class CudaGraphRunner:
|
|
63
101
|
with graph_capture() as graph_capture_context:
|
64
102
|
self.stream = graph_capture_context.stream
|
65
103
|
for bs in batch_size_list:
|
66
|
-
(
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
104
|
+
with patch_model(
|
105
|
+
self.model_runner.model,
|
106
|
+
bs in self.compile_bs,
|
107
|
+
self.model_runner.tp_group,
|
108
|
+
) as forward:
|
109
|
+
(
|
110
|
+
graph,
|
111
|
+
input_buffers,
|
112
|
+
output_buffers,
|
113
|
+
flashinfer_handler,
|
114
|
+
) = self.capture_one_batch_size(bs, forward)
|
115
|
+
self.graphs[bs] = graph
|
116
|
+
self.input_buffers[bs] = input_buffers
|
117
|
+
self.output_buffers[bs] = output_buffers
|
118
|
+
self.flashinfer_handlers[bs] = flashinfer_handler
|
119
|
+
|
120
|
+
def capture_one_batch_size(self, bs, forward):
|
78
121
|
graph = torch.cuda.CUDAGraph()
|
79
122
|
stream = self.stream
|
80
123
|
|
@@ -127,9 +170,8 @@ class CudaGraphRunner:
|
|
127
170
|
skip_flashinfer_init=True,
|
128
171
|
)
|
129
172
|
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
130
|
-
|
131
|
-
|
132
|
-
)
|
173
|
+
|
174
|
+
return forward(input_ids, input_metadata.positions, input_metadata)
|
133
175
|
|
134
176
|
for _ in range(2):
|
135
177
|
run_once()
|
@@ -9,6 +9,7 @@ import numpy as np
|
|
9
9
|
import torch
|
10
10
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
11
11
|
|
12
|
+
from sglang.global_config import global_config
|
12
13
|
from sglang.srt.constrained import RegexGuide
|
13
14
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
14
15
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
@@ -431,7 +432,8 @@ class Batch:
|
|
431
432
|
|
432
433
|
def retract_decode(self):
|
433
434
|
sorted_indices = [i for i in range(len(self.reqs))]
|
434
|
-
|
435
|
+
|
436
|
+
# TODO(lsyin): improve retraction policy for radix cache
|
435
437
|
sorted_indices.sort(
|
436
438
|
key=lambda i: (
|
437
439
|
len(self.reqs[i].output_ids),
|
@@ -443,7 +445,17 @@ class Batch:
|
|
443
445
|
retracted_reqs = []
|
444
446
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
445
447
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
446
|
-
while
|
448
|
+
while (
|
449
|
+
self.token_to_kv_pool.available_size()
|
450
|
+
< len(sorted_indices) * global_config.retract_decode_steps
|
451
|
+
):
|
452
|
+
if len(sorted_indices) == 1:
|
453
|
+
# Corner case: only one request left
|
454
|
+
assert (
|
455
|
+
self.token_to_kv_pool.available_size() > 0
|
456
|
+
), "No space left for only one request"
|
457
|
+
break
|
458
|
+
|
447
459
|
idx = sorted_indices.pop()
|
448
460
|
req = self.reqs[idx]
|
449
461
|
retracted_reqs.append(req)
|
@@ -468,7 +480,16 @@ class Batch:
|
|
468
480
|
|
469
481
|
self.filter_batch(sorted_indices)
|
470
482
|
|
471
|
-
|
483
|
+
# Reqs in batch are filtered
|
484
|
+
total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
|
485
|
+
total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
|
486
|
+
|
487
|
+
new_estimate_ratio = (
|
488
|
+
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
|
489
|
+
) / total_max_new_tokens
|
490
|
+
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
491
|
+
|
492
|
+
return retracted_reqs, new_estimate_ratio
|
472
493
|
|
473
494
|
def check_for_jump_forward(self, model_runner):
|
474
495
|
jump_forward_reqs = []
|
@@ -668,18 +689,17 @@ class Batch:
|
|
668
689
|
|
669
690
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
670
691
|
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
|
671
|
-
batch_next_token_ids,
|
692
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
672
693
|
probs, uniform_samples, self.top_ks, self.top_ps
|
673
694
|
)
|
674
695
|
|
675
|
-
|
676
|
-
|
677
|
-
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
|
678
|
-
)
|
679
|
-
if torch.any(illegal_mask):
|
680
|
-
warnings.warn("Illegal sampled token ids")
|
696
|
+
if torch.any(~success):
|
697
|
+
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
681
698
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
682
|
-
|
699
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
700
|
+
batch_next_token_ids = torch.where(
|
701
|
+
success, batch_next_token_ids, argmax_ids
|
702
|
+
)
|
683
703
|
|
684
704
|
if has_regex:
|
685
705
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
@@ -727,6 +747,7 @@ class InputMetadata:
|
|
727
747
|
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
728
748
|
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
729
749
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
750
|
+
use_ragged: bool = False
|
730
751
|
|
731
752
|
@classmethod
|
732
753
|
def create(
|
@@ -742,7 +763,10 @@ class InputMetadata:
|
|
742
763
|
return_logprob=False,
|
743
764
|
skip_flashinfer_init=False,
|
744
765
|
):
|
766
|
+
use_ragged = False
|
745
767
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
768
|
+
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
769
|
+
use_ragged = True
|
746
770
|
init_flashinfer_args(
|
747
771
|
forward_mode,
|
748
772
|
model_runner,
|
@@ -750,6 +774,7 @@ class InputMetadata:
|
|
750
774
|
seq_lens,
|
751
775
|
prefix_lens,
|
752
776
|
model_runner.flashinfer_decode_wrapper,
|
777
|
+
use_ragged,
|
753
778
|
)
|
754
779
|
|
755
780
|
batch_size = len(req_pool_indices)
|
@@ -804,6 +829,7 @@ class InputMetadata:
|
|
804
829
|
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
805
830
|
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
806
831
|
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
832
|
+
use_ragged=use_ragged,
|
807
833
|
)
|
808
834
|
|
809
835
|
if model_runner.server_args.disable_flashinfer:
|
@@ -824,17 +850,19 @@ def init_flashinfer_args(
|
|
824
850
|
seq_lens,
|
825
851
|
prefix_lens,
|
826
852
|
flashinfer_decode_wrapper,
|
853
|
+
use_ragged=False,
|
827
854
|
):
|
828
855
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
829
856
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
830
857
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
831
858
|
head_dim = model_runner.model_config.head_dim
|
832
859
|
batch_size = len(req_pool_indices)
|
860
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
833
861
|
|
834
|
-
if
|
835
|
-
paged_kernel_lens = seq_lens
|
836
|
-
else:
|
862
|
+
if use_ragged:
|
837
863
|
paged_kernel_lens = prefix_lens
|
864
|
+
else:
|
865
|
+
paged_kernel_lens = seq_lens
|
838
866
|
|
839
867
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
840
868
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
@@ -867,14 +895,15 @@ def init_flashinfer_args(
|
|
867
895
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
868
896
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
869
897
|
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
898
|
+
if use_ragged:
|
899
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
900
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
901
|
+
qo_indptr,
|
902
|
+
qo_indptr,
|
903
|
+
num_qo_heads,
|
904
|
+
num_kv_heads,
|
905
|
+
head_dim,
|
906
|
+
)
|
878
907
|
|
879
908
|
# cached part
|
880
909
|
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
@@ -15,6 +15,7 @@ from flashinfer import (
|
|
15
15
|
BatchPrefillWithRaggedKVCacheWrapper,
|
16
16
|
)
|
17
17
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
18
|
+
from torch.nn.parameter import Parameter
|
18
19
|
from vllm.config import DeviceConfig, LoadConfig
|
19
20
|
from vllm.config import ModelConfig as VllmModelConfig
|
20
21
|
from vllm.distributed import (
|
@@ -22,7 +23,7 @@ from vllm.distributed import (
|
|
22
23
|
init_distributed_environment,
|
23
24
|
initialize_model_parallel,
|
24
25
|
)
|
25
|
-
from vllm.model_executor.
|
26
|
+
from vllm.model_executor.layers.linear import QKVParallelLinear
|
26
27
|
from vllm.model_executor.models import ModelRegistry
|
27
28
|
|
28
29
|
from sglang.global_config import global_config
|
@@ -39,6 +40,18 @@ from sglang.srt.utils import (
|
|
39
40
|
logger = logging.getLogger("srt.model_runner")
|
40
41
|
|
41
42
|
|
43
|
+
def is_llama3_405b_fp8(model_config):
|
44
|
+
if (
|
45
|
+
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
46
|
+
and model_config.hf_config.hidden_size == 16384
|
47
|
+
and model_config.hf_config.intermediate_size == 53248
|
48
|
+
and model_config.hf_config.num_hidden_layers == 126
|
49
|
+
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
50
|
+
):
|
51
|
+
return True
|
52
|
+
return False
|
53
|
+
|
54
|
+
|
42
55
|
class ModelRunner:
|
43
56
|
def __init__(
|
44
57
|
self,
|
@@ -119,6 +132,9 @@ class ModelRunner:
|
|
119
132
|
seed=42,
|
120
133
|
skip_tokenizer_init=True,
|
121
134
|
)
|
135
|
+
if is_llama3_405b_fp8(self.model_config):
|
136
|
+
self.model_config.hf_config.num_key_value_heads = 8
|
137
|
+
vllm_model_config.hf_config.num_key_value_heads = 8
|
122
138
|
self.dtype = vllm_model_config.dtype
|
123
139
|
if self.model_config.model_overide_args is not None:
|
124
140
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
@@ -241,16 +257,20 @@ class ModelRunner:
|
|
241
257
|
self.cuda_graph_runner = None
|
242
258
|
return
|
243
259
|
|
244
|
-
logger.info(
|
260
|
+
logger.info(
|
261
|
+
f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
|
262
|
+
)
|
245
263
|
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
|
246
264
|
self.cuda_graph_runner = CudaGraphRunner(
|
247
|
-
self,
|
265
|
+
self,
|
266
|
+
max_batch_size_to_capture=max(batch_size_list),
|
267
|
+
use_torch_compile=self.server_args.enable_torch_compile,
|
248
268
|
)
|
249
269
|
try:
|
250
270
|
self.cuda_graph_runner.capture(batch_size_list)
|
251
271
|
except RuntimeError as e:
|
252
272
|
raise Exception(
|
253
|
-
f"Capture cuda graph failed {e}. Possible solutions:\n"
|
273
|
+
f"Capture cuda graph failed: {e}. Possible solutions:\n"
|
254
274
|
f"1. disable cuda graph by --disable-cuda-graph\n"
|
255
275
|
f"2. set --mem-fraction-static to a smaller value\n"
|
256
276
|
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
|
@@ -367,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
367
387
|
return model_arch_name_to_cls[model_arch]
|
368
388
|
|
369
389
|
|
390
|
+
def get_original_weight(loaded_weight, head_dim):
|
391
|
+
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
|
392
|
+
dim = loaded_weight.shape[1]
|
393
|
+
for i in range(n_kv_head):
|
394
|
+
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
|
395
|
+
2 * i * head_dim : (2 * i + 1) * head_dim, :
|
396
|
+
]
|
397
|
+
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
|
398
|
+
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
|
399
|
+
return original_kv_weight
|
400
|
+
|
401
|
+
|
402
|
+
def get_weight_loader_srt(weight_loader):
|
403
|
+
def weight_loader_srt(
|
404
|
+
self,
|
405
|
+
param: Parameter,
|
406
|
+
loaded_weight: torch.Tensor,
|
407
|
+
loaded_shard_id: Optional[str] = None,
|
408
|
+
):
|
409
|
+
if (
|
410
|
+
loaded_shard_id in ["k", "v"]
|
411
|
+
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
|
412
|
+
):
|
413
|
+
loaded_weight = get_original_weight(loaded_weight, self.head_size)
|
414
|
+
|
415
|
+
weight_loader(self, param, loaded_weight, loaded_shard_id)
|
416
|
+
|
417
|
+
return weight_loader_srt
|
418
|
+
|
419
|
+
|
370
420
|
# Monkey patch model loader
|
371
421
|
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
422
|
+
original_weight_loader = QKVParallelLinear.weight_loader
|
423
|
+
setattr(
|
424
|
+
QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
|
425
|
+
)
|
@@ -14,7 +14,7 @@ class ScheduleHeuristic:
|
|
14
14
|
tree_cache,
|
15
15
|
):
|
16
16
|
if tree_cache.disable and schedule_heuristic == "lpm":
|
17
|
-
# LMP is
|
17
|
+
# LMP is meaningless when the tree cache is disabled.
|
18
18
|
schedule_heuristic = "fcfs"
|
19
19
|
|
20
20
|
self.schedule_heuristic = schedule_heuristic
|
@@ -28,11 +28,16 @@ class ScheduleHeuristic:
|
|
28
28
|
# longest prefix match
|
29
29
|
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
30
30
|
return forward_queue
|
31
|
+
elif self.schedule_heuristic == "fcfs":
|
32
|
+
# first come first serve
|
33
|
+
return forward_queue
|
34
|
+
elif self.schedule_heuristic == "lof":
|
35
|
+
# longest output first
|
36
|
+
forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
37
|
+
return forward_queue
|
31
38
|
elif self.schedule_heuristic == "random":
|
32
39
|
random.shuffle(forward_queue)
|
33
40
|
return forward_queue
|
34
|
-
elif self.schedule_heuristic == "fcfs":
|
35
|
-
return forward_queue
|
36
41
|
elif self.schedule_heuristic == "dfs-weight":
|
37
42
|
last_node_to_reqs = defaultdict(list)
|
38
43
|
for req in forward_queue:
|
@@ -103,6 +103,9 @@ class ModelTpServer:
|
|
103
103
|
if server_args.max_running_requests is None
|
104
104
|
else server_args.max_running_requests
|
105
105
|
)
|
106
|
+
self.max_running_requests = min(
|
107
|
+
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
|
108
|
+
)
|
106
109
|
self.int_token_logit_bias = torch.tensor(
|
107
110
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
108
111
|
)
|
@@ -113,13 +116,9 @@ class ModelTpServer:
|
|
113
116
|
f"[gpu_id={self.gpu_id}] "
|
114
117
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
115
118
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
119
|
+
f"max_running_requests={self.max_running_requests}, "
|
116
120
|
f"context_len={self.model_config.context_len}"
|
117
121
|
)
|
118
|
-
if self.tp_rank == 0:
|
119
|
-
logger.info(
|
120
|
-
f"[gpu_id={self.gpu_id}] "
|
121
|
-
f"server_args: {server_args.print_mode_args()}"
|
122
|
-
)
|
123
122
|
|
124
123
|
# Init cache
|
125
124
|
self.tree_cache = RadixCache(
|
@@ -161,15 +160,12 @@ class ModelTpServer:
|
|
161
160
|
assert (
|
162
161
|
server_args.schedule_conservativeness >= 0
|
163
162
|
), "Invalid schedule_conservativeness"
|
164
|
-
self.new_token_ratio = min(
|
165
|
-
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
166
|
-
1.0,
|
167
|
-
)
|
168
163
|
self.min_new_token_ratio = min(
|
169
164
|
global_config.base_min_new_token_ratio
|
170
165
|
* server_args.schedule_conservativeness,
|
171
166
|
1.0,
|
172
167
|
)
|
168
|
+
self.new_token_ratio = self.min_new_token_ratio
|
173
169
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
174
170
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
175
171
|
|
@@ -231,6 +227,7 @@ class ModelTpServer:
|
|
231
227
|
break
|
232
228
|
else:
|
233
229
|
self.check_memory()
|
230
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
234
231
|
|
235
232
|
def print_stats(self):
|
236
233
|
num_used = self.max_total_num_tokens - (
|
@@ -539,9 +536,10 @@ class ModelTpServer:
|
|
539
536
|
# Check if decode out of memory
|
540
537
|
if not batch.check_decode_mem():
|
541
538
|
old_ratio = self.new_token_ratio
|
542
|
-
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
543
539
|
|
544
|
-
retracted_reqs = batch.retract_decode()
|
540
|
+
retracted_reqs, new_token_ratio = batch.retract_decode()
|
541
|
+
self.new_token_ratio = new_token_ratio
|
542
|
+
|
545
543
|
logger.info(
|
546
544
|
"decode out of memory happened, "
|
547
545
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
sglang/srt/memory_pool.py
CHANGED
@@ -11,6 +11,7 @@ class ReqToTokenPool:
|
|
11
11
|
"""A memory pool that maps a request to its token locations."""
|
12
12
|
|
13
13
|
def __init__(self, size: int, max_context_len: int):
|
14
|
+
self.size = size
|
14
15
|
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
15
16
|
self.req_to_token = torch.empty(
|
16
17
|
(size, max_context_len), dtype=torch.int32, device="cuda"
|
@@ -57,9 +58,13 @@ class TokenToKVPool:
|
|
57
58
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
58
59
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
59
60
|
|
60
|
-
# [size,
|
61
|
-
self.
|
62
|
-
torch.empty((size + 1,
|
61
|
+
# [size, head_num, head_dim] for each layer
|
62
|
+
self.k_buffer = [
|
63
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
64
|
+
for _ in range(layer_num)
|
65
|
+
]
|
66
|
+
self.v_buffer = [
|
67
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
63
68
|
for _ in range(layer_num)
|
64
69
|
]
|
65
70
|
|
@@ -71,10 +76,13 @@ class TokenToKVPool:
|
|
71
76
|
self.clear()
|
72
77
|
|
73
78
|
def get_key_buffer(self, layer_id: int):
|
74
|
-
return self.
|
79
|
+
return self.k_buffer[layer_id]
|
75
80
|
|
76
81
|
def get_value_buffer(self, layer_id: int):
|
77
|
-
return self.
|
82
|
+
return self.v_buffer[layer_id]
|
83
|
+
|
84
|
+
def get_kv_buffer(self, layer_id: int):
|
85
|
+
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
78
86
|
|
79
87
|
def available_size(self):
|
80
88
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|