sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +4 -4
- sglang/lang/interpreter.py +24 -9
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/base_cache.py +15 -0
- sglang/srt/constrained/fsm_cache.py +36 -1
- sglang/srt/constrained/jump_forward.py +15 -0
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +18 -1
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +109 -72
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +21 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +38 -5
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
- sglang/srt/managers/tokenizer_manager.py +99 -57
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +20 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
- sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +532 -0
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +16 -1
- sglang/srt/models/llama_classification.py +19 -4
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +545 -160
- sglang/srt/openai_api/protocol.py +65 -1
- sglang/srt/sampling_params.py +20 -4
- sglang/srt/server.py +90 -37
- sglang/srt/server_args.py +76 -17
- sglang/srt/utils.py +15 -0
- sglang/test/test_programs.py +5 -1
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
- sglang-0.2.7.dist-info/RECORD +93 -0
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.5.dist-info/RECORD +0 -92
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""A tensor parallel worker."""
|
2
17
|
|
3
18
|
import logging
|
@@ -14,23 +29,23 @@ from sglang.global_config import global_config
|
|
14
29
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
15
30
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
16
31
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
17
|
-
from sglang.srt.managers.controller.infer_batch import (
|
18
|
-
FINISH_ABORT,
|
19
|
-
BaseFinishReason,
|
20
|
-
Batch,
|
21
|
-
ForwardMode,
|
22
|
-
Req,
|
23
|
-
)
|
24
|
-
from sglang.srt.managers.controller.model_runner import ModelRunner
|
25
|
-
from sglang.srt.managers.controller.radix_cache import RadixCache
|
26
|
-
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
27
32
|
from sglang.srt.managers.io_struct import (
|
28
33
|
AbortReq,
|
29
34
|
BatchTokenIDOut,
|
30
35
|
FlushCacheReq,
|
31
36
|
TokenizedGenerateReqInput,
|
32
37
|
)
|
38
|
+
from sglang.srt.managers.policy_scheduler import PolicyScheduler
|
39
|
+
from sglang.srt.managers.schedule_batch import (
|
40
|
+
FINISH_ABORT,
|
41
|
+
BaseFinishReason,
|
42
|
+
Batch,
|
43
|
+
ForwardMode,
|
44
|
+
Req,
|
45
|
+
)
|
46
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache
|
33
47
|
from sglang.srt.model_config import ModelConfig
|
48
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
34
49
|
from sglang.srt.server_args import ServerArgs
|
35
50
|
from sglang.srt.utils import (
|
36
51
|
get_int_token_logit_bias,
|
@@ -40,7 +55,7 @@ from sglang.srt.utils import (
|
|
40
55
|
)
|
41
56
|
from sglang.utils import get_exception_traceback
|
42
57
|
|
43
|
-
logger = logging.getLogger(
|
58
|
+
logger = logging.getLogger(__name__)
|
44
59
|
|
45
60
|
|
46
61
|
class ModelTpServer:
|
@@ -59,9 +74,13 @@ class ModelTpServer:
|
|
59
74
|
self.tp_rank = tp_rank
|
60
75
|
self.tp_size = server_args.tp_size
|
61
76
|
self.dp_size = server_args.dp_size
|
62
|
-
self.
|
77
|
+
self.schedule_policy = server_args.schedule_policy
|
63
78
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
64
79
|
|
80
|
+
# Chunked prefill
|
81
|
+
self.chunked_prefill_size = server_args.chunked_prefill_size
|
82
|
+
self.current_inflight_req = None
|
83
|
+
|
65
84
|
# Init model and tokenizer
|
66
85
|
self.model_config = ModelConfig(
|
67
86
|
server_args.model_path,
|
@@ -98,22 +117,26 @@ class ModelTpServer:
|
|
98
117
|
if server_args.max_prefill_tokens is None
|
99
118
|
else server_args.max_prefill_tokens
|
100
119
|
)
|
101
|
-
self.max_running_requests = (
|
102
|
-
self.max_total_num_tokens // 2
|
103
|
-
if server_args.max_running_requests is None
|
104
|
-
else server_args.max_running_requests
|
105
|
-
)
|
106
120
|
self.max_running_requests = min(
|
107
|
-
|
121
|
+
(
|
122
|
+
self.max_total_num_tokens // 2
|
123
|
+
if server_args.max_running_requests is None
|
124
|
+
else server_args.max_running_requests
|
125
|
+
),
|
126
|
+
self.model_runner.req_to_token_pool.size - 1,
|
108
127
|
)
|
109
128
|
self.int_token_logit_bias = torch.tensor(
|
110
129
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
111
130
|
)
|
131
|
+
self.max_req_input_len = min(
|
132
|
+
self.model_config.context_len - 1,
|
133
|
+
self.max_total_num_tokens - 1,
|
134
|
+
)
|
112
135
|
set_random_seed(server_args.random_seed)
|
113
136
|
|
114
137
|
# Print info
|
115
138
|
logger.info(
|
116
|
-
f"[
|
139
|
+
f"[gpu={self.gpu_id}] "
|
117
140
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
118
141
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
119
142
|
f"max_running_requests={self.max_running_requests}, "
|
@@ -127,8 +150,8 @@ class ModelTpServer:
|
|
127
150
|
disable=server_args.disable_radix_cache,
|
128
151
|
)
|
129
152
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
130
|
-
self.scheduler =
|
131
|
-
self.
|
153
|
+
self.scheduler = PolicyScheduler(
|
154
|
+
self.schedule_policy,
|
132
155
|
self.max_running_requests,
|
133
156
|
self.max_prefill_tokens,
|
134
157
|
self.max_total_num_tokens,
|
@@ -138,7 +161,7 @@ class ModelTpServer:
|
|
138
161
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
139
162
|
|
140
163
|
# Init running status
|
141
|
-
self.
|
164
|
+
self.waiting_queue: List[Req] = []
|
142
165
|
self.running_batch: Batch = None
|
143
166
|
self.out_pyobjs = []
|
144
167
|
self.decode_forward_ct = 0
|
@@ -201,6 +224,7 @@ class ModelTpServer:
|
|
201
224
|
# Run a new prefill batch
|
202
225
|
self.forward_prefill_batch(new_batch)
|
203
226
|
self.cache_filled_batch(new_batch)
|
227
|
+
self.filter_out_inflight(new_batch)
|
204
228
|
|
205
229
|
if not new_batch.is_empty():
|
206
230
|
if self.running_batch is None:
|
@@ -237,12 +261,12 @@ class ModelTpServer:
|
|
237
261
|
self.num_generated_tokens = 0
|
238
262
|
self.last_stats_tic = time.time()
|
239
263
|
logger.info(
|
240
|
-
f"[
|
264
|
+
f"[gpu={self.gpu_id}] Decode batch. "
|
241
265
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
242
266
|
f"#token: {num_used}, "
|
243
267
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
244
268
|
f"gen throughput (token/s): {throughput:.2f}, "
|
245
|
-
f"#queue-req: {len(self.
|
269
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
246
270
|
)
|
247
271
|
|
248
272
|
def check_memory(self):
|
@@ -295,21 +319,24 @@ class ModelTpServer:
|
|
295
319
|
)
|
296
320
|
|
297
321
|
# Truncate prompts that are too long
|
298
|
-
req.origin_input_ids
|
322
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
323
|
+
logger.warn(
|
324
|
+
"Request length is longer than the KV cache pool size or "
|
325
|
+
"the max context length. Truncated!!!"
|
326
|
+
)
|
327
|
+
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
299
328
|
req.sampling_params.max_new_tokens = min(
|
300
|
-
|
301
|
-
|
302
|
-
|
329
|
+
(
|
330
|
+
req.sampling_params.max_new_tokens
|
331
|
+
if req.sampling_params.max_new_tokens is not None
|
332
|
+
else 1 << 30
|
333
|
+
),
|
334
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
303
335
|
)
|
304
|
-
|
305
|
-
req.origin_input_ids = req.origin_input_ids[
|
306
|
-
: self.max_total_num_tokens - 128
|
307
|
-
]
|
308
|
-
logger.error("Request longer than memory pool size, truncated!!!")
|
309
|
-
|
310
|
-
self.forward_queue.append(req)
|
336
|
+
self.waiting_queue.append(req)
|
311
337
|
|
312
338
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
339
|
+
# TODO(lsyin): organize this function
|
313
340
|
running_bs = (
|
314
341
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
315
342
|
)
|
@@ -317,7 +344,7 @@ class ModelTpServer:
|
|
317
344
|
return
|
318
345
|
|
319
346
|
# Compute matched prefix length
|
320
|
-
for req in self.
|
347
|
+
for req in self.waiting_queue:
|
321
348
|
req.input_ids = req.origin_input_ids + req.output_ids
|
322
349
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
323
350
|
if req.return_logprob:
|
@@ -327,7 +354,7 @@ class ModelTpServer:
|
|
327
354
|
req.last_node = last_node
|
328
355
|
|
329
356
|
# Get priority queue
|
330
|
-
self.
|
357
|
+
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
331
358
|
|
332
359
|
# Add requests if there is available space
|
333
360
|
can_run_list = []
|
@@ -346,7 +373,33 @@ class ModelTpServer:
|
|
346
373
|
]
|
347
374
|
)
|
348
375
|
|
349
|
-
|
376
|
+
# Handle the current inflight request
|
377
|
+
take_inflight = 0
|
378
|
+
if self.current_inflight_req:
|
379
|
+
take_inflight = 1
|
380
|
+
r = self.current_inflight_req
|
381
|
+
r.input_ids = r.origin_input_ids + r.output_ids
|
382
|
+
truncated = (
|
383
|
+
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
384
|
+
)
|
385
|
+
r.extend_input_len = min(
|
386
|
+
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
387
|
+
)
|
388
|
+
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
389
|
+
can_run_list.append(r)
|
390
|
+
|
391
|
+
if not truncated:
|
392
|
+
# Finish inflight
|
393
|
+
self.current_inflight_req = None
|
394
|
+
new_batch_total_tokens += (
|
395
|
+
r.extend_input_len + r.sampling_params.max_new_tokens
|
396
|
+
)
|
397
|
+
new_batch_input_tokens += r.extend_input_len
|
398
|
+
else:
|
399
|
+
new_batch_total_tokens += r.extend_input_len
|
400
|
+
new_batch_input_tokens += r.extend_input_len
|
401
|
+
|
402
|
+
for req in self.waiting_queue:
|
350
403
|
if req.return_logprob and req.normalized_prompt_logprob is None:
|
351
404
|
# Need at least two tokens to compute normalized logprob
|
352
405
|
if req.extend_input_len < 2:
|
@@ -388,11 +441,39 @@ class ModelTpServer:
|
|
388
441
|
break
|
389
442
|
else:
|
390
443
|
# Add this request to the running batch
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
444
|
+
if (
|
445
|
+
self.chunked_prefill_size is None
|
446
|
+
or (
|
447
|
+
new_batch_input_tokens + req.extend_input_len
|
448
|
+
<= self.chunked_prefill_size
|
449
|
+
)
|
450
|
+
or (
|
451
|
+
req.return_logprob and req.normalized_prompt_logprob is None
|
452
|
+
)
|
453
|
+
):
|
454
|
+
can_run_list.append(req)
|
455
|
+
new_batch_total_tokens += (
|
456
|
+
req.extend_input_len + req.sampling_params.max_new_tokens
|
457
|
+
)
|
458
|
+
new_batch_input_tokens += req.extend_input_len
|
459
|
+
else:
|
460
|
+
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
461
|
+
|
462
|
+
if trunc_len <= 0:
|
463
|
+
# Undo locking
|
464
|
+
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
465
|
+
available_size += delta
|
466
|
+
break
|
467
|
+
|
468
|
+
req.extend_input_len = trunc_len
|
469
|
+
req.input_ids = req.input_ids[
|
470
|
+
: len(req.prefix_indices) + req.extend_input_len
|
471
|
+
]
|
472
|
+
can_run_list.append(req)
|
473
|
+
self.current_inflight_req = req
|
474
|
+
new_batch_input_tokens += req.extend_input_len
|
475
|
+
new_batch_total_tokens += req.extend_input_len
|
476
|
+
break
|
396
477
|
else:
|
397
478
|
break
|
398
479
|
|
@@ -413,13 +494,13 @@ class ModelTpServer:
|
|
413
494
|
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
414
495
|
)
|
415
496
|
logger.info(
|
416
|
-
f"[
|
497
|
+
f"[gpu={self.gpu_id}] Prefill batch. "
|
417
498
|
f"#new-seq: {len(can_run_list)}, "
|
418
499
|
f"#new-token: {new_batch_input_tokens}, "
|
419
500
|
f"#cached-token: {hit_tokens}, "
|
420
501
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
421
502
|
f"#running-req: {running_bs}, "
|
422
|
-
f"#queue-req: {len(self.
|
503
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
423
504
|
)
|
424
505
|
|
425
506
|
# Return the new batch
|
@@ -429,7 +510,7 @@ class ModelTpServer:
|
|
429
510
|
self.token_to_kv_pool,
|
430
511
|
self.tree_cache,
|
431
512
|
)
|
432
|
-
self.
|
513
|
+
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
433
514
|
return new_batch
|
434
515
|
|
435
516
|
def forward_prefill_batch(self, batch: Batch):
|
@@ -449,7 +530,7 @@ class ModelTpServer:
|
|
449
530
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
450
531
|
next_token_ids,
|
451
532
|
].tolist()
|
452
|
-
output.
|
533
|
+
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
453
534
|
output.normalized_prompt_logprobs = (
|
454
535
|
output.normalized_prompt_logprobs.tolist()
|
455
536
|
)
|
@@ -461,9 +542,10 @@ class ModelTpServer:
|
|
461
542
|
# Check finish conditions
|
462
543
|
pt = 0
|
463
544
|
for i, req in enumerate(batch.reqs):
|
464
|
-
req
|
465
|
-
|
466
|
-
|
545
|
+
if req is not self.current_inflight_req:
|
546
|
+
req.completion_tokens_wo_jump_forward += 1
|
547
|
+
req.output_ids.append(next_token_ids[i])
|
548
|
+
req.check_finished()
|
467
549
|
|
468
550
|
if req.return_logprob:
|
469
551
|
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
@@ -475,24 +557,24 @@ class ModelTpServer:
|
|
475
557
|
if req.normalized_prompt_logprob is None:
|
476
558
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
477
559
|
|
478
|
-
if req.
|
560
|
+
if req.input_token_logprobs is None:
|
479
561
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
480
|
-
req.
|
562
|
+
req.input_token_logprobs = list(
|
481
563
|
zip(
|
482
|
-
output.
|
564
|
+
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
483
565
|
req.input_ids[-req.extend_input_len + 1 :],
|
484
566
|
)
|
485
567
|
)
|
486
568
|
if req.logprob_start_len == 0:
|
487
|
-
req.
|
569
|
+
req.input_token_logprobs = [
|
488
570
|
(None, req.input_ids[0])
|
489
|
-
] + req.
|
571
|
+
] + req.input_token_logprobs
|
490
572
|
|
491
573
|
if req.last_update_decode_tokens != 0:
|
492
|
-
req.
|
574
|
+
req.output_token_logprobs.extend(
|
493
575
|
list(
|
494
576
|
zip(
|
495
|
-
output.
|
577
|
+
output.input_token_logprobs[
|
496
578
|
pt
|
497
579
|
+ req.extend_input_len
|
498
580
|
- req.last_update_decode_tokens : pt
|
@@ -504,27 +586,27 @@ class ModelTpServer:
|
|
504
586
|
)
|
505
587
|
)
|
506
588
|
|
507
|
-
req.
|
589
|
+
req.output_token_logprobs.append(
|
508
590
|
(output.next_token_logprobs[i], next_token_ids[i])
|
509
591
|
)
|
510
592
|
|
511
593
|
if req.top_logprobs_num > 0:
|
512
|
-
if req.
|
513
|
-
req.
|
594
|
+
if req.input_top_logprobs is None:
|
595
|
+
req.input_top_logprobs = output.input_top_logprobs[i]
|
514
596
|
if req.logprob_start_len == 0:
|
515
|
-
req.
|
597
|
+
req.input_top_logprobs = [None] + req.input_top_logprobs
|
516
598
|
|
517
599
|
if req.last_update_decode_tokens != 0:
|
518
|
-
req.
|
519
|
-
output.
|
600
|
+
req.output_top_logprobs.extend(
|
601
|
+
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
520
602
|
)
|
521
|
-
req.
|
603
|
+
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
522
604
|
|
523
605
|
def cache_filled_batch(self, batch: Batch):
|
524
606
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
525
607
|
for i, req in enumerate(batch.reqs):
|
526
608
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
527
|
-
token_ids=tuple(req.
|
609
|
+
token_ids=tuple(req.input_ids),
|
528
610
|
last_uncached_pos=len(req.prefix_indices),
|
529
611
|
req_pool_idx=req_pool_indices_cpu[i],
|
530
612
|
del_in_memory_pool=False,
|
@@ -532,6 +614,10 @@ class ModelTpServer:
|
|
532
614
|
)
|
533
615
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
534
616
|
|
617
|
+
if req is self.current_inflight_req:
|
618
|
+
# inflight request would get a new req idx
|
619
|
+
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
620
|
+
|
535
621
|
def forward_decode_batch(self, batch: Batch):
|
536
622
|
# Check if decode out of memory
|
537
623
|
if not batch.check_decode_mem():
|
@@ -545,7 +631,7 @@ class ModelTpServer:
|
|
545
631
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
546
632
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
547
633
|
)
|
548
|
-
self.
|
634
|
+
self.waiting_queue.extend(retracted_reqs)
|
549
635
|
else:
|
550
636
|
self.new_token_ratio = max(
|
551
637
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -555,7 +641,7 @@ class ModelTpServer:
|
|
555
641
|
if not self.disable_regex_jump_forward:
|
556
642
|
# Check for jump-forward
|
557
643
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
558
|
-
self.
|
644
|
+
self.waiting_queue.extend(jump_forward_reqs)
|
559
645
|
if batch.is_empty():
|
560
646
|
return
|
561
647
|
|
@@ -583,11 +669,11 @@ class ModelTpServer:
|
|
583
669
|
req.check_finished()
|
584
670
|
|
585
671
|
if req.return_logprob:
|
586
|
-
req.
|
672
|
+
req.output_token_logprobs.append(
|
587
673
|
(next_token_logprobs[i], next_token_id)
|
588
674
|
)
|
589
675
|
if req.top_logprobs_num > 0:
|
590
|
-
req.
|
676
|
+
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
591
677
|
|
592
678
|
self.handle_finished_requests(batch)
|
593
679
|
|
@@ -639,16 +725,16 @@ class ModelTpServer:
|
|
639
725
|
}
|
640
726
|
if req.return_logprob:
|
641
727
|
(
|
642
|
-
meta_info["
|
643
|
-
meta_info["
|
644
|
-
meta_info["
|
645
|
-
meta_info["
|
728
|
+
meta_info["input_token_logprobs"],
|
729
|
+
meta_info["output_token_logprobs"],
|
730
|
+
meta_info["input_top_logprobs"],
|
731
|
+
meta_info["output_top_logprobs"],
|
646
732
|
meta_info["normalized_prompt_logprob"],
|
647
733
|
) = (
|
648
|
-
req.
|
649
|
-
req.
|
650
|
-
req.
|
651
|
-
req.
|
734
|
+
req.input_token_logprobs,
|
735
|
+
req.output_token_logprobs,
|
736
|
+
req.input_top_logprobs,
|
737
|
+
req.output_top_logprobs,
|
652
738
|
req.normalized_prompt_logprob,
|
653
739
|
)
|
654
740
|
output_meta_info.append(meta_info)
|
@@ -690,8 +776,18 @@ class ModelTpServer:
|
|
690
776
|
else:
|
691
777
|
batch.reqs = []
|
692
778
|
|
779
|
+
def filter_out_inflight(self, batch: Batch):
|
780
|
+
# TODO(lsyin): reduce the overhead, make a special version for this
|
781
|
+
if self.current_inflight_req is None:
|
782
|
+
return
|
783
|
+
|
784
|
+
to_remove = batch.reqs.index(self.current_inflight_req)
|
785
|
+
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
786
|
+
|
787
|
+
batch.filter_batch(unfinished_indices)
|
788
|
+
|
693
789
|
def flush_cache(self):
|
694
|
-
if len(self.
|
790
|
+
if len(self.waiting_queue) == 0 and (
|
695
791
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
696
792
|
):
|
697
793
|
self.tree_cache.reset()
|
@@ -704,20 +800,20 @@ class ModelTpServer:
|
|
704
800
|
else:
|
705
801
|
warnings.warn(
|
706
802
|
f"Cache not flushed because there are pending requests. "
|
707
|
-
f"#queue-req: {len(self.
|
803
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
708
804
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
709
805
|
)
|
710
806
|
|
711
807
|
def abort_request(self, recv_req):
|
712
808
|
# Delete requests in the waiting queue
|
713
809
|
to_del = None
|
714
|
-
for i, req in enumerate(self.
|
810
|
+
for i, req in enumerate(self.waiting_queue):
|
715
811
|
if req.rid == recv_req.rid:
|
716
812
|
to_del = i
|
717
813
|
break
|
718
814
|
|
719
815
|
if to_del is not None:
|
720
|
-
del self.
|
816
|
+
del self.waiting_queue[to_del]
|
721
817
|
|
722
818
|
# Delete requests in the running batch
|
723
819
|
if self.running_batch:
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Flush the KV cache.
|
18
|
+
|
19
|
+
Usage:
|
20
|
+
python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
|
21
|
+
"""
|
22
|
+
|
23
|
+
import argparse
|
24
|
+
|
25
|
+
import requests
|
26
|
+
|
27
|
+
if __name__ == "__main__":
|
28
|
+
parser = argparse.ArgumentParser()
|
29
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
args = parser.parse_args()
|
31
|
+
|
32
|
+
response = requests.get(args.url + "/flush_cache")
|
33
|
+
assert response.status_code == 200
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Memory pool."""
|
2
17
|
|
3
18
|
import logging
|
@@ -30,7 +45,7 @@ class ReqToTokenPool:
|
|
30
45
|
|
31
46
|
return select_index
|
32
47
|
|
33
|
-
def free(self, free_index
|
48
|
+
def free(self, free_index):
|
34
49
|
self.mem_state[free_index] = True
|
35
50
|
if isinstance(free_index, (int,)):
|
36
51
|
self.can_use_mem_size += 1
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
The radix tree data structure for managing the KV cache.
|
3
18
|
"""
|
sglang/srt/mm_utils.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
2
17
|
import ast
|
3
18
|
import base64
|
sglang/srt/model_config.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
from typing import Optional
|
2
17
|
|
3
18
|
from transformers import PretrainedConfig
|
@@ -36,6 +51,11 @@ class ModelConfig:
|
|
36
51
|
"head_dim",
|
37
52
|
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
|
38
53
|
)
|
54
|
+
|
55
|
+
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
56
|
+
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
57
|
+
self.head_dim = 256
|
58
|
+
|
39
59
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
40
60
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
41
61
|
|