sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,15 +15,13 @@ See the License for the specific language governing permissions and
13
15
  limitations under the License.
14
16
  """
15
17
 
16
- """Run the model with cuda graph."""
18
+ """Run the model with cuda graph and torch.compile."""
17
19
 
18
20
  import bisect
19
21
  from contextlib import contextmanager
20
- from typing import Callable, List
22
+ from typing import TYPE_CHECKING, Callable
21
23
 
22
24
  import torch
23
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
24
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
@@ -30,16 +30,13 @@ from sglang.srt.layers.logits_processor import (
30
30
  LogitsProcessor,
31
31
  LogitsProcessorOutput,
32
32
  )
33
- from sglang.srt.layers.sampler import SampleOutput
34
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
35
- from sglang.srt.model_executor.forward_batch_info import (
36
- ForwardMode,
37
- InputMetadata,
38
- update_flashinfer_indices,
39
- )
40
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
41
35
  from sglang.srt.utils import monkey_patch_vllm_all_gather
42
36
 
37
+ if TYPE_CHECKING:
38
+ from sglang.srt.model_executor.model_runner import ModelRunner
39
+
43
40
 
44
41
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
45
42
  for sub in model._modules.values():
@@ -58,6 +55,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
58
55
  def patch_model(
59
56
  model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
60
57
  ):
58
+ """Patch the model to make it compatible with with torch.compile"""
61
59
  backup_ca_comm = None
62
60
 
63
61
  try:
@@ -89,28 +87,33 @@ def set_torch_compile_config():
89
87
 
90
88
 
91
89
  class CudaGraphRunner:
92
- def __init__(
93
- self,
94
- model_runner: "ModelRunner",
95
- max_batch_size_to_capture: int,
96
- use_torch_compile: bool,
97
- disable_padding: bool,
98
- ):
90
+ """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
91
+
92
+ def __init__(self, model_runner: "ModelRunner"):
93
+ # Parse args
99
94
  self.model_runner = model_runner
100
95
  self.graphs = {}
101
96
  self.input_buffers = {}
102
97
  self.output_buffers = {}
103
98
  self.flashinfer_handlers = {}
104
99
  self.graph_memory_pool = None
105
- self.disable_padding = disable_padding
100
+ self.use_torch_compile = model_runner.server_args.enable_torch_compile
101
+ self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
102
+
103
+ # Batch sizes to capture
104
+ if self.model_runner.server_args.disable_cuda_graph_padding:
105
+ self.capture_bs = list(range(1, 32)) + [64, 128]
106
+ else:
107
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
108
+ self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
106
109
 
107
110
  # Common inputs
108
- self.max_bs = max_batch_size_to_capture
111
+ self.max_bs = max(self.capture_bs)
109
112
  self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
110
113
  self.req_pool_indices = torch.zeros(
111
114
  (self.max_bs,), dtype=torch.int32, device="cuda"
112
115
  )
113
- self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
116
+ self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
114
117
  self.position_ids_offsets = torch.ones(
115
118
  (self.max_bs,), dtype=torch.int32, device="cuda"
116
119
  )
@@ -118,56 +121,38 @@ class CudaGraphRunner:
118
121
  (self.max_bs,), dtype=torch.int32, device="cuda"
119
122
  )
120
123
 
121
- # FlashInfer inputs
122
- self.flashinfer_kv_indptr = torch.zeros(
123
- (self.max_bs + 1,), dtype=torch.int32, device="cuda"
124
- )
125
- self.flashinfer_kv_indices = torch.zeros(
126
- (self.max_bs * model_runner.model_config.context_len,),
127
- dtype=torch.int32,
128
- device="cuda",
124
+ # Attention backend
125
+ self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
126
+ self.seq_len_fill_value = (
127
+ self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
129
128
  )
130
- self.flashinfer_kv_last_page_len = torch.ones(
131
- (self.max_bs,), dtype=torch.int32, device="cuda"
132
- )
133
- if model_runner.sliding_window_size is None:
134
- self.flashinfer_workspace_buffer = (
135
- self.model_runner.flashinfer_workspace_buffer
136
- )
137
- else:
138
- self.flashinfer_workspace_buffer = (
139
- self.model_runner.flashinfer_workspace_buffer
140
- )
141
-
142
- self.flashinfer_kv_indptr = [
143
- self.flashinfer_kv_indptr,
144
- self.flashinfer_kv_indptr.clone(),
145
- ]
146
- self.flashinfer_kv_indices = [
147
- self.flashinfer_kv_indices,
148
- self.flashinfer_kv_indices.clone(),
149
- ]
150
129
 
151
- # Sampling inputs
152
- vocab_size = model_runner.model_config.vocab_size
153
- self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
154
-
155
- self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
156
-
157
- if use_torch_compile:
130
+ if self.use_torch_compile:
158
131
  set_torch_compile_config()
159
132
 
133
+ # Capture
134
+ try:
135
+ self.capture()
136
+ except RuntimeError as e:
137
+ raise Exception(
138
+ f"Capture cuda graph failed: {e}\n"
139
+ "Possible solutions:\n"
140
+ "1. disable cuda graph by --disable-cuda-graph\n"
141
+ "2. set --mem-fraction-static to a smaller value\n"
142
+ "3. disable torch compile by not using --enable-torch-compile\n"
143
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
144
+ )
145
+
160
146
  def can_run(self, batch_size: int):
161
147
  if self.disable_padding:
162
148
  return batch_size in self.graphs
163
149
  else:
164
150
  return batch_size <= self.max_bs
165
151
 
166
- def capture(self, batch_size_list: List[int]):
167
- self.batch_size_list = batch_size_list
152
+ def capture(self):
168
153
  with graph_capture() as graph_capture_context:
169
154
  self.stream = graph_capture_context.stream
170
- for bs in batch_size_list:
155
+ for bs in self.capture_bs:
171
156
  with patch_model(
172
157
  self.model_runner.model,
173
158
  bs in self.compile_bs,
@@ -175,14 +160,10 @@ class CudaGraphRunner:
175
160
  ) as forward:
176
161
  (
177
162
  graph,
178
- input_buffers,
179
163
  output_buffers,
180
- flashinfer_handler,
181
164
  ) = self.capture_one_batch_size(bs, forward)
182
165
  self.graphs[bs] = graph
183
- self.input_buffers[bs] = input_buffers
184
166
  self.output_buffers[bs] = output_buffers
185
- self.flashinfer_handlers[bs] = flashinfer_handler
186
167
 
187
168
  def capture_one_batch_size(self, bs: int, forward: Callable):
188
169
  graph = torch.cuda.CUDAGraph()
@@ -195,67 +176,26 @@ class CudaGraphRunner:
195
176
  position_ids_offsets = self.position_ids_offsets[:bs]
196
177
  out_cache_loc = self.out_cache_loc[:bs]
197
178
 
198
- # FlashInfer inputs
199
- if not _grouped_size_compiled_for_decode_kernels(
200
- self.model_runner.model_config.num_attention_heads
201
- // self.model_runner.tp_size,
202
- self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
203
- ):
204
- use_tensor_cores = True
205
- else:
206
- use_tensor_cores = False
207
- if self.model_runner.sliding_window_size is None:
208
- flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
209
- self.flashinfer_workspace_buffer,
210
- "NHD",
211
- use_cuda_graph=True,
212
- use_tensor_cores=use_tensor_cores,
213
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
214
- paged_kv_indices_buffer=self.flashinfer_kv_indices,
215
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
216
- )
217
- else:
218
- flashinfer_decode_wrapper = []
219
- for i in range(2):
220
- flashinfer_decode_wrapper.append(
221
- BatchDecodeWithPagedKVCacheWrapper(
222
- self.flashinfer_workspace_buffer,
223
- "NHD",
224
- use_cuda_graph=True,
225
- use_tensor_cores=use_tensor_cores,
226
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
227
- paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
228
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
229
- :bs
230
- ],
231
- )
232
- )
233
- update_flashinfer_indices(
234
- ForwardMode.DECODE,
235
- self.model_runner,
236
- req_pool_indices,
237
- seq_lens,
238
- None,
239
- flashinfer_decode_wrapper,
179
+ # Attention backend
180
+ self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
181
+ bs, req_pool_indices, seq_lens
240
182
  )
241
183
 
242
184
  # Run and capture
243
185
  def run_once():
244
186
  input_metadata = InputMetadata(
245
187
  forward_mode=ForwardMode.DECODE,
246
- sampling_info=self.sampling_info[:bs],
247
188
  batch_size=bs,
248
189
  req_pool_indices=req_pool_indices,
249
190
  seq_lens=seq_lens,
250
191
  req_to_token_pool=self.model_runner.req_to_token_pool,
251
192
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
193
+ attn_backend=self.model_runner.attn_backend,
252
194
  out_cache_loc=out_cache_loc,
253
195
  return_logprob=False,
254
- top_logprobs_nums=0,
196
+ top_logprobs_nums=[0] * bs,
255
197
  positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
256
- flashinfer_decode_wrapper=flashinfer_decode_wrapper,
257
198
  )
258
-
259
199
  return forward(input_ids, input_metadata.positions, input_metadata)
260
200
 
261
201
  for _ in range(2):
@@ -277,17 +217,17 @@ class CudaGraphRunner:
277
217
  self.model_runner.tp_group.barrier()
278
218
 
279
219
  self.graph_memory_pool = graph.pool()
280
- return graph, None, out, flashinfer_decode_wrapper
220
+ return graph, out
281
221
 
282
222
  def replay(self, batch: ScheduleBatch):
283
223
  assert batch.out_cache_loc is not None
284
224
  raw_bs = len(batch.reqs)
285
225
 
286
226
  # Pad
287
- index = bisect.bisect_left(self.batch_size_list, raw_bs)
288
- bs = self.batch_size_list[index]
227
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
228
+ bs = self.capture_bs[index]
289
229
  if bs != raw_bs:
290
- self.seq_lens.zero_()
230
+ self.seq_lens.fill_(self.seq_len_fill_value)
291
231
  self.position_ids_offsets.fill_(1)
292
232
  self.out_cache_loc.zero_()
293
233
 
@@ -298,24 +238,14 @@ class CudaGraphRunner:
298
238
  self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
299
239
  self.out_cache_loc[:raw_bs] = batch.out_cache_loc
300
240
 
301
- # FlashInfer inputs
302
- update_flashinfer_indices(
303
- ForwardMode.DECODE,
304
- self.model_runner,
305
- self.req_pool_indices[:bs],
306
- self.seq_lens[:bs],
307
- None,
308
- self.flashinfer_handlers[bs],
241
+ # Attention backend
242
+ self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
243
+ bs, self.req_pool_indices, self.seq_lens
309
244
  )
310
245
 
311
- # Sampling inputs
312
- self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
313
-
314
246
  # Replay
315
- torch.cuda.synchronize()
316
247
  self.graphs[bs].replay()
317
- torch.cuda.synchronize()
318
- sample_output, logits_output = self.output_buffers[bs]
248
+ logits_output = self.output_buffers[bs]
319
249
 
320
250
  # Unpad
321
251
  if bs != raw_bs:
@@ -327,11 +257,6 @@ class CudaGraphRunner:
327
257
  input_top_logprobs=None,
328
258
  output_top_logprobs=None,
329
259
  )
330
- sample_output = SampleOutput(
331
- sample_output.success[:raw_bs],
332
- sample_output.probs[:raw_bs],
333
- sample_output.batch_next_token_ids[:raw_bs],
334
- )
335
260
 
336
261
  # Extract logprobs
337
262
  if batch.return_logprob:
@@ -348,4 +273,4 @@ class CudaGraphRunner:
348
273
  logits_output.next_token_logprobs, logits_metadata
349
274
  )[1]
350
275
 
351
- return sample_output, logits_output
276
+ return logits_output