sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +71 -1
- sglang/check_env.py +3 -6
- sglang/srt/constrained/outlines_backend.py +15 -2
- sglang/srt/constrained/xgrammar_backend.py +22 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +204 -54
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +99 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +27 -1
- sglang/srt/server_args.py +78 -62
- sglang/srt/utils.py +71 -52
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +30 -19
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -15,6 +15,7 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
17
|
|
18
|
+
import dataclasses
|
18
19
|
import logging
|
19
20
|
import os
|
20
21
|
import threading
|
@@ -29,16 +30,19 @@ import torch
|
|
29
30
|
import zmq
|
30
31
|
|
31
32
|
from sglang.global_config import global_config
|
32
|
-
from sglang.srt.configs.model_config import ModelConfig
|
33
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
34
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
35
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
36
|
from sglang.srt.managers.io_struct import (
|
36
37
|
AbortReq,
|
37
38
|
BatchEmbeddingOut,
|
38
39
|
BatchTokenIDOut,
|
40
|
+
CloseSessionReqInput,
|
39
41
|
FlushCacheReq,
|
40
42
|
GetMemPoolSizeReq,
|
41
43
|
GetMemPoolSizeReqOutput,
|
44
|
+
OpenSessionReqInput,
|
45
|
+
OpenSessionReqOutput,
|
42
46
|
ProfileReq,
|
43
47
|
TokenizedEmbeddingReqInput,
|
44
48
|
TokenizedGenerateReqInput,
|
@@ -58,15 +62,18 @@ from sglang.srt.managers.schedule_policy import (
|
|
58
62
|
PrefillAdder,
|
59
63
|
SchedulePolicy,
|
60
64
|
)
|
65
|
+
from sglang.srt.managers.session_controller import Session
|
61
66
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
62
67
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
63
68
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
64
69
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
65
70
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
71
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
66
72
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
67
73
|
from sglang.srt.utils import (
|
68
74
|
broadcast_pyobj,
|
69
75
|
configure_logger,
|
76
|
+
crash_on_warnings,
|
70
77
|
get_zmq_socket,
|
71
78
|
kill_parent_process,
|
72
79
|
set_random_seed,
|
@@ -76,10 +83,6 @@ from sglang.utils import get_exception_traceback
|
|
76
83
|
|
77
84
|
logger = logging.getLogger(__name__)
|
78
85
|
|
79
|
-
|
80
|
-
# Crash on warning if we are running CI tests
|
81
|
-
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
82
|
-
|
83
86
|
# Test retract decode
|
84
87
|
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
85
88
|
|
@@ -103,14 +106,17 @@ class Scheduler:
|
|
103
106
|
self.disable_jump_forward = server_args.disable_jump_forward
|
104
107
|
self.lora_paths = server_args.lora_paths
|
105
108
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
106
|
-
self.enable_overlap = server_args.
|
109
|
+
self.enable_overlap = not server_args.disable_overlap_schedule
|
107
110
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
108
111
|
self.enable_metrics = server_args.enable_metrics
|
109
112
|
|
113
|
+
# Session info
|
114
|
+
self.sessions = {}
|
115
|
+
|
110
116
|
# Init inter-process communication
|
111
117
|
context = zmq.Context(2)
|
112
118
|
|
113
|
-
if self.tp_rank == 0:
|
119
|
+
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
114
120
|
self.recv_from_tokenizer = get_zmq_socket(
|
115
121
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
116
122
|
)
|
@@ -160,6 +166,14 @@ class Scheduler:
|
|
160
166
|
trust_remote_code=server_args.trust_remote_code,
|
161
167
|
)
|
162
168
|
|
169
|
+
# Check whether overlap can be enabled
|
170
|
+
if not self.is_generation:
|
171
|
+
self.enable_overlap = False
|
172
|
+
logger.info("Overlap scheduler is disabled for embedding models.")
|
173
|
+
|
174
|
+
if self.enable_overlap:
|
175
|
+
self.disable_jump_forward = True
|
176
|
+
|
163
177
|
# Launch a tensor parallel worker
|
164
178
|
if self.enable_overlap:
|
165
179
|
TpWorkerClass = TpModelWorkerClient
|
@@ -223,8 +237,12 @@ class Scheduler:
|
|
223
237
|
|
224
238
|
# Init running status
|
225
239
|
self.waiting_queue: List[Req] = []
|
240
|
+
# The running decoding batch for continuous batching
|
226
241
|
self.running_batch: Optional[ScheduleBatch] = None
|
242
|
+
# The current forward batch
|
227
243
|
self.cur_batch: Optional[ScheduleBatch] = None
|
244
|
+
# The current forward batch
|
245
|
+
self.last_batch: Optional[ScheduleBatch] = None
|
228
246
|
self.forward_ct = 0
|
229
247
|
self.forward_ct_decode = 0
|
230
248
|
self.num_generated_tokens = 0
|
@@ -337,46 +355,34 @@ class Scheduler:
|
|
337
355
|
|
338
356
|
kill_parent_process()
|
339
357
|
|
340
|
-
@torch.
|
358
|
+
@torch.no_grad()
|
341
359
|
def event_loop_normal(self):
|
342
|
-
"""A normal
|
343
|
-
self.last_batch = None
|
344
|
-
|
360
|
+
"""A normal scheduler loop."""
|
345
361
|
while True:
|
346
362
|
recv_reqs = self.recv_requests()
|
347
363
|
self.process_input_requests(recv_reqs)
|
348
364
|
|
349
365
|
batch = self.get_next_batch_to_run()
|
366
|
+
if self.server_args.enable_dp_attention:
|
367
|
+
batch = self.prepare_dp_attn_batch(batch)
|
368
|
+
|
350
369
|
self.cur_batch = batch
|
351
370
|
|
352
371
|
if batch:
|
353
372
|
result = self.run_batch(batch)
|
354
373
|
self.process_batch_result(batch, result)
|
355
|
-
|
356
|
-
# Decode multiple steps to reduce the overhead
|
357
|
-
if batch.forward_mode.is_decode():
|
358
|
-
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
359
|
-
if not self.running_batch:
|
360
|
-
break
|
361
|
-
self.update_running_batch()
|
362
|
-
if not self.running_batch:
|
363
|
-
break
|
364
|
-
result = self.run_batch(batch)
|
365
|
-
self.process_batch_result(batch, result)
|
366
374
|
else:
|
375
|
+
# Self-check and re-init some states when the server is idle
|
367
376
|
self.check_memory()
|
368
377
|
self.new_token_ratio = self.init_new_token_ratio
|
369
378
|
|
370
379
|
self.last_batch = batch
|
371
380
|
|
372
|
-
@torch.
|
381
|
+
@torch.no_grad()
|
373
382
|
def event_loop_overlap(self):
|
374
383
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
375
384
|
result_queue = deque()
|
376
385
|
|
377
|
-
self.last_batch = None
|
378
|
-
self.running_batch = None
|
379
|
-
|
380
386
|
while True:
|
381
387
|
recv_reqs = self.recv_requests()
|
382
388
|
self.process_input_requests(recv_reqs)
|
@@ -387,17 +393,85 @@ class Scheduler:
|
|
387
393
|
result = self.run_batch(batch)
|
388
394
|
result_queue.append((batch.copy(), result))
|
389
395
|
|
396
|
+
if self.last_batch is None:
|
397
|
+
# A dummy first batch to start the pipeline for overlap scheduler.
|
398
|
+
# It is now used for triggering the sampling_info_done event.
|
399
|
+
tmp_batch = ScheduleBatch(
|
400
|
+
reqs=None,
|
401
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
402
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
403
|
+
)
|
404
|
+
self.process_batch_result(tmp_batch, None)
|
405
|
+
|
390
406
|
if self.last_batch:
|
391
407
|
tmp_batch, tmp_result = result_queue.popleft()
|
408
|
+
tmp_batch.next_batch_sampling_info = (
|
409
|
+
self.tp_worker.cur_sampling_info if batch else None
|
410
|
+
)
|
392
411
|
self.process_batch_result(tmp_batch, tmp_result)
|
393
412
|
elif batch is None:
|
413
|
+
# Self-check and re-init some states when the server is idle
|
394
414
|
self.check_memory()
|
395
415
|
self.new_token_ratio = self.init_new_token_ratio
|
396
416
|
|
397
417
|
self.last_batch = batch
|
398
418
|
|
419
|
+
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
420
|
+
# Check if other DP workers have running batches
|
421
|
+
if local_batch is None:
|
422
|
+
num_tokens = 0
|
423
|
+
elif local_batch.forward_mode.is_decode():
|
424
|
+
num_tokens = local_batch.batch_size()
|
425
|
+
else:
|
426
|
+
num_tokens = local_batch.extend_num_tokens
|
427
|
+
|
428
|
+
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
429
|
+
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
430
|
+
torch.distributed.all_gather_into_tensor(
|
431
|
+
global_num_tokens,
|
432
|
+
local_num_tokens,
|
433
|
+
group=self.tp_cpu_group,
|
434
|
+
)
|
435
|
+
|
436
|
+
if local_batch is None and global_num_tokens.max().item() > 0:
|
437
|
+
local_batch = self.get_idle_batch()
|
438
|
+
|
439
|
+
if local_batch is not None:
|
440
|
+
local_batch.global_num_tokens = global_num_tokens.tolist()
|
441
|
+
|
442
|
+
# Check forward mode for cuda graph
|
443
|
+
if not self.server_args.disable_cuda_graph:
|
444
|
+
forward_mode_state = torch.tensor(
|
445
|
+
(
|
446
|
+
1
|
447
|
+
if local_batch.forward_mode.is_decode()
|
448
|
+
or local_batch.forward_mode.is_idle()
|
449
|
+
else 0
|
450
|
+
),
|
451
|
+
dtype=torch.int32,
|
452
|
+
)
|
453
|
+
torch.distributed.all_reduce(
|
454
|
+
forward_mode_state,
|
455
|
+
op=torch.distributed.ReduceOp.MIN,
|
456
|
+
group=self.tp_cpu_group,
|
457
|
+
)
|
458
|
+
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
459
|
+
|
460
|
+
return local_batch
|
461
|
+
|
462
|
+
def get_idle_batch(self):
|
463
|
+
idle_batch = ScheduleBatch.init_new(
|
464
|
+
[],
|
465
|
+
self.req_to_token_pool,
|
466
|
+
self.token_to_kv_pool,
|
467
|
+
self.tree_cache,
|
468
|
+
self.model_config,
|
469
|
+
)
|
470
|
+
idle_batch.prepare_for_idle()
|
471
|
+
return idle_batch
|
472
|
+
|
399
473
|
def recv_requests(self):
|
400
|
-
if self.tp_rank == 0:
|
474
|
+
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
401
475
|
recv_reqs = []
|
402
476
|
|
403
477
|
while True:
|
@@ -409,7 +483,7 @@ class Scheduler:
|
|
409
483
|
else:
|
410
484
|
recv_reqs = None
|
411
485
|
|
412
|
-
if self.tp_size != 1:
|
486
|
+
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
413
487
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
414
488
|
return recv_reqs
|
415
489
|
|
@@ -433,6 +507,11 @@ class Scheduler:
|
|
433
507
|
self.start_profile()
|
434
508
|
else:
|
435
509
|
self.stop_profile()
|
510
|
+
elif isinstance(recv_req, OpenSessionReqInput):
|
511
|
+
session_id = self.open_session(recv_req)
|
512
|
+
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
513
|
+
elif isinstance(recv_req, CloseSessionReqInput):
|
514
|
+
self.close_session(recv_req)
|
436
515
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
437
516
|
self.send_to_tokenizer.send_pyobj(
|
438
517
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
@@ -444,14 +523,30 @@ class Scheduler:
|
|
444
523
|
self,
|
445
524
|
recv_req: TokenizedGenerateReqInput,
|
446
525
|
):
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
526
|
+
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
527
|
+
req = Req(
|
528
|
+
recv_req.rid,
|
529
|
+
recv_req.input_text,
|
530
|
+
recv_req.input_ids,
|
531
|
+
recv_req.sampling_params,
|
532
|
+
lora_path=recv_req.lora_path,
|
533
|
+
)
|
534
|
+
req.tokenizer = self.tokenizer
|
535
|
+
if recv_req.session_id is not None:
|
536
|
+
req.finished_reason = FINISH_ABORT(
|
537
|
+
f"Invalid request: session id {recv_req.session_id} does not exist"
|
538
|
+
)
|
539
|
+
self.waiting_queue.append(req)
|
540
|
+
return
|
541
|
+
else:
|
542
|
+
# Handle sessions
|
543
|
+
session = self.sessions[recv_req.session_id]
|
544
|
+
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
545
|
+
del self.sessions[recv_req.session_id]
|
546
|
+
self.sessions[new_session_id] = session
|
547
|
+
if isinstance(req.finished_reason, FINISH_ABORT):
|
548
|
+
self.waiting_queue.append(req)
|
549
|
+
return
|
455
550
|
|
456
551
|
# Image inputs
|
457
552
|
if recv_req.image_inputs is not None:
|
@@ -462,6 +557,15 @@ class Scheduler:
|
|
462
557
|
req.origin_input_ids_unpadded, req.image_inputs
|
463
558
|
)
|
464
559
|
|
560
|
+
if len(req.origin_input_ids) > self.max_req_input_len:
|
561
|
+
req.finished_reason = FINISH_ABORT(
|
562
|
+
"Image request length is longer than the KV cache pool size or "
|
563
|
+
"the max context length aborting because you cannot truncate the image embeds"
|
564
|
+
)
|
565
|
+
req.sampling_params.max_new_tokens = 0
|
566
|
+
self.waiting_queue.append(req)
|
567
|
+
return
|
568
|
+
|
465
569
|
req.return_logprob = recv_req.return_logprob
|
466
570
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
467
571
|
req.stream = recv_req.stream
|
@@ -599,21 +703,23 @@ class Scheduler:
|
|
599
703
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
600
704
|
)
|
601
705
|
if available_size != self.max_total_num_tokens:
|
602
|
-
|
603
|
-
"Warning: "
|
604
|
-
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
706
|
+
msg = (
|
605
707
|
"KV cache pool leak detected!"
|
708
|
+
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
606
709
|
)
|
607
|
-
|
710
|
+
warnings.warn(msg)
|
711
|
+
if crash_on_warnings():
|
712
|
+
raise ValueError(msg)
|
608
713
|
|
609
714
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
610
|
-
|
611
|
-
"Warning: "
|
612
|
-
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
613
|
-
f"total slots={self.req_to_token_pool.size}\n"
|
715
|
+
msg = (
|
614
716
|
"Memory pool leak detected!"
|
717
|
+
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
718
|
+
f"total_size={self.req_to_token_pool.size}\n"
|
615
719
|
)
|
616
|
-
|
720
|
+
warnings.warn(msg)
|
721
|
+
if crash_on_warnings():
|
722
|
+
raise ValueError(msg)
|
617
723
|
|
618
724
|
def get_next_batch_to_run(self):
|
619
725
|
# Merge the prefill batch into the running batch
|
@@ -747,7 +853,7 @@ class Scheduler:
|
|
747
853
|
self.tree_cache,
|
748
854
|
self.model_config,
|
749
855
|
)
|
750
|
-
new_batch.prepare_for_extend()
|
856
|
+
new_batch.prepare_for_extend(self.enable_overlap)
|
751
857
|
|
752
858
|
# Mixed-style chunked prefill
|
753
859
|
if self.is_mixed_chunk and self.running_batch is not None:
|
@@ -812,6 +918,10 @@ class Scheduler:
|
|
812
918
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
813
919
|
model_worker_batch
|
814
920
|
)
|
921
|
+
elif batch.forward_mode.is_idle():
|
922
|
+
model_worker_batch = batch.get_model_worker_batch()
|
923
|
+
self.tp_worker.forward_batch_idle(model_worker_batch)
|
924
|
+
return
|
815
925
|
else:
|
816
926
|
logits_output = None
|
817
927
|
if self.skip_tokenizer_init:
|
@@ -834,8 +944,12 @@ class Scheduler:
|
|
834
944
|
self.process_batch_result_decode(batch, result)
|
835
945
|
if batch.is_empty():
|
836
946
|
self.running_batch = None
|
837
|
-
|
947
|
+
elif batch.forward_mode.is_extend():
|
838
948
|
self.process_batch_result_prefill(batch, result)
|
949
|
+
elif batch.forward_mode.is_dummy_first():
|
950
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
951
|
+
torch.cuda.current_stream().synchronize()
|
952
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
839
953
|
|
840
954
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
841
955
|
|
@@ -843,7 +957,7 @@ class Scheduler:
|
|
843
957
|
logits_output, next_token_ids, bid = result
|
844
958
|
|
845
959
|
if self.enable_overlap:
|
846
|
-
logits_output, next_token_ids = self.tp_worker.
|
960
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
847
961
|
else:
|
848
962
|
# Move next_token_ids and logprobs to cpu
|
849
963
|
if batch.return_logprob:
|
@@ -863,14 +977,14 @@ class Scheduler:
|
|
863
977
|
|
864
978
|
# Check finish conditions
|
865
979
|
logprob_pt = 0
|
866
|
-
for i, req in enumerate(batch.reqs):
|
980
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
867
981
|
if req.is_retracted:
|
868
982
|
continue
|
869
983
|
|
870
984
|
if req.is_being_chunked <= 0:
|
871
985
|
# Inflight reqs' prefill is not finished
|
872
986
|
req.completion_tokens_wo_jump_forward += 1
|
873
|
-
req.output_ids.append(
|
987
|
+
req.output_ids.append(next_token_id)
|
874
988
|
req.check_finished()
|
875
989
|
|
876
990
|
if req.finished():
|
@@ -879,7 +993,7 @@ class Scheduler:
|
|
879
993
|
self.tree_cache.cache_unfinished_req(req)
|
880
994
|
|
881
995
|
if req.grammar is not None:
|
882
|
-
req.grammar.accept_token(
|
996
|
+
req.grammar.accept_token(next_token_id)
|
883
997
|
|
884
998
|
if req.return_logprob:
|
885
999
|
logprob_pt += self.add_logprob_return_values(
|
@@ -888,6 +1002,11 @@ class Scheduler:
|
|
888
1002
|
else:
|
889
1003
|
req.is_being_chunked -= 1
|
890
1004
|
|
1005
|
+
if batch.next_batch_sampling_info:
|
1006
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1007
|
+
torch.cuda.current_stream().synchronize()
|
1008
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1009
|
+
|
891
1010
|
else: # embedding or reward model
|
892
1011
|
embeddings, bid = result
|
893
1012
|
embeddings = embeddings.tolist()
|
@@ -918,7 +1037,7 @@ class Scheduler:
|
|
918
1037
|
self.num_generated_tokens += len(batch.reqs)
|
919
1038
|
|
920
1039
|
if self.enable_overlap:
|
921
|
-
logits_output, next_token_ids = self.tp_worker.
|
1040
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
922
1041
|
next_token_logprobs = logits_output.next_token_logprobs
|
923
1042
|
else:
|
924
1043
|
# Move next_token_ids and logprobs to cpu
|
@@ -936,7 +1055,7 @@ class Scheduler:
|
|
936
1055
|
if req.is_retracted:
|
937
1056
|
continue
|
938
1057
|
|
939
|
-
if self.
|
1058
|
+
if self.enable_overlap and req.finished():
|
940
1059
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
941
1060
|
continue
|
942
1061
|
|
@@ -957,6 +1076,11 @@ class Scheduler:
|
|
957
1076
|
if req.top_logprobs_num > 0:
|
958
1077
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
959
1078
|
|
1079
|
+
if batch.next_batch_sampling_info:
|
1080
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1081
|
+
torch.cuda.current_stream().synchronize()
|
1082
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1083
|
+
|
960
1084
|
self.stream_output(batch.reqs)
|
961
1085
|
|
962
1086
|
self.token_to_kv_pool.free_group_end()
|
@@ -1055,6 +1179,7 @@ class Scheduler:
|
|
1055
1179
|
output_skip_special_tokens = []
|
1056
1180
|
output_spaces_between_special_tokens = []
|
1057
1181
|
output_no_stop_trim = []
|
1182
|
+
output_session_ids = []
|
1058
1183
|
else: # embedding or reward model
|
1059
1184
|
output_embeddings = []
|
1060
1185
|
|
@@ -1082,6 +1207,7 @@ class Scheduler:
|
|
1082
1207
|
req.sampling_params.spaces_between_special_tokens
|
1083
1208
|
)
|
1084
1209
|
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1210
|
+
output_session_ids.append(req.session_id)
|
1085
1211
|
|
1086
1212
|
meta_info = {
|
1087
1213
|
"prompt_tokens": len(req.origin_input_ids),
|
@@ -1132,6 +1258,7 @@ class Scheduler:
|
|
1132
1258
|
output_meta_info,
|
1133
1259
|
output_finished_reason,
|
1134
1260
|
output_no_stop_trim,
|
1261
|
+
output_session_ids,
|
1135
1262
|
)
|
1136
1263
|
)
|
1137
1264
|
else: # embedding or reward model
|
@@ -1234,6 +1361,25 @@ class Scheduler:
|
|
1234
1361
|
)
|
1235
1362
|
logger.info("Profiler is done")
|
1236
1363
|
|
1364
|
+
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
1365
|
+
# handle error
|
1366
|
+
session_id = recv_req.session_id
|
1367
|
+
if session_id in self.sessions:
|
1368
|
+
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1369
|
+
else:
|
1370
|
+
self.sessions[session_id] = Session(
|
1371
|
+
recv_req.capacity_of_str_len, session_id
|
1372
|
+
)
|
1373
|
+
return session_id
|
1374
|
+
|
1375
|
+
def close_session(self, recv_req: CloseSessionReqInput):
|
1376
|
+
# handle error
|
1377
|
+
session_id = recv_req.session_id
|
1378
|
+
if session_id not in self.sessions:
|
1379
|
+
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
1380
|
+
else:
|
1381
|
+
del self.sessions[session_id]
|
1382
|
+
|
1237
1383
|
|
1238
1384
|
def run_scheduler_process(
|
1239
1385
|
server_args: ServerArgs,
|
@@ -1243,6 +1389,10 @@ def run_scheduler_process(
|
|
1243
1389
|
dp_rank: Optional[int],
|
1244
1390
|
pipe_writer,
|
1245
1391
|
):
|
1392
|
+
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
1393
|
+
if dp_rank is None:
|
1394
|
+
dp_rank = int(os.getenv("DP_RANK", -1))
|
1395
|
+
|
1246
1396
|
if dp_rank is None:
|
1247
1397
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1248
1398
|
else:
|
@@ -1253,7 +1403,7 @@ def run_scheduler_process(
|
|
1253
1403
|
try:
|
1254
1404
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1255
1405
|
pipe_writer.send("ready")
|
1256
|
-
if
|
1406
|
+
if scheduler.enable_overlap:
|
1257
1407
|
scheduler.event_loop_overlap()
|
1258
1408
|
else:
|
1259
1409
|
scheduler.event_loop_normal()
|
@@ -0,0 +1,62 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
import copy
|
15
|
+
import uuid
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
20
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
21
|
+
|
22
|
+
|
23
|
+
class Session:
|
24
|
+
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
25
|
+
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
26
|
+
self.capacity_of_str_len = capacity_of_str_len
|
27
|
+
self.reqs: List[Req] = []
|
28
|
+
|
29
|
+
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
30
|
+
# renew session id
|
31
|
+
self.session_id = uuid.uuid4().hex
|
32
|
+
if req.session_rid is not None:
|
33
|
+
while len(self.reqs) > 0:
|
34
|
+
if self.reqs[-1].rid == req.session_rid:
|
35
|
+
break
|
36
|
+
self.reqs = self.reqs[:-1]
|
37
|
+
if len(self.reqs) > 0:
|
38
|
+
input_ids = (
|
39
|
+
self.reqs[-1].origin_input_ids
|
40
|
+
+ self.reqs[-1].output_ids[
|
41
|
+
: self.reqs[-1].sampling_params.max_new_tokens
|
42
|
+
]
|
43
|
+
+ req.input_ids
|
44
|
+
)
|
45
|
+
else:
|
46
|
+
input_ids = req.input_ids
|
47
|
+
new_req = Req(
|
48
|
+
req.rid,
|
49
|
+
None,
|
50
|
+
input_ids,
|
51
|
+
req.sampling_params,
|
52
|
+
lora_path=req.lora_path,
|
53
|
+
session_id=self.session_id,
|
54
|
+
)
|
55
|
+
new_req.tokenizer = tokenizer
|
56
|
+
if req.session_rid is not None and len(self.reqs) == 0:
|
57
|
+
new_req.finished_reason = FINISH_ABORT(
|
58
|
+
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
self.reqs.append(new_req)
|
62
|
+
return new_req, self.session_id
|
@@ -23,6 +23,7 @@ import os
|
|
23
23
|
import signal
|
24
24
|
import sys
|
25
25
|
import time
|
26
|
+
import uuid
|
26
27
|
from typing import Dict, List, Optional, Tuple, Union
|
27
28
|
|
28
29
|
import fastapi
|
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
|
|
42
43
|
BatchEmbeddingOut,
|
43
44
|
BatchStrOut,
|
44
45
|
BatchTokenIDOut,
|
46
|
+
CloseSessionReqInput,
|
45
47
|
EmbeddingReqInput,
|
46
48
|
FlushCacheReq,
|
47
49
|
GenerateReqInput,
|
48
50
|
GetMemPoolSizeReq,
|
49
51
|
GetMemPoolSizeReqOutput,
|
52
|
+
OpenSessionReqInput,
|
53
|
+
OpenSessionReqOutput,
|
50
54
|
ProfileReq,
|
51
55
|
TokenizedEmbeddingReqInput,
|
52
56
|
TokenizedGenerateReqInput,
|
@@ -146,6 +150,9 @@ class TokenizerManager:
|
|
146
150
|
self.model_update_lock = asyncio.Lock()
|
147
151
|
self.model_update_result = None
|
148
152
|
|
153
|
+
# For session info
|
154
|
+
self.session_futures = {} # session_id -> asyncio event
|
155
|
+
|
149
156
|
# Others
|
150
157
|
self.gracefully_exit = False
|
151
158
|
|
@@ -211,6 +218,8 @@ class TokenizerManager:
|
|
211
218
|
return_logprob = obj.return_logprob
|
212
219
|
logprob_start_len = obj.logprob_start_len
|
213
220
|
top_logprobs_num = obj.top_logprobs_num
|
221
|
+
session_id = obj.session_id
|
222
|
+
session_rid = obj.session_rid
|
214
223
|
|
215
224
|
if len(input_ids) >= self.context_len:
|
216
225
|
raise ValueError(
|
@@ -236,6 +245,8 @@ class TokenizerManager:
|
|
236
245
|
top_logprobs_num,
|
237
246
|
obj.stream,
|
238
247
|
obj.lora_path,
|
248
|
+
session_id=session_id,
|
249
|
+
session_rid=session_rid,
|
239
250
|
)
|
240
251
|
elif isinstance(obj, EmbeddingReqInput):
|
241
252
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -451,6 +462,26 @@ class TokenizerManager:
|
|
451
462
|
else:
|
452
463
|
return False, "Another update is in progress. Please try again later."
|
453
464
|
|
465
|
+
async def open_session(
|
466
|
+
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
467
|
+
):
|
468
|
+
if self.to_create_loop:
|
469
|
+
self.create_handle_loop()
|
470
|
+
|
471
|
+
session_id = uuid.uuid4().hex
|
472
|
+
obj.session_id = session_id
|
473
|
+
self.send_to_scheduler.send_pyobj(obj)
|
474
|
+
self.session_futures[session_id] = asyncio.Future()
|
475
|
+
session_id = await self.session_futures[session_id]
|
476
|
+
del self.session_futures[session_id]
|
477
|
+
return session_id
|
478
|
+
|
479
|
+
async def close_session(
|
480
|
+
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
481
|
+
):
|
482
|
+
assert not self.to_create_loop, "close session should not be the first request"
|
483
|
+
await self.send_to_scheduler.send_pyobj(obj)
|
484
|
+
|
454
485
|
def create_abort_task(self, obj: GenerateReqInput):
|
455
486
|
# Abort the request if the client is disconnected.
|
456
487
|
async def abort_request():
|
@@ -521,6 +552,11 @@ class TokenizerManager:
|
|
521
552
|
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
522
553
|
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
523
554
|
continue
|
555
|
+
elif isinstance(recv_obj, OpenSessionReqOutput):
|
556
|
+
self.session_futures[recv_obj.session_id].set_result(
|
557
|
+
recv_obj.session_id
|
558
|
+
)
|
559
|
+
continue
|
524
560
|
|
525
561
|
assert isinstance(
|
526
562
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
@@ -536,11 +572,13 @@ class TokenizerManager:
|
|
536
572
|
out_dict = {
|
537
573
|
"text": recv_obj.output_strs[i],
|
538
574
|
"meta_info": recv_obj.meta_info[i],
|
575
|
+
"session_id": recv_obj.session_ids[i],
|
539
576
|
}
|
540
577
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
541
578
|
out_dict = {
|
542
579
|
"token_ids": recv_obj.output_ids[i],
|
543
580
|
"meta_info": recv_obj.meta_info[i],
|
581
|
+
"session_id": recv_obj.session_ids[i],
|
544
582
|
}
|
545
583
|
else:
|
546
584
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|