sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Copyright 2023-2025 SGLang Team
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ import logging
19
+ import threading
20
+ from queue import PriorityQueue, Queue
21
+ from typing import Optional
22
+
23
+ import torch
24
+
25
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class CacheOperation:
31
+
32
+ counter = 0
33
+
34
+ def __init__(
35
+ self,
36
+ host_indices: torch.Tensor,
37
+ device_indices: torch.Tensor,
38
+ node_id: int,
39
+ priority: Optional[int] = None,
40
+ ):
41
+ self.host_indices = host_indices
42
+ self.device_indices = device_indices
43
+ self.node_ids = [node_id]
44
+ self.data = None
45
+
46
+ self.id = CacheOperation.counter
47
+ CacheOperation.counter += 1
48
+ # default priority is the order of creation
49
+ self.priority = priority if priority is not None else self.id
50
+
51
+ def merge(self, other: "CacheOperation") -> None:
52
+ # multiple operations can be merged into a single operation for batch processing
53
+ self.host_indices = torch.cat([self.host_indices, other.host_indices])
54
+ self.device_indices = torch.cat([self.device_indices, other.device_indices])
55
+ self.priority = min(self.priority, other.priority)
56
+ self.node_ids.extend(other.node_ids)
57
+
58
+ def __lt__(self, other: "CacheOperation"):
59
+ return self.priority < other.priority
60
+
61
+
62
+ class TransferBuffer:
63
+ """
64
+ Overlapping buffer preparation and transfer operations to improve throughput.
65
+ """
66
+
67
+ def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
68
+ self.buffers = Queue(maxsize=buffer_count)
69
+ # todo: adjust the buffer size based on throughput profile of the system
70
+ self.max_buffer_size = max_buffer_size
71
+
72
+ def full(self) -> bool:
73
+ return self.buffers.full()
74
+
75
+ def empty(self) -> bool:
76
+ return self.buffers.empty()
77
+
78
+ def put(self, item, block=True) -> None:
79
+ self.buffers.put(item, block=block)
80
+
81
+ def get(self, block=True) -> Optional[CacheOperation]:
82
+ try:
83
+ return self.buffers.get(block=block)
84
+ except Exception as e:
85
+ logger.error(e)
86
+
87
+
88
+ class HiCacheController:
89
+
90
+ def __init__(
91
+ self,
92
+ mem_pool_device: MHATokenToKVPool,
93
+ mem_pool_host: MLATokenToKVPoolHost,
94
+ write_policy: str = "write_through_selective",
95
+ ):
96
+
97
+ self.mem_pool_device = mem_pool_device
98
+ self.mem_pool_host = mem_pool_host
99
+ self.write_policy = write_policy
100
+
101
+ if write_policy not in [
102
+ "write_through",
103
+ "write_through_selective",
104
+ "write_back",
105
+ ]:
106
+ raise ValueError(f"Invalid write policy: {write_policy}")
107
+
108
+ self.write_queue = PriorityQueue()
109
+ self.load_queue = PriorityQueue()
110
+
111
+ self.ack_write_queue = Queue()
112
+ self.ack_load_queue = Queue()
113
+
114
+ self.write_buffer = TransferBuffer()
115
+ self.load_buffer = TransferBuffer()
116
+
117
+ self.write_stream = torch.cuda.Stream()
118
+ self.load_stream = torch.cuda.Stream()
119
+
120
+ self.write_thread = threading.Thread(
121
+ target=self.write_thread_func_buffer, daemon=True
122
+ )
123
+ self.load_thread = threading.Thread(
124
+ target=self.load_thread_func_buffer, daemon=True
125
+ )
126
+ self.write_thread.start()
127
+ self.load_thread.start()
128
+
129
+ def write(
130
+ self,
131
+ device_indices: torch.Tensor,
132
+ priority: Optional[int] = None,
133
+ node_id: int = 0,
134
+ ) -> Optional[torch.Tensor]:
135
+ """
136
+ Back up KV caches from device memory to host memory.
137
+ """
138
+ host_indices = self.mem_pool_host.alloc(len(device_indices))
139
+ if host_indices is None:
140
+ return None
141
+ self.write_queue.put(
142
+ CacheOperation(host_indices, device_indices, node_id, priority)
143
+ )
144
+ self.mem_pool_host.protect_write(host_indices)
145
+ return host_indices
146
+
147
+ def load(
148
+ self,
149
+ host_indices: torch.Tensor,
150
+ priority: Optional[int] = None,
151
+ node_id: int = 0,
152
+ ) -> Optional[torch.Tensor]:
153
+ """
154
+ Load KV caches from host memory to device memory.
155
+ """
156
+ device_indices = self.mem_pool_device.alloc(len(host_indices))
157
+ if device_indices is None:
158
+ return None
159
+ self.load_queue.put(
160
+ CacheOperation(host_indices, device_indices, node_id, priority)
161
+ )
162
+ self.mem_pool_host.protect_load(host_indices)
163
+ return device_indices
164
+
165
+ def write_thread_func_direct(self):
166
+ """
167
+ Directly write through KV caches to host memory without buffering.
168
+ """
169
+ with torch.cuda.stream(self.write_stream):
170
+ while True:
171
+ try:
172
+ operation = self.write_queue.get(block=True)
173
+ operation.data = self.mem_pool_device.get_flat_data(
174
+ operation.device_indices
175
+ )
176
+ self.mem_pool_host.transfer(operation.host_indices, operation.data)
177
+ self.mem_pool_host.complete_io(operation.host_indices)
178
+ for node_id in operation.node_ids:
179
+ self.ack_write_queue.put(node_id)
180
+ except Exception as e:
181
+ logger.error(e)
182
+
183
+ def load_thread_func_direct(self):
184
+ """
185
+ Directly load KV caches from host memory to device memory without buffering.
186
+ """
187
+ with torch.cuda.stream(self.load_stream):
188
+ while True:
189
+ try:
190
+ operation = self.load_queue.get(block=True)
191
+ operation.data = self.mem_pool_host.get_flat_data(
192
+ operation.host_indices
193
+ )
194
+ self.mem_pool_device.transfer(
195
+ operation.device_indices, operation.data
196
+ )
197
+ self.mem_pool_host.complete_io(operation.host_indices)
198
+ for node_id in operation.node_ids:
199
+ self.ack_load_queue.put(node_id)
200
+ except Exception as e:
201
+ logger.error(e)
202
+
203
+ def write_aux_func(self, no_wait=False):
204
+ """
205
+ Auxiliary function to prepare the buffer for write operations.
206
+ """
207
+ buffer = None
208
+ while True:
209
+ try:
210
+ operation = self.write_queue.get(block=True)
211
+ if buffer is None:
212
+ buffer = operation
213
+ else:
214
+ buffer.merge(operation)
215
+ if (
216
+ no_wait
217
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
218
+ or self.write_queue.empty()
219
+ or self.write_buffer.empty()
220
+ ):
221
+ assert (
222
+ buffer.device_indices.is_cuda
223
+ ), "Device indices should be on GPU"
224
+ buffer.data = self.mem_pool_device.get_flat_data(
225
+ buffer.device_indices
226
+ ).contiguous()
227
+ self.write_buffer.put(buffer, block=True)
228
+ buffer = None
229
+ except Exception as e:
230
+ logger.error(e)
231
+
232
+ def load_aux_func(self):
233
+ """
234
+ Auxiliary function to prepare the buffer for load operations.
235
+ """
236
+ buffer = None
237
+ while True:
238
+ try:
239
+ operation = self.load_queue.get(block=True)
240
+ if buffer is None:
241
+ buffer = operation
242
+ else:
243
+ buffer.merge(operation)
244
+ if (
245
+ len(buffer.host_indices) >= self.load_buffer.max_buffer_size
246
+ or self.load_queue.empty()
247
+ or self.load_buffer.empty()
248
+ ):
249
+ buffer.data = (
250
+ self.mem_pool_host.get_flat_data(buffer.host_indices)
251
+ .contiguous()
252
+ .pin_memory()
253
+ )
254
+ self.load_buffer.put(buffer, block=True)
255
+ buffer = None
256
+ except Exception as e:
257
+ logger.error(e)
258
+
259
+ def write_thread_func_buffer(self):
260
+ aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
261
+ aux_thread.start()
262
+ with torch.cuda.stream(self.write_stream):
263
+ while True:
264
+ operation = self.write_buffer.get()
265
+ if operation is None:
266
+ continue
267
+ self.mem_pool_host.transfer(operation.host_indices, operation.data)
268
+ self.mem_pool_host.complete_io(operation.host_indices)
269
+ for node_id in operation.node_ids:
270
+ self.ack_write_queue.put(node_id)
271
+
272
+ def load_thread_func_buffer(self):
273
+ aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
274
+ aux_thread.start()
275
+ with torch.cuda.stream(self.load_stream):
276
+ while True:
277
+ operation = self.load_buffer.get()
278
+ if operation is None:
279
+ continue
280
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
281
+ self.mem_pool_host.complete_io(operation.host_indices)
282
+ for node_id in operation.node_ids:
283
+ self.ack_load_queue.put(node_id)
284
+
285
+ def evict_device(
286
+ self, device_indices: torch.Tensor, host_indices: torch.Tensor
287
+ ) -> int:
288
+ if self.mem_pool_host.is_synced(host_indices):
289
+ self.mem_pool_device.free(device_indices)
290
+ self.mem_pool_host.update_backup(host_indices)
291
+ return len(device_indices)
292
+ else:
293
+ raise ValueError(
294
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
295
+ )
296
+
297
+ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
298
+ if not backup_only:
299
+ raise ValueError("Other eviction policies are not supported yet.")
300
+
301
+ if self.mem_pool_host.is_backup(host_indices):
302
+ self.mem_pool_host.free(host_indices)
303
+ return len(host_indices)
304
+ else:
305
+ raise ValueError(
306
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
307
+ )
@@ -0,0 +1,43 @@
1
+ """
2
+ Copyright 2023-2025 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Configure the logging settings of a server.
18
+
19
+ Usage:
20
+ python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
21
+ """
22
+
23
+ import argparse
24
+
25
+ import requests
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ parser.add_argument(
31
+ "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
32
+ )
33
+ parser.add_argument("--dump-requests-threshold", type=int, default=1000)
34
+ args = parser.parse_args()
35
+
36
+ response = requests.post(
37
+ args.url + "/configure_logging",
38
+ json={
39
+ "dump_requests_folder": args.dump_requests_folder,
40
+ "dump_requests_threshold": args.dump_requests_threshold,
41
+ },
42
+ )
43
+ assert response.status_code == 200
@@ -20,6 +20,7 @@ import threading
20
20
  from enum import Enum, auto
21
21
 
22
22
  import psutil
23
+ import setproctitle
23
24
  import zmq
24
25
 
25
26
  from sglang.srt.managers.io_struct import (
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
230
231
  port_args: PortArgs,
231
232
  pipe_writer,
232
233
  ):
234
+ setproctitle.setproctitle("sglang::data_parallel_controller")
233
235
  configure_logger(server_args)
234
236
  parent_process = psutil.Process().parent()
235
237
 
@@ -181,8 +181,6 @@ class DetokenizerManager:
181
181
  finished_reasons=recv_obj.finished_reasons,
182
182
  output_strs=output_strs,
183
183
  prompt_tokens=recv_obj.prompt_tokens,
184
- origin_input_ids=recv_obj.origin_input_ids,
185
- output_ids=recv_obj.output_ids,
186
184
  completion_tokens=recv_obj.completion_tokens,
187
185
  cached_tokens=recv_obj.cached_tokens,
188
186
  input_token_logprobs_val=recv_obj.input_token_logprobs_val,
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  import uuid
20
20
  from dataclasses import dataclass
21
21
  from enum import Enum
22
- from typing import Dict, List, Optional, Tuple, Union
23
-
24
- import torch
22
+ from typing import Dict, List, Optional, Union
25
23
 
26
24
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
25
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -323,9 +321,7 @@ class BatchTokenIDOut:
323
321
  decoded_texts: List[str]
324
322
  decode_ids: List[int]
325
323
  read_offsets: List[int]
326
- # Only used when --return-token-ids` is set
327
- origin_input_ids: Optional[List[int]]
328
- # Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
324
+ # Only used when `--skip-tokenizer-init` is on
329
325
  output_ids: Optional[List[int]]
330
326
  # Detokenization configs
331
327
  skip_special_tokens: List[bool]
@@ -356,14 +352,7 @@ class BatchStrOut:
356
352
  # The output decoded strings
357
353
  output_strs: List[str]
358
354
 
359
- # The token ids
360
- origin_input_ids: Optional[List[int]]
361
- output_ids: Optional[List[int]]
362
-
363
355
  # Token counts
364
- # real input and output tokens can be get from
365
- # origin_input_ids and output_ids by enabling --return_token_ids
366
- # TODO (Shuai): Rename this to clarify the meaning.
367
356
  prompt_tokens: List[int]
368
357
  completion_tokens: List[int]
369
358
  cached_tokens: List[int]
@@ -468,6 +457,26 @@ class GetWeightsByNameReqOutput:
468
457
  parameter: list
469
458
 
470
459
 
460
+ @dataclass
461
+ class ReleaseMemoryOccupationReqInput:
462
+ pass
463
+
464
+
465
+ @dataclass
466
+ class ReleaseMemoryOccupationReqOutput:
467
+ pass
468
+
469
+
470
+ @dataclass
471
+ class ResumeMemoryOccupationReqInput:
472
+ pass
473
+
474
+
475
+ @dataclass
476
+ class ResumeMemoryOccupationReqOutput:
477
+ pass
478
+
479
+
471
480
  @dataclass
472
481
  class AbortReq:
473
482
  # The request id
@@ -479,6 +488,13 @@ class ProfileReq(Enum):
479
488
  STOP_PROFILE = 2
480
489
 
481
490
 
491
+ @dataclass
492
+ class ConfigureLoggingReq:
493
+ log_requests: Optional[bool] = None
494
+ dump_requests_folder: Optional[str] = None
495
+ dump_requests_threshold: Optional[int] = None
496
+
497
+
482
498
  @dataclass
483
499
  class OpenSessionReqInput:
484
500
  capacity_of_str_len: int
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
44
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
45
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
46
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
49
  from sglang.srt.sampling.sampling_params import SamplingParams
50
50
  from sglang.srt.server_args import ServerArgs
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
1163
1163
  input_embeds=self.input_embeds,
1164
1164
  spec_algorithm=self.spec_algorithm,
1165
1165
  spec_info=self.spec_info,
1166
+ capture_hidden_mode=(
1167
+ getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
1168
+ if self.spec_info
1169
+ else CaptureHiddenMode.NULL
1170
+ ),
1166
1171
  )
1167
1172
 
1168
1173
  def copy(self):
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
1237
1242
  # Speculative decoding
1238
1243
  spec_algorithm: SpeculativeAlgorithm = None
1239
1244
  spec_info: Optional[SpecInfo] = None
1245
+ capture_hidden_mode: CaptureHiddenMode = None
1240
1246
 
1241
1247
 
1242
1248
  @triton.jit
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
+ import faulthandler
16
17
  import logging
17
18
  import os
18
19
  import signal
@@ -46,6 +47,10 @@ from sglang.srt.managers.io_struct import (
46
47
  OpenSessionReqInput,
47
48
  OpenSessionReqOutput,
48
49
  ProfileReq,
50
+ ReleaseMemoryOccupationReqInput,
51
+ ReleaseMemoryOccupationReqOutput,
52
+ ResumeMemoryOccupationReqInput,
53
+ ResumeMemoryOccupationReqOutput,
49
54
  TokenizedEmbeddingReqInput,
50
55
  TokenizedGenerateReqInput,
51
56
  UpdateWeightFromDiskReqInput,
@@ -77,6 +82,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerSta
77
82
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
78
83
  from sglang.srt.server_args import PortArgs, ServerArgs
79
84
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
85
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
80
86
  from sglang.srt.utils import (
81
87
  broadcast_pyobj,
82
88
  configure_logger,
@@ -356,6 +362,10 @@ class Scheduler:
356
362
  t.start()
357
363
  self.parent_process = psutil.Process().parent()
358
364
 
365
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
366
+ enable=server_args.enable_memory_saver
367
+ )
368
+
359
369
  # Init profiler
360
370
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
361
371
  self.profiler = None
@@ -399,6 +409,8 @@ class Scheduler:
399
409
  self.watchdog_last_time = time.time()
400
410
  time.sleep(self.watchdog_timeout / 2)
401
411
 
412
+ # Wait sometimes so that the parent process can print the error.
413
+ time.sleep(5)
402
414
  self.parent_process.send_signal(signal.SIGQUIT)
403
415
 
404
416
  @torch.no_grad()
@@ -516,6 +528,12 @@ class Scheduler:
516
528
  elif isinstance(recv_req, GetWeightsByNameReqInput):
517
529
  parameter = self.get_weights_by_name(recv_req)
518
530
  self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
531
+ elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
532
+ self.release_memory_occupation()
533
+ self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
534
+ elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
535
+ self.resume_memory_occupation()
536
+ self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
519
537
  elif isinstance(recv_req, ProfileReq):
520
538
  if recv_req == ProfileReq.START_PROFILE:
521
539
  self.start_profile()
@@ -962,10 +980,13 @@ class Scheduler:
962
980
  self.tp_worker.forward_batch_generation(model_worker_batch)
963
981
  )
964
982
  else:
965
- logits_output, next_token_ids, model_worker_batch, spec_info = (
966
- self.draft_worker.forward_batch_speculative_generation(batch)
967
- )
968
- batch.spec_info = spec_info
983
+ (
984
+ logits_output,
985
+ next_token_ids,
986
+ model_worker_batch,
987
+ num_accepted_tokens,
988
+ ) = self.draft_worker.forward_batch_speculative_generation(batch)
989
+ self.num_generated_tokens += num_accepted_tokens
969
990
  elif batch.forward_mode.is_idle():
970
991
  model_worker_batch = batch.get_model_worker_batch()
971
992
  self.tp_worker.forward_batch_idle(model_worker_batch)
@@ -1250,7 +1271,6 @@ class Scheduler:
1250
1271
  decode_ids_list = []
1251
1272
  read_offsets = []
1252
1273
  output_ids = []
1253
- origin_input_ids = []
1254
1274
 
1255
1275
  skip_special_tokens = []
1256
1276
  spaces_between_special_tokens = []
@@ -1302,14 +1322,8 @@ class Scheduler:
1302
1322
  decode_ids, read_offset = req.init_incremental_detokenize()
1303
1323
  decode_ids_list.append(decode_ids)
1304
1324
  read_offsets.append(read_offset)
1305
- if self.skip_tokenizer_init or self.server_args.return_token_ids:
1325
+ if self.skip_tokenizer_init:
1306
1326
  output_ids.append(req.output_ids)
1307
- else:
1308
- output_ids = None
1309
- if self.server_args.return_token_ids:
1310
- origin_input_ids.append(req.origin_input_ids)
1311
- else:
1312
- origin_input_ids = None
1313
1327
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1314
1328
  spaces_between_special_tokens.append(
1315
1329
  req.sampling_params.spaces_between_special_tokens
@@ -1341,7 +1355,6 @@ class Scheduler:
1341
1355
  decoded_texts,
1342
1356
  decode_ids_list,
1343
1357
  read_offsets,
1344
- origin_input_ids,
1345
1358
  output_ids,
1346
1359
  skip_special_tokens,
1347
1360
  spaces_between_special_tokens,
@@ -1513,8 +1526,9 @@ class Scheduler:
1513
1526
  return success, message
1514
1527
 
1515
1528
  def update_weights_from_distributed(
1516
- self, recv_req: UpdateWeightsFromDistributedReqInput
1517
- ):
1529
+ self,
1530
+ recv_req: UpdateWeightsFromDistributedReqInput,
1531
+ ) -> Tuple[bool, str]:
1518
1532
  """Update the online model parameter."""
1519
1533
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
1520
1534
  if success:
@@ -1539,6 +1553,20 @@ class Scheduler:
1539
1553
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1540
1554
  return parameter
1541
1555
 
1556
+ def release_memory_occupation(self):
1557
+ self.stashed_model_static_state = _export_static_state(
1558
+ self.tp_worker.worker.model_runner.model
1559
+ )
1560
+ self.memory_saver_adapter.pause()
1561
+ self.flush_cache()
1562
+
1563
+ def resume_memory_occupation(self):
1564
+ self.memory_saver_adapter.resume()
1565
+ _import_static_state(
1566
+ self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
1567
+ )
1568
+ del self.stashed_model_static_state
1569
+
1542
1570
  def start_profile(self) -> None:
1543
1571
  if self.profiler is None:
1544
1572
  raise RuntimeError("Profiler is not enabled.")
@@ -1577,6 +1605,20 @@ class Scheduler:
1577
1605
  del self.sessions[session_id]
1578
1606
 
1579
1607
 
1608
+ def _export_static_state(model):
1609
+ return dict(
1610
+ buffers=[
1611
+ (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
1612
+ ]
1613
+ )
1614
+
1615
+
1616
+ def _import_static_state(model, static_params):
1617
+ self_named_buffers = dict(model.named_buffers())
1618
+ for name, tensor in static_params["buffers"]:
1619
+ self_named_buffers[name][...] = tensor
1620
+
1621
+
1580
1622
  def run_scheduler_process(
1581
1623
  server_args: ServerArgs,
1582
1624
  port_args: PortArgs,
@@ -1586,6 +1628,7 @@ def run_scheduler_process(
1586
1628
  pipe_writer,
1587
1629
  ):
1588
1630
  setproctitle.setproctitle("sglang::scheduler")
1631
+ faulthandler.enable()
1589
1632
 
1590
1633
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1591
1634
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
@@ -99,7 +99,7 @@ class Session:
99
99
 
100
100
  if last_req is not None:
101
101
  # trim bos token if it is an append
102
- if req.input_ids[0] == tokenizer.bos_token_id:
102
+ if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
103
103
  req.input_ids = req.input_ids[1:]
104
104
 
105
105
  input_ids = (