sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -174,7 +174,7 @@ class SchedulePolicy:
|
|
174
174
|
self.waiting_queue_radix_tree.reset()
|
175
175
|
|
176
176
|
for r in waiting_queue:
|
177
|
-
prefix_ids = r.
|
177
|
+
prefix_ids = r.origin_input_ids + r.output_ids
|
178
178
|
extra_key = r.extra_key
|
179
179
|
|
180
180
|
# NOTE: the prefix_indices must always be aligned with last_node
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -25,12 +25,14 @@ from concurrent import futures
|
|
25
25
|
from dataclasses import dataclass
|
26
26
|
from http import HTTPStatus
|
27
27
|
from types import SimpleNamespace
|
28
|
-
from typing import Dict, List, Optional, Tuple, Union
|
28
|
+
from typing import Deque, Dict, List, Optional, Tuple, Union
|
29
29
|
|
30
30
|
import psutil
|
31
31
|
import setproctitle
|
32
32
|
import torch
|
33
33
|
import zmq
|
34
|
+
from torch.cuda import Stream as CudaStream
|
35
|
+
from torch.cuda import StreamContext as CudaStreamContext
|
34
36
|
from torch.distributed import barrier
|
35
37
|
|
36
38
|
from sglang.global_config import global_config
|
@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
|
|
112
114
|
UpdateWeightsFromTensorReqInput,
|
113
115
|
)
|
114
116
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
117
|
+
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
|
115
118
|
from sglang.srt.managers.schedule_batch import (
|
116
119
|
FINISH_ABORT,
|
120
|
+
ModelWorkerBatch,
|
117
121
|
MultimodalInputs,
|
118
122
|
Req,
|
119
123
|
RequestStage,
|
@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
139
143
|
SchedulerUpdateWeightsMixin,
|
140
144
|
)
|
141
145
|
from sglang.srt.managers.session_controller import Session
|
142
|
-
from sglang.srt.managers.tp_worker import TpModelWorker
|
143
|
-
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
144
146
|
from sglang.srt.managers.utils import validate_input_length
|
145
147
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
146
148
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
147
149
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
148
150
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
149
151
|
from sglang.srt.model_executor.forward_batch_info import (
|
150
|
-
|
152
|
+
ForwardBatch,
|
151
153
|
ForwardMode,
|
152
154
|
PPProxyTensors,
|
153
155
|
)
|
@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
|
201
203
|
|
202
204
|
@dataclass
|
203
205
|
class GenerationBatchResult:
|
204
|
-
logits_output: Optional[LogitsProcessorOutput]
|
205
|
-
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
206
|
-
next_token_ids: Optional[
|
207
|
-
|
206
|
+
logits_output: Optional[LogitsProcessorOutput] = None
|
207
|
+
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
208
|
+
next_token_ids: Optional[torch.Tensor] = None
|
209
|
+
num_accepted_tokens: Optional[int] = None
|
210
|
+
can_run_cuda_graph: bool = False
|
208
211
|
|
209
212
|
# For output processing
|
210
|
-
extend_input_len_per_req: List[int]
|
211
|
-
extend_logprob_start_len_per_req: List[int]
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
):
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
213
|
+
extend_input_len_per_req: Optional[List[int]] = None
|
214
|
+
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
215
|
+
|
216
|
+
# For overlap scheduling
|
217
|
+
copy_done: Optional[torch.cuda.Event] = None
|
218
|
+
delay_sample_launch: bool = False
|
219
|
+
forward_batch: Optional[ForwardBatch] = None
|
220
|
+
future_indices: Optional[FutureIndices] = None
|
221
|
+
|
222
|
+
def copy_to_cpu(self, return_logprob: bool = False):
|
223
|
+
"""Copy tensors to CPU in overlap scheduling.
|
224
|
+
Only the tensors which are needed for processing results are copied,
|
225
|
+
e.g., next_token_ids, logits outputs
|
226
|
+
"""
|
227
|
+
if return_logprob:
|
228
|
+
if self.logits_output.next_token_logits is not None:
|
229
|
+
self.logits_output.next_token_logits = (
|
230
|
+
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
231
|
+
)
|
232
|
+
if self.logits_output.input_token_logprobs is not None:
|
233
|
+
self.logits_output.input_token_logprobs = (
|
234
|
+
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
235
|
+
)
|
236
|
+
if self.logits_output.hidden_states is not None:
|
237
|
+
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
238
|
+
"cpu", non_blocking=True
|
239
|
+
)
|
240
|
+
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
241
|
+
self.copy_done.record()
|
230
242
|
|
231
243
|
@classmethod
|
232
244
|
def from_pp_proxy(
|
233
245
|
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
234
246
|
):
|
235
|
-
# TODO(lsyin):
|
236
|
-
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
237
|
-
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
247
|
+
# TODO(lsyin): refactor PP and avoid using dict
|
238
248
|
proxy_dict = next_pp_outputs.tensors
|
239
249
|
return cls(
|
240
250
|
logits_output=logits_output,
|
@@ -263,6 +273,48 @@ class Scheduler(
|
|
263
273
|
):
|
264
274
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
265
275
|
|
276
|
+
def launch_draft_worker(
|
277
|
+
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
278
|
+
):
|
279
|
+
if self.spec_algorithm.is_eagle():
|
280
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
281
|
+
|
282
|
+
self.draft_worker = EAGLEWorker(
|
283
|
+
gpu_id=gpu_id,
|
284
|
+
tp_rank=tp_rank,
|
285
|
+
moe_ep_rank=moe_ep_rank,
|
286
|
+
server_args=server_args,
|
287
|
+
nccl_port=port_args.nccl_port,
|
288
|
+
target_worker=self.tp_worker,
|
289
|
+
dp_rank=dp_rank,
|
290
|
+
)
|
291
|
+
elif self.spec_algorithm.is_standalone():
|
292
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
293
|
+
|
294
|
+
self.draft_worker = StandaloneWorker(
|
295
|
+
gpu_id=gpu_id,
|
296
|
+
tp_rank=tp_rank,
|
297
|
+
moe_ep_rank=moe_ep_rank,
|
298
|
+
server_args=server_args,
|
299
|
+
nccl_port=port_args.nccl_port,
|
300
|
+
target_worker=self.tp_worker,
|
301
|
+
dp_rank=dp_rank,
|
302
|
+
)
|
303
|
+
elif self.spec_algorithm.is_ngram():
|
304
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
305
|
+
|
306
|
+
self.draft_worker = NGRAMWorker(
|
307
|
+
gpu_id=gpu_id,
|
308
|
+
tp_rank=tp_rank,
|
309
|
+
moe_ep_rank=moe_ep_rank,
|
310
|
+
server_args=server_args,
|
311
|
+
nccl_port=port_args.nccl_port,
|
312
|
+
target_worker=self.tp_worker,
|
313
|
+
dp_rank=dp_rank,
|
314
|
+
)
|
315
|
+
else:
|
316
|
+
self.draft_worker = None
|
317
|
+
|
266
318
|
def __init__(
|
267
319
|
self,
|
268
320
|
server_args: ServerArgs,
|
@@ -388,12 +440,10 @@ class Scheduler(
|
|
388
440
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
389
441
|
|
390
442
|
# Launch a tensor parallel worker
|
391
|
-
if self.enable_overlap:
|
392
|
-
TpWorkerClass = TpModelWorkerClient
|
393
|
-
else:
|
394
|
-
TpWorkerClass = TpModelWorker
|
395
443
|
|
396
|
-
|
444
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
445
|
+
|
446
|
+
self.tp_worker = TpModelWorker(
|
397
447
|
server_args=server_args,
|
398
448
|
gpu_id=gpu_id,
|
399
449
|
tp_rank=tp_rank,
|
@@ -404,44 +454,9 @@ class Scheduler(
|
|
404
454
|
)
|
405
455
|
|
406
456
|
# Launch a draft worker for speculative decoding
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
self.draft_worker = EAGLEWorker(
|
411
|
-
gpu_id=gpu_id,
|
412
|
-
tp_rank=tp_rank,
|
413
|
-
moe_ep_rank=moe_ep_rank,
|
414
|
-
server_args=server_args,
|
415
|
-
nccl_port=port_args.nccl_port,
|
416
|
-
target_worker=self.tp_worker,
|
417
|
-
dp_rank=dp_rank,
|
418
|
-
)
|
419
|
-
elif self.spec_algorithm.is_standalone():
|
420
|
-
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
421
|
-
|
422
|
-
self.draft_worker = StandaloneWorker(
|
423
|
-
gpu_id=gpu_id,
|
424
|
-
tp_rank=tp_rank,
|
425
|
-
moe_ep_rank=moe_ep_rank,
|
426
|
-
server_args=server_args,
|
427
|
-
nccl_port=port_args.nccl_port,
|
428
|
-
target_worker=self.tp_worker,
|
429
|
-
dp_rank=dp_rank,
|
430
|
-
)
|
431
|
-
elif self.spec_algorithm.is_ngram():
|
432
|
-
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
433
|
-
|
434
|
-
self.draft_worker = NGRAMWorker(
|
435
|
-
gpu_id=gpu_id,
|
436
|
-
tp_rank=tp_rank,
|
437
|
-
moe_ep_rank=moe_ep_rank,
|
438
|
-
server_args=server_args,
|
439
|
-
nccl_port=port_args.nccl_port,
|
440
|
-
target_worker=self.tp_worker,
|
441
|
-
dp_rank=dp_rank,
|
442
|
-
)
|
443
|
-
else:
|
444
|
-
self.draft_worker = None
|
457
|
+
self.launch_draft_worker(
|
458
|
+
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
459
|
+
)
|
445
460
|
|
446
461
|
# Dispatch the model worker
|
447
462
|
if self.spec_algorithm.is_none():
|
@@ -464,8 +479,8 @@ class Scheduler(
|
|
464
479
|
_,
|
465
480
|
_,
|
466
481
|
) = self.tp_worker.get_worker_info()
|
467
|
-
if global_server_args_dict["
|
468
|
-
global_server_args_dict["
|
482
|
+
if global_server_args_dict["pp_max_micro_batch_size"] is None:
|
483
|
+
global_server_args_dict["pp_max_micro_batch_size"] = max(
|
469
484
|
self.max_running_requests // server_args.pp_size, 1
|
470
485
|
)
|
471
486
|
|
@@ -525,9 +540,11 @@ class Scheduler(
|
|
525
540
|
self.kv_transfer_speed_gb_s: float = 0.0
|
526
541
|
self.kv_transfer_latency_ms: float = 0.0
|
527
542
|
self.sessions: Dict[str, Session] = {}
|
528
|
-
self.
|
543
|
+
self.default_stream: CudaStream = torch.get_device_module(
|
544
|
+
self.device
|
545
|
+
).current_stream()
|
529
546
|
if self.device == "cpu":
|
530
|
-
self.
|
547
|
+
self.default_stream.synchronize = lambda: None # No-op for CPU
|
531
548
|
self.forward_sleep_time = None
|
532
549
|
|
533
550
|
# Init chunked prefill
|
@@ -618,6 +635,9 @@ class Scheduler(
|
|
618
635
|
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
619
636
|
self.init_deterministic_inference_config()
|
620
637
|
|
638
|
+
# Init overlap
|
639
|
+
self.init_overlap()
|
640
|
+
|
621
641
|
# Init request dispatcher
|
622
642
|
self._request_dispatcher = TypeBasedDispatcher(
|
623
643
|
[
|
@@ -777,6 +797,7 @@ class Scheduler(
|
|
777
797
|
sliding_window_size=self.sliding_window_size,
|
778
798
|
page_size=self.page_size,
|
779
799
|
disable=server_args.disable_radix_cache,
|
800
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
780
801
|
)
|
781
802
|
elif server_args.enable_lmcache:
|
782
803
|
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
@@ -931,6 +952,32 @@ class Scheduler(
|
|
931
952
|
# The prefill requests that are in the middle of kv sending
|
932
953
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
933
954
|
|
955
|
+
def init_overlap(self):
|
956
|
+
if not self.enable_overlap:
|
957
|
+
return
|
958
|
+
|
959
|
+
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
960
|
+
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
|
961
|
+
self.device
|
962
|
+
).stream(self.forward_stream)
|
963
|
+
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
964
|
+
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
|
965
|
+
self.device
|
966
|
+
).stream(self.copy_stream)
|
967
|
+
|
968
|
+
self.future_map = FutureMap(self.max_running_requests, self.device)
|
969
|
+
self.batch_record_buf = [None] * 2
|
970
|
+
self.batch_record_ct = 0
|
971
|
+
|
972
|
+
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
|
973
|
+
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
|
974
|
+
# NOTE: More Reliable: record all tensors into the forward stream
|
975
|
+
# NOTE: - for all future tensors, we shall always read from future map
|
976
|
+
# - for all non-future tensors (produced only by schedule stream),
|
977
|
+
# we shall keep its reference not being release during all the forwarding pass
|
978
|
+
self.batch_record_ct = (self.batch_record_ct + 1) % 2
|
979
|
+
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
|
980
|
+
|
934
981
|
def init_moe_config(self):
|
935
982
|
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
936
983
|
initialize_moe_config(self.server_args)
|
@@ -957,9 +1004,11 @@ class Scheduler(
|
|
957
1004
|
@DynamicGradMode()
|
958
1005
|
def event_loop_overlap(self):
|
959
1006
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
960
|
-
self.result_queue = deque()
|
1007
|
+
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
961
1008
|
|
962
1009
|
while True:
|
1010
|
+
self.launch_last_batch_sample_if_needed()
|
1011
|
+
|
963
1012
|
recv_reqs = self.recv_requests()
|
964
1013
|
self.process_input_requests(recv_reqs)
|
965
1014
|
|
@@ -967,30 +1016,13 @@ class Scheduler(
|
|
967
1016
|
self.cur_batch = batch
|
968
1017
|
|
969
1018
|
if batch:
|
970
|
-
batch.launch_done = threading.Event()
|
971
1019
|
result = self.run_batch(batch)
|
972
1020
|
self.result_queue.append((batch.copy(), result))
|
973
1021
|
|
974
|
-
if self.last_batch is None:
|
975
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
976
|
-
# It is now used for triggering the sampling_info_done event.
|
977
|
-
tmp_batch = ScheduleBatch(
|
978
|
-
reqs=None,
|
979
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
980
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
981
|
-
)
|
982
|
-
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
983
|
-
|
984
1022
|
if self.last_batch:
|
985
1023
|
# Process the results of the last batch
|
986
1024
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
987
|
-
tmp_batch
|
988
|
-
self.tp_worker.cur_sampling_info if batch else None
|
989
|
-
)
|
990
|
-
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
991
|
-
self.process_batch_result(
|
992
|
-
tmp_batch, tmp_result, batch.launch_done if batch else None
|
993
|
-
)
|
1025
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
994
1026
|
elif batch is None:
|
995
1027
|
# When the server is idle, do self-check and re-init some states
|
996
1028
|
self.self_check_during_idle()
|
@@ -1745,7 +1777,7 @@ class Scheduler(
|
|
1745
1777
|
chunked_req_to_exclude.add(self.chunked_req)
|
1746
1778
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1747
1779
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1748
|
-
if self.tp_worker.worker.model_runner.
|
1780
|
+
if self.tp_worker.worker.model_runner.mambaish_config is not None:
|
1749
1781
|
self.req_to_token_pool.free(
|
1750
1782
|
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1751
1783
|
)
|
@@ -1802,7 +1834,7 @@ class Scheduler(
|
|
1802
1834
|
return ret
|
1803
1835
|
|
1804
1836
|
def get_num_allocatable_reqs(self, running_bs):
|
1805
|
-
res = global_server_args_dict["
|
1837
|
+
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
|
1806
1838
|
if self.pp_size > 1:
|
1807
1839
|
res = min(res, self.req_to_token_pool.available_size())
|
1808
1840
|
return res
|
@@ -2055,18 +2087,59 @@ class Scheduler(
|
|
2055
2087
|
# FIXME(lsyin): remove this if and finally unify the abstraction
|
2056
2088
|
batch_or_worker_batch = batch.get_model_worker_batch()
|
2057
2089
|
|
2058
|
-
|
2059
|
-
|
2060
|
-
|
2090
|
+
if self.enable_overlap:
|
2091
|
+
# FIXME: remove this assert
|
2092
|
+
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
|
2093
|
+
model_worker_batch = batch_or_worker_batch
|
2094
|
+
self.record_batch_in_overlap(model_worker_batch)
|
2095
|
+
|
2096
|
+
# Sampling info will be modified during forward
|
2097
|
+
model_worker_batch.sampling_info = (
|
2098
|
+
model_worker_batch.sampling_info.copy_for_forward()
|
2099
|
+
)
|
2100
|
+
|
2101
|
+
bs = len(model_worker_batch.seq_lens)
|
2102
|
+
future_indices = self.future_map.alloc_future_indices(bs)
|
2103
|
+
|
2104
|
+
with self.forward_stream_ctx:
|
2105
|
+
self.forward_stream.wait_stream(self.default_stream)
|
2106
|
+
self.future_map.resolve_future(model_worker_batch)
|
2107
|
+
if batch.sampling_info.grammars is not None:
|
2108
|
+
model_worker_batch.delay_sample_launch = True
|
2109
|
+
batch_result = self.model_worker.forward_batch_generation(
|
2110
|
+
batch_or_worker_batch
|
2111
|
+
)
|
2112
|
+
# FIXME(lsyin): maybe move this to forward_batch_generation
|
2113
|
+
batch_result.copy_done = torch.get_device_module(
|
2114
|
+
self.device
|
2115
|
+
).Event()
|
2116
|
+
if not model_worker_batch.delay_sample_launch:
|
2117
|
+
self.future_map.store_to_map(
|
2118
|
+
future_indices, batch_result.next_token_ids
|
2119
|
+
)
|
2120
|
+
batch_result.copy_to_cpu()
|
2121
|
+
else:
|
2122
|
+
batch_result.future_indices = future_indices
|
2123
|
+
|
2124
|
+
# FIXME(lsyin): move this assignment elsewhere
|
2125
|
+
maybe_future_next_token_ids = -future_indices.indices
|
2126
|
+
else:
|
2127
|
+
batch_result = self.model_worker.forward_batch_generation(
|
2128
|
+
batch_or_worker_batch
|
2129
|
+
)
|
2130
|
+
maybe_future_next_token_ids = batch_result.next_token_ids
|
2061
2131
|
|
2062
2132
|
if not self.spec_algorithm.is_none():
|
2063
2133
|
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
2064
|
-
self.
|
2065
|
-
batch.batch_size(),
|
2134
|
+
self.update_spec_metrics(
|
2135
|
+
batch.batch_size(), batch_result.num_accepted_tokens
|
2066
2136
|
)
|
2067
2137
|
|
2068
|
-
#
|
2069
|
-
|
2138
|
+
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
|
2139
|
+
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
2140
|
+
# we shall still keep the original outputs, e.g. next_token_ids
|
2141
|
+
# in the GenerationBatchOutput for processing after copy_done.
|
2142
|
+
batch.output_ids = maybe_future_next_token_ids
|
2070
2143
|
|
2071
2144
|
# These 2 values are needed for processing the output, but the values can be
|
2072
2145
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -2083,39 +2156,60 @@ class Scheduler(
|
|
2083
2156
|
else:
|
2084
2157
|
extend_logprob_start_len_per_req = None
|
2085
2158
|
|
2086
|
-
|
2087
|
-
|
2088
|
-
|
2089
|
-
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
2159
|
+
batch_result.extend_input_len_per_req = extend_input_len_per_req
|
2160
|
+
batch_result.extend_logprob_start_len_per_req = (
|
2161
|
+
extend_logprob_start_len_per_req
|
2090
2162
|
)
|
2163
|
+
return batch_result
|
2091
2164
|
else: # embedding or reward model
|
2092
2165
|
model_worker_batch = batch.get_model_worker_batch()
|
2093
2166
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
2094
2167
|
ret = EmbeddingBatchResult(embeddings=embeddings)
|
2095
2168
|
return ret
|
2096
2169
|
|
2170
|
+
def launch_last_batch_sample_if_needed(
|
2171
|
+
self,
|
2172
|
+
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
2173
|
+
if len(self.result_queue) == 0:
|
2174
|
+
return
|
2175
|
+
|
2176
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
2177
|
+
|
2178
|
+
tmp_result: GenerationBatchResult
|
2179
|
+
if not tmp_result.delay_sample_launch:
|
2180
|
+
self.result_queue.appendleft((tmp_batch, tmp_result))
|
2181
|
+
return
|
2182
|
+
|
2183
|
+
with self.forward_stream_ctx:
|
2184
|
+
self.forward_stream.wait_stream(self.default_stream)
|
2185
|
+
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
|
2186
|
+
tmp_result.logits_output,
|
2187
|
+
tmp_result.forward_batch,
|
2188
|
+
)
|
2189
|
+
future_indices = tmp_result.future_indices
|
2190
|
+
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
|
2191
|
+
tmp_result.copy_to_cpu()
|
2192
|
+
self.result_queue.appendleft((tmp_batch, tmp_result))
|
2193
|
+
|
2097
2194
|
def process_batch_result(
|
2098
2195
|
self,
|
2099
2196
|
batch: ScheduleBatch,
|
2100
2197
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
2101
|
-
launch_done: Optional[threading.Event] = None,
|
2102
2198
|
):
|
2103
2199
|
if batch.forward_mode.is_decode():
|
2104
|
-
self.process_batch_result_decode(batch, result
|
2200
|
+
self.process_batch_result_decode(batch, result)
|
2105
2201
|
if self.enable_trace:
|
2106
2202
|
trace_slice_batch("decode loop", batch.reqs)
|
2107
2203
|
|
2108
2204
|
elif batch.forward_mode.is_extend():
|
2109
|
-
self.process_batch_result_prefill(batch, result
|
2205
|
+
self.process_batch_result_prefill(batch, result)
|
2110
2206
|
if self.enable_trace:
|
2111
2207
|
trace_slice_batch("prefill", batch.reqs)
|
2112
2208
|
|
2113
2209
|
elif batch.forward_mode.is_idle():
|
2114
2210
|
if self.enable_overlap:
|
2115
|
-
|
2116
|
-
|
2117
|
-
elif batch.forward_mode.is_dummy_first():
|
2118
|
-
self.set_next_batch_sampling_info_done(batch)
|
2211
|
+
if result.copy_done is not None:
|
2212
|
+
result.copy_done.synchronize()
|
2119
2213
|
|
2120
2214
|
self.maybe_send_health_check_signal()
|
2121
2215
|
|
@@ -2325,13 +2419,6 @@ class Scheduler(
|
|
2325
2419
|
self._add_request_to_queue(req)
|
2326
2420
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
2327
2421
|
|
2328
|
-
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
2329
|
-
if batch.next_batch_sampling_info:
|
2330
|
-
if batch.next_batch_sampling_info.grammars is not None:
|
2331
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
2332
|
-
self.current_stream.synchronize()
|
2333
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
2334
|
-
|
2335
2422
|
def watchdog_thread(self):
|
2336
2423
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
2337
2424
|
self.watchdog_last_forward_ct = 0
|
@@ -2510,7 +2597,7 @@ class Scheduler(
|
|
2510
2597
|
server_args_dict = recv_req.server_args
|
2511
2598
|
args_allow_update = set(
|
2512
2599
|
[
|
2513
|
-
"
|
2600
|
+
"pp_max_micro_batch_size",
|
2514
2601
|
"speculative_accept_threshold_single",
|
2515
2602
|
"speculative_accept_threshold_acc",
|
2516
2603
|
]
|
@@ -2521,7 +2608,7 @@ class Scheduler(
|
|
2521
2608
|
logging.warning(f"Updating {k} is not supported.")
|
2522
2609
|
if_success = False
|
2523
2610
|
break
|
2524
|
-
elif k == "
|
2611
|
+
elif k == "pp_max_micro_batch_size" and (
|
2525
2612
|
v > self.max_running_requests // self.pp_size or v < 1
|
2526
2613
|
):
|
2527
2614
|
logging.warning(
|
@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
|
|
69
69
|
kv_events_config, self.attn_dp_rank
|
70
70
|
)
|
71
71
|
|
72
|
-
def
|
72
|
+
def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
73
73
|
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
74
74
|
self.spec_num_total_forward_ct += bs
|
75
75
|
self.num_generated_tokens += num_accepted_tokens
|