sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
15
|
|
18
16
|
import logging
|
@@ -36,9 +34,12 @@ from sglang.srt.managers.io_struct import (
|
|
36
34
|
AbortReq,
|
37
35
|
BatchEmbeddingOut,
|
38
36
|
BatchTokenIDOut,
|
37
|
+
CloseSessionReqInput,
|
39
38
|
FlushCacheReq,
|
40
39
|
GetMemPoolSizeReq,
|
41
40
|
GetMemPoolSizeReqOutput,
|
41
|
+
OpenSessionReqInput,
|
42
|
+
OpenSessionReqOutput,
|
42
43
|
ProfileReq,
|
43
44
|
TokenizedEmbeddingReqInput,
|
44
45
|
TokenizedGenerateReqInput,
|
@@ -58,16 +59,20 @@ from sglang.srt.managers.schedule_policy import (
|
|
58
59
|
PrefillAdder,
|
59
60
|
SchedulePolicy,
|
60
61
|
)
|
62
|
+
from sglang.srt.managers.session_controller import Session
|
61
63
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
62
64
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
63
65
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
64
66
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
65
67
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
68
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
66
69
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
67
70
|
from sglang.srt.utils import (
|
68
71
|
broadcast_pyobj,
|
69
72
|
configure_logger,
|
73
|
+
crash_on_warnings,
|
70
74
|
get_zmq_socket,
|
75
|
+
gpu_proc_affinity,
|
71
76
|
kill_parent_process,
|
72
77
|
set_random_seed,
|
73
78
|
suppress_other_loggers,
|
@@ -76,12 +81,8 @@ from sglang.utils import get_exception_traceback
|
|
76
81
|
|
77
82
|
logger = logging.getLogger(__name__)
|
78
83
|
|
79
|
-
|
80
|
-
# Crash on warning if we are running CI tests
|
81
|
-
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
82
|
-
|
83
84
|
# Test retract decode
|
84
|
-
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
85
|
+
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
|
85
86
|
|
86
87
|
|
87
88
|
class Scheduler:
|
@@ -103,14 +104,17 @@ class Scheduler:
|
|
103
104
|
self.disable_jump_forward = server_args.disable_jump_forward
|
104
105
|
self.lora_paths = server_args.lora_paths
|
105
106
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
106
|
-
self.enable_overlap = server_args.
|
107
|
+
self.enable_overlap = not server_args.disable_overlap_schedule
|
107
108
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
108
109
|
self.enable_metrics = server_args.enable_metrics
|
109
110
|
|
111
|
+
# Session info
|
112
|
+
self.sessions = {}
|
113
|
+
|
110
114
|
# Init inter-process communication
|
111
115
|
context = zmq.Context(2)
|
112
116
|
|
113
|
-
if self.tp_rank == 0:
|
117
|
+
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
114
118
|
self.recv_from_tokenizer = get_zmq_socket(
|
115
119
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
116
120
|
)
|
@@ -160,6 +164,14 @@ class Scheduler:
|
|
160
164
|
trust_remote_code=server_args.trust_remote_code,
|
161
165
|
)
|
162
166
|
|
167
|
+
# Check whether overlap can be enabled
|
168
|
+
if not self.is_generation:
|
169
|
+
self.enable_overlap = False
|
170
|
+
logger.info("Overlap scheduler is disabled for embedding models.")
|
171
|
+
|
172
|
+
if self.enable_overlap:
|
173
|
+
self.disable_jump_forward = True
|
174
|
+
|
163
175
|
# Launch a tensor parallel worker
|
164
176
|
if self.enable_overlap:
|
165
177
|
TpWorkerClass = TpModelWorkerClient
|
@@ -223,8 +235,12 @@ class Scheduler:
|
|
223
235
|
|
224
236
|
# Init running status
|
225
237
|
self.waiting_queue: List[Req] = []
|
238
|
+
# The running decoding batch for continuous batching
|
226
239
|
self.running_batch: Optional[ScheduleBatch] = None
|
240
|
+
# The current forward batch
|
227
241
|
self.cur_batch: Optional[ScheduleBatch] = None
|
242
|
+
# The current forward batch
|
243
|
+
self.last_batch: Optional[ScheduleBatch] = None
|
228
244
|
self.forward_ct = 0
|
229
245
|
self.forward_ct_decode = 0
|
230
246
|
self.num_generated_tokens = 0
|
@@ -286,6 +302,9 @@ class Scheduler:
|
|
286
302
|
) / global_config.default_new_token_ratio_decay_steps
|
287
303
|
self.new_token_ratio = self.init_new_token_ratio
|
288
304
|
|
305
|
+
# Tells whether the current running batch is full so that we can skip
|
306
|
+
# the check of whether to prefill new requests.
|
307
|
+
# This is an optimization to reduce the overhead of the prefill check.
|
289
308
|
self.batch_is_full = False
|
290
309
|
|
291
310
|
# Init watchdog thread
|
@@ -337,46 +356,34 @@ class Scheduler:
|
|
337
356
|
|
338
357
|
kill_parent_process()
|
339
358
|
|
340
|
-
@torch.
|
359
|
+
@torch.no_grad()
|
341
360
|
def event_loop_normal(self):
|
342
|
-
"""A normal
|
343
|
-
self.last_batch = None
|
344
|
-
|
361
|
+
"""A normal scheduler loop."""
|
345
362
|
while True:
|
346
363
|
recv_reqs = self.recv_requests()
|
347
364
|
self.process_input_requests(recv_reqs)
|
348
365
|
|
349
366
|
batch = self.get_next_batch_to_run()
|
367
|
+
if self.server_args.enable_dp_attention:
|
368
|
+
batch = self.prepare_dp_attn_batch(batch)
|
369
|
+
|
350
370
|
self.cur_batch = batch
|
351
371
|
|
352
372
|
if batch:
|
353
373
|
result = self.run_batch(batch)
|
354
374
|
self.process_batch_result(batch, result)
|
355
|
-
|
356
|
-
# Decode multiple steps to reduce the overhead
|
357
|
-
if batch.forward_mode.is_decode():
|
358
|
-
for _ in range(self.server_args.num_continuous_decode_steps - 1):
|
359
|
-
if not self.running_batch:
|
360
|
-
break
|
361
|
-
self.update_running_batch()
|
362
|
-
if not self.running_batch:
|
363
|
-
break
|
364
|
-
result = self.run_batch(batch)
|
365
|
-
self.process_batch_result(batch, result)
|
366
375
|
else:
|
376
|
+
# Self-check and re-init some states when the server is idle
|
367
377
|
self.check_memory()
|
368
378
|
self.new_token_ratio = self.init_new_token_ratio
|
369
379
|
|
370
380
|
self.last_batch = batch
|
371
381
|
|
372
|
-
@torch.
|
382
|
+
@torch.no_grad()
|
373
383
|
def event_loop_overlap(self):
|
374
384
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
375
385
|
result_queue = deque()
|
376
386
|
|
377
|
-
self.last_batch = None
|
378
|
-
self.running_batch = None
|
379
|
-
|
380
387
|
while True:
|
381
388
|
recv_reqs = self.recv_requests()
|
382
389
|
self.process_input_requests(recv_reqs)
|
@@ -387,17 +394,86 @@ class Scheduler:
|
|
387
394
|
result = self.run_batch(batch)
|
388
395
|
result_queue.append((batch.copy(), result))
|
389
396
|
|
397
|
+
if self.last_batch is None:
|
398
|
+
# A dummy first batch to start the pipeline for overlap scheduler.
|
399
|
+
# It is now used for triggering the sampling_info_done event.
|
400
|
+
tmp_batch = ScheduleBatch(
|
401
|
+
reqs=None,
|
402
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
403
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
404
|
+
)
|
405
|
+
self.process_batch_result(tmp_batch, None)
|
406
|
+
|
390
407
|
if self.last_batch:
|
391
408
|
tmp_batch, tmp_result = result_queue.popleft()
|
409
|
+
tmp_batch.next_batch_sampling_info = (
|
410
|
+
self.tp_worker.cur_sampling_info if batch else None
|
411
|
+
)
|
392
412
|
self.process_batch_result(tmp_batch, tmp_result)
|
393
413
|
elif batch is None:
|
414
|
+
# Self-check and re-init some states when the server is idle
|
394
415
|
self.check_memory()
|
395
416
|
self.new_token_ratio = self.init_new_token_ratio
|
396
417
|
|
397
418
|
self.last_batch = batch
|
398
419
|
|
420
|
+
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
421
|
+
# Check if other DP workers have running batches
|
422
|
+
if local_batch is None:
|
423
|
+
num_tokens = 0
|
424
|
+
elif local_batch.forward_mode.is_decode():
|
425
|
+
num_tokens = local_batch.batch_size()
|
426
|
+
else:
|
427
|
+
num_tokens = local_batch.extend_num_tokens
|
428
|
+
|
429
|
+
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
430
|
+
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
431
|
+
torch.distributed.all_gather_into_tensor(
|
432
|
+
global_num_tokens,
|
433
|
+
local_num_tokens,
|
434
|
+
group=self.tp_cpu_group,
|
435
|
+
)
|
436
|
+
|
437
|
+
if local_batch is None and global_num_tokens.max().item() > 0:
|
438
|
+
local_batch = self.get_idle_batch()
|
439
|
+
|
440
|
+
if local_batch is not None:
|
441
|
+
local_batch.global_num_tokens = global_num_tokens.tolist()
|
442
|
+
|
443
|
+
# Check forward mode for cuda graph
|
444
|
+
if not self.server_args.disable_cuda_graph:
|
445
|
+
forward_mode_state = torch.tensor(
|
446
|
+
(
|
447
|
+
1
|
448
|
+
if local_batch.forward_mode.is_decode()
|
449
|
+
or local_batch.forward_mode.is_idle()
|
450
|
+
else 0
|
451
|
+
),
|
452
|
+
dtype=torch.int32,
|
453
|
+
)
|
454
|
+
torch.distributed.all_reduce(
|
455
|
+
forward_mode_state,
|
456
|
+
op=torch.distributed.ReduceOp.MIN,
|
457
|
+
group=self.tp_cpu_group,
|
458
|
+
)
|
459
|
+
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
460
|
+
|
461
|
+
return local_batch
|
462
|
+
|
463
|
+
def get_idle_batch(self):
|
464
|
+
idle_batch = ScheduleBatch.init_new(
|
465
|
+
[],
|
466
|
+
self.req_to_token_pool,
|
467
|
+
self.token_to_kv_pool,
|
468
|
+
self.tree_cache,
|
469
|
+
self.model_config,
|
470
|
+
self.enable_overlap,
|
471
|
+
)
|
472
|
+
idle_batch.prepare_for_idle()
|
473
|
+
return idle_batch
|
474
|
+
|
399
475
|
def recv_requests(self):
|
400
|
-
if self.tp_rank == 0:
|
476
|
+
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
401
477
|
recv_reqs = []
|
402
478
|
|
403
479
|
while True:
|
@@ -409,7 +485,7 @@ class Scheduler:
|
|
409
485
|
else:
|
410
486
|
recv_reqs = None
|
411
487
|
|
412
|
-
if self.tp_size != 1:
|
488
|
+
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
413
489
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
414
490
|
return recv_reqs
|
415
491
|
|
@@ -433,6 +509,11 @@ class Scheduler:
|
|
433
509
|
self.start_profile()
|
434
510
|
else:
|
435
511
|
self.stop_profile()
|
512
|
+
elif isinstance(recv_req, OpenSessionReqInput):
|
513
|
+
session_id = self.open_session(recv_req)
|
514
|
+
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
515
|
+
elif isinstance(recv_req, CloseSessionReqInput):
|
516
|
+
self.close_session(recv_req)
|
436
517
|
elif isinstance(recv_req, GetMemPoolSizeReq):
|
437
518
|
self.send_to_tokenizer.send_pyobj(
|
438
519
|
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
@@ -444,14 +525,37 @@ class Scheduler:
|
|
444
525
|
self,
|
445
526
|
recv_req: TokenizedGenerateReqInput,
|
446
527
|
):
|
447
|
-
|
448
|
-
|
449
|
-
recv_req.
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
528
|
+
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
529
|
+
# Create a new request
|
530
|
+
if recv_req.input_embeds is not None:
|
531
|
+
# Generate fake input_ids based on the length of input_embeds
|
532
|
+
seq_length = len(recv_req.input_embeds)
|
533
|
+
fake_input_ids = [1] * seq_length
|
534
|
+
recv_req.input_ids = fake_input_ids
|
535
|
+
|
536
|
+
req = Req(
|
537
|
+
recv_req.rid,
|
538
|
+
recv_req.input_text,
|
539
|
+
recv_req.input_ids,
|
540
|
+
recv_req.sampling_params,
|
541
|
+
lora_path=recv_req.lora_path,
|
542
|
+
input_embeds=recv_req.input_embeds,
|
543
|
+
)
|
544
|
+
req.tokenizer = self.tokenizer
|
545
|
+
|
546
|
+
if recv_req.session_id is not None:
|
547
|
+
req.finished_reason = FINISH_ABORT(
|
548
|
+
f"Invalid request: session id {recv_req.session_id} does not exist"
|
549
|
+
)
|
550
|
+
self.waiting_queue.append(req)
|
551
|
+
return
|
552
|
+
else:
|
553
|
+
# Create a new request from a previsou session
|
554
|
+
session = self.sessions[recv_req.session_id]
|
555
|
+
req = session.create_req(recv_req, self.tokenizer)
|
556
|
+
if isinstance(req.finished_reason, FINISH_ABORT):
|
557
|
+
self.waiting_queue.append(req)
|
558
|
+
return
|
455
559
|
|
456
560
|
# Image inputs
|
457
561
|
if recv_req.image_inputs is not None:
|
@@ -462,6 +566,15 @@ class Scheduler:
|
|
462
566
|
req.origin_input_ids_unpadded, req.image_inputs
|
463
567
|
)
|
464
568
|
|
569
|
+
if len(req.origin_input_ids) > self.max_req_input_len:
|
570
|
+
req.finished_reason = FINISH_ABORT(
|
571
|
+
"Image request length is longer than the KV cache pool size or "
|
572
|
+
"the max context length aborting because you cannot truncate the image embeds"
|
573
|
+
)
|
574
|
+
req.sampling_params.max_new_tokens = 0
|
575
|
+
self.waiting_queue.append(req)
|
576
|
+
return
|
577
|
+
|
465
578
|
req.return_logprob = recv_req.return_logprob
|
466
579
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
467
580
|
req.stream = recv_req.stream
|
@@ -599,58 +712,50 @@ class Scheduler:
|
|
599
712
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
600
713
|
)
|
601
714
|
if available_size != self.max_total_num_tokens:
|
602
|
-
|
603
|
-
"Warning: "
|
604
|
-
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
715
|
+
msg = (
|
605
716
|
"KV cache pool leak detected!"
|
717
|
+
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
606
718
|
)
|
607
|
-
|
719
|
+
warnings.warn(msg)
|
720
|
+
if crash_on_warnings():
|
721
|
+
raise ValueError(msg)
|
608
722
|
|
609
723
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
610
|
-
|
611
|
-
"Warning: "
|
612
|
-
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
613
|
-
f"total slots={self.req_to_token_pool.size}\n"
|
724
|
+
msg = (
|
614
725
|
"Memory pool leak detected!"
|
726
|
+
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
727
|
+
f"total_size={self.req_to_token_pool.size}\n"
|
615
728
|
)
|
616
|
-
|
729
|
+
warnings.warn(msg)
|
730
|
+
if crash_on_warnings():
|
731
|
+
raise ValueError(msg)
|
617
732
|
|
618
733
|
def get_next_batch_to_run(self):
|
619
734
|
# Merge the prefill batch into the running batch
|
620
|
-
if (
|
621
|
-
self.last_batch
|
622
|
-
and not self.last_batch.forward_mode.is_decode()
|
623
|
-
and not self.last_batch.is_empty()
|
624
|
-
):
|
735
|
+
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
625
736
|
if self.being_chunked_req:
|
737
|
+
# Move the chunked request out of the batch
|
626
738
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
627
739
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
628
|
-
# Inflight request keeps its rid but will get a new req_pool_idx
|
740
|
+
# Inflight request keeps its rid but will get a new req_pool_idx
|
629
741
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
630
742
|
self.batch_is_full = False
|
743
|
+
|
631
744
|
if not self.last_batch.is_empty():
|
632
745
|
if self.running_batch is None:
|
633
746
|
self.running_batch = self.last_batch
|
634
747
|
else:
|
635
748
|
self.running_batch.merge_batch(self.last_batch)
|
636
749
|
|
637
|
-
#
|
750
|
+
# Run prefill first if possible
|
638
751
|
new_batch = self.get_new_batch_prefill()
|
639
752
|
if new_batch is not None:
|
640
753
|
return new_batch
|
641
754
|
|
642
|
-
# Check memory
|
643
|
-
if self.running_batch is None:
|
644
|
-
return
|
645
|
-
|
646
755
|
# Run decode
|
647
|
-
|
648
|
-
self.update_running_batch()
|
649
|
-
if not self.running_batch:
|
650
|
-
self.batch_is_full = False
|
756
|
+
if self.running_batch is None:
|
651
757
|
return None
|
652
|
-
|
653
|
-
self.batch_is_full = False
|
758
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
654
759
|
return self.running_batch
|
655
760
|
|
656
761
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
@@ -746,14 +851,20 @@ class Scheduler:
|
|
746
851
|
self.token_to_kv_pool,
|
747
852
|
self.tree_cache,
|
748
853
|
self.model_config,
|
854
|
+
self.enable_overlap,
|
749
855
|
)
|
750
856
|
new_batch.prepare_for_extend()
|
751
857
|
|
752
858
|
# Mixed-style chunked prefill
|
753
|
-
if
|
859
|
+
if (
|
860
|
+
self.is_mixed_chunk
|
861
|
+
and self.running_batch is not None
|
862
|
+
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
863
|
+
):
|
864
|
+
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
754
865
|
self.running_batch.filter_batch()
|
755
866
|
if not self.running_batch.is_empty():
|
756
|
-
self.running_batch.prepare_for_decode(
|
867
|
+
self.running_batch.prepare_for_decode()
|
757
868
|
new_batch.mix_with_running(self.running_batch)
|
758
869
|
new_batch.decoding_reqs = self.running_batch.reqs
|
759
870
|
self.running_batch = None
|
@@ -762,15 +873,16 @@ class Scheduler:
|
|
762
873
|
|
763
874
|
return new_batch
|
764
875
|
|
765
|
-
def update_running_batch(self):
|
876
|
+
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
766
877
|
"""Update the current running decoding batch."""
|
767
878
|
global test_retract
|
768
|
-
|
879
|
+
|
880
|
+
initial_bs = batch.batch_size()
|
769
881
|
|
770
882
|
batch.filter_batch()
|
771
883
|
if batch.is_empty():
|
772
|
-
self.
|
773
|
-
return
|
884
|
+
self.batch_is_full = False
|
885
|
+
return None
|
774
886
|
|
775
887
|
# Check if decode out of memory
|
776
888
|
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
@@ -796,11 +908,15 @@ class Scheduler:
|
|
796
908
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
797
909
|
self.waiting_queue.extend(jump_forward_reqs)
|
798
910
|
if batch.is_empty():
|
799
|
-
self.
|
800
|
-
return
|
911
|
+
self.batch_is_full = False
|
912
|
+
return None
|
913
|
+
|
914
|
+
if batch.batch_size() < initial_bs:
|
915
|
+
self.batch_is_full = False
|
801
916
|
|
802
917
|
# Update batch tensors
|
803
|
-
batch.prepare_for_decode(
|
918
|
+
batch.prepare_for_decode()
|
919
|
+
return batch
|
804
920
|
|
805
921
|
def run_batch(self, batch: ScheduleBatch):
|
806
922
|
"""Run a batch."""
|
@@ -812,6 +928,10 @@ class Scheduler:
|
|
812
928
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
813
929
|
model_worker_batch
|
814
930
|
)
|
931
|
+
elif batch.forward_mode.is_idle():
|
932
|
+
model_worker_batch = batch.get_model_worker_batch()
|
933
|
+
self.tp_worker.forward_batch_idle(model_worker_batch)
|
934
|
+
return
|
815
935
|
else:
|
816
936
|
logits_output = None
|
817
937
|
if self.skip_tokenizer_init:
|
@@ -834,8 +954,12 @@ class Scheduler:
|
|
834
954
|
self.process_batch_result_decode(batch, result)
|
835
955
|
if batch.is_empty():
|
836
956
|
self.running_batch = None
|
837
|
-
|
957
|
+
elif batch.forward_mode.is_extend():
|
838
958
|
self.process_batch_result_prefill(batch, result)
|
959
|
+
elif batch.forward_mode.is_dummy_first():
|
960
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
961
|
+
torch.cuda.current_stream().synchronize()
|
962
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
839
963
|
|
840
964
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
841
965
|
|
@@ -843,7 +967,7 @@ class Scheduler:
|
|
843
967
|
logits_output, next_token_ids, bid = result
|
844
968
|
|
845
969
|
if self.enable_overlap:
|
846
|
-
logits_output, next_token_ids = self.tp_worker.
|
970
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
847
971
|
else:
|
848
972
|
# Move next_token_ids and logprobs to cpu
|
849
973
|
if batch.return_logprob:
|
@@ -863,14 +987,19 @@ class Scheduler:
|
|
863
987
|
|
864
988
|
# Check finish conditions
|
865
989
|
logprob_pt = 0
|
866
|
-
for i, req in enumerate(batch.reqs):
|
990
|
+
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
867
991
|
if req.is_retracted:
|
868
992
|
continue
|
869
993
|
|
994
|
+
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
995
|
+
# Free the one delayed token for the mixed decode batch
|
996
|
+
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
997
|
+
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
998
|
+
continue
|
999
|
+
|
870
1000
|
if req.is_being_chunked <= 0:
|
871
|
-
# Inflight reqs' prefill is not finished
|
872
1001
|
req.completion_tokens_wo_jump_forward += 1
|
873
|
-
req.output_ids.append(
|
1002
|
+
req.output_ids.append(next_token_id)
|
874
1003
|
req.check_finished()
|
875
1004
|
|
876
1005
|
if req.finished():
|
@@ -878,16 +1007,22 @@ class Scheduler:
|
|
878
1007
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
879
1008
|
self.tree_cache.cache_unfinished_req(req)
|
880
1009
|
|
881
|
-
if req.grammar is not None:
|
882
|
-
req.grammar.accept_token(next_token_ids[i])
|
883
|
-
|
884
1010
|
if req.return_logprob:
|
885
1011
|
logprob_pt += self.add_logprob_return_values(
|
886
1012
|
i, req, logprob_pt, next_token_ids, logits_output
|
887
1013
|
)
|
1014
|
+
|
1015
|
+
if req.grammar is not None:
|
1016
|
+
req.grammar.accept_token(next_token_id)
|
888
1017
|
else:
|
1018
|
+
# Inflight reqs' prefill is not finished
|
889
1019
|
req.is_being_chunked -= 1
|
890
1020
|
|
1021
|
+
if batch.next_batch_sampling_info:
|
1022
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1023
|
+
torch.cuda.current_stream().synchronize()
|
1024
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1025
|
+
|
891
1026
|
else: # embedding or reward model
|
892
1027
|
embeddings, bid = result
|
893
1028
|
embeddings = embeddings.tolist()
|
@@ -898,18 +1033,18 @@ class Scheduler:
|
|
898
1033
|
continue
|
899
1034
|
|
900
1035
|
req.embedding = embeddings[i]
|
901
|
-
if req.is_being_chunked
|
902
|
-
|
903
|
-
else:
|
904
|
-
# Inflight reqs' prefill is not finished
|
905
|
-
# dummy output token for embedding models
|
1036
|
+
if req.is_being_chunked <= 0:
|
1037
|
+
# Dummy output token for embedding models
|
906
1038
|
req.output_ids.append(0)
|
907
1039
|
req.check_finished()
|
908
1040
|
|
909
|
-
|
910
|
-
|
1041
|
+
if req.finished():
|
1042
|
+
self.tree_cache.cache_finished_req(req)
|
1043
|
+
else:
|
1044
|
+
self.tree_cache.cache_unfinished_req(req)
|
911
1045
|
else:
|
912
|
-
|
1046
|
+
# Inflight reqs' prefill is not finished
|
1047
|
+
req.is_being_chunked -= 1
|
913
1048
|
|
914
1049
|
self.stream_output(batch.reqs)
|
915
1050
|
|
@@ -918,7 +1053,7 @@ class Scheduler:
|
|
918
1053
|
self.num_generated_tokens += len(batch.reqs)
|
919
1054
|
|
920
1055
|
if self.enable_overlap:
|
921
|
-
logits_output, next_token_ids = self.tp_worker.
|
1056
|
+
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
922
1057
|
next_token_logprobs = logits_output.next_token_logprobs
|
923
1058
|
else:
|
924
1059
|
# Move next_token_ids and logprobs to cpu
|
@@ -936,7 +1071,8 @@ class Scheduler:
|
|
936
1071
|
if req.is_retracted:
|
937
1072
|
continue
|
938
1073
|
|
939
|
-
if self.
|
1074
|
+
if self.enable_overlap and req.finished():
|
1075
|
+
# Free the one delayed token
|
940
1076
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
941
1077
|
continue
|
942
1078
|
|
@@ -944,9 +1080,6 @@ class Scheduler:
|
|
944
1080
|
req.output_ids.append(next_token_id)
|
945
1081
|
req.check_finished()
|
946
1082
|
|
947
|
-
if req.grammar is not None:
|
948
|
-
req.grammar.accept_token(next_token_id)
|
949
|
-
|
950
1083
|
if req.finished():
|
951
1084
|
self.tree_cache.cache_finished_req(req)
|
952
1085
|
|
@@ -957,6 +1090,14 @@ class Scheduler:
|
|
957
1090
|
if req.top_logprobs_num > 0:
|
958
1091
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
959
1092
|
|
1093
|
+
if req.grammar is not None:
|
1094
|
+
req.grammar.accept_token(next_token_id)
|
1095
|
+
|
1096
|
+
if batch.next_batch_sampling_info:
|
1097
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1098
|
+
torch.cuda.current_stream().synchronize()
|
1099
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1100
|
+
|
960
1101
|
self.stream_output(batch.reqs)
|
961
1102
|
|
962
1103
|
self.token_to_kv_pool.free_group_end()
|
@@ -1234,6 +1375,25 @@ class Scheduler:
|
|
1234
1375
|
)
|
1235
1376
|
logger.info("Profiler is done")
|
1236
1377
|
|
1378
|
+
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
1379
|
+
# handle error
|
1380
|
+
session_id = recv_req.session_id
|
1381
|
+
if session_id in self.sessions:
|
1382
|
+
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1383
|
+
else:
|
1384
|
+
self.sessions[session_id] = Session(
|
1385
|
+
recv_req.capacity_of_str_len, session_id
|
1386
|
+
)
|
1387
|
+
return session_id
|
1388
|
+
|
1389
|
+
def close_session(self, recv_req: CloseSessionReqInput):
|
1390
|
+
# handle error
|
1391
|
+
session_id = recv_req.session_id
|
1392
|
+
if session_id not in self.sessions:
|
1393
|
+
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
1394
|
+
else:
|
1395
|
+
del self.sessions[session_id]
|
1396
|
+
|
1237
1397
|
|
1238
1398
|
def run_scheduler_process(
|
1239
1399
|
server_args: ServerArgs,
|
@@ -1243,6 +1403,13 @@ def run_scheduler_process(
|
|
1243
1403
|
dp_rank: Optional[int],
|
1244
1404
|
pipe_writer,
|
1245
1405
|
):
|
1406
|
+
# set cpu affinity to this gpu process
|
1407
|
+
gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1408
|
+
|
1409
|
+
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
1410
|
+
if dp_rank is None and "DP_RANK" in os.environ:
|
1411
|
+
dp_rank = int(os.environ["DP_RANK"])
|
1412
|
+
|
1246
1413
|
if dp_rank is None:
|
1247
1414
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1248
1415
|
else:
|
@@ -1252,8 +1419,10 @@ def run_scheduler_process(
|
|
1252
1419
|
|
1253
1420
|
try:
|
1254
1421
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1255
|
-
pipe_writer.send(
|
1256
|
-
|
1422
|
+
pipe_writer.send(
|
1423
|
+
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
|
1424
|
+
)
|
1425
|
+
if scheduler.enable_overlap:
|
1257
1426
|
scheduler.event_loop_overlap()
|
1258
1427
|
else:
|
1259
1428
|
scheduler.event_loop_normal()
|