sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
42
42
|
from sglang.srt.managers.io_struct import (
|
43
43
|
AbortReq,
|
44
|
-
BatchEmbeddingOut,
|
45
|
-
BatchTokenIDOut,
|
46
44
|
CloseSessionReqInput,
|
47
45
|
FlushCacheReq,
|
48
46
|
GetInternalStateReq,
|
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
|
|
74
72
|
)
|
75
73
|
from sglang.srt.managers.schedule_batch import (
|
76
74
|
FINISH_ABORT,
|
77
|
-
BaseFinishReason,
|
78
75
|
ImageInputs,
|
79
76
|
Req,
|
80
77
|
ScheduleBatch,
|
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
|
|
85
82
|
PrefillAdder,
|
86
83
|
SchedulePolicy,
|
87
84
|
)
|
85
|
+
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
86
|
+
SchedulerOutputProcessorMixin,
|
87
|
+
)
|
88
88
|
from sglang.srt.managers.session_controller import Session
|
89
89
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
90
90
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
93
93
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
94
94
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
95
95
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
96
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
96
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
97
97
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
98
98
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
99
99
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
|
|
103
103
|
crash_on_warnings,
|
104
104
|
get_bool_env_var,
|
105
105
|
get_zmq_socket,
|
106
|
+
kill_itself_when_parent_died,
|
106
107
|
pyspy_dump_schedulers,
|
107
108
|
set_gpu_proc_affinity,
|
108
109
|
set_random_seed,
|
@@ -132,7 +133,7 @@ class EmbeddingBatchResult:
|
|
132
133
|
bid: int
|
133
134
|
|
134
135
|
|
135
|
-
class Scheduler:
|
136
|
+
class Scheduler(SchedulerOutputProcessorMixin):
|
136
137
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
137
138
|
|
138
139
|
def __init__(
|
@@ -159,6 +160,7 @@ class Scheduler:
|
|
159
160
|
)
|
160
161
|
self.gpu_id = gpu_id
|
161
162
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
163
|
+
self.page_size = server_args.page_size
|
162
164
|
|
163
165
|
# Distributed rank info
|
164
166
|
self.dp_size = server_args.dp_size
|
@@ -270,17 +272,18 @@ class Scheduler:
|
|
270
272
|
|
271
273
|
# Init running status
|
272
274
|
self.waiting_queue: List[Req] = []
|
273
|
-
self.staging_reqs = {}
|
274
275
|
# The running decoding batch for continuous batching
|
275
|
-
self.running_batch:
|
276
|
+
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
276
277
|
# The current forward batch
|
277
278
|
self.cur_batch: Optional[ScheduleBatch] = None
|
278
|
-
# The
|
279
|
+
# The last forward batch
|
279
280
|
self.last_batch: Optional[ScheduleBatch] = None
|
280
281
|
self.forward_ct = 0
|
281
282
|
self.forward_ct_decode = 0
|
282
283
|
self.num_generated_tokens = 0
|
284
|
+
self.num_prefill_tokens = 0
|
283
285
|
self.last_decode_stats_tic = time.time()
|
286
|
+
self.last_prefill_stats_tic = time.time()
|
284
287
|
self.return_health_check_ct = 0
|
285
288
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
286
289
|
if self.device == "cpu":
|
@@ -308,7 +311,11 @@ class Scheduler:
|
|
308
311
|
self.grammar_backend = None
|
309
312
|
|
310
313
|
# Init schedule policy and new token estimation
|
311
|
-
self.policy = SchedulePolicy(
|
314
|
+
self.policy = SchedulePolicy(
|
315
|
+
self.schedule_policy,
|
316
|
+
self.tree_cache,
|
317
|
+
self.enable_hierarchical_cache,
|
318
|
+
)
|
312
319
|
assert (
|
313
320
|
server_args.schedule_conservativeness >= 0
|
314
321
|
), "Invalid schedule_conservativeness"
|
@@ -327,11 +334,6 @@ class Scheduler:
|
|
327
334
|
) / global_config.default_new_token_ratio_decay_steps
|
328
335
|
self.new_token_ratio = self.init_new_token_ratio
|
329
336
|
|
330
|
-
# Tell whether the current running batch is full so that we can skip
|
331
|
-
# the check of whether to prefill new requests.
|
332
|
-
# This is an optimization to reduce the overhead of the prefill check.
|
333
|
-
self.batch_is_full = False
|
334
|
-
|
335
337
|
# Init watchdog thread
|
336
338
|
self.watchdog_timeout = server_args.watchdog_timeout
|
337
339
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
@@ -431,11 +433,14 @@ class Scheduler:
|
|
431
433
|
self.tree_cache = HiRadixCache(
|
432
434
|
req_to_token_pool=self.req_to_token_pool,
|
433
435
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
|
+
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
437
|
+
page_size=self.page_size,
|
434
438
|
)
|
435
439
|
else:
|
436
440
|
self.tree_cache = RadixCache(
|
437
441
|
req_to_token_pool=self.req_to_token_pool,
|
438
442
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
443
|
+
page_size=self.page_size,
|
439
444
|
disable=server_args.disable_radix_cache,
|
440
445
|
)
|
441
446
|
|
@@ -457,6 +462,7 @@ class Scheduler:
|
|
457
462
|
# The largest context length (prefill + generation) of a single request
|
458
463
|
self._largest_prefill_decode_len: int = 0
|
459
464
|
self.last_gen_throughput: float = 0.0
|
465
|
+
self.last_input_throughput: float = 0.0
|
460
466
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
461
467
|
self.spec_num_total_accepted_tokens = 0
|
462
468
|
self.spec_num_total_forward_ct = 0
|
@@ -486,7 +492,7 @@ class Scheduler:
|
|
486
492
|
result = self.run_batch(batch)
|
487
493
|
self.process_batch_result(batch, result)
|
488
494
|
else:
|
489
|
-
# When the server is idle,
|
495
|
+
# When the server is idle, do self-check and re-init some states
|
490
496
|
self.check_memory()
|
491
497
|
self.new_token_ratio = self.init_new_token_ratio
|
492
498
|
|
@@ -526,7 +532,7 @@ class Scheduler:
|
|
526
532
|
)
|
527
533
|
self.process_batch_result(tmp_batch, tmp_result)
|
528
534
|
elif batch is None:
|
529
|
-
# When the server is idle,
|
535
|
+
# When the server is idle, do self-check and re-init some states
|
530
536
|
self.check_memory()
|
531
537
|
self.new_token_ratio = self.init_new_token_ratio
|
532
538
|
|
@@ -587,7 +593,7 @@ class Scheduler:
|
|
587
593
|
for recv_req in recv_reqs:
|
588
594
|
# If it is a health check generation request and there are running requests, ignore it.
|
589
595
|
if is_health_check_generate_req(recv_req) and (
|
590
|
-
self.chunked_req is not None or self.running_batch
|
596
|
+
self.chunked_req is not None or not self.running_batch.is_empty()
|
591
597
|
):
|
592
598
|
self.return_health_check_ct += 1
|
593
599
|
continue
|
@@ -767,6 +773,30 @@ class Scheduler:
|
|
767
773
|
)
|
768
774
|
req.tokenizer = self.tokenizer
|
769
775
|
|
776
|
+
# Handle multimodal inputs
|
777
|
+
if recv_req.image_inputs is not None:
|
778
|
+
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
779
|
+
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
780
|
+
req.origin_input_ids = self.pad_input_ids_func(
|
781
|
+
req.origin_input_ids, image_inputs
|
782
|
+
)
|
783
|
+
req.extend_image_inputs(image_inputs)
|
784
|
+
|
785
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
786
|
+
error_msg = (
|
787
|
+
"Multimodal prompt is too long after expanding multimodal tokens. "
|
788
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
789
|
+
)
|
790
|
+
logger.error(error_msg)
|
791
|
+
req.origin_input_ids = [0]
|
792
|
+
req.image_inputs = None
|
793
|
+
req.sampling_params.max_new_tokens = 0
|
794
|
+
req.finished_reason = FINISH_ABORT(
|
795
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
796
|
+
)
|
797
|
+
self.waiting_queue.append(req)
|
798
|
+
return
|
799
|
+
|
770
800
|
# Validate prompts length
|
771
801
|
error_msg = validate_input_length(
|
772
802
|
req,
|
@@ -787,6 +817,11 @@ class Scheduler:
|
|
787
817
|
can_run_list: List[Req],
|
788
818
|
running_bs: int,
|
789
819
|
):
|
820
|
+
gap_latency = time.time() - self.last_prefill_stats_tic
|
821
|
+
self.last_prefill_stats_tic = time.time()
|
822
|
+
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
823
|
+
self.num_prefill_tokens = 0
|
824
|
+
|
790
825
|
num_used = self.max_total_num_tokens - (
|
791
826
|
self.token_to_kv_pool_allocator.available_size()
|
792
827
|
+ self.tree_cache.evictable_size()
|
@@ -822,7 +857,7 @@ class Scheduler:
|
|
822
857
|
self.last_decode_stats_tic = time.time()
|
823
858
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
824
859
|
self.num_generated_tokens = 0
|
825
|
-
num_running_reqs = len(self.running_batch.reqs)
|
860
|
+
num_running_reqs = len(self.running_batch.reqs)
|
826
861
|
num_used = self.max_total_num_tokens - (
|
827
862
|
self.token_to_kv_pool_allocator.available_size()
|
828
863
|
+ self.tree_cache.evictable_size()
|
@@ -886,8 +921,10 @@ class Scheduler:
|
|
886
921
|
)
|
887
922
|
if memory_leak:
|
888
923
|
msg = (
|
889
|
-
"KV cache pool leak detected!"
|
924
|
+
"KV cache pool leak detected! "
|
890
925
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
926
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
927
|
+
f"{self.tree_cache.evictable_size()=}\n"
|
891
928
|
)
|
892
929
|
warnings.warn(msg)
|
893
930
|
if crash_on_warnings():
|
@@ -913,7 +950,7 @@ class Scheduler:
|
|
913
950
|
self.token_to_kv_pool_allocator.available_size()
|
914
951
|
+ self.tree_cache.evictable_size()
|
915
952
|
)
|
916
|
-
num_running_reqs = len(self.running_batch.reqs)
|
953
|
+
num_running_reqs = len(self.running_batch.reqs)
|
917
954
|
self.stats.num_running_reqs = num_running_reqs
|
918
955
|
self.stats.num_used_tokens = num_used
|
919
956
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
@@ -931,14 +968,20 @@ class Scheduler:
|
|
931
968
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
932
969
|
# chunked request keeps its rid but will get a new req_pool_idx
|
933
970
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
934
|
-
self.batch_is_full = False
|
971
|
+
self.running_batch.batch_is_full = False
|
935
972
|
|
973
|
+
# Filter batch
|
974
|
+
last_bs = self.last_batch.batch_size()
|
936
975
|
self.last_batch.filter_batch()
|
976
|
+
if self.last_batch.batch_size() < last_bs:
|
977
|
+
self.running_batch.batch_is_full = False
|
978
|
+
|
979
|
+
# Merge the new batch into the running batch
|
937
980
|
if not self.last_batch.is_empty():
|
938
|
-
if self.running_batch
|
981
|
+
if self.running_batch.is_empty():
|
939
982
|
self.running_batch = self.last_batch
|
940
983
|
else:
|
941
|
-
#
|
984
|
+
# Merge running_batch with prefill batch
|
942
985
|
self.running_batch.merge_batch(self.last_batch)
|
943
986
|
|
944
987
|
new_batch = self.get_new_batch_prefill()
|
@@ -947,15 +990,15 @@ class Scheduler:
|
|
947
990
|
ret = new_batch
|
948
991
|
else:
|
949
992
|
# Run decode
|
950
|
-
if self.running_batch
|
951
|
-
ret = None
|
952
|
-
else:
|
993
|
+
if not self.running_batch.is_empty():
|
953
994
|
self.running_batch = self.update_running_batch(self.running_batch)
|
954
|
-
ret = self.running_batch
|
995
|
+
ret = self.running_batch if not self.running_batch.is_empty() else None
|
996
|
+
else:
|
997
|
+
ret = None
|
955
998
|
|
956
999
|
# Handle DP attention
|
957
1000
|
if self.server_args.enable_dp_attention:
|
958
|
-
ret = self.prepare_dp_attn_batch(ret)
|
1001
|
+
ret, _ = self.prepare_dp_attn_batch(ret)
|
959
1002
|
|
960
1003
|
return ret
|
961
1004
|
|
@@ -966,15 +1009,20 @@ class Scheduler:
|
|
966
1009
|
|
967
1010
|
# Handle the cases where prefill is not allowed
|
968
1011
|
if (
|
969
|
-
self.batch_is_full or len(self.waiting_queue) == 0
|
1012
|
+
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
970
1013
|
) and self.chunked_req is None:
|
971
1014
|
return None
|
972
1015
|
|
973
|
-
running_bs = len(self.running_batch.reqs)
|
1016
|
+
running_bs = len(self.running_batch.reqs)
|
974
1017
|
if running_bs >= self.max_running_requests:
|
975
|
-
self.batch_is_full = True
|
1018
|
+
self.running_batch.batch_is_full = True
|
976
1019
|
return None
|
977
1020
|
|
1021
|
+
if self.enable_hierarchical_cache:
|
1022
|
+
# check for completion of hierarchical cache activities to release memory
|
1023
|
+
self.tree_cache.writing_check()
|
1024
|
+
self.tree_cache.loading_check()
|
1025
|
+
|
978
1026
|
# Get priority queue
|
979
1027
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
980
1028
|
|
@@ -989,17 +1037,13 @@ class Scheduler:
|
|
989
1037
|
running_bs if self.is_mixed_chunk else 0,
|
990
1038
|
)
|
991
1039
|
|
992
|
-
|
993
|
-
if is_chunked:
|
1040
|
+
if self.chunked_req is not None:
|
994
1041
|
self.chunked_req.init_next_round_input()
|
995
1042
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
996
1043
|
|
997
1044
|
if self.lora_paths:
|
998
|
-
lora_set = (
|
999
|
-
|
1000
|
-
if self.running_batch is not None
|
1001
|
-
else set([])
|
1002
|
-
)
|
1045
|
+
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
1046
|
+
|
1003
1047
|
# Get requests from the waiting queue to a new prefill batch
|
1004
1048
|
for req in self.waiting_queue:
|
1005
1049
|
if (
|
@@ -1011,49 +1055,33 @@ class Scheduler:
|
|
1011
1055
|
)
|
1012
1056
|
> self.max_loras_per_batch
|
1013
1057
|
):
|
1014
|
-
self.batch_is_full = True
|
1058
|
+
self.running_batch.batch_is_full = True
|
1015
1059
|
break
|
1016
1060
|
|
1017
1061
|
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
1018
|
-
self.batch_is_full = True
|
1062
|
+
self.running_batch.batch_is_full = True
|
1019
1063
|
break
|
1020
1064
|
|
1021
|
-
req.init_next_round_input(
|
1065
|
+
req.init_next_round_input(
|
1066
|
+
None if prefix_computed else self.tree_cache,
|
1067
|
+
self.enable_hierarchical_cache,
|
1068
|
+
)
|
1022
1069
|
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
|
1027
|
-
req.last_node,
|
1028
|
-
req.prefix_indices,
|
1029
|
-
adder.rem_total_tokens,
|
1030
|
-
)
|
1031
|
-
if req.last_node.loading:
|
1032
|
-
# to prevent frequent cache invalidation
|
1033
|
-
if req.rid in self.staging_reqs:
|
1034
|
-
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1035
|
-
self.tree_cache.inc_lock_ref(req.last_node)
|
1036
|
-
self.staging_reqs[req.rid] = req.last_node
|
1037
|
-
continue
|
1038
|
-
elif req.last_node.loading:
|
1039
|
-
if not self.tree_cache.loading_complete(req.last_node):
|
1040
|
-
continue
|
1041
|
-
|
1042
|
-
if req.rid in self.staging_reqs:
|
1043
|
-
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
|
1044
|
-
del self.staging_reqs[req.rid]
|
1045
|
-
|
1046
|
-
res = adder.add_one_req(req, self.chunked_req)
|
1070
|
+
res = adder.add_one_req(
|
1071
|
+
req, self.chunked_req, self.enable_hierarchical_cache
|
1072
|
+
)
|
1047
1073
|
if res != AddReqResult.CONTINUE:
|
1048
1074
|
if res == AddReqResult.NO_TOKEN:
|
1049
1075
|
if self.enable_hierarchical_cache:
|
1050
1076
|
# Set batch_is_full after making sure there are requests that can be served
|
1051
|
-
self.batch_is_full = len(
|
1077
|
+
self.running_batch.batch_is_full = len(
|
1078
|
+
adder.can_run_list
|
1079
|
+
) > 0 or (
|
1052
1080
|
self.running_batch is not None
|
1053
1081
|
and not self.running_batch.is_empty()
|
1054
1082
|
)
|
1055
1083
|
else:
|
1056
|
-
self.batch_is_full = True
|
1084
|
+
self.running_batch.batch_is_full = True
|
1057
1085
|
break
|
1058
1086
|
|
1059
1087
|
# Update waiting queue
|
@@ -1064,6 +1092,9 @@ class Scheduler:
|
|
1064
1092
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1065
1093
|
]
|
1066
1094
|
|
1095
|
+
if self.enable_hierarchical_cache:
|
1096
|
+
self.tree_cache.read_to_load_cache()
|
1097
|
+
|
1067
1098
|
if adder.new_chunked_req is not None:
|
1068
1099
|
assert self.chunked_req is None
|
1069
1100
|
self.chunked_req = adder.new_chunked_req
|
@@ -1091,7 +1122,7 @@ class Scheduler:
|
|
1091
1122
|
# Mixed-style chunked prefill
|
1092
1123
|
if (
|
1093
1124
|
self.is_mixed_chunk
|
1094
|
-
and self.running_batch
|
1125
|
+
and not self.running_batch.is_empty()
|
1095
1126
|
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
1096
1127
|
):
|
1097
1128
|
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
@@ -1100,7 +1131,9 @@ class Scheduler:
|
|
1100
1131
|
self.running_batch.prepare_for_decode()
|
1101
1132
|
new_batch.mix_with_running(self.running_batch)
|
1102
1133
|
new_batch.decoding_reqs = self.running_batch.reqs
|
1103
|
-
self.running_batch =
|
1134
|
+
self.running_batch = ScheduleBatch(
|
1135
|
+
reqs=[], batch_is_full=self.running_batch.batch_is_full
|
1136
|
+
)
|
1104
1137
|
else:
|
1105
1138
|
new_batch.decoding_reqs = None
|
1106
1139
|
|
@@ -1112,8 +1145,8 @@ class Scheduler:
|
|
1112
1145
|
|
1113
1146
|
batch.filter_batch()
|
1114
1147
|
if batch.is_empty():
|
1115
|
-
|
1116
|
-
return
|
1148
|
+
batch.batch_is_full = False
|
1149
|
+
return batch
|
1117
1150
|
|
1118
1151
|
# Check if decode out of memory
|
1119
1152
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
@@ -1137,7 +1170,7 @@ class Scheduler:
|
|
1137
1170
|
)
|
1138
1171
|
|
1139
1172
|
if batch.batch_size() < initial_bs:
|
1140
|
-
|
1173
|
+
batch.batch_is_full = False
|
1141
1174
|
|
1142
1175
|
# Update batch tensors
|
1143
1176
|
batch.prepare_for_decode()
|
@@ -1212,8 +1245,6 @@ class Scheduler:
|
|
1212
1245
|
):
|
1213
1246
|
if batch.forward_mode.is_decode():
|
1214
1247
|
self.process_batch_result_decode(batch, result)
|
1215
|
-
if batch.is_empty():
|
1216
|
-
self.running_batch = None
|
1217
1248
|
elif batch.forward_mode.is_extend():
|
1218
1249
|
self.process_batch_result_prefill(batch, result)
|
1219
1250
|
elif batch.forward_mode.is_idle():
|
@@ -1235,615 +1266,76 @@ class Scheduler:
|
|
1235
1266
|
self.return_health_check_ct -= 1
|
1236
1267
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1237
1268
|
|
1238
|
-
def process_batch_result_prefill(
|
1239
|
-
self,
|
1240
|
-
batch: ScheduleBatch,
|
1241
|
-
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1242
|
-
):
|
1243
|
-
skip_stream_req = None
|
1244
|
-
|
1245
|
-
if self.is_generation:
|
1246
|
-
(
|
1247
|
-
logits_output,
|
1248
|
-
next_token_ids,
|
1249
|
-
extend_input_len_per_req,
|
1250
|
-
extend_logprob_start_len_per_req,
|
1251
|
-
bid,
|
1252
|
-
) = (
|
1253
|
-
result.logits_output,
|
1254
|
-
result.next_token_ids,
|
1255
|
-
result.extend_input_len_per_req,
|
1256
|
-
result.extend_logprob_start_len_per_req,
|
1257
|
-
result.bid,
|
1258
|
-
)
|
1259
|
-
|
1260
|
-
if self.enable_overlap:
|
1261
|
-
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1262
|
-
else:
|
1263
|
-
# Move next_token_ids and logprobs to cpu
|
1264
|
-
next_token_ids = next_token_ids.tolist()
|
1265
|
-
if batch.return_logprob:
|
1266
|
-
if logits_output.next_token_logprobs is not None:
|
1267
|
-
logits_output.next_token_logprobs = (
|
1268
|
-
logits_output.next_token_logprobs.tolist()
|
1269
|
-
)
|
1270
|
-
if logits_output.input_token_logprobs is not None:
|
1271
|
-
logits_output.input_token_logprobs = tuple(
|
1272
|
-
logits_output.input_token_logprobs.tolist()
|
1273
|
-
)
|
1274
|
-
|
1275
|
-
hidden_state_offset = 0
|
1276
|
-
|
1277
|
-
# Check finish conditions
|
1278
|
-
logprob_pt = 0
|
1279
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1280
|
-
if req.is_retracted:
|
1281
|
-
continue
|
1282
|
-
|
1283
|
-
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
1284
|
-
# Free the one delayed token for the mixed decode batch
|
1285
|
-
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
1286
|
-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
1287
|
-
continue
|
1288
|
-
|
1289
|
-
if req.is_chunked <= 0:
|
1290
|
-
# req output_ids are set here
|
1291
|
-
req.output_ids.append(next_token_id)
|
1292
|
-
req.check_finished()
|
1293
|
-
|
1294
|
-
if req.finished():
|
1295
|
-
self.tree_cache.cache_finished_req(req)
|
1296
|
-
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
1297
|
-
# This updates radix so others can match
|
1298
|
-
self.tree_cache.cache_unfinished_req(req)
|
1299
|
-
|
1300
|
-
if req.return_logprob:
|
1301
|
-
assert extend_logprob_start_len_per_req is not None
|
1302
|
-
assert extend_input_len_per_req is not None
|
1303
|
-
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1304
|
-
extend_input_len = extend_input_len_per_req[i]
|
1305
|
-
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
1306
|
-
self.add_logprob_return_values(
|
1307
|
-
i,
|
1308
|
-
req,
|
1309
|
-
logprob_pt,
|
1310
|
-
next_token_ids,
|
1311
|
-
num_input_logprobs,
|
1312
|
-
logits_output,
|
1313
|
-
)
|
1314
|
-
logprob_pt += num_input_logprobs
|
1315
|
-
|
1316
|
-
if (
|
1317
|
-
req.return_hidden_states
|
1318
|
-
and logits_output.hidden_states is not None
|
1319
|
-
):
|
1320
|
-
req.hidden_states.append(
|
1321
|
-
logits_output.hidden_states[
|
1322
|
-
hidden_state_offset : (
|
1323
|
-
hidden_state_offset := hidden_state_offset
|
1324
|
-
+ len(req.origin_input_ids)
|
1325
|
-
)
|
1326
|
-
]
|
1327
|
-
.cpu()
|
1328
|
-
.clone()
|
1329
|
-
)
|
1330
|
-
|
1331
|
-
if req.grammar is not None:
|
1332
|
-
req.grammar.accept_token(next_token_id)
|
1333
|
-
req.grammar.finished = req.finished()
|
1334
|
-
else:
|
1335
|
-
# being chunked reqs' prefill is not finished
|
1336
|
-
req.is_chunked -= 1
|
1337
|
-
# There is only at most one request being currently chunked.
|
1338
|
-
# Because this request does not finish prefill,
|
1339
|
-
# we don't want to stream the request currently being chunked.
|
1340
|
-
skip_stream_req = req
|
1341
|
-
|
1342
|
-
# Incrementally update input logprobs.
|
1343
|
-
if req.return_logprob:
|
1344
|
-
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
1345
|
-
extend_input_len = extend_input_len_per_req[i]
|
1346
|
-
if extend_logprob_start_len < extend_input_len:
|
1347
|
-
# Update input logprobs.
|
1348
|
-
num_input_logprobs = (
|
1349
|
-
extend_input_len - extend_logprob_start_len
|
1350
|
-
)
|
1351
|
-
self.add_input_logprob_return_values(
|
1352
|
-
i,
|
1353
|
-
req,
|
1354
|
-
logits_output,
|
1355
|
-
logprob_pt,
|
1356
|
-
num_input_logprobs,
|
1357
|
-
last_prefill_chunk=False,
|
1358
|
-
)
|
1359
|
-
logprob_pt += num_input_logprobs
|
1360
|
-
|
1361
|
-
if batch.next_batch_sampling_info:
|
1362
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1363
|
-
self.current_stream.synchronize()
|
1364
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1365
|
-
|
1366
|
-
else: # embedding or reward model
|
1367
|
-
embeddings, bid = result.embeddings, result.bid
|
1368
|
-
embeddings = embeddings.tolist()
|
1369
|
-
|
1370
|
-
# Check finish conditions
|
1371
|
-
for i, req in enumerate(batch.reqs):
|
1372
|
-
if req.is_retracted:
|
1373
|
-
continue
|
1374
|
-
|
1375
|
-
req.embedding = embeddings[i]
|
1376
|
-
if req.is_chunked <= 0:
|
1377
|
-
# Dummy output token for embedding models
|
1378
|
-
req.output_ids.append(0)
|
1379
|
-
req.check_finished()
|
1380
|
-
|
1381
|
-
if req.finished():
|
1382
|
-
self.tree_cache.cache_finished_req(req)
|
1383
|
-
else:
|
1384
|
-
self.tree_cache.cache_unfinished_req(req)
|
1385
|
-
else:
|
1386
|
-
# being chunked reqs' prefill is not finished
|
1387
|
-
req.is_chunked -= 1
|
1388
|
-
|
1389
|
-
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1390
|
-
|
1391
|
-
def process_batch_result_decode(
|
1392
|
-
self,
|
1393
|
-
batch: ScheduleBatch,
|
1394
|
-
result: GenerationBatchResult,
|
1395
|
-
):
|
1396
|
-
logits_output, next_token_ids, bid = (
|
1397
|
-
result.logits_output,
|
1398
|
-
result.next_token_ids,
|
1399
|
-
result.bid,
|
1400
|
-
)
|
1401
|
-
self.num_generated_tokens += len(batch.reqs)
|
1402
|
-
|
1403
|
-
if self.enable_overlap:
|
1404
|
-
assert batch.spec_algorithm.is_none()
|
1405
|
-
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1406
|
-
next_token_logprobs = logits_output.next_token_logprobs
|
1407
|
-
elif batch.spec_algorithm.is_none():
|
1408
|
-
# spec decoding handles output logprobs inside verify process.
|
1409
|
-
next_token_ids = next_token_ids.tolist()
|
1410
|
-
if batch.return_logprob:
|
1411
|
-
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1412
|
-
|
1413
|
-
self.token_to_kv_pool_allocator.free_group_begin()
|
1414
|
-
|
1415
|
-
# Check finish condition
|
1416
|
-
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
1417
|
-
# We should ignore using next_token_ids for spec decoding cases.
|
1418
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
1419
|
-
if req.is_retracted:
|
1420
|
-
continue
|
1421
|
-
|
1422
|
-
if self.enable_overlap and req.finished():
|
1423
|
-
# Free the one delayed token
|
1424
|
-
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
1425
|
-
continue
|
1426
|
-
|
1427
|
-
if batch.spec_algorithm.is_none():
|
1428
|
-
# speculative worker will solve the output_ids in speculative decoding
|
1429
|
-
req.output_ids.append(next_token_id)
|
1430
|
-
|
1431
|
-
req.check_finished()
|
1432
|
-
if req.finished():
|
1433
|
-
self.tree_cache.cache_finished_req(req)
|
1434
|
-
|
1435
|
-
if req.return_logprob and batch.spec_algorithm.is_none():
|
1436
|
-
# speculative worker handles logprob in speculative decoding
|
1437
|
-
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
1438
|
-
req.output_token_logprobs_idx.append(next_token_id)
|
1439
|
-
if req.top_logprobs_num > 0:
|
1440
|
-
req.output_top_logprobs_val.append(
|
1441
|
-
logits_output.next_token_top_logprobs_val[i]
|
1442
|
-
)
|
1443
|
-
req.output_top_logprobs_idx.append(
|
1444
|
-
logits_output.next_token_top_logprobs_idx[i]
|
1445
|
-
)
|
1446
|
-
if req.token_ids_logprob is not None:
|
1447
|
-
req.output_token_ids_logprobs_val.append(
|
1448
|
-
logits_output.next_token_token_ids_logprobs_val[i]
|
1449
|
-
)
|
1450
|
-
req.output_token_ids_logprobs_idx.append(
|
1451
|
-
logits_output.next_token_token_ids_logprobs_idx[i]
|
1452
|
-
)
|
1453
|
-
|
1454
|
-
if req.return_hidden_states and logits_output.hidden_states is not None:
|
1455
|
-
req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
|
1456
|
-
|
1457
|
-
if req.grammar is not None and batch.spec_algorithm.is_none():
|
1458
|
-
req.grammar.accept_token(next_token_id)
|
1459
|
-
req.grammar.finished = req.finished()
|
1460
|
-
|
1461
|
-
if batch.next_batch_sampling_info:
|
1462
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1463
|
-
self.current_stream.synchronize()
|
1464
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1465
|
-
|
1466
|
-
self.stream_output(batch.reqs, batch.return_logprob)
|
1467
|
-
|
1468
|
-
self.token_to_kv_pool_allocator.free_group_end()
|
1469
|
-
|
1470
|
-
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1471
|
-
if (
|
1472
|
-
self.attn_tp_rank == 0
|
1473
|
-
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
1474
|
-
):
|
1475
|
-
self.log_decode_stats()
|
1476
|
-
|
1477
|
-
def add_input_logprob_return_values(
|
1478
|
-
self,
|
1479
|
-
i: int,
|
1480
|
-
req: Req,
|
1481
|
-
output: LogitsProcessorOutput,
|
1482
|
-
logprob_pt: int,
|
1483
|
-
num_input_logprobs: int,
|
1484
|
-
last_prefill_chunk: bool, # If True, it means prefill is finished.
|
1485
|
-
):
|
1486
|
-
"""Incrementally add input logprobs to `req`.
|
1487
|
-
|
1488
|
-
Args:
|
1489
|
-
i: The request index in a batch.
|
1490
|
-
req: The request. Input logprobs inside req are modified as a
|
1491
|
-
consequence of the API
|
1492
|
-
fill_ids: The prefill ids processed.
|
1493
|
-
output: Logit processor output that's used to compute input logprobs
|
1494
|
-
last_prefill_chunk: True if it is the last prefill (when chunked).
|
1495
|
-
Some of input logprob operation should only happen at the last
|
1496
|
-
prefill (e.g., computing input token logprobs).
|
1497
|
-
"""
|
1498
|
-
assert output.input_token_logprobs is not None
|
1499
|
-
if req.input_token_logprobs is None:
|
1500
|
-
req.input_token_logprobs = []
|
1501
|
-
if req.temp_input_top_logprobs_val is None:
|
1502
|
-
req.temp_input_top_logprobs_val = []
|
1503
|
-
if req.temp_input_top_logprobs_idx is None:
|
1504
|
-
req.temp_input_top_logprobs_idx = []
|
1505
|
-
if req.temp_input_token_ids_logprobs_val is None:
|
1506
|
-
req.temp_input_token_ids_logprobs_val = []
|
1507
|
-
if req.temp_input_token_ids_logprobs_idx is None:
|
1508
|
-
req.temp_input_token_ids_logprobs_idx = []
|
1509
|
-
|
1510
|
-
if req.input_token_logprobs_val is not None:
|
1511
|
-
# The input logprob has been already computed. It only happens
|
1512
|
-
# upon retract.
|
1513
|
-
if req.top_logprobs_num > 0:
|
1514
|
-
assert req.input_token_logprobs_val is not None
|
1515
|
-
return
|
1516
|
-
|
1517
|
-
# Important for the performance.
|
1518
|
-
assert isinstance(output.input_token_logprobs, tuple)
|
1519
|
-
input_token_logprobs: Tuple[int] = output.input_token_logprobs
|
1520
|
-
input_token_logprobs = input_token_logprobs[
|
1521
|
-
logprob_pt : logprob_pt + num_input_logprobs
|
1522
|
-
]
|
1523
|
-
req.input_token_logprobs.extend(input_token_logprobs)
|
1524
|
-
|
1525
|
-
if req.top_logprobs_num > 0:
|
1526
|
-
req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
|
1527
|
-
req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
|
1528
|
-
|
1529
|
-
if req.token_ids_logprob is not None:
|
1530
|
-
req.temp_input_token_ids_logprobs_val.append(
|
1531
|
-
output.input_token_ids_logprobs_val[i]
|
1532
|
-
)
|
1533
|
-
req.temp_input_token_ids_logprobs_idx.append(
|
1534
|
-
output.input_token_ids_logprobs_idx[i]
|
1535
|
-
)
|
1536
|
-
|
1537
|
-
if last_prefill_chunk:
|
1538
|
-
input_token_logprobs = req.input_token_logprobs
|
1539
|
-
req.input_token_logprobs = None
|
1540
|
-
assert req.input_token_logprobs_val is None
|
1541
|
-
assert req.input_token_logprobs_idx is None
|
1542
|
-
assert req.input_top_logprobs_val is None
|
1543
|
-
assert req.input_top_logprobs_idx is None
|
1544
|
-
|
1545
|
-
# Compute input_token_logprobs_val
|
1546
|
-
# Always pad the first one with None.
|
1547
|
-
req.input_token_logprobs_val = [None]
|
1548
|
-
req.input_token_logprobs_val.extend(input_token_logprobs)
|
1549
|
-
# The last input logprob is for sampling, so just pop it out.
|
1550
|
-
req.input_token_logprobs_val.pop()
|
1551
|
-
|
1552
|
-
# Compute input_token_logprobs_idx
|
1553
|
-
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
1554
|
-
# Clip the padded hash values from image tokens.
|
1555
|
-
# Otherwise, it will lead to detokenization errors.
|
1556
|
-
input_token_logprobs_idx = [
|
1557
|
-
x if x < self.model_config.vocab_size - 1 else 0
|
1558
|
-
for x in input_token_logprobs_idx
|
1559
|
-
]
|
1560
|
-
req.input_token_logprobs_idx = input_token_logprobs_idx
|
1561
|
-
|
1562
|
-
if req.top_logprobs_num > 0:
|
1563
|
-
req.input_top_logprobs_val = [None]
|
1564
|
-
req.input_top_logprobs_idx = [None]
|
1565
|
-
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
1566
|
-
req.temp_input_token_ids_logprobs_idx
|
1567
|
-
)
|
1568
|
-
for val, idx in zip(
|
1569
|
-
req.temp_input_top_logprobs_val,
|
1570
|
-
req.temp_input_top_logprobs_idx,
|
1571
|
-
strict=True,
|
1572
|
-
):
|
1573
|
-
req.input_top_logprobs_val.extend(val)
|
1574
|
-
req.input_top_logprobs_idx.extend(idx)
|
1575
|
-
|
1576
|
-
# Last token is a sample token.
|
1577
|
-
req.input_top_logprobs_val.pop()
|
1578
|
-
req.input_top_logprobs_idx.pop()
|
1579
|
-
req.temp_input_top_logprobs_idx = None
|
1580
|
-
req.temp_input_top_logprobs_val = None
|
1581
|
-
|
1582
|
-
if req.token_ids_logprob is not None:
|
1583
|
-
req.input_token_ids_logprobs_val = [None]
|
1584
|
-
req.input_token_ids_logprobs_idx = [None]
|
1585
|
-
|
1586
|
-
for val, idx in zip(
|
1587
|
-
req.temp_input_token_ids_logprobs_val,
|
1588
|
-
req.temp_input_token_ids_logprobs_idx,
|
1589
|
-
strict=True,
|
1590
|
-
):
|
1591
|
-
req.input_token_ids_logprobs_val.extend(val)
|
1592
|
-
req.input_token_ids_logprobs_idx.extend(idx)
|
1593
|
-
|
1594
|
-
# Last token is a sample token.
|
1595
|
-
req.input_token_ids_logprobs_val.pop()
|
1596
|
-
req.input_token_ids_logprobs_idx.pop()
|
1597
|
-
req.temp_input_token_ids_logprobs_idx = None
|
1598
|
-
req.temp_input_token_ids_logprobs_val = None
|
1599
|
-
|
1600
|
-
if req.return_logprob:
|
1601
|
-
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
1602
|
-
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
1603
|
-
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
1604
|
-
if req.top_logprobs_num > 0:
|
1605
|
-
assert len(req.input_top_logprobs_val) == relevant_tokens_len
|
1606
|
-
assert len(req.input_top_logprobs_idx) == relevant_tokens_len
|
1607
|
-
if req.token_ids_logprob is not None:
|
1608
|
-
assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
|
1609
|
-
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
1610
|
-
|
1611
|
-
def add_logprob_return_values(
|
1612
|
-
self,
|
1613
|
-
i: int,
|
1614
|
-
req: Req,
|
1615
|
-
pt: int,
|
1616
|
-
next_token_ids: List[int],
|
1617
|
-
num_input_logprobs: int,
|
1618
|
-
output: LogitsProcessorOutput,
|
1619
|
-
):
|
1620
|
-
"""Attach logprobs to the return values."""
|
1621
|
-
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
1622
|
-
req.output_token_logprobs_idx.append(next_token_ids[i])
|
1623
|
-
|
1624
|
-
self.add_input_logprob_return_values(
|
1625
|
-
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
1626
|
-
)
|
1627
|
-
|
1628
|
-
if req.top_logprobs_num > 0:
|
1629
|
-
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1630
|
-
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1631
|
-
|
1632
|
-
if req.token_ids_logprob is not None:
|
1633
|
-
req.output_token_ids_logprobs_val.append(
|
1634
|
-
output.next_token_token_ids_logprobs_val[i]
|
1635
|
-
)
|
1636
|
-
req.output_token_ids_logprobs_idx.append(
|
1637
|
-
output.next_token_token_ids_logprobs_idx[i]
|
1638
|
-
)
|
1639
|
-
|
1640
|
-
return num_input_logprobs
|
1641
|
-
|
1642
|
-
def stream_output(
|
1643
|
-
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
1644
|
-
):
|
1645
|
-
"""Stream the output to detokenizer."""
|
1646
|
-
rids = []
|
1647
|
-
finished_reasons: List[BaseFinishReason] = []
|
1648
|
-
|
1649
|
-
if self.is_generation:
|
1650
|
-
decoded_texts = []
|
1651
|
-
decode_ids_list = []
|
1652
|
-
read_offsets = []
|
1653
|
-
output_ids = []
|
1654
|
-
|
1655
|
-
skip_special_tokens = []
|
1656
|
-
spaces_between_special_tokens = []
|
1657
|
-
no_stop_trim = []
|
1658
|
-
prompt_tokens = []
|
1659
|
-
completion_tokens = []
|
1660
|
-
cached_tokens = []
|
1661
|
-
spec_verify_ct = []
|
1662
|
-
output_hidden_states = None
|
1663
|
-
|
1664
|
-
if return_logprob:
|
1665
|
-
input_token_logprobs_val = []
|
1666
|
-
input_token_logprobs_idx = []
|
1667
|
-
output_token_logprobs_val = []
|
1668
|
-
output_token_logprobs_idx = []
|
1669
|
-
input_top_logprobs_val = []
|
1670
|
-
input_top_logprobs_idx = []
|
1671
|
-
output_top_logprobs_val = []
|
1672
|
-
output_top_logprobs_idx = []
|
1673
|
-
input_token_ids_logprobs_val = []
|
1674
|
-
input_token_ids_logprobs_idx = []
|
1675
|
-
output_token_ids_logprobs_val = []
|
1676
|
-
output_token_ids_logprobs_idx = []
|
1677
|
-
else:
|
1678
|
-
input_token_logprobs_val = input_token_logprobs_idx = (
|
1679
|
-
output_token_logprobs_val
|
1680
|
-
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1681
|
-
input_top_logprobs_idx
|
1682
|
-
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
1683
|
-
input_token_ids_logprobs_val
|
1684
|
-
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
|
1685
|
-
output_token_ids_logprobs_idx
|
1686
|
-
) = None
|
1687
|
-
|
1688
|
-
for req in reqs:
|
1689
|
-
if req is skip_req:
|
1690
|
-
continue
|
1691
|
-
|
1692
|
-
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
|
1693
|
-
if self.model_config.is_multimodal_gen and req.to_abort:
|
1694
|
-
continue
|
1695
|
-
|
1696
|
-
if (
|
1697
|
-
req.finished()
|
1698
|
-
# If stream, follow the given stream_interval
|
1699
|
-
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
1700
|
-
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1701
|
-
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
|
1702
|
-
# always increase one-by-one.
|
1703
|
-
or (
|
1704
|
-
not req.stream
|
1705
|
-
and len(req.output_ids) % 50 == 0
|
1706
|
-
and not self.model_config.is_multimodal_gen
|
1707
|
-
)
|
1708
|
-
):
|
1709
|
-
rids.append(req.rid)
|
1710
|
-
finished_reasons.append(
|
1711
|
-
req.finished_reason.to_json() if req.finished_reason else None
|
1712
|
-
)
|
1713
|
-
decoded_texts.append(req.decoded_text)
|
1714
|
-
decode_ids, read_offset = req.init_incremental_detokenize()
|
1715
|
-
decode_ids_list.append(decode_ids)
|
1716
|
-
read_offsets.append(read_offset)
|
1717
|
-
if self.skip_tokenizer_init:
|
1718
|
-
output_ids.append(req.output_ids)
|
1719
|
-
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1720
|
-
spaces_between_special_tokens.append(
|
1721
|
-
req.sampling_params.spaces_between_special_tokens
|
1722
|
-
)
|
1723
|
-
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1724
|
-
|
1725
|
-
prompt_tokens.append(len(req.origin_input_ids))
|
1726
|
-
completion_tokens.append(len(req.output_ids))
|
1727
|
-
cached_tokens.append(req.cached_tokens)
|
1728
|
-
|
1729
|
-
if not self.spec_algorithm.is_none():
|
1730
|
-
spec_verify_ct.append(req.spec_verify_ct)
|
1731
|
-
|
1732
|
-
if return_logprob:
|
1733
|
-
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1734
|
-
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
1735
|
-
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
1736
|
-
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
1737
|
-
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
1738
|
-
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1739
|
-
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1740
|
-
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1741
|
-
input_token_ids_logprobs_val.append(
|
1742
|
-
req.input_token_ids_logprobs_val
|
1743
|
-
)
|
1744
|
-
input_token_ids_logprobs_idx.append(
|
1745
|
-
req.input_token_ids_logprobs_idx
|
1746
|
-
)
|
1747
|
-
output_token_ids_logprobs_val.append(
|
1748
|
-
req.output_token_ids_logprobs_val
|
1749
|
-
)
|
1750
|
-
output_token_ids_logprobs_idx.append(
|
1751
|
-
req.output_token_ids_logprobs_idx
|
1752
|
-
)
|
1753
|
-
|
1754
|
-
if req.return_hidden_states:
|
1755
|
-
if output_hidden_states is None:
|
1756
|
-
output_hidden_states = []
|
1757
|
-
output_hidden_states.append(req.hidden_states)
|
1758
|
-
|
1759
|
-
# Send to detokenizer
|
1760
|
-
if rids:
|
1761
|
-
if self.model_config.is_multimodal_gen:
|
1762
|
-
raise NotImplementedError()
|
1763
|
-
self.send_to_detokenizer.send_pyobj(
|
1764
|
-
BatchTokenIDOut(
|
1765
|
-
rids,
|
1766
|
-
finished_reasons,
|
1767
|
-
decoded_texts,
|
1768
|
-
decode_ids_list,
|
1769
|
-
read_offsets,
|
1770
|
-
output_ids,
|
1771
|
-
skip_special_tokens,
|
1772
|
-
spaces_between_special_tokens,
|
1773
|
-
no_stop_trim,
|
1774
|
-
prompt_tokens,
|
1775
|
-
completion_tokens,
|
1776
|
-
cached_tokens,
|
1777
|
-
spec_verify_ct,
|
1778
|
-
input_token_logprobs_val,
|
1779
|
-
input_token_logprobs_idx,
|
1780
|
-
output_token_logprobs_val,
|
1781
|
-
output_token_logprobs_idx,
|
1782
|
-
input_top_logprobs_val,
|
1783
|
-
input_top_logprobs_idx,
|
1784
|
-
output_top_logprobs_val,
|
1785
|
-
output_top_logprobs_idx,
|
1786
|
-
input_token_ids_logprobs_val,
|
1787
|
-
input_token_ids_logprobs_idx,
|
1788
|
-
output_token_ids_logprobs_val,
|
1789
|
-
output_token_ids_logprobs_idx,
|
1790
|
-
output_hidden_states,
|
1791
|
-
)
|
1792
|
-
)
|
1793
|
-
else: # embedding or reward model
|
1794
|
-
embeddings = []
|
1795
|
-
prompt_tokens = []
|
1796
|
-
cached_tokens = []
|
1797
|
-
for req in reqs:
|
1798
|
-
if req.finished():
|
1799
|
-
rids.append(req.rid)
|
1800
|
-
finished_reasons.append(req.finished_reason.to_json())
|
1801
|
-
embeddings.append(req.embedding)
|
1802
|
-
prompt_tokens.append(len(req.origin_input_ids))
|
1803
|
-
cached_tokens.append(req.cached_tokens)
|
1804
|
-
self.send_to_detokenizer.send_pyobj(
|
1805
|
-
BatchEmbeddingOut(
|
1806
|
-
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
|
1807
|
-
)
|
1808
|
-
)
|
1809
|
-
|
1810
1269
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1811
1270
|
# Check if other DP workers have running batches
|
1812
1271
|
if local_batch is None:
|
1813
1272
|
num_tokens = 0
|
1273
|
+
global_num_tokens_for_logprob = 0
|
1814
1274
|
elif local_batch.forward_mode.is_decode():
|
1815
1275
|
num_tokens = local_batch.batch_size()
|
1276
|
+
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
|
1277
|
+
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
|
1278
|
+
global_num_tokens_for_logprob = num_tokens
|
1816
1279
|
else:
|
1817
1280
|
num_tokens = local_batch.extend_num_tokens
|
1281
|
+
global_num_tokens_for_logprob = sum(
|
1282
|
+
[
|
1283
|
+
# We should have at least 1 token for sample in every case.
|
1284
|
+
max(extend_len - logprob_start_len, 1)
|
1285
|
+
for logprob_start_len, extend_len in zip(
|
1286
|
+
local_batch.extend_logprob_start_lens, local_batch.extend_lens
|
1287
|
+
)
|
1288
|
+
]
|
1289
|
+
)
|
1290
|
+
|
1291
|
+
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
1292
|
+
can_cuda_graph = 1
|
1293
|
+
else:
|
1294
|
+
can_cuda_graph = 0
|
1818
1295
|
|
1819
|
-
|
1820
|
-
|
1296
|
+
if not self.spec_algorithm.is_none():
|
1297
|
+
# TODO(sang): Support cuda graph when idle batch is there.
|
1298
|
+
if local_batch is None or local_batch.forward_mode.is_idle():
|
1299
|
+
can_cuda_graph = 0
|
1300
|
+
|
1301
|
+
is_extend_in_batch = (
|
1302
|
+
local_batch.forward_mode.is_extend() if local_batch else False
|
1303
|
+
)
|
1304
|
+
local_info = torch.tensor(
|
1305
|
+
[
|
1306
|
+
num_tokens,
|
1307
|
+
can_cuda_graph,
|
1308
|
+
global_num_tokens_for_logprob,
|
1309
|
+
is_extend_in_batch,
|
1310
|
+
],
|
1311
|
+
dtype=torch.int64,
|
1312
|
+
)
|
1313
|
+
global_info = torch.empty(
|
1314
|
+
(self.server_args.dp_size, self.attn_tp_size, 4),
|
1315
|
+
dtype=torch.int64,
|
1316
|
+
)
|
1821
1317
|
torch.distributed.all_gather_into_tensor(
|
1822
|
-
|
1823
|
-
|
1318
|
+
global_info.flatten(),
|
1319
|
+
local_info,
|
1824
1320
|
group=self.tp_cpu_group,
|
1825
1321
|
)
|
1322
|
+
global_num_tokens = global_info[:, 0, 0].tolist()
|
1323
|
+
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
1324
|
+
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
1325
|
+
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
1826
1326
|
|
1827
|
-
if local_batch is None and
|
1327
|
+
if local_batch is None and max(global_num_tokens) > 0:
|
1828
1328
|
local_batch = self.get_idle_batch()
|
1829
1329
|
|
1830
1330
|
if local_batch is not None:
|
1831
|
-
local_batch.global_num_tokens = global_num_tokens
|
1331
|
+
local_batch.global_num_tokens = global_num_tokens
|
1332
|
+
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
1832
1333
|
|
1833
1334
|
# Check forward mode for cuda graph
|
1834
1335
|
if not self.server_args.disable_cuda_graph:
|
1835
|
-
|
1836
|
-
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
1837
|
-
dtype=torch.int32,
|
1838
|
-
)
|
1839
|
-
torch.distributed.all_reduce(
|
1840
|
-
forward_mode_state,
|
1841
|
-
op=torch.distributed.ReduceOp.MIN,
|
1842
|
-
group=self.tp_cpu_group,
|
1843
|
-
)
|
1844
|
-
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
1336
|
+
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
1845
1337
|
|
1846
|
-
return local_batch
|
1338
|
+
return local_batch, any(is_extend_in_batch)
|
1847
1339
|
|
1848
1340
|
def get_idle_batch(self):
|
1849
1341
|
idle_batch = ScheduleBatch.init_new(
|
@@ -1926,9 +1418,7 @@ class Scheduler:
|
|
1926
1418
|
|
1927
1419
|
def flush_cache(self):
|
1928
1420
|
"""Flush the memory pool and cache."""
|
1929
|
-
if len(self.waiting_queue) == 0 and (
|
1930
|
-
self.running_batch is None or len(self.running_batch.reqs) == 0
|
1931
|
-
):
|
1421
|
+
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
1932
1422
|
self.cur_batch = None
|
1933
1423
|
self.last_batch = None
|
1934
1424
|
self.tree_cache.reset()
|
@@ -1954,7 +1444,7 @@ class Scheduler:
|
|
1954
1444
|
logging.warning(
|
1955
1445
|
f"Cache not flushed because there are pending requests. "
|
1956
1446
|
f"#queue-req: {len(self.waiting_queue)}, "
|
1957
|
-
f"#running-req: {
|
1447
|
+
f"#running-req: {len(self.running_batch.reqs)}"
|
1958
1448
|
)
|
1959
1449
|
if_success = False
|
1960
1450
|
return if_success
|
@@ -2004,24 +1494,24 @@ class Scheduler:
|
|
2004
1494
|
|
2005
1495
|
def abort_request(self, recv_req: AbortReq):
|
2006
1496
|
# Delete requests in the waiting queue
|
2007
|
-
to_del =
|
1497
|
+
to_del = []
|
2008
1498
|
for i, req in enumerate(self.waiting_queue):
|
2009
|
-
if req.rid
|
2010
|
-
to_del
|
1499
|
+
if req.rid.startswith(recv_req.rid):
|
1500
|
+
to_del.append(i)
|
2011
1501
|
break
|
2012
1502
|
|
2013
|
-
|
2014
|
-
|
1503
|
+
# Sort in reverse order to avoid index issues when deleting
|
1504
|
+
for i in sorted(to_del, reverse=True):
|
1505
|
+
req = self.waiting_queue.pop(i)
|
2015
1506
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2016
1507
|
return
|
2017
1508
|
|
2018
1509
|
# Delete requests in the running batch
|
2019
|
-
|
2020
|
-
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2024
|
-
break
|
1510
|
+
for req in self.running_batch.reqs:
|
1511
|
+
if req.rid.startswith(recv_req.rid) and not req.finished():
|
1512
|
+
logger.debug(f"Abort running request. {req.rid=}")
|
1513
|
+
req.to_abort = True
|
1514
|
+
return
|
2025
1515
|
|
2026
1516
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
2027
1517
|
raise NotImplementedError()
|
@@ -2228,9 +1718,16 @@ def run_scheduler_process(
|
|
2228
1718
|
dp_rank: Optional[int],
|
2229
1719
|
pipe_writer,
|
2230
1720
|
):
|
1721
|
+
|
1722
|
+
# Generate the prefix
|
1723
|
+
if dp_rank is None:
|
1724
|
+
prefix = f" TP{tp_rank}"
|
1725
|
+
else:
|
1726
|
+
prefix = f" DP{dp_rank} TP{tp_rank}"
|
1727
|
+
|
2231
1728
|
# Config the process
|
2232
1729
|
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
|
2233
|
-
setproctitle.setproctitle(f"sglang::
|
1730
|
+
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
2234
1731
|
faulthandler.enable()
|
2235
1732
|
parent_process = psutil.Process().parent()
|
2236
1733
|
|
@@ -2239,10 +1736,6 @@ def run_scheduler_process(
|
|
2239
1736
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2240
1737
|
|
2241
1738
|
# Configure the logger
|
2242
|
-
if dp_rank is None:
|
2243
|
-
prefix = f" TP{tp_rank}"
|
2244
|
-
else:
|
2245
|
-
prefix = f" DP{dp_rank} TP{tp_rank}"
|
2246
1739
|
configure_logger(server_args, prefix=prefix)
|
2247
1740
|
suppress_other_loggers()
|
2248
1741
|
|