sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
|
|
82
82
|
self.forward_thread.start()
|
83
83
|
self.parent_process = psutil.Process().parent()
|
84
84
|
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
85
|
+
if self.device == "cpu":
|
86
|
+
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
85
87
|
|
86
88
|
def get_worker_info(self):
|
87
89
|
return self.worker.get_worker_info()
|
@@ -92,6 +94,9 @@ class TpModelWorkerClient:
|
|
92
94
|
def get_tp_cpu_group(self):
|
93
95
|
return self.worker.get_tp_cpu_group()
|
94
96
|
|
97
|
+
def get_attention_tp_cpu_group(self):
|
98
|
+
return self.worker.get_attention_tp_cpu_group()
|
99
|
+
|
95
100
|
def get_memory_pool(self):
|
96
101
|
return (
|
97
102
|
self.worker.model_runner.req_to_token_pool,
|
@@ -151,11 +156,6 @@ class TpModelWorkerClient:
|
|
151
156
|
logits_output.input_token_logprobs = (
|
152
157
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
153
158
|
)
|
154
|
-
logits_output.normalized_prompt_logprobs = (
|
155
|
-
logits_output.normalized_prompt_logprobs.to(
|
156
|
-
"cpu", non_blocking=True
|
157
|
-
)
|
158
|
-
)
|
159
159
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
160
160
|
copy_done.record()
|
161
161
|
|
@@ -174,9 +174,6 @@ class TpModelWorkerClient:
|
|
174
174
|
logits_output.input_token_logprobs = (
|
175
175
|
logits_output.input_token_logprobs.tolist()
|
176
176
|
)
|
177
|
-
logits_output.normalized_prompt_logprobs = (
|
178
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
179
|
-
)
|
180
177
|
next_token_ids = next_token_ids.tolist()
|
181
178
|
return logits_output, next_token_ids
|
182
179
|
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import logging
|
2
|
+
from http import HTTPStatus
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
def validate_input_length(
|
11
|
+
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
12
|
+
) -> Optional[str]:
|
13
|
+
"""Validate and potentially truncate input length.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
req: The request containing input_ids to validate
|
17
|
+
max_req_input_len: Maximum allowed input length
|
18
|
+
allow_auto_truncate: Whether to truncate long inputs
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
Error message if validation fails, None if successful
|
22
|
+
"""
|
23
|
+
if len(req.origin_input_ids) >= max_req_input_len:
|
24
|
+
if allow_auto_truncate:
|
25
|
+
logger.warning(
|
26
|
+
"Request length is longer than the KV cache pool size or "
|
27
|
+
"the max context length. Truncated. "
|
28
|
+
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
|
29
|
+
)
|
30
|
+
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
|
31
|
+
return None
|
32
|
+
else:
|
33
|
+
error_msg = (
|
34
|
+
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
|
35
|
+
f"the maximum allowed length ({max_req_input_len} tokens). "
|
36
|
+
f"Use a shorter input or enable --allow-auto-truncate."
|
37
|
+
)
|
38
|
+
logger.error(error_msg)
|
39
|
+
req.finished_reason = FINISH_ABORT(
|
40
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
41
|
+
)
|
42
|
+
return error_msg
|
43
|
+
|
44
|
+
return None
|
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
17
|
+
|
16
18
|
"""
|
17
19
|
Memory pool.
|
18
20
|
|
@@ -25,8 +27,9 @@ import logging
|
|
25
27
|
import threading
|
26
28
|
from enum import IntEnum
|
27
29
|
from functools import wraps
|
28
|
-
from typing import List, Tuple, Union
|
30
|
+
from typing import List, Optional, Tuple, Union
|
29
31
|
|
32
|
+
import numpy as np
|
30
33
|
import psutil
|
31
34
|
import torch
|
32
35
|
|
@@ -35,29 +38,34 @@ from sglang.srt.utils import debug_timing, get_compiler_backend
|
|
35
38
|
|
36
39
|
logger = logging.getLogger(__name__)
|
37
40
|
|
41
|
+
GB = 1024 * 1024 * 1024
|
42
|
+
|
38
43
|
|
39
44
|
class ReqToTokenPool:
|
40
45
|
"""A memory pool that maps a request to its token locations."""
|
41
46
|
|
42
|
-
def __init__(
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
size: int,
|
50
|
+
max_context_len: int,
|
51
|
+
device: str,
|
52
|
+
enable_memory_saver: bool,
|
53
|
+
):
|
54
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
55
|
+
enable=enable_memory_saver
|
56
|
+
)
|
57
|
+
|
43
58
|
self.size = size
|
44
59
|
self.max_context_len = max_context_len
|
45
60
|
self.device = device
|
46
|
-
|
47
|
-
|
48
|
-
|
61
|
+
with memory_saver_adapter.region():
|
62
|
+
self.req_to_token = torch.zeros(
|
63
|
+
(size, max_context_len), dtype=torch.int32, device=device
|
64
|
+
)
|
49
65
|
self.free_slots = list(range(size))
|
50
|
-
self.write_records = []
|
51
|
-
self.use_records = use_records
|
52
|
-
|
53
|
-
if self.use_records:
|
54
|
-
self.write = self.write_with_records
|
55
|
-
else:
|
56
|
-
self.write = self.write_without_records
|
57
66
|
|
58
67
|
def write(self, indices, values):
|
59
|
-
|
60
|
-
raise NotImplementedError()
|
68
|
+
self.req_to_token[indices] = values
|
61
69
|
|
62
70
|
def available_size(self):
|
63
71
|
return len(self.free_slots)
|
@@ -79,23 +87,6 @@ class ReqToTokenPool:
|
|
79
87
|
|
80
88
|
def clear(self):
|
81
89
|
self.free_slots = list(range(self.size))
|
82
|
-
self.write_records = []
|
83
|
-
|
84
|
-
def write_without_records(self, indices, values):
|
85
|
-
self.req_to_token[indices] = values
|
86
|
-
|
87
|
-
def write_with_records(self, indices, values):
|
88
|
-
self.req_to_token[indices] = values
|
89
|
-
self.write_records.append((indices, values))
|
90
|
-
|
91
|
-
def get_write_records(self):
|
92
|
-
ret = self.write_records
|
93
|
-
self.write_records = []
|
94
|
-
return ret
|
95
|
-
|
96
|
-
def apply_write_records(self, write_records: List[Tuple]):
|
97
|
-
for indices, values in write_records:
|
98
|
-
self.req_to_token[indices] = values
|
99
90
|
|
100
91
|
|
101
92
|
class BaseTokenToKVPool:
|
@@ -109,8 +100,8 @@ class BaseTokenToKVPool:
|
|
109
100
|
):
|
110
101
|
self.size = size
|
111
102
|
self.dtype = dtype
|
112
|
-
if dtype
|
113
|
-
# NOTE: Store as torch.uint8 because Tensor
|
103
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
104
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
114
105
|
self.store_dtype = torch.uint8
|
115
106
|
else:
|
116
107
|
self.store_dtype = dtype
|
@@ -186,37 +177,60 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
186
177
|
head_dim: int,
|
187
178
|
layer_num: int,
|
188
179
|
device: str,
|
180
|
+
enable_memory_saver: bool,
|
189
181
|
):
|
190
182
|
super().__init__(size, dtype, device)
|
183
|
+
|
184
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
185
|
+
enable=enable_memory_saver
|
186
|
+
)
|
187
|
+
|
191
188
|
self.head_num = head_num
|
192
189
|
self.head_dim = head_dim
|
193
190
|
self.layer_num = layer_num
|
194
191
|
self._create_buffers()
|
195
192
|
|
193
|
+
k_size, v_size = self.get_kv_size_bytes()
|
194
|
+
logger.info(
|
195
|
+
f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB."
|
196
|
+
)
|
197
|
+
|
196
198
|
def _create_buffers(self):
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
(
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
(
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
199
|
+
with self.memory_saver_adapter.region():
|
200
|
+
# [size, head_num, head_dim] for each layer
|
201
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
202
|
+
self.k_buffer = [
|
203
|
+
torch.empty(
|
204
|
+
(self.size + 1, self.head_num, self.head_dim),
|
205
|
+
dtype=self.store_dtype,
|
206
|
+
device=self.device,
|
207
|
+
)
|
208
|
+
for _ in range(self.layer_num)
|
209
|
+
]
|
210
|
+
self.v_buffer = [
|
211
|
+
torch.empty(
|
212
|
+
(self.size + 1, self.head_num, self.head_dim),
|
213
|
+
dtype=self.store_dtype,
|
214
|
+
device=self.device,
|
215
|
+
)
|
216
|
+
for _ in range(self.layer_num)
|
217
|
+
]
|
215
218
|
|
216
219
|
def _clear_buffers(self):
|
217
220
|
del self.k_buffer
|
218
221
|
del self.v_buffer
|
219
222
|
|
223
|
+
def get_kv_size_bytes(self):
|
224
|
+
assert hasattr(self, "k_buffer")
|
225
|
+
assert hasattr(self, "v_buffer")
|
226
|
+
k_size_bytes = 0
|
227
|
+
for k_cache in self.k_buffer:
|
228
|
+
k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
|
229
|
+
v_size_bytes = 0
|
230
|
+
for v_cache in self.v_buffer:
|
231
|
+
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
232
|
+
return k_size_bytes, v_size_bytes
|
233
|
+
|
220
234
|
# Todo: different memory layout
|
221
235
|
def get_flat_data(self, indices):
|
222
236
|
# prepare a large chunk of contiguous data for efficient transfer
|
@@ -256,9 +270,15 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
256
270
|
loc: torch.Tensor,
|
257
271
|
cache_k: torch.Tensor,
|
258
272
|
cache_v: torch.Tensor,
|
273
|
+
k_scale: Optional[float] = None,
|
274
|
+
v_scale: Optional[float] = None,
|
259
275
|
):
|
260
276
|
layer_id = layer.layer_id
|
261
277
|
if cache_k.dtype != self.dtype:
|
278
|
+
if k_scale is not None:
|
279
|
+
cache_k.div_(k_scale)
|
280
|
+
if v_scale is not None:
|
281
|
+
cache_v.div_(v_scale)
|
262
282
|
cache_k = cache_k.to(self.dtype)
|
263
283
|
cache_v = cache_v.to(self.dtype)
|
264
284
|
if self.store_dtype != self.dtype:
|
@@ -286,19 +306,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
286
306
|
qk_rope_head_dim: int,
|
287
307
|
layer_num: int,
|
288
308
|
device: str,
|
309
|
+
enable_memory_saver: bool,
|
289
310
|
):
|
290
311
|
super().__init__(size, dtype, device)
|
291
312
|
|
292
313
|
self.kv_lora_rank = kv_lora_rank
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
314
|
+
|
315
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
316
|
+
enable=enable_memory_saver
|
317
|
+
)
|
318
|
+
|
319
|
+
with memory_saver_adapter.region():
|
320
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
321
|
+
self.kv_buffer = [
|
322
|
+
torch.empty(
|
323
|
+
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
324
|
+
dtype=self.store_dtype,
|
325
|
+
device=device,
|
326
|
+
)
|
327
|
+
for _ in range(layer_num)
|
328
|
+
]
|
302
329
|
|
303
330
|
def get_key_buffer(self, layer_id: int):
|
304
331
|
if self.store_dtype != self.dtype:
|
@@ -339,26 +366,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
339
366
|
layer_num: int,
|
340
367
|
device: str,
|
341
368
|
heavy_channel_num: int,
|
369
|
+
enable_memory_saver: bool,
|
342
370
|
):
|
343
371
|
super().__init__(size, dtype, device)
|
344
372
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
for
|
361
|
-
|
373
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
374
|
+
enable=enable_memory_saver
|
375
|
+
)
|
376
|
+
|
377
|
+
with memory_saver_adapter.region():
|
378
|
+
# [size, head_num, head_dim] for each layer
|
379
|
+
self.k_buffer = [
|
380
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
381
|
+
for _ in range(layer_num)
|
382
|
+
]
|
383
|
+
self.v_buffer = [
|
384
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
385
|
+
for _ in range(layer_num)
|
386
|
+
]
|
387
|
+
|
388
|
+
# [size, head_num, heavy_channel_num] for each layer
|
389
|
+
self.label_buffer = [
|
390
|
+
torch.empty(
|
391
|
+
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
392
|
+
)
|
393
|
+
for _ in range(layer_num)
|
394
|
+
]
|
362
395
|
|
363
396
|
def get_key_buffer(self, layer_id: int):
|
364
397
|
return self.k_buffer[layer_id]
|
sglang/srt/metrics/collector.py
CHANGED
@@ -25,6 +25,7 @@ class SchedulerStats:
|
|
25
25
|
gen_throughput: float = 0.0
|
26
26
|
num_queue_reqs: int = 0
|
27
27
|
cache_hit_rate: float = 0.0
|
28
|
+
spec_accept_length: float = 0.0
|
28
29
|
|
29
30
|
|
30
31
|
class SchedulerMetricsCollector:
|
@@ -37,42 +38,49 @@ class SchedulerMetricsCollector:
|
|
37
38
|
|
38
39
|
self.num_running_reqs = Gauge(
|
39
40
|
name="sglang:num_running_reqs",
|
40
|
-
documentation="The number of running requests",
|
41
|
+
documentation="The number of running requests.",
|
41
42
|
labelnames=labels.keys(),
|
42
43
|
multiprocess_mode="sum",
|
43
44
|
)
|
44
45
|
|
45
46
|
self.num_used_tokens = Gauge(
|
46
47
|
name="sglang:num_used_tokens",
|
47
|
-
documentation="The number of used tokens",
|
48
|
+
documentation="The number of used tokens.",
|
48
49
|
labelnames=labels.keys(),
|
49
50
|
multiprocess_mode="sum",
|
50
51
|
)
|
51
52
|
|
52
53
|
self.token_usage = Gauge(
|
53
54
|
name="sglang:token_usage",
|
54
|
-
documentation="The token usage",
|
55
|
+
documentation="The token usage.",
|
55
56
|
labelnames=labels.keys(),
|
56
57
|
multiprocess_mode="mostrecent",
|
57
58
|
)
|
58
59
|
|
59
60
|
self.gen_throughput = Gauge(
|
60
61
|
name="sglang:gen_throughput",
|
61
|
-
documentation="The
|
62
|
+
documentation="The generation throughput (token/s).",
|
62
63
|
labelnames=labels.keys(),
|
63
64
|
multiprocess_mode="sum",
|
64
65
|
)
|
65
66
|
|
66
67
|
self.num_queue_reqs = Gauge(
|
67
68
|
name="sglang:num_queue_reqs",
|
68
|
-
documentation="The number of requests in the waiting queue",
|
69
|
+
documentation="The number of requests in the waiting queue.",
|
69
70
|
labelnames=labels.keys(),
|
70
71
|
multiprocess_mode="sum",
|
71
72
|
)
|
72
73
|
|
73
74
|
self.cache_hit_rate = Gauge(
|
74
75
|
name="sglang:cache_hit_rate",
|
75
|
-
documentation="The cache hit rate",
|
76
|
+
documentation="The prefix cache hit rate.",
|
77
|
+
labelnames=labels.keys(),
|
78
|
+
multiprocess_mode="mostrecent",
|
79
|
+
)
|
80
|
+
|
81
|
+
self.spec_accept_length = Gauge(
|
82
|
+
name="sglang:spec_accept_length",
|
83
|
+
documentation="The average acceptance length of speculative decoding.",
|
76
84
|
labelnames=labels.keys(),
|
77
85
|
multiprocess_mode="mostrecent",
|
78
86
|
)
|
@@ -88,6 +96,7 @@ class SchedulerMetricsCollector:
|
|
88
96
|
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
89
97
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
90
98
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
99
|
+
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
91
100
|
|
92
101
|
|
93
102
|
class TokenizerMetricsCollector:
|
@@ -109,6 +118,12 @@ class TokenizerMetricsCollector:
|
|
109
118
|
labelnames=labels.keys(),
|
110
119
|
)
|
111
120
|
|
121
|
+
self.num_requests_total = Counter(
|
122
|
+
name="sglang:num_requests_total",
|
123
|
+
documentation="Number of requests processed.",
|
124
|
+
labelnames=labels.keys(),
|
125
|
+
)
|
126
|
+
|
112
127
|
self.histogram_time_to_first_token = Histogram(
|
113
128
|
name="sglang:time_to_first_token_seconds",
|
114
129
|
documentation="Histogram of time to first token in seconds.",
|
@@ -185,11 +200,10 @@ class TokenizerMetricsCollector:
|
|
185
200
|
# Convenience function for logging to counter.
|
186
201
|
counter.labels(**self.labels).inc(data)
|
187
202
|
|
188
|
-
def
|
189
|
-
self.
|
190
|
-
|
191
|
-
|
192
|
-
self._log_counter(self.generation_tokens_total, value)
|
203
|
+
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
|
204
|
+
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
205
|
+
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
206
|
+
self.num_requests_total.labels(**self.labels).inc(1)
|
193
207
|
|
194
208
|
def observe_time_to_first_token(self, value: Union[float, int]):
|
195
209
|
self._log_histogram(self.histogram_time_to_first_token, value)
|
@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import tqdm
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
25
|
-
from vllm.distributed.parallel_state import graph_capture
|
26
24
|
from vllm.model_executor.custom_op import CustomOp
|
27
25
|
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
|
+
from sglang.srt.distributed.parallel_state import graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
29
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
30
30
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
-
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
39
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -72,7 +71,6 @@ def patch_model(
|
|
72
71
|
try:
|
73
72
|
if enable_compile:
|
74
73
|
_to_torch(model, reverse=False, batch_size=batch_size)
|
75
|
-
monkey_patch_vllm_all_gather()
|
76
74
|
backup_ca_comm = tp_group.ca_comm
|
77
75
|
# Use custom-allreduce here.
|
78
76
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
@@ -88,7 +86,6 @@ def patch_model(
|
|
88
86
|
finally:
|
89
87
|
if enable_compile:
|
90
88
|
_to_torch(model, reverse=True, batch_size=batch_size)
|
91
|
-
monkey_patch_vllm_all_gather(reverse=True)
|
92
89
|
tp_group.ca_comm = backup_ca_comm
|
93
90
|
|
94
91
|
|
@@ -122,6 +119,7 @@ class CudaGraphRunner:
|
|
122
119
|
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
123
120
|
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
124
121
|
self.tp_size = self.model_runner.tp_size
|
122
|
+
self.dp_size = self.model_runner.server_args.dp_size
|
125
123
|
|
126
124
|
# Batch sizes to capture
|
127
125
|
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
@@ -218,7 +216,7 @@ class CudaGraphRunner:
|
|
218
216
|
if self.enable_dp_attention:
|
219
217
|
self.gathered_buffer = torch.zeros(
|
220
218
|
(
|
221
|
-
self.max_bs * self.
|
219
|
+
self.max_bs * self.dp_size,
|
222
220
|
self.model_runner.model_config.hidden_size,
|
223
221
|
),
|
224
222
|
dtype=self.model_runner.dtype,
|