sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +337 -0
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +115 -31
- sglang/check_env.py +3 -6
- sglang/srt/constrained/base_grammar_backend.py +4 -3
- sglang/srt/constrained/outlines_backend.py +39 -26
- sglang/srt/constrained/xgrammar_backend.py +58 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -14
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +210 -56
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +102 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +30 -27
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +3 -3
- sglang/srt/server.py +29 -2
- sglang/srt/server_args.py +97 -60
- sglang/srt/utils.py +103 -51
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +33 -22
- sglang/version.py +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -15,6 +15,7 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
17
|
|
18
|
+
import dataclasses
|
18
19
|
import logging
|
19
20
|
import os
|
20
21
|
import threading
|
@@ -29,16 +30,19 @@ import torch
|
|
29
30
|
import zmq
|
30
31
|
|
31
32
|
from sglang.global_config import global_config
|
32
|
-
from sglang.srt.configs.model_config import ModelConfig
|
33
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
34
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
35
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
36
|
from sglang.srt.managers.io_struct import (
|
36
37
|
AbortReq,
|
37
38
|
BatchEmbeddingOut,
|
38
39
|
BatchTokenIDOut,
|
40
|
+
CloseSessionReqInput,
|
39
41
|
FlushCacheReq,
|
40
42
|
GetMemPoolSizeReq,
|
41
43
|
GetMemPoolSizeReqOutput,
|
44
|
+
OpenSessionReqInput,
|
45
|
+
OpenSessionReqOutput,
|
42
46
|
ProfileReq,
|
43
47
|
TokenizedEmbeddingReqInput,
|
44
48
|
TokenizedGenerateReqInput,
|
@@ -58,15 +62,18 @@ from sglang.srt.managers.schedule_policy import (
|
|
58
62
|
PrefillAdder,
|
59
63
|
SchedulePolicy,
|
60
64
|
)
|
65
|
+
from sglang.srt.managers.session_controller import Session
|
61
66
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
62
67
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
63
68
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
64
69
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
65
70
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
71
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
66
72
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
67
73
|
from sglang.srt.utils import (
|
68
74
|
broadcast_pyobj,
|
69
75
|
configure_logger,
|
76
|
+
crash_on_warnings,
|
70
77
|
get_zmq_socket,
|
71
78
|
kill_parent_process,
|
72
79
|
set_random_seed,
|
@@ -76,10 +83,6 @@ from sglang.utils import get_exception_traceback
|
|
76
83
|
|
77
84
|
logger = logging.getLogger(__name__)
|
78
85
|
|
79
|
-
|
80
|
-
# Crash on warning if we are running CI tests
|
81
|
-
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
82
|
-
|
83
86
|
# Test retract decode
|
84
87
|
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
85
88
|
|
@@ -103,17 +106,23 @@ class Scheduler:
|
|
103
106
|
self.disable_jump_forward = server_args.disable_jump_forward
|
104
107
|
self.lora_paths = server_args.lora_paths
|
105
108
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
106
|
-
self.enable_overlap = server_args.
|
109
|
+
self.enable_overlap = not server_args.disable_overlap_schedule
|
107
110
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
108
111
|
self.enable_metrics = server_args.enable_metrics
|
109
112
|
|
113
|
+
# Session info
|
114
|
+
self.sessions = {}
|
115
|
+
|
110
116
|
# Init inter-process communication
|
111
117
|
context = zmq.Context(2)
|
112
118
|
|
113
|
-
if self.tp_rank == 0:
|
119
|
+
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
114
120
|
self.recv_from_tokenizer = get_zmq_socket(
|
115
121
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
116
122
|
)
|
123
|
+
self.send_to_tokenizer = get_zmq_socket(
|
124
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
125
|
+
)
|
117
126
|
|
118
127
|
if server_args.skip_tokenizer_init:
|
119
128
|
# Directly send to the tokenizer/api
|
@@ -127,6 +136,7 @@ class Scheduler:
|
|
127
136
|
)
|
128
137
|
else:
|
129
138
|
self.recv_from_tokenizer = None
|
139
|
+
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
130
140
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
131
141
|
|
132
142
|
# Init tokenizer
|
@@ -156,6 +166,14 @@ class Scheduler:
|
|
156
166
|
trust_remote_code=server_args.trust_remote_code,
|
157
167
|
)
|
158
168
|
|
169
|
+
# Check whether overlap can be enabled
|
170
|
+
if not self.is_generation:
|
171
|
+
self.enable_overlap = False
|
172
|
+
logger.info("Overlap scheduler is disabled for embedding models.")
|
173
|
+
|
174
|
+
if self.enable_overlap:
|
175
|
+
self.disable_jump_forward = True
|
176
|
+
|
159
177
|
# Launch a tensor parallel worker
|
160
178
|
if self.enable_overlap:
|
161
179
|
TpWorkerClass = TpModelWorkerClient
|
@@ -219,8 +237,12 @@ class Scheduler:
|
|
219
237
|
|
220
238
|
# Init running status
|
221
239
|
self.waiting_queue: List[Req] = []
|
240
|
+
# The running decoding batch for continuous batching
|
222
241
|
self.running_batch: Optional[ScheduleBatch] = None
|
242
|
+
# The current forward batch
|
223
243
|
self.cur_batch: Optional[ScheduleBatch] = None
|
244
|
+
# The current forward batch
|
245
|
+
self.last_batch: Optional[ScheduleBatch] = None
|
224
246
|
self.forward_ct = 0
|
225
247
|
self.forward_ct_decode = 0
|
226
248
|
self.num_generated_tokens = 0
|
@@ -333,46 +355,34 @@ class Scheduler:
|
|
333
355
|
|
334
356
|
kill_parent_process()
|
335
357
|
|
336
|
-
@torch.
|
358
|
+
@torch.no_grad()
|
337
359
|
def event_loop_normal(self):
|
338
|
-
"""A normal
|
339
|
-
self.last_batch = None
|
340
|
-
|
360
|
+
"""A normal scheduler loop."""
|
341
361
|
while True:
|
342
362
|
recv_reqs = self.recv_requests()
|
343
363
|
self.process_input_requests(recv_reqs)
|
344
364
|
|
345
365
|
batch = self.get_next_batch_to_run()
|
366
|
+
if self.server_args.enable_dp_attention:
|
367
|
+
batch = self.prepare_dp_attn_batch(batch)
|
368
|
+
|
346
369
|
self.cur_batch = batch
|
347
370
|
|
348
371
|
if batch:
|
349
372
|
result = self.run_batch(batch)
|
350
373
|
self.process_batch_result(batch, result)
|
351
|
-
|
352
|
-
# Decode multiple steps to reduce the overhead
|
353
|
-
if batch.forward_mode.is_decode():
|
354
|
-
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
355
|
-
if not self.running_batch:
|
356
|
-
break
|
357
|
-
self.update_running_batch()
|
358
|
-
if not self.running_batch:
|
359
|
-
break
|
360
|
-
result = self.run_batch(batch)
|
361
|
-
self.process_batch_result(batch, result)
|
362
374
|
else:
|
375
|
+
# Self-check and re-init some states when the server is idle
|
363
376
|
self.check_memory()
|
364
377
|
self.new_token_ratio = self.init_new_token_ratio
|
365
378
|
|
366
379
|
self.last_batch = batch
|
367
380
|
|
368
|
-
@torch.
|
381
|
+
@torch.no_grad()
|
369
382
|
def event_loop_overlap(self):
|
370
383
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
371
384
|
result_queue = deque()
|
372
385
|
|
373
|
-
self.last_batch = None
|
374
|
-
self.running_batch = None
|
375
|
-
|
376
386
|
while True:
|
377
387
|
recv_reqs = self.recv_requests()
|
378
388
|
self.process_input_requests(recv_reqs)
|
@@ -383,17 +393,85 @@ class Scheduler:
|
|
383
393
|
result = self.run_batch(batch)
|
384
394
|
result_queue.append((batch.copy(), result))
|
385
395
|
|
396
|
+
if self.last_batch is None:
|
397
|
+
# A dummy first batch to start the pipeline for overlap scheduler.
|
398
|
+
# It is now used for triggering the sampling_info_done event.
|
399
|
+
tmp_batch = ScheduleBatch(
|
400
|
+
reqs=None,
|
401
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
402
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
403
|
+
)
|
404
|
+
self.process_batch_result(tmp_batch, None)
|
405
|
+
|
386
406
|
if self.last_batch:
|
387
407
|
tmp_batch, tmp_result = result_queue.popleft()
|
408
|
+
tmp_batch.next_batch_sampling_info = (
|
409
|
+
self.tp_worker.cur_sampling_info if batch else None
|
410
|
+
)
|
388
411
|
self.process_batch_result(tmp_batch, tmp_result)
|
389
412
|
elif batch is None:
|
413
|
+
# Self-check and re-init some states when the server is idle
|
390
414
|
self.check_memory()
|
391
415
|
self.new_token_ratio = self.init_new_token_ratio
|
392
416
|
|
393
417
|
self.last_batch = batch
|
394
418
|
|
419
|
+
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
420
|
+
# Check if other DP workers have running batches
|
421
|
+
if local_batch is None:
|
422
|
+
num_tokens = 0
|
423
|
+
elif local_batch.forward_mode.is_decode():
|
424
|
+
num_tokens = local_batch.batch_size()
|
425
|
+
else:
|
426
|
+
num_tokens = local_batch.extend_num_tokens
|
427
|
+
|
428
|
+
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
429
|
+
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
430
|
+
torch.distributed.all_gather_into_tensor(
|
431
|
+
global_num_tokens,
|
432
|
+
local_num_tokens,
|
433
|
+
group=self.tp_cpu_group,
|
434
|
+
)
|
435
|
+
|
436
|
+
if local_batch is None and global_num_tokens.max().item() > 0:
|
437
|
+
local_batch = self.get_idle_batch()
|
438
|
+
|
439
|
+
if local_batch is not None:
|
440
|
+
local_batch.global_num_tokens = global_num_tokens.tolist()
|
441
|
+
|
442
|
+
# Check forward mode for cuda graph
|
443
|
+
if not self.server_args.disable_cuda_graph:
|
444
|
+
forward_mode_state = torch.tensor(
|
445
|
+
(
|
446
|
+
1
|
447
|
+
if local_batch.forward_mode.is_decode()
|
448
|
+
or local_batch.forward_mode.is_idle()
|
449
|
+
else 0
|
450
|
+
),
|
451
|
+
dtype=torch.int32,
|
452
|
+
)
|
453
|
+
torch.distributed.all_reduce(
|
454
|
+
forward_mode_state,
|
455
|
+
op=torch.distributed.ReduceOp.MIN,
|
456
|
+
group=self.tp_cpu_group,
|
457
|
+
)
|
458
|
+
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
459
|
+
|
460
|
+
return local_batch
|
461
|
+
|
462
|
+
def get_idle_batch(self):
|
463
|
+
idle_batch = ScheduleBatch.init_new(
|
464
|
+
[],
|
465
|
+
self.req_to_token_pool,
|
466
|
+
self.token_to_kv_pool,
|
467
|
+
self.tree_cache,
|
468
|
+
self.model_config,
|
469
|
+
)
|
470
|
+
idle_batch.prepare_for_idle()
|
471
|
+
return idle_batch
|
472
|
+
|
395
473
|
def recv_requests(self):
|
396
|
-
if self.tp_rank == 0:
|
474
|
+
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
397
475
|
recv_reqs = []
|
398
476
|
|
399
477
|
while True:
|
@@ -405,7 +483,7 @@ class Scheduler:
|
|
405
483
|
else:
|
406
484
|
recv_reqs = None
|
407
485
|
|
408
|
-
if self.tp_size != 1:
|
486
|
+
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
409
487
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
410
488
|
return recv_reqs
|
411
489
|
|
@@ -421,7 +499,7 @@ class Scheduler:
|
|
421
499
|
self.abort_request(recv_req)
|
422
500
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
423
501
|
success, message = self.update_weights(recv_req)
|
424
|
-
self.
|
502
|
+
self.send_to_tokenizer.send_pyobj(
|
425
503
|
UpdateWeightReqOutput(success, message)
|
426
504
|
)
|
427
505
|
elif isinstance(recv_req, ProfileReq):
|
@@ -429,8 +507,13 @@ class Scheduler:
|
|
429
507
|
self.start_profile()
|
430
508
|
else:
|
431
509
|
self.stop_profile()
|
510
|
+
elif isinstance(recv_req, OpenSessionReqInput):
|
511
|
+
session_id = self.open_session(recv_req)
|
512
|
+
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
513
|
+
elif isinstance(recv_req, CloseSessionReqInput):
|
514
|
+
self.close_session(recv_req)
|
432
515
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
433
|
-
self.
|
516
|
+
self.send_to_tokenizer.send_pyobj(
|
434
517
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
435
518
|
)
|
436
519
|
else:
|
@@ -440,14 +523,30 @@ class Scheduler:
|
|
440
523
|
self,
|
441
524
|
recv_req: TokenizedGenerateReqInput,
|
442
525
|
):
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
526
|
+
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
527
|
+
req = Req(
|
528
|
+
recv_req.rid,
|
529
|
+
recv_req.input_text,
|
530
|
+
recv_req.input_ids,
|
531
|
+
recv_req.sampling_params,
|
532
|
+
lora_path=recv_req.lora_path,
|
533
|
+
)
|
534
|
+
req.tokenizer = self.tokenizer
|
535
|
+
if recv_req.session_id is not None:
|
536
|
+
req.finished_reason = FINISH_ABORT(
|
537
|
+
f"Invalid request: session id {recv_req.session_id} does not exist"
|
538
|
+
)
|
539
|
+
self.waiting_queue.append(req)
|
540
|
+
return
|
541
|
+
else:
|
542
|
+
# Handle sessions
|
543
|
+
session = self.sessions[recv_req.session_id]
|
544
|
+
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
545
|
+
del self.sessions[recv_req.session_id]
|
546
|
+
self.sessions[new_session_id] = session
|
547
|
+
if isinstance(req.finished_reason, FINISH_ABORT):
|
548
|
+
self.waiting_queue.append(req)
|
549
|
+
return
|
451
550
|
|
452
551
|
# Image inputs
|
453
552
|
if recv_req.image_inputs is not None:
|
@@ -458,6 +557,15 @@ class Scheduler:
|
|
458
557
|
req.origin_input_ids_unpadded, req.image_inputs
|
459
558
|
)
|
460
559
|
|
560
|
+
if len(req.origin_input_ids) > self.max_req_input_len:
|
561
|
+
req.finished_reason = FINISH_ABORT(
|
562
|
+
"Image request length is longer than the KV cache pool size or "
|
563
|
+
"the max context length aborting because you cannot truncate the image embeds"
|
564
|
+
)
|
565
|
+
req.sampling_params.max_new_tokens = 0
|
566
|
+
self.waiting_queue.append(req)
|
567
|
+
return
|
568
|
+
|
461
569
|
req.return_logprob = recv_req.return_logprob
|
462
570
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
463
571
|
req.stream = recv_req.stream
|
@@ -595,21 +703,23 @@ class Scheduler:
|
|
595
703
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
596
704
|
)
|
597
705
|
if available_size != self.max_total_num_tokens:
|
598
|
-
|
599
|
-
"Warning: "
|
600
|
-
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
706
|
+
msg = (
|
601
707
|
"KV cache pool leak detected!"
|
708
|
+
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
602
709
|
)
|
603
|
-
|
710
|
+
warnings.warn(msg)
|
711
|
+
if crash_on_warnings():
|
712
|
+
raise ValueError(msg)
|
604
713
|
|
605
714
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
606
|
-
|
607
|
-
"Warning: "
|
608
|
-
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
609
|
-
f"total slots={self.req_to_token_pool.size}\n"
|
715
|
+
msg = (
|
610
716
|
"Memory pool leak detected!"
|
717
|
+
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
718
|
+
f"total_size={self.req_to_token_pool.size}\n"
|
611
719
|
)
|
612
|
-
|
720
|
+
warnings.warn(msg)
|
721
|
+
if crash_on_warnings():
|
722
|
+
raise ValueError(msg)
|
613
723
|
|
614
724
|
def get_next_batch_to_run(self):
|
615
725
|
# Merge the prefill batch into the running batch
|
@@ -743,7 +853,7 @@ class Scheduler:
|
|
743
853
|
self.tree_cache,
|
744
854
|
self.model_config,
|
745
855
|
)
|
746
|
-
new_batch.prepare_for_extend()
|
856
|
+
new_batch.prepare_for_extend(self.enable_overlap)
|
747
857
|
|
748
858
|
# Mixed-style chunked prefill
|
749
859
|
if self.is_mixed_chunk and self.running_batch is not None:
|
@@ -808,6 +918,10 @@ class Scheduler:
|
|
808
918
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
809
919
|
model_worker_batch
|
810
920
|
)
|
921
|
+
elif batch.forward_mode.is_idle():
|
922
|
+
model_worker_batch = batch.get_model_worker_batch()
|
923
|
+
self.tp_worker.forward_batch_idle(model_worker_batch)
|
924
|
+
return
|
811
925
|
else:
|
812
926
|
logits_output = None
|
813
927
|
if self.skip_tokenizer_init:
|
@@ -830,8 +944,12 @@ class Scheduler:
|
|
830
944
|
self.process_batch_result_decode(batch, result)
|
831
945
|
if batch.is_empty():
|
832
946
|
self.running_batch = None
|
833
|
-
|
947
|
+
elif batch.forward_mode.is_extend():
|
834
948
|
self.process_batch_result_prefill(batch, result)
|
949
|
+
elif batch.forward_mode.is_dummy_first():
|
950
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
951
|
+
torch.cuda.current_stream().synchronize()
|
952
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
835
953
|
|
836
954
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
837
955
|
|
@@ -839,7 +957,7 @@ class Scheduler:
|
|
839
957
|
logits_output, next_token_ids, bid = result
|
840
958
|
|
841
959
|
if self.enable_overlap:
|
842
|
-
logits_output, next_token_ids = self.tp_worker.
|
960
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
843
961
|
else:
|
844
962
|
# Move next_token_ids and logprobs to cpu
|
845
963
|
if batch.return_logprob:
|
@@ -859,14 +977,14 @@ class Scheduler:
|
|
859
977
|
|
860
978
|
# Check finish conditions
|
861
979
|
logprob_pt = 0
|
862
|
-
for i, req in enumerate(batch.reqs):
|
980
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
863
981
|
if req.is_retracted:
|
864
982
|
continue
|
865
983
|
|
866
984
|
if req.is_being_chunked <= 0:
|
867
985
|
# Inflight reqs' prefill is not finished
|
868
986
|
req.completion_tokens_wo_jump_forward += 1
|
869
|
-
req.output_ids.append(
|
987
|
+
req.output_ids.append(next_token_id)
|
870
988
|
req.check_finished()
|
871
989
|
|
872
990
|
if req.finished():
|
@@ -875,7 +993,7 @@ class Scheduler:
|
|
875
993
|
self.tree_cache.cache_unfinished_req(req)
|
876
994
|
|
877
995
|
if req.grammar is not None:
|
878
|
-
req.grammar.accept_token(
|
996
|
+
req.grammar.accept_token(next_token_id)
|
879
997
|
|
880
998
|
if req.return_logprob:
|
881
999
|
logprob_pt += self.add_logprob_return_values(
|
@@ -884,6 +1002,11 @@ class Scheduler:
|
|
884
1002
|
else:
|
885
1003
|
req.is_being_chunked -= 1
|
886
1004
|
|
1005
|
+
if batch.next_batch_sampling_info:
|
1006
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1007
|
+
torch.cuda.current_stream().synchronize()
|
1008
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1009
|
+
|
887
1010
|
else: # embedding or reward model
|
888
1011
|
embeddings, bid = result
|
889
1012
|
embeddings = embeddings.tolist()
|
@@ -914,7 +1037,7 @@ class Scheduler:
|
|
914
1037
|
self.num_generated_tokens += len(batch.reqs)
|
915
1038
|
|
916
1039
|
if self.enable_overlap:
|
917
|
-
logits_output, next_token_ids = self.tp_worker.
|
1040
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
918
1041
|
next_token_logprobs = logits_output.next_token_logprobs
|
919
1042
|
else:
|
920
1043
|
# Move next_token_ids and logprobs to cpu
|
@@ -932,7 +1055,7 @@ class Scheduler:
|
|
932
1055
|
if req.is_retracted:
|
933
1056
|
continue
|
934
1057
|
|
935
|
-
if self.
|
1058
|
+
if self.enable_overlap and req.finished():
|
936
1059
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
937
1060
|
continue
|
938
1061
|
|
@@ -953,6 +1076,11 @@ class Scheduler:
|
|
953
1076
|
if req.top_logprobs_num > 0:
|
954
1077
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
955
1078
|
|
1079
|
+
if batch.next_batch_sampling_info:
|
1080
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1081
|
+
torch.cuda.current_stream().synchronize()
|
1082
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1083
|
+
|
956
1084
|
self.stream_output(batch.reqs)
|
957
1085
|
|
958
1086
|
self.token_to_kv_pool.free_group_end()
|
@@ -1051,6 +1179,7 @@ class Scheduler:
|
|
1051
1179
|
output_skip_special_tokens = []
|
1052
1180
|
output_spaces_between_special_tokens = []
|
1053
1181
|
output_no_stop_trim = []
|
1182
|
+
output_session_ids = []
|
1054
1183
|
else: # embedding or reward model
|
1055
1184
|
output_embeddings = []
|
1056
1185
|
|
@@ -1078,6 +1207,7 @@ class Scheduler:
|
|
1078
1207
|
req.sampling_params.spaces_between_special_tokens
|
1079
1208
|
)
|
1080
1209
|
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1210
|
+
output_session_ids.append(req.session_id)
|
1081
1211
|
|
1082
1212
|
meta_info = {
|
1083
1213
|
"prompt_tokens": len(req.origin_input_ids),
|
@@ -1128,6 +1258,7 @@ class Scheduler:
|
|
1128
1258
|
output_meta_info,
|
1129
1259
|
output_finished_reason,
|
1130
1260
|
output_no_stop_trim,
|
1261
|
+
output_session_ids,
|
1131
1262
|
)
|
1132
1263
|
)
|
1133
1264
|
else: # embedding or reward model
|
@@ -1230,6 +1361,25 @@ class Scheduler:
|
|
1230
1361
|
)
|
1231
1362
|
logger.info("Profiler is done")
|
1232
1363
|
|
1364
|
+
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
1365
|
+
# handle error
|
1366
|
+
session_id = recv_req.session_id
|
1367
|
+
if session_id in self.sessions:
|
1368
|
+
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1369
|
+
else:
|
1370
|
+
self.sessions[session_id] = Session(
|
1371
|
+
recv_req.capacity_of_str_len, session_id
|
1372
|
+
)
|
1373
|
+
return session_id
|
1374
|
+
|
1375
|
+
def close_session(self, recv_req: CloseSessionReqInput):
|
1376
|
+
# handle error
|
1377
|
+
session_id = recv_req.session_id
|
1378
|
+
if session_id not in self.sessions:
|
1379
|
+
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
1380
|
+
else:
|
1381
|
+
del self.sessions[session_id]
|
1382
|
+
|
1233
1383
|
|
1234
1384
|
def run_scheduler_process(
|
1235
1385
|
server_args: ServerArgs,
|
@@ -1239,6 +1389,10 @@ def run_scheduler_process(
|
|
1239
1389
|
dp_rank: Optional[int],
|
1240
1390
|
pipe_writer,
|
1241
1391
|
):
|
1392
|
+
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
1393
|
+
if dp_rank is None:
|
1394
|
+
dp_rank = int(os.getenv("DP_RANK", -1))
|
1395
|
+
|
1242
1396
|
if dp_rank is None:
|
1243
1397
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1244
1398
|
else:
|
@@ -1249,7 +1403,7 @@ def run_scheduler_process(
|
|
1249
1403
|
try:
|
1250
1404
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1251
1405
|
pipe_writer.send("ready")
|
1252
|
-
if
|
1406
|
+
if scheduler.enable_overlap:
|
1253
1407
|
scheduler.event_loop_overlap()
|
1254
1408
|
else:
|
1255
1409
|
scheduler.event_loop_normal()
|
@@ -0,0 +1,62 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
import copy
|
15
|
+
import uuid
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
20
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
21
|
+
|
22
|
+
|
23
|
+
class Session:
|
24
|
+
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
25
|
+
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
26
|
+
self.capacity_of_str_len = capacity_of_str_len
|
27
|
+
self.reqs: List[Req] = []
|
28
|
+
|
29
|
+
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
30
|
+
# renew session id
|
31
|
+
self.session_id = uuid.uuid4().hex
|
32
|
+
if req.session_rid is not None:
|
33
|
+
while len(self.reqs) > 0:
|
34
|
+
if self.reqs[-1].rid == req.session_rid:
|
35
|
+
break
|
36
|
+
self.reqs = self.reqs[:-1]
|
37
|
+
if len(self.reqs) > 0:
|
38
|
+
input_ids = (
|
39
|
+
self.reqs[-1].origin_input_ids
|
40
|
+
+ self.reqs[-1].output_ids[
|
41
|
+
: self.reqs[-1].sampling_params.max_new_tokens
|
42
|
+
]
|
43
|
+
+ req.input_ids
|
44
|
+
)
|
45
|
+
else:
|
46
|
+
input_ids = req.input_ids
|
47
|
+
new_req = Req(
|
48
|
+
req.rid,
|
49
|
+
None,
|
50
|
+
input_ids,
|
51
|
+
req.sampling_params,
|
52
|
+
lora_path=req.lora_path,
|
53
|
+
session_id=self.session_id,
|
54
|
+
)
|
55
|
+
new_req.tokenizer = tokenizer
|
56
|
+
if req.session_rid is not None and len(self.reqs) == 0:
|
57
|
+
new_req.finished_reason = FINISH_ABORT(
|
58
|
+
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
self.reqs.append(new_req)
|
62
|
+
return new_req, self.session_id
|