sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
1
+ import abc
2
+ import logging
3
+ import threading
4
+ from enum import IntEnum
5
+ from functools import wraps
6
+
7
+ import psutil
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
+ from sglang.srt.utils import debug_timing
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MemoryStateInt(IntEnum):
17
+ IDLE = 0
18
+ RESERVED = 1
19
+ PROTECTED = 2
20
+ SYNCED = 3
21
+ BACKUP = 4
22
+
23
+
24
+ def synchronized(debug_only=False):
25
+ def _decorator(func):
26
+ @wraps(func)
27
+ def wrapper(self, *args, **kwargs):
28
+ if (not debug_only) or self.debug:
29
+ return func(self, *args, **kwargs)
30
+ with self.lock:
31
+ return func(self, *args, **kwargs)
32
+ else:
33
+ return True
34
+
35
+ return wrapper
36
+
37
+ return _decorator
38
+
39
+
40
+ class HostKVCache(abc.ABC):
41
+
42
+ def __init__(
43
+ self,
44
+ device_pool: KVCache,
45
+ host_to_device_ratio: float,
46
+ host_size: int,
47
+ pin_memory: bool,
48
+ device: str,
49
+ page_size: int,
50
+ ):
51
+ self.device_pool = device_pool
52
+ self.dtype = device_pool.store_dtype
53
+ self.pin_memory = pin_memory
54
+ self.device = device
55
+ self.page_size = page_size
56
+ self.size_per_token = self.get_size_per_token()
57
+ if host_size > 0:
58
+ self.size = int(host_size * 1e9 // self.size_per_token)
59
+ else:
60
+ self.size = int(device_pool.size * host_to_device_ratio)
61
+ # Align the host memory pool size to the page size
62
+ self.size = self.size - (self.size % self.page_size)
63
+ self.start_layer = device_pool.start_layer
64
+ self.end_layer = device_pool.end_layer
65
+
66
+ assert (
67
+ self.size > device_pool.size
68
+ ), "The host memory should be larger than the device memory with the current protocol"
69
+
70
+ # Verify there is enough available host memory.
71
+ host_mem = psutil.virtual_memory()
72
+ requested_bytes = self.size * self.size_per_token
73
+ # preserve at least 10GB for other usage
74
+ ten_gb = 10 * (1024**3)
75
+ if requested_bytes > host_mem.available - ten_gb:
76
+ raise ValueError(
77
+ f"Not enough host memory available. Requesting "
78
+ f"{requested_bytes / 1e9:.2f} GB but only have "
79
+ f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
80
+ f"size of the hierarchical cache."
81
+ )
82
+ else:
83
+ logger.info(
84
+ f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
85
+ )
86
+
87
+ self.kv_buffer = self.init_kv_buffer()
88
+
89
+ # A lock for synchronized operations on memory allocation and state transitions.
90
+ self.lock = threading.RLock()
91
+ self.debug = logger.isEnabledFor(logging.DEBUG)
92
+ self.clear()
93
+
94
+ @abc.abstractmethod
95
+ def get_size_per_token(self):
96
+ raise NotImplementedError()
97
+
98
+ @abc.abstractmethod
99
+ def init_kv_buffer(self):
100
+ raise NotImplementedError()
101
+
102
+ @abc.abstractmethod
103
+ def transfer(self, indices, flat_data):
104
+ raise NotImplementedError()
105
+
106
+ @abc.abstractmethod
107
+ def get_flat_data(self, indices):
108
+ raise NotImplementedError()
109
+
110
+ @abc.abstractmethod
111
+ def get_flat_data_by_layer(self, indices, layer_id):
112
+ raise NotImplementedError()
113
+
114
+ @abc.abstractmethod
115
+ def assign_flat_data(self, indices, flat_data):
116
+ raise NotImplementedError()
117
+
118
+ @synchronized()
119
+ def clear(self):
120
+ # Initialize memory states and tracking structures.
121
+ self.mem_state = torch.zeros(
122
+ (self.size,), dtype=torch.uint8, device=self.device
123
+ )
124
+ self.free_slots = torch.arange(self.size, dtype=torch.int64)
125
+
126
+ def available_size(self):
127
+ return len(self.free_slots)
128
+
129
+ @synchronized()
130
+ def alloc(self, need_size: int) -> torch.Tensor:
131
+ if need_size > self.available_size():
132
+ return None
133
+
134
+ select_index = self.free_slots[:need_size]
135
+ self.free_slots = self.free_slots[need_size:]
136
+
137
+ if self.debug:
138
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
139
+
140
+ return select_index
141
+
142
+ @synchronized()
143
+ def free(self, indices: torch.Tensor) -> int:
144
+ self.free_slots = torch.cat([self.free_slots, indices])
145
+ if self.debug:
146
+ self.mem_state[indices] = MemoryStateInt.IDLE
147
+ return len(indices)
148
+
149
+ @synchronized(debug_only=True)
150
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
151
+ assert len(indices) > 0, "The indices should not be empty"
152
+ states = self.mem_state[indices]
153
+ assert (
154
+ states == states[0]
155
+ ).all(), "The memory slots should have the same state {}".format(states)
156
+ return MemoryStateInt(states[0].item())
157
+
158
+ @synchronized(debug_only=True)
159
+ def is_reserved(self, indices: torch.Tensor) -> bool:
160
+ return self.get_state(indices) == MemoryStateInt.RESERVED
161
+
162
+ @synchronized(debug_only=True)
163
+ def is_protected(self, indices: torch.Tensor) -> bool:
164
+ return self.get_state(indices) == MemoryStateInt.PROTECTED
165
+
166
+ @synchronized(debug_only=True)
167
+ def is_synced(self, indices: torch.Tensor) -> bool:
168
+ return self.get_state(indices) == MemoryStateInt.SYNCED
169
+
170
+ @synchronized(debug_only=True)
171
+ def is_backup(self, indices: torch.Tensor) -> bool:
172
+ return self.get_state(indices) == MemoryStateInt.BACKUP
173
+
174
+ @synchronized(debug_only=True)
175
+ def update_backup(self, indices: torch.Tensor):
176
+ if not self.is_synced(indices):
177
+ raise ValueError(
178
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
179
+ f"Current state: {self.get_state(indices)}"
180
+ )
181
+ self.mem_state[indices] = MemoryStateInt.BACKUP
182
+
183
+ @synchronized(debug_only=True)
184
+ def update_synced(self, indices: torch.Tensor):
185
+ self.mem_state[indices] = MemoryStateInt.SYNCED
186
+
187
+ @synchronized(debug_only=True)
188
+ def protect_write(self, indices: torch.Tensor):
189
+ if not self.is_reserved(indices):
190
+ raise ValueError(
191
+ f"The host memory slots should be RESERVED before write operations. "
192
+ f"Current state: {self.get_state(indices)}"
193
+ )
194
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
195
+
196
+ @synchronized(debug_only=True)
197
+ def protect_load(self, indices: torch.Tensor):
198
+ if not self.is_backup(indices):
199
+ raise ValueError(
200
+ f"The host memory slots should be in BACKUP state before load operations. "
201
+ f"Current state: {self.get_state(indices)}"
202
+ )
203
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
204
+
205
+ @synchronized(debug_only=True)
206
+ def complete_io(self, indices: torch.Tensor):
207
+ if not self.is_protected(indices):
208
+ raise ValueError(
209
+ f"The host memory slots should be PROTECTED during I/O operations. "
210
+ f"Current state: {self.get_state(indices)}"
211
+ )
212
+ self.mem_state[indices] = MemoryStateInt.SYNCED
213
+
214
+
215
+ class MHATokenToKVPoolHost(HostKVCache):
216
+ device_pool: MHATokenToKVPool
217
+
218
+ def __init__(
219
+ self,
220
+ device_pool: MHATokenToKVPool,
221
+ host_to_device_ratio: float,
222
+ host_size: int,
223
+ page_size: int,
224
+ pin_memory: bool = True,
225
+ device: str = "cpu",
226
+ ):
227
+ super().__init__(
228
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
229
+ )
230
+
231
+ def get_size_per_token(self):
232
+ self.head_num = self.device_pool.head_num
233
+ self.head_dim = self.device_pool.head_dim
234
+ self.layer_num = self.device_pool.layer_num
235
+
236
+ return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
237
+
238
+ def init_kv_buffer(self):
239
+ return torch.empty(
240
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
241
+ dtype=self.dtype,
242
+ device=self.device,
243
+ pin_memory=self.pin_memory,
244
+ )
245
+
246
+ @debug_timing
247
+ def transfer(self, indices, flat_data):
248
+ # backup prepared data from device to host
249
+ self.kv_buffer[:, :, indices] = flat_data.to(
250
+ device=self.device, non_blocking=False
251
+ )
252
+
253
+ def get_flat_data(self, indices):
254
+ return self.kv_buffer[:, :, indices]
255
+
256
+ def get_flat_data_by_layer(self, indices, layer_id):
257
+ return self.kv_buffer[:, layer_id - self.start_layer, indices]
258
+
259
+ def assign_flat_data(self, indices, flat_data):
260
+ self.kv_buffer[:, :, indices] = flat_data
261
+
262
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
263
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
264
+ for i in range(len(device_indices_cpu)):
265
+ h_index = host_indices[i * self.page_size]
266
+ d_index = device_indices_cpu[i]
267
+ for j in range(self.layer_num):
268
+ self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
269
+ device_pool.k_buffer[j][d_index : d_index + self.page_size],
270
+ non_blocking=True,
271
+ )
272
+ self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
273
+ device_pool.v_buffer[j][d_index : d_index + self.page_size],
274
+ non_blocking=True,
275
+ )
276
+
277
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
278
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
279
+ for i in range(len(device_indices_cpu)):
280
+ h_index = host_indices[i * self.page_size]
281
+ d_index = device_indices_cpu[i]
282
+ device_pool.k_buffer[layer_id - self.start_layer][
283
+ d_index : d_index + self.page_size
284
+ ].copy_(
285
+ self.kv_buffer[
286
+ 0, layer_id - self.start_layer, h_index : h_index + self.page_size
287
+ ],
288
+ non_blocking=True,
289
+ )
290
+ device_pool.v_buffer[layer_id - self.start_layer][
291
+ d_index : d_index + self.page_size
292
+ ].copy_(
293
+ self.kv_buffer[
294
+ 1, layer_id - self.start_layer, h_index : h_index + self.page_size
295
+ ],
296
+ non_blocking=True,
297
+ )
298
+
299
+
300
+ class MLATokenToKVPoolHost(HostKVCache):
301
+ device_pool: MLATokenToKVPool
302
+
303
+ def __init__(
304
+ self,
305
+ device_pool: MLATokenToKVPool,
306
+ host_to_device_ratio: float,
307
+ host_size: int,
308
+ page_size: int,
309
+ pin_memory: bool = True,
310
+ device: str = "cpu",
311
+ ):
312
+ super().__init__(
313
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
314
+ )
315
+
316
+ def get_size_per_token(self):
317
+ self.kv_lora_rank = self.device_pool.kv_lora_rank
318
+ self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
319
+ self.layer_num = self.device_pool.layer_num
320
+
321
+ return (
322
+ (self.kv_lora_rank + self.qk_rope_head_dim)
323
+ * 1
324
+ * self.dtype.itemsize
325
+ * self.layer_num
326
+ )
327
+
328
+ def init_kv_buffer(self):
329
+ return torch.empty(
330
+ (
331
+ self.layer_num,
332
+ self.size,
333
+ 1,
334
+ self.kv_lora_rank + self.qk_rope_head_dim,
335
+ ),
336
+ dtype=self.dtype,
337
+ device=self.device,
338
+ pin_memory=self.pin_memory,
339
+ )
340
+
341
+ @debug_timing
342
+ def transfer(self, indices, flat_data):
343
+ # backup prepared data from device to host
344
+ self.kv_buffer[:, indices] = flat_data.to(
345
+ device=self.device, non_blocking=False
346
+ )
347
+
348
+ def get_flat_data(self, indices):
349
+ return self.kv_buffer[:, indices]
350
+
351
+ def get_flat_data_by_layer(self, indices, layer_id):
352
+ return self.kv_buffer[layer_id - self.start_layer, indices]
353
+
354
+ def assign_flat_data(self, indices, flat_data):
355
+ self.kv_buffer[:, indices] = flat_data
356
+
357
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
358
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
359
+ for i in range(len(device_indices_cpu)):
360
+ h_index = host_indices[i * self.page_size]
361
+ d_index = device_indices_cpu[i]
362
+ for j in range(self.layer_num):
363
+ self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
364
+ device_pool.kv_buffer[j][d_index : d_index + self.page_size],
365
+ non_blocking=True,
366
+ )
367
+
368
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
369
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
370
+ for i in range(len(device_indices_cpu)):
371
+ h_index = host_indices[i * self.page_size]
372
+ d_index = device_indices_cpu[i]
373
+ device_pool.kv_buffer[layer_id - self.start_layer][
374
+ d_index : d_index + self.page_size
375
+ ].copy_(
376
+ self.kv_buffer[
377
+ layer_id - self.start_layer, h_index : h_index + self.page_size
378
+ ],
379
+ non_blocking=True,
380
+ )
@@ -461,23 +461,47 @@ class RadixCache(BasePrefixCache):
461
461
  return ret_list
462
462
 
463
463
  def _record_store_event(self, node: TreeNode):
464
+ # One BlockStored per ``page_size`` chunk.
464
465
  if self.enable_kv_cache_events:
465
- block_hash = hash(tuple(node.key))
466
- parent_block_hash = hash(tuple(node.parent.key))
467
- self.kv_event_queue.append(
468
- BlockStored(
469
- block_hashes=[block_hash],
470
- parent_block_hash=parent_block_hash,
471
- token_ids=node.key,
472
- block_size=len(node.key),
473
- lora_id=None,
466
+ # First chunk links to the last page of the parent node (if any).
467
+ if node.parent is None:
468
+ parent_block_hash = None
469
+ else:
470
+ last_page_start = (
471
+ (len(node.parent.key) - 1) // self.page_size
472
+ ) * self.page_size
473
+ parent_parent_tokens = node.parent.key[last_page_start:]
474
+ parent_block_hash = hash(tuple(parent_parent_tokens))
475
+
476
+ for start in range(0, len(node.key), self.page_size):
477
+ page_tokens = node.key[start : start + self.page_size]
478
+ if not page_tokens:
479
+ continue
480
+
481
+ block_hash = hash(tuple(page_tokens))
482
+
483
+ self.kv_event_queue.append(
484
+ BlockStored(
485
+ block_hashes=[block_hash],
486
+ parent_block_hash=parent_block_hash,
487
+ token_ids=page_tokens,
488
+ block_size=len(page_tokens),
489
+ lora_id=None,
490
+ )
474
491
  )
475
- )
492
+
493
+ # Chain next chunk to this one.
494
+ parent_block_hash = block_hash
476
495
 
477
496
  def _record_remove_event(self, node: TreeNode):
497
+ # One BlockRemoved per chunk.
478
498
  if self.enable_kv_cache_events:
479
- block_hash = hash(tuple(node.key))
480
- self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
499
+ for start in range(0, len(node.key), self.page_size):
500
+ page_tokens = node.key[start : start + self.page_size]
501
+ if not page_tokens:
502
+ continue
503
+ block_hash = hash(tuple(page_tokens))
504
+ self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
481
505
 
482
506
  def _record_all_cleared_event(self):
483
507
  if self.enable_kv_cache_events:
@@ -17,12 +17,14 @@ from __future__ import annotations
17
17
 
18
18
  import bisect
19
19
  import inspect
20
+ import logging
20
21
  import os
21
22
  from contextlib import contextmanager
22
23
  from typing import TYPE_CHECKING, Callable, Optional, Union
23
24
 
24
25
  import torch
25
26
  import tqdm
27
+ from torch.profiler import ProfilerActivity, profile
26
28
 
27
29
  from sglang.srt.custom_op import CustomOp
28
30
  from sglang.srt.distributed import get_tensor_model_parallel_rank
@@ -40,11 +42,14 @@ from sglang.srt.model_executor.forward_batch_info import (
40
42
  from sglang.srt.patch_torch import monkey_patch_torch_compile
41
43
  from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
42
44
  from sglang.srt.utils import (
45
+ empty_context,
43
46
  get_available_gpu_memory,
44
47
  get_device_memory_capacity,
45
48
  rank0_log,
46
49
  )
47
50
 
51
+ logger = logging.getLogger(__name__)
52
+
48
53
  if TYPE_CHECKING:
49
54
  from sglang.srt.model_executor.model_runner import ModelRunner
50
55
 
@@ -147,10 +152,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
147
152
  )
148
153
 
149
154
  gpu_mem = get_device_memory_capacity()
150
- if gpu_mem is not None and gpu_mem > 96 * 1024:
151
- capture_bs += list(range(160, 257, 8))
152
- if gpu_mem is not None and gpu_mem > 180 * 1000:
153
- capture_bs += list(range(256, 513, 16))
155
+ if gpu_mem is not None:
156
+ if gpu_mem > 90 * 1024: # H200, H20
157
+ capture_bs += list(range(160, 257, 8))
158
+ if gpu_mem > 160 * 1000: # B200, MI300
159
+ capture_bs += list(range(256, 513, 16))
154
160
 
155
161
  if max(capture_bs) > model_runner.req_to_token_pool.size:
156
162
  # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -207,6 +213,9 @@ class CudaGraphRunner:
207
213
  model_runner.server_args.enable_two_batch_overlap
208
214
  )
209
215
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
216
+ self.enable_profile_cuda_graph = (
217
+ model_runner.server_args.enable_profile_cuda_graph
218
+ )
210
219
  self.tp_size = model_runner.server_args.tp_size
211
220
  self.dp_size = model_runner.server_args.dp_size
212
221
  self.pp_size = model_runner.server_args.pp_size
@@ -226,6 +235,10 @@ class CudaGraphRunner:
226
235
  self.model_runner.server_args.speculative_num_draft_tokens
227
236
  )
228
237
 
238
+ # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
239
+ if model_runner.server_args.enable_return_hidden_states:
240
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
241
+
229
242
  # Attention backend
230
243
  self.max_bs = max(self.capture_bs)
231
244
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
@@ -333,50 +346,91 @@ class CudaGraphRunner:
333
346
  else True
334
347
  )
335
348
 
349
+ requested_capture_hidden_mode = max(
350
+ forward_batch.capture_hidden_mode,
351
+ (
352
+ forward_batch.spec_info.capture_hidden_mode
353
+ if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
354
+ is not None
355
+ else CaptureHiddenMode.NULL
356
+ ),
357
+ )
358
+ capture_hidden_mode_matches = (
359
+ requested_capture_hidden_mode == CaptureHiddenMode.NULL
360
+ or requested_capture_hidden_mode == self.capture_hidden_mode
361
+ )
336
362
  is_tbo_supported = (
337
363
  forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
338
364
  )
339
365
 
340
- return is_bs_supported and is_encoder_lens_supported and is_tbo_supported
366
+ return (
367
+ is_bs_supported
368
+ and is_encoder_lens_supported
369
+ and is_tbo_supported
370
+ and capture_hidden_mode_matches
371
+ )
341
372
 
342
- def capture(self):
343
- with graph_capture() as graph_capture_context:
344
- self.stream = graph_capture_context.stream
345
- avail_mem = get_available_gpu_memory(
346
- self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
347
- )
348
- # Reverse the order to enable better memory sharing across cuda graphs.
349
- capture_range = (
350
- tqdm.tqdm(list(reversed(self.capture_bs)))
351
- if get_tensor_model_parallel_rank() == 0
352
- else reversed(self.capture_bs)
373
+ def capture(self) -> None:
374
+ profile_context = empty_context()
375
+ if self.enable_profile_cuda_graph:
376
+ profile_context = profile(
377
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
378
+ record_shapes=True,
353
379
  )
354
- for bs in capture_range:
355
- if get_tensor_model_parallel_rank() == 0:
356
- avail_mem = get_available_gpu_memory(
357
- self.model_runner.device,
358
- self.model_runner.gpu_id,
359
- empty_cache=False,
360
- )
361
- capture_range.set_description(
362
- f"Capturing batches ({avail_mem=:.2f} GB)"
363
- )
364
-
365
- with patch_model(
366
- self.model_runner.model,
367
- bs in self.compile_bs,
368
- num_tokens=bs * self.num_tokens_per_bs,
369
- tp_group=self.model_runner.tp_group,
370
- ) as forward:
371
- (
372
- graph,
373
- output_buffers,
374
- ) = self.capture_one_batch_size(bs, forward)
375
- self.graphs[bs] = graph
376
- self.output_buffers[bs] = output_buffers
377
380
 
378
- # Save gemlite cache after each capture
379
- save_gemlite_cache()
381
+ with graph_capture() as graph_capture_context:
382
+ with profile_context as prof:
383
+ self.stream = graph_capture_context.stream
384
+ avail_mem = get_available_gpu_memory(
385
+ self.model_runner.device,
386
+ self.model_runner.gpu_id,
387
+ empty_cache=False,
388
+ )
389
+ # Reverse the order to enable better memory sharing across cuda graphs.
390
+ capture_range = (
391
+ tqdm.tqdm(list(reversed(self.capture_bs)))
392
+ if get_tensor_model_parallel_rank() == 0
393
+ else reversed(self.capture_bs)
394
+ )
395
+ for i, bs in enumerate(capture_range):
396
+ if get_tensor_model_parallel_rank() == 0:
397
+ avail_mem = get_available_gpu_memory(
398
+ self.model_runner.device,
399
+ self.model_runner.gpu_id,
400
+ empty_cache=False,
401
+ )
402
+ capture_range.set_description(
403
+ f"Capturing batches ({avail_mem=:.2f} GB)"
404
+ )
405
+
406
+ with patch_model(
407
+ self.model_runner.model,
408
+ bs in self.compile_bs,
409
+ num_tokens=bs * self.num_tokens_per_bs,
410
+ tp_group=self.model_runner.tp_group,
411
+ ) as forward:
412
+ (
413
+ graph,
414
+ output_buffers,
415
+ ) = self.capture_one_batch_size(bs, forward)
416
+ self.graphs[bs] = graph
417
+ self.output_buffers[bs] = output_buffers
418
+
419
+ # Save gemlite cache after each capture
420
+ save_gemlite_cache()
421
+
422
+ if self.enable_profile_cuda_graph:
423
+ log_message = (
424
+ "Sorted by CUDA Time:\n"
425
+ + prof.key_averages(group_by_input_shape=True).table(
426
+ sort_by="cuda_time_total", row_limit=10
427
+ )
428
+ + "\n\nSorted by CPU Time:\n"
429
+ + prof.key_averages(group_by_input_shape=True).table(
430
+ sort_by="cpu_time_total", row_limit=10
431
+ )
432
+ )
433
+ logger.info(log_message)
380
434
 
381
435
  def capture_one_batch_size(self, bs: int, forward: Callable):
382
436
  graph = torch.cuda.CUDAGraph()
@@ -443,7 +497,7 @@ class CudaGraphRunner:
443
497
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
444
498
  attn_backend=self.model_runner.attn_backend,
445
499
  out_cache_loc=out_cache_loc,
446
- seq_lens_sum=seq_lens.sum(),
500
+ seq_lens_sum=seq_lens.sum().item(),
447
501
  encoder_lens=encoder_lens,
448
502
  return_logprob=False,
449
503
  positions=positions,
@@ -509,21 +563,34 @@ class CudaGraphRunner:
509
563
  return graph, out
510
564
 
511
565
  def recapture_if_needed(self, forward_batch: ForwardBatch):
512
- # If the capture_hidden_mode changes, we need to recapture the graph
513
- hidden_mode_from_spec_info = getattr(
566
+
567
+ # If the required capture_hidden_mode changes, we need to recapture the graph
568
+
569
+ # These are the different factors that can influence the capture_hidden_mode
570
+ capture_hidden_mode_required_by_forward_batch = (
571
+ forward_batch.capture_hidden_mode
572
+ )
573
+ capture_hidden_mode_required_by_spec_info = getattr(
514
574
  forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
515
575
  )
516
- if (
517
- forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
518
- and self.capture_hidden_mode != CaptureHiddenMode.FULL
519
- ):
520
- self.capture_hidden_mode = CaptureHiddenMode.FULL
521
- self.capture()
522
- elif (
523
- forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
524
- and self.capture_hidden_mode != hidden_mode_from_spec_info
525
- ):
526
- self.capture_hidden_mode = hidden_mode_from_spec_info
576
+ capture_hidden_mode_required_for_returning_hidden_states = (
577
+ CaptureHiddenMode.FULL
578
+ if self.model_runner.server_args.enable_return_hidden_states
579
+ else CaptureHiddenMode.NULL
580
+ )
581
+
582
+ # Determine the highest capture_hidden_mode required
583
+ # (If we have FULL, we can emulate LAST or NULL)
584
+ # (If we have LAST, we can emulate NULL)
585
+ required_capture_hidden_mode = max(
586
+ capture_hidden_mode_required_by_forward_batch,
587
+ capture_hidden_mode_required_by_spec_info,
588
+ capture_hidden_mode_required_for_returning_hidden_states,
589
+ )
590
+
591
+ # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
592
+ if self.capture_hidden_mode != required_capture_hidden_mode:
593
+ self.capture_hidden_mode = required_capture_hidden_mode
527
594
  self.capture()
528
595
 
529
596
  def replay_prepare(