sglang 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/interpreter.py +21 -5
- sglang/lang/ir.py +1 -2
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +17 -2
- sglang/srt/constrained/fsm_cache.py +17 -2
- sglang/srt/constrained/jump_forward.py +17 -2
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +15 -0
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +41 -13
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +17 -2
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +36 -3
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +60 -21
- sglang/srt/managers/tokenizer_manager.py +39 -16
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +159 -46
- sglang/srt/mem_cache/base_cache.py +43 -0
- sglang/srt/mem_cache/chunk_cache.py +60 -0
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +20 -2
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +15 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
- sglang/srt/{managers/controller → model_executor}/model_runner.py +49 -14
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +16 -1
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +21 -22
- sglang/srt/models/llama_classification.py +16 -1
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +569 -131
- sglang/srt/openai_api/protocol.py +84 -2
- sglang/srt/sampling_params.py +15 -0
- sglang/srt/server.py +92 -23
- sglang/srt/server_args.py +52 -11
- sglang/srt/utils.py +15 -0
- sglang/test/test_programs.py +9 -6
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/METADATA +33 -7
- sglang-0.2.8.dist-info/RECORD +95 -0
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.6.dist-info/RECORD +0 -93
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/LICENSE +0 -0
- {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""A tensor parallel worker."""
|
2
17
|
|
3
18
|
import logging
|
@@ -14,23 +29,24 @@ from sglang.global_config import global_config
|
|
14
29
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
15
30
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
16
31
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
17
|
-
from sglang.srt.managers.controller.infer_batch import (
|
18
|
-
FINISH_ABORT,
|
19
|
-
BaseFinishReason,
|
20
|
-
Batch,
|
21
|
-
ForwardMode,
|
22
|
-
Req,
|
23
|
-
)
|
24
|
-
from sglang.srt.managers.controller.model_runner import ModelRunner
|
25
|
-
from sglang.srt.managers.controller.radix_cache import RadixCache
|
26
|
-
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
27
32
|
from sglang.srt.managers.io_struct import (
|
28
33
|
AbortReq,
|
29
34
|
BatchTokenIDOut,
|
30
35
|
FlushCacheReq,
|
31
36
|
TokenizedGenerateReqInput,
|
32
37
|
)
|
38
|
+
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
39
|
+
from sglang.srt.managers.schedule_batch import (
|
40
|
+
FINISH_ABORT,
|
41
|
+
BaseFinishReason,
|
42
|
+
Batch,
|
43
|
+
ForwardMode,
|
44
|
+
Req,
|
45
|
+
)
|
46
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
47
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache
|
33
48
|
from sglang.srt.model_config import ModelConfig
|
49
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
34
50
|
from sglang.srt.server_args import ServerArgs
|
35
51
|
from sglang.srt.utils import (
|
36
52
|
get_int_token_logit_bias,
|
@@ -40,7 +56,7 @@ from sglang.srt.utils import (
|
|
40
56
|
)
|
41
57
|
from sglang.utils import get_exception_traceback
|
42
58
|
|
43
|
-
logger = logging.getLogger(
|
59
|
+
logger = logging.getLogger(__name__)
|
44
60
|
|
45
61
|
|
46
62
|
class ModelTpServer:
|
@@ -59,9 +75,13 @@ class ModelTpServer:
|
|
59
75
|
self.tp_rank = tp_rank
|
60
76
|
self.tp_size = server_args.tp_size
|
61
77
|
self.dp_size = server_args.dp_size
|
62
|
-
self.
|
78
|
+
self.schedule_policy = server_args.schedule_policy
|
63
79
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
64
80
|
|
81
|
+
# Chunked prefill
|
82
|
+
self.chunked_prefill_size = server_args.chunked_prefill_size
|
83
|
+
self.current_inflight_req = None
|
84
|
+
|
65
85
|
# Init model and tokenizer
|
66
86
|
self.model_config = ModelConfig(
|
67
87
|
server_args.model_path,
|
@@ -117,7 +137,7 @@ class ModelTpServer:
|
|
117
137
|
|
118
138
|
# Print info
|
119
139
|
logger.info(
|
120
|
-
f"[
|
140
|
+
f"[gpu={self.gpu_id}] "
|
121
141
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
122
142
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
123
143
|
f"max_running_requests={self.max_running_requests}, "
|
@@ -125,14 +145,23 @@ class ModelTpServer:
|
|
125
145
|
)
|
126
146
|
|
127
147
|
# Init cache
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
148
|
+
if (
|
149
|
+
server_args.chunked_prefill_size is not None
|
150
|
+
and server_args.disable_radix_cache
|
151
|
+
):
|
152
|
+
self.tree_cache = ChunkCache(
|
153
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
154
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
self.tree_cache = RadixCache(
|
158
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
159
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
160
|
+
disable=server_args.disable_radix_cache,
|
161
|
+
)
|
133
162
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
134
|
-
self.scheduler =
|
135
|
-
self.
|
163
|
+
self.scheduler = PolicyScheduler(
|
164
|
+
self.schedule_policy,
|
136
165
|
self.max_running_requests,
|
137
166
|
self.max_prefill_tokens,
|
138
167
|
self.max_total_num_tokens,
|
@@ -142,7 +171,7 @@ class ModelTpServer:
|
|
142
171
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
143
172
|
|
144
173
|
# Init running status
|
145
|
-
self.
|
174
|
+
self.waiting_queue: List[Req] = []
|
146
175
|
self.running_batch: Batch = None
|
147
176
|
self.out_pyobjs = []
|
148
177
|
self.decode_forward_ct = 0
|
@@ -205,6 +234,7 @@ class ModelTpServer:
|
|
205
234
|
# Run a new prefill batch
|
206
235
|
self.forward_prefill_batch(new_batch)
|
207
236
|
self.cache_filled_batch(new_batch)
|
237
|
+
self.filter_out_inflight(new_batch)
|
208
238
|
|
209
239
|
if not new_batch.is_empty():
|
210
240
|
if self.running_batch is None:
|
@@ -241,12 +271,12 @@ class ModelTpServer:
|
|
241
271
|
self.num_generated_tokens = 0
|
242
272
|
self.last_stats_tic = time.time()
|
243
273
|
logger.info(
|
244
|
-
f"[
|
274
|
+
f"[gpu={self.gpu_id}] Decode batch. "
|
245
275
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
246
276
|
f"#token: {num_used}, "
|
247
277
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
248
278
|
f"gen throughput (token/s): {throughput:.2f}, "
|
249
|
-
f"#queue-req: {len(self.
|
279
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
250
280
|
)
|
251
281
|
|
252
282
|
def check_memory(self):
|
@@ -260,6 +290,14 @@ class ModelTpServer:
|
|
260
290
|
"KV cache pool leak detected!"
|
261
291
|
)
|
262
292
|
|
293
|
+
if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
|
294
|
+
warnings.warn(
|
295
|
+
"Warning: "
|
296
|
+
f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
|
297
|
+
f"total slots={self.req_to_token_pool.size}\n"
|
298
|
+
"Memory pool leak detected!"
|
299
|
+
)
|
300
|
+
|
263
301
|
def handle_generate_request(
|
264
302
|
self,
|
265
303
|
recv_req: TokenizedGenerateReqInput,
|
@@ -313,9 +351,10 @@ class ModelTpServer:
|
|
313
351
|
),
|
314
352
|
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
315
353
|
)
|
316
|
-
self.
|
354
|
+
self.waiting_queue.append(req)
|
317
355
|
|
318
356
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
357
|
+
# TODO(lsyin): organize this function
|
319
358
|
running_bs = (
|
320
359
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
321
360
|
)
|
@@ -323,9 +362,12 @@ class ModelTpServer:
|
|
323
362
|
return
|
324
363
|
|
325
364
|
# Compute matched prefix length
|
326
|
-
for req in self.
|
365
|
+
for req in self.waiting_queue:
|
327
366
|
req.input_ids = req.origin_input_ids + req.output_ids
|
328
|
-
prefix_indices, last_node = self.tree_cache.match_prefix(
|
367
|
+
prefix_indices, last_node = self.tree_cache.match_prefix(
|
368
|
+
rid=req.rid,
|
369
|
+
key=req.input_ids,
|
370
|
+
)
|
329
371
|
if req.return_logprob:
|
330
372
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
331
373
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
@@ -333,7 +375,7 @@ class ModelTpServer:
|
|
333
375
|
req.last_node = last_node
|
334
376
|
|
335
377
|
# Get priority queue
|
336
|
-
self.
|
378
|
+
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
337
379
|
|
338
380
|
# Add requests if there is available space
|
339
381
|
can_run_list = []
|
@@ -352,7 +394,33 @@ class ModelTpServer:
|
|
352
394
|
]
|
353
395
|
)
|
354
396
|
|
355
|
-
|
397
|
+
# Handle the current inflight request
|
398
|
+
take_inflight = 0
|
399
|
+
if self.current_inflight_req:
|
400
|
+
take_inflight = 1
|
401
|
+
r = self.current_inflight_req
|
402
|
+
r.input_ids = r.origin_input_ids + r.output_ids
|
403
|
+
truncated = (
|
404
|
+
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
405
|
+
)
|
406
|
+
r.extend_input_len = min(
|
407
|
+
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
408
|
+
)
|
409
|
+
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
410
|
+
can_run_list.append(r)
|
411
|
+
|
412
|
+
if not truncated:
|
413
|
+
# Finish inflight
|
414
|
+
self.current_inflight_req = None
|
415
|
+
new_batch_total_tokens += (
|
416
|
+
r.extend_input_len + r.sampling_params.max_new_tokens
|
417
|
+
)
|
418
|
+
new_batch_input_tokens += r.extend_input_len
|
419
|
+
else:
|
420
|
+
new_batch_total_tokens += r.extend_input_len
|
421
|
+
new_batch_input_tokens += r.extend_input_len
|
422
|
+
|
423
|
+
for req in self.waiting_queue:
|
356
424
|
if req.return_logprob and req.normalized_prompt_logprob is None:
|
357
425
|
# Need at least two tokens to compute normalized logprob
|
358
426
|
if req.extend_input_len < 2:
|
@@ -394,11 +462,39 @@ class ModelTpServer:
|
|
394
462
|
break
|
395
463
|
else:
|
396
464
|
# Add this request to the running batch
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
465
|
+
if (
|
466
|
+
self.chunked_prefill_size is None
|
467
|
+
or (
|
468
|
+
new_batch_input_tokens + req.extend_input_len
|
469
|
+
<= self.chunked_prefill_size
|
470
|
+
)
|
471
|
+
or (
|
472
|
+
req.return_logprob and req.normalized_prompt_logprob is None
|
473
|
+
)
|
474
|
+
):
|
475
|
+
can_run_list.append(req)
|
476
|
+
new_batch_total_tokens += (
|
477
|
+
req.extend_input_len + req.sampling_params.max_new_tokens
|
478
|
+
)
|
479
|
+
new_batch_input_tokens += req.extend_input_len
|
480
|
+
else:
|
481
|
+
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
482
|
+
|
483
|
+
if trunc_len <= 0:
|
484
|
+
# Undo locking
|
485
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
486
|
+
available_size += delta
|
487
|
+
break
|
488
|
+
|
489
|
+
req.extend_input_len = trunc_len
|
490
|
+
req.input_ids = req.input_ids[
|
491
|
+
: len(req.prefix_indices) + req.extend_input_len
|
492
|
+
]
|
493
|
+
can_run_list.append(req)
|
494
|
+
self.current_inflight_req = req
|
495
|
+
new_batch_input_tokens += req.extend_input_len
|
496
|
+
new_batch_total_tokens += req.extend_input_len
|
497
|
+
break
|
402
498
|
else:
|
403
499
|
break
|
404
500
|
|
@@ -419,13 +515,13 @@ class ModelTpServer:
|
|
419
515
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
420
516
|
)
|
421
517
|
logger.info(
|
422
|
-
f"[
|
518
|
+
f"[gpu={self.gpu_id}] Prefill batch. "
|
423
519
|
f"#new-seq: {len(can_run_list)}, "
|
424
520
|
f"#new-token: {new_batch_input_tokens}, "
|
425
521
|
f"#cached-token: {hit_tokens}, "
|
426
522
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
427
523
|
f"#running-req: {running_bs}, "
|
428
|
-
f"#queue-req: {len(self.
|
524
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
429
525
|
)
|
430
526
|
|
431
527
|
# Return the new batch
|
@@ -435,7 +531,7 @@ class ModelTpServer:
|
|
435
531
|
self.token_to_kv_pool,
|
436
532
|
self.tree_cache,
|
437
533
|
)
|
438
|
-
self.
|
534
|
+
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
439
535
|
return new_batch
|
440
536
|
|
441
537
|
def forward_prefill_batch(self, batch: Batch):
|
@@ -467,9 +563,10 @@ class ModelTpServer:
|
|
467
563
|
# Check finish conditions
|
468
564
|
pt = 0
|
469
565
|
for i, req in enumerate(batch.reqs):
|
470
|
-
req
|
471
|
-
|
472
|
-
|
566
|
+
if req is not self.current_inflight_req:
|
567
|
+
req.completion_tokens_wo_jump_forward += 1
|
568
|
+
req.output_ids.append(next_token_ids[i])
|
569
|
+
req.check_finished()
|
473
570
|
|
474
571
|
if req.return_logprob:
|
475
572
|
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
@@ -530,7 +627,8 @@ class ModelTpServer:
|
|
530
627
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
531
628
|
for i, req in enumerate(batch.reqs):
|
532
629
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
533
|
-
|
630
|
+
rid=req.rid,
|
631
|
+
token_ids=tuple(req.input_ids),
|
534
632
|
last_uncached_pos=len(req.prefix_indices),
|
535
633
|
req_pool_idx=req_pool_indices_cpu[i],
|
536
634
|
del_in_memory_pool=False,
|
@@ -538,6 +636,10 @@ class ModelTpServer:
|
|
538
636
|
)
|
539
637
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
540
638
|
|
639
|
+
if req is self.current_inflight_req:
|
640
|
+
# inflight request would get a new req idx
|
641
|
+
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
642
|
+
|
541
643
|
def forward_decode_batch(self, batch: Batch):
|
542
644
|
# Check if decode out of memory
|
543
645
|
if not batch.check_decode_mem():
|
@@ -551,7 +653,7 @@ class ModelTpServer:
|
|
551
653
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
552
654
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
553
655
|
)
|
554
|
-
self.
|
656
|
+
self.waiting_queue.extend(retracted_reqs)
|
555
657
|
else:
|
556
658
|
self.new_token_ratio = max(
|
557
659
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -561,7 +663,7 @@ class ModelTpServer:
|
|
561
663
|
if not self.disable_regex_jump_forward:
|
562
664
|
# Check for jump-forward
|
563
665
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
564
|
-
self.
|
666
|
+
self.waiting_queue.extend(jump_forward_reqs)
|
565
667
|
if batch.is_empty():
|
566
668
|
return
|
567
669
|
|
@@ -683,6 +785,7 @@ class ModelTpServer:
|
|
683
785
|
for i in finished_indices:
|
684
786
|
req = batch.reqs[i]
|
685
787
|
self.tree_cache.cache_req(
|
788
|
+
rid=req.rid,
|
686
789
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
687
790
|
last_uncached_pos=len(req.prefix_indices),
|
688
791
|
req_pool_idx=req_pool_indices_cpu[i],
|
@@ -696,8 +799,18 @@ class ModelTpServer:
|
|
696
799
|
else:
|
697
800
|
batch.reqs = []
|
698
801
|
|
802
|
+
def filter_out_inflight(self, batch: Batch):
|
803
|
+
# TODO(lsyin): reduce the overhead, make a special version for this
|
804
|
+
if self.current_inflight_req is None:
|
805
|
+
return
|
806
|
+
|
807
|
+
to_remove = batch.reqs.index(self.current_inflight_req)
|
808
|
+
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
809
|
+
|
810
|
+
batch.filter_batch(unfinished_indices)
|
811
|
+
|
699
812
|
def flush_cache(self):
|
700
|
-
if len(self.
|
813
|
+
if len(self.waiting_queue) == 0 and (
|
701
814
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
702
815
|
):
|
703
816
|
self.tree_cache.reset()
|
@@ -710,20 +823,20 @@ class ModelTpServer:
|
|
710
823
|
else:
|
711
824
|
warnings.warn(
|
712
825
|
f"Cache not flushed because there are pending requests. "
|
713
|
-
f"#queue-req: {len(self.
|
826
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
714
827
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
715
828
|
)
|
716
829
|
|
717
830
|
def abort_request(self, recv_req):
|
718
831
|
# Delete requests in the waiting queue
|
719
832
|
to_del = None
|
720
|
-
for i, req in enumerate(self.
|
833
|
+
for i, req in enumerate(self.waiting_queue):
|
721
834
|
if req.rid == recv_req.rid:
|
722
835
|
to_del = i
|
723
836
|
break
|
724
837
|
|
725
838
|
if to_del is not None:
|
726
|
-
del self.
|
839
|
+
del self.waiting_queue[to_del]
|
727
840
|
|
728
841
|
# Delete requests in the running batch
|
729
842
|
if self.running_batch:
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
|
4
|
+
class BasePrefixCache(ABC):
|
5
|
+
"""Cache can be indexed by either rid or key."""
|
6
|
+
|
7
|
+
@abstractmethod
|
8
|
+
def reset(self):
|
9
|
+
pass
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def match_prefix(self, **kwargs):
|
13
|
+
pass
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def insert(self, **kwargs):
|
17
|
+
pass
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def cache_req(self, **kwargs):
|
21
|
+
pass
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def evict(self, num_tokens, evict_callback):
|
25
|
+
pass
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def inc_lock_ref(self, node):
|
29
|
+
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
def dec_lock_ref(self, node):
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def evictable_size(self):
|
37
|
+
pass
|
38
|
+
|
39
|
+
def total_size(self):
|
40
|
+
raise NotImplementedError
|
41
|
+
|
42
|
+
def pretty_print(self):
|
43
|
+
raise NotImplementedError
|
@@ -0,0 +1,60 @@
|
|
1
|
+
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
2
|
+
|
3
|
+
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
4
|
+
|
5
|
+
|
6
|
+
class ChunkCacheEntry:
|
7
|
+
def __init__(self, rid, value):
|
8
|
+
self.rid = rid
|
9
|
+
self.value = value
|
10
|
+
|
11
|
+
|
12
|
+
class ChunkCache(BasePrefixCache):
|
13
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool):
|
14
|
+
self.disable = True
|
15
|
+
self.req_to_token_pool = req_to_token_pool
|
16
|
+
self.token_to_kv_pool = token_to_kv_pool
|
17
|
+
|
18
|
+
self.reset()
|
19
|
+
|
20
|
+
def reset(self):
|
21
|
+
self.entries = {}
|
22
|
+
|
23
|
+
def match_prefix(self, rid, **kwargs):
|
24
|
+
if rid not in self.entries:
|
25
|
+
return [], None
|
26
|
+
|
27
|
+
entry = self.entries[rid]
|
28
|
+
return entry.value, entry
|
29
|
+
|
30
|
+
def cache_req(
|
31
|
+
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
|
32
|
+
):
|
33
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
34
|
+
if del_in_memory_pool:
|
35
|
+
assert rid in self.entries
|
36
|
+
self.req_to_token_pool.free(req_pool_idx)
|
37
|
+
self.token_to_kv_pool.free(indices)
|
38
|
+
return
|
39
|
+
|
40
|
+
if rid not in self.entries:
|
41
|
+
self.entries[rid] = ChunkCacheEntry(rid, indices)
|
42
|
+
|
43
|
+
entry = self.entries[rid]
|
44
|
+
entry.value = indices
|
45
|
+
return indices, entry
|
46
|
+
|
47
|
+
def insert(self):
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
def evict(self, num_tokens, evict_callback):
|
51
|
+
pass
|
52
|
+
|
53
|
+
def inc_lock_ref(self, node):
|
54
|
+
return 0
|
55
|
+
|
56
|
+
def dec_lock_ref(self, node):
|
57
|
+
return 0
|
58
|
+
|
59
|
+
def evictable_size(self):
|
60
|
+
return 0
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Flush the KV cache.
|
18
|
+
|
19
|
+
Usage:
|
20
|
+
python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
|
21
|
+
"""
|
22
|
+
|
23
|
+
import argparse
|
24
|
+
|
25
|
+
import requests
|
26
|
+
|
27
|
+
if __name__ == "__main__":
|
28
|
+
parser = argparse.ArgumentParser()
|
29
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
args = parser.parse_args()
|
31
|
+
|
32
|
+
response = requests.get(args.url + "/flush_cache")
|
33
|
+
assert response.status_code == 200
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Memory pool."""
|
2
17
|
|
3
18
|
import logging
|
@@ -30,7 +45,7 @@ class ReqToTokenPool:
|
|
30
45
|
|
31
46
|
return select_index
|
32
47
|
|
33
|
-
def free(self, free_index
|
48
|
+
def free(self, free_index):
|
34
49
|
self.mem_state[free_index] = True
|
35
50
|
if isinstance(free_index, (int,)):
|
36
51
|
self.can_use_mem_size += 1
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
The radix tree data structure for managing the KV cache.
|
3
18
|
"""
|
@@ -8,6 +23,8 @@ from collections import defaultdict
|
|
8
23
|
|
9
24
|
import torch
|
10
25
|
|
26
|
+
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
27
|
+
|
11
28
|
|
12
29
|
class TreeNode:
|
13
30
|
def __init__(self):
|
@@ -31,7 +48,7 @@ def _key_match(key0, key1):
|
|
31
48
|
return i
|
32
49
|
|
33
50
|
|
34
|
-
class RadixCache:
|
51
|
+
class RadixCache(BasePrefixCache):
|
35
52
|
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
36
53
|
self.req_to_token_pool = req_to_token_pool
|
37
54
|
self.token_to_kv_pool = token_to_kv_pool
|
@@ -47,7 +64,7 @@ class RadixCache:
|
|
47
64
|
self.root_node.lock_ref = 1
|
48
65
|
self.evictable_size_ = 0
|
49
66
|
|
50
|
-
def match_prefix(self, key):
|
67
|
+
def match_prefix(self, key, **kwargs):
|
51
68
|
if self.disable:
|
52
69
|
return [], self.root_node
|
53
70
|
|
@@ -75,6 +92,7 @@ class RadixCache:
|
|
75
92
|
req_pool_idx,
|
76
93
|
del_in_memory_pool=True,
|
77
94
|
old_last_node=None,
|
95
|
+
**kwargs,
|
78
96
|
):
|
79
97
|
# Insert the request into radix cache
|
80
98
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
sglang/srt/mm_utils.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
2
17
|
import ast
|
3
18
|
import base64
|
sglang/srt/model_config.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
from typing import Optional
|
2
17
|
|
3
18
|
from transformers import PretrainedConfig
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Run the model with cuda graph."""
|
2
17
|
|
3
18
|
import bisect
|
@@ -14,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
|
|
14
29
|
LogitsMetadata,
|
15
30
|
LogitsProcessor,
|
16
31
|
)
|
17
|
-
from sglang.srt.managers.
|
32
|
+
from sglang.srt.managers.schedule_batch import (
|
18
33
|
Batch,
|
19
34
|
ForwardMode,
|
20
35
|
InputMetadata,
|