sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +113 -3
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +72 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +114 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +153 -94
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +2 -1
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +0 -14
- sglang/srt/server.py +58 -16
- sglang/srt/server_args.py +42 -22
- sglang/srt/utils.py +87 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +18 -4
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import logging
|
|
22
22
|
import os
|
23
23
|
import signal
|
24
24
|
import sys
|
25
|
+
import time
|
25
26
|
from typing import Dict, List, Optional, Tuple, Union
|
26
27
|
|
27
28
|
import fastapi
|
@@ -52,6 +53,7 @@ from sglang.srt.managers.io_struct import (
|
|
52
53
|
UpdateWeightReqInput,
|
53
54
|
UpdateWeightReqOutput,
|
54
55
|
)
|
56
|
+
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
55
57
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
56
58
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
57
59
|
from sglang.srt.utils import get_zmq_socket, kill_child_process
|
@@ -69,6 +71,10 @@ class ReqState:
|
|
69
71
|
finished: bool
|
70
72
|
event: asyncio.Event
|
71
73
|
|
74
|
+
# For metrics
|
75
|
+
created_time: float
|
76
|
+
first_token_time: Optional[float] = None
|
77
|
+
|
72
78
|
|
73
79
|
class TokenizerManager:
|
74
80
|
"""TokenizerManager is a process that tokenizes the text."""
|
@@ -80,6 +86,7 @@ class TokenizerManager:
|
|
80
86
|
):
|
81
87
|
# Parse args
|
82
88
|
self.server_args = server_args
|
89
|
+
self.enable_metrics = server_args.enable_metrics
|
83
90
|
|
84
91
|
# Init inter-process communication
|
85
92
|
context = zmq.asyncio.Context(2)
|
@@ -142,11 +149,22 @@ class TokenizerManager:
|
|
142
149
|
# Others
|
143
150
|
self.gracefully_exit = False
|
144
151
|
|
152
|
+
# Metrics
|
153
|
+
if self.enable_metrics:
|
154
|
+
self.metrics_collector = TokenizerMetricsCollector(
|
155
|
+
labels={
|
156
|
+
"model_name": self.server_args.served_model_name,
|
157
|
+
# TODO: Add lora name/path in the future,
|
158
|
+
},
|
159
|
+
)
|
160
|
+
|
145
161
|
async def generate_request(
|
146
162
|
self,
|
147
163
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
148
164
|
request: Optional[fastapi.Request] = None,
|
149
165
|
):
|
166
|
+
created_time = time.time()
|
167
|
+
|
150
168
|
if self.to_create_loop:
|
151
169
|
self.create_handle_loop()
|
152
170
|
|
@@ -164,10 +182,12 @@ class TokenizerManager:
|
|
164
182
|
if is_single:
|
165
183
|
tokenized_obj = await self._tokenize_one_request(obj)
|
166
184
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
167
|
-
async for response in self._wait_one_response(obj, request):
|
185
|
+
async for response in self._wait_one_response(obj, request, created_time):
|
168
186
|
yield response
|
169
187
|
else:
|
170
|
-
async for response in self._handle_batch_request(
|
188
|
+
async for response in self._handle_batch_request(
|
189
|
+
obj, request, created_time
|
190
|
+
):
|
171
191
|
yield response
|
172
192
|
|
173
193
|
async def _tokenize_one_request(
|
@@ -215,7 +235,7 @@ class TokenizerManager:
|
|
215
235
|
logprob_start_len,
|
216
236
|
top_logprobs_num,
|
217
237
|
obj.stream,
|
218
|
-
obj.lora_path
|
238
|
+
obj.lora_path,
|
219
239
|
)
|
220
240
|
elif isinstance(obj, EmbeddingReqInput):
|
221
241
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -231,10 +251,11 @@ class TokenizerManager:
|
|
231
251
|
self,
|
232
252
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
233
253
|
request: Optional[fastapi.Request] = None,
|
254
|
+
created_time: Optional[float] = None,
|
234
255
|
):
|
235
256
|
"""Wait for the response of one request."""
|
236
257
|
event = asyncio.Event()
|
237
|
-
state = ReqState([], False, event)
|
258
|
+
state = ReqState([], False, event, created_time=created_time)
|
238
259
|
self.rid_to_state[obj.rid] = state
|
239
260
|
|
240
261
|
while True:
|
@@ -272,6 +293,7 @@ class TokenizerManager:
|
|
272
293
|
self,
|
273
294
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
274
295
|
request: Optional[fastapi.Request] = None,
|
296
|
+
created_time: Optional[float] = None,
|
275
297
|
):
|
276
298
|
batch_size = obj.batch_size
|
277
299
|
|
@@ -283,14 +305,18 @@ class TokenizerManager:
|
|
283
305
|
tmp_obj = obj[i]
|
284
306
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
285
307
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
286
|
-
generators.append(
|
308
|
+
generators.append(
|
309
|
+
self._wait_one_response(tmp_obj, request, created_time)
|
310
|
+
)
|
287
311
|
rids.append(tmp_obj.rid)
|
288
312
|
else:
|
289
313
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
290
314
|
|
291
315
|
# Tokenize all requests
|
292
316
|
objs = [obj[i] for i in range(batch_size)]
|
293
|
-
tokenized_objs = await asyncio.gather(
|
317
|
+
tokenized_objs = await asyncio.gather(
|
318
|
+
*(self._tokenize_one_request(obj) for obj in objs)
|
319
|
+
)
|
294
320
|
|
295
321
|
# Cache the common prefix for parallel sampling
|
296
322
|
for i in range(batch_size):
|
@@ -301,7 +327,9 @@ class TokenizerManager:
|
|
301
327
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
302
328
|
tokenized_obj.stream = False
|
303
329
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
304
|
-
await self._wait_one_response(
|
330
|
+
await self._wait_one_response(
|
331
|
+
tmp_obj, request, created_time
|
332
|
+
).__anext__()
|
305
333
|
|
306
334
|
# Expand requests, assign new rids for them, and send them
|
307
335
|
for i in range(batch_size):
|
@@ -310,7 +338,9 @@ class TokenizerManager:
|
|
310
338
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
311
339
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
312
340
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
313
|
-
generators.append(
|
341
|
+
generators.append(
|
342
|
+
self._wait_one_response(tmp_obj, request, created_time)
|
343
|
+
)
|
314
344
|
rids.append(tmp_obj.rid)
|
315
345
|
|
316
346
|
# Wait for all requests
|
@@ -322,7 +352,9 @@ class TokenizerManager:
|
|
322
352
|
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
323
353
|
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
324
354
|
while task_map:
|
325
|
-
done, _ = await asyncio.wait(
|
355
|
+
done, _ = await asyncio.wait(
|
356
|
+
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
357
|
+
)
|
326
358
|
|
327
359
|
for task in done:
|
328
360
|
gen = task_map.pop(task)
|
@@ -367,7 +399,7 @@ class TokenizerManager:
|
|
367
399
|
if self.server_args.dp_size == 1:
|
368
400
|
res = await self.mem_pool_size
|
369
401
|
return res.size
|
370
|
-
else:
|
402
|
+
else: # self.server_args.dp_size > 1
|
371
403
|
self.mem_pool_size_tmp = []
|
372
404
|
res = await self.mem_pool_size
|
373
405
|
ret = [r.size for r in res]
|
@@ -384,11 +416,15 @@ class TokenizerManager:
|
|
384
416
|
obj.load_format = self.server_args.load_format
|
385
417
|
|
386
418
|
if not self.model_update_lock.locked():
|
387
|
-
|
419
|
+
|
388
420
|
async with self.model_update_lock:
|
389
421
|
# wait for the previous generation requests to finish
|
390
|
-
|
391
|
-
|
422
|
+
for i in range(3):
|
423
|
+
while len(self.rid_to_state) > 0:
|
424
|
+
await asyncio.sleep(0.001)
|
425
|
+
# FIXME: We add some sleep here to avoid some race conditions.
|
426
|
+
# We can use a read-write lock as a better fix.
|
427
|
+
await asyncio.sleep(0.01)
|
392
428
|
self.send_to_scheduler.send_pyobj(obj)
|
393
429
|
self.model_update_result = asyncio.Future()
|
394
430
|
|
@@ -399,7 +435,7 @@ class TokenizerManager:
|
|
399
435
|
self.server_args.load_format = obj.load_format
|
400
436
|
self.model_path = obj.model_path
|
401
437
|
return result.success, result.message
|
402
|
-
else:
|
438
|
+
else: # self.server_args.dp_size > 1
|
403
439
|
self.model_update_tmp = []
|
404
440
|
result = await self.model_update_result
|
405
441
|
|
@@ -457,7 +493,7 @@ class TokenizerManager:
|
|
457
493
|
break
|
458
494
|
|
459
495
|
kill_child_process(include_self=True)
|
460
|
-
sys.exit(
|
496
|
+
sys.exit(0)
|
461
497
|
|
462
498
|
async def handle_loop(self):
|
463
499
|
"""The event loop that handles requests"""
|
@@ -470,7 +506,7 @@ class TokenizerManager:
|
|
470
506
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
471
507
|
if self.server_args.dp_size == 1:
|
472
508
|
self.model_update_result.set_result(recv_obj)
|
473
|
-
else:
|
509
|
+
else: # self.server_args.dp_size > 1
|
474
510
|
self.model_update_tmp.append(recv_obj)
|
475
511
|
# set future if the all results are recevied
|
476
512
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
@@ -479,7 +515,7 @@ class TokenizerManager:
|
|
479
515
|
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
480
516
|
if self.server_args.dp_size == 1:
|
481
517
|
self.mem_pool_size.set_result(recv_obj)
|
482
|
-
else:
|
518
|
+
else: # self.sever_args.dp_size > 1
|
483
519
|
self.mem_pool_size_tmp.append(recv_obj)
|
484
520
|
# set future if the all results are received
|
485
521
|
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
@@ -516,6 +552,34 @@ class TokenizerManager:
|
|
516
552
|
state.finished = recv_obj.finished_reason[i] is not None
|
517
553
|
state.event.set()
|
518
554
|
|
555
|
+
if self.enable_metrics:
|
556
|
+
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
557
|
+
|
558
|
+
if state.first_token_time is None:
|
559
|
+
state.first_token_time = time.time()
|
560
|
+
self.metrics_collector.observe_time_to_first_token(
|
561
|
+
state.first_token_time - state.created_time
|
562
|
+
)
|
563
|
+
else:
|
564
|
+
if completion_tokens >= 2:
|
565
|
+
self.metrics_collector.observe_time_per_output_token(
|
566
|
+
(time.time() - state.first_token_time)
|
567
|
+
/ (completion_tokens - 1)
|
568
|
+
)
|
569
|
+
|
570
|
+
if state.finished:
|
571
|
+
self.metrics_collector.inc_prompt_tokens(
|
572
|
+
recv_obj.meta_info[i]["prompt_tokens"]
|
573
|
+
)
|
574
|
+
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
575
|
+
self.metrics_collector.observe_e2e_request_latency(
|
576
|
+
time.time() - state.created_time
|
577
|
+
)
|
578
|
+
if completion_tokens >= 1:
|
579
|
+
self.metrics_collector.observe_time_per_output_token(
|
580
|
+
(time.time() - state.created_time) / completion_tokens
|
581
|
+
)
|
582
|
+
|
519
583
|
def convert_logprob_style(
|
520
584
|
self,
|
521
585
|
ret: dict,
|
@@ -0,0 +1,211 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Utilities for Prometheus Metrics Collection."""
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import Dict, Union
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class SchedulerStats:
|
24
|
+
num_running_reqs: int = 0
|
25
|
+
num_used_tokens: int = 0
|
26
|
+
token_usage: float = 0.0
|
27
|
+
gen_throughput: float = 0.0
|
28
|
+
num_queue_reqs: int = 0
|
29
|
+
cache_hit_rate: float = 0.0
|
30
|
+
|
31
|
+
|
32
|
+
class SchedulerMetricsCollector:
|
33
|
+
|
34
|
+
def __init__(self, labels: Dict[str, str]) -> None:
|
35
|
+
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
36
|
+
from prometheus_client import Gauge
|
37
|
+
|
38
|
+
self.labels = labels
|
39
|
+
|
40
|
+
self.num_running_reqs = Gauge(
|
41
|
+
name="sglang:num_running_reqs",
|
42
|
+
documentation="The number of running requests",
|
43
|
+
labelnames=labels.keys(),
|
44
|
+
multiprocess_mode="sum",
|
45
|
+
)
|
46
|
+
|
47
|
+
self.num_used_tokens = Gauge(
|
48
|
+
name="sglang:num_used_tokens",
|
49
|
+
documentation="The number of used tokens",
|
50
|
+
labelnames=labels.keys(),
|
51
|
+
multiprocess_mode="sum",
|
52
|
+
)
|
53
|
+
|
54
|
+
self.token_usage = Gauge(
|
55
|
+
name="sglang:token_usage",
|
56
|
+
documentation="The token usage",
|
57
|
+
labelnames=labels.keys(),
|
58
|
+
multiprocess_mode="mostrecent",
|
59
|
+
)
|
60
|
+
|
61
|
+
self.gen_throughput = Gauge(
|
62
|
+
name="sglang:gen_throughput",
|
63
|
+
documentation="The generate throughput (token/s)",
|
64
|
+
labelnames=labels.keys(),
|
65
|
+
multiprocess_mode="sum",
|
66
|
+
)
|
67
|
+
|
68
|
+
self.num_queue_reqs = Gauge(
|
69
|
+
name="sglang:num_queue_reqs",
|
70
|
+
documentation="The number of requests in the waiting queue",
|
71
|
+
labelnames=labels.keys(),
|
72
|
+
multiprocess_mode="sum",
|
73
|
+
)
|
74
|
+
|
75
|
+
self.cache_hit_rate = Gauge(
|
76
|
+
name="sglang:cache_hit_rate",
|
77
|
+
documentation="The cache hit rate",
|
78
|
+
labelnames=labels.keys(),
|
79
|
+
multiprocess_mode="mostrecent",
|
80
|
+
)
|
81
|
+
|
82
|
+
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
83
|
+
# Convenience function for logging to gauge.
|
84
|
+
gauge.labels(**self.labels).set(data)
|
85
|
+
|
86
|
+
def log_stats(self, stats: SchedulerStats) -> None:
|
87
|
+
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
88
|
+
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
89
|
+
self._log_gauge(self.token_usage, stats.token_usage)
|
90
|
+
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
91
|
+
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
92
|
+
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
93
|
+
|
94
|
+
|
95
|
+
class TokenizerMetricsCollector:
|
96
|
+
def __init__(self, labels: Dict[str, str]) -> None:
|
97
|
+
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
98
|
+
from prometheus_client import Counter, Histogram
|
99
|
+
|
100
|
+
self.labels = labels
|
101
|
+
|
102
|
+
self.prompt_tokens_total = Counter(
|
103
|
+
name="sglang:prompt_tokens_total",
|
104
|
+
documentation="Number of prefill tokens processed.",
|
105
|
+
labelnames=labels.keys(),
|
106
|
+
)
|
107
|
+
|
108
|
+
self.generation_tokens_total = Counter(
|
109
|
+
name="sglang:generation_tokens_total",
|
110
|
+
documentation="Number of generation tokens processed.",
|
111
|
+
labelnames=labels.keys(),
|
112
|
+
)
|
113
|
+
|
114
|
+
self.histogram_time_to_first_token = Histogram(
|
115
|
+
name="sglang:time_to_first_token_seconds",
|
116
|
+
documentation="Histogram of time to first token in seconds.",
|
117
|
+
labelnames=labels.keys(),
|
118
|
+
buckets=[
|
119
|
+
0.001,
|
120
|
+
0.005,
|
121
|
+
0.01,
|
122
|
+
0.02,
|
123
|
+
0.04,
|
124
|
+
0.06,
|
125
|
+
0.08,
|
126
|
+
0.1,
|
127
|
+
0.25,
|
128
|
+
0.5,
|
129
|
+
0.75,
|
130
|
+
1.0,
|
131
|
+
2.5,
|
132
|
+
5.0,
|
133
|
+
7.5,
|
134
|
+
10.0,
|
135
|
+
15.0,
|
136
|
+
20.0,
|
137
|
+
25.0,
|
138
|
+
30.0,
|
139
|
+
],
|
140
|
+
)
|
141
|
+
|
142
|
+
self.histogram_time_per_output_token = Histogram(
|
143
|
+
name="sglang:time_per_output_token_seconds",
|
144
|
+
documentation="Histogram of time per output token in seconds.",
|
145
|
+
labelnames=labels.keys(),
|
146
|
+
buckets=[
|
147
|
+
0.005,
|
148
|
+
0.01,
|
149
|
+
0.015,
|
150
|
+
0.02,
|
151
|
+
0.025,
|
152
|
+
0.03,
|
153
|
+
0.04,
|
154
|
+
0.05,
|
155
|
+
0.075,
|
156
|
+
0.1,
|
157
|
+
0.15,
|
158
|
+
0.2,
|
159
|
+
0.3,
|
160
|
+
0.4,
|
161
|
+
0.5,
|
162
|
+
0.75,
|
163
|
+
1.0,
|
164
|
+
2.5,
|
165
|
+
],
|
166
|
+
)
|
167
|
+
|
168
|
+
self.histogram_e2e_request_latency = Histogram(
|
169
|
+
name="sglang:e2e_request_latency_seconds",
|
170
|
+
documentation="Histogram of End-to-end request latency in seconds",
|
171
|
+
labelnames=labels.keys(),
|
172
|
+
buckets=[
|
173
|
+
0.3,
|
174
|
+
0.5,
|
175
|
+
0.8,
|
176
|
+
1.0,
|
177
|
+
1.5,
|
178
|
+
2.0,
|
179
|
+
2.5,
|
180
|
+
5.0,
|
181
|
+
10.0,
|
182
|
+
15.0,
|
183
|
+
20.0,
|
184
|
+
30.0,
|
185
|
+
40.0,
|
186
|
+
50.0,
|
187
|
+
60.0,
|
188
|
+
],
|
189
|
+
)
|
190
|
+
|
191
|
+
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
192
|
+
histogram.labels(**self.labels).observe(data)
|
193
|
+
|
194
|
+
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
195
|
+
# Convenience function for logging to counter.
|
196
|
+
counter.labels(**self.labels).inc(data)
|
197
|
+
|
198
|
+
def inc_prompt_tokens(self, value: int):
|
199
|
+
self._log_counter(self.prompt_tokens_total, value)
|
200
|
+
|
201
|
+
def inc_generation_tokens(self, value: int):
|
202
|
+
self._log_counter(self.generation_tokens_total, value)
|
203
|
+
|
204
|
+
def observe_time_to_first_token(self, value: Union[float, int]):
|
205
|
+
self._log_histogram(self.histogram_time_to_first_token, value)
|
206
|
+
|
207
|
+
def observe_time_per_output_token(self, value: Union[float, int]):
|
208
|
+
self._log_histogram(self.histogram_time_per_output_token, value)
|
209
|
+
|
210
|
+
def observe_e2e_request_latency(self, value: Union[float, int]):
|
211
|
+
self._log_histogram(self.histogram_e2e_request_latency, value)
|
@@ -0,0 +1,108 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Records the latency of some functions
|
18
|
+
"""
|
19
|
+
|
20
|
+
import asyncio
|
21
|
+
import time
|
22
|
+
from functools import wraps
|
23
|
+
from typing import Any, Callable, List, Optional
|
24
|
+
|
25
|
+
enable_metrics = False
|
26
|
+
|
27
|
+
|
28
|
+
def enable_func_timer():
|
29
|
+
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
30
|
+
from prometheus_client import Histogram
|
31
|
+
|
32
|
+
global enable_metrics, FUNC_LATENCY
|
33
|
+
enable_metrics = True
|
34
|
+
|
35
|
+
FUNC_LATENCY = Histogram(
|
36
|
+
"sglang:func_latency_seconds",
|
37
|
+
"Function latency in seconds",
|
38
|
+
# captures latency in range [50ms - ~50s]
|
39
|
+
buckets=exponential_buckets(start=0.05, width=1.5, length=18),
|
40
|
+
labelnames=["name"],
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
FUNC_LATENCY = None
|
45
|
+
|
46
|
+
|
47
|
+
def exponential_buckets(start: float, width: float, length: int) -> List[float]:
|
48
|
+
buckets = []
|
49
|
+
for i in range(length):
|
50
|
+
buckets.append(start * (width**i))
|
51
|
+
return buckets
|
52
|
+
|
53
|
+
|
54
|
+
def time_func_latency(
|
55
|
+
func: Callable = None, name: Optional[str] = None
|
56
|
+
) -> Callable[..., Any]:
|
57
|
+
"""
|
58
|
+
A decorator to observe the latency of a function's execution. Supports both sync and async functions.
|
59
|
+
|
60
|
+
NOTE: We use our own implementation of a timer decorator since prometheus_client does not support async
|
61
|
+
context manager yet.
|
62
|
+
|
63
|
+
Overhead: The overhead introduced here in case of an async function could likely be because of `await` introduced
|
64
|
+
which will return in another coroutine object creation and under heavy load could see longer wall time
|
65
|
+
(scheduling delays due to introduction of another awaitable).
|
66
|
+
"""
|
67
|
+
|
68
|
+
def measure(func: Callable[..., Any]) -> Callable[..., Any]:
|
69
|
+
nonlocal name
|
70
|
+
|
71
|
+
name = name or func.__name__
|
72
|
+
|
73
|
+
@wraps(func)
|
74
|
+
async def async_wrapper(*args, **kwargs):
|
75
|
+
if not enable_metrics:
|
76
|
+
return await func(*args, **kwargs)
|
77
|
+
|
78
|
+
metric = FUNC_LATENCY
|
79
|
+
start = time.monotonic()
|
80
|
+
ret = func(*args, **kwargs)
|
81
|
+
if isinstance(ret, asyncio.Future) or asyncio.iscoroutine(ret):
|
82
|
+
try:
|
83
|
+
ret = await ret
|
84
|
+
finally:
|
85
|
+
metric.labels(name=name).observe(time.monotonic() - start)
|
86
|
+
return ret
|
87
|
+
|
88
|
+
@wraps(func)
|
89
|
+
def sync_wrapper(*args, **kwargs):
|
90
|
+
if not enable_metrics:
|
91
|
+
return func(*args, **kwargs)
|
92
|
+
|
93
|
+
metric = FUNC_LATENCY
|
94
|
+
start = time.monotonic()
|
95
|
+
try:
|
96
|
+
ret = func(*args, **kwargs)
|
97
|
+
finally:
|
98
|
+
metric.labels(name=name).observe(time.monotonic() - start)
|
99
|
+
return ret
|
100
|
+
|
101
|
+
if asyncio.iscoroutinefunction(func):
|
102
|
+
return async_wrapper
|
103
|
+
return sync_wrapper
|
104
|
+
|
105
|
+
if func:
|
106
|
+
return measure(func)
|
107
|
+
else:
|
108
|
+
return measure
|
sglang/srt/mm_utils.py
CHANGED
@@ -17,7 +17,7 @@ limitations under the License.
|
|
17
17
|
"""
|
18
18
|
Utilities for multi-modal models.
|
19
19
|
|
20
|
-
This python file mainly contains utilities that were used in the
|
20
|
+
This python file mainly contains utilities that were used in the
|
21
21
|
image processing logic of llava-next including operations such as
|
22
22
|
anyres and anyres_max
|
23
23
|
|
@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import (
|
|
32
32
|
LogitsProcessorOutput,
|
33
33
|
)
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
35
|
-
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
35
|
+
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
36
36
|
|
37
37
|
if TYPE_CHECKING:
|
38
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -92,7 +92,7 @@ def set_torch_compile_config():
|
|
92
92
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
93
93
|
|
94
94
|
|
95
|
-
@
|
95
|
+
@maybe_torch_compile(dynamic=True)
|
96
96
|
def clamp_position(seq_lens):
|
97
97
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
98
98
|
|
@@ -136,8 +136,13 @@ class ForwardBatch:
|
|
136
136
|
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
137
137
|
if self.forward_mode.is_decode():
|
138
138
|
for i, _ in enumerate(mrope_positions_list):
|
139
|
+
mrope_position_delta = (
|
140
|
+
0
|
141
|
+
if batch.image_inputs[i] is None
|
142
|
+
else batch.image_inputs[i].mrope_position_delta
|
143
|
+
)
|
139
144
|
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
140
|
-
|
145
|
+
mrope_position_delta,
|
141
146
|
int(self.seq_lens[i]) - 1,
|
142
147
|
int(self.seq_lens[i]),
|
143
148
|
)
|
@@ -159,7 +164,6 @@ class ForwardBatch:
|
|
159
164
|
)
|
160
165
|
]
|
161
166
|
] * 3
|
162
|
-
mrope_position_delta = 0
|
163
167
|
else:
|
164
168
|
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
165
169
|
mrope_positions, mrope_position_delta = (
|
@@ -173,8 +177,8 @@ class ForwardBatch:
|
|
173
177
|
context_len=0,
|
174
178
|
)
|
175
179
|
)
|
180
|
+
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
176
181
|
mrope_positions_list[i] = mrope_positions
|
177
|
-
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
178
182
|
|
179
183
|
self.mrope_positions = torch.concat(
|
180
184
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
|
|
39
39
|
from vllm.model_executor.models import ModelRegistry
|
40
40
|
|
41
41
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
42
|
-
from sglang.srt.constrained import disable_cache
|
43
42
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
44
43
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
45
44
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
@@ -129,6 +128,8 @@ class ModelRunner:
|
|
129
128
|
if server_args.show_time_cost:
|
130
129
|
enable_show_time_cost()
|
131
130
|
if server_args.disable_disk_cache:
|
131
|
+
from outlines.caching import disable_cache
|
132
|
+
|
132
133
|
disable_cache()
|
133
134
|
|
134
135
|
global_server_args_dict.update(
|