sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,15 +15,13 @@ See the License for the specific language governing permissions and
13
15
  limitations under the License.
14
16
  """
15
17
 
16
- """Run the model with cuda graph."""
18
+ """Run the model with cuda graph and torch.compile."""
17
19
 
18
20
  import bisect
19
21
  from contextlib import contextmanager
20
- from typing import Callable, List
22
+ from typing import TYPE_CHECKING, Callable
21
23
 
22
24
  import torch
23
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
24
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
@@ -30,20 +30,20 @@ from sglang.srt.layers.logits_processor import (
30
30
  LogitsProcessor,
31
31
  LogitsProcessorOutput,
32
32
  )
33
- from sglang.srt.layers.sampler import SampleOutput
34
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
35
- from sglang.srt.model_executor.forward_batch_info import (
36
- ForwardMode,
37
- InputMetadata,
38
- update_flashinfer_indices,
39
- )
40
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
41
35
  from sglang.srt.utils import monkey_patch_vllm_all_gather
42
36
 
37
+ if TYPE_CHECKING:
38
+ from sglang.srt.model_executor.model_runner import ModelRunner
39
+
43
40
 
44
41
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
45
42
  for sub in model._modules.values():
46
43
  if isinstance(sub, CustomOp):
44
+ # NOTE: FusedMoE torch native implementaiton is not efficient
45
+ if "FusedMoE" in sub.__class__.__name__:
46
+ continue
47
47
  if reverse:
48
48
  sub._forward_method = sub.forward_cuda
49
49
  setattr(sub, "is_torch_compile", False)
@@ -58,6 +58,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
58
58
  def patch_model(
59
59
  model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
60
60
  ):
61
+ """Patch the model to make it compatible with with torch.compile"""
61
62
  backup_ca_comm = None
62
63
 
63
64
  try:
@@ -89,28 +90,41 @@ def set_torch_compile_config():
89
90
 
90
91
 
91
92
  class CudaGraphRunner:
92
- def __init__(
93
- self,
94
- model_runner: "ModelRunner",
95
- max_batch_size_to_capture: int,
96
- use_torch_compile: bool,
97
- disable_padding: bool,
98
- ):
93
+ """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
94
+
95
+ def __init__(self, model_runner: "ModelRunner"):
96
+ # Parse args
99
97
  self.model_runner = model_runner
100
98
  self.graphs = {}
101
99
  self.input_buffers = {}
102
100
  self.output_buffers = {}
103
101
  self.flashinfer_handlers = {}
104
102
  self.graph_memory_pool = None
105
- self.disable_padding = disable_padding
103
+ self.use_torch_compile = model_runner.server_args.enable_torch_compile
104
+ self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
105
+
106
+ # Batch sizes to capture
107
+ if self.model_runner.server_args.disable_cuda_graph_padding:
108
+ self.capture_bs = list(range(1, 32)) + [64, 128]
109
+ else:
110
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
111
+ self.compile_bs = (
112
+ [
113
+ bs
114
+ for bs in self.capture_bs
115
+ if bs <= self.model_runner.server_args.max_torch_compile_bs
116
+ ]
117
+ if self.use_torch_compile
118
+ else []
119
+ )
106
120
 
107
121
  # Common inputs
108
- self.max_bs = max_batch_size_to_capture
122
+ self.max_bs = max(self.capture_bs)
109
123
  self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
110
124
  self.req_pool_indices = torch.zeros(
111
125
  (self.max_bs,), dtype=torch.int32, device="cuda"
112
126
  )
113
- self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
127
+ self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
114
128
  self.position_ids_offsets = torch.ones(
115
129
  (self.max_bs,), dtype=torch.int32, device="cuda"
116
130
  )
@@ -118,56 +132,38 @@ class CudaGraphRunner:
118
132
  (self.max_bs,), dtype=torch.int32, device="cuda"
119
133
  )
120
134
 
121
- # FlashInfer inputs
122
- self.flashinfer_kv_indptr = torch.zeros(
123
- (self.max_bs + 1,), dtype=torch.int32, device="cuda"
135
+ # Attention backend
136
+ self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
137
+ self.seq_len_fill_value = (
138
+ self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
124
139
  )
125
- self.flashinfer_kv_indices = torch.zeros(
126
- (self.max_bs * model_runner.model_config.context_len,),
127
- dtype=torch.int32,
128
- device="cuda",
129
- )
130
- self.flashinfer_kv_last_page_len = torch.ones(
131
- (self.max_bs,), dtype=torch.int32, device="cuda"
132
- )
133
- if model_runner.sliding_window_size is None:
134
- self.flashinfer_workspace_buffer = (
135
- self.model_runner.flashinfer_workspace_buffer
136
- )
137
- else:
138
- self.flashinfer_workspace_buffer = (
139
- self.model_runner.flashinfer_workspace_buffer
140
- )
141
-
142
- self.flashinfer_kv_indptr = [
143
- self.flashinfer_kv_indptr,
144
- self.flashinfer_kv_indptr.clone(),
145
- ]
146
- self.flashinfer_kv_indices = [
147
- self.flashinfer_kv_indices,
148
- self.flashinfer_kv_indices.clone(),
149
- ]
150
-
151
- # Sampling inputs
152
- vocab_size = model_runner.model_config.vocab_size
153
- self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
154
-
155
- self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
156
140
 
157
- if use_torch_compile:
141
+ if self.use_torch_compile:
158
142
  set_torch_compile_config()
159
143
 
144
+ # Capture
145
+ try:
146
+ self.capture()
147
+ except RuntimeError as e:
148
+ raise Exception(
149
+ f"Capture cuda graph failed: {e}\n"
150
+ "Possible solutions:\n"
151
+ "1. disable cuda graph by --disable-cuda-graph\n"
152
+ "2. set --mem-fraction-static to a smaller value\n"
153
+ "3. disable torch compile by not using --enable-torch-compile\n"
154
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
155
+ )
156
+
160
157
  def can_run(self, batch_size: int):
161
158
  if self.disable_padding:
162
159
  return batch_size in self.graphs
163
160
  else:
164
161
  return batch_size <= self.max_bs
165
162
 
166
- def capture(self, batch_size_list: List[int]):
167
- self.batch_size_list = batch_size_list
163
+ def capture(self):
168
164
  with graph_capture() as graph_capture_context:
169
165
  self.stream = graph_capture_context.stream
170
- for bs in batch_size_list:
166
+ for bs in self.capture_bs:
171
167
  with patch_model(
172
168
  self.model_runner.model,
173
169
  bs in self.compile_bs,
@@ -175,14 +171,10 @@ class CudaGraphRunner:
175
171
  ) as forward:
176
172
  (
177
173
  graph,
178
- input_buffers,
179
174
  output_buffers,
180
- flashinfer_handler,
181
175
  ) = self.capture_one_batch_size(bs, forward)
182
176
  self.graphs[bs] = graph
183
- self.input_buffers[bs] = input_buffers
184
177
  self.output_buffers[bs] = output_buffers
185
- self.flashinfer_handlers[bs] = flashinfer_handler
186
178
 
187
179
  def capture_one_batch_size(self, bs: int, forward: Callable):
188
180
  graph = torch.cuda.CUDAGraph()
@@ -195,67 +187,26 @@ class CudaGraphRunner:
195
187
  position_ids_offsets = self.position_ids_offsets[:bs]
196
188
  out_cache_loc = self.out_cache_loc[:bs]
197
189
 
198
- # FlashInfer inputs
199
- if not _grouped_size_compiled_for_decode_kernels(
200
- self.model_runner.model_config.num_attention_heads
201
- // self.model_runner.tp_size,
202
- self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
203
- ):
204
- use_tensor_cores = True
205
- else:
206
- use_tensor_cores = False
207
- if self.model_runner.sliding_window_size is None:
208
- flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
209
- self.flashinfer_workspace_buffer,
210
- "NHD",
211
- use_cuda_graph=True,
212
- use_tensor_cores=use_tensor_cores,
213
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
214
- paged_kv_indices_buffer=self.flashinfer_kv_indices,
215
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
216
- )
217
- else:
218
- flashinfer_decode_wrapper = []
219
- for i in range(2):
220
- flashinfer_decode_wrapper.append(
221
- BatchDecodeWithPagedKVCacheWrapper(
222
- self.flashinfer_workspace_buffer,
223
- "NHD",
224
- use_cuda_graph=True,
225
- use_tensor_cores=use_tensor_cores,
226
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
227
- paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
228
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
229
- :bs
230
- ],
231
- )
232
- )
233
- update_flashinfer_indices(
234
- ForwardMode.DECODE,
235
- self.model_runner,
236
- req_pool_indices,
237
- seq_lens,
238
- None,
239
- flashinfer_decode_wrapper,
190
+ # Attention backend
191
+ self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
192
+ bs, req_pool_indices, seq_lens
240
193
  )
241
194
 
242
195
  # Run and capture
243
196
  def run_once():
244
197
  input_metadata = InputMetadata(
245
198
  forward_mode=ForwardMode.DECODE,
246
- sampling_info=self.sampling_info[:bs],
247
199
  batch_size=bs,
248
200
  req_pool_indices=req_pool_indices,
249
201
  seq_lens=seq_lens,
250
202
  req_to_token_pool=self.model_runner.req_to_token_pool,
251
203
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
204
+ attn_backend=self.model_runner.attn_backend,
252
205
  out_cache_loc=out_cache_loc,
253
206
  return_logprob=False,
254
- top_logprobs_nums=0,
207
+ top_logprobs_nums=[0] * bs,
255
208
  positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
256
- flashinfer_decode_wrapper=flashinfer_decode_wrapper,
257
209
  )
258
-
259
210
  return forward(input_ids, input_metadata.positions, input_metadata)
260
211
 
261
212
  for _ in range(2):
@@ -277,17 +228,17 @@ class CudaGraphRunner:
277
228
  self.model_runner.tp_group.barrier()
278
229
 
279
230
  self.graph_memory_pool = graph.pool()
280
- return graph, None, out, flashinfer_decode_wrapper
231
+ return graph, out
281
232
 
282
233
  def replay(self, batch: ScheduleBatch):
283
234
  assert batch.out_cache_loc is not None
284
235
  raw_bs = len(batch.reqs)
285
236
 
286
237
  # Pad
287
- index = bisect.bisect_left(self.batch_size_list, raw_bs)
288
- bs = self.batch_size_list[index]
238
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
239
+ bs = self.capture_bs[index]
289
240
  if bs != raw_bs:
290
- self.seq_lens.zero_()
241
+ self.seq_lens.fill_(self.seq_len_fill_value)
291
242
  self.position_ids_offsets.fill_(1)
292
243
  self.out_cache_loc.zero_()
293
244
 
@@ -298,24 +249,14 @@ class CudaGraphRunner:
298
249
  self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
299
250
  self.out_cache_loc[:raw_bs] = batch.out_cache_loc
300
251
 
301
- # FlashInfer inputs
302
- update_flashinfer_indices(
303
- ForwardMode.DECODE,
304
- self.model_runner,
305
- self.req_pool_indices[:bs],
306
- self.seq_lens[:bs],
307
- None,
308
- self.flashinfer_handlers[bs],
252
+ # Attention backend
253
+ self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
254
+ bs, self.req_pool_indices, self.seq_lens
309
255
  )
310
256
 
311
- # Sampling inputs
312
- self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
313
-
314
257
  # Replay
315
- torch.cuda.synchronize()
316
258
  self.graphs[bs].replay()
317
- torch.cuda.synchronize()
318
- sample_output, logits_output = self.output_buffers[bs]
259
+ logits_output = self.output_buffers[bs]
319
260
 
320
261
  # Unpad
321
262
  if bs != raw_bs:
@@ -327,11 +268,6 @@ class CudaGraphRunner:
327
268
  input_top_logprobs=None,
328
269
  output_top_logprobs=None,
329
270
  )
330
- sample_output = SampleOutput(
331
- sample_output.success[:raw_bs],
332
- sample_output.probs[:raw_bs],
333
- sample_output.batch_next_token_ids[:raw_bs],
334
- )
335
271
 
336
272
  # Extract logprobs
337
273
  if batch.return_logprob:
@@ -348,4 +284,4 @@ class CudaGraphRunner:
348
284
  logits_output.next_token_logprobs, logits_metadata
349
285
  )[1]
350
286
 
351
- return sample_output, logits_output
287
+ return logits_output