sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
"""
|
4
|
+
Copyright 2023-2025 SGLang Team
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
you may not use this file except in compliance with the License.
|
7
|
+
You may obtain a copy of the License at
|
8
|
+
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
See the License for the specific language governing permissions and
|
15
|
+
limitations under the License.
|
16
|
+
"""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import threading
|
20
|
+
from queue import PriorityQueue, Queue
|
21
|
+
from typing import Optional
|
22
|
+
|
23
|
+
import torch
|
24
|
+
|
25
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class CacheOperation:
|
31
|
+
|
32
|
+
counter = 0
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
host_indices: torch.Tensor,
|
37
|
+
device_indices: torch.Tensor,
|
38
|
+
node_id: int,
|
39
|
+
priority: Optional[int] = None,
|
40
|
+
):
|
41
|
+
self.host_indices = host_indices
|
42
|
+
self.device_indices = device_indices
|
43
|
+
self.node_ids = [node_id]
|
44
|
+
self.data = None
|
45
|
+
|
46
|
+
self.id = CacheOperation.counter
|
47
|
+
CacheOperation.counter += 1
|
48
|
+
# default priority is the order of creation
|
49
|
+
self.priority = priority if priority is not None else self.id
|
50
|
+
|
51
|
+
def merge(self, other: "CacheOperation") -> None:
|
52
|
+
# multiple operations can be merged into a single operation for batch processing
|
53
|
+
self.host_indices = torch.cat([self.host_indices, other.host_indices])
|
54
|
+
self.device_indices = torch.cat([self.device_indices, other.device_indices])
|
55
|
+
self.priority = min(self.priority, other.priority)
|
56
|
+
self.node_ids.extend(other.node_ids)
|
57
|
+
|
58
|
+
def __lt__(self, other: "CacheOperation"):
|
59
|
+
return self.priority < other.priority
|
60
|
+
|
61
|
+
|
62
|
+
class TransferBuffer:
|
63
|
+
"""
|
64
|
+
Overlapping buffer preparation and transfer operations to improve throughput.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
|
68
|
+
self.buffers = Queue(maxsize=buffer_count)
|
69
|
+
# todo: adjust the buffer size based on throughput profile of the system
|
70
|
+
self.max_buffer_size = max_buffer_size
|
71
|
+
|
72
|
+
def full(self) -> bool:
|
73
|
+
return self.buffers.full()
|
74
|
+
|
75
|
+
def empty(self) -> bool:
|
76
|
+
return self.buffers.empty()
|
77
|
+
|
78
|
+
def put(self, item, block=True) -> None:
|
79
|
+
self.buffers.put(item, block=block)
|
80
|
+
|
81
|
+
def get(self, block=True) -> Optional[CacheOperation]:
|
82
|
+
try:
|
83
|
+
return self.buffers.get(block=block)
|
84
|
+
except Exception as e:
|
85
|
+
logger.error(e)
|
86
|
+
|
87
|
+
|
88
|
+
class HiCacheController:
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
mem_pool_device: MHATokenToKVPool,
|
93
|
+
mem_pool_host: MLATokenToKVPoolHost,
|
94
|
+
write_policy: str = "write_through_selective",
|
95
|
+
):
|
96
|
+
|
97
|
+
self.mem_pool_device = mem_pool_device
|
98
|
+
self.mem_pool_host = mem_pool_host
|
99
|
+
self.write_policy = write_policy
|
100
|
+
|
101
|
+
if write_policy not in [
|
102
|
+
"write_through",
|
103
|
+
"write_through_selective",
|
104
|
+
"write_back",
|
105
|
+
]:
|
106
|
+
raise ValueError(f"Invalid write policy: {write_policy}")
|
107
|
+
|
108
|
+
self.write_queue = PriorityQueue()
|
109
|
+
self.load_queue = PriorityQueue()
|
110
|
+
|
111
|
+
self.ack_write_queue = Queue()
|
112
|
+
self.ack_load_queue = Queue()
|
113
|
+
|
114
|
+
self.write_buffer = TransferBuffer()
|
115
|
+
self.load_buffer = TransferBuffer()
|
116
|
+
|
117
|
+
self.write_stream = torch.cuda.Stream()
|
118
|
+
self.load_stream = torch.cuda.Stream()
|
119
|
+
|
120
|
+
self.write_thread = threading.Thread(
|
121
|
+
target=self.write_thread_func_buffer, daemon=True
|
122
|
+
)
|
123
|
+
self.load_thread = threading.Thread(
|
124
|
+
target=self.load_thread_func_buffer, daemon=True
|
125
|
+
)
|
126
|
+
self.write_thread.start()
|
127
|
+
self.load_thread.start()
|
128
|
+
|
129
|
+
def write(
|
130
|
+
self,
|
131
|
+
device_indices: torch.Tensor,
|
132
|
+
priority: Optional[int] = None,
|
133
|
+
node_id: int = 0,
|
134
|
+
) -> Optional[torch.Tensor]:
|
135
|
+
"""
|
136
|
+
Back up KV caches from device memory to host memory.
|
137
|
+
"""
|
138
|
+
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
139
|
+
if host_indices is None:
|
140
|
+
return None
|
141
|
+
self.write_queue.put(
|
142
|
+
CacheOperation(host_indices, device_indices, node_id, priority)
|
143
|
+
)
|
144
|
+
self.mem_pool_host.protect_write(host_indices)
|
145
|
+
return host_indices
|
146
|
+
|
147
|
+
def load(
|
148
|
+
self,
|
149
|
+
host_indices: torch.Tensor,
|
150
|
+
priority: Optional[int] = None,
|
151
|
+
node_id: int = 0,
|
152
|
+
) -> Optional[torch.Tensor]:
|
153
|
+
"""
|
154
|
+
Load KV caches from host memory to device memory.
|
155
|
+
"""
|
156
|
+
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
157
|
+
if device_indices is None:
|
158
|
+
return None
|
159
|
+
self.load_queue.put(
|
160
|
+
CacheOperation(host_indices, device_indices, node_id, priority)
|
161
|
+
)
|
162
|
+
self.mem_pool_host.protect_load(host_indices)
|
163
|
+
return device_indices
|
164
|
+
|
165
|
+
def write_thread_func_direct(self):
|
166
|
+
"""
|
167
|
+
Directly write through KV caches to host memory without buffering.
|
168
|
+
"""
|
169
|
+
with torch.cuda.stream(self.write_stream):
|
170
|
+
while True:
|
171
|
+
try:
|
172
|
+
operation = self.write_queue.get(block=True)
|
173
|
+
operation.data = self.mem_pool_device.get_flat_data(
|
174
|
+
operation.device_indices
|
175
|
+
)
|
176
|
+
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
177
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
178
|
+
for node_id in operation.node_ids:
|
179
|
+
self.ack_write_queue.put(node_id)
|
180
|
+
except Exception as e:
|
181
|
+
logger.error(e)
|
182
|
+
|
183
|
+
def load_thread_func_direct(self):
|
184
|
+
"""
|
185
|
+
Directly load KV caches from host memory to device memory without buffering.
|
186
|
+
"""
|
187
|
+
with torch.cuda.stream(self.load_stream):
|
188
|
+
while True:
|
189
|
+
try:
|
190
|
+
operation = self.load_queue.get(block=True)
|
191
|
+
operation.data = self.mem_pool_host.get_flat_data(
|
192
|
+
operation.host_indices
|
193
|
+
)
|
194
|
+
self.mem_pool_device.transfer(
|
195
|
+
operation.device_indices, operation.data
|
196
|
+
)
|
197
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
198
|
+
for node_id in operation.node_ids:
|
199
|
+
self.ack_load_queue.put(node_id)
|
200
|
+
except Exception as e:
|
201
|
+
logger.error(e)
|
202
|
+
|
203
|
+
def write_aux_func(self, no_wait=False):
|
204
|
+
"""
|
205
|
+
Auxiliary function to prepare the buffer for write operations.
|
206
|
+
"""
|
207
|
+
buffer = None
|
208
|
+
while True:
|
209
|
+
try:
|
210
|
+
operation = self.write_queue.get(block=True)
|
211
|
+
if buffer is None:
|
212
|
+
buffer = operation
|
213
|
+
else:
|
214
|
+
buffer.merge(operation)
|
215
|
+
if (
|
216
|
+
no_wait
|
217
|
+
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
218
|
+
or self.write_queue.empty()
|
219
|
+
or self.write_buffer.empty()
|
220
|
+
):
|
221
|
+
assert (
|
222
|
+
buffer.device_indices.is_cuda
|
223
|
+
), "Device indices should be on GPU"
|
224
|
+
buffer.data = self.mem_pool_device.get_flat_data(
|
225
|
+
buffer.device_indices
|
226
|
+
).contiguous()
|
227
|
+
self.write_buffer.put(buffer, block=True)
|
228
|
+
buffer = None
|
229
|
+
except Exception as e:
|
230
|
+
logger.error(e)
|
231
|
+
|
232
|
+
def load_aux_func(self):
|
233
|
+
"""
|
234
|
+
Auxiliary function to prepare the buffer for load operations.
|
235
|
+
"""
|
236
|
+
buffer = None
|
237
|
+
while True:
|
238
|
+
try:
|
239
|
+
operation = self.load_queue.get(block=True)
|
240
|
+
if buffer is None:
|
241
|
+
buffer = operation
|
242
|
+
else:
|
243
|
+
buffer.merge(operation)
|
244
|
+
if (
|
245
|
+
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
246
|
+
or self.load_queue.empty()
|
247
|
+
or self.load_buffer.empty()
|
248
|
+
):
|
249
|
+
buffer.data = (
|
250
|
+
self.mem_pool_host.get_flat_data(buffer.host_indices)
|
251
|
+
.contiguous()
|
252
|
+
.pin_memory()
|
253
|
+
)
|
254
|
+
self.load_buffer.put(buffer, block=True)
|
255
|
+
buffer = None
|
256
|
+
except Exception as e:
|
257
|
+
logger.error(e)
|
258
|
+
|
259
|
+
def write_thread_func_buffer(self):
|
260
|
+
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
261
|
+
aux_thread.start()
|
262
|
+
with torch.cuda.stream(self.write_stream):
|
263
|
+
while True:
|
264
|
+
operation = self.write_buffer.get()
|
265
|
+
if operation is None:
|
266
|
+
continue
|
267
|
+
self.mem_pool_host.transfer(operation.host_indices, operation.data)
|
268
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
269
|
+
for node_id in operation.node_ids:
|
270
|
+
self.ack_write_queue.put(node_id)
|
271
|
+
|
272
|
+
def load_thread_func_buffer(self):
|
273
|
+
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
274
|
+
aux_thread.start()
|
275
|
+
with torch.cuda.stream(self.load_stream):
|
276
|
+
while True:
|
277
|
+
operation = self.load_buffer.get()
|
278
|
+
if operation is None:
|
279
|
+
continue
|
280
|
+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
281
|
+
self.mem_pool_host.complete_io(operation.host_indices)
|
282
|
+
for node_id in operation.node_ids:
|
283
|
+
self.ack_load_queue.put(node_id)
|
284
|
+
|
285
|
+
def evict_device(
|
286
|
+
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
287
|
+
) -> int:
|
288
|
+
if self.mem_pool_host.is_synced(host_indices):
|
289
|
+
self.mem_pool_device.free(device_indices)
|
290
|
+
self.mem_pool_host.update_backup(host_indices)
|
291
|
+
return len(device_indices)
|
292
|
+
else:
|
293
|
+
raise ValueError(
|
294
|
+
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
295
|
+
)
|
296
|
+
|
297
|
+
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
298
|
+
if not backup_only:
|
299
|
+
raise ValueError("Other eviction policies are not supported yet.")
|
300
|
+
|
301
|
+
if self.mem_pool_host.is_backup(host_indices):
|
302
|
+
self.mem_pool_host.free(host_indices)
|
303
|
+
return len(host_indices)
|
304
|
+
else:
|
305
|
+
raise ValueError(
|
306
|
+
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
307
|
+
)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Configure the logging settings of a server.
|
18
|
+
|
19
|
+
Usage:
|
20
|
+
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
|
21
|
+
"""
|
22
|
+
|
23
|
+
import argparse
|
24
|
+
|
25
|
+
import requests
|
26
|
+
|
27
|
+
if __name__ == "__main__":
|
28
|
+
parser = argparse.ArgumentParser()
|
29
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
parser.add_argument(
|
31
|
+
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
32
|
+
)
|
33
|
+
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
|
34
|
+
args = parser.parse_args()
|
35
|
+
|
36
|
+
response = requests.post(
|
37
|
+
args.url + "/configure_logging",
|
38
|
+
json={
|
39
|
+
"dump_requests_folder": args.dump_requests_folder,
|
40
|
+
"dump_requests_threshold": args.dump_requests_threshold,
|
41
|
+
},
|
42
|
+
)
|
43
|
+
assert response.status_code == 200
|
@@ -20,6 +20,7 @@ import threading
|
|
20
20
|
from enum import Enum, auto
|
21
21
|
|
22
22
|
import psutil
|
23
|
+
import setproctitle
|
23
24
|
import zmq
|
24
25
|
|
25
26
|
from sglang.srt.managers.io_struct import (
|
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
|
|
230
231
|
port_args: PortArgs,
|
231
232
|
pipe_writer,
|
232
233
|
):
|
234
|
+
setproctitle.setproctitle("sglang::data_parallel_controller")
|
233
235
|
configure_logger(server_args)
|
234
236
|
parent_process = psutil.Process().parent()
|
235
237
|
|
@@ -181,8 +181,6 @@ class DetokenizerManager:
|
|
181
181
|
finished_reasons=recv_obj.finished_reasons,
|
182
182
|
output_strs=output_strs,
|
183
183
|
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
-
origin_input_ids=recv_obj.origin_input_ids,
|
185
|
-
output_ids=recv_obj.output_ids,
|
186
184
|
completion_tokens=recv_obj.completion_tokens,
|
187
185
|
cached_tokens=recv_obj.cached_tokens,
|
188
186
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
19
19
|
import uuid
|
20
20
|
from dataclasses import dataclass
|
21
21
|
from enum import Enum
|
22
|
-
from typing import Dict, List, Optional,
|
23
|
-
|
24
|
-
import torch
|
22
|
+
from typing import Dict, List, Optional, Union
|
25
23
|
|
26
24
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
25
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -323,9 +321,7 @@ class BatchTokenIDOut:
|
|
323
321
|
decoded_texts: List[str]
|
324
322
|
decode_ids: List[int]
|
325
323
|
read_offsets: List[int]
|
326
|
-
# Only used when
|
327
|
-
origin_input_ids: Optional[List[int]]
|
328
|
-
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
|
324
|
+
# Only used when `--skip-tokenizer-init` is on
|
329
325
|
output_ids: Optional[List[int]]
|
330
326
|
# Detokenization configs
|
331
327
|
skip_special_tokens: List[bool]
|
@@ -356,14 +352,7 @@ class BatchStrOut:
|
|
356
352
|
# The output decoded strings
|
357
353
|
output_strs: List[str]
|
358
354
|
|
359
|
-
# The token ids
|
360
|
-
origin_input_ids: Optional[List[int]]
|
361
|
-
output_ids: Optional[List[int]]
|
362
|
-
|
363
355
|
# Token counts
|
364
|
-
# real input and output tokens can be get from
|
365
|
-
# origin_input_ids and output_ids by enabling --return_token_ids
|
366
|
-
# TODO (Shuai): Rename this to clarify the meaning.
|
367
356
|
prompt_tokens: List[int]
|
368
357
|
completion_tokens: List[int]
|
369
358
|
cached_tokens: List[int]
|
@@ -468,6 +457,26 @@ class GetWeightsByNameReqOutput:
|
|
468
457
|
parameter: list
|
469
458
|
|
470
459
|
|
460
|
+
@dataclass
|
461
|
+
class ReleaseMemoryOccupationReqInput:
|
462
|
+
pass
|
463
|
+
|
464
|
+
|
465
|
+
@dataclass
|
466
|
+
class ReleaseMemoryOccupationReqOutput:
|
467
|
+
pass
|
468
|
+
|
469
|
+
|
470
|
+
@dataclass
|
471
|
+
class ResumeMemoryOccupationReqInput:
|
472
|
+
pass
|
473
|
+
|
474
|
+
|
475
|
+
@dataclass
|
476
|
+
class ResumeMemoryOccupationReqOutput:
|
477
|
+
pass
|
478
|
+
|
479
|
+
|
471
480
|
@dataclass
|
472
481
|
class AbortReq:
|
473
482
|
# The request id
|
@@ -479,6 +488,13 @@ class ProfileReq(Enum):
|
|
479
488
|
STOP_PROFILE = 2
|
480
489
|
|
481
490
|
|
491
|
+
@dataclass
|
492
|
+
class ConfigureLoggingReq:
|
493
|
+
log_requests: Optional[bool] = None
|
494
|
+
dump_requests_folder: Optional[str] = None
|
495
|
+
dump_requests_threshold: Optional[int] = None
|
496
|
+
|
497
|
+
|
482
498
|
@dataclass
|
483
499
|
class OpenSessionReqInput:
|
484
500
|
capacity_of_str_len: int
|
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
|
44
44
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
45
45
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
46
46
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
47
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
49
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
50
50
|
from sglang.srt.server_args import ServerArgs
|
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
|
|
1163
1163
|
input_embeds=self.input_embeds,
|
1164
1164
|
spec_algorithm=self.spec_algorithm,
|
1165
1165
|
spec_info=self.spec_info,
|
1166
|
+
capture_hidden_mode=(
|
1167
|
+
getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
|
1168
|
+
if self.spec_info
|
1169
|
+
else CaptureHiddenMode.NULL
|
1170
|
+
),
|
1166
1171
|
)
|
1167
1172
|
|
1168
1173
|
def copy(self):
|
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
|
|
1237
1242
|
# Speculative decoding
|
1238
1243
|
spec_algorithm: SpeculativeAlgorithm = None
|
1239
1244
|
spec_info: Optional[SpecInfo] = None
|
1245
|
+
capture_hidden_mode: CaptureHiddenMode = None
|
1240
1246
|
|
1241
1247
|
|
1242
1248
|
@triton.jit
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
15
15
|
|
16
|
+
import faulthandler
|
16
17
|
import logging
|
17
18
|
import os
|
18
19
|
import signal
|
@@ -46,6 +47,10 @@ from sglang.srt.managers.io_struct import (
|
|
46
47
|
OpenSessionReqInput,
|
47
48
|
OpenSessionReqOutput,
|
48
49
|
ProfileReq,
|
50
|
+
ReleaseMemoryOccupationReqInput,
|
51
|
+
ReleaseMemoryOccupationReqOutput,
|
52
|
+
ResumeMemoryOccupationReqInput,
|
53
|
+
ResumeMemoryOccupationReqOutput,
|
49
54
|
TokenizedEmbeddingReqInput,
|
50
55
|
TokenizedGenerateReqInput,
|
51
56
|
UpdateWeightFromDiskReqInput,
|
@@ -77,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
|
|
77
82
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
78
83
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
79
84
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
85
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
80
86
|
from sglang.srt.utils import (
|
81
87
|
broadcast_pyobj,
|
82
88
|
configure_logger,
|
@@ -356,6 +362,10 @@ class Scheduler:
|
|
356
362
|
t.start()
|
357
363
|
self.parent_process = psutil.Process().parent()
|
358
364
|
|
365
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
366
|
+
enable=server_args.enable_memory_saver
|
367
|
+
)
|
368
|
+
|
359
369
|
# Init profiler
|
360
370
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
361
371
|
self.profiler = None
|
@@ -399,6 +409,8 @@ class Scheduler:
|
|
399
409
|
self.watchdog_last_time = time.time()
|
400
410
|
time.sleep(self.watchdog_timeout / 2)
|
401
411
|
|
412
|
+
# Wait sometimes so that the parent process can print the error.
|
413
|
+
time.sleep(5)
|
402
414
|
self.parent_process.send_signal(signal.SIGQUIT)
|
403
415
|
|
404
416
|
@torch.no_grad()
|
@@ -516,6 +528,12 @@ class Scheduler:
|
|
516
528
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
517
529
|
parameter = self.get_weights_by_name(recv_req)
|
518
530
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
531
|
+
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
532
|
+
self.release_memory_occupation()
|
533
|
+
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
534
|
+
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
535
|
+
self.resume_memory_occupation()
|
536
|
+
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
519
537
|
elif isinstance(recv_req, ProfileReq):
|
520
538
|
if recv_req == ProfileReq.START_PROFILE:
|
521
539
|
self.start_profile()
|
@@ -962,10 +980,13 @@ class Scheduler:
|
|
962
980
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
963
981
|
)
|
964
982
|
else:
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
983
|
+
(
|
984
|
+
logits_output,
|
985
|
+
next_token_ids,
|
986
|
+
model_worker_batch,
|
987
|
+
num_accepted_tokens,
|
988
|
+
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
989
|
+
self.num_generated_tokens += num_accepted_tokens
|
969
990
|
elif batch.forward_mode.is_idle():
|
970
991
|
model_worker_batch = batch.get_model_worker_batch()
|
971
992
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
@@ -1250,7 +1271,6 @@ class Scheduler:
|
|
1250
1271
|
decode_ids_list = []
|
1251
1272
|
read_offsets = []
|
1252
1273
|
output_ids = []
|
1253
|
-
origin_input_ids = []
|
1254
1274
|
|
1255
1275
|
skip_special_tokens = []
|
1256
1276
|
spaces_between_special_tokens = []
|
@@ -1302,14 +1322,8 @@ class Scheduler:
|
|
1302
1322
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
1303
1323
|
decode_ids_list.append(decode_ids)
|
1304
1324
|
read_offsets.append(read_offset)
|
1305
|
-
if self.skip_tokenizer_init
|
1325
|
+
if self.skip_tokenizer_init:
|
1306
1326
|
output_ids.append(req.output_ids)
|
1307
|
-
else:
|
1308
|
-
output_ids = None
|
1309
|
-
if self.server_args.return_token_ids:
|
1310
|
-
origin_input_ids.append(req.origin_input_ids)
|
1311
|
-
else:
|
1312
|
-
origin_input_ids = None
|
1313
1327
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1314
1328
|
spaces_between_special_tokens.append(
|
1315
1329
|
req.sampling_params.spaces_between_special_tokens
|
@@ -1341,7 +1355,6 @@ class Scheduler:
|
|
1341
1355
|
decoded_texts,
|
1342
1356
|
decode_ids_list,
|
1343
1357
|
read_offsets,
|
1344
|
-
origin_input_ids,
|
1345
1358
|
output_ids,
|
1346
1359
|
skip_special_tokens,
|
1347
1360
|
spaces_between_special_tokens,
|
@@ -1513,8 +1526,9 @@ class Scheduler:
|
|
1513
1526
|
return success, message
|
1514
1527
|
|
1515
1528
|
def update_weights_from_distributed(
|
1516
|
-
self,
|
1517
|
-
|
1529
|
+
self,
|
1530
|
+
recv_req: UpdateWeightsFromDistributedReqInput,
|
1531
|
+
) -> Tuple[bool, str]:
|
1518
1532
|
"""Update the online model parameter."""
|
1519
1533
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
1520
1534
|
if success:
|
@@ -1539,6 +1553,20 @@ class Scheduler:
|
|
1539
1553
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1540
1554
|
return parameter
|
1541
1555
|
|
1556
|
+
def release_memory_occupation(self):
|
1557
|
+
self.stashed_model_static_state = _export_static_state(
|
1558
|
+
self.tp_worker.worker.model_runner.model
|
1559
|
+
)
|
1560
|
+
self.memory_saver_adapter.pause()
|
1561
|
+
self.flush_cache()
|
1562
|
+
|
1563
|
+
def resume_memory_occupation(self):
|
1564
|
+
self.memory_saver_adapter.resume()
|
1565
|
+
_import_static_state(
|
1566
|
+
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
1567
|
+
)
|
1568
|
+
del self.stashed_model_static_state
|
1569
|
+
|
1542
1570
|
def start_profile(self) -> None:
|
1543
1571
|
if self.profiler is None:
|
1544
1572
|
raise RuntimeError("Profiler is not enabled.")
|
@@ -1577,6 +1605,20 @@ class Scheduler:
|
|
1577
1605
|
del self.sessions[session_id]
|
1578
1606
|
|
1579
1607
|
|
1608
|
+
def _export_static_state(model):
|
1609
|
+
return dict(
|
1610
|
+
buffers=[
|
1611
|
+
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
1612
|
+
]
|
1613
|
+
)
|
1614
|
+
|
1615
|
+
|
1616
|
+
def _import_static_state(model, static_params):
|
1617
|
+
self_named_buffers = dict(model.named_buffers())
|
1618
|
+
for name, tensor in static_params["buffers"]:
|
1619
|
+
self_named_buffers[name][...] = tensor
|
1620
|
+
|
1621
|
+
|
1580
1622
|
def run_scheduler_process(
|
1581
1623
|
server_args: ServerArgs,
|
1582
1624
|
port_args: PortArgs,
|
@@ -1586,6 +1628,7 @@ def run_scheduler_process(
|
|
1586
1628
|
pipe_writer,
|
1587
1629
|
):
|
1588
1630
|
setproctitle.setproctitle("sglang::scheduler")
|
1631
|
+
faulthandler.enable()
|
1589
1632
|
|
1590
1633
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1591
1634
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
@@ -99,7 +99,7 @@ class Session:
|
|
99
99
|
|
100
100
|
if last_req is not None:
|
101
101
|
# trim bos token if it is an append
|
102
|
-
if req.input_ids[0] == tokenizer.bos_token_id:
|
102
|
+
if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
|
103
103
|
req.input_ids = req.input_ids[1:]
|
104
104
|
|
105
105
|
input_ids = (
|