sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,8 +22,10 @@ import time
|
|
22
22
|
import warnings
|
23
23
|
from collections import deque
|
24
24
|
from concurrent import futures
|
25
|
+
from dataclasses import dataclass
|
26
|
+
from http import HTTPStatus
|
25
27
|
from types import SimpleNamespace
|
26
|
-
from typing import Dict, List, Optional, Tuple
|
28
|
+
from typing import Dict, List, Optional, Tuple, Union
|
27
29
|
|
28
30
|
import psutil
|
29
31
|
import setproctitle
|
@@ -32,7 +34,9 @@ import zmq
|
|
32
34
|
|
33
35
|
from sglang.global_config import global_config
|
34
36
|
from sglang.srt.configs.model_config import ModelConfig
|
37
|
+
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
35
38
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
39
|
+
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
36
40
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
37
41
|
from sglang.srt.managers.io_struct import (
|
38
42
|
AbortReq,
|
@@ -76,6 +80,7 @@ from sglang.srt.managers.schedule_policy import (
|
|
76
80
|
from sglang.srt.managers.session_controller import Session
|
77
81
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
78
82
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
83
|
+
from sglang.srt.managers.utils import validate_input_length
|
79
84
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
80
85
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
81
86
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
@@ -93,7 +98,7 @@ from sglang.srt.utils import (
|
|
93
98
|
set_random_seed,
|
94
99
|
suppress_other_loggers,
|
95
100
|
)
|
96
|
-
from sglang.utils import get_exception_traceback
|
101
|
+
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
97
102
|
|
98
103
|
logger = logging.getLogger(__name__)
|
99
104
|
|
@@ -101,6 +106,19 @@ logger = logging.getLogger(__name__)
|
|
101
106
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
102
107
|
|
103
108
|
|
109
|
+
@dataclass
|
110
|
+
class GenerationBatchResult:
|
111
|
+
logits_output: LogitsProcessorOutput
|
112
|
+
next_token_ids: List[int]
|
113
|
+
bid: int
|
114
|
+
|
115
|
+
|
116
|
+
@dataclass
|
117
|
+
class EmbeddingBatchResult:
|
118
|
+
embeddings: torch.Tensor
|
119
|
+
bid: int
|
120
|
+
|
121
|
+
|
104
122
|
class Scheduler:
|
105
123
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
106
124
|
|
@@ -132,26 +150,36 @@ class Scheduler:
|
|
132
150
|
else 1
|
133
151
|
)
|
134
152
|
|
153
|
+
# Distributed rank info
|
154
|
+
self.dp_size = server_args.dp_size
|
155
|
+
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
156
|
+
compute_dp_attention_world_info(
|
157
|
+
server_args.enable_dp_attention,
|
158
|
+
self.tp_rank,
|
159
|
+
self.tp_size,
|
160
|
+
self.dp_size,
|
161
|
+
)
|
162
|
+
)
|
163
|
+
|
135
164
|
# Init inter-process communication
|
136
165
|
context = zmq.Context(2)
|
137
|
-
|
138
|
-
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
166
|
+
if self.attn_tp_rank == 0:
|
139
167
|
self.recv_from_tokenizer = get_zmq_socket(
|
140
|
-
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
168
|
+
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
141
169
|
)
|
142
170
|
self.send_to_tokenizer = get_zmq_socket(
|
143
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
171
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
144
172
|
)
|
145
173
|
|
146
174
|
if server_args.skip_tokenizer_init:
|
147
175
|
# Directly send to the TokenizerManager
|
148
176
|
self.send_to_detokenizer = get_zmq_socket(
|
149
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
177
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
150
178
|
)
|
151
179
|
else:
|
152
180
|
# Send to the DetokenizerManager
|
153
181
|
self.send_to_detokenizer = get_zmq_socket(
|
154
|
-
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
182
|
+
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
155
183
|
)
|
156
184
|
else:
|
157
185
|
self.recv_from_tokenizer = None
|
@@ -179,6 +207,7 @@ class Scheduler:
|
|
179
207
|
server_args.tokenizer_path,
|
180
208
|
tokenizer_mode=server_args.tokenizer_mode,
|
181
209
|
trust_remote_code=server_args.trust_remote_code,
|
210
|
+
revision=server_args.revision,
|
182
211
|
)
|
183
212
|
self.tokenizer = self.processor.tokenizer
|
184
213
|
else:
|
@@ -186,6 +215,7 @@ class Scheduler:
|
|
186
215
|
server_args.tokenizer_path,
|
187
216
|
tokenizer_mode=server_args.tokenizer_mode,
|
188
217
|
trust_remote_code=server_args.trust_remote_code,
|
218
|
+
revision=server_args.revision,
|
189
219
|
)
|
190
220
|
|
191
221
|
# Check whether overlap can be enabled
|
@@ -214,7 +244,7 @@ class Scheduler:
|
|
214
244
|
nccl_port=port_args.nccl_port,
|
215
245
|
)
|
216
246
|
|
217
|
-
# Launch worker for speculative decoding if
|
247
|
+
# Launch a worker for speculative decoding if needed
|
218
248
|
if self.spec_algorithm.is_eagle():
|
219
249
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
220
250
|
|
@@ -244,13 +274,14 @@ class Scheduler:
|
|
244
274
|
_,
|
245
275
|
) = self.tp_worker.get_worker_info()
|
246
276
|
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
277
|
+
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
247
278
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
248
279
|
global_server_args_dict.update(worker_global_server_args_dict)
|
249
280
|
set_random_seed(self.random_seed)
|
250
|
-
|
251
281
|
# Print debug info
|
252
282
|
logger.info(
|
253
283
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
284
|
+
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
254
285
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
255
286
|
f"max_running_requests={self.max_running_requests}, "
|
256
287
|
f"context_len={self.model_config.context_len}"
|
@@ -287,9 +318,13 @@ class Scheduler:
|
|
287
318
|
self.forward_ct = 0
|
288
319
|
self.forward_ct_decode = 0
|
289
320
|
self.num_generated_tokens = 0
|
321
|
+
self.spec_num_total_accepted_tokens = 0
|
322
|
+
self.spec_num_total_forward_ct = 0
|
290
323
|
self.last_decode_stats_tic = time.time()
|
291
324
|
self.stream_interval = server_args.stream_interval
|
292
325
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
326
|
+
if self.device == "cpu":
|
327
|
+
self.current_stream.synchronize = lambda: None # No-op for CPU
|
293
328
|
|
294
329
|
# Session info
|
295
330
|
self.sessions: Dict[str, Session] = {}
|
@@ -306,28 +341,9 @@ class Scheduler:
|
|
306
341
|
# Init the grammar backend for constrained generation
|
307
342
|
self.grammar_queue: List[Req] = []
|
308
343
|
if not server_args.skip_tokenizer_init:
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
)
|
313
|
-
|
314
|
-
self.grammar_backend = OutlinesGrammarBackend(
|
315
|
-
self.tokenizer,
|
316
|
-
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
317
|
-
allow_jump_forward=not server_args.disable_jump_forward,
|
318
|
-
)
|
319
|
-
elif server_args.grammar_backend == "xgrammar":
|
320
|
-
from sglang.srt.constrained.xgrammar_backend import (
|
321
|
-
XGrammarGrammarBackend,
|
322
|
-
)
|
323
|
-
|
324
|
-
self.grammar_backend = XGrammarGrammarBackend(
|
325
|
-
self.tokenizer, vocab_size=self.model_config.vocab_size
|
326
|
-
)
|
327
|
-
else:
|
328
|
-
raise ValueError(
|
329
|
-
f"Invalid grammar backend: {server_args.grammar_backend}"
|
330
|
-
)
|
344
|
+
self.grammar_backend = create_grammar_backend(
|
345
|
+
server_args, self.tokenizer, self.model_config.vocab_size
|
346
|
+
)
|
331
347
|
else:
|
332
348
|
self.grammar_backend = None
|
333
349
|
|
@@ -393,22 +409,56 @@ class Scheduler:
|
|
393
409
|
},
|
394
410
|
)
|
395
411
|
|
412
|
+
# The largest prefill length of a single request
|
413
|
+
self._largest_prefill_len: int = 0
|
414
|
+
# The largest context length (prefill + generation) of a single request
|
415
|
+
self._largest_prefill_decode_len: int = 0
|
416
|
+
|
417
|
+
# Init request dispatcher
|
418
|
+
self._request_dispatcher = TypeBasedDispatcher(
|
419
|
+
[
|
420
|
+
(TokenizedGenerateReqInput, self.handle_generate_request),
|
421
|
+
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
422
|
+
(FlushCacheReq, self.flush_cache_wrapped),
|
423
|
+
(AbortReq, self.abort_request),
|
424
|
+
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
425
|
+
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
426
|
+
(
|
427
|
+
UpdateWeightsFromDistributedReqInput,
|
428
|
+
self.update_weights_from_distributed,
|
429
|
+
),
|
430
|
+
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
431
|
+
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
432
|
+
(ProfileReq, self.profile),
|
433
|
+
(OpenSessionReqInput, self.open_session),
|
434
|
+
(CloseSessionReqInput, self.close_session),
|
435
|
+
(
|
436
|
+
ReleaseMemoryOccupationReqInput,
|
437
|
+
lambda _: self.release_memory_occupation(),
|
438
|
+
),
|
439
|
+
(
|
440
|
+
ResumeMemoryOccupationReqInput,
|
441
|
+
lambda _: self.resume_memory_occupation(),
|
442
|
+
),
|
443
|
+
]
|
444
|
+
)
|
445
|
+
|
396
446
|
def watchdog_thread(self):
|
397
447
|
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
398
448
|
self.watchdog_last_forward_ct = 0
|
399
449
|
self.watchdog_last_time = time.time()
|
400
450
|
|
401
451
|
while True:
|
452
|
+
current = time.time()
|
402
453
|
if self.cur_batch is not None:
|
403
454
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
404
|
-
if
|
455
|
+
if current > self.watchdog_last_time + self.watchdog_timeout:
|
405
456
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
406
457
|
break
|
407
458
|
else:
|
408
459
|
self.watchdog_last_forward_ct = self.forward_ct
|
409
|
-
self.watchdog_last_time =
|
410
|
-
time.sleep(self.watchdog_timeout
|
411
|
-
|
460
|
+
self.watchdog_last_time = current
|
461
|
+
time.sleep(self.watchdog_timeout // 2)
|
412
462
|
# Wait sometimes so that the parent process can print the error.
|
413
463
|
time.sleep(5)
|
414
464
|
self.parent_process.send_signal(signal.SIGQUIT)
|
@@ -421,10 +471,6 @@ class Scheduler:
|
|
421
471
|
self.process_input_requests(recv_reqs)
|
422
472
|
|
423
473
|
batch = self.get_next_batch_to_run()
|
424
|
-
|
425
|
-
if self.server_args.enable_dp_attention: # TODO: simplify this
|
426
|
-
batch = self.prepare_dp_attn_batch(batch)
|
427
|
-
|
428
474
|
self.cur_batch = batch
|
429
475
|
|
430
476
|
if batch:
|
@@ -440,7 +486,7 @@ class Scheduler:
|
|
440
486
|
@torch.no_grad()
|
441
487
|
def event_loop_overlap(self):
|
442
488
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
443
|
-
result_queue = deque()
|
489
|
+
self.result_queue = deque()
|
444
490
|
|
445
491
|
while True:
|
446
492
|
recv_reqs = self.recv_requests()
|
@@ -451,10 +497,10 @@ class Scheduler:
|
|
451
497
|
|
452
498
|
if batch:
|
453
499
|
result = self.run_batch(batch)
|
454
|
-
result_queue.append((batch.copy(), result))
|
500
|
+
self.result_queue.append((batch.copy(), result))
|
455
501
|
|
456
502
|
if self.last_batch is None:
|
457
|
-
# Create a dummy first batch to start the pipeline for overlap
|
503
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
458
504
|
# It is now used for triggering the sampling_info_done event.
|
459
505
|
tmp_batch = ScheduleBatch(
|
460
506
|
reqs=None,
|
@@ -465,7 +511,7 @@ class Scheduler:
|
|
465
511
|
|
466
512
|
if self.last_batch:
|
467
513
|
# Process the results of the last batch
|
468
|
-
tmp_batch, tmp_result = result_queue.popleft()
|
514
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
469
515
|
tmp_batch.next_batch_sampling_info = (
|
470
516
|
self.tp_worker.cur_sampling_info if batch else None
|
471
517
|
)
|
@@ -479,7 +525,7 @@ class Scheduler:
|
|
479
525
|
|
480
526
|
def recv_requests(self) -> List[Req]:
|
481
527
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
482
|
-
if self.
|
528
|
+
if self.attn_tp_rank == 0:
|
483
529
|
recv_reqs = []
|
484
530
|
|
485
531
|
while True:
|
@@ -491,63 +537,48 @@ class Scheduler:
|
|
491
537
|
else:
|
492
538
|
recv_reqs = None
|
493
539
|
|
494
|
-
if self.
|
540
|
+
if self.server_args.enable_dp_attention:
|
541
|
+
if self.attn_tp_rank == 0:
|
542
|
+
work_reqs = [
|
543
|
+
req
|
544
|
+
for req in recv_reqs
|
545
|
+
if isinstance(
|
546
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
547
|
+
)
|
548
|
+
]
|
549
|
+
control_reqs = [
|
550
|
+
req
|
551
|
+
for req in recv_reqs
|
552
|
+
if not isinstance(
|
553
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
554
|
+
)
|
555
|
+
]
|
556
|
+
else:
|
557
|
+
work_reqs = None
|
558
|
+
control_reqs = None
|
559
|
+
|
560
|
+
if self.attn_tp_size != 1:
|
561
|
+
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
562
|
+
work_reqs = broadcast_pyobj(
|
563
|
+
work_reqs,
|
564
|
+
self.attn_tp_rank,
|
565
|
+
self.attn_tp_cpu_group,
|
566
|
+
src=attn_tp_rank_0,
|
567
|
+
)
|
568
|
+
if self.tp_size != 1:
|
569
|
+
control_reqs = broadcast_pyobj(
|
570
|
+
control_reqs, self.tp_rank, self.tp_cpu_group
|
571
|
+
)
|
572
|
+
recv_reqs = work_reqs + control_reqs
|
573
|
+
elif self.tp_size != 1:
|
495
574
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
496
575
|
return recv_reqs
|
497
576
|
|
498
577
|
def process_input_requests(self, recv_reqs: List):
|
499
578
|
for recv_req in recv_reqs:
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
self.handle_embedding_request(recv_req)
|
504
|
-
elif isinstance(recv_req, FlushCacheReq):
|
505
|
-
self.flush_cache()
|
506
|
-
elif isinstance(recv_req, AbortReq):
|
507
|
-
self.abort_request(recv_req)
|
508
|
-
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
509
|
-
success, message = self.update_weights_from_disk(recv_req)
|
510
|
-
self.send_to_tokenizer.send_pyobj(
|
511
|
-
UpdateWeightFromDiskReqOutput(success, message)
|
512
|
-
)
|
513
|
-
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
514
|
-
success, message = self.init_weights_update_group(recv_req)
|
515
|
-
self.send_to_tokenizer.send_pyobj(
|
516
|
-
InitWeightsUpdateGroupReqOutput(success, message)
|
517
|
-
)
|
518
|
-
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
519
|
-
success, message = self.update_weights_from_distributed(recv_req)
|
520
|
-
self.send_to_tokenizer.send_pyobj(
|
521
|
-
UpdateWeightsFromDistributedReqOutput(success, message)
|
522
|
-
)
|
523
|
-
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
524
|
-
success, message = self.update_weights_from_tensor(recv_req)
|
525
|
-
self.send_to_tokenizer.send_pyobj(
|
526
|
-
UpdateWeightsFromTensorReqOutput(success, message)
|
527
|
-
)
|
528
|
-
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
529
|
-
parameter = self.get_weights_by_name(recv_req)
|
530
|
-
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
531
|
-
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
532
|
-
self.release_memory_occupation()
|
533
|
-
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
534
|
-
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
535
|
-
self.resume_memory_occupation()
|
536
|
-
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
537
|
-
elif isinstance(recv_req, ProfileReq):
|
538
|
-
if recv_req == ProfileReq.START_PROFILE:
|
539
|
-
self.start_profile()
|
540
|
-
else:
|
541
|
-
self.stop_profile()
|
542
|
-
elif isinstance(recv_req, OpenSessionReqInput):
|
543
|
-
session_id, success = self.open_session(recv_req)
|
544
|
-
self.send_to_tokenizer.send_pyobj(
|
545
|
-
OpenSessionReqOutput(session_id=session_id, success=success)
|
546
|
-
)
|
547
|
-
elif isinstance(recv_req, CloseSessionReqInput):
|
548
|
-
self.close_session(recv_req)
|
549
|
-
else:
|
550
|
-
raise ValueError(f"Invalid request: {recv_req}")
|
579
|
+
output = self._request_dispatcher(recv_req)
|
580
|
+
if output is not None:
|
581
|
+
self.send_to_tokenizer.send_pyobj(output)
|
551
582
|
|
552
583
|
def handle_generate_request(
|
553
584
|
self,
|
@@ -566,6 +597,19 @@ class Scheduler:
|
|
566
597
|
fake_input_ids = [1] * seq_length
|
567
598
|
recv_req.input_ids = fake_input_ids
|
568
599
|
|
600
|
+
# Handle custom logit processor passed to the request
|
601
|
+
custom_logit_processor = recv_req.custom_logit_processor
|
602
|
+
if (
|
603
|
+
not self.server_args.enable_custom_logit_processor
|
604
|
+
and custom_logit_processor is not None
|
605
|
+
):
|
606
|
+
logger.warning(
|
607
|
+
"The SGLang server is not configured to enable custom logit processor."
|
608
|
+
"The custom logit processor passed in will be ignored."
|
609
|
+
"Please set --enable-custom-logits-processor to enable this feature."
|
610
|
+
)
|
611
|
+
custom_logit_processor = None
|
612
|
+
|
569
613
|
req = Req(
|
570
614
|
recv_req.rid,
|
571
615
|
recv_req.input_text,
|
@@ -576,6 +620,7 @@ class Scheduler:
|
|
576
620
|
stream=recv_req.stream,
|
577
621
|
lora_path=recv_req.lora_path,
|
578
622
|
input_embeds=recv_req.input_embeds,
|
623
|
+
custom_logit_processor=custom_logit_processor,
|
579
624
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
580
625
|
)
|
581
626
|
req.tokenizer = self.tokenizer
|
@@ -597,7 +642,7 @@ class Scheduler:
|
|
597
642
|
self.waiting_queue.append(req)
|
598
643
|
return
|
599
644
|
|
600
|
-
# Handle
|
645
|
+
# Handle multimodal inputs
|
601
646
|
if recv_req.image_inputs is not None:
|
602
647
|
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
603
648
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
@@ -607,33 +652,36 @@ class Scheduler:
|
|
607
652
|
req.extend_image_inputs(image_inputs)
|
608
653
|
|
609
654
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
610
|
-
|
655
|
+
error_msg = (
|
611
656
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
612
|
-
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}.
|
657
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
613
658
|
)
|
659
|
+
logger.error(error_msg)
|
614
660
|
req.origin_input_ids = [0]
|
615
661
|
req.image_inputs = None
|
616
662
|
req.sampling_params.max_new_tokens = 0
|
617
663
|
req.finished_reason = FINISH_ABORT(
|
618
|
-
|
664
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
619
665
|
)
|
620
666
|
self.waiting_queue.append(req)
|
621
667
|
return
|
622
668
|
|
623
|
-
#
|
624
|
-
|
669
|
+
# Validate prompts length
|
670
|
+
error_msg = validate_input_length(
|
671
|
+
req,
|
672
|
+
self.max_req_input_len,
|
673
|
+
self.server_args.allow_auto_truncate,
|
674
|
+
)
|
675
|
+
if error_msg:
|
676
|
+
self.waiting_queue.append(req)
|
677
|
+
return
|
625
678
|
|
626
|
-
|
679
|
+
# Copy more attributes
|
680
|
+
if recv_req.logprob_start_len == -1:
|
627
681
|
# By default, only return the logprobs for output tokens
|
628
682
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
629
|
-
|
630
|
-
|
631
|
-
if len(req.origin_input_ids) > self.max_req_input_len:
|
632
|
-
logger.warning(
|
633
|
-
"Request length is longer than the KV cache pool size or "
|
634
|
-
"the max context length. Truncated!!!"
|
635
|
-
)
|
636
|
-
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
683
|
+
else:
|
684
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
637
685
|
|
638
686
|
req.sampling_params.max_new_tokens = min(
|
639
687
|
(
|
@@ -681,17 +729,27 @@ class Scheduler:
|
|
681
729
|
)
|
682
730
|
req.tokenizer = self.tokenizer
|
683
731
|
|
684
|
-
#
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
732
|
+
# Validate prompts length
|
733
|
+
error_msg = validate_input_length(
|
734
|
+
req,
|
735
|
+
self.max_req_input_len,
|
736
|
+
self.server_args.allow_auto_truncate,
|
737
|
+
)
|
738
|
+
if error_msg:
|
739
|
+
self.waiting_queue.append(req)
|
740
|
+
return
|
691
741
|
|
742
|
+
# Copy more attributes
|
743
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
692
744
|
self.waiting_queue.append(req)
|
693
745
|
|
694
|
-
def log_prefill_stats(
|
746
|
+
def log_prefill_stats(
|
747
|
+
self,
|
748
|
+
adder: PrefillAdder,
|
749
|
+
can_run_list: List[Req],
|
750
|
+
running_bs: ScheduleBatch,
|
751
|
+
has_being_chunked: bool,
|
752
|
+
):
|
695
753
|
self.tree_cache_metrics["total"] += (
|
696
754
|
adder.log_input_tokens + adder.log_hit_tokens
|
697
755
|
) / 10**9
|
@@ -733,21 +791,40 @@ class Scheduler:
|
|
733
791
|
self.num_generated_tokens = 0
|
734
792
|
self.last_decode_stats_tic = time.time()
|
735
793
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
736
|
-
logger.info(
|
737
|
-
f"Decode batch. "
|
738
|
-
f"#running-req: {num_running_reqs}, "
|
739
|
-
f"#token: {num_used}, "
|
740
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
741
|
-
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
742
|
-
f"#queue-req: {len(self.waiting_queue)}"
|
743
|
-
)
|
744
794
|
|
795
|
+
if self.spec_algorithm.is_none():
|
796
|
+
msg = (
|
797
|
+
f"Decode batch. "
|
798
|
+
f"#running-req: {num_running_reqs}, "
|
799
|
+
f"#token: {num_used}, "
|
800
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
801
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
802
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
803
|
+
)
|
804
|
+
spec_accept_length = 0
|
805
|
+
else:
|
806
|
+
spec_accept_length = (
|
807
|
+
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
808
|
+
)
|
809
|
+
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
810
|
+
msg = (
|
811
|
+
f"Decode batch. "
|
812
|
+
f"#running-req: {num_running_reqs}, "
|
813
|
+
f"#token: {num_used}, "
|
814
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
815
|
+
f"accept len: {spec_accept_length:.2f}, "
|
816
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
817
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
818
|
+
)
|
819
|
+
|
820
|
+
logger.info(msg)
|
745
821
|
if self.enable_metrics:
|
746
822
|
self.stats.num_running_reqs = num_running_reqs
|
747
823
|
self.stats.num_used_tokens = num_used
|
748
824
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
749
825
|
self.stats.gen_throughput = gen_throughput
|
750
826
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
827
|
+
self.stats.spec_accept_length = spec_accept_length
|
751
828
|
self.metrics_collector.log_stats(self.stats)
|
752
829
|
|
753
830
|
def check_memory(self):
|
@@ -790,16 +867,23 @@ class Scheduler:
|
|
790
867
|
else:
|
791
868
|
self.running_batch.merge_batch(self.last_batch)
|
792
869
|
|
793
|
-
# Run prefill first if possible
|
794
870
|
new_batch = self.get_new_batch_prefill()
|
795
871
|
if new_batch is not None:
|
796
|
-
|
872
|
+
# Run prefill first if possible
|
873
|
+
ret = new_batch
|
874
|
+
else:
|
875
|
+
# Run decode
|
876
|
+
if self.running_batch is None:
|
877
|
+
ret = None
|
878
|
+
else:
|
879
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
880
|
+
ret = self.running_batch
|
797
881
|
|
798
|
-
#
|
799
|
-
if self.
|
800
|
-
|
801
|
-
|
802
|
-
return
|
882
|
+
# Handle DP attention
|
883
|
+
if self.server_args.enable_dp_attention:
|
884
|
+
ret = self.prepare_dp_attn_batch(ret)
|
885
|
+
|
886
|
+
return ret
|
803
887
|
|
804
888
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
805
889
|
# Check if the grammar is ready in the grammar queue
|
@@ -823,9 +907,9 @@ class Scheduler:
|
|
823
907
|
# Prefill policy
|
824
908
|
adder = PrefillAdder(
|
825
909
|
self.tree_cache,
|
910
|
+
self.token_to_kv_pool,
|
826
911
|
self.running_batch,
|
827
912
|
self.new_token_ratio,
|
828
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
829
913
|
self.max_prefill_tokens,
|
830
914
|
self.chunked_prefill_size,
|
831
915
|
running_bs if self.is_mixed_chunk else 0,
|
@@ -886,7 +970,7 @@ class Scheduler:
|
|
886
970
|
self.being_chunked_req.is_being_chunked += 1
|
887
971
|
|
888
972
|
# Print stats
|
889
|
-
if self.
|
973
|
+
if self.attn_tp_rank == 0:
|
890
974
|
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
891
975
|
|
892
976
|
# Create a new batch
|
@@ -898,6 +982,7 @@ class Scheduler:
|
|
898
982
|
self.model_config,
|
899
983
|
self.enable_overlap,
|
900
984
|
self.spec_algorithm,
|
985
|
+
self.server_args.enable_custom_logit_processor,
|
901
986
|
)
|
902
987
|
new_batch.prepare_for_extend()
|
903
988
|
|
@@ -954,7 +1039,7 @@ class Scheduler:
|
|
954
1039
|
)
|
955
1040
|
|
956
1041
|
# Check for jump-forward
|
957
|
-
if not self.disable_jump_forward:
|
1042
|
+
if not self.disable_jump_forward and batch.has_grammar:
|
958
1043
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
959
1044
|
self.waiting_queue.extend(jump_forward_reqs)
|
960
1045
|
if batch.is_empty():
|
@@ -968,63 +1053,81 @@ class Scheduler:
|
|
968
1053
|
batch.prepare_for_decode()
|
969
1054
|
return batch
|
970
1055
|
|
971
|
-
def run_batch(
|
1056
|
+
def run_batch(
|
1057
|
+
self, batch: ScheduleBatch
|
1058
|
+
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
972
1059
|
"""Run a batch."""
|
973
1060
|
self.forward_ct += 1
|
974
1061
|
|
975
1062
|
if self.is_generation:
|
976
|
-
if
|
977
|
-
if self.spec_algorithm.is_none():
|
978
|
-
model_worker_batch = batch.get_model_worker_batch()
|
979
|
-
logits_output, next_token_ids = (
|
980
|
-
self.tp_worker.forward_batch_generation(model_worker_batch)
|
981
|
-
)
|
982
|
-
else:
|
983
|
-
(
|
984
|
-
logits_output,
|
985
|
-
next_token_ids,
|
986
|
-
model_worker_batch,
|
987
|
-
num_accepted_tokens,
|
988
|
-
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
989
|
-
self.num_generated_tokens += num_accepted_tokens
|
990
|
-
elif batch.forward_mode.is_idle():
|
1063
|
+
if self.spec_algorithm.is_none():
|
991
1064
|
model_worker_batch = batch.get_model_worker_batch()
|
992
|
-
self.tp_worker.
|
993
|
-
|
1065
|
+
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
1066
|
+
model_worker_batch
|
1067
|
+
)
|
994
1068
|
else:
|
995
|
-
|
996
|
-
|
997
|
-
next_token_ids
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1069
|
+
(
|
1070
|
+
logits_output,
|
1071
|
+
next_token_ids,
|
1072
|
+
model_worker_batch,
|
1073
|
+
num_accepted_tokens,
|
1074
|
+
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1075
|
+
self.spec_num_total_accepted_tokens += (
|
1076
|
+
num_accepted_tokens + batch.batch_size()
|
1077
|
+
)
|
1078
|
+
self.spec_num_total_forward_ct += batch.batch_size()
|
1079
|
+
self.num_generated_tokens += num_accepted_tokens
|
1002
1080
|
batch.output_ids = next_token_ids
|
1003
|
-
|
1081
|
+
|
1082
|
+
ret = GenerationBatchResult(
|
1083
|
+
logits_output=logits_output,
|
1084
|
+
next_token_ids=next_token_ids,
|
1085
|
+
bid=model_worker_batch.bid,
|
1086
|
+
)
|
1004
1087
|
else: # embedding or reward model
|
1005
|
-
assert batch.extend_num_tokens != 0
|
1006
1088
|
model_worker_batch = batch.get_model_worker_batch()
|
1007
1089
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1008
|
-
ret =
|
1090
|
+
ret = EmbeddingBatchResult(
|
1091
|
+
embeddings=embeddings, bid=model_worker_batch.bid
|
1092
|
+
)
|
1009
1093
|
return ret
|
1010
1094
|
|
1011
|
-
def process_batch_result(
|
1095
|
+
def process_batch_result(
|
1096
|
+
self,
|
1097
|
+
batch: ScheduleBatch,
|
1098
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1099
|
+
):
|
1012
1100
|
if batch.forward_mode.is_decode():
|
1013
1101
|
self.process_batch_result_decode(batch, result)
|
1014
1102
|
if batch.is_empty():
|
1015
1103
|
self.running_batch = None
|
1016
1104
|
elif batch.forward_mode.is_extend():
|
1017
1105
|
self.process_batch_result_prefill(batch, result)
|
1106
|
+
elif batch.forward_mode.is_idle():
|
1107
|
+
if self.enable_overlap:
|
1108
|
+
self.tp_worker.resolve_batch_result(result.bid)
|
1018
1109
|
elif batch.forward_mode.is_dummy_first():
|
1019
1110
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1020
1111
|
self.current_stream.synchronize()
|
1021
1112
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1022
1113
|
|
1023
|
-
def process_batch_result_prefill(
|
1114
|
+
def process_batch_result_prefill(
|
1115
|
+
self,
|
1116
|
+
batch: ScheduleBatch,
|
1117
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1118
|
+
):
|
1024
1119
|
skip_stream_req = None
|
1025
1120
|
|
1026
1121
|
if self.is_generation:
|
1027
|
-
|
1122
|
+
(
|
1123
|
+
logits_output,
|
1124
|
+
next_token_ids,
|
1125
|
+
bid,
|
1126
|
+
) = (
|
1127
|
+
result.logits_output,
|
1128
|
+
result.next_token_ids,
|
1129
|
+
result.bid,
|
1130
|
+
)
|
1028
1131
|
|
1029
1132
|
if self.enable_overlap:
|
1030
1133
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
@@ -1038,9 +1141,6 @@ class Scheduler:
|
|
1038
1141
|
logits_output.input_token_logprobs = (
|
1039
1142
|
logits_output.input_token_logprobs.tolist()
|
1040
1143
|
)
|
1041
|
-
logits_output.normalized_prompt_logprobs = (
|
1042
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
1043
|
-
)
|
1044
1144
|
|
1045
1145
|
# Check finish conditions
|
1046
1146
|
logprob_pt = 0
|
@@ -1085,7 +1185,7 @@ class Scheduler:
|
|
1085
1185
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1086
1186
|
|
1087
1187
|
else: # embedding or reward model
|
1088
|
-
embeddings, bid = result
|
1188
|
+
embeddings, bid = result.embeddings, result.bid
|
1089
1189
|
embeddings = embeddings.tolist()
|
1090
1190
|
|
1091
1191
|
# Check finish conditions
|
@@ -1109,8 +1209,16 @@ class Scheduler:
|
|
1109
1209
|
|
1110
1210
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1111
1211
|
|
1112
|
-
def process_batch_result_decode(
|
1113
|
-
|
1212
|
+
def process_batch_result_decode(
|
1213
|
+
self,
|
1214
|
+
batch: ScheduleBatch,
|
1215
|
+
result: GenerationBatchResult,
|
1216
|
+
):
|
1217
|
+
logits_output, next_token_ids, bid = (
|
1218
|
+
result.logits_output,
|
1219
|
+
result.next_token_ids,
|
1220
|
+
result.bid,
|
1221
|
+
)
|
1114
1222
|
self.num_generated_tokens += len(batch.reqs)
|
1115
1223
|
|
1116
1224
|
if self.enable_overlap:
|
@@ -1168,7 +1276,7 @@ class Scheduler:
|
|
1168
1276
|
|
1169
1277
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1170
1278
|
if (
|
1171
|
-
self.
|
1279
|
+
self.attn_tp_rank == 0
|
1172
1280
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
1173
1281
|
):
|
1174
1282
|
self.log_decode_stats()
|
@@ -1188,9 +1296,6 @@ class Scheduler:
|
|
1188
1296
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
1189
1297
|
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
1190
1298
|
|
1191
|
-
if req.normalized_prompt_logprob is None:
|
1192
|
-
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
1193
|
-
|
1194
1299
|
if req.input_token_logprobs_val is None:
|
1195
1300
|
input_token_logprobs_val = output.input_token_logprobs[
|
1196
1301
|
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
@@ -1278,6 +1383,7 @@ class Scheduler:
|
|
1278
1383
|
prompt_tokens = []
|
1279
1384
|
completion_tokens = []
|
1280
1385
|
cached_tokens = []
|
1386
|
+
spec_verify_ct = []
|
1281
1387
|
|
1282
1388
|
if return_logprob:
|
1283
1389
|
input_token_logprobs_val = []
|
@@ -1288,15 +1394,12 @@ class Scheduler:
|
|
1288
1394
|
input_top_logprobs_idx = []
|
1289
1395
|
output_top_logprobs_val = []
|
1290
1396
|
output_top_logprobs_idx = []
|
1291
|
-
normalized_prompt_logprob = []
|
1292
1397
|
else:
|
1293
1398
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
1294
1399
|
output_token_logprobs_val
|
1295
1400
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1296
1401
|
input_top_logprobs_idx
|
1297
|
-
) = output_top_logprobs_val = output_top_logprobs_idx =
|
1298
|
-
normalized_prompt_logprob
|
1299
|
-
) = None
|
1402
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = None
|
1300
1403
|
|
1301
1404
|
for req in reqs:
|
1302
1405
|
if req is skip_req:
|
@@ -1334,6 +1437,9 @@ class Scheduler:
|
|
1334
1437
|
completion_tokens.append(len(req.output_ids))
|
1335
1438
|
cached_tokens.append(req.cached_tokens)
|
1336
1439
|
|
1440
|
+
if not self.spec_algorithm.is_none():
|
1441
|
+
spec_verify_ct.append(req.spec_verify_ct)
|
1442
|
+
|
1337
1443
|
if return_logprob:
|
1338
1444
|
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1339
1445
|
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
@@ -1343,7 +1449,6 @@ class Scheduler:
|
|
1343
1449
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1344
1450
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1345
1451
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1346
|
-
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
1347
1452
|
|
1348
1453
|
# Send to detokenizer
|
1349
1454
|
if rids:
|
@@ -1362,6 +1467,7 @@ class Scheduler:
|
|
1362
1467
|
prompt_tokens,
|
1363
1468
|
completion_tokens,
|
1364
1469
|
cached_tokens,
|
1470
|
+
spec_verify_ct,
|
1365
1471
|
input_token_logprobs_val,
|
1366
1472
|
input_token_logprobs_idx,
|
1367
1473
|
output_token_logprobs_val,
|
@@ -1370,7 +1476,6 @@ class Scheduler:
|
|
1370
1476
|
input_top_logprobs_idx,
|
1371
1477
|
output_top_logprobs_val,
|
1372
1478
|
output_top_logprobs_idx,
|
1373
|
-
normalized_prompt_logprob,
|
1374
1479
|
)
|
1375
1480
|
)
|
1376
1481
|
else: # embedding or reward model
|
@@ -1412,12 +1517,7 @@ class Scheduler:
|
|
1412
1517
|
# Check forward mode for cuda graph
|
1413
1518
|
if not self.server_args.disable_cuda_graph:
|
1414
1519
|
forward_mode_state = torch.tensor(
|
1415
|
-
(
|
1416
|
-
1
|
1417
|
-
if local_batch.forward_mode.is_decode()
|
1418
|
-
or local_batch.forward_mode.is_idle()
|
1419
|
-
else 0
|
1420
|
-
),
|
1520
|
+
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
1421
1521
|
dtype=torch.int32,
|
1422
1522
|
)
|
1423
1523
|
torch.distributed.all_reduce(
|
@@ -1438,6 +1538,7 @@ class Scheduler:
|
|
1438
1538
|
self.model_config,
|
1439
1539
|
self.enable_overlap,
|
1440
1540
|
self.spec_algorithm,
|
1541
|
+
self.server_args.enable_custom_logit_processor,
|
1441
1542
|
)
|
1442
1543
|
idle_batch.prepare_for_idle()
|
1443
1544
|
return idle_batch
|
@@ -1466,6 +1567,9 @@ class Scheduler:
|
|
1466
1567
|
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
1467
1568
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1468
1569
|
|
1570
|
+
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
1571
|
+
self.flush_cache()
|
1572
|
+
|
1469
1573
|
def flush_cache(self):
|
1470
1574
|
"""Flush the memory pool and cache."""
|
1471
1575
|
if len(self.waiting_queue) == 0 and (
|
@@ -1477,6 +1581,15 @@ class Scheduler:
|
|
1477
1581
|
self.grammar_backend.reset()
|
1478
1582
|
self.req_to_token_pool.clear()
|
1479
1583
|
self.token_to_kv_pool.clear()
|
1584
|
+
|
1585
|
+
if not self.spec_algorithm.is_none():
|
1586
|
+
self.draft_worker.model_runner.req_to_token_pool.clear()
|
1587
|
+
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
1588
|
+
|
1589
|
+
self.num_generated_tokens = 0
|
1590
|
+
self.forward_ct_decode = 0
|
1591
|
+
self.spec_num_total_accepted_tokens = 0
|
1592
|
+
self.spec_num_total_forward_ct = 0
|
1480
1593
|
torch.cuda.empty_cache()
|
1481
1594
|
logger.info("Cache flushed successfully!")
|
1482
1595
|
if_success = True
|
@@ -1518,12 +1631,12 @@ class Scheduler:
|
|
1518
1631
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1519
1632
|
else:
|
1520
1633
|
logger.error(message)
|
1521
|
-
return success, message
|
1634
|
+
return UpdateWeightFromDiskReqOutput(success, message)
|
1522
1635
|
|
1523
1636
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1524
1637
|
"""Initialize the online model parameter update group."""
|
1525
1638
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
1526
|
-
return success, message
|
1639
|
+
return InitWeightsUpdateGroupReqOutput(success, message)
|
1527
1640
|
|
1528
1641
|
def update_weights_from_distributed(
|
1529
1642
|
self,
|
@@ -1536,7 +1649,7 @@ class Scheduler:
|
|
1536
1649
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1537
1650
|
else:
|
1538
1651
|
logger.error(message)
|
1539
|
-
return success, message
|
1652
|
+
return UpdateWeightsFromDistributedReqOutput(success, message)
|
1540
1653
|
|
1541
1654
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
1542
1655
|
"""Update the online model parameter from tensors."""
|
@@ -1547,11 +1660,11 @@ class Scheduler:
|
|
1547
1660
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1548
1661
|
else:
|
1549
1662
|
logger.error(message)
|
1550
|
-
return success, message
|
1663
|
+
return UpdateWeightsFromTensorReqOutput(success, message)
|
1551
1664
|
|
1552
1665
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1553
1666
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1554
|
-
return parameter
|
1667
|
+
return GetWeightsByNameReqOutput(parameter)
|
1555
1668
|
|
1556
1669
|
def release_memory_occupation(self):
|
1557
1670
|
self.stashed_model_static_state = _export_static_state(
|
@@ -1559,6 +1672,7 @@ class Scheduler:
|
|
1559
1672
|
)
|
1560
1673
|
self.memory_saver_adapter.pause()
|
1561
1674
|
self.flush_cache()
|
1675
|
+
return ReleaseMemoryOccupationReqOutput()
|
1562
1676
|
|
1563
1677
|
def resume_memory_occupation(self):
|
1564
1678
|
self.memory_saver_adapter.resume()
|
@@ -1566,6 +1680,13 @@ class Scheduler:
|
|
1566
1680
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
1567
1681
|
)
|
1568
1682
|
del self.stashed_model_static_state
|
1683
|
+
return ResumeMemoryOccupationReqOutput()
|
1684
|
+
|
1685
|
+
def profile(self, recv_req: ProfileReq):
|
1686
|
+
if recv_req == ProfileReq.START_PROFILE:
|
1687
|
+
self.start_profile()
|
1688
|
+
else:
|
1689
|
+
self.stop_profile()
|
1569
1690
|
|
1570
1691
|
def start_profile(self) -> None:
|
1571
1692
|
if self.profiler is None:
|
@@ -1581,20 +1702,20 @@ class Scheduler:
|
|
1581
1702
|
)
|
1582
1703
|
logger.info("Profiler is done")
|
1583
1704
|
|
1584
|
-
def open_session(self, recv_req: OpenSessionReqInput)
|
1705
|
+
def open_session(self, recv_req: OpenSessionReqInput):
|
1585
1706
|
# handle error
|
1586
1707
|
session_id = recv_req.session_id
|
1587
1708
|
if session_id in self.sessions:
|
1588
1709
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1589
|
-
return session_id, False
|
1710
|
+
return OpenSessionReqOutput(session_id, False)
|
1590
1711
|
elif session_id is None:
|
1591
1712
|
logger.warning(f"session id is None, cannot open.")
|
1592
|
-
return session_id, False
|
1713
|
+
return OpenSessionReqOutput(session_id, False)
|
1593
1714
|
else:
|
1594
1715
|
self.sessions[session_id] = Session(
|
1595
1716
|
recv_req.capacity_of_str_len, session_id
|
1596
1717
|
)
|
1597
|
-
return session_id, True
|
1718
|
+
return OpenSessionReqOutput(session_id, True)
|
1598
1719
|
|
1599
1720
|
def close_session(self, recv_req: CloseSessionReqInput):
|
1600
1721
|
# handle error
|
@@ -1651,7 +1772,11 @@ def run_scheduler_process(
|
|
1651
1772
|
try:
|
1652
1773
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1653
1774
|
pipe_writer.send(
|
1654
|
-
{
|
1775
|
+
{
|
1776
|
+
"status": "ready",
|
1777
|
+
"max_total_num_tokens": scheduler.max_total_num_tokens,
|
1778
|
+
"max_req_input_len": scheduler.max_req_input_len,
|
1779
|
+
}
|
1655
1780
|
)
|
1656
1781
|
if scheduler.enable_overlap:
|
1657
1782
|
scheduler.event_loop_overlap()
|