sglang 0.4.3.post4__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +124 -665
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +78 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +9 -4
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -11
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/RECORD +124 -79
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
42
42
|
from sglang.srt.managers.io_struct import (
|
43
43
|
AbortReq,
|
44
|
-
BatchEmbeddingOut,
|
45
|
-
BatchTokenIDOut,
|
46
44
|
CloseSessionReqInput,
|
47
45
|
FlushCacheReq,
|
48
46
|
GetInternalStateReq,
|
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
|
|
74
72
|
)
|
75
73
|
from sglang.srt.managers.schedule_batch import (
|
76
74
|
FINISH_ABORT,
|
77
|
-
BaseFinishReason,
|
78
75
|
ImageInputs,
|
79
76
|
Req,
|
80
77
|
ScheduleBatch,
|
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
|
|
85
82
|
PrefillAdder,
|
86
83
|
SchedulePolicy,
|
87
84
|
)
|
85
|
+
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
86
|
+
SchedulerOutputProcessorMixin,
|
87
|
+
)
|
88
88
|
from sglang.srt.managers.session_controller import Session
|
89
89
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
90
90
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
93
93
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
94
94
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
95
95
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
96
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
96
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
97
97
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
98
98
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
99
99
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
|
|
103
103
|
crash_on_warnings,
|
104
104
|
get_bool_env_var,
|
105
105
|
get_zmq_socket,
|
106
|
+
kill_itself_when_parent_died,
|
106
107
|
pyspy_dump_schedulers,
|
107
108
|
set_gpu_proc_affinity,
|
108
109
|
set_random_seed,
|
@@ -132,7 +133,7 @@ class EmbeddingBatchResult:
|
|
132
133
|
bid: int
|
133
134
|
|
134
135
|
|
135
|
-
class Scheduler:
|
136
|
+
class Scheduler(SchedulerOutputProcessorMixin):
|
136
137
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
137
138
|
|
138
139
|
def __init__(
|
@@ -159,6 +160,7 @@ class Scheduler:
|
|
159
160
|
)
|
160
161
|
self.gpu_id = gpu_id
|
161
162
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
163
|
+
self.page_size = server_args.page_size
|
162
164
|
|
163
165
|
# Distributed rank info
|
164
166
|
self.dp_size = server_args.dp_size
|
@@ -270,17 +272,18 @@ class Scheduler:
|
|
270
272
|
|
271
273
|
# Init running status
|
272
274
|
self.waiting_queue: List[Req] = []
|
273
|
-
self.staging_reqs = {}
|
274
275
|
# The running decoding batch for continuous batching
|
275
|
-
self.running_batch:
|
276
|
+
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
276
277
|
# The current forward batch
|
277
278
|
self.cur_batch: Optional[ScheduleBatch] = None
|
278
|
-
# The
|
279
|
+
# The last forward batch
|
279
280
|
self.last_batch: Optional[ScheduleBatch] = None
|
280
281
|
self.forward_ct = 0
|
281
282
|
self.forward_ct_decode = 0
|
282
283
|
self.num_generated_tokens = 0
|
284
|
+
self.num_prefill_tokens = 0
|
283
285
|
self.last_decode_stats_tic = time.time()
|
286
|
+
self.last_prefill_stats_tic = time.time()
|
284
287
|
self.return_health_check_ct = 0
|
285
288
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
286
289
|
if self.device == "cpu":
|
@@ -308,7 +311,11 @@ class Scheduler:
|
|
308
311
|
self.grammar_backend = None
|
309
312
|
|
310
313
|
# Init schedule policy and new token estimation
|
311
|
-
self.policy = SchedulePolicy(
|
314
|
+
self.policy = SchedulePolicy(
|
315
|
+
self.schedule_policy,
|
316
|
+
self.tree_cache,
|
317
|
+
self.enable_hierarchical_cache,
|
318
|
+
)
|
312
319
|
assert (
|
313
320
|
server_args.schedule_conservativeness >= 0
|
314
321
|
), "Invalid schedule_conservativeness"
|
@@ -327,11 +334,6 @@ class Scheduler:
|
|
327
334
|
) / global_config.default_new_token_ratio_decay_steps
|
328
335
|
self.new_token_ratio = self.init_new_token_ratio
|
329
336
|
|
330
|
-
# Tell whether the current running batch is full so that we can skip
|
331
|
-
# the check of whether to prefill new requests.
|
332
|
-
# This is an optimization to reduce the overhead of the prefill check.
|
333
|
-
self.batch_is_full = False
|
334
|
-
|
335
337
|
# Init watchdog thread
|
336
338
|
self.watchdog_timeout = server_args.watchdog_timeout
|
337
339
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
@@ -431,11 +433,13 @@ class Scheduler:
|
|
431
433
|
self.tree_cache = HiRadixCache(
|
432
434
|
req_to_token_pool=self.req_to_token_pool,
|
433
435
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
|
+
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
434
437
|
)
|
435
438
|
else:
|
436
439
|
self.tree_cache = RadixCache(
|
437
440
|
req_to_token_pool=self.req_to_token_pool,
|
438
441
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
442
|
+
page_size=self.page_size,
|
439
443
|
disable=server_args.disable_radix_cache,
|
440
444
|
)
|
441
445
|
|
@@ -457,6 +461,7 @@ class Scheduler:
|
|
457
461
|
# The largest context length (prefill + generation) of a single request
|
458
462
|
self._largest_prefill_decode_len: int = 0
|
459
463
|
self.last_gen_throughput: float = 0.0
|
464
|
+
self.last_input_throughput: float = 0.0
|
460
465
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
461
466
|
self.spec_num_total_accepted_tokens = 0
|
462
467
|
self.spec_num_total_forward_ct = 0
|
@@ -486,7 +491,7 @@ class Scheduler:
|
|
486
491
|
result = self.run_batch(batch)
|
487
492
|
self.process_batch_result(batch, result)
|
488
493
|
else:
|
489
|
-
# When the server is idle,
|
494
|
+
# When the server is idle, do self-check and re-init some states
|
490
495
|
self.check_memory()
|
491
496
|
self.new_token_ratio = self.init_new_token_ratio
|
492
497
|
|
@@ -526,7 +531,7 @@ class Scheduler:
|
|
526
531
|
)
|
527
532
|
self.process_batch_result(tmp_batch, tmp_result)
|
528
533
|
elif batch is None:
|
529
|
-
# When the server is idle,
|
534
|
+
# When the server is idle, do self-check and re-init some states
|
530
535
|
self.check_memory()
|
531
536
|
self.new_token_ratio = self.init_new_token_ratio
|
532
537
|
|
@@ -587,7 +592,7 @@ class Scheduler:
|
|
587
592
|
for recv_req in recv_reqs:
|
588
593
|
# If it is a health check generation request and there are running requests, ignore it.
|
589
594
|
if is_health_check_generate_req(recv_req) and (
|
590
|
-
self.chunked_req is not None or self.running_batch
|
595
|
+
self.chunked_req is not None or not self.running_batch.is_empty()
|
591
596
|
):
|
592
597
|
self.return_health_check_ct += 1
|
593
598
|
continue
|
@@ -767,6 +772,30 @@ class Scheduler:
|
|
767
772
|
)
|
768
773
|
req.tokenizer = self.tokenizer
|
769
774
|
|
775
|
+
# Handle multimodal inputs
|
776
|
+
if recv_req.image_inputs is not None:
|
777
|
+
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
778
|
+
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
779
|
+
req.origin_input_ids = self.pad_input_ids_func(
|
780
|
+
req.origin_input_ids, image_inputs
|
781
|
+
)
|
782
|
+
req.extend_image_inputs(image_inputs)
|
783
|
+
|
784
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
785
|
+
error_msg = (
|
786
|
+
"Multimodal prompt is too long after expanding multimodal tokens. "
|
787
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
788
|
+
)
|
789
|
+
logger.error(error_msg)
|
790
|
+
req.origin_input_ids = [0]
|
791
|
+
req.image_inputs = None
|
792
|
+
req.sampling_params.max_new_tokens = 0
|
793
|
+
req.finished_reason = FINISH_ABORT(
|
794
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
795
|
+
)
|
796
|
+
self.waiting_queue.append(req)
|
797
|
+
return
|
798
|
+
|
770
799
|
# Validate prompts length
|
771
800
|
error_msg = validate_input_length(
|
772
801
|
req,
|
@@ -787,6 +816,11 @@ class Scheduler:
|
|
787
816
|
can_run_list: List[Req],
|
788
817
|
running_bs: int,
|
789
818
|
):
|
819
|
+
gap_latency = time.time() - self.last_prefill_stats_tic
|
820
|
+
self.last_prefill_stats_tic = time.time()
|
821
|
+
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
822
|
+
self.num_prefill_tokens = 0
|
823
|
+
|
790
824
|
num_used = self.max_total_num_tokens - (
|
791
825
|
self.token_to_kv_pool_allocator.available_size()
|
792
826
|
+ self.tree_cache.evictable_size()
|
@@ -822,7 +856,7 @@ class Scheduler:
|
|
822
856
|
self.last_decode_stats_tic = time.time()
|
823
857
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
824
858
|
self.num_generated_tokens = 0
|
825
|
-
num_running_reqs = len(self.running_batch.reqs)
|
859
|
+
num_running_reqs = len(self.running_batch.reqs)
|
826
860
|
num_used = self.max_total_num_tokens - (
|
827
861
|
self.token_to_kv_pool_allocator.available_size()
|
828
862
|
+ self.tree_cache.evictable_size()
|
@@ -886,8 +920,10 @@ class Scheduler:
|
|
886
920
|
)
|
887
921
|
if memory_leak:
|
888
922
|
msg = (
|
889
|
-
"KV cache pool leak detected!"
|
923
|
+
"KV cache pool leak detected! "
|
890
924
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
925
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
926
|
+
f"{self.tree_cache.evictable_size()=}\n"
|
891
927
|
)
|
892
928
|
warnings.warn(msg)
|
893
929
|
if crash_on_warnings():
|
@@ -913,7 +949,7 @@ class Scheduler:
|
|
913
949
|
self.token_to_kv_pool_allocator.available_size()
|
914
950
|
+ self.tree_cache.evictable_size()
|
915
951
|
)
|
916
|
-
num_running_reqs = len(self.running_batch.reqs)
|
952
|
+
num_running_reqs = len(self.running_batch.reqs)
|
917
953
|
self.stats.num_running_reqs = num_running_reqs
|
918
954
|
self.stats.num_used_tokens = num_used
|
919
955
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
@@ -931,14 +967,20 @@ class Scheduler:
|
|
931
967
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
932
968
|
# chunked request keeps its rid but will get a new req_pool_idx
|
933
969
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
934
|
-
self.batch_is_full = False
|
970
|
+
self.running_batch.batch_is_full = False
|
935
971
|
|
972
|
+
# Filter batch
|
973
|
+
last_bs = self.last_batch.batch_size()
|
936
974
|
self.last_batch.filter_batch()
|
975
|
+
if self.last_batch.batch_size() < last_bs:
|
976
|
+
self.running_batch.batch_is_full = False
|
977
|
+
|
978
|
+
# Merge the new batch into the running batch
|
937
979
|
if not self.last_batch.is_empty():
|
938
|
-
if self.running_batch
|
980
|
+
if self.running_batch.is_empty():
|
939
981
|
self.running_batch = self.last_batch
|
940
982
|
else:
|
941
|
-
#
|
983
|
+
# Merge running_batch with prefill batch
|
942
984
|
self.running_batch.merge_batch(self.last_batch)
|
943
985
|
|
944
986
|
new_batch = self.get_new_batch_prefill()
|
@@ -947,11 +989,11 @@ class Scheduler:
|
|
947
989
|
ret = new_batch
|
948
990
|
else:
|
949
991
|
# Run decode
|
950
|
-
if self.running_batch
|
951
|
-
ret = None
|
952
|
-
else:
|
992
|
+
if not self.running_batch.is_empty():
|
953
993
|
self.running_batch = self.update_running_batch(self.running_batch)
|
954
|
-
ret = self.running_batch
|
994
|
+
ret = self.running_batch if not self.running_batch.is_empty() else None
|
995
|
+
else:
|
996
|
+
ret = None
|
955
997
|
|
956
998
|
# Handle DP attention
|
957
999
|
if self.server_args.enable_dp_attention:
|
@@ -966,15 +1008,20 @@ class Scheduler:
|
|
966
1008
|
|
967
1009
|
# Handle the cases where prefill is not allowed
|
968
1010
|
if (
|
969
|
-
self.batch_is_full or len(self.waiting_queue) == 0
|
1011
|
+
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
970
1012
|
) and self.chunked_req is None:
|
971
1013
|
return None
|
972
1014
|
|
973
|
-
running_bs = len(self.running_batch.reqs)
|
1015
|
+
running_bs = len(self.running_batch.reqs)
|
974
1016
|
if running_bs >= self.max_running_requests:
|
975
|
-
self.batch_is_full = True
|
1017
|
+
self.running_batch.batch_is_full = True
|
976
1018
|
return None
|
977
1019
|
|
1020
|
+
if self.enable_hierarchical_cache:
|
1021
|
+
# check for completion of hierarchical cache activities to release memory
|
1022
|
+
self.tree_cache.writing_check()
|
1023
|
+
self.tree_cache.loading_check()
|
1024
|
+
|
978
1025
|
# Get priority queue
|
979
1026
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
980
1027
|
|
@@ -989,17 +1036,13 @@ class Scheduler:
|
|
989
1036
|
running_bs if self.is_mixed_chunk else 0,
|
990
1037
|
)
|
991
1038
|
|
992
|
-
|
993
|
-
if is_chunked:
|
1039
|
+
if self.chunked_req is not None:
|
994
1040
|
self.chunked_req.init_next_round_input()
|
995
1041
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
996
1042
|
|
997
1043
|
if self.lora_paths:
|
998
|
-
lora_set = (
|
999
|
-
|
1000
|
-
if self.running_batch is not None
|
1001
|
-
else set([])
|
1002
|
-
)
|
1044
|
+
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
1045
|
+
|
1003
1046
|
# Get requests from the waiting queue to a new prefill batch
|
1004
1047
|
for req in self.waiting_queue:
|
1005
1048
|
if (
|
@@ -1011,49 +1054,33 @@ class Scheduler:
|
|
1011
1054
|
)
|
1012
1055
|
> self.max_loras_per_batch
|
1013
1056
|
):
|
1014
|
-
self.batch_is_full = True
|
1057
|
+
self.running_batch.batch_is_full = True
|
1015
1058
|
break
|
1016
1059
|
|
1017
1060
|
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
1018
|
-
self.batch_is_full = True
|
1061
|
+
self.running_batch.batch_is_full = True
|
1019
1062
|
break
|
1020
1063
|
|
1021
|
-
req.init_next_round_input(
|
1064
|
+
req.init_next_round_input(
|
1065
|
+
None if prefix_computed else self.tree_cache,
|
1066
|
+
self.enable_hierarchical_cache,
|
1067
|
+
)
|
1022
1068
|
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
1027
|
-
req.last_node,
|
1028
|
-
req.prefix_indices,
|
1029
|
-
adder.rem_total_tokens,
|
1030
|
-
)
|
1031
|
-
if req.last_node.loading:
|
1032
|
-
# to prevent frequent cache invalidation
|
1033
|
-
if req.rid in self.staging_reqs:
|
1034
|
-
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1035
|
-
self.tree_cache.inc_lock_ref(req.last_node)
|
1036
|
-
self.staging_reqs[req.rid] = req.last_node
|
1037
|
-
continue
|
1038
|
-
elif req.last_node.loading:
|
1039
|
-
if not self.tree_cache.loading_complete(req.last_node):
|
1040
|
-
continue
|
1041
|
-
|
1042
|
-
if req.rid in self.staging_reqs:
|
1043
|
-
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1044
|
-
del self.staging_reqs[req.rid]
|
1045
|
-
|
1046
|
-
res = adder.add_one_req(req, self.chunked_req)
|
1069
|
+
res = adder.add_one_req(
|
1070
|
+
req, self.chunked_req, self.enable_hierarchical_cache
|
1071
|
+
)
|
1047
1072
|
if res != AddReqResult.CONTINUE:
|
1048
1073
|
if res == AddReqResult.NO_TOKEN:
|
1049
1074
|
if self.enable_hierarchical_cache:
|
1050
1075
|
# Set batch_is_full after making sure there are requests that can be served
|
1051
|
-
self.batch_is_full = len(
|
1076
|
+
self.running_batch.batch_is_full = len(
|
1077
|
+
adder.can_run_list
|
1078
|
+
) > 0 or (
|
1052
1079
|
self.running_batch is not None
|
1053
1080
|
and not self.running_batch.is_empty()
|
1054
1081
|
)
|
1055
1082
|
else:
|
1056
|
-
self.batch_is_full = True
|
1083
|
+
self.running_batch.batch_is_full = True
|
1057
1084
|
break
|
1058
1085
|
|
1059
1086
|
# Update waiting queue
|
@@ -1064,6 +1091,9 @@ class Scheduler:
|
|
1064
1091
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1065
1092
|
]
|
1066
1093
|
|
1094
|
+
if self.enable_hierarchical_cache:
|
1095
|
+
self.tree_cache.read_to_load_cache()
|
1096
|
+
|
1067
1097
|
if adder.new_chunked_req is not None:
|
1068
1098
|
assert self.chunked_req is None
|
1069
1099
|
self.chunked_req = adder.new_chunked_req
|
@@ -1091,7 +1121,7 @@ class Scheduler:
|
|
1091
1121
|
# Mixed-style chunked prefill
|
1092
1122
|
if (
|
1093
1123
|
self.is_mixed_chunk
|
1094
|
-
and self.running_batch
|
1124
|
+
and not self.running_batch.is_empty()
|
1095
1125
|
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
1096
1126
|
):
|
1097
1127
|
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
@@ -1100,7 +1130,9 @@ class Scheduler:
|
|
1100
1130
|
self.running_batch.prepare_for_decode()
|
1101
1131
|
new_batch.mix_with_running(self.running_batch)
|
1102
1132
|
new_batch.decoding_reqs = self.running_batch.reqs
|
1103
|
-
self.running_batch =
|
1133
|
+
self.running_batch = ScheduleBatch(
|
1134
|
+
reqs=[], batch_is_full=self.running_batch.batch_is_full
|
1135
|
+
)
|
1104
1136
|
else:
|
1105
1137
|
new_batch.decoding_reqs = None
|
1106
1138
|
|
@@ -1112,8 +1144,8 @@ class Scheduler:
|
|
1112
1144
|
|
1113
1145
|
batch.filter_batch()
|
1114
1146
|
if batch.is_empty():
|
1115
|
-
|
1116
|
-
return
|
1147
|
+
batch.batch_is_full = False
|
1148
|
+
return batch
|
1117
1149
|
|
1118
1150
|
# Check if decode out of memory
|
1119
1151
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
@@ -1137,7 +1169,7 @@ class Scheduler:
|
|
1137
1169
|
)
|
1138
1170
|
|
1139
1171
|
if batch.batch_size() < initial_bs:
|
1140
|
-
|
1172
|
+
batch.batch_is_full = False
|
1141
1173
|
|
1142
1174
|
# Update batch tensors
|
1143
1175
|
batch.prepare_for_decode()
|
@@ -1212,8 +1244,6 @@ class Scheduler:
|
|
1212
1244
|
):
|
1213
1245
|
if batch.forward_mode.is_decode():
|
1214
1246
|
self.process_batch_result_decode(batch, result)
|
1215
|
-
if batch.is_empty():
|
1216
|
-
self.running_batch = None
|
1217
1247
|
elif batch.forward_mode.is_extend():
|
1218
1248
|
self.process_batch_result_prefill(batch, result)
|
1219
1249
|
elif batch.forward_mode.is_idle():
|
@@ -1235,578 +1265,6 @@ class Scheduler:
|
|
1235
1265
|
self.return_health_check_ct -= 1
|
1236
1266
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1237
1267
|
|
1238
|
-
def process_batch_result_prefill(
|
1239
|
-
self,
|
1240
|
-
batch: ScheduleBatch,
|
1241
|
-
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1242
|
-
):
|
1243
|
-
skip_stream_req = None
|
1244
|
-
|
1245
|
-
if self.is_generation:
|
1246
|
-
(
|
1247
|
-
logits_output,
|
1248
|
-
next_token_ids,
|
1249
|
-
extend_input_len_per_req,
|
1250
|
-
extend_logprob_start_len_per_req,
|
1251
|
-
bid,
|
1252
|
-
) = (
|
1253
|
-
result.logits_output,
|
1254
|
-
result.next_token_ids,
|
1255
|
-
result.extend_input_len_per_req,
|
1256
|
-
result.extend_logprob_start_len_per_req,
|
1257
|
-
result.bid,
|
1258
|
-
)
|
1259
|
-
|
1260
|
-
if self.enable_overlap:
|
1261
|
-
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1262
|
-
else:
|
1263
|
-
# Move next_token_ids and logprobs to cpu
|
1264
|
-
next_token_ids = next_token_ids.tolist()
|
1265
|
-
if batch.return_logprob:
|
1266
|
-
if logits_output.next_token_logprobs is not None:
|
1267
|
-
logits_output.next_token_logprobs = (
|
1268
|
-
logits_output.next_token_logprobs.tolist()
|
1269
|
-
)
|
1270
|
-
if logits_output.input_token_logprobs is not None:
|
1271
|
-
logits_output.input_token_logprobs = tuple(
|
1272
|
-
logits_output.input_token_logprobs.tolist()
|
1273
|
-
)
|
1274
|
-
|
1275
|
-
hidden_state_offset = 0
|
1276
|
-
|
1277
|
-
# Check finish conditions
|
1278
|
-
logprob_pt = 0
|
1279
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1280
|
-
if req.is_retracted:
|
1281
|
-
continue
|
1282
|
-
|
1283
|
-
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
1284
|
-
# Free the one delayed token for the mixed decode batch
|
1285
|
-
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
1286
|
-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
1287
|
-
continue
|
1288
|
-
|
1289
|
-
if req.is_chunked <= 0:
|
1290
|
-
# req output_ids are set here
|
1291
|
-
req.output_ids.append(next_token_id)
|
1292
|
-
req.check_finished()
|
1293
|
-
|
1294
|
-
if req.finished():
|
1295
|
-
self.tree_cache.cache_finished_req(req)
|
1296
|
-
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
1297
|
-
# This updates radix so others can match
|
1298
|
-
self.tree_cache.cache_unfinished_req(req)
|
1299
|
-
|
1300
|
-
if req.return_logprob:
|
1301
|
-
assert extend_logprob_start_len_per_req is not None
|
1302
|
-
assert extend_input_len_per_req is not None
|
1303
|
-
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1304
|
-
extend_input_len = extend_input_len_per_req[i]
|
1305
|
-
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
1306
|
-
self.add_logprob_return_values(
|
1307
|
-
i,
|
1308
|
-
req,
|
1309
|
-
logprob_pt,
|
1310
|
-
next_token_ids,
|
1311
|
-
num_input_logprobs,
|
1312
|
-
logits_output,
|
1313
|
-
)
|
1314
|
-
logprob_pt += num_input_logprobs
|
1315
|
-
|
1316
|
-
if (
|
1317
|
-
req.return_hidden_states
|
1318
|
-
and logits_output.hidden_states is not None
|
1319
|
-
):
|
1320
|
-
req.hidden_states.append(
|
1321
|
-
logits_output.hidden_states[
|
1322
|
-
hidden_state_offset : (
|
1323
|
-
hidden_state_offset := hidden_state_offset
|
1324
|
-
+ len(req.origin_input_ids)
|
1325
|
-
)
|
1326
|
-
]
|
1327
|
-
.cpu()
|
1328
|
-
.clone()
|
1329
|
-
)
|
1330
|
-
|
1331
|
-
if req.grammar is not None:
|
1332
|
-
req.grammar.accept_token(next_token_id)
|
1333
|
-
req.grammar.finished = req.finished()
|
1334
|
-
else:
|
1335
|
-
# being chunked reqs' prefill is not finished
|
1336
|
-
req.is_chunked -= 1
|
1337
|
-
# There is only at most one request being currently chunked.
|
1338
|
-
# Because this request does not finish prefill,
|
1339
|
-
# we don't want to stream the request currently being chunked.
|
1340
|
-
skip_stream_req = req
|
1341
|
-
|
1342
|
-
# Incrementally update input logprobs.
|
1343
|
-
if req.return_logprob:
|
1344
|
-
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1345
|
-
extend_input_len = extend_input_len_per_req[i]
|
1346
|
-
if extend_logprob_start_len < extend_input_len:
|
1347
|
-
# Update input logprobs.
|
1348
|
-
num_input_logprobs = (
|
1349
|
-
extend_input_len - extend_logprob_start_len
|
1350
|
-
)
|
1351
|
-
self.add_input_logprob_return_values(
|
1352
|
-
i,
|
1353
|
-
req,
|
1354
|
-
logits_output,
|
1355
|
-
logprob_pt,
|
1356
|
-
num_input_logprobs,
|
1357
|
-
last_prefill_chunk=False,
|
1358
|
-
)
|
1359
|
-
logprob_pt += num_input_logprobs
|
1360
|
-
|
1361
|
-
if batch.next_batch_sampling_info:
|
1362
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1363
|
-
self.current_stream.synchronize()
|
1364
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1365
|
-
|
1366
|
-
else: # embedding or reward model
|
1367
|
-
embeddings, bid = result.embeddings, result.bid
|
1368
|
-
embeddings = embeddings.tolist()
|
1369
|
-
|
1370
|
-
# Check finish conditions
|
1371
|
-
for i, req in enumerate(batch.reqs):
|
1372
|
-
if req.is_retracted:
|
1373
|
-
continue
|
1374
|
-
|
1375
|
-
req.embedding = embeddings[i]
|
1376
|
-
if req.is_chunked <= 0:
|
1377
|
-
# Dummy output token for embedding models
|
1378
|
-
req.output_ids.append(0)
|
1379
|
-
req.check_finished()
|
1380
|
-
|
1381
|
-
if req.finished():
|
1382
|
-
self.tree_cache.cache_finished_req(req)
|
1383
|
-
else:
|
1384
|
-
self.tree_cache.cache_unfinished_req(req)
|
1385
|
-
else:
|
1386
|
-
# being chunked reqs' prefill is not finished
|
1387
|
-
req.is_chunked -= 1
|
1388
|
-
|
1389
|
-
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1390
|
-
|
1391
|
-
def process_batch_result_decode(
|
1392
|
-
self,
|
1393
|
-
batch: ScheduleBatch,
|
1394
|
-
result: GenerationBatchResult,
|
1395
|
-
):
|
1396
|
-
logits_output, next_token_ids, bid = (
|
1397
|
-
result.logits_output,
|
1398
|
-
result.next_token_ids,
|
1399
|
-
result.bid,
|
1400
|
-
)
|
1401
|
-
self.num_generated_tokens += len(batch.reqs)
|
1402
|
-
|
1403
|
-
if self.enable_overlap:
|
1404
|
-
assert batch.spec_algorithm.is_none()
|
1405
|
-
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1406
|
-
next_token_logprobs = logits_output.next_token_logprobs
|
1407
|
-
elif batch.spec_algorithm.is_none():
|
1408
|
-
# spec decoding handles output logprobs inside verify process.
|
1409
|
-
next_token_ids = next_token_ids.tolist()
|
1410
|
-
if batch.return_logprob:
|
1411
|
-
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1412
|
-
|
1413
|
-
self.token_to_kv_pool_allocator.free_group_begin()
|
1414
|
-
|
1415
|
-
# Check finish condition
|
1416
|
-
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
1417
|
-
# We should ignore using next_token_ids for spec decoding cases.
|
1418
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1419
|
-
if req.is_retracted:
|
1420
|
-
continue
|
1421
|
-
|
1422
|
-
if self.enable_overlap and req.finished():
|
1423
|
-
# Free the one delayed token
|
1424
|
-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
1425
|
-
continue
|
1426
|
-
|
1427
|
-
if batch.spec_algorithm.is_none():
|
1428
|
-
# speculative worker will solve the output_ids in speculative decoding
|
1429
|
-
req.output_ids.append(next_token_id)
|
1430
|
-
|
1431
|
-
req.check_finished()
|
1432
|
-
if req.finished():
|
1433
|
-
self.tree_cache.cache_finished_req(req)
|
1434
|
-
|
1435
|
-
if req.return_logprob and batch.spec_algorithm.is_none():
|
1436
|
-
# speculative worker handles logprob in speculative decoding
|
1437
|
-
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
1438
|
-
req.output_token_logprobs_idx.append(next_token_id)
|
1439
|
-
if req.top_logprobs_num > 0:
|
1440
|
-
req.output_top_logprobs_val.append(
|
1441
|
-
logits_output.next_token_top_logprobs_val[i]
|
1442
|
-
)
|
1443
|
-
req.output_top_logprobs_idx.append(
|
1444
|
-
logits_output.next_token_top_logprobs_idx[i]
|
1445
|
-
)
|
1446
|
-
if req.token_ids_logprob is not None:
|
1447
|
-
req.output_token_ids_logprobs_val.append(
|
1448
|
-
logits_output.next_token_token_ids_logprobs_val[i]
|
1449
|
-
)
|
1450
|
-
req.output_token_ids_logprobs_idx.append(
|
1451
|
-
logits_output.next_token_token_ids_logprobs_idx[i]
|
1452
|
-
)
|
1453
|
-
|
1454
|
-
if req.return_hidden_states and logits_output.hidden_states is not None:
|
1455
|
-
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1456
|
-
|
1457
|
-
if req.grammar is not None and batch.spec_algorithm.is_none():
|
1458
|
-
req.grammar.accept_token(next_token_id)
|
1459
|
-
req.grammar.finished = req.finished()
|
1460
|
-
|
1461
|
-
if batch.next_batch_sampling_info:
|
1462
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1463
|
-
self.current_stream.synchronize()
|
1464
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1465
|
-
|
1466
|
-
self.stream_output(batch.reqs, batch.return_logprob)
|
1467
|
-
|
1468
|
-
self.token_to_kv_pool_allocator.free_group_end()
|
1469
|
-
|
1470
|
-
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1471
|
-
if (
|
1472
|
-
self.attn_tp_rank == 0
|
1473
|
-
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
1474
|
-
):
|
1475
|
-
self.log_decode_stats()
|
1476
|
-
|
1477
|
-
def add_input_logprob_return_values(
|
1478
|
-
self,
|
1479
|
-
i: int,
|
1480
|
-
req: Req,
|
1481
|
-
output: LogitsProcessorOutput,
|
1482
|
-
logprob_pt: int,
|
1483
|
-
num_input_logprobs: int,
|
1484
|
-
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
1485
|
-
):
|
1486
|
-
"""Incrementally add input logprobs to `req`.
|
1487
|
-
|
1488
|
-
Args:
|
1489
|
-
i: The request index in a batch.
|
1490
|
-
req: The request. Input logprobs inside req are modified as a
|
1491
|
-
consequence of the API
|
1492
|
-
fill_ids: The prefill ids processed.
|
1493
|
-
output: Logit processor output that's used to compute input logprobs
|
1494
|
-
last_prefill_chunk: True if it is the last prefill (when chunked).
|
1495
|
-
Some of input logprob operation should only happen at the last
|
1496
|
-
prefill (e.g., computing input token logprobs).
|
1497
|
-
"""
|
1498
|
-
assert output.input_token_logprobs is not None
|
1499
|
-
if req.input_token_logprobs is None:
|
1500
|
-
req.input_token_logprobs = []
|
1501
|
-
if req.temp_input_top_logprobs_val is None:
|
1502
|
-
req.temp_input_top_logprobs_val = []
|
1503
|
-
if req.temp_input_top_logprobs_idx is None:
|
1504
|
-
req.temp_input_top_logprobs_idx = []
|
1505
|
-
if req.temp_input_token_ids_logprobs_val is None:
|
1506
|
-
req.temp_input_token_ids_logprobs_val = []
|
1507
|
-
if req.temp_input_token_ids_logprobs_idx is None:
|
1508
|
-
req.temp_input_token_ids_logprobs_idx = []
|
1509
|
-
|
1510
|
-
if req.input_token_logprobs_val is not None:
|
1511
|
-
# The input logprob has been already computed. It only happens
|
1512
|
-
# upon retract.
|
1513
|
-
if req.top_logprobs_num > 0:
|
1514
|
-
assert req.input_token_logprobs_val is not None
|
1515
|
-
return
|
1516
|
-
|
1517
|
-
# Important for the performance.
|
1518
|
-
assert isinstance(output.input_token_logprobs, tuple)
|
1519
|
-
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
1520
|
-
input_token_logprobs = input_token_logprobs[
|
1521
|
-
logprob_pt : logprob_pt + num_input_logprobs
|
1522
|
-
]
|
1523
|
-
req.input_token_logprobs.extend(input_token_logprobs)
|
1524
|
-
|
1525
|
-
if req.top_logprobs_num > 0:
|
1526
|
-
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
1527
|
-
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
1528
|
-
|
1529
|
-
if req.token_ids_logprob is not None:
|
1530
|
-
req.temp_input_token_ids_logprobs_val.append(
|
1531
|
-
output.input_token_ids_logprobs_val[i]
|
1532
|
-
)
|
1533
|
-
req.temp_input_token_ids_logprobs_idx.append(
|
1534
|
-
output.input_token_ids_logprobs_idx[i]
|
1535
|
-
)
|
1536
|
-
|
1537
|
-
if last_prefill_chunk:
|
1538
|
-
input_token_logprobs = req.input_token_logprobs
|
1539
|
-
req.input_token_logprobs = None
|
1540
|
-
assert req.input_token_logprobs_val is None
|
1541
|
-
assert req.input_token_logprobs_idx is None
|
1542
|
-
assert req.input_top_logprobs_val is None
|
1543
|
-
assert req.input_top_logprobs_idx is None
|
1544
|
-
|
1545
|
-
# Compute input_token_logprobs_val
|
1546
|
-
# Always pad the first one with None.
|
1547
|
-
req.input_token_logprobs_val = [None]
|
1548
|
-
req.input_token_logprobs_val.extend(input_token_logprobs)
|
1549
|
-
# The last input logprob is for sampling, so just pop it out.
|
1550
|
-
req.input_token_logprobs_val.pop()
|
1551
|
-
|
1552
|
-
# Compute input_token_logprobs_idx
|
1553
|
-
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
1554
|
-
# Clip the padded hash values from image tokens.
|
1555
|
-
# Otherwise, it will lead to detokenization errors.
|
1556
|
-
input_token_logprobs_idx = [
|
1557
|
-
x if x < self.model_config.vocab_size - 1 else 0
|
1558
|
-
for x in input_token_logprobs_idx
|
1559
|
-
]
|
1560
|
-
req.input_token_logprobs_idx = input_token_logprobs_idx
|
1561
|
-
|
1562
|
-
if req.top_logprobs_num > 0:
|
1563
|
-
req.input_top_logprobs_val = [None]
|
1564
|
-
req.input_top_logprobs_idx = [None]
|
1565
|
-
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
1566
|
-
req.temp_input_token_ids_logprobs_idx
|
1567
|
-
)
|
1568
|
-
for val, idx in zip(
|
1569
|
-
req.temp_input_top_logprobs_val,
|
1570
|
-
req.temp_input_top_logprobs_idx,
|
1571
|
-
strict=True,
|
1572
|
-
):
|
1573
|
-
req.input_top_logprobs_val.extend(val)
|
1574
|
-
req.input_top_logprobs_idx.extend(idx)
|
1575
|
-
|
1576
|
-
# Last token is a sample token.
|
1577
|
-
req.input_top_logprobs_val.pop()
|
1578
|
-
req.input_top_logprobs_idx.pop()
|
1579
|
-
req.temp_input_top_logprobs_idx = None
|
1580
|
-
req.temp_input_top_logprobs_val = None
|
1581
|
-
|
1582
|
-
if req.token_ids_logprob is not None:
|
1583
|
-
req.input_token_ids_logprobs_val = [None]
|
1584
|
-
req.input_token_ids_logprobs_idx = [None]
|
1585
|
-
|
1586
|
-
for val, idx in zip(
|
1587
|
-
req.temp_input_token_ids_logprobs_val,
|
1588
|
-
req.temp_input_token_ids_logprobs_idx,
|
1589
|
-
strict=True,
|
1590
|
-
):
|
1591
|
-
req.input_token_ids_logprobs_val.extend(val)
|
1592
|
-
req.input_token_ids_logprobs_idx.extend(idx)
|
1593
|
-
|
1594
|
-
# Last token is a sample token.
|
1595
|
-
req.input_token_ids_logprobs_val.pop()
|
1596
|
-
req.input_token_ids_logprobs_idx.pop()
|
1597
|
-
req.temp_input_token_ids_logprobs_idx = None
|
1598
|
-
req.temp_input_token_ids_logprobs_val = None
|
1599
|
-
|
1600
|
-
if req.return_logprob:
|
1601
|
-
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
1602
|
-
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
1603
|
-
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
1604
|
-
if req.top_logprobs_num > 0:
|
1605
|
-
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
1606
|
-
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
1607
|
-
if req.token_ids_logprob is not None:
|
1608
|
-
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
1609
|
-
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
1610
|
-
|
1611
|
-
def add_logprob_return_values(
|
1612
|
-
self,
|
1613
|
-
i: int,
|
1614
|
-
req: Req,
|
1615
|
-
pt: int,
|
1616
|
-
next_token_ids: List[int],
|
1617
|
-
num_input_logprobs: int,
|
1618
|
-
output: LogitsProcessorOutput,
|
1619
|
-
):
|
1620
|
-
"""Attach logprobs to the return values."""
|
1621
|
-
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
1622
|
-
req.output_token_logprobs_idx.append(next_token_ids[i])
|
1623
|
-
|
1624
|
-
self.add_input_logprob_return_values(
|
1625
|
-
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
1626
|
-
)
|
1627
|
-
|
1628
|
-
if req.top_logprobs_num > 0:
|
1629
|
-
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1630
|
-
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1631
|
-
|
1632
|
-
if req.token_ids_logprob is not None:
|
1633
|
-
req.output_token_ids_logprobs_val.append(
|
1634
|
-
output.next_token_token_ids_logprobs_val[i]
|
1635
|
-
)
|
1636
|
-
req.output_token_ids_logprobs_idx.append(
|
1637
|
-
output.next_token_token_ids_logprobs_idx[i]
|
1638
|
-
)
|
1639
|
-
|
1640
|
-
return num_input_logprobs
|
1641
|
-
|
1642
|
-
def stream_output(
|
1643
|
-
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
1644
|
-
):
|
1645
|
-
"""Stream the output to detokenizer."""
|
1646
|
-
rids = []
|
1647
|
-
finished_reasons: List[BaseFinishReason] = []
|
1648
|
-
|
1649
|
-
if self.is_generation:
|
1650
|
-
decoded_texts = []
|
1651
|
-
decode_ids_list = []
|
1652
|
-
read_offsets = []
|
1653
|
-
output_ids = []
|
1654
|
-
|
1655
|
-
skip_special_tokens = []
|
1656
|
-
spaces_between_special_tokens = []
|
1657
|
-
no_stop_trim = []
|
1658
|
-
prompt_tokens = []
|
1659
|
-
completion_tokens = []
|
1660
|
-
cached_tokens = []
|
1661
|
-
spec_verify_ct = []
|
1662
|
-
output_hidden_states = None
|
1663
|
-
|
1664
|
-
if return_logprob:
|
1665
|
-
input_token_logprobs_val = []
|
1666
|
-
input_token_logprobs_idx = []
|
1667
|
-
output_token_logprobs_val = []
|
1668
|
-
output_token_logprobs_idx = []
|
1669
|
-
input_top_logprobs_val = []
|
1670
|
-
input_top_logprobs_idx = []
|
1671
|
-
output_top_logprobs_val = []
|
1672
|
-
output_top_logprobs_idx = []
|
1673
|
-
input_token_ids_logprobs_val = []
|
1674
|
-
input_token_ids_logprobs_idx = []
|
1675
|
-
output_token_ids_logprobs_val = []
|
1676
|
-
output_token_ids_logprobs_idx = []
|
1677
|
-
else:
|
1678
|
-
input_token_logprobs_val = input_token_logprobs_idx = (
|
1679
|
-
output_token_logprobs_val
|
1680
|
-
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1681
|
-
input_top_logprobs_idx
|
1682
|
-
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
1683
|
-
input_token_ids_logprobs_val
|
1684
|
-
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
1685
|
-
output_token_ids_logprobs_idx
|
1686
|
-
) = None
|
1687
|
-
|
1688
|
-
for req in reqs:
|
1689
|
-
if req is skip_req:
|
1690
|
-
continue
|
1691
|
-
|
1692
|
-
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
1693
|
-
if self.model_config.is_multimodal_gen and req.to_abort:
|
1694
|
-
continue
|
1695
|
-
|
1696
|
-
if (
|
1697
|
-
req.finished()
|
1698
|
-
# If stream, follow the given stream_interval
|
1699
|
-
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
1700
|
-
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1701
|
-
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
1702
|
-
# always increase one-by-one.
|
1703
|
-
or (
|
1704
|
-
not req.stream
|
1705
|
-
and len(req.output_ids) % 50 == 0
|
1706
|
-
and not self.model_config.is_multimodal_gen
|
1707
|
-
)
|
1708
|
-
):
|
1709
|
-
rids.append(req.rid)
|
1710
|
-
finished_reasons.append(
|
1711
|
-
req.finished_reason.to_json() if req.finished_reason else None
|
1712
|
-
)
|
1713
|
-
decoded_texts.append(req.decoded_text)
|
1714
|
-
decode_ids, read_offset = req.init_incremental_detokenize()
|
1715
|
-
decode_ids_list.append(decode_ids)
|
1716
|
-
read_offsets.append(read_offset)
|
1717
|
-
if self.skip_tokenizer_init:
|
1718
|
-
output_ids.append(req.output_ids)
|
1719
|
-
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1720
|
-
spaces_between_special_tokens.append(
|
1721
|
-
req.sampling_params.spaces_between_special_tokens
|
1722
|
-
)
|
1723
|
-
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1724
|
-
|
1725
|
-
prompt_tokens.append(len(req.origin_input_ids))
|
1726
|
-
completion_tokens.append(len(req.output_ids))
|
1727
|
-
cached_tokens.append(req.cached_tokens)
|
1728
|
-
|
1729
|
-
if not self.spec_algorithm.is_none():
|
1730
|
-
spec_verify_ct.append(req.spec_verify_ct)
|
1731
|
-
|
1732
|
-
if return_logprob:
|
1733
|
-
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1734
|
-
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
1735
|
-
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
1736
|
-
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
1737
|
-
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
1738
|
-
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1739
|
-
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1740
|
-
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1741
|
-
input_token_ids_logprobs_val.append(
|
1742
|
-
req.input_token_ids_logprobs_val
|
1743
|
-
)
|
1744
|
-
input_token_ids_logprobs_idx.append(
|
1745
|
-
req.input_token_ids_logprobs_idx
|
1746
|
-
)
|
1747
|
-
output_token_ids_logprobs_val.append(
|
1748
|
-
req.output_token_ids_logprobs_val
|
1749
|
-
)
|
1750
|
-
output_token_ids_logprobs_idx.append(
|
1751
|
-
req.output_token_ids_logprobs_idx
|
1752
|
-
)
|
1753
|
-
|
1754
|
-
if req.return_hidden_states:
|
1755
|
-
if output_hidden_states is None:
|
1756
|
-
output_hidden_states = []
|
1757
|
-
output_hidden_states.append(req.hidden_states)
|
1758
|
-
|
1759
|
-
# Send to detokenizer
|
1760
|
-
if rids:
|
1761
|
-
if self.model_config.is_multimodal_gen:
|
1762
|
-
raise NotImplementedError()
|
1763
|
-
self.send_to_detokenizer.send_pyobj(
|
1764
|
-
BatchTokenIDOut(
|
1765
|
-
rids,
|
1766
|
-
finished_reasons,
|
1767
|
-
decoded_texts,
|
1768
|
-
decode_ids_list,
|
1769
|
-
read_offsets,
|
1770
|
-
output_ids,
|
1771
|
-
skip_special_tokens,
|
1772
|
-
spaces_between_special_tokens,
|
1773
|
-
no_stop_trim,
|
1774
|
-
prompt_tokens,
|
1775
|
-
completion_tokens,
|
1776
|
-
cached_tokens,
|
1777
|
-
spec_verify_ct,
|
1778
|
-
input_token_logprobs_val,
|
1779
|
-
input_token_logprobs_idx,
|
1780
|
-
output_token_logprobs_val,
|
1781
|
-
output_token_logprobs_idx,
|
1782
|
-
input_top_logprobs_val,
|
1783
|
-
input_top_logprobs_idx,
|
1784
|
-
output_top_logprobs_val,
|
1785
|
-
output_top_logprobs_idx,
|
1786
|
-
input_token_ids_logprobs_val,
|
1787
|
-
input_token_ids_logprobs_idx,
|
1788
|
-
output_token_ids_logprobs_val,
|
1789
|
-
output_token_ids_logprobs_idx,
|
1790
|
-
output_hidden_states,
|
1791
|
-
)
|
1792
|
-
)
|
1793
|
-
else: # embedding or reward model
|
1794
|
-
embeddings = []
|
1795
|
-
prompt_tokens = []
|
1796
|
-
cached_tokens = []
|
1797
|
-
for req in reqs:
|
1798
|
-
if req.finished():
|
1799
|
-
rids.append(req.rid)
|
1800
|
-
finished_reasons.append(req.finished_reason.to_json())
|
1801
|
-
embeddings.append(req.embedding)
|
1802
|
-
prompt_tokens.append(len(req.origin_input_ids))
|
1803
|
-
cached_tokens.append(req.cached_tokens)
|
1804
|
-
self.send_to_detokenizer.send_pyobj(
|
1805
|
-
BatchEmbeddingOut(
|
1806
|
-
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
|
1807
|
-
)
|
1808
|
-
)
|
1809
|
-
|
1810
1268
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1811
1269
|
# Check if other DP workers have running batches
|
1812
1270
|
if local_batch is None:
|
@@ -1926,9 +1384,7 @@ class Scheduler:
|
|
1926
1384
|
|
1927
1385
|
def flush_cache(self):
|
1928
1386
|
"""Flush the memory pool and cache."""
|
1929
|
-
if len(self.waiting_queue) == 0 and (
|
1930
|
-
self.running_batch is None or len(self.running_batch.reqs) == 0
|
1931
|
-
):
|
1387
|
+
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
1932
1388
|
self.cur_batch = None
|
1933
1389
|
self.last_batch = None
|
1934
1390
|
self.tree_cache.reset()
|
@@ -1954,7 +1410,7 @@ class Scheduler:
|
|
1954
1410
|
logging.warning(
|
1955
1411
|
f"Cache not flushed because there are pending requests. "
|
1956
1412
|
f"#queue-req: {len(self.waiting_queue)}, "
|
1957
|
-
f"#running-req: {
|
1413
|
+
f"#running-req: {len(self.running_batch.reqs)}"
|
1958
1414
|
)
|
1959
1415
|
if_success = False
|
1960
1416
|
return if_success
|
@@ -2004,24 +1460,24 @@ class Scheduler:
|
|
2004
1460
|
|
2005
1461
|
def abort_request(self, recv_req: AbortReq):
|
2006
1462
|
# Delete requests in the waiting queue
|
2007
|
-
to_del =
|
1463
|
+
to_del = []
|
2008
1464
|
for i, req in enumerate(self.waiting_queue):
|
2009
|
-
if req.rid
|
2010
|
-
to_del
|
1465
|
+
if req.rid.startswith(recv_req.rid):
|
1466
|
+
to_del.append(i)
|
2011
1467
|
break
|
2012
1468
|
|
2013
|
-
|
2014
|
-
|
1469
|
+
# Sort in reverse order to avoid index issues when deleting
|
1470
|
+
for i in sorted(to_del, reverse=True):
|
1471
|
+
req = self.waiting_queue.pop(i)
|
2015
1472
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2016
1473
|
return
|
2017
1474
|
|
2018
1475
|
# Delete requests in the running batch
|
2019
|
-
|
2020
|
-
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2024
|
-
break
|
1476
|
+
for req in self.running_batch.reqs:
|
1477
|
+
if req.rid.startswith(recv_req.rid) and not req.finished():
|
1478
|
+
logger.debug(f"Abort running request. {req.rid=}")
|
1479
|
+
req.to_abort = True
|
1480
|
+
return
|
2025
1481
|
|
2026
1482
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
2027
1483
|
raise NotImplementedError()
|
@@ -2228,9 +1684,16 @@ def run_scheduler_process(
|
|
2228
1684
|
dp_rank: Optional[int],
|
2229
1685
|
pipe_writer,
|
2230
1686
|
):
|
1687
|
+
|
1688
|
+
# Generate the prefix
|
1689
|
+
if dp_rank is None:
|
1690
|
+
prefix = f" TP{tp_rank}"
|
1691
|
+
else:
|
1692
|
+
prefix = f" DP{dp_rank} TP{tp_rank}"
|
1693
|
+
|
2231
1694
|
# Config the process
|
2232
1695
|
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
|
2233
|
-
setproctitle.setproctitle(f"sglang::
|
1696
|
+
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
2234
1697
|
faulthandler.enable()
|
2235
1698
|
parent_process = psutil.Process().parent()
|
2236
1699
|
|
@@ -2239,10 +1702,6 @@ def run_scheduler_process(
|
|
2239
1702
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2240
1703
|
|
2241
1704
|
# Configure the logger
|
2242
|
-
if dp_rank is None:
|
2243
|
-
prefix = f" TP{tp_rank}"
|
2244
|
-
else:
|
2245
|
-
prefix = f" DP{dp_rank} TP{tp_rank}"
|
2246
1705
|
configure_logger(server_args, prefix=prefix)
|
2247
1706
|
suppress_other_loggers()
|
2248
1707
|
|