sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +41 -5
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
- sglang/srt/layers/parameter.py +2 -1
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/fp8.py +6 -3
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +25 -2
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +277 -178
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +206 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +37 -15
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/sampling_batch_info.py +139 -4
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +57 -14
- sglang/srt/utils.py +103 -65
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,8 +22,10 @@ import time
|
|
22
22
|
import warnings
|
23
23
|
from collections import deque
|
24
24
|
from concurrent import futures
|
25
|
+
from dataclasses import dataclass
|
26
|
+
from http import HTTPStatus
|
25
27
|
from types import SimpleNamespace
|
26
|
-
from typing import Dict, List, Optional, Tuple
|
28
|
+
from typing import Dict, List, Optional, Tuple, Union
|
27
29
|
|
28
30
|
import psutil
|
29
31
|
import setproctitle
|
@@ -32,7 +34,9 @@ import zmq
|
|
32
34
|
|
33
35
|
from sglang.global_config import global_config
|
34
36
|
from sglang.srt.configs.model_config import ModelConfig
|
37
|
+
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
35
38
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
39
|
+
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
36
40
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
37
41
|
from sglang.srt.managers.io_struct import (
|
38
42
|
AbortReq,
|
@@ -76,6 +80,7 @@ from sglang.srt.managers.schedule_policy import (
|
|
76
80
|
from sglang.srt.managers.session_controller import Session
|
77
81
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
78
82
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
83
|
+
from sglang.srt.managers.utils import validate_input_length
|
79
84
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
80
85
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
81
86
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
@@ -93,7 +98,7 @@ from sglang.srt.utils import (
|
|
93
98
|
set_random_seed,
|
94
99
|
suppress_other_loggers,
|
95
100
|
)
|
96
|
-
from sglang.utils import get_exception_traceback
|
101
|
+
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
97
102
|
|
98
103
|
logger = logging.getLogger(__name__)
|
99
104
|
|
@@ -101,6 +106,19 @@ logger = logging.getLogger(__name__)
|
|
101
106
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
102
107
|
|
103
108
|
|
109
|
+
@dataclass
|
110
|
+
class GenerationBatchResult:
|
111
|
+
logits_output: LogitsProcessorOutput
|
112
|
+
next_token_ids: List[int]
|
113
|
+
bid: int
|
114
|
+
|
115
|
+
|
116
|
+
@dataclass
|
117
|
+
class EmbeddingBatchResult:
|
118
|
+
embeddings: torch.Tensor
|
119
|
+
bid: int
|
120
|
+
|
121
|
+
|
104
122
|
class Scheduler:
|
105
123
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
106
124
|
|
@@ -132,26 +150,36 @@ class Scheduler:
|
|
132
150
|
else 1
|
133
151
|
)
|
134
152
|
|
153
|
+
# Distributed rank info
|
154
|
+
self.dp_size = server_args.dp_size
|
155
|
+
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
156
|
+
compute_dp_attention_world_info(
|
157
|
+
server_args.enable_dp_attention,
|
158
|
+
self.tp_rank,
|
159
|
+
self.tp_size,
|
160
|
+
self.dp_size,
|
161
|
+
)
|
162
|
+
)
|
163
|
+
|
135
164
|
# Init inter-process communication
|
136
165
|
context = zmq.Context(2)
|
137
|
-
|
138
|
-
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
166
|
+
if self.attn_tp_rank == 0:
|
139
167
|
self.recv_from_tokenizer = get_zmq_socket(
|
140
|
-
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
168
|
+
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
141
169
|
)
|
142
170
|
self.send_to_tokenizer = get_zmq_socket(
|
143
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
171
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
144
172
|
)
|
145
173
|
|
146
174
|
if server_args.skip_tokenizer_init:
|
147
175
|
# Directly send to the TokenizerManager
|
148
176
|
self.send_to_detokenizer = get_zmq_socket(
|
149
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
177
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
150
178
|
)
|
151
179
|
else:
|
152
180
|
# Send to the DetokenizerManager
|
153
181
|
self.send_to_detokenizer = get_zmq_socket(
|
154
|
-
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
182
|
+
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
155
183
|
)
|
156
184
|
else:
|
157
185
|
self.recv_from_tokenizer = None
|
@@ -179,6 +207,7 @@ class Scheduler:
|
|
179
207
|
server_args.tokenizer_path,
|
180
208
|
tokenizer_mode=server_args.tokenizer_mode,
|
181
209
|
trust_remote_code=server_args.trust_remote_code,
|
210
|
+
revision=server_args.revision,
|
182
211
|
)
|
183
212
|
self.tokenizer = self.processor.tokenizer
|
184
213
|
else:
|
@@ -186,6 +215,7 @@ class Scheduler:
|
|
186
215
|
server_args.tokenizer_path,
|
187
216
|
tokenizer_mode=server_args.tokenizer_mode,
|
188
217
|
trust_remote_code=server_args.trust_remote_code,
|
218
|
+
revision=server_args.revision,
|
189
219
|
)
|
190
220
|
|
191
221
|
# Check whether overlap can be enabled
|
@@ -214,7 +244,7 @@ class Scheduler:
|
|
214
244
|
nccl_port=port_args.nccl_port,
|
215
245
|
)
|
216
246
|
|
217
|
-
# Launch worker for speculative decoding if
|
247
|
+
# Launch a worker for speculative decoding if needed
|
218
248
|
if self.spec_algorithm.is_eagle():
|
219
249
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
220
250
|
|
@@ -244,10 +274,10 @@ class Scheduler:
|
|
244
274
|
_,
|
245
275
|
) = self.tp_worker.get_worker_info()
|
246
276
|
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
277
|
+
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
247
278
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
248
279
|
global_server_args_dict.update(worker_global_server_args_dict)
|
249
280
|
set_random_seed(self.random_seed)
|
250
|
-
|
251
281
|
# Print debug info
|
252
282
|
logger.info(
|
253
283
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
@@ -287,9 +317,13 @@ class Scheduler:
|
|
287
317
|
self.forward_ct = 0
|
288
318
|
self.forward_ct_decode = 0
|
289
319
|
self.num_generated_tokens = 0
|
320
|
+
self.spec_num_total_accepted_tokens = 0
|
321
|
+
self.spec_num_total_forward_ct = 0
|
290
322
|
self.last_decode_stats_tic = time.time()
|
291
323
|
self.stream_interval = server_args.stream_interval
|
292
324
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
325
|
+
if self.device == "cpu":
|
326
|
+
self.current_stream.synchronize = lambda: None # No-op for CPU
|
293
327
|
|
294
328
|
# Session info
|
295
329
|
self.sessions: Dict[str, Session] = {}
|
@@ -306,28 +340,9 @@ class Scheduler:
|
|
306
340
|
# Init the grammar backend for constrained generation
|
307
341
|
self.grammar_queue: List[Req] = []
|
308
342
|
if not server_args.skip_tokenizer_init:
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
)
|
313
|
-
|
314
|
-
self.grammar_backend = OutlinesGrammarBackend(
|
315
|
-
self.tokenizer,
|
316
|
-
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
317
|
-
allow_jump_forward=not server_args.disable_jump_forward,
|
318
|
-
)
|
319
|
-
elif server_args.grammar_backend == "xgrammar":
|
320
|
-
from sglang.srt.constrained.xgrammar_backend import (
|
321
|
-
XGrammarGrammarBackend,
|
322
|
-
)
|
323
|
-
|
324
|
-
self.grammar_backend = XGrammarGrammarBackend(
|
325
|
-
self.tokenizer, vocab_size=self.model_config.vocab_size
|
326
|
-
)
|
327
|
-
else:
|
328
|
-
raise ValueError(
|
329
|
-
f"Invalid grammar backend: {server_args.grammar_backend}"
|
330
|
-
)
|
343
|
+
self.grammar_backend = create_grammar_backend(
|
344
|
+
server_args, self.tokenizer, self.model_config.vocab_size
|
345
|
+
)
|
331
346
|
else:
|
332
347
|
self.grammar_backend = None
|
333
348
|
|
@@ -393,22 +408,51 @@ class Scheduler:
|
|
393
408
|
},
|
394
409
|
)
|
395
410
|
|
411
|
+
# Init request dispatcher
|
412
|
+
self._request_dispatcher = TypeBasedDispatcher(
|
413
|
+
[
|
414
|
+
(TokenizedGenerateReqInput, self.handle_generate_request),
|
415
|
+
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
416
|
+
(FlushCacheReq, self.flush_cache_wrapped),
|
417
|
+
(AbortReq, self.abort_request),
|
418
|
+
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
419
|
+
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
420
|
+
(
|
421
|
+
UpdateWeightsFromDistributedReqInput,
|
422
|
+
self.update_weights_from_distributed,
|
423
|
+
),
|
424
|
+
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
425
|
+
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
426
|
+
(ProfileReq, self.profile),
|
427
|
+
(OpenSessionReqInput, self.open_session),
|
428
|
+
(CloseSessionReqInput, self.close_session),
|
429
|
+
(
|
430
|
+
ReleaseMemoryOccupationReqInput,
|
431
|
+
lambda _: self.release_memory_occupation(),
|
432
|
+
),
|
433
|
+
(
|
434
|
+
ResumeMemoryOccupationReqInput,
|
435
|
+
lambda _: self.resume_memory_occupation(),
|
436
|
+
),
|
437
|
+
]
|
438
|
+
)
|
439
|
+
|
396
440
|
def watchdog_thread(self):
|
397
441
|
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
398
442
|
self.watchdog_last_forward_ct = 0
|
399
443
|
self.watchdog_last_time = time.time()
|
400
444
|
|
401
445
|
while True:
|
446
|
+
current = time.time()
|
402
447
|
if self.cur_batch is not None:
|
403
448
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
404
|
-
if
|
449
|
+
if current > self.watchdog_last_time + self.watchdog_timeout:
|
405
450
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
406
451
|
break
|
407
452
|
else:
|
408
453
|
self.watchdog_last_forward_ct = self.forward_ct
|
409
|
-
self.watchdog_last_time =
|
410
|
-
time.sleep(self.watchdog_timeout
|
411
|
-
|
454
|
+
self.watchdog_last_time = current
|
455
|
+
time.sleep(self.watchdog_timeout // 2)
|
412
456
|
# Wait sometimes so that the parent process can print the error.
|
413
457
|
time.sleep(5)
|
414
458
|
self.parent_process.send_signal(signal.SIGQUIT)
|
@@ -421,10 +465,6 @@ class Scheduler:
|
|
421
465
|
self.process_input_requests(recv_reqs)
|
422
466
|
|
423
467
|
batch = self.get_next_batch_to_run()
|
424
|
-
|
425
|
-
if self.server_args.enable_dp_attention: # TODO: simplify this
|
426
|
-
batch = self.prepare_dp_attn_batch(batch)
|
427
|
-
|
428
468
|
self.cur_batch = batch
|
429
469
|
|
430
470
|
if batch:
|
@@ -454,7 +494,7 @@ class Scheduler:
|
|
454
494
|
result_queue.append((batch.copy(), result))
|
455
495
|
|
456
496
|
if self.last_batch is None:
|
457
|
-
# Create a dummy first batch to start the pipeline for overlap
|
497
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
458
498
|
# It is now used for triggering the sampling_info_done event.
|
459
499
|
tmp_batch = ScheduleBatch(
|
460
500
|
reqs=None,
|
@@ -479,7 +519,7 @@ class Scheduler:
|
|
479
519
|
|
480
520
|
def recv_requests(self) -> List[Req]:
|
481
521
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
482
|
-
if self.
|
522
|
+
if self.attn_tp_rank == 0:
|
483
523
|
recv_reqs = []
|
484
524
|
|
485
525
|
while True:
|
@@ -491,63 +531,48 @@ class Scheduler:
|
|
491
531
|
else:
|
492
532
|
recv_reqs = None
|
493
533
|
|
494
|
-
if self.
|
534
|
+
if self.server_args.enable_dp_attention:
|
535
|
+
if self.attn_tp_rank == 0:
|
536
|
+
work_reqs = [
|
537
|
+
req
|
538
|
+
for req in recv_reqs
|
539
|
+
if isinstance(
|
540
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
541
|
+
)
|
542
|
+
]
|
543
|
+
control_reqs = [
|
544
|
+
req
|
545
|
+
for req in recv_reqs
|
546
|
+
if not isinstance(
|
547
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
548
|
+
)
|
549
|
+
]
|
550
|
+
else:
|
551
|
+
work_reqs = None
|
552
|
+
control_reqs = None
|
553
|
+
|
554
|
+
if self.attn_tp_size != 1:
|
555
|
+
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
556
|
+
work_reqs = broadcast_pyobj(
|
557
|
+
work_reqs,
|
558
|
+
self.attn_tp_rank,
|
559
|
+
self.attn_tp_cpu_group,
|
560
|
+
src=attn_tp_rank_0,
|
561
|
+
)
|
562
|
+
if self.tp_size != 1:
|
563
|
+
control_reqs = broadcast_pyobj(
|
564
|
+
control_reqs, self.tp_rank, self.tp_cpu_group
|
565
|
+
)
|
566
|
+
recv_reqs = work_reqs + control_reqs
|
567
|
+
elif self.tp_size != 1:
|
495
568
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
496
569
|
return recv_reqs
|
497
570
|
|
498
571
|
def process_input_requests(self, recv_reqs: List):
|
499
572
|
for recv_req in recv_reqs:
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
self.handle_embedding_request(recv_req)
|
504
|
-
elif isinstance(recv_req, FlushCacheReq):
|
505
|
-
self.flush_cache()
|
506
|
-
elif isinstance(recv_req, AbortReq):
|
507
|
-
self.abort_request(recv_req)
|
508
|
-
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
509
|
-
success, message = self.update_weights_from_disk(recv_req)
|
510
|
-
self.send_to_tokenizer.send_pyobj(
|
511
|
-
UpdateWeightFromDiskReqOutput(success, message)
|
512
|
-
)
|
513
|
-
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
514
|
-
success, message = self.init_weights_update_group(recv_req)
|
515
|
-
self.send_to_tokenizer.send_pyobj(
|
516
|
-
InitWeightsUpdateGroupReqOutput(success, message)
|
517
|
-
)
|
518
|
-
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
519
|
-
success, message = self.update_weights_from_distributed(recv_req)
|
520
|
-
self.send_to_tokenizer.send_pyobj(
|
521
|
-
UpdateWeightsFromDistributedReqOutput(success, message)
|
522
|
-
)
|
523
|
-
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
524
|
-
success, message = self.update_weights_from_tensor(recv_req)
|
525
|
-
self.send_to_tokenizer.send_pyobj(
|
526
|
-
UpdateWeightsFromTensorReqOutput(success, message)
|
527
|
-
)
|
528
|
-
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
529
|
-
parameter = self.get_weights_by_name(recv_req)
|
530
|
-
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
531
|
-
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
532
|
-
self.release_memory_occupation()
|
533
|
-
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
534
|
-
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
535
|
-
self.resume_memory_occupation()
|
536
|
-
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
537
|
-
elif isinstance(recv_req, ProfileReq):
|
538
|
-
if recv_req == ProfileReq.START_PROFILE:
|
539
|
-
self.start_profile()
|
540
|
-
else:
|
541
|
-
self.stop_profile()
|
542
|
-
elif isinstance(recv_req, OpenSessionReqInput):
|
543
|
-
session_id, success = self.open_session(recv_req)
|
544
|
-
self.send_to_tokenizer.send_pyobj(
|
545
|
-
OpenSessionReqOutput(session_id=session_id, success=success)
|
546
|
-
)
|
547
|
-
elif isinstance(recv_req, CloseSessionReqInput):
|
548
|
-
self.close_session(recv_req)
|
549
|
-
else:
|
550
|
-
raise ValueError(f"Invalid request: {recv_req}")
|
573
|
+
output = self._request_dispatcher(recv_req)
|
574
|
+
if output is not None:
|
575
|
+
self.send_to_tokenizer.send_pyobj(output)
|
551
576
|
|
552
577
|
def handle_generate_request(
|
553
578
|
self,
|
@@ -566,6 +591,19 @@ class Scheduler:
|
|
566
591
|
fake_input_ids = [1] * seq_length
|
567
592
|
recv_req.input_ids = fake_input_ids
|
568
593
|
|
594
|
+
# Handle custom logit processor passed to the request
|
595
|
+
custom_logit_processor = recv_req.custom_logit_processor
|
596
|
+
if (
|
597
|
+
not self.server_args.enable_custom_logit_processor
|
598
|
+
and custom_logit_processor is not None
|
599
|
+
):
|
600
|
+
logger.warning(
|
601
|
+
"The SGLang server is not configured to enable custom logit processor."
|
602
|
+
"The custom logit processor passed in will be ignored."
|
603
|
+
"Please set --enable-custom-logits-processor to enable this feature."
|
604
|
+
)
|
605
|
+
custom_logit_processor = None
|
606
|
+
|
569
607
|
req = Req(
|
570
608
|
recv_req.rid,
|
571
609
|
recv_req.input_text,
|
@@ -576,6 +614,7 @@ class Scheduler:
|
|
576
614
|
stream=recv_req.stream,
|
577
615
|
lora_path=recv_req.lora_path,
|
578
616
|
input_embeds=recv_req.input_embeds,
|
617
|
+
custom_logit_processor=custom_logit_processor,
|
579
618
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
580
619
|
)
|
581
620
|
req.tokenizer = self.tokenizer
|
@@ -607,15 +646,16 @@ class Scheduler:
|
|
607
646
|
req.extend_image_inputs(image_inputs)
|
608
647
|
|
609
648
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
610
|
-
|
649
|
+
error_msg = (
|
611
650
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
612
|
-
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}.
|
651
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
613
652
|
)
|
653
|
+
logger.error(error_msg)
|
614
654
|
req.origin_input_ids = [0]
|
615
655
|
req.image_inputs = None
|
616
656
|
req.sampling_params.max_new_tokens = 0
|
617
657
|
req.finished_reason = FINISH_ABORT(
|
618
|
-
|
658
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
619
659
|
)
|
620
660
|
self.waiting_queue.append(req)
|
621
661
|
return
|
@@ -627,13 +667,16 @@ class Scheduler:
|
|
627
667
|
# By default, only return the logprobs for output tokens
|
628
668
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
629
669
|
|
630
|
-
#
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
670
|
+
# Validate prompts length
|
671
|
+
error_msg = validate_input_length(
|
672
|
+
req,
|
673
|
+
self.max_req_input_len,
|
674
|
+
self.server_args.allow_auto_truncate,
|
675
|
+
)
|
676
|
+
|
677
|
+
if error_msg:
|
678
|
+
self.waiting_queue.append(req)
|
679
|
+
return
|
637
680
|
|
638
681
|
req.sampling_params.max_new_tokens = min(
|
639
682
|
(
|
@@ -681,13 +724,12 @@ class Scheduler:
|
|
681
724
|
)
|
682
725
|
req.tokenizer = self.tokenizer
|
683
726
|
|
684
|
-
#
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
727
|
+
# Validate prompts length
|
728
|
+
validate_input_length(
|
729
|
+
req,
|
730
|
+
self.max_req_input_len,
|
731
|
+
self.server_args.allow_auto_truncate,
|
732
|
+
)
|
691
733
|
|
692
734
|
self.waiting_queue.append(req)
|
693
735
|
|
@@ -733,21 +775,40 @@ class Scheduler:
|
|
733
775
|
self.num_generated_tokens = 0
|
734
776
|
self.last_decode_stats_tic = time.time()
|
735
777
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
736
|
-
logger.info(
|
737
|
-
f"Decode batch. "
|
738
|
-
f"#running-req: {num_running_reqs}, "
|
739
|
-
f"#token: {num_used}, "
|
740
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
741
|
-
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
742
|
-
f"#queue-req: {len(self.waiting_queue)}"
|
743
|
-
)
|
744
778
|
|
779
|
+
if self.spec_algorithm.is_none():
|
780
|
+
msg = (
|
781
|
+
f"Decode batch. "
|
782
|
+
f"#running-req: {num_running_reqs}, "
|
783
|
+
f"#token: {num_used}, "
|
784
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
785
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
786
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
787
|
+
)
|
788
|
+
spec_accept_length = 0
|
789
|
+
else:
|
790
|
+
spec_accept_length = (
|
791
|
+
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
792
|
+
)
|
793
|
+
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
794
|
+
msg = (
|
795
|
+
f"Decode batch. "
|
796
|
+
f"#running-req: {num_running_reqs}, "
|
797
|
+
f"#token: {num_used}, "
|
798
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
799
|
+
f"accept len: {spec_accept_length:.2f}, "
|
800
|
+
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
801
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
802
|
+
)
|
803
|
+
|
804
|
+
logger.info(msg)
|
745
805
|
if self.enable_metrics:
|
746
806
|
self.stats.num_running_reqs = num_running_reqs
|
747
807
|
self.stats.num_used_tokens = num_used
|
748
808
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
749
809
|
self.stats.gen_throughput = gen_throughput
|
750
810
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
811
|
+
self.stats.spec_accept_length = spec_accept_length
|
751
812
|
self.metrics_collector.log_stats(self.stats)
|
752
813
|
|
753
814
|
def check_memory(self):
|
@@ -790,16 +851,23 @@ class Scheduler:
|
|
790
851
|
else:
|
791
852
|
self.running_batch.merge_batch(self.last_batch)
|
792
853
|
|
793
|
-
# Run prefill first if possible
|
794
854
|
new_batch = self.get_new_batch_prefill()
|
795
855
|
if new_batch is not None:
|
796
|
-
|
856
|
+
# Run prefill first if possible
|
857
|
+
ret = new_batch
|
858
|
+
else:
|
859
|
+
# Run decode
|
860
|
+
if self.running_batch is None:
|
861
|
+
ret = None
|
862
|
+
else:
|
863
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
864
|
+
ret = self.running_batch
|
797
865
|
|
798
|
-
#
|
799
|
-
if self.
|
800
|
-
|
801
|
-
|
802
|
-
return
|
866
|
+
# Handle DP attention
|
867
|
+
if self.server_args.enable_dp_attention:
|
868
|
+
ret = self.prepare_dp_attn_batch(ret)
|
869
|
+
|
870
|
+
return ret
|
803
871
|
|
804
872
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
805
873
|
# Check if the grammar is ready in the grammar queue
|
@@ -823,9 +891,9 @@ class Scheduler:
|
|
823
891
|
# Prefill policy
|
824
892
|
adder = PrefillAdder(
|
825
893
|
self.tree_cache,
|
894
|
+
self.token_to_kv_pool,
|
826
895
|
self.running_batch,
|
827
896
|
self.new_token_ratio,
|
828
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
829
897
|
self.max_prefill_tokens,
|
830
898
|
self.chunked_prefill_size,
|
831
899
|
running_bs if self.is_mixed_chunk else 0,
|
@@ -886,7 +954,7 @@ class Scheduler:
|
|
886
954
|
self.being_chunked_req.is_being_chunked += 1
|
887
955
|
|
888
956
|
# Print stats
|
889
|
-
if self.
|
957
|
+
if self.attn_tp_rank == 0:
|
890
958
|
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
891
959
|
|
892
960
|
# Create a new batch
|
@@ -898,6 +966,7 @@ class Scheduler:
|
|
898
966
|
self.model_config,
|
899
967
|
self.enable_overlap,
|
900
968
|
self.spec_algorithm,
|
969
|
+
self.server_args.enable_custom_logit_processor,
|
901
970
|
)
|
902
971
|
new_batch.prepare_for_extend()
|
903
972
|
|
@@ -968,12 +1037,14 @@ class Scheduler:
|
|
968
1037
|
batch.prepare_for_decode()
|
969
1038
|
return batch
|
970
1039
|
|
971
|
-
def run_batch(
|
1040
|
+
def run_batch(
|
1041
|
+
self, batch: ScheduleBatch
|
1042
|
+
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
972
1043
|
"""Run a batch."""
|
973
1044
|
self.forward_ct += 1
|
974
1045
|
|
975
1046
|
if self.is_generation:
|
976
|
-
if batch.forward_mode.
|
1047
|
+
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
|
977
1048
|
if self.spec_algorithm.is_none():
|
978
1049
|
model_worker_batch = batch.get_model_worker_batch()
|
979
1050
|
logits_output, next_token_ids = (
|
@@ -986,45 +1057,65 @@ class Scheduler:
|
|
986
1057
|
model_worker_batch,
|
987
1058
|
num_accepted_tokens,
|
988
1059
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1060
|
+
self.spec_num_total_accepted_tokens += (
|
1061
|
+
num_accepted_tokens + batch.batch_size()
|
1062
|
+
)
|
1063
|
+
self.spec_num_total_forward_ct += batch.batch_size()
|
989
1064
|
self.num_generated_tokens += num_accepted_tokens
|
990
|
-
elif batch.forward_mode.is_idle():
|
991
|
-
model_worker_batch = batch.get_model_worker_batch()
|
992
|
-
self.tp_worker.forward_batch_idle(model_worker_batch)
|
993
|
-
return
|
994
1065
|
else:
|
995
|
-
|
996
|
-
if self.skip_tokenizer_init:
|
997
|
-
next_token_ids = torch.full(
|
998
|
-
(batch.batch_size(),), self.tokenizer.eos_token_id
|
999
|
-
)
|
1000
|
-
else:
|
1001
|
-
next_token_ids = torch.full((batch.batch_size(),), 0)
|
1066
|
+
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
1002
1067
|
batch.output_ids = next_token_ids
|
1003
|
-
|
1068
|
+
|
1069
|
+
ret = GenerationBatchResult(
|
1070
|
+
logits_output=logits_output,
|
1071
|
+
next_token_ids=next_token_ids,
|
1072
|
+
bid=model_worker_batch.bid,
|
1073
|
+
)
|
1004
1074
|
else: # embedding or reward model
|
1005
1075
|
assert batch.extend_num_tokens != 0
|
1006
1076
|
model_worker_batch = batch.get_model_worker_batch()
|
1007
1077
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1008
|
-
ret =
|
1078
|
+
ret = EmbeddingBatchResult(
|
1079
|
+
embeddings=embeddings, bid=model_worker_batch.bid
|
1080
|
+
)
|
1009
1081
|
return ret
|
1010
1082
|
|
1011
|
-
def process_batch_result(
|
1083
|
+
def process_batch_result(
|
1084
|
+
self,
|
1085
|
+
batch: ScheduleBatch,
|
1086
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1087
|
+
):
|
1012
1088
|
if batch.forward_mode.is_decode():
|
1013
1089
|
self.process_batch_result_decode(batch, result)
|
1014
1090
|
if batch.is_empty():
|
1015
1091
|
self.running_batch = None
|
1016
1092
|
elif batch.forward_mode.is_extend():
|
1017
1093
|
self.process_batch_result_prefill(batch, result)
|
1094
|
+
elif batch.forward_mode.is_idle():
|
1095
|
+
if self.enable_overlap:
|
1096
|
+
self.tp_worker.resolve_batch_result(result.bid)
|
1018
1097
|
elif batch.forward_mode.is_dummy_first():
|
1019
1098
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1020
1099
|
self.current_stream.synchronize()
|
1021
1100
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1022
1101
|
|
1023
|
-
def process_batch_result_prefill(
|
1102
|
+
def process_batch_result_prefill(
|
1103
|
+
self,
|
1104
|
+
batch: ScheduleBatch,
|
1105
|
+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1106
|
+
):
|
1024
1107
|
skip_stream_req = None
|
1025
1108
|
|
1026
1109
|
if self.is_generation:
|
1027
|
-
|
1110
|
+
(
|
1111
|
+
logits_output,
|
1112
|
+
next_token_ids,
|
1113
|
+
bid,
|
1114
|
+
) = (
|
1115
|
+
result.logits_output,
|
1116
|
+
result.next_token_ids,
|
1117
|
+
result.bid,
|
1118
|
+
)
|
1028
1119
|
|
1029
1120
|
if self.enable_overlap:
|
1030
1121
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
@@ -1038,9 +1129,6 @@ class Scheduler:
|
|
1038
1129
|
logits_output.input_token_logprobs = (
|
1039
1130
|
logits_output.input_token_logprobs.tolist()
|
1040
1131
|
)
|
1041
|
-
logits_output.normalized_prompt_logprobs = (
|
1042
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
1043
|
-
)
|
1044
1132
|
|
1045
1133
|
# Check finish conditions
|
1046
1134
|
logprob_pt = 0
|
@@ -1085,7 +1173,7 @@ class Scheduler:
|
|
1085
1173
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1086
1174
|
|
1087
1175
|
else: # embedding or reward model
|
1088
|
-
embeddings, bid = result
|
1176
|
+
embeddings, bid = result.embeddings, result.bid
|
1089
1177
|
embeddings = embeddings.tolist()
|
1090
1178
|
|
1091
1179
|
# Check finish conditions
|
@@ -1109,8 +1197,16 @@ class Scheduler:
|
|
1109
1197
|
|
1110
1198
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1111
1199
|
|
1112
|
-
def process_batch_result_decode(
|
1113
|
-
|
1200
|
+
def process_batch_result_decode(
|
1201
|
+
self,
|
1202
|
+
batch: ScheduleBatch,
|
1203
|
+
result: GenerationBatchResult,
|
1204
|
+
):
|
1205
|
+
logits_output, next_token_ids, bid = (
|
1206
|
+
result.logits_output,
|
1207
|
+
result.next_token_ids,
|
1208
|
+
result.bid,
|
1209
|
+
)
|
1114
1210
|
self.num_generated_tokens += len(batch.reqs)
|
1115
1211
|
|
1116
1212
|
if self.enable_overlap:
|
@@ -1168,7 +1264,7 @@ class Scheduler:
|
|
1168
1264
|
|
1169
1265
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
1170
1266
|
if (
|
1171
|
-
self.
|
1267
|
+
self.attn_tp_rank == 0
|
1172
1268
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
1173
1269
|
):
|
1174
1270
|
self.log_decode_stats()
|
@@ -1188,9 +1284,6 @@ class Scheduler:
|
|
1188
1284
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
1189
1285
|
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
1190
1286
|
|
1191
|
-
if req.normalized_prompt_logprob is None:
|
1192
|
-
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
1193
|
-
|
1194
1287
|
if req.input_token_logprobs_val is None:
|
1195
1288
|
input_token_logprobs_val = output.input_token_logprobs[
|
1196
1289
|
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
@@ -1288,15 +1381,12 @@ class Scheduler:
|
|
1288
1381
|
input_top_logprobs_idx = []
|
1289
1382
|
output_top_logprobs_val = []
|
1290
1383
|
output_top_logprobs_idx = []
|
1291
|
-
normalized_prompt_logprob = []
|
1292
1384
|
else:
|
1293
1385
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
1294
1386
|
output_token_logprobs_val
|
1295
1387
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1296
1388
|
input_top_logprobs_idx
|
1297
|
-
) = output_top_logprobs_val = output_top_logprobs_idx =
|
1298
|
-
normalized_prompt_logprob
|
1299
|
-
) = None
|
1389
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = None
|
1300
1390
|
|
1301
1391
|
for req in reqs:
|
1302
1392
|
if req is skip_req:
|
@@ -1343,7 +1433,6 @@ class Scheduler:
|
|
1343
1433
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1344
1434
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1345
1435
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1346
|
-
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
1347
1436
|
|
1348
1437
|
# Send to detokenizer
|
1349
1438
|
if rids:
|
@@ -1370,7 +1459,6 @@ class Scheduler:
|
|
1370
1459
|
input_top_logprobs_idx,
|
1371
1460
|
output_top_logprobs_val,
|
1372
1461
|
output_top_logprobs_idx,
|
1373
|
-
normalized_prompt_logprob,
|
1374
1462
|
)
|
1375
1463
|
)
|
1376
1464
|
else: # embedding or reward model
|
@@ -1412,12 +1500,7 @@ class Scheduler:
|
|
1412
1500
|
# Check forward mode for cuda graph
|
1413
1501
|
if not self.server_args.disable_cuda_graph:
|
1414
1502
|
forward_mode_state = torch.tensor(
|
1415
|
-
(
|
1416
|
-
1
|
1417
|
-
if local_batch.forward_mode.is_decode()
|
1418
|
-
or local_batch.forward_mode.is_idle()
|
1419
|
-
else 0
|
1420
|
-
),
|
1503
|
+
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
1421
1504
|
dtype=torch.int32,
|
1422
1505
|
)
|
1423
1506
|
torch.distributed.all_reduce(
|
@@ -1438,6 +1521,7 @@ class Scheduler:
|
|
1438
1521
|
self.model_config,
|
1439
1522
|
self.enable_overlap,
|
1440
1523
|
self.spec_algorithm,
|
1524
|
+
self.server_args.enable_custom_logit_processor,
|
1441
1525
|
)
|
1442
1526
|
idle_batch.prepare_for_idle()
|
1443
1527
|
return idle_batch
|
@@ -1466,6 +1550,9 @@ class Scheduler:
|
|
1466
1550
|
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
1467
1551
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1468
1552
|
|
1553
|
+
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
1554
|
+
self.flush_cache()
|
1555
|
+
|
1469
1556
|
def flush_cache(self):
|
1470
1557
|
"""Flush the memory pool and cache."""
|
1471
1558
|
if len(self.waiting_queue) == 0 and (
|
@@ -1518,12 +1605,12 @@ class Scheduler:
|
|
1518
1605
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1519
1606
|
else:
|
1520
1607
|
logger.error(message)
|
1521
|
-
return success, message
|
1608
|
+
return UpdateWeightFromDiskReqOutput(success, message)
|
1522
1609
|
|
1523
1610
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1524
1611
|
"""Initialize the online model parameter update group."""
|
1525
1612
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
1526
|
-
return success, message
|
1613
|
+
return InitWeightsUpdateGroupReqOutput(success, message)
|
1527
1614
|
|
1528
1615
|
def update_weights_from_distributed(
|
1529
1616
|
self,
|
@@ -1536,7 +1623,7 @@ class Scheduler:
|
|
1536
1623
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1537
1624
|
else:
|
1538
1625
|
logger.error(message)
|
1539
|
-
return success, message
|
1626
|
+
return UpdateWeightsFromDistributedReqOutput(success, message)
|
1540
1627
|
|
1541
1628
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
1542
1629
|
"""Update the online model parameter from tensors."""
|
@@ -1547,11 +1634,11 @@ class Scheduler:
|
|
1547
1634
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
1548
1635
|
else:
|
1549
1636
|
logger.error(message)
|
1550
|
-
return success, message
|
1637
|
+
return UpdateWeightsFromTensorReqOutput(success, message)
|
1551
1638
|
|
1552
1639
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1553
1640
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1554
|
-
return parameter
|
1641
|
+
return GetWeightsByNameReqOutput(parameter)
|
1555
1642
|
|
1556
1643
|
def release_memory_occupation(self):
|
1557
1644
|
self.stashed_model_static_state = _export_static_state(
|
@@ -1559,6 +1646,7 @@ class Scheduler:
|
|
1559
1646
|
)
|
1560
1647
|
self.memory_saver_adapter.pause()
|
1561
1648
|
self.flush_cache()
|
1649
|
+
return ReleaseMemoryOccupationReqOutput()
|
1562
1650
|
|
1563
1651
|
def resume_memory_occupation(self):
|
1564
1652
|
self.memory_saver_adapter.resume()
|
@@ -1566,6 +1654,13 @@ class Scheduler:
|
|
1566
1654
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
1567
1655
|
)
|
1568
1656
|
del self.stashed_model_static_state
|
1657
|
+
return ResumeMemoryOccupationReqOutput()
|
1658
|
+
|
1659
|
+
def profile(self, recv_req: ProfileReq):
|
1660
|
+
if recv_req == ProfileReq.START_PROFILE:
|
1661
|
+
self.start_profile()
|
1662
|
+
else:
|
1663
|
+
self.stop_profile()
|
1569
1664
|
|
1570
1665
|
def start_profile(self) -> None:
|
1571
1666
|
if self.profiler is None:
|
@@ -1581,20 +1676,20 @@ class Scheduler:
|
|
1581
1676
|
)
|
1582
1677
|
logger.info("Profiler is done")
|
1583
1678
|
|
1584
|
-
def open_session(self, recv_req: OpenSessionReqInput)
|
1679
|
+
def open_session(self, recv_req: OpenSessionReqInput):
|
1585
1680
|
# handle error
|
1586
1681
|
session_id = recv_req.session_id
|
1587
1682
|
if session_id in self.sessions:
|
1588
1683
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1589
|
-
return session_id, False
|
1684
|
+
return OpenSessionReqOutput(session_id, False)
|
1590
1685
|
elif session_id is None:
|
1591
1686
|
logger.warning(f"session id is None, cannot open.")
|
1592
|
-
return session_id, False
|
1687
|
+
return OpenSessionReqOutput(session_id, False)
|
1593
1688
|
else:
|
1594
1689
|
self.sessions[session_id] = Session(
|
1595
1690
|
recv_req.capacity_of_str_len, session_id
|
1596
1691
|
)
|
1597
|
-
return session_id, True
|
1692
|
+
return OpenSessionReqOutput(session_id, True)
|
1598
1693
|
|
1599
1694
|
def close_session(self, recv_req: CloseSessionReqInput):
|
1600
1695
|
# handle error
|
@@ -1651,7 +1746,11 @@ def run_scheduler_process(
|
|
1651
1746
|
try:
|
1652
1747
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1653
1748
|
pipe_writer.send(
|
1654
|
-
{
|
1749
|
+
{
|
1750
|
+
"status": "ready",
|
1751
|
+
"max_total_num_tokens": scheduler.max_total_num_tokens,
|
1752
|
+
"max_req_input_len": scheduler.max_req_input_len,
|
1753
|
+
}
|
1655
1754
|
)
|
1656
1755
|
if scheduler.enable_overlap:
|
1657
1756
|
scheduler.event_loop_overlap()
|