sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/bench_serving.py +113 -3
  2. sglang/srt/configs/model_config.py +5 -2
  3. sglang/srt/constrained/__init__.py +2 -66
  4. sglang/srt/constrained/base_grammar_backend.py +72 -0
  5. sglang/srt/constrained/outlines_backend.py +165 -0
  6. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  7. sglang/srt/constrained/xgrammar_backend.py +114 -0
  8. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  10. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  11. sglang/srt/layers/quantization/base_config.py +4 -6
  12. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  13. sglang/srt/managers/io_struct.py +5 -3
  14. sglang/srt/managers/schedule_batch.py +14 -20
  15. sglang/srt/managers/scheduler.py +153 -94
  16. sglang/srt/managers/tokenizer_manager.py +81 -17
  17. sglang/srt/metrics/collector.py +211 -0
  18. sglang/srt/metrics/func_timer.py +108 -0
  19. sglang/srt/mm_utils.py +1 -1
  20. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  21. sglang/srt/model_executor/forward_batch_info.py +7 -3
  22. sglang/srt/model_executor/model_runner.py +2 -1
  23. sglang/srt/models/gemma2_reward.py +69 -0
  24. sglang/srt/models/gpt2.py +31 -37
  25. sglang/srt/models/internlm2_reward.py +62 -0
  26. sglang/srt/models/llama.py +11 -6
  27. sglang/srt/models/llama_reward.py +5 -26
  28. sglang/srt/models/qwen2_vl.py +5 -7
  29. sglang/srt/openai_api/adapter.py +6 -2
  30. sglang/srt/sampling/sampling_batch_info.py +2 -3
  31. sglang/srt/sampling/sampling_params.py +0 -14
  32. sglang/srt/server.py +58 -16
  33. sglang/srt/server_args.py +42 -22
  34. sglang/srt/utils.py +87 -0
  35. sglang/test/simple_eval_common.py +1 -1
  36. sglang/test/simple_eval_humaneval.py +2 -2
  37. sglang/test/simple_eval_mgsm.py +2 -2
  38. sglang/test/test_utils.py +18 -4
  39. sglang/utils.py +1 -0
  40. sglang/version.py +1 -1
  41. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
  42. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
  43. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
  44. sglang/srt/constrained/base_tool_cache.py +0 -65
  45. sglang/srt/constrained/bnf_cache.py +0 -61
  46. sglang/srt/constrained/fsm_cache.py +0 -95
  47. sglang/srt/constrained/grammar.py +0 -190
  48. sglang/srt/constrained/jump_forward.py +0 -203
  49. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
  50. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import logging
22
22
  import os
23
23
  import signal
24
24
  import sys
25
+ import time
25
26
  from typing import Dict, List, Optional, Tuple, Union
26
27
 
27
28
  import fastapi
@@ -52,6 +53,7 @@ from sglang.srt.managers.io_struct import (
52
53
  UpdateWeightReqInput,
53
54
  UpdateWeightReqOutput,
54
55
  )
56
+ from sglang.srt.metrics.collector import TokenizerMetricsCollector
55
57
  from sglang.srt.sampling.sampling_params import SamplingParams
56
58
  from sglang.srt.server_args import PortArgs, ServerArgs
57
59
  from sglang.srt.utils import get_zmq_socket, kill_child_process
@@ -69,6 +71,10 @@ class ReqState:
69
71
  finished: bool
70
72
  event: asyncio.Event
71
73
 
74
+ # For metrics
75
+ created_time: float
76
+ first_token_time: Optional[float] = None
77
+
72
78
 
73
79
  class TokenizerManager:
74
80
  """TokenizerManager is a process that tokenizes the text."""
@@ -80,6 +86,7 @@ class TokenizerManager:
80
86
  ):
81
87
  # Parse args
82
88
  self.server_args = server_args
89
+ self.enable_metrics = server_args.enable_metrics
83
90
 
84
91
  # Init inter-process communication
85
92
  context = zmq.asyncio.Context(2)
@@ -142,11 +149,22 @@ class TokenizerManager:
142
149
  # Others
143
150
  self.gracefully_exit = False
144
151
 
152
+ # Metrics
153
+ if self.enable_metrics:
154
+ self.metrics_collector = TokenizerMetricsCollector(
155
+ labels={
156
+ "model_name": self.server_args.served_model_name,
157
+ # TODO: Add lora name/path in the future,
158
+ },
159
+ )
160
+
145
161
  async def generate_request(
146
162
  self,
147
163
  obj: Union[GenerateReqInput, EmbeddingReqInput],
148
164
  request: Optional[fastapi.Request] = None,
149
165
  ):
166
+ created_time = time.time()
167
+
150
168
  if self.to_create_loop:
151
169
  self.create_handle_loop()
152
170
 
@@ -164,10 +182,12 @@ class TokenizerManager:
164
182
  if is_single:
165
183
  tokenized_obj = await self._tokenize_one_request(obj)
166
184
  self.send_to_scheduler.send_pyobj(tokenized_obj)
167
- async for response in self._wait_one_response(obj, request):
185
+ async for response in self._wait_one_response(obj, request, created_time):
168
186
  yield response
169
187
  else:
170
- async for response in self._handle_batch_request(obj, request):
188
+ async for response in self._handle_batch_request(
189
+ obj, request, created_time
190
+ ):
171
191
  yield response
172
192
 
173
193
  async def _tokenize_one_request(
@@ -215,7 +235,7 @@ class TokenizerManager:
215
235
  logprob_start_len,
216
236
  top_logprobs_num,
217
237
  obj.stream,
218
- obj.lora_path
238
+ obj.lora_path,
219
239
  )
220
240
  elif isinstance(obj, EmbeddingReqInput):
221
241
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -231,10 +251,11 @@ class TokenizerManager:
231
251
  self,
232
252
  obj: Union[GenerateReqInput, EmbeddingReqInput],
233
253
  request: Optional[fastapi.Request] = None,
254
+ created_time: Optional[float] = None,
234
255
  ):
235
256
  """Wait for the response of one request."""
236
257
  event = asyncio.Event()
237
- state = ReqState([], False, event)
258
+ state = ReqState([], False, event, created_time=created_time)
238
259
  self.rid_to_state[obj.rid] = state
239
260
 
240
261
  while True:
@@ -272,6 +293,7 @@ class TokenizerManager:
272
293
  self,
273
294
  obj: Union[GenerateReqInput, EmbeddingReqInput],
274
295
  request: Optional[fastapi.Request] = None,
296
+ created_time: Optional[float] = None,
275
297
  ):
276
298
  batch_size = obj.batch_size
277
299
 
@@ -283,14 +305,18 @@ class TokenizerManager:
283
305
  tmp_obj = obj[i]
284
306
  tokenized_obj = await self._tokenize_one_request(tmp_obj)
285
307
  self.send_to_scheduler.send_pyobj(tokenized_obj)
286
- generators.append(self._wait_one_response(tmp_obj, request))
308
+ generators.append(
309
+ self._wait_one_response(tmp_obj, request, created_time)
310
+ )
287
311
  rids.append(tmp_obj.rid)
288
312
  else:
289
313
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
290
314
 
291
315
  # Tokenize all requests
292
316
  objs = [obj[i] for i in range(batch_size)]
293
- tokenized_objs = await asyncio.gather(*(self._tokenize_one_request(obj) for obj in objs))
317
+ tokenized_objs = await asyncio.gather(
318
+ *(self._tokenize_one_request(obj) for obj in objs)
319
+ )
294
320
 
295
321
  # Cache the common prefix for parallel sampling
296
322
  for i in range(batch_size):
@@ -301,7 +327,9 @@ class TokenizerManager:
301
327
  tokenized_obj.sampling_params.max_new_tokens = 0
302
328
  tokenized_obj.stream = False
303
329
  self.send_to_scheduler.send_pyobj(tokenized_obj)
304
- await self._wait_one_response(tmp_obj, request).__anext__()
330
+ await self._wait_one_response(
331
+ tmp_obj, request, created_time
332
+ ).__anext__()
305
333
 
306
334
  # Expand requests, assign new rids for them, and send them
307
335
  for i in range(batch_size):
@@ -310,7 +338,9 @@ class TokenizerManager:
310
338
  tokenized_obj = copy.copy(tokenized_objs[i])
311
339
  tokenized_obj.rid = tmp_obj.regenerate_rid()
312
340
  self.send_to_scheduler.send_pyobj(tokenized_obj)
313
- generators.append(self._wait_one_response(tmp_obj, request))
341
+ generators.append(
342
+ self._wait_one_response(tmp_obj, request, created_time)
343
+ )
314
344
  rids.append(tmp_obj.rid)
315
345
 
316
346
  # Wait for all requests
@@ -322,7 +352,9 @@ class TokenizerManager:
322
352
  rid_to_index = {rid: i for i, rid in enumerate(rids)}
323
353
  task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
324
354
  while task_map:
325
- done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
355
+ done, _ = await asyncio.wait(
356
+ task_map.keys(), return_when=asyncio.FIRST_COMPLETED
357
+ )
326
358
 
327
359
  for task in done:
328
360
  gen = task_map.pop(task)
@@ -367,7 +399,7 @@ class TokenizerManager:
367
399
  if self.server_args.dp_size == 1:
368
400
  res = await self.mem_pool_size
369
401
  return res.size
370
- else: # self.server_args.dp_size > 1
402
+ else: # self.server_args.dp_size > 1
371
403
  self.mem_pool_size_tmp = []
372
404
  res = await self.mem_pool_size
373
405
  ret = [r.size for r in res]
@@ -384,11 +416,15 @@ class TokenizerManager:
384
416
  obj.load_format = self.server_args.load_format
385
417
 
386
418
  if not self.model_update_lock.locked():
387
-
419
+
388
420
  async with self.model_update_lock:
389
421
  # wait for the previous generation requests to finish
390
- while len(self.rid_to_state) > 0:
391
- await asyncio.sleep(0.001)
422
+ for i in range(3):
423
+ while len(self.rid_to_state) > 0:
424
+ await asyncio.sleep(0.001)
425
+ # FIXME: We add some sleep here to avoid some race conditions.
426
+ # We can use a read-write lock as a better fix.
427
+ await asyncio.sleep(0.01)
392
428
  self.send_to_scheduler.send_pyobj(obj)
393
429
  self.model_update_result = asyncio.Future()
394
430
 
@@ -399,7 +435,7 @@ class TokenizerManager:
399
435
  self.server_args.load_format = obj.load_format
400
436
  self.model_path = obj.model_path
401
437
  return result.success, result.message
402
- else: # self.server_args.dp_size > 1
438
+ else: # self.server_args.dp_size > 1
403
439
  self.model_update_tmp = []
404
440
  result = await self.model_update_result
405
441
 
@@ -457,7 +493,7 @@ class TokenizerManager:
457
493
  break
458
494
 
459
495
  kill_child_process(include_self=True)
460
- sys.exit(-1)
496
+ sys.exit(0)
461
497
 
462
498
  async def handle_loop(self):
463
499
  """The event loop that handles requests"""
@@ -470,7 +506,7 @@ class TokenizerManager:
470
506
  if isinstance(recv_obj, UpdateWeightReqOutput):
471
507
  if self.server_args.dp_size == 1:
472
508
  self.model_update_result.set_result(recv_obj)
473
- else: # self.server_args.dp_size > 1
509
+ else: # self.server_args.dp_size > 1
474
510
  self.model_update_tmp.append(recv_obj)
475
511
  # set future if the all results are recevied
476
512
  if len(self.model_update_tmp) == self.server_args.dp_size:
@@ -479,7 +515,7 @@ class TokenizerManager:
479
515
  elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
480
516
  if self.server_args.dp_size == 1:
481
517
  self.mem_pool_size.set_result(recv_obj)
482
- else: # self.sever_args.dp_size > 1
518
+ else: # self.sever_args.dp_size > 1
483
519
  self.mem_pool_size_tmp.append(recv_obj)
484
520
  # set future if the all results are received
485
521
  if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
@@ -516,6 +552,34 @@ class TokenizerManager:
516
552
  state.finished = recv_obj.finished_reason[i] is not None
517
553
  state.event.set()
518
554
 
555
+ if self.enable_metrics:
556
+ completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
557
+
558
+ if state.first_token_time is None:
559
+ state.first_token_time = time.time()
560
+ self.metrics_collector.observe_time_to_first_token(
561
+ state.first_token_time - state.created_time
562
+ )
563
+ else:
564
+ if completion_tokens >= 2:
565
+ self.metrics_collector.observe_time_per_output_token(
566
+ (time.time() - state.first_token_time)
567
+ / (completion_tokens - 1)
568
+ )
569
+
570
+ if state.finished:
571
+ self.metrics_collector.inc_prompt_tokens(
572
+ recv_obj.meta_info[i]["prompt_tokens"]
573
+ )
574
+ self.metrics_collector.inc_generation_tokens(completion_tokens)
575
+ self.metrics_collector.observe_e2e_request_latency(
576
+ time.time() - state.created_time
577
+ )
578
+ if completion_tokens >= 1:
579
+ self.metrics_collector.observe_time_per_output_token(
580
+ (time.time() - state.created_time) / completion_tokens
581
+ )
582
+
519
583
  def convert_logprob_style(
520
584
  self,
521
585
  ret: dict,
@@ -0,0 +1,211 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Utilities for Prometheus Metrics Collection."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Dict, Union
20
+
21
+
22
+ @dataclass
23
+ class SchedulerStats:
24
+ num_running_reqs: int = 0
25
+ num_used_tokens: int = 0
26
+ token_usage: float = 0.0
27
+ gen_throughput: float = 0.0
28
+ num_queue_reqs: int = 0
29
+ cache_hit_rate: float = 0.0
30
+
31
+
32
+ class SchedulerMetricsCollector:
33
+
34
+ def __init__(self, labels: Dict[str, str]) -> None:
35
+ # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
36
+ from prometheus_client import Gauge
37
+
38
+ self.labels = labels
39
+
40
+ self.num_running_reqs = Gauge(
41
+ name="sglang:num_running_reqs",
42
+ documentation="The number of running requests",
43
+ labelnames=labels.keys(),
44
+ multiprocess_mode="sum",
45
+ )
46
+
47
+ self.num_used_tokens = Gauge(
48
+ name="sglang:num_used_tokens",
49
+ documentation="The number of used tokens",
50
+ labelnames=labels.keys(),
51
+ multiprocess_mode="sum",
52
+ )
53
+
54
+ self.token_usage = Gauge(
55
+ name="sglang:token_usage",
56
+ documentation="The token usage",
57
+ labelnames=labels.keys(),
58
+ multiprocess_mode="mostrecent",
59
+ )
60
+
61
+ self.gen_throughput = Gauge(
62
+ name="sglang:gen_throughput",
63
+ documentation="The generate throughput (token/s)",
64
+ labelnames=labels.keys(),
65
+ multiprocess_mode="sum",
66
+ )
67
+
68
+ self.num_queue_reqs = Gauge(
69
+ name="sglang:num_queue_reqs",
70
+ documentation="The number of requests in the waiting queue",
71
+ labelnames=labels.keys(),
72
+ multiprocess_mode="sum",
73
+ )
74
+
75
+ self.cache_hit_rate = Gauge(
76
+ name="sglang:cache_hit_rate",
77
+ documentation="The cache hit rate",
78
+ labelnames=labels.keys(),
79
+ multiprocess_mode="mostrecent",
80
+ )
81
+
82
+ def _log_gauge(self, gauge, data: Union[int, float]) -> None:
83
+ # Convenience function for logging to gauge.
84
+ gauge.labels(**self.labels).set(data)
85
+
86
+ def log_stats(self, stats: SchedulerStats) -> None:
87
+ self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
88
+ self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
89
+ self._log_gauge(self.token_usage, stats.token_usage)
90
+ self._log_gauge(self.gen_throughput, stats.gen_throughput)
91
+ self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
92
+ self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
93
+
94
+
95
+ class TokenizerMetricsCollector:
96
+ def __init__(self, labels: Dict[str, str]) -> None:
97
+ # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
98
+ from prometheus_client import Counter, Histogram
99
+
100
+ self.labels = labels
101
+
102
+ self.prompt_tokens_total = Counter(
103
+ name="sglang:prompt_tokens_total",
104
+ documentation="Number of prefill tokens processed.",
105
+ labelnames=labels.keys(),
106
+ )
107
+
108
+ self.generation_tokens_total = Counter(
109
+ name="sglang:generation_tokens_total",
110
+ documentation="Number of generation tokens processed.",
111
+ labelnames=labels.keys(),
112
+ )
113
+
114
+ self.histogram_time_to_first_token = Histogram(
115
+ name="sglang:time_to_first_token_seconds",
116
+ documentation="Histogram of time to first token in seconds.",
117
+ labelnames=labels.keys(),
118
+ buckets=[
119
+ 0.001,
120
+ 0.005,
121
+ 0.01,
122
+ 0.02,
123
+ 0.04,
124
+ 0.06,
125
+ 0.08,
126
+ 0.1,
127
+ 0.25,
128
+ 0.5,
129
+ 0.75,
130
+ 1.0,
131
+ 2.5,
132
+ 5.0,
133
+ 7.5,
134
+ 10.0,
135
+ 15.0,
136
+ 20.0,
137
+ 25.0,
138
+ 30.0,
139
+ ],
140
+ )
141
+
142
+ self.histogram_time_per_output_token = Histogram(
143
+ name="sglang:time_per_output_token_seconds",
144
+ documentation="Histogram of time per output token in seconds.",
145
+ labelnames=labels.keys(),
146
+ buckets=[
147
+ 0.005,
148
+ 0.01,
149
+ 0.015,
150
+ 0.02,
151
+ 0.025,
152
+ 0.03,
153
+ 0.04,
154
+ 0.05,
155
+ 0.075,
156
+ 0.1,
157
+ 0.15,
158
+ 0.2,
159
+ 0.3,
160
+ 0.4,
161
+ 0.5,
162
+ 0.75,
163
+ 1.0,
164
+ 2.5,
165
+ ],
166
+ )
167
+
168
+ self.histogram_e2e_request_latency = Histogram(
169
+ name="sglang:e2e_request_latency_seconds",
170
+ documentation="Histogram of End-to-end request latency in seconds",
171
+ labelnames=labels.keys(),
172
+ buckets=[
173
+ 0.3,
174
+ 0.5,
175
+ 0.8,
176
+ 1.0,
177
+ 1.5,
178
+ 2.0,
179
+ 2.5,
180
+ 5.0,
181
+ 10.0,
182
+ 15.0,
183
+ 20.0,
184
+ 30.0,
185
+ 40.0,
186
+ 50.0,
187
+ 60.0,
188
+ ],
189
+ )
190
+
191
+ def _log_histogram(self, histogram, data: Union[int, float]) -> None:
192
+ histogram.labels(**self.labels).observe(data)
193
+
194
+ def _log_counter(self, counter, data: Union[int, float]) -> None:
195
+ # Convenience function for logging to counter.
196
+ counter.labels(**self.labels).inc(data)
197
+
198
+ def inc_prompt_tokens(self, value: int):
199
+ self._log_counter(self.prompt_tokens_total, value)
200
+
201
+ def inc_generation_tokens(self, value: int):
202
+ self._log_counter(self.generation_tokens_total, value)
203
+
204
+ def observe_time_to_first_token(self, value: Union[float, int]):
205
+ self._log_histogram(self.histogram_time_to_first_token, value)
206
+
207
+ def observe_time_per_output_token(self, value: Union[float, int]):
208
+ self._log_histogram(self.histogram_time_per_output_token, value)
209
+
210
+ def observe_e2e_request_latency(self, value: Union[float, int]):
211
+ self._log_histogram(self.histogram_e2e_request_latency, value)
@@ -0,0 +1,108 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Records the latency of some functions
18
+ """
19
+
20
+ import asyncio
21
+ import time
22
+ from functools import wraps
23
+ from typing import Any, Callable, List, Optional
24
+
25
+ enable_metrics = False
26
+
27
+
28
+ def enable_func_timer():
29
+ # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
30
+ from prometheus_client import Histogram
31
+
32
+ global enable_metrics, FUNC_LATENCY
33
+ enable_metrics = True
34
+
35
+ FUNC_LATENCY = Histogram(
36
+ "sglang:func_latency_seconds",
37
+ "Function latency in seconds",
38
+ # captures latency in range [50ms - ~50s]
39
+ buckets=exponential_buckets(start=0.05, width=1.5, length=18),
40
+ labelnames=["name"],
41
+ )
42
+
43
+
44
+ FUNC_LATENCY = None
45
+
46
+
47
+ def exponential_buckets(start: float, width: float, length: int) -> List[float]:
48
+ buckets = []
49
+ for i in range(length):
50
+ buckets.append(start * (width**i))
51
+ return buckets
52
+
53
+
54
+ def time_func_latency(
55
+ func: Callable = None, name: Optional[str] = None
56
+ ) -> Callable[..., Any]:
57
+ """
58
+ A decorator to observe the latency of a function's execution. Supports both sync and async functions.
59
+
60
+ NOTE: We use our own implementation of a timer decorator since prometheus_client does not support async
61
+ context manager yet.
62
+
63
+ Overhead: The overhead introduced here in case of an async function could likely be because of `await` introduced
64
+ which will return in another coroutine object creation and under heavy load could see longer wall time
65
+ (scheduling delays due to introduction of another awaitable).
66
+ """
67
+
68
+ def measure(func: Callable[..., Any]) -> Callable[..., Any]:
69
+ nonlocal name
70
+
71
+ name = name or func.__name__
72
+
73
+ @wraps(func)
74
+ async def async_wrapper(*args, **kwargs):
75
+ if not enable_metrics:
76
+ return await func(*args, **kwargs)
77
+
78
+ metric = FUNC_LATENCY
79
+ start = time.monotonic()
80
+ ret = func(*args, **kwargs)
81
+ if isinstance(ret, asyncio.Future) or asyncio.iscoroutine(ret):
82
+ try:
83
+ ret = await ret
84
+ finally:
85
+ metric.labels(name=name).observe(time.monotonic() - start)
86
+ return ret
87
+
88
+ @wraps(func)
89
+ def sync_wrapper(*args, **kwargs):
90
+ if not enable_metrics:
91
+ return func(*args, **kwargs)
92
+
93
+ metric = FUNC_LATENCY
94
+ start = time.monotonic()
95
+ try:
96
+ ret = func(*args, **kwargs)
97
+ finally:
98
+ metric.labels(name=name).observe(time.monotonic() - start)
99
+ return ret
100
+
101
+ if asyncio.iscoroutinefunction(func):
102
+ return async_wrapper
103
+ return sync_wrapper
104
+
105
+ if func:
106
+ return measure(func)
107
+ else:
108
+ return measure
sglang/srt/mm_utils.py CHANGED
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  """
18
18
  Utilities for multi-modal models.
19
19
 
20
- This python file mainly contains utilities that were used in the
20
+ This python file mainly contains utilities that were used in the
21
21
  image processing logic of llava-next including operations such as
22
22
  anyres and anyres_max
23
23
 
@@ -32,7 +32,7 @@ from sglang.srt.layers.logits_processor import (
32
32
  LogitsProcessorOutput,
33
33
  )
34
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
- from sglang.srt.utils import monkey_patch_vllm_all_gather
35
+ from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -92,7 +92,7 @@ def set_torch_compile_config():
92
92
  torch._dynamo.config.accumulated_cache_size_limit = 1024
93
93
 
94
94
 
95
- @torch.compile(dynamic=True)
95
+ @maybe_torch_compile(dynamic=True)
96
96
  def clamp_position(seq_lens):
97
97
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
98
98
 
@@ -136,8 +136,13 @@ class ForwardBatch:
136
136
  mrope_positions_list = [None] * self.seq_lens.shape[0]
137
137
  if self.forward_mode.is_decode():
138
138
  for i, _ in enumerate(mrope_positions_list):
139
+ mrope_position_delta = (
140
+ 0
141
+ if batch.image_inputs[i] is None
142
+ else batch.image_inputs[i].mrope_position_delta
143
+ )
139
144
  mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
140
- batch.mrope_positions_delta[i][0],
145
+ mrope_position_delta,
141
146
  int(self.seq_lens[i]) - 1,
142
147
  int(self.seq_lens[i]),
143
148
  )
@@ -159,7 +164,6 @@ class ForwardBatch:
159
164
  )
160
165
  ]
161
166
  ] * 3
162
- mrope_position_delta = 0
163
167
  else:
164
168
  # TODO: current qwen2-vl do not support radix cache since mrope position calculation
165
169
  mrope_positions, mrope_position_delta = (
@@ -173,8 +177,8 @@ class ForwardBatch:
173
177
  context_len=0,
174
178
  )
175
179
  )
180
+ batch.image_inputs[i].mrope_position_delta = mrope_position_delta
176
181
  mrope_positions_list[i] = mrope_positions
177
- batch.mrope_positions_delta[i].append(mrope_position_delta)
178
182
 
179
183
  self.mrope_positions = torch.concat(
180
184
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
39
39
  from vllm.model_executor.models import ModelRegistry
40
40
 
41
41
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
42
- from sglang.srt.constrained import disable_cache
43
42
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
44
43
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
45
44
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
@@ -129,6 +128,8 @@ class ModelRunner:
129
128
  if server_args.show_time_cost:
130
129
  enable_show_time_cost()
131
130
  if server_args.disable_disk_cache:
131
+ from outlines.caching import disable_cache
132
+
132
133
  disable_cache()
133
134
 
134
135
  global_server_args_dict.update(