sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -17,10 +17,11 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import json
|
19
19
|
import logging
|
20
|
-
import multiprocessing
|
21
20
|
import os
|
22
21
|
import time
|
23
22
|
import warnings
|
23
|
+
from collections import deque
|
24
|
+
from types import SimpleNamespace
|
24
25
|
from typing import List, Optional, Union
|
25
26
|
|
26
27
|
import torch
|
@@ -37,6 +38,7 @@ from sglang.srt.managers.io_struct import (
|
|
37
38
|
BatchEmbeddingOut,
|
38
39
|
BatchTokenIDOut,
|
39
40
|
FlushCacheReq,
|
41
|
+
ProfileReq,
|
40
42
|
TokenizedEmbeddingReqInput,
|
41
43
|
TokenizedGenerateReqInput,
|
42
44
|
TokenizedRewardReqInput,
|
@@ -76,6 +78,9 @@ logger = logging.getLogger(__name__)
|
|
76
78
|
# Crash on warning if we are running CI tests
|
77
79
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
78
80
|
|
81
|
+
# Test retract decode
|
82
|
+
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
83
|
+
|
79
84
|
|
80
85
|
class Scheduler:
|
81
86
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
@@ -106,7 +111,8 @@ class Scheduler:
|
|
106
111
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
107
112
|
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
108
113
|
else:
|
109
|
-
self.recv_from_tokenizer =
|
114
|
+
self.recv_from_tokenizer = None
|
115
|
+
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
110
116
|
|
111
117
|
# Init tokenizer
|
112
118
|
self.model_config = ModelConfig(
|
@@ -141,9 +147,10 @@ class Scheduler:
|
|
141
147
|
gpu_id=gpu_id,
|
142
148
|
tp_rank=tp_rank,
|
143
149
|
server_args=server_args,
|
144
|
-
nccl_port=port_args.
|
150
|
+
nccl_port=port_args.nccl_port,
|
145
151
|
)
|
146
152
|
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
153
|
+
self.device = self.tp_worker.device
|
147
154
|
|
148
155
|
# Get token and memory info from the model worker
|
149
156
|
(
|
@@ -189,8 +196,8 @@ class Scheduler:
|
|
189
196
|
|
190
197
|
# Init running status
|
191
198
|
self.waiting_queue: List[Req] = []
|
192
|
-
self.running_batch: ScheduleBatch = None
|
193
|
-
self.
|
199
|
+
self.running_batch: Optional[ScheduleBatch] = None
|
200
|
+
self.cur_batch: Optional[ScheduleBatch] = None
|
194
201
|
self.decode_forward_ct = 0
|
195
202
|
self.stream_interval = server_args.stream_interval
|
196
203
|
self.num_generated_tokens = 0
|
@@ -229,15 +236,92 @@ class Scheduler:
|
|
229
236
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
230
237
|
self.batch_is_full = False
|
231
238
|
|
239
|
+
# Init profiler
|
240
|
+
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
241
|
+
self.profiler = None
|
242
|
+
else:
|
243
|
+
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
244
|
+
logger.info(
|
245
|
+
"Profiling enabled. Traces will be saved to: %s",
|
246
|
+
self.torch_profiler_trace_dir,
|
247
|
+
)
|
248
|
+
self.profiler = torch.profiler.profile(
|
249
|
+
activities=[
|
250
|
+
torch.profiler.ProfilerActivity.CPU,
|
251
|
+
torch.profiler.ProfilerActivity.CUDA,
|
252
|
+
],
|
253
|
+
with_stack=True,
|
254
|
+
)
|
255
|
+
|
256
|
+
# Init states for overlap schedule
|
257
|
+
if self.server_args.enable_overlap_schedule:
|
258
|
+
self.forward_batch_generation = (
|
259
|
+
self.tp_worker.forward_batch_generation_non_blocking
|
260
|
+
)
|
261
|
+
self.resolve_next_token_ids = (
|
262
|
+
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
263
|
+
)
|
264
|
+
self.cache_finished_req = self.tree_cache.cache_finished_req
|
265
|
+
else:
|
266
|
+
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
267
|
+
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
268
|
+
self.cache_finished_req = self.tree_cache.cache_finished_req
|
269
|
+
|
270
|
+
@torch.inference_mode()
|
271
|
+
def event_loop_normal(self):
|
272
|
+
self.last_batch = None
|
273
|
+
|
274
|
+
while True:
|
275
|
+
recv_reqs = self.recv_requests()
|
276
|
+
self.process_input_requests(recv_reqs)
|
277
|
+
|
278
|
+
batch = self.get_next_batch_to_run()
|
279
|
+
|
280
|
+
if batch:
|
281
|
+
result = self.run_batch(batch)
|
282
|
+
self.process_batch_result(batch, result)
|
283
|
+
|
284
|
+
# Decode multiple steps to reduce the overhead
|
285
|
+
if batch.forward_mode.is_decode():
|
286
|
+
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
287
|
+
if not self.running_batch:
|
288
|
+
break
|
289
|
+
self.update_running_batch()
|
290
|
+
if not self.running_batch:
|
291
|
+
break
|
292
|
+
result = self.run_batch(batch)
|
293
|
+
self.process_batch_result(batch, result)
|
294
|
+
else:
|
295
|
+
self.check_memory()
|
296
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
297
|
+
|
298
|
+
self.last_batch = batch
|
299
|
+
|
232
300
|
@torch.inference_mode()
|
233
|
-
def
|
301
|
+
def event_loop_overlap(self):
|
302
|
+
result_queue = deque()
|
303
|
+
|
304
|
+
self.last_batch = None
|
305
|
+
self.running_batch = None
|
306
|
+
|
234
307
|
while True:
|
235
308
|
recv_reqs = self.recv_requests()
|
236
309
|
self.process_input_requests(recv_reqs)
|
237
310
|
|
238
|
-
self.
|
311
|
+
batch = self.get_next_batch_to_run()
|
312
|
+
self.cur_batch = batch
|
313
|
+
if batch:
|
314
|
+
result = self.run_batch(batch)
|
315
|
+
result_queue.append((batch.copy(), result))
|
239
316
|
|
240
|
-
self.
|
317
|
+
if self.last_batch:
|
318
|
+
tmp_batch, tmp_result = result_queue.popleft()
|
319
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
320
|
+
elif batch is None:
|
321
|
+
self.check_memory()
|
322
|
+
self.new_token_ratio = global_config.init_new_token_ratio
|
323
|
+
|
324
|
+
self.last_batch = batch
|
241
325
|
|
242
326
|
def recv_requests(self):
|
243
327
|
if self.tp_rank == 0:
|
@@ -270,7 +354,14 @@ class Scheduler:
|
|
270
354
|
self.abort_request(recv_req)
|
271
355
|
elif isinstance(recv_req, UpdateWeightReqInput):
|
272
356
|
success, message = self.update_weights(recv_req)
|
273
|
-
self.
|
357
|
+
self.send_to_detokenizer.send_pyobj(
|
358
|
+
UpdateWeightReqOutput(success, message)
|
359
|
+
)
|
360
|
+
elif isinstance(recv_req, ProfileReq):
|
361
|
+
if recv_req == ProfileReq.START_PROFILE:
|
362
|
+
self.start_profile()
|
363
|
+
else:
|
364
|
+
self.stop_profile()
|
274
365
|
else:
|
275
366
|
raise ValueError(f"Invalid request: {recv_req}")
|
276
367
|
|
@@ -363,12 +454,6 @@ class Scheduler:
|
|
363
454
|
|
364
455
|
self.waiting_queue.append(req)
|
365
456
|
|
366
|
-
def send_results(self):
|
367
|
-
if self.tp_rank == 0:
|
368
|
-
for obj in self.out_pyobjs:
|
369
|
-
self.send_to_detokenizer.send_pyobj(obj)
|
370
|
-
self.out_pyobjs = []
|
371
|
-
|
372
457
|
def print_decode_stats(self):
|
373
458
|
num_used = self.max_total_num_tokens - (
|
374
459
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
@@ -376,9 +461,10 @@ class Scheduler:
|
|
376
461
|
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
377
462
|
self.num_generated_tokens = 0
|
378
463
|
self.last_stats_tic = time.time()
|
464
|
+
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
379
465
|
logger.info(
|
380
466
|
f"Decode batch. "
|
381
|
-
f"#running-req: {
|
467
|
+
f"#running-req: {num_running_reqs}, "
|
382
468
|
f"#token: {num_used}, "
|
383
469
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
384
470
|
f"gen throughput (token/s): {throughput:.2f}, "
|
@@ -406,41 +492,45 @@ class Scheduler:
|
|
406
492
|
)
|
407
493
|
exit(1) if crash_on_warning else None
|
408
494
|
|
409
|
-
def
|
495
|
+
def get_next_batch_to_run(self):
|
496
|
+
# Merge the prefill batch into the running batch
|
497
|
+
if (
|
498
|
+
self.last_batch
|
499
|
+
and not self.last_batch.forward_mode.is_decode()
|
500
|
+
and not self.last_batch.is_empty()
|
501
|
+
):
|
502
|
+
if self.current_inflight_req:
|
503
|
+
self.last_batch.filter_batch(
|
504
|
+
current_inflight_req=self.current_inflight_req
|
505
|
+
)
|
506
|
+
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
507
|
+
# Inflight request keeps its rid but will get a new req_pool_idx.
|
508
|
+
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
509
|
+
self.batch_is_full = False
|
510
|
+
if not self.last_batch.is_empty():
|
511
|
+
if self.running_batch is None:
|
512
|
+
self.running_batch = self.last_batch
|
513
|
+
else:
|
514
|
+
self.running_batch.merge_batch(self.last_batch)
|
515
|
+
|
516
|
+
# Prefill first
|
410
517
|
new_batch = self.get_new_batch_prefill()
|
411
518
|
if new_batch is not None:
|
412
|
-
|
413
|
-
# replace run_batch with the uncommented line to use pytorch profiler
|
414
|
-
# result = pytorch_profile(
|
415
|
-
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
|
416
|
-
# )
|
417
|
-
result = self.run_batch(new_batch)
|
418
|
-
self.process_batch_result(new_batch, result)
|
419
|
-
else:
|
420
|
-
if self.running_batch is not None:
|
421
|
-
# Run a few decode batches continuously for reducing overhead
|
422
|
-
for _ in range(global_config.num_continue_decode_steps):
|
423
|
-
batch = self.get_new_batch_decode()
|
424
|
-
|
425
|
-
if batch:
|
426
|
-
# replace run_batch with the uncommented line to use pytorch profiler
|
427
|
-
# result = pytorch_profile(
|
428
|
-
# "profile_decode_step",
|
429
|
-
# self.run_batch,
|
430
|
-
# batch,
|
431
|
-
# data_size=len(batch.reqs),
|
432
|
-
# )
|
433
|
-
result = self.run_batch(batch)
|
434
|
-
self.process_batch_result(batch, result)
|
519
|
+
return new_batch
|
435
520
|
|
436
|
-
|
437
|
-
|
521
|
+
# Check memory
|
522
|
+
if self.running_batch is None:
|
523
|
+
return
|
438
524
|
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
525
|
+
# Run decode
|
526
|
+
before_bs = self.running_batch.batch_size()
|
527
|
+
self.update_running_batch()
|
528
|
+
if not self.running_batch:
|
529
|
+
self.batch_is_full = False
|
530
|
+
return None
|
531
|
+
if before_bs != self.running_batch.batch_size():
|
532
|
+
self.batch_is_full = False
|
533
|
+
return self.running_batch
|
444
534
|
|
445
535
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
446
536
|
# Handle the cases where prefill is not allowed
|
@@ -449,9 +539,7 @@ class Scheduler:
|
|
449
539
|
) and self.current_inflight_req is None:
|
450
540
|
return None
|
451
541
|
|
452
|
-
running_bs = (
|
453
|
-
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
454
|
-
)
|
542
|
+
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
455
543
|
if running_bs >= self.max_running_requests:
|
456
544
|
self.batch_is_full = True
|
457
545
|
return None
|
@@ -472,7 +560,7 @@ class Scheduler:
|
|
472
560
|
)
|
473
561
|
|
474
562
|
has_inflight = self.current_inflight_req is not None
|
475
|
-
if
|
563
|
+
if has_inflight:
|
476
564
|
self.current_inflight_req.init_next_round_input(
|
477
565
|
None if prefix_computed else self.tree_cache
|
478
566
|
)
|
@@ -480,7 +568,7 @@ class Scheduler:
|
|
480
568
|
self.current_inflight_req
|
481
569
|
)
|
482
570
|
|
483
|
-
if self.lora_paths
|
571
|
+
if self.lora_paths:
|
484
572
|
lora_set = (
|
485
573
|
set([req.lora_path for req in self.running_batch.reqs])
|
486
574
|
if self.running_batch is not None
|
@@ -489,7 +577,7 @@ class Scheduler:
|
|
489
577
|
|
490
578
|
for req in self.waiting_queue:
|
491
579
|
if (
|
492
|
-
self.lora_paths
|
580
|
+
self.lora_paths
|
493
581
|
and len(
|
494
582
|
lora_set
|
495
583
|
| set([req.lora_path for req in adder.can_run_list])
|
@@ -511,16 +599,20 @@ class Scheduler:
|
|
511
599
|
self.batch_is_full = True
|
512
600
|
break
|
513
601
|
|
602
|
+
# Update waiting queue
|
514
603
|
can_run_list = adder.can_run_list
|
604
|
+
if len(can_run_list) == 0:
|
605
|
+
return None
|
606
|
+
self.waiting_queue = [
|
607
|
+
x for x in self.waiting_queue if x not in set(can_run_list)
|
608
|
+
]
|
515
609
|
|
516
610
|
if adder.new_inflight_req is not None:
|
517
611
|
assert self.current_inflight_req is None
|
518
612
|
self.current_inflight_req = adder.new_inflight_req
|
519
613
|
|
520
|
-
if
|
521
|
-
|
522
|
-
|
523
|
-
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
614
|
+
if self.current_inflight_req:
|
615
|
+
self.current_inflight_req.is_inflight_req += 1
|
524
616
|
|
525
617
|
# Print stats
|
526
618
|
if self.tp_rank == 0:
|
@@ -573,21 +665,27 @@ class Scheduler:
|
|
573
665
|
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
574
666
|
|
575
667
|
# Mixed-style chunked prefill
|
576
|
-
decoding_reqs = []
|
577
668
|
if self.is_mixed_chunk and self.running_batch is not None:
|
578
669
|
self.running_batch.prepare_for_decode()
|
579
670
|
new_batch.mix_with_running(self.running_batch)
|
580
|
-
decoding_reqs = self.running_batch.reqs
|
671
|
+
new_batch.decoding_reqs = self.running_batch.reqs
|
581
672
|
self.running_batch = None
|
582
|
-
|
673
|
+
else:
|
674
|
+
new_batch.decoding_reqs = None
|
583
675
|
|
584
676
|
return new_batch
|
585
677
|
|
586
|
-
def
|
678
|
+
def update_running_batch(self):
|
679
|
+
global test_retract
|
587
680
|
batch = self.running_batch
|
588
681
|
|
682
|
+
batch.filter_batch()
|
683
|
+
if batch.is_empty():
|
684
|
+
self.running_batch = None
|
685
|
+
return
|
686
|
+
|
589
687
|
# Check if decode out of memory
|
590
|
-
if not batch.check_decode_mem():
|
688
|
+
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
591
689
|
old_ratio = self.new_token_ratio
|
592
690
|
|
593
691
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
@@ -610,17 +708,17 @@ class Scheduler:
|
|
610
708
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
611
709
|
self.waiting_queue.extend(jump_forward_reqs)
|
612
710
|
if batch.is_empty():
|
613
|
-
|
711
|
+
self.running_batch = None
|
712
|
+
return
|
614
713
|
|
615
714
|
# Update batch tensors
|
616
715
|
batch.prepare_for_decode()
|
617
|
-
return batch
|
618
716
|
|
619
717
|
def run_batch(self, batch: ScheduleBatch):
|
620
718
|
if self.is_generation:
|
621
719
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
622
720
|
model_worker_batch = batch.get_model_worker_batch()
|
623
|
-
logits_output, next_token_ids = self.
|
721
|
+
logits_output, next_token_ids = self.forward_batch_generation(
|
624
722
|
model_worker_batch
|
625
723
|
)
|
626
724
|
else:
|
@@ -631,34 +729,32 @@ class Scheduler:
|
|
631
729
|
)
|
632
730
|
else:
|
633
731
|
next_token_ids = torch.full((batch.batch_size(),), 0)
|
634
|
-
|
732
|
+
batch.output_ids = next_token_ids
|
733
|
+
ret = logits_output, next_token_ids, model_worker_batch.bid
|
635
734
|
else: # embedding or reward model
|
636
735
|
assert batch.extend_num_tokens != 0
|
637
736
|
model_worker_batch = batch.get_model_worker_batch()
|
638
737
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
639
|
-
|
738
|
+
ret = embeddings, model_worker_batch.bid
|
739
|
+
return ret
|
640
740
|
|
641
741
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
642
742
|
if batch.forward_mode.is_decode():
|
643
743
|
self.process_batch_result_decode(batch, result)
|
744
|
+
if batch.is_empty():
|
745
|
+
self.running_batch = None
|
644
746
|
else:
|
645
747
|
self.process_batch_result_prefill(batch, result)
|
646
748
|
|
647
749
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
648
750
|
if self.is_generation:
|
649
|
-
logits_output, next_token_ids = result
|
650
|
-
batch.
|
651
|
-
next_token_ids
|
652
|
-
)
|
653
|
-
|
654
|
-
if logits_output:
|
751
|
+
logits_output, next_token_ids, bid = result
|
752
|
+
if batch.return_logprob:
|
655
753
|
# Move logprobs to cpu
|
656
754
|
if logits_output.next_token_logprobs is not None:
|
657
755
|
logits_output.next_token_logprobs = (
|
658
756
|
logits_output.next_token_logprobs[
|
659
|
-
torch.arange(
|
660
|
-
len(next_token_ids), device=next_token_ids.device
|
661
|
-
),
|
757
|
+
torch.arange(len(next_token_ids), device=self.device),
|
662
758
|
next_token_ids,
|
663
759
|
].tolist()
|
664
760
|
)
|
@@ -669,84 +765,76 @@ class Scheduler:
|
|
669
765
|
logits_output.normalized_prompt_logprobs.tolist()
|
670
766
|
)
|
671
767
|
|
672
|
-
next_token_ids =
|
768
|
+
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
673
769
|
|
674
770
|
# Check finish conditions
|
675
771
|
logprob_pt = 0
|
676
772
|
for i, req in enumerate(batch.reqs):
|
677
|
-
if req
|
773
|
+
if req.is_inflight_req > 0:
|
774
|
+
req.is_inflight_req -= 1
|
775
|
+
else:
|
678
776
|
# Inflight reqs' prefill is not finished
|
679
777
|
req.completion_tokens_wo_jump_forward += 1
|
680
778
|
req.output_ids.append(next_token_ids[i])
|
681
779
|
req.check_finished()
|
682
780
|
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
781
|
+
if req.finished():
|
782
|
+
self.cache_finished_req(req)
|
783
|
+
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
784
|
+
self.tree_cache.cache_unfinished_req(req)
|
687
785
|
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
self.tree_cache.cache_unfinished_req(req)
|
693
|
-
|
694
|
-
if req is self.current_inflight_req:
|
695
|
-
# Inflight request would get a new req idx
|
696
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
786
|
+
if req.regex_fsm is not None:
|
787
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
788
|
+
req.regex_fsm_state, next_token_ids[i]
|
789
|
+
)
|
697
790
|
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
791
|
+
if req.return_logprob:
|
792
|
+
logprob_pt += self.add_logprob_return_values(
|
793
|
+
i, req, logprob_pt, next_token_ids, logits_output
|
794
|
+
)
|
702
795
|
else: # embedding or reward model
|
703
|
-
|
704
|
-
embeddings =
|
796
|
+
embeddings, bid = result
|
797
|
+
embeddings = embeddings.tolist()
|
705
798
|
|
706
799
|
# Check finish conditions
|
707
800
|
for i, req in enumerate(batch.reqs):
|
708
801
|
req.embedding = embeddings[i]
|
709
|
-
if req
|
802
|
+
if req.is_inflight_req > 0:
|
803
|
+
req.is_inflight_req -= 1
|
804
|
+
else:
|
710
805
|
# Inflight reqs' prefill is not finished
|
711
806
|
# dummy output token for embedding models
|
712
807
|
req.output_ids.append(0)
|
713
808
|
req.check_finished()
|
714
809
|
|
715
810
|
if req.finished():
|
716
|
-
self.
|
811
|
+
self.cache_finished_req(req)
|
717
812
|
else:
|
718
813
|
self.tree_cache.cache_unfinished_req(req)
|
719
814
|
|
720
|
-
|
721
|
-
# Inflight request would get a new req idx
|
722
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
723
|
-
|
724
|
-
self.handle_finished_requests(batch)
|
725
|
-
|
726
|
-
if not batch.is_empty():
|
727
|
-
if self.running_batch is None:
|
728
|
-
self.running_batch = batch
|
729
|
-
else:
|
730
|
-
self.running_batch.merge_batch(batch)
|
815
|
+
self.stream_output(batch.reqs)
|
731
816
|
|
732
817
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
733
|
-
logits_output, next_token_ids = result
|
734
|
-
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
735
|
-
next_token_ids
|
736
|
-
)
|
818
|
+
logits_output, next_token_ids, bid = result
|
737
819
|
self.num_generated_tokens += len(batch.reqs)
|
738
820
|
|
739
821
|
# Move logprobs to cpu
|
740
|
-
if
|
822
|
+
if batch.return_logprob:
|
741
823
|
next_token_logprobs = logits_output.next_token_logprobs[
|
742
|
-
torch.arange(len(next_token_ids), device=
|
824
|
+
torch.arange(len(next_token_ids), device=self.device),
|
743
825
|
next_token_ids,
|
744
826
|
].tolist()
|
745
827
|
|
746
|
-
next_token_ids =
|
828
|
+
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
829
|
+
|
830
|
+
self.token_to_kv_pool.free_group_begin()
|
747
831
|
|
748
832
|
# Check finish condition
|
749
833
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
834
|
+
if self.server_args.enable_overlap_schedule and req.finished():
|
835
|
+
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
836
|
+
continue
|
837
|
+
|
750
838
|
req.completion_tokens_wo_jump_forward += 1
|
751
839
|
req.output_ids.append(next_token_id)
|
752
840
|
req.check_finished()
|
@@ -757,7 +845,7 @@ class Scheduler:
|
|
757
845
|
)
|
758
846
|
|
759
847
|
if req.finished():
|
760
|
-
self.
|
848
|
+
self.cache_finished_req(req)
|
761
849
|
|
762
850
|
if req.return_logprob:
|
763
851
|
req.output_token_logprobs.append(
|
@@ -766,15 +854,14 @@ class Scheduler:
|
|
766
854
|
if req.top_logprobs_num > 0:
|
767
855
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
768
856
|
|
769
|
-
self.
|
857
|
+
self.stream_output(batch.reqs)
|
858
|
+
|
859
|
+
self.token_to_kv_pool.free_group_end()
|
770
860
|
|
771
861
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
772
862
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
773
863
|
self.print_decode_stats()
|
774
864
|
|
775
|
-
if self.running_batch.is_empty():
|
776
|
-
self.running_batch = None
|
777
|
-
|
778
865
|
def add_logprob_return_values(
|
779
866
|
self,
|
780
867
|
i: int,
|
@@ -848,7 +935,7 @@ class Scheduler:
|
|
848
935
|
|
849
936
|
return num_input_logprobs
|
850
937
|
|
851
|
-
def
|
938
|
+
def stream_output(self, reqs: List[Req]):
|
852
939
|
output_rids = []
|
853
940
|
output_meta_info = []
|
854
941
|
output_finished_reason: List[BaseFinishReason] = []
|
@@ -859,22 +946,15 @@ class Scheduler:
|
|
859
946
|
output_read_offsets = []
|
860
947
|
output_skip_special_tokens = []
|
861
948
|
output_spaces_between_special_tokens = []
|
949
|
+
output_no_stop_trim = []
|
862
950
|
else: # embedding or reward model
|
863
951
|
output_embeddings = []
|
864
|
-
unfinished_indices = []
|
865
952
|
|
866
|
-
|
867
|
-
if not req.finished() and req is not self.current_inflight_req:
|
868
|
-
unfinished_indices.append(i)
|
869
|
-
else:
|
870
|
-
self.batch_is_full = False
|
953
|
+
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
871
954
|
|
955
|
+
for req in reqs:
|
872
956
|
if req.finished() or (
|
873
|
-
req.stream
|
874
|
-
and (
|
875
|
-
self.decode_forward_ct % self.stream_interval == 0
|
876
|
-
or len(req.output_ids) == 1
|
877
|
-
)
|
957
|
+
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
878
958
|
):
|
879
959
|
output_rids.append(req.rid)
|
880
960
|
output_finished_reason.append(req.finished_reason)
|
@@ -890,11 +970,13 @@ class Scheduler:
|
|
890
970
|
output_spaces_between_special_tokens.append(
|
891
971
|
req.sampling_params.spaces_between_special_tokens
|
892
972
|
)
|
973
|
+
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
893
974
|
|
894
975
|
meta_info = {
|
895
976
|
"prompt_tokens": len(req.origin_input_ids),
|
896
977
|
"completion_tokens": len(req.output_ids),
|
897
978
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
979
|
+
"cached_tokens": req.cached_tokens,
|
898
980
|
"finish_reason": (
|
899
981
|
req.finished_reason.to_json()
|
900
982
|
if req.finished_reason is not None
|
@@ -926,7 +1008,7 @@ class Scheduler:
|
|
926
1008
|
# Send to detokenizer
|
927
1009
|
if output_rids:
|
928
1010
|
if self.is_generation:
|
929
|
-
self.
|
1011
|
+
self.send_to_detokenizer.send_pyobj(
|
930
1012
|
BatchTokenIDOut(
|
931
1013
|
output_rids,
|
932
1014
|
output_vids,
|
@@ -937,10 +1019,11 @@ class Scheduler:
|
|
937
1019
|
output_spaces_between_special_tokens,
|
938
1020
|
output_meta_info,
|
939
1021
|
output_finished_reason,
|
1022
|
+
output_no_stop_trim,
|
940
1023
|
)
|
941
1024
|
)
|
942
1025
|
else: # embedding or reward model
|
943
|
-
self.
|
1026
|
+
self.send_to_detokenizer.send_pyobj(
|
944
1027
|
BatchEmbeddingOut(
|
945
1028
|
output_rids,
|
946
1029
|
output_embeddings,
|
@@ -949,9 +1032,6 @@ class Scheduler:
|
|
949
1032
|
)
|
950
1033
|
)
|
951
1034
|
|
952
|
-
# Remove finished reqs: update batch tensors
|
953
|
-
batch.filter_batch(unfinished_indices)
|
954
|
-
|
955
1035
|
def flush_cache(self):
|
956
1036
|
if len(self.waiting_queue) == 0 and (
|
957
1037
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
@@ -987,8 +1067,9 @@ class Scheduler:
|
|
987
1067
|
# Delete requests in the running batch
|
988
1068
|
if self.running_batch:
|
989
1069
|
for req in self.running_batch.reqs:
|
990
|
-
if req.rid == recv_req.rid:
|
1070
|
+
if req.rid == recv_req.rid and not req.finished():
|
991
1071
|
req.finished_reason = FINISH_ABORT()
|
1072
|
+
self.cache_finished_req(req)
|
992
1073
|
break
|
993
1074
|
|
994
1075
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
@@ -1000,21 +1081,43 @@ class Scheduler:
|
|
1000
1081
|
logger.error(message)
|
1001
1082
|
return success, message
|
1002
1083
|
|
1084
|
+
def start_profile(self) -> None:
|
1085
|
+
if self.profiler is None:
|
1086
|
+
raise RuntimeError("Profiler is not enabled.")
|
1087
|
+
self.profiler.start()
|
1088
|
+
|
1089
|
+
def stop_profile(self) -> None:
|
1090
|
+
if self.profiler is None:
|
1091
|
+
raise RuntimeError("Profiler is not enabled.")
|
1092
|
+
self.profiler.stop()
|
1093
|
+
self.profiler.export_chrome_trace(
|
1094
|
+
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
1095
|
+
)
|
1096
|
+
logger.info("Profiler is done")
|
1097
|
+
|
1003
1098
|
|
1004
1099
|
def run_scheduler_process(
|
1005
1100
|
server_args: ServerArgs,
|
1006
1101
|
port_args: PortArgs,
|
1007
1102
|
gpu_id: int,
|
1008
1103
|
tp_rank: int,
|
1104
|
+
dp_rank: Optional[int],
|
1009
1105
|
pipe_writer,
|
1010
1106
|
):
|
1011
|
-
|
1107
|
+
if dp_rank is None:
|
1108
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1109
|
+
else:
|
1110
|
+
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1111
|
+
|
1012
1112
|
suppress_other_loggers()
|
1013
1113
|
|
1014
1114
|
try:
|
1015
1115
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
1016
1116
|
pipe_writer.send("ready")
|
1017
|
-
|
1117
|
+
if server_args.enable_overlap_schedule:
|
1118
|
+
scheduler.event_loop_overlap()
|
1119
|
+
else:
|
1120
|
+
scheduler.event_loop_normal()
|
1018
1121
|
except Exception:
|
1019
1122
|
msg = get_exception_traceback()
|
1020
1123
|
logger.error(msg)
|