sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
15
15
|
|
16
|
+
import faulthandler
|
16
17
|
import logging
|
17
18
|
import os
|
18
19
|
import signal
|
@@ -21,8 +22,10 @@ import time
|
|
21
22
|
import warnings
|
22
23
|
from collections import deque
|
23
24
|
from concurrent import futures
|
25
|
+
from dataclasses import dataclass
|
26
|
+
from http import HTTPStatus
|
24
27
|
from types import SimpleNamespace
|
25
|
-
from typing import Dict, List, Optional, Tuple
|
28
|
+
from typing import Dict, List, Optional, Tuple, Union
|
26
29
|
|
27
30
|
import psutil
|
28
31
|
import setproctitle
|
@@ -31,7 +34,9 @@ import zmq
|
|
31
34
|
|
32
35
|
from sglang.global_config import global_config
|
33
36
|
from sglang.srt.configs.model_config import ModelConfig
|
37
|
+
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
34
38
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
39
|
+
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
35
40
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
36
41
|
from sglang.srt.managers.io_struct import (
|
37
42
|
AbortReq,
|
@@ -46,6 +51,10 @@ from sglang.srt.managers.io_struct import (
|
|
46
51
|
OpenSessionReqInput,
|
47
52
|
OpenSessionReqOutput,
|
48
53
|
ProfileReq,
|
54
|
+
ReleaseMemoryOccupationReqInput,
|
55
|
+
ReleaseMemoryOccupationReqOutput,
|
56
|
+
ResumeMemoryOccupationReqInput,
|
57
|
+
ResumeMemoryOccupationReqOutput,
|
49
58
|
TokenizedEmbeddingReqInput,
|
50
59
|
TokenizedGenerateReqInput,
|
51
60
|
UpdateWeightFromDiskReqInput,
|
@@ -71,12 +80,14 @@ from sglang.srt.managers.schedule_policy import (
|
|
71
80
|
from sglang.srt.managers.session_controller import Session
|
72
81
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
73
82
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
83
|
+
from sglang.srt.managers.utils import validate_input_length
|
74
84
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
75
85
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
76
86
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
77
87
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
78
88
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
79
89
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
90
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
80
91
|
from sglang.srt.utils import (
|
81
92
|
broadcast_pyobj,
|
82
93
|
configure_logger,
|
@@ -87,7 +98,7 @@ from sglang.srt.utils import (
|
|
87
98
|
set_random_seed,
|
88
99
|
suppress_other_loggers,
|
89
100
|
)
|
90
|
-
from sglang.utils import get_exception_traceback
|
101
|
+
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
91
102
|
|
92
103
|
logger = logging.getLogger(__name__)
|
93
104
|
|
@@ -95,6 +106,19 @@ logger = logging.getLogger(__name__)
|
|
95
106
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
96
107
|
|
97
108
|
|
109
|
+
@dataclass
|
110
|
+
class GenerationBatchResult:
|
111
|
+
logits_output: LogitsProcessorOutput
|
112
|
+
next_token_ids: List[int]
|
113
|
+
bid: int
|
114
|
+
|
115
|
+
|
116
|
+
@dataclass
|
117
|
+
class EmbeddingBatchResult:
|
118
|
+
embeddings: torch.Tensor
|
119
|
+
bid: int
|
120
|
+
|
121
|
+
|
98
122
|
class Scheduler:
|
99
123
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
100
124
|
|
@@ -126,26 +150,36 @@ class Scheduler:
|
|
126
150
|
else 1
|
127
151
|
)
|
128
152
|
|
153
|
+
# Distributed rank info
|
154
|
+
self.dp_size = server_args.dp_size
|
155
|
+
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
156
|
+
compute_dp_attention_world_info(
|
157
|
+
server_args.enable_dp_attention,
|
158
|
+
self.tp_rank,
|
159
|
+
self.tp_size,
|
160
|
+
self.dp_size,
|
161
|
+
)
|
162
|
+
)
|
163
|
+
|
129
164
|
# Init inter-process communication
|
130
165
|
context = zmq.Context(2)
|
131
|
-
|
132
|
-
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
166
|
+
if self.attn_tp_rank == 0:
|
133
167
|
self.recv_from_tokenizer = get_zmq_socket(
|
134
|
-
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
168
|
+
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
135
169
|
)
|
136
170
|
self.send_to_tokenizer = get_zmq_socket(
|
137
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
171
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
138
172
|
)
|
139
173
|
|
140
174
|
if server_args.skip_tokenizer_init:
|
141
175
|
# Directly send to the TokenizerManager
|
142
176
|
self.send_to_detokenizer = get_zmq_socket(
|
143
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
177
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
144
178
|
)
|
145
179
|
else:
|
146
180
|
# Send to the DetokenizerManager
|
147
181
|
self.send_to_detokenizer = get_zmq_socket(
|
148
|
-
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
182
|
+
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
149
183
|
)
|
150
184
|
else:
|
151
185
|
self.recv_from_tokenizer = None
|
@@ -173,6 +207,7 @@ class Scheduler:
|
|
173
207
|
server_args.tokenizer_path,
|
174
208
|
tokenizer_mode=server_args.tokenizer_mode,
|
175
209
|
trust_remote_code=server_args.trust_remote_code,
|
210
|
+
revision=server_args.revision,
|
176
211
|
)
|
177
212
|
self.tokenizer = self.processor.tokenizer
|
178
213
|
else:
|
@@ -180,6 +215,7 @@ class Scheduler:
|
|
180
215
|
server_args.tokenizer_path,
|
181
216
|
tokenizer_mode=server_args.tokenizer_mode,
|
182
217
|
trust_remote_code=server_args.trust_remote_code,
|
218
|
+
revision=server_args.revision,
|
183
219
|
)
|
184
220
|
|
185
221
|
# Check whether overlap can be enabled
|
@@ -208,7 +244,7 @@ class Scheduler:
|
|
208
244
|
nccl_port=port_args.nccl_port,
|
209
245
|
)
|
210
246
|
|
211
|
-
# Launch worker for speculative decoding if
|
247
|
+
# Launch a worker for speculative decoding if needed
|
212
248
|
if self.spec_algorithm.is_eagle():
|
213
249
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
214
250
|
|
@@ -238,10 +274,10 @@ class Scheduler:
|
|
238
274
|
_,
|
239
275
|
) = self.tp_worker.get_worker_info()
|
240
276
|
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
277
|
+
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
241
278
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
242
279
|
global_server_args_dict.update(worker_global_server_args_dict)
|
243
280
|
set_random_seed(self.random_seed)
|
244
|
-
|
245
281
|
# Print debug info
|
246
282
|
logger.info(
|
247
283
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
@@ -281,9 +317,13 @@ class Scheduler:
|
|
281
317
|
self.forward_ct = 0
|
282
318
|
self.forward_ct_decode = 0
|
283
319
|
self.num_generated_tokens = 0
|
320
|
+
self.spec_num_total_accepted_tokens = 0
|
321
|
+
self.spec_num_total_forward_ct = 0
|
284
322
|
self.last_decode_stats_tic = time.time()
|
285
323
|
self.stream_interval = server_args.stream_interval
|
286
324
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
325
|
+
if self.device == "cpu":
|
326
|
+
self.current_stream.synchronize = lambda: None # No-op for CPU
|
287
327
|
|
288
328
|
# Session info
|
289
329
|
self.sessions: Dict[str, Session] = {}
|
@@ -300,28 +340,9 @@ class Scheduler:
|
|
300
340
|
# Init the grammar backend for constrained generation
|
301
341
|
self.grammar_queue: List[Req] = []
|
302
342
|
if not server_args.skip_tokenizer_init:
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
)
|
307
|
-
|
308
|
-
self.grammar_backend = OutlinesGrammarBackend(
|
309
|
-
self.tokenizer,
|
310
|
-
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
311
|
-
allow_jump_forward=not server_args.disable_jump_forward,
|
312
|
-
)
|
313
|
-
elif server_args.grammar_backend == "xgrammar":
|
314
|
-
from sglang.srt.constrained.xgrammar_backend import (
|
315
|
-
XGrammarGrammarBackend,
|
316
|
-
)
|
317
|
-
|
318
|
-
self.grammar_backend = XGrammarGrammarBackend(
|
319
|
-
self.tokenizer, vocab_size=self.model_config.vocab_size
|
320
|
-
)
|
321
|
-
else:
|
322
|
-
raise ValueError(
|
323
|
-
f"Invalid grammar backend: {server_args.grammar_backend}"
|
324
|
-
)
|
343
|
+
self.grammar_backend = create_grammar_backend(
|
344
|
+
server_args, self.tokenizer, self.model_config.vocab_size
|
345
|
+
)
|
325
346
|
else:
|
326
347
|
self.grammar_backend = None
|
327
348
|
|
@@ -356,6 +377,10 @@ class Scheduler:
|
|
356
377
|
t.start()
|
357
378
|
self.parent_process = psutil.Process().parent()
|
358
379
|
|
380
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
381
|
+
enable=server_args.enable_memory_saver
|
382
|
+
)
|
383
|
+
|
359
384
|
# Init profiler
|
360
385
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
361
386
|
self.profiler = None
|
@@ -383,22 +408,53 @@ class Scheduler:
|
|
383
408
|
},
|
384
409
|
)
|
385
410
|
|
411
|
+
# Init request dispatcher
|
412
|
+
self._request_dispatcher = TypeBasedDispatcher(
|
413
|
+
[
|
414
|
+
(TokenizedGenerateReqInput, self.handle_generate_request),
|
415
|
+
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
416
|
+
(FlushCacheReq, self.flush_cache_wrapped),
|
417
|
+
(AbortReq, self.abort_request),
|
418
|
+
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
419
|
+
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
420
|
+
(
|
421
|
+
UpdateWeightsFromDistributedReqInput,
|
422
|
+
self.update_weights_from_distributed,
|
423
|
+
),
|
424
|
+
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
425
|
+
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
426
|
+
(ProfileReq, self.profile),
|
427
|
+
(OpenSessionReqInput, self.open_session),
|
428
|
+
(CloseSessionReqInput, self.close_session),
|
429
|
+
(
|
430
|
+
ReleaseMemoryOccupationReqInput,
|
431
|
+
lambda _: self.release_memory_occupation(),
|
432
|
+
),
|
433
|
+
(
|
434
|
+
ResumeMemoryOccupationReqInput,
|
435
|
+
lambda _: self.resume_memory_occupation(),
|
436
|
+
),
|
437
|
+
]
|
438
|
+
)
|
439
|
+
|
386
440
|
def watchdog_thread(self):
|
387
441
|
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
388
442
|
self.watchdog_last_forward_ct = 0
|
389
443
|
self.watchdog_last_time = time.time()
|
390
444
|
|
391
445
|
while True:
|
446
|
+
current = time.time()
|
392
447
|
if self.cur_batch is not None:
|
393
448
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
394
|
-
if
|
449
|
+
if current > self.watchdog_last_time + self.watchdog_timeout:
|
395
450
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
396
451
|
break
|
397
452
|
else:
|
398
453
|
self.watchdog_last_forward_ct = self.forward_ct
|
399
|
-
self.watchdog_last_time =
|
400
|
-
time.sleep(self.watchdog_timeout
|
401
|
-
|
454
|
+
self.watchdog_last_time = current
|
455
|
+
time.sleep(self.watchdog_timeout // 2)
|
456
|
+
# Wait sometimes so that the parent process can print the error.
|
457
|
+
time.sleep(5)
|
402
458
|
self.parent_process.send_signal(signal.SIGQUIT)
|
403
459
|
|
404
460
|
@torch.no_grad()
|
@@ -409,10 +465,6 @@ class Scheduler:
|
|
409
465
|
self.process_input_requests(recv_reqs)
|
410
466
|
|
411
467
|
batch = self.get_next_batch_to_run()
|
412
|
-
|
413
|
-
if self.server_args.enable_dp_attention: # TODO: simplify this
|
414
|
-
batch = self.prepare_dp_attn_batch(batch)
|
415
|
-
|
416
468
|
self.cur_batch = batch
|
417
469
|
|
418
470
|
if batch:
|
@@ -442,7 +494,7 @@ class Scheduler:
|
|
442
494
|
result_queue.append((batch.copy(), result))
|
443
495
|
|
444
496
|
if self.last_batch is None:
|
445
|
-
# Create a dummy first batch to start the pipeline for overlap
|
497
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
446
498
|
# It is now used for triggering the sampling_info_done event.
|
447
499
|
tmp_batch = ScheduleBatch(
|
448
500
|
reqs=None,
|
@@ -467,7 +519,7 @@ class Scheduler:
|
|
467
519
|
|
468
520
|
def recv_requests(self) -> List[Req]:
|
469
521
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
470
|
-
if self.
|
522
|
+
if self.attn_tp_rank == 0:
|
471
523
|
recv_reqs = []
|
472
524
|
|
473
525
|
while True:
|
@@ -479,57 +531,48 @@ class Scheduler:
|
|
479
531
|
else:
|
480
532
|
recv_reqs = None
|
481
533
|
|
482
|
-
if self.
|
534
|
+
if self.server_args.enable_dp_attention:
|
535
|
+
if self.attn_tp_rank == 0:
|
536
|
+
work_reqs = [
|
537
|
+
req
|
538
|
+
for req in recv_reqs
|
539
|
+
if isinstance(
|
540
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
541
|
+
)
|
542
|
+
]
|
543
|
+
control_reqs = [
|
544
|
+
req
|
545
|
+
for req in recv_reqs
|
546
|
+
if not isinstance(
|
547
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
548
|
+
)
|
549
|
+
]
|
550
|
+
else:
|
551
|
+
work_reqs = None
|
552
|
+
control_reqs = None
|
553
|
+
|
554
|
+
if self.attn_tp_size != 1:
|
555
|
+
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
556
|
+
work_reqs = broadcast_pyobj(
|
557
|
+
work_reqs,
|
558
|
+
self.attn_tp_rank,
|
559
|
+
self.attn_tp_cpu_group,
|
560
|
+
src=attn_tp_rank_0,
|
561
|
+
)
|
562
|
+
if self.tp_size != 1:
|
563
|
+
control_reqs = broadcast_pyobj(
|
564
|
+
control_reqs, self.tp_rank, self.tp_cpu_group
|
565
|
+
)
|
566
|
+
recv_reqs = work_reqs + control_reqs
|
567
|
+
elif self.tp_size != 1:
|
483
568
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
484
569
|
return recv_reqs
|
485
570
|
|
486
571
|
def process_input_requests(self, recv_reqs: List):
|
487
572
|
for recv_req in recv_reqs:
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
self.handle_embedding_request(recv_req)
|
492
|
-
elif isinstance(recv_req, FlushCacheReq):
|
493
|
-
self.flush_cache()
|
494
|
-
elif isinstance(recv_req, AbortReq):
|
495
|
-
self.abort_request(recv_req)
|
496
|
-
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
497
|
-
success, message = self.update_weights_from_disk(recv_req)
|
498
|
-
self.send_to_tokenizer.send_pyobj(
|
499
|
-
UpdateWeightFromDiskReqOutput(success, message)
|
500
|
-
)
|
501
|
-
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
502
|
-
success, message = self.init_weights_update_group(recv_req)
|
503
|
-
self.send_to_tokenizer.send_pyobj(
|
504
|
-
InitWeightsUpdateGroupReqOutput(success, message)
|
505
|
-
)
|
506
|
-
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
507
|
-
success, message = self.update_weights_from_distributed(recv_req)
|
508
|
-
self.send_to_tokenizer.send_pyobj(
|
509
|
-
UpdateWeightsFromDistributedReqOutput(success, message)
|
510
|
-
)
|
511
|
-
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
512
|
-
success, message = self.update_weights_from_tensor(recv_req)
|
513
|
-
self.send_to_tokenizer.send_pyobj(
|
514
|
-
UpdateWeightsFromTensorReqOutput(success, message)
|
515
|
-
)
|
516
|
-
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
517
|
-
parameter = self.get_weights_by_name(recv_req)
|
518
|
-
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
519
|
-
elif isinstance(recv_req, ProfileReq):
|
520
|
-
if recv_req == ProfileReq.START_PROFILE:
|
521
|
-
self.start_profile()
|
522
|
-
else:
|
523
|
-
self.stop_profile()
|
524
|
-
elif isinstance(recv_req, OpenSessionReqInput):
|
525
|
-
session_id, success = self.open_session(recv_req)
|
526
|
-
self.send_to_tokenizer.send_pyobj(
|
527
|
-
OpenSessionReqOutput(session_id=session_id, success=success)
|
528
|
-
)
|
529
|
-
elif isinstance(recv_req, CloseSessionReqInput):
|
530
|
-
self.close_session(recv_req)
|
531
|
-
else:
|
532
|
-
raise ValueError(f"Invalid request: {recv_req}")
|
573
|
+
output = self._request_dispatcher(recv_req)
|
574
|
+
if output is not None:
|
575
|
+
self.send_to_tokenizer.send_pyobj(output)
|
533
576
|
|
534
577
|
def handle_generate_request(
|
535
578
|
self,
|
@@ -548,6 +591,19 @@ class Scheduler:
|
|
548
591
|
fake_input_ids = [1] * seq_length
|
549
592
|
recv_req.input_ids = fake_input_ids
|
550
593
|
|
594
|
+
# Handle custom logit processor passed to the request
|
595
|
+
custom_logit_processor = recv_req.custom_logit_processor
|
596
|
+
if (
|
597
|
+
not self.server_args.enable_custom_logit_processor
|
598
|
+
and custom_logit_processor is not None
|
599
|
+
):
|
600
|
+
logger.warning(
|
601
|
+
"The SGLang server is not configured to enable custom logit processor."
|
602
|
+
"The custom logit processor passed in will be ignored."
|
603
|
+
"Please set --enable-custom-logits-processor to enable this feature."
|
604
|
+
)
|
605
|
+
custom_logit_processor = None
|
606
|
+
|
551
607
|
req = Req(
|
552
608
|
recv_req.rid,
|
553
609
|
recv_req.input_text,
|
@@ -558,6 +614,7 @@ class Scheduler:
|
|
558
614
|
stream=recv_req.stream,
|
559
615
|
lora_path=recv_req.lora_path,
|
560
616
|
input_embeds=recv_req.input_embeds,
|
617
|
+
custom_logit_processor=custom_logit_processor,
|
561
618
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
562
619
|
)
|
563
620
|
req.tokenizer = self.tokenizer
|
@@ -589,15 +646,16 @@ class Scheduler:
|
|
589
646
|
req.extend_image_inputs(image_inputs)
|
590
647
|
|
591
648
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
592
|
-
|
649
|
+
error_msg = (
|
593
650
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
594
|
-
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}.
|
651
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
595
652
|
)
|
653
|
+
logger.error(error_msg)
|
596
654
|
req.origin_input_ids = [0]
|
597
655
|
req.image_inputs = None
|
598
656
|
req.sampling_params.max_new_tokens = 0
|
599
657
|
req.finished_reason = FINISH_ABORT(
|
600
|
-
|
658
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
601
659
|
)
|
602
660
|
self.waiting_queue.append(req)
|
603
661
|
return
|
@@ -609,13 +667,16 @@ class Scheduler:
|
|
609
667
|
# By default, only return the logprobs for output tokens
|
610
668
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
611
669
|
|
612
|
-
#
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
670
|
+
# Validate prompts length
|
671
|
+
error_msg = validate_input_length(
|
672
|
+
req,
|
673
|
+
self.max_req_input_len,
|
674
|
+
self.server_args.allow_auto_truncate,
|
675
|
+
)
|
676
|
+
|
677
|
+
if error_msg:
|
678
|
+
self.waiting_queue.append(req)
|
679
|
+
return
|
619
680
|
|
620
681
|
req.sampling_params.max_new_tokens = min(
|
621
682
|
(
|
@@ -663,13 +724,12 @@ class Scheduler:
|
|
663
724
|
)
|
664
725
|
req.tokenizer = self.tokenizer
|
665
726
|
|
666
|
-
#
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
727
|
+
# Validate prompts length
|
728
|
+
validate_input_length(
|
729
|
+
req,
|
730
|
+
self.max_req_input_len,
|
731
|
+
self.server_args.allow_auto_truncate,
|
732
|
+
)
|
673
733
|
|
674
734
|
self.waiting_queue.append(req)
|
675
735
|
|
@@ -715,21 +775,40 @@ class Scheduler:
|
|
715
775
|
self.num_generated_tokens = 0
|
716
776
|
self.last_decode_stats_tic = time.time()
|
717
777
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
718
|
-
logger.info(
|
719
|
-
f"Decode batch. "
|
720
|
-
f"#running-req: {num_running_reqs}, "
|
721
|
-
f"#token: {num_used}, "
|
722
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
723
|
-
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
724
|
-
f"#queue-req: {len(self.waiting_queue)}"
|
725
|
-
)
|
726
778
|
|
779
|
+
if self.spec_algorithm.is_none():
|
780
|
+
msg = (
|
781
|
+
f"Decode batch. "
|
782
|
+
f"#running-req: {num_running_reqs}, "
|
783
|
+
f"#token: {num_used}, "
|
784
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
785
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
786
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
787
|
+
)
|
788
|
+
spec_accept_length = 0
|
789
|
+
else:
|
790
|
+
spec_accept_length = (
|
791
|
+
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
792
|
+
)
|
793
|
+
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
794
|
+
msg = (
|
795
|
+
f"Decode batch. "
|
796
|
+
f"#running-req: {num_running_reqs}, "
|
797
|
+
f"#token: {num_used}, "
|
798
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
799
|
+
f"accept len: {spec_accept_length:.2f}, "
|
800
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
801
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
802
|
+
)
|
803
|
+
|
804
|
+
logger.info(msg)
|
727
805
|
if self.enable_metrics:
|
728
806
|
self.stats.num_running_reqs = num_running_reqs
|
729
807
|
self.stats.num_used_tokens = num_used
|
730
808
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
731
809
|
self.stats.gen_throughput = gen_throughput
|
732
810
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
811
|
+
self.stats.spec_accept_length = spec_accept_length
|
733
812
|
self.metrics_collector.log_stats(self.stats)
|
734
813
|
|
735
814
|
def check_memory(self):
|
@@ -772,16 +851,23 @@ class Scheduler:
|
|
772
851
|
else:
|
773
852
|
self.running_batch.merge_batch(self.last_batch)
|
774
853
|
|
775
|
-
# Run prefill first if possible
|
776
854
|
new_batch = self.get_new_batch_prefill()
|
777
855
|
if new_batch is not None:
|
778
|
-
|
856
|
+
# Run prefill first if possible
|
857
|
+
ret = new_batch
|
858
|
+
else:
|
859
|
+
# Run decode
|
860
|
+
if self.running_batch is None:
|
861
|
+
ret = None
|
862
|
+
else:
|
863
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
864
|
+
ret = self.running_batch
|
779
865
|
|
780
|
-
#
|
781
|
-
if self.
|
782
|
-
|
783
|
-
|
784
|
-
return
|
866
|
+
# Handle DP attention
|
867
|
+
if self.server_args.enable_dp_attention:
|
868
|
+
ret = self.prepare_dp_attn_batch(ret)
|
869
|
+
|
870
|
+
return ret
|
785
871
|
|
786
872
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
787
873
|
# Check if the grammar is ready in the grammar queue
|
@@ -805,9 +891,9 @@ class Scheduler:
|
|
805
891
|
# Prefill policy
|
806
892
|
adder = PrefillAdder(
|
807
893
|
self.tree_cache,
|
894
|
+
self.token_to_kv_pool,
|
808
895
|
self.running_batch,
|
809
896
|
self.new_token_ratio,
|
810
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
811
897
|
self.max_prefill_tokens,
|
812
898
|
self.chunked_prefill_size,
|
813
899
|
running_bs if self.is_mixed_chunk else 0,
|
@@ -868,7 +954,7 @@ class Scheduler:
|
|
868
954
|
self.being_chunked_req.is_being_chunked += 1
|
869
955
|
|
870
956
|
# Print stats
|
871
|
-
if self.
|
957
|
+
if self.attn_tp_rank == 0:
|
872
958
|
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
873
959
|
|
874
960
|
# Create a new batch
|
@@ -880,6 +966,7 @@ class Scheduler:
|
|
880
966
|
self.model_config,
|
881
967
|
self.enable_overlap,
|
882
968
|
self.spec_algorithm,
|
969
|
+
self.server_args.enable_custom_logit_processor,
|
883
970
|
)
|
884
971
|
new_batch.prepare_for_extend()
|
885
972
|
|
@@ -950,12 +1037,14 @@ class Scheduler:
|
|
950
1037
|
batch.prepare_for_decode()
|
951
1038
|
return batch
|
952
1039
|
|
953
|
-
def run_batch(
|
1040
|
+
def run_batch(
|
1041
|
+
self, batch: ScheduleBatch
|
1042
|
+
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
954
1043
|
"""Run a batch."""
|
955
1044
|
self.forward_ct += 1
|
956
1045
|
|
957
1046
|
if self.is_generation:
|
958
|
-
if batch.forward_mode.
|
1047
|
+
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
|
959
1048
|
if self.spec_algorithm.is_none():
|
960
1049
|
model_worker_batch = batch.get_model_worker_batch()
|
961
1050
|
logits_output, next_token_ids = (
|
@@ -968,45 +1057,65 @@ class Scheduler:
|
|
968
1057
|
model_worker_batch,
|
969
1058
|
num_accepted_tokens,
|
970
1059
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1060
|
+
self.spec_num_total_accepted_tokens += (
|
1061
|
+
num_accepted_tokens + batch.batch_size()
|
1062
|
+
)
|
1063
|
+
self.spec_num_total_forward_ct += batch.batch_size()
|
971
1064
|
self.num_generated_tokens += num_accepted_tokens
|
972
|
-
elif batch.forward_mode.is_idle():
|
973
|
-
model_worker_batch = batch.get_model_worker_batch()
|
974
|
-
self.tp_worker.forward_batch_idle(model_worker_batch)
|
975
|
-
return
|
976
1065
|
else:
|
977
|
-
|
978
|
-
if self.skip_tokenizer_init:
|
979
|
-
next_token_ids = torch.full(
|
980
|
-
(batch.batch_size(),), self.tokenizer.eos_token_id
|
981
|
-
)
|
982
|
-
else:
|
983
|
-
next_token_ids = torch.full((batch.batch_size(),), 0)
|
1066
|
+
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
984
1067
|
batch.output_ids = next_token_ids
|
985
|
-
|
1068
|
+
|
1069
|
+
ret = GenerationBatchResult(
|
1070
|
+
logits_output=logits_output,
|
1071
|
+
next_token_ids=next_token_ids,
|
1072
|
+
bid=model_worker_batch.bid,
|
1073
|
+
)
|
986
1074
|
else: # embedding or reward model
|
987
1075
|
assert batch.extend_num_tokens != 0
|
988
1076
|
model_worker_batch = batch.get_model_worker_batch()
|
989
1077
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
990
|
-
ret =
|
1078
|
+
ret = EmbeddingBatchResult(
|
1079
|
+
embeddings=embeddings, bid=model_worker_batch.bid
|
1080
|
+
)
|
991
1081
|
return ret
|
992
1082
|
|
993
|
-
def process_batch_result(
|
1083
|
+
def process_batch_result(
|
1084
|
+
self,
|
1085
|
+
batch: ScheduleBatch,
|
1086
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1087
|
+
):
|
994
1088
|
if batch.forward_mode.is_decode():
|
995
1089
|
self.process_batch_result_decode(batch, result)
|
996
1090
|
if batch.is_empty():
|
997
1091
|
self.running_batch = None
|
998
1092
|
elif batch.forward_mode.is_extend():
|
999
1093
|
self.process_batch_result_prefill(batch, result)
|
1094
|
+
elif batch.forward_mode.is_idle():
|
1095
|
+
if self.enable_overlap:
|
1096
|
+
self.tp_worker.resolve_batch_result(result.bid)
|
1000
1097
|
elif batch.forward_mode.is_dummy_first():
|
1001
1098
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1002
1099
|
self.current_stream.synchronize()
|
1003
1100
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1004
1101
|
|
1005
|
-
def process_batch_result_prefill(
|
1102
|
+
def process_batch_result_prefill(
|
1103
|
+
self,
|
1104
|
+
batch: ScheduleBatch,
|
1105
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1106
|
+
):
|
1006
1107
|
skip_stream_req = None
|
1007
1108
|
|
1008
1109
|
if self.is_generation:
|
1009
|
-
|
1110
|
+
(
|
1111
|
+
logits_output,
|
1112
|
+
next_token_ids,
|
1113
|
+
bid,
|
1114
|
+
) = (
|
1115
|
+
result.logits_output,
|
1116
|
+
result.next_token_ids,
|
1117
|
+
result.bid,
|
1118
|
+
)
|
1010
1119
|
|
1011
1120
|
if self.enable_overlap:
|
1012
1121
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
@@ -1020,9 +1129,6 @@ class Scheduler:
|
|
1020
1129
|
logits_output.input_token_logprobs = (
|
1021
1130
|
logits_output.input_token_logprobs.tolist()
|
1022
1131
|
)
|
1023
|
-
logits_output.normalized_prompt_logprobs = (
|
1024
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
1025
|
-
)
|
1026
1132
|
|
1027
1133
|
# Check finish conditions
|
1028
1134
|
logprob_pt = 0
|
@@ -1067,7 +1173,7 @@ class Scheduler:
|
|
1067
1173
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1068
1174
|
|
1069
1175
|
else: # embedding or reward model
|
1070
|
-
embeddings, bid = result
|
1176
|
+
embeddings, bid = result.embeddings, result.bid
|
1071
1177
|
embeddings = embeddings.tolist()
|
1072
1178
|
|
1073
1179
|
# Check finish conditions
|
@@ -1091,8 +1197,16 @@ class Scheduler:
|
|
1091
1197
|
|
1092
1198
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1093
1199
|
|
1094
|
-
def process_batch_result_decode(
|
1095
|
-
|
1200
|
+
def process_batch_result_decode(
|
1201
|
+
self,
|
1202
|
+
batch: ScheduleBatch,
|
1203
|
+
result: GenerationBatchResult,
|
1204
|
+
):
|
1205
|
+
logits_output, next_token_ids, bid = (
|
1206
|
+
result.logits_output,
|
1207
|
+
result.next_token_ids,
|
1208
|
+
result.bid,
|
1209
|
+
)
|
1096
1210
|
self.num_generated_tokens += len(batch.reqs)
|
1097
1211
|
|
1098
1212
|
if self.enable_overlap:
|
@@ -1150,7 +1264,7 @@ class Scheduler:
|
|
1150
1264
|
|
1151
1265
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1152
1266
|
if (
|
1153
|
-
self.
|
1267
|
+
self.attn_tp_rank == 0
|
1154
1268
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
1155
1269
|
):
|
1156
1270
|
self.log_decode_stats()
|
@@ -1170,9 +1284,6 @@ class Scheduler:
|
|
1170
1284
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
1171
1285
|
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
1172
1286
|
|
1173
|
-
if req.normalized_prompt_logprob is None:
|
1174
|
-
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
1175
|
-
|
1176
1287
|
if req.input_token_logprobs_val is None:
|
1177
1288
|
input_token_logprobs_val = output.input_token_logprobs[
|
1178
1289
|
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
@@ -1253,7 +1364,6 @@ class Scheduler:
|
|
1253
1364
|
decode_ids_list = []
|
1254
1365
|
read_offsets = []
|
1255
1366
|
output_ids = []
|
1256
|
-
origin_input_ids = []
|
1257
1367
|
|
1258
1368
|
skip_special_tokens = []
|
1259
1369
|
spaces_between_special_tokens = []
|
@@ -1271,15 +1381,12 @@ class Scheduler:
|
|
1271
1381
|
input_top_logprobs_idx = []
|
1272
1382
|
output_top_logprobs_val = []
|
1273
1383
|
output_top_logprobs_idx = []
|
1274
|
-
normalized_prompt_logprob = []
|
1275
1384
|
else:
|
1276
1385
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
1277
1386
|
output_token_logprobs_val
|
1278
1387
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1279
1388
|
input_top_logprobs_idx
|
1280
|
-
) = output_top_logprobs_val = output_top_logprobs_idx =
|
1281
|
-
normalized_prompt_logprob
|
1282
|
-
) = None
|
1389
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = None
|
1283
1390
|
|
1284
1391
|
for req in reqs:
|
1285
1392
|
if req is skip_req:
|
@@ -1305,14 +1412,8 @@ class Scheduler:
|
|
1305
1412
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
1306
1413
|
decode_ids_list.append(decode_ids)
|
1307
1414
|
read_offsets.append(read_offset)
|
1308
|
-
if self.skip_tokenizer_init
|
1415
|
+
if self.skip_tokenizer_init:
|
1309
1416
|
output_ids.append(req.output_ids)
|
1310
|
-
else:
|
1311
|
-
output_ids = None
|
1312
|
-
if self.server_args.return_token_ids:
|
1313
|
-
origin_input_ids.append(req.origin_input_ids)
|
1314
|
-
else:
|
1315
|
-
origin_input_ids = None
|
1316
1417
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1317
1418
|
spaces_between_special_tokens.append(
|
1318
1419
|
req.sampling_params.spaces_between_special_tokens
|
@@ -1332,7 +1433,6 @@ class Scheduler:
|
|
1332
1433
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1333
1434
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1334
1435
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1335
|
-
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
1336
1436
|
|
1337
1437
|
# Send to detokenizer
|
1338
1438
|
if rids:
|
@@ -1344,7 +1444,6 @@ class Scheduler:
|
|
1344
1444
|
decoded_texts,
|
1345
1445
|
decode_ids_list,
|
1346
1446
|
read_offsets,
|
1347
|
-
origin_input_ids,
|
1348
1447
|
output_ids,
|
1349
1448
|
skip_special_tokens,
|
1350
1449
|
spaces_between_special_tokens,
|
@@ -1360,7 +1459,6 @@ class Scheduler:
|
|
1360
1459
|
input_top_logprobs_idx,
|
1361
1460
|
output_top_logprobs_val,
|
1362
1461
|
output_top_logprobs_idx,
|
1363
|
-
normalized_prompt_logprob,
|
1364
1462
|
)
|
1365
1463
|
)
|
1366
1464
|
else: # embedding or reward model
|
@@ -1402,12 +1500,7 @@ class Scheduler:
|
|
1402
1500
|
# Check forward mode for cuda graph
|
1403
1501
|
if not self.server_args.disable_cuda_graph:
|
1404
1502
|
forward_mode_state = torch.tensor(
|
1405
|
-
(
|
1406
|
-
1
|
1407
|
-
if local_batch.forward_mode.is_decode()
|
1408
|
-
or local_batch.forward_mode.is_idle()
|
1409
|
-
else 0
|
1410
|
-
),
|
1503
|
+
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
1411
1504
|
dtype=torch.int32,
|
1412
1505
|
)
|
1413
1506
|
torch.distributed.all_reduce(
|
@@ -1428,6 +1521,7 @@ class Scheduler:
|
|
1428
1521
|
self.model_config,
|
1429
1522
|
self.enable_overlap,
|
1430
1523
|
self.spec_algorithm,
|
1524
|
+
self.server_args.enable_custom_logit_processor,
|
1431
1525
|
)
|
1432
1526
|
idle_batch.prepare_for_idle()
|
1433
1527
|
return idle_batch
|
@@ -1456,6 +1550,9 @@ class Scheduler:
|
|
1456
1550
|
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
1457
1551
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1458
1552
|
|
1553
|
+
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
1554
|
+
self.flush_cache()
|
1555
|
+
|
1459
1556
|
def flush_cache(self):
|
1460
1557
|
"""Flush the memory pool and cache."""
|
1461
1558
|
if len(self.waiting_queue) == 0 and (
|
@@ -1508,12 +1605,12 @@ class Scheduler:
|
|
1508
1605
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1509
1606
|
else:
|
1510
1607
|
logger.error(message)
|
1511
|
-
return success, message
|
1608
|
+
return UpdateWeightFromDiskReqOutput(success, message)
|
1512
1609
|
|
1513
1610
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1514
1611
|
"""Initialize the online model parameter update group."""
|
1515
1612
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
1516
|
-
return success, message
|
1613
|
+
return InitWeightsUpdateGroupReqOutput(success, message)
|
1517
1614
|
|
1518
1615
|
def update_weights_from_distributed(
|
1519
1616
|
self,
|
@@ -1526,7 +1623,7 @@ class Scheduler:
|
|
1526
1623
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1527
1624
|
else:
|
1528
1625
|
logger.error(message)
|
1529
|
-
return success, message
|
1626
|
+
return UpdateWeightsFromDistributedReqOutput(success, message)
|
1530
1627
|
|
1531
1628
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
1532
1629
|
"""Update the online model parameter from tensors."""
|
@@ -1537,11 +1634,33 @@ class Scheduler:
|
|
1537
1634
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1538
1635
|
else:
|
1539
1636
|
logger.error(message)
|
1540
|
-
return success, message
|
1637
|
+
return UpdateWeightsFromTensorReqOutput(success, message)
|
1541
1638
|
|
1542
1639
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1543
1640
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1544
|
-
return parameter
|
1641
|
+
return GetWeightsByNameReqOutput(parameter)
|
1642
|
+
|
1643
|
+
def release_memory_occupation(self):
|
1644
|
+
self.stashed_model_static_state = _export_static_state(
|
1645
|
+
self.tp_worker.worker.model_runner.model
|
1646
|
+
)
|
1647
|
+
self.memory_saver_adapter.pause()
|
1648
|
+
self.flush_cache()
|
1649
|
+
return ReleaseMemoryOccupationReqOutput()
|
1650
|
+
|
1651
|
+
def resume_memory_occupation(self):
|
1652
|
+
self.memory_saver_adapter.resume()
|
1653
|
+
_import_static_state(
|
1654
|
+
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
1655
|
+
)
|
1656
|
+
del self.stashed_model_static_state
|
1657
|
+
return ResumeMemoryOccupationReqOutput()
|
1658
|
+
|
1659
|
+
def profile(self, recv_req: ProfileReq):
|
1660
|
+
if recv_req == ProfileReq.START_PROFILE:
|
1661
|
+
self.start_profile()
|
1662
|
+
else:
|
1663
|
+
self.stop_profile()
|
1545
1664
|
|
1546
1665
|
def start_profile(self) -> None:
|
1547
1666
|
if self.profiler is None:
|
@@ -1557,20 +1676,20 @@ class Scheduler:
|
|
1557
1676
|
)
|
1558
1677
|
logger.info("Profiler is done")
|
1559
1678
|
|
1560
|
-
def open_session(self, recv_req: OpenSessionReqInput)
|
1679
|
+
def open_session(self, recv_req: OpenSessionReqInput):
|
1561
1680
|
# handle error
|
1562
1681
|
session_id = recv_req.session_id
|
1563
1682
|
if session_id in self.sessions:
|
1564
1683
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1565
|
-
return session_id, False
|
1684
|
+
return OpenSessionReqOutput(session_id, False)
|
1566
1685
|
elif session_id is None:
|
1567
1686
|
logger.warning(f"session id is None, cannot open.")
|
1568
|
-
return session_id, False
|
1687
|
+
return OpenSessionReqOutput(session_id, False)
|
1569
1688
|
else:
|
1570
1689
|
self.sessions[session_id] = Session(
|
1571
1690
|
recv_req.capacity_of_str_len, session_id
|
1572
1691
|
)
|
1573
|
-
return session_id, True
|
1692
|
+
return OpenSessionReqOutput(session_id, True)
|
1574
1693
|
|
1575
1694
|
def close_session(self, recv_req: CloseSessionReqInput):
|
1576
1695
|
# handle error
|
@@ -1581,6 +1700,20 @@ class Scheduler:
|
|
1581
1700
|
del self.sessions[session_id]
|
1582
1701
|
|
1583
1702
|
|
1703
|
+
def _export_static_state(model):
|
1704
|
+
return dict(
|
1705
|
+
buffers=[
|
1706
|
+
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
1707
|
+
]
|
1708
|
+
)
|
1709
|
+
|
1710
|
+
|
1711
|
+
def _import_static_state(model, static_params):
|
1712
|
+
self_named_buffers = dict(model.named_buffers())
|
1713
|
+
for name, tensor in static_params["buffers"]:
|
1714
|
+
self_named_buffers[name][...] = tensor
|
1715
|
+
|
1716
|
+
|
1584
1717
|
def run_scheduler_process(
|
1585
1718
|
server_args: ServerArgs,
|
1586
1719
|
port_args: PortArgs,
|
@@ -1590,6 +1723,7 @@ def run_scheduler_process(
|
|
1590
1723
|
pipe_writer,
|
1591
1724
|
):
|
1592
1725
|
setproctitle.setproctitle("sglang::scheduler")
|
1726
|
+
faulthandler.enable()
|
1593
1727
|
|
1594
1728
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1595
1729
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
@@ -1612,7 +1746,11 @@ def run_scheduler_process(
|
|
1612
1746
|
try:
|
1613
1747
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1614
1748
|
pipe_writer.send(
|
1615
|
-
{
|
1749
|
+
{
|
1750
|
+
"status": "ready",
|
1751
|
+
"max_total_num_tokens": scheduler.max_total_num_tokens,
|
1752
|
+
"max_req_input_len": scheduler.max_req_input_len,
|
1753
|
+
}
|
1616
1754
|
)
|
1617
1755
|
if scheduler.enable_overlap:
|
1618
1756
|
scheduler.event_loop_overlap()
|