sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
82
82
  self.forward_thread.start()
83
83
  self.parent_process = psutil.Process().parent()
84
84
  self.scheduler_stream = torch.get_device_module(self.device).current_stream()
85
+ if self.device == "cpu":
86
+ self.scheduler_stream.synchronize = lambda: None # No-op for CPU
85
87
 
86
88
  def get_worker_info(self):
87
89
  return self.worker.get_worker_info()
@@ -92,6 +94,9 @@ class TpModelWorkerClient:
92
94
  def get_tp_cpu_group(self):
93
95
  return self.worker.get_tp_cpu_group()
94
96
 
97
+ def get_attention_tp_cpu_group(self):
98
+ return self.worker.get_attention_tp_cpu_group()
99
+
95
100
  def get_memory_pool(self):
96
101
  return (
97
102
  self.worker.model_runner.req_to_token_pool,
@@ -151,11 +156,6 @@ class TpModelWorkerClient:
151
156
  logits_output.input_token_logprobs = (
152
157
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
153
158
  )
154
- logits_output.normalized_prompt_logprobs = (
155
- logits_output.normalized_prompt_logprobs.to(
156
- "cpu", non_blocking=True
157
- )
158
- )
159
159
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
160
160
  copy_done.record()
161
161
 
@@ -174,9 +174,6 @@ class TpModelWorkerClient:
174
174
  logits_output.input_token_logprobs = (
175
175
  logits_output.input_token_logprobs.tolist()
176
176
  )
177
- logits_output.normalized_prompt_logprobs = (
178
- logits_output.normalized_prompt_logprobs.tolist()
179
- )
180
177
  next_token_ids = next_token_ids.tolist()
181
178
  return logits_output, next_token_ids
182
179
 
@@ -0,0 +1,44 @@
1
+ import logging
2
+ from http import HTTPStatus
3
+ from typing import Optional
4
+
5
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def validate_input_length(
11
+ req: Req, max_req_input_len: int, allow_auto_truncate: bool
12
+ ) -> Optional[str]:
13
+ """Validate and potentially truncate input length.
14
+
15
+ Args:
16
+ req: The request containing input_ids to validate
17
+ max_req_input_len: Maximum allowed input length
18
+ allow_auto_truncate: Whether to truncate long inputs
19
+
20
+ Returns:
21
+ Error message if validation fails, None if successful
22
+ """
23
+ if len(req.origin_input_ids) >= max_req_input_len:
24
+ if allow_auto_truncate:
25
+ logger.warning(
26
+ "Request length is longer than the KV cache pool size or "
27
+ "the max context length. Truncated. "
28
+ f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
29
+ )
30
+ req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
31
+ return None
32
+ else:
33
+ error_msg = (
34
+ f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
35
+ f"the maximum allowed length ({max_req_input_len} tokens). "
36
+ f"Use a shorter input or enable --allow-auto-truncate."
37
+ )
38
+ logger.error(error_msg)
39
+ req.finished_reason = FINISH_ABORT(
40
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
41
+ )
42
+ return error_msg
43
+
44
+ return None
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
+
16
18
  """
17
19
  Memory pool.
18
20
 
@@ -25,8 +27,9 @@ import logging
25
27
  import threading
26
28
  from enum import IntEnum
27
29
  from functools import wraps
28
- from typing import List, Tuple, Union
30
+ from typing import List, Optional, Tuple, Union
29
31
 
32
+ import numpy as np
30
33
  import psutil
31
34
  import torch
32
35
 
@@ -35,29 +38,34 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
35
38
 
36
39
  logger = logging.getLogger(__name__)
37
40
 
41
+ GB = 1024 * 1024 * 1024
42
+
38
43
 
39
44
  class ReqToTokenPool:
40
45
  """A memory pool that maps a request to its token locations."""
41
46
 
42
- def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
47
+ def __init__(
48
+ self,
49
+ size: int,
50
+ max_context_len: int,
51
+ device: str,
52
+ enable_memory_saver: bool,
53
+ ):
54
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
55
+ enable=enable_memory_saver
56
+ )
57
+
43
58
  self.size = size
44
59
  self.max_context_len = max_context_len
45
60
  self.device = device
46
- self.req_to_token = torch.zeros(
47
- (size, max_context_len), dtype=torch.int32, device=device
48
- )
61
+ with memory_saver_adapter.region():
62
+ self.req_to_token = torch.zeros(
63
+ (size, max_context_len), dtype=torch.int32, device=device
64
+ )
49
65
  self.free_slots = list(range(size))
50
- self.write_records = []
51
- self.use_records = use_records
52
-
53
- if self.use_records:
54
- self.write = self.write_with_records
55
- else:
56
- self.write = self.write_without_records
57
66
 
58
67
  def write(self, indices, values):
59
- # Keep the signature for type checking. It will be assigned during runtime.
60
- raise NotImplementedError()
68
+ self.req_to_token[indices] = values
61
69
 
62
70
  def available_size(self):
63
71
  return len(self.free_slots)
@@ -79,23 +87,6 @@ class ReqToTokenPool:
79
87
 
80
88
  def clear(self):
81
89
  self.free_slots = list(range(self.size))
82
- self.write_records = []
83
-
84
- def write_without_records(self, indices, values):
85
- self.req_to_token[indices] = values
86
-
87
- def write_with_records(self, indices, values):
88
- self.req_to_token[indices] = values
89
- self.write_records.append((indices, values))
90
-
91
- def get_write_records(self):
92
- ret = self.write_records
93
- self.write_records = []
94
- return ret
95
-
96
- def apply_write_records(self, write_records: List[Tuple]):
97
- for indices, values in write_records:
98
- self.req_to_token[indices] = values
99
90
 
100
91
 
101
92
  class BaseTokenToKVPool:
@@ -109,8 +100,8 @@ class BaseTokenToKVPool:
109
100
  ):
110
101
  self.size = size
111
102
  self.dtype = dtype
112
- if dtype == torch.float8_e5m2:
113
- # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
103
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
104
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
114
105
  self.store_dtype = torch.uint8
115
106
  else:
116
107
  self.store_dtype = dtype
@@ -186,37 +177,60 @@ class MHATokenToKVPool(BaseTokenToKVPool):
186
177
  head_dim: int,
187
178
  layer_num: int,
188
179
  device: str,
180
+ enable_memory_saver: bool,
189
181
  ):
190
182
  super().__init__(size, dtype, device)
183
+
184
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
185
+ enable=enable_memory_saver
186
+ )
187
+
191
188
  self.head_num = head_num
192
189
  self.head_dim = head_dim
193
190
  self.layer_num = layer_num
194
191
  self._create_buffers()
195
192
 
193
+ k_size, v_size = self.get_kv_size_bytes()
194
+ logger.info(
195
+ f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
196
+ )
197
+
196
198
  def _create_buffers(self):
197
- # [size, head_num, head_dim] for each layer
198
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
199
- self.k_buffer = [
200
- torch.empty(
201
- (self.size + 1, self.head_num, self.head_dim),
202
- dtype=self.store_dtype,
203
- device=self.device,
204
- )
205
- for _ in range(self.layer_num)
206
- ]
207
- self.v_buffer = [
208
- torch.empty(
209
- (self.size + 1, self.head_num, self.head_dim),
210
- dtype=self.store_dtype,
211
- device=self.device,
212
- )
213
- for _ in range(self.layer_num)
214
- ]
199
+ with self.memory_saver_adapter.region():
200
+ # [size, head_num, head_dim] for each layer
201
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
202
+ self.k_buffer = [
203
+ torch.empty(
204
+ (self.size + 1, self.head_num, self.head_dim),
205
+ dtype=self.store_dtype,
206
+ device=self.device,
207
+ )
208
+ for _ in range(self.layer_num)
209
+ ]
210
+ self.v_buffer = [
211
+ torch.empty(
212
+ (self.size + 1, self.head_num, self.head_dim),
213
+ dtype=self.store_dtype,
214
+ device=self.device,
215
+ )
216
+ for _ in range(self.layer_num)
217
+ ]
215
218
 
216
219
  def _clear_buffers(self):
217
220
  del self.k_buffer
218
221
  del self.v_buffer
219
222
 
223
+ def get_kv_size_bytes(self):
224
+ assert hasattr(self, "k_buffer")
225
+ assert hasattr(self, "v_buffer")
226
+ k_size_bytes = 0
227
+ for k_cache in self.k_buffer:
228
+ k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
229
+ v_size_bytes = 0
230
+ for v_cache in self.v_buffer:
231
+ v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
232
+ return k_size_bytes, v_size_bytes
233
+
220
234
  # Todo: different memory layout
221
235
  def get_flat_data(self, indices):
222
236
  # prepare a large chunk of contiguous data for efficient transfer
@@ -256,9 +270,15 @@ class MHATokenToKVPool(BaseTokenToKVPool):
256
270
  loc: torch.Tensor,
257
271
  cache_k: torch.Tensor,
258
272
  cache_v: torch.Tensor,
273
+ k_scale: Optional[float] = None,
274
+ v_scale: Optional[float] = None,
259
275
  ):
260
276
  layer_id = layer.layer_id
261
277
  if cache_k.dtype != self.dtype:
278
+ if k_scale is not None:
279
+ cache_k.div_(k_scale)
280
+ if v_scale is not None:
281
+ cache_v.div_(v_scale)
262
282
  cache_k = cache_k.to(self.dtype)
263
283
  cache_v = cache_v.to(self.dtype)
264
284
  if self.store_dtype != self.dtype:
@@ -286,19 +306,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
286
306
  qk_rope_head_dim: int,
287
307
  layer_num: int,
288
308
  device: str,
309
+ enable_memory_saver: bool,
289
310
  ):
290
311
  super().__init__(size, dtype, device)
291
312
 
292
313
  self.kv_lora_rank = kv_lora_rank
293
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
294
- self.kv_buffer = [
295
- torch.empty(
296
- (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
297
- dtype=self.store_dtype,
298
- device=device,
299
- )
300
- for _ in range(layer_num)
301
- ]
314
+
315
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
316
+ enable=enable_memory_saver
317
+ )
318
+
319
+ with memory_saver_adapter.region():
320
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
321
+ self.kv_buffer = [
322
+ torch.empty(
323
+ (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
324
+ dtype=self.store_dtype,
325
+ device=device,
326
+ )
327
+ for _ in range(layer_num)
328
+ ]
302
329
 
303
330
  def get_key_buffer(self, layer_id: int):
304
331
  if self.store_dtype != self.dtype:
@@ -339,26 +366,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
339
366
  layer_num: int,
340
367
  device: str,
341
368
  heavy_channel_num: int,
369
+ enable_memory_saver: bool,
342
370
  ):
343
371
  super().__init__(size, dtype, device)
344
372
 
345
- # [size, head_num, head_dim] for each layer
346
- self.k_buffer = [
347
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
348
- for _ in range(layer_num)
349
- ]
350
- self.v_buffer = [
351
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
352
- for _ in range(layer_num)
353
- ]
354
-
355
- # [size, head_num, heavy_channel_num] for each layer
356
- self.label_buffer = [
357
- torch.empty(
358
- (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
359
- )
360
- for _ in range(layer_num)
361
- ]
373
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
374
+ enable=enable_memory_saver
375
+ )
376
+
377
+ with memory_saver_adapter.region():
378
+ # [size, head_num, head_dim] for each layer
379
+ self.k_buffer = [
380
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
381
+ for _ in range(layer_num)
382
+ ]
383
+ self.v_buffer = [
384
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
385
+ for _ in range(layer_num)
386
+ ]
387
+
388
+ # [size, head_num, heavy_channel_num] for each layer
389
+ self.label_buffer = [
390
+ torch.empty(
391
+ (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
392
+ )
393
+ for _ in range(layer_num)
394
+ ]
362
395
 
363
396
  def get_key_buffer(self, layer_id: int):
364
397
  return self.k_buffer[layer_id]
@@ -25,6 +25,7 @@ class SchedulerStats:
25
25
  gen_throughput: float = 0.0
26
26
  num_queue_reqs: int = 0
27
27
  cache_hit_rate: float = 0.0
28
+ spec_accept_length: float = 0.0
28
29
 
29
30
 
30
31
  class SchedulerMetricsCollector:
@@ -37,42 +38,49 @@ class SchedulerMetricsCollector:
37
38
 
38
39
  self.num_running_reqs = Gauge(
39
40
  name="sglang:num_running_reqs",
40
- documentation="The number of running requests",
41
+ documentation="The number of running requests.",
41
42
  labelnames=labels.keys(),
42
43
  multiprocess_mode="sum",
43
44
  )
44
45
 
45
46
  self.num_used_tokens = Gauge(
46
47
  name="sglang:num_used_tokens",
47
- documentation="The number of used tokens",
48
+ documentation="The number of used tokens.",
48
49
  labelnames=labels.keys(),
49
50
  multiprocess_mode="sum",
50
51
  )
51
52
 
52
53
  self.token_usage = Gauge(
53
54
  name="sglang:token_usage",
54
- documentation="The token usage",
55
+ documentation="The token usage.",
55
56
  labelnames=labels.keys(),
56
57
  multiprocess_mode="mostrecent",
57
58
  )
58
59
 
59
60
  self.gen_throughput = Gauge(
60
61
  name="sglang:gen_throughput",
61
- documentation="The generate throughput (token/s)",
62
+ documentation="The generation throughput (token/s).",
62
63
  labelnames=labels.keys(),
63
64
  multiprocess_mode="sum",
64
65
  )
65
66
 
66
67
  self.num_queue_reqs = Gauge(
67
68
  name="sglang:num_queue_reqs",
68
- documentation="The number of requests in the waiting queue",
69
+ documentation="The number of requests in the waiting queue.",
69
70
  labelnames=labels.keys(),
70
71
  multiprocess_mode="sum",
71
72
  )
72
73
 
73
74
  self.cache_hit_rate = Gauge(
74
75
  name="sglang:cache_hit_rate",
75
- documentation="The cache hit rate",
76
+ documentation="The prefix cache hit rate.",
77
+ labelnames=labels.keys(),
78
+ multiprocess_mode="mostrecent",
79
+ )
80
+
81
+ self.spec_accept_length = Gauge(
82
+ name="sglang:spec_accept_length",
83
+ documentation="The average acceptance length of speculative decoding.",
76
84
  labelnames=labels.keys(),
77
85
  multiprocess_mode="mostrecent",
78
86
  )
@@ -88,6 +96,7 @@ class SchedulerMetricsCollector:
88
96
  self._log_gauge(self.gen_throughput, stats.gen_throughput)
89
97
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
90
98
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
99
+ self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
91
100
 
92
101
 
93
102
  class TokenizerMetricsCollector:
@@ -109,6 +118,12 @@ class TokenizerMetricsCollector:
109
118
  labelnames=labels.keys(),
110
119
  )
111
120
 
121
+ self.num_requests_total = Counter(
122
+ name="sglang:num_requests_total",
123
+ documentation="Number of requests processed.",
124
+ labelnames=labels.keys(),
125
+ )
126
+
112
127
  self.histogram_time_to_first_token = Histogram(
113
128
  name="sglang:time_to_first_token_seconds",
114
129
  documentation="Histogram of time to first token in seconds.",
@@ -185,11 +200,10 @@ class TokenizerMetricsCollector:
185
200
  # Convenience function for logging to counter.
186
201
  counter.labels(**self.labels).inc(data)
187
202
 
188
- def inc_prompt_tokens(self, value: int):
189
- self._log_counter(self.prompt_tokens_total, value)
190
-
191
- def inc_generation_tokens(self, value: int):
192
- self._log_counter(self.generation_tokens_total, value)
203
+ def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
204
+ self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
205
+ self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
206
+ self.num_requests_total.labels(**self.labels).inc(1)
193
207
 
194
208
  def observe_time_to_first_token(self, value: Union[float, int]):
195
209
  self._log_histogram(self.histogram_time_to_first_token, value)
@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
23
  import tqdm
24
- from vllm.distributed import get_tensor_model_parallel_rank
25
- from vllm.distributed.parallel_state import graph_capture
26
24
  from vllm.model_executor.custom_op import CustomOp
27
25
 
26
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
27
+ from sglang.srt.distributed.parallel_state import graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
29
29
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
30
30
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
- from sglang.srt.utils import monkey_patch_vllm_all_gather
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -72,7 +71,6 @@ def patch_model(
72
71
  try:
73
72
  if enable_compile:
74
73
  _to_torch(model, reverse=False, batch_size=batch_size)
75
- monkey_patch_vllm_all_gather()
76
74
  backup_ca_comm = tp_group.ca_comm
77
75
  # Use custom-allreduce here.
78
76
  # We found the custom allreduce is much faster than the built-in allreduce in torch,
@@ -88,7 +86,6 @@ def patch_model(
88
86
  finally:
89
87
  if enable_compile:
90
88
  _to_torch(model, reverse=True, batch_size=batch_size)
91
- monkey_patch_vllm_all_gather(reverse=True)
92
89
  tp_group.ca_comm = backup_ca_comm
93
90
 
94
91
 
@@ -122,6 +119,7 @@ class CudaGraphRunner:
122
119
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
123
120
  self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
124
121
  self.tp_size = self.model_runner.tp_size
122
+ self.dp_size = self.model_runner.server_args.dp_size
125
123
 
126
124
  # Batch sizes to capture
127
125
  self.capture_bs = self.model_runner.server_args.cuda_graph_bs
@@ -218,7 +216,7 @@ class CudaGraphRunner:
218
216
  if self.enable_dp_attention:
219
217
  self.gathered_buffer = torch.zeros(
220
218
  (
221
- self.max_bs * self.tp_size,
219
+ self.max_bs * self.dp_size,
222
220
  self.model_runner.model_config.hidden_size,
223
221
  ),
224
222
  dtype=self.model_runner.dtype,