sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -153,8 +153,9 @@ class TokenizerManager:
|
|
153
153
|
async def _handle_single_request(
|
154
154
|
self, obj, request, index=None, is_cache_for_prefill=False
|
155
155
|
):
|
156
|
-
if not is_cache_for_prefill:
|
157
|
-
not_use_index =
|
156
|
+
if not is_cache_for_prefill: # The normal case with a single prompt
|
157
|
+
not_use_index = index is None
|
158
|
+
|
158
159
|
rid = obj.rid if not_use_index else obj.rid[index]
|
159
160
|
input_text = obj.text if not_use_index else obj.text[index]
|
160
161
|
input_ids = (
|
@@ -182,14 +183,27 @@ class TokenizerManager:
|
|
182
183
|
top_logprobs_num = (
|
183
184
|
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
184
185
|
)
|
185
|
-
else:
|
186
|
-
if
|
187
|
-
|
188
|
-
|
186
|
+
else: # A prefill request to cache the common prompt for parallel sampling
|
187
|
+
if obj.text is not None:
|
188
|
+
if isinstance(obj.text, list):
|
189
|
+
input_text = obj.text[index]
|
190
|
+
rid = obj.rid[index]
|
191
|
+
else:
|
192
|
+
input_text = obj.text
|
193
|
+
rid = obj.rid[0]
|
194
|
+
input_ids = self.tokenizer.encode(input_text)
|
189
195
|
else:
|
190
|
-
input_text =
|
191
|
-
|
192
|
-
|
196
|
+
input_text = None
|
197
|
+
if isinstance(obj.input_ids, list) and isinstance(
|
198
|
+
obj.input_ids[0], list
|
199
|
+
):
|
200
|
+
# when obj["input_ids"] is List[List[int]]
|
201
|
+
input_ids = obj.input_ids[index]
|
202
|
+
rid = obj.rid[index]
|
203
|
+
else:
|
204
|
+
input_ids = obj.input_ids
|
205
|
+
rid = obj.rid[0]
|
206
|
+
|
193
207
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
194
208
|
sampling_params.max_new_tokens = 0
|
195
209
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
@@ -240,11 +254,11 @@ class TokenizerManager:
|
|
240
254
|
):
|
241
255
|
if input_id_result is not None:
|
242
256
|
input_id_result.append(input_id)
|
243
|
-
|
244
|
-
if len(input_id_result) > 1 and input_id_result is not None:
|
257
|
+
if input_id_result is not None and len(input_id_result) > 1:
|
245
258
|
obj.input_ids = input_id_result
|
246
259
|
elif input_id_result is not None:
|
247
260
|
obj.input_ids = input_id_result[0]
|
261
|
+
|
248
262
|
# First send out all requests
|
249
263
|
for i in range(batch_size):
|
250
264
|
for j in range(parallel_sample_num):
|
@@ -264,11 +278,12 @@ class TokenizerManager:
|
|
264
278
|
input_text = None
|
265
279
|
input_ids = obj.input_ids[i]
|
266
280
|
else:
|
281
|
+
assert obj.input_ids is not None
|
267
282
|
if batch_size == 1:
|
268
|
-
input_text =
|
283
|
+
input_text = None
|
269
284
|
input_ids = obj.input_ids
|
270
285
|
else:
|
271
|
-
input_text =
|
286
|
+
input_text = None
|
272
287
|
input_ids = obj.input_ids[i]
|
273
288
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
274
289
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
@@ -293,7 +308,6 @@ class TokenizerManager:
|
|
293
308
|
event = asyncio.Event()
|
294
309
|
state = ReqState([], False, event)
|
295
310
|
self.rid_to_state[rid] = state
|
296
|
-
|
297
311
|
# Then wait for all responses
|
298
312
|
output_list = []
|
299
313
|
for i in range(batch_size):
|
@@ -326,7 +340,6 @@ class TokenizerManager:
|
|
326
340
|
)
|
327
341
|
assert state.finished
|
328
342
|
del self.rid_to_state[rid]
|
329
|
-
|
330
343
|
yield output_list
|
331
344
|
|
332
345
|
def _validate_input_length(self, input_ids: List[int]):
|
@@ -375,8 +388,13 @@ class TokenizerManager:
|
|
375
388
|
obj.return_text_in_logprobs,
|
376
389
|
)
|
377
390
|
|
391
|
+
# Log requests
|
378
392
|
if self.server_args.log_requests and state.finished:
|
379
|
-
|
393
|
+
if obj.text is None:
|
394
|
+
in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
|
395
|
+
else:
|
396
|
+
in_obj = {"text": obj.text}
|
397
|
+
logger.info(f"in={in_obj}, out={out}")
|
380
398
|
|
381
399
|
state.out_list = []
|
382
400
|
if state.finished:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
|
39
39
|
from sglang.srt.managers.schedule_batch import (
|
40
40
|
FINISH_ABORT,
|
41
41
|
BaseFinishReason,
|
42
|
-
Batch,
|
43
|
-
ForwardMode,
|
44
42
|
Req,
|
43
|
+
ScheduleBatch,
|
45
44
|
)
|
46
45
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
47
46
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
48
47
|
from sglang.srt.model_config import ModelConfig
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
49
49
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
50
50
|
from sglang.srt.server_args import ServerArgs
|
51
51
|
from sglang.srt.utils import (
|
@@ -172,7 +172,7 @@ class ModelTpServer:
|
|
172
172
|
|
173
173
|
# Init running status
|
174
174
|
self.waiting_queue: List[Req] = []
|
175
|
-
self.running_batch:
|
175
|
+
self.running_batch: ScheduleBatch = None
|
176
176
|
self.out_pyobjs = []
|
177
177
|
self.decode_forward_ct = 0
|
178
178
|
self.stream_interval = server_args.stream_interval
|
@@ -200,7 +200,6 @@ class ModelTpServer:
|
|
200
200
|
)
|
201
201
|
self.new_token_ratio = self.min_new_token_ratio
|
202
202
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
203
|
-
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
204
203
|
|
205
204
|
def exposed_step(self, recv_reqs):
|
206
205
|
try:
|
@@ -290,10 +289,10 @@ class ModelTpServer:
|
|
290
289
|
"KV cache pool leak detected!"
|
291
290
|
)
|
292
291
|
|
293
|
-
if self.req_to_token_pool.
|
292
|
+
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
294
293
|
warnings.warn(
|
295
294
|
"Warning: "
|
296
|
-
f"available req slots={self.req_to_token_pool.
|
295
|
+
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
297
296
|
f"total slots={self.req_to_token_pool.size}\n"
|
298
297
|
"Memory pool leak detected!"
|
299
298
|
)
|
@@ -353,7 +352,7 @@ class ModelTpServer:
|
|
353
352
|
)
|
354
353
|
self.waiting_queue.append(req)
|
355
354
|
|
356
|
-
def get_new_prefill_batch(self) -> Optional[
|
355
|
+
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
357
356
|
# TODO(lsyin): organize this function
|
358
357
|
running_bs = (
|
359
358
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
@@ -364,12 +363,13 @@ class ModelTpServer:
|
|
364
363
|
# Compute matched prefix length
|
365
364
|
for req in self.waiting_queue:
|
366
365
|
req.input_ids = req.origin_input_ids + req.output_ids
|
366
|
+
try_match_ids = req.input_ids
|
367
|
+
if req.return_logprob:
|
368
|
+
try_match_ids = req.input_ids[: req.logprob_start_len]
|
369
|
+
# NOTE: the prefix_indices must always be aligned with last_node
|
367
370
|
prefix_indices, last_node = self.tree_cache.match_prefix(
|
368
|
-
rid=req.rid,
|
369
|
-
key=req.input_ids,
|
371
|
+
rid=req.rid, key=try_match_ids
|
370
372
|
)
|
371
|
-
if req.return_logprob:
|
372
|
-
prefix_indices = prefix_indices[: req.logprob_start_len]
|
373
373
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
374
374
|
req.prefix_indices = prefix_indices
|
375
375
|
req.last_node = last_node
|
@@ -525,7 +525,7 @@ class ModelTpServer:
|
|
525
525
|
)
|
526
526
|
|
527
527
|
# Return the new batch
|
528
|
-
new_batch =
|
528
|
+
new_batch = ScheduleBatch.init_new(
|
529
529
|
can_run_list,
|
530
530
|
self.req_to_token_pool,
|
531
531
|
self.token_to_kv_pool,
|
@@ -534,7 +534,7 @@ class ModelTpServer:
|
|
534
534
|
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
535
535
|
return new_batch
|
536
536
|
|
537
|
-
def forward_prefill_batch(self, batch:
|
537
|
+
def forward_prefill_batch(self, batch: ScheduleBatch):
|
538
538
|
# Build batch tensors
|
539
539
|
batch.prepare_for_extend(
|
540
540
|
self.model_config.vocab_size, self.int_token_logit_bias
|
@@ -623,14 +623,13 @@ class ModelTpServer:
|
|
623
623
|
)
|
624
624
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
625
625
|
|
626
|
-
def cache_filled_batch(self, batch:
|
627
|
-
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
626
|
+
def cache_filled_batch(self, batch: ScheduleBatch):
|
628
627
|
for i, req in enumerate(batch.reqs):
|
629
628
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
630
629
|
rid=req.rid,
|
631
630
|
token_ids=tuple(req.input_ids),
|
632
631
|
last_uncached_pos=len(req.prefix_indices),
|
633
|
-
req_pool_idx=
|
632
|
+
req_pool_idx=req.req_pool_idx,
|
634
633
|
del_in_memory_pool=False,
|
635
634
|
old_last_node=req.last_node,
|
636
635
|
)
|
@@ -638,9 +637,9 @@ class ModelTpServer:
|
|
638
637
|
|
639
638
|
if req is self.current_inflight_req:
|
640
639
|
# inflight request would get a new req idx
|
641
|
-
self.req_to_token_pool.free(
|
640
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
642
641
|
|
643
|
-
def forward_decode_batch(self, batch:
|
642
|
+
def forward_decode_batch(self, batch: ScheduleBatch):
|
644
643
|
# Check if decode out of memory
|
645
644
|
if not batch.check_decode_mem():
|
646
645
|
old_ratio = self.new_token_ratio
|
@@ -699,7 +698,7 @@ class ModelTpServer:
|
|
699
698
|
|
700
699
|
self.handle_finished_requests(batch)
|
701
700
|
|
702
|
-
def handle_finished_requests(self, batch:
|
701
|
+
def handle_finished_requests(self, batch: ScheduleBatch):
|
703
702
|
output_rids = []
|
704
703
|
output_vids = []
|
705
704
|
decoded_texts = []
|
@@ -781,14 +780,13 @@ class ModelTpServer:
|
|
781
780
|
# Remove finished reqs
|
782
781
|
if finished_indices:
|
783
782
|
# Update radix cache
|
784
|
-
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
785
783
|
for i in finished_indices:
|
786
784
|
req = batch.reqs[i]
|
787
785
|
self.tree_cache.cache_req(
|
788
786
|
rid=req.rid,
|
789
787
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
790
788
|
last_uncached_pos=len(req.prefix_indices),
|
791
|
-
req_pool_idx=
|
789
|
+
req_pool_idx=req.req_pool_idx,
|
792
790
|
)
|
793
791
|
|
794
792
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -799,7 +797,7 @@ class ModelTpServer:
|
|
799
797
|
else:
|
800
798
|
batch.reqs = []
|
801
799
|
|
802
|
-
def filter_out_inflight(self, batch:
|
800
|
+
def filter_out_inflight(self, batch: ScheduleBatch):
|
803
801
|
# TODO(lsyin): reduce the overhead, make a special version for this
|
804
802
|
if self.current_inflight_req is None:
|
805
803
|
return
|
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
+
from typing import List
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -27,62 +28,42 @@ class ReqToTokenPool:
|
|
27
28
|
|
28
29
|
def __init__(self, size: int, max_context_len: int):
|
29
30
|
self.size = size
|
30
|
-
self.
|
31
|
+
self.free_slots = list(range(size))
|
31
32
|
self.req_to_token = torch.empty(
|
32
33
|
(size, max_context_len), dtype=torch.int32, device="cuda"
|
33
34
|
)
|
34
|
-
self.can_use_mem_size = size
|
35
35
|
|
36
|
-
def alloc(self, need_size: int):
|
37
|
-
if need_size > self.
|
36
|
+
def alloc(self, need_size: int) -> List[int]:
|
37
|
+
if need_size > len(self.free_slots):
|
38
38
|
return None
|
39
39
|
|
40
|
-
select_index =
|
41
|
-
|
42
|
-
)
|
43
|
-
self.mem_state[select_index] = False
|
44
|
-
self.can_use_mem_size -= need_size
|
40
|
+
select_index = self.free_slots[:need_size]
|
41
|
+
self.free_slots = self.free_slots[need_size:]
|
45
42
|
|
46
43
|
return select_index
|
47
44
|
|
48
45
|
def free(self, free_index):
|
49
|
-
self.mem_state[free_index] = True
|
50
46
|
if isinstance(free_index, (int,)):
|
51
|
-
self.
|
47
|
+
self.free_slots.append(free_index)
|
52
48
|
else:
|
53
|
-
self.
|
49
|
+
self.free_slots.extend(free_index)
|
54
50
|
|
55
51
|
def clear(self):
|
56
|
-
self.
|
57
|
-
self.can_use_mem_size = len(self.mem_state)
|
52
|
+
self.free_slots = list(range(self.size))
|
58
53
|
|
59
54
|
|
60
|
-
class
|
55
|
+
class BaseTokenToKVPool:
|
61
56
|
"""A memory pool that maps a token to its kv cache locations"""
|
62
57
|
|
63
58
|
def __init__(
|
64
59
|
self,
|
65
60
|
size: int,
|
66
|
-
dtype: torch.dtype,
|
67
|
-
head_num: int,
|
68
|
-
head_dim: int,
|
69
|
-
layer_num: int,
|
70
61
|
):
|
71
62
|
self.size = size
|
72
63
|
|
73
64
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
74
65
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
75
66
|
|
76
|
-
# [size, head_num, head_dim] for each layer
|
77
|
-
self.k_buffer = [
|
78
|
-
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
79
|
-
for _ in range(layer_num)
|
80
|
-
]
|
81
|
-
self.v_buffer = [
|
82
|
-
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
83
|
-
for _ in range(layer_num)
|
84
|
-
]
|
85
|
-
|
86
67
|
# Prefetch buffer
|
87
68
|
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
88
69
|
self.prefetch_chunk_size = 512
|
@@ -90,15 +71,6 @@ class TokenToKVPool:
|
|
90
71
|
self.can_use_mem_size = self.size
|
91
72
|
self.clear()
|
92
73
|
|
93
|
-
def get_key_buffer(self, layer_id: int):
|
94
|
-
return self.k_buffer[layer_id]
|
95
|
-
|
96
|
-
def get_value_buffer(self, layer_id: int):
|
97
|
-
return self.v_buffer[layer_id]
|
98
|
-
|
99
|
-
def get_kv_buffer(self, layer_id: int):
|
100
|
-
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
101
|
-
|
102
74
|
def available_size(self):
|
103
75
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
104
76
|
|
@@ -139,3 +111,67 @@ class TokenToKVPool:
|
|
139
111
|
|
140
112
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
141
113
|
self.mem_state[0] = False
|
114
|
+
|
115
|
+
|
116
|
+
class MHATokenToKVPool(BaseTokenToKVPool):
|
117
|
+
|
118
|
+
def __init__(
|
119
|
+
self,
|
120
|
+
size: int,
|
121
|
+
dtype: torch.dtype,
|
122
|
+
head_num: int,
|
123
|
+
head_dim: int,
|
124
|
+
layer_num: int,
|
125
|
+
):
|
126
|
+
super().__init__(size)
|
127
|
+
|
128
|
+
# [size, head_num, head_dim] for each layer
|
129
|
+
self.k_buffer = [
|
130
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
131
|
+
for _ in range(layer_num)
|
132
|
+
]
|
133
|
+
self.v_buffer = [
|
134
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
135
|
+
for _ in range(layer_num)
|
136
|
+
]
|
137
|
+
|
138
|
+
def get_key_buffer(self, layer_id: int):
|
139
|
+
return self.k_buffer[layer_id]
|
140
|
+
|
141
|
+
def get_value_buffer(self, layer_id: int):
|
142
|
+
return self.v_buffer[layer_id]
|
143
|
+
|
144
|
+
def get_kv_buffer(self, layer_id: int):
|
145
|
+
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
146
|
+
|
147
|
+
|
148
|
+
class MLATokenToKVPool(BaseTokenToKVPool):
|
149
|
+
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
size: int,
|
153
|
+
dtype: torch.dtype,
|
154
|
+
kv_lora_rank: int,
|
155
|
+
qk_rope_head_dim: int,
|
156
|
+
layer_num: int,
|
157
|
+
):
|
158
|
+
super().__init__(size)
|
159
|
+
|
160
|
+
self.kv_lora_rank = kv_lora_rank
|
161
|
+
self.kv_buffer = [
|
162
|
+
torch.empty(
|
163
|
+
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
164
|
+
dtype=dtype,
|
165
|
+
device="cuda",
|
166
|
+
)
|
167
|
+
for _ in range(layer_num)
|
168
|
+
]
|
169
|
+
|
170
|
+
def get_key_buffer(self, layer_id: int):
|
171
|
+
return self.kv_buffer[layer_id]
|
172
|
+
|
173
|
+
def get_value_buffer(self, layer_id: int):
|
174
|
+
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
175
|
+
|
176
|
+
def get_kv_buffer(self, layer_id: int):
|
177
|
+
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
sglang/srt/model_config.py
CHANGED
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from enum import IntEnum, auto
|
16
17
|
from typing import Optional
|
17
18
|
|
18
19
|
from transformers import PretrainedConfig
|
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
|
|
20
21
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
21
22
|
|
22
23
|
|
24
|
+
class AttentionArch(IntEnum):
|
25
|
+
MLA = auto()
|
26
|
+
MHA = auto()
|
27
|
+
|
28
|
+
|
23
29
|
class ModelConfig:
|
24
30
|
def __init__(
|
25
31
|
self,
|
@@ -55,6 +61,11 @@ class ModelConfig:
|
|
55
61
|
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
56
62
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
57
63
|
self.head_dim = 256
|
64
|
+
self.attention_arch = AttentionArch.MLA
|
65
|
+
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
66
|
+
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
67
|
+
else:
|
68
|
+
self.attention_arch = AttentionArch.MHA
|
58
69
|
|
59
70
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
60
71
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
|
|
29
29
|
LogitsMetadata,
|
30
30
|
LogitsProcessor,
|
31
31
|
)
|
32
|
-
from sglang.srt.managers.schedule_batch import
|
33
|
-
|
32
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
34
34
|
ForwardMode,
|
35
35
|
InputMetadata,
|
36
36
|
init_flashinfer_args,
|
@@ -202,7 +202,7 @@ class CudaGraphRunner:
|
|
202
202
|
self.graph_memory_pool = graph.pool()
|
203
203
|
return graph, None, out, flashinfer_decode_wrapper
|
204
204
|
|
205
|
-
def replay(self, batch:
|
205
|
+
def replay(self, batch: ScheduleBatch):
|
206
206
|
assert batch.out_cache_loc is not None
|
207
207
|
raw_bs = len(batch.reqs)
|
208
208
|
|