sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,15 +15,13 @@ See the License for the specific language governing permissions and
13
15
  limitations under the License.
14
16
  """
15
17
 
16
- """Run the model with cuda graph."""
18
+ """Run the model with cuda graph and torch.compile."""
17
19
 
18
20
  import bisect
19
21
  from contextlib import contextmanager
20
- from typing import Callable, List
22
+ from typing import TYPE_CHECKING, Callable
21
23
 
22
24
  import torch
23
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
24
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
@@ -30,24 +30,23 @@ from sglang.srt.layers.logits_processor import (
30
30
  LogitsProcessor,
31
31
  LogitsProcessorOutput,
32
32
  )
33
- from sglang.srt.layers.sampler import SampleOutput
34
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
35
- from sglang.srt.model_executor.forward_batch_info import (
36
- ForwardMode,
37
- InputMetadata,
38
- update_flashinfer_indices,
39
- )
40
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
41
35
  from sglang.srt.utils import monkey_patch_vllm_all_gather
42
36
 
37
+ if TYPE_CHECKING:
38
+ from sglang.srt.model_executor.model_runner import ModelRunner
39
+
43
40
 
44
41
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
45
42
  for sub in model._modules.values():
46
43
  if isinstance(sub, CustomOp):
47
44
  if reverse:
48
45
  sub._forward_method = sub.forward_cuda
46
+ setattr(sub, "is_torch_compile", False)
49
47
  else:
50
48
  sub._forward_method = sub.forward_native
49
+ setattr(sub, "is_torch_compile", True)
51
50
  if isinstance(sub, torch.nn.Module):
52
51
  _to_torch(sub, reverse)
53
52
 
@@ -56,6 +55,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
56
55
  def patch_model(
57
56
  model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
58
57
  ):
58
+ """Patch the model to make it compatible with with torch.compile"""
59
59
  backup_ca_comm = None
60
60
 
61
61
  try:
@@ -87,28 +87,33 @@ def set_torch_compile_config():
87
87
 
88
88
 
89
89
  class CudaGraphRunner:
90
- def __init__(
91
- self,
92
- model_runner: "ModelRunner",
93
- max_batch_size_to_capture: int,
94
- use_torch_compile: bool,
95
- disable_padding: bool,
96
- ):
90
+ """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
91
+
92
+ def __init__(self, model_runner: "ModelRunner"):
93
+ # Parse args
97
94
  self.model_runner = model_runner
98
95
  self.graphs = {}
99
96
  self.input_buffers = {}
100
97
  self.output_buffers = {}
101
98
  self.flashinfer_handlers = {}
102
99
  self.graph_memory_pool = None
103
- self.disable_padding = disable_padding
100
+ self.use_torch_compile = model_runner.server_args.enable_torch_compile
101
+ self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
102
+
103
+ # Batch sizes to capture
104
+ if self.model_runner.server_args.disable_cuda_graph_padding:
105
+ self.capture_bs = list(range(1, 32)) + [64, 128]
106
+ else:
107
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
108
+ self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
104
109
 
105
110
  # Common inputs
106
- self.max_bs = max_batch_size_to_capture
111
+ self.max_bs = max(self.capture_bs)
107
112
  self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
108
113
  self.req_pool_indices = torch.zeros(
109
114
  (self.max_bs,), dtype=torch.int32, device="cuda"
110
115
  )
111
- self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
116
+ self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
112
117
  self.position_ids_offsets = torch.ones(
113
118
  (self.max_bs,), dtype=torch.int32, device="cuda"
114
119
  )
@@ -116,56 +121,38 @@ class CudaGraphRunner:
116
121
  (self.max_bs,), dtype=torch.int32, device="cuda"
117
122
  )
118
123
 
119
- # FlashInfer inputs
120
- self.flashinfer_kv_indptr = torch.zeros(
121
- (self.max_bs + 1,), dtype=torch.int32, device="cuda"
122
- )
123
- self.flashinfer_kv_indices = torch.zeros(
124
- (self.max_bs * model_runner.model_config.context_len,),
125
- dtype=torch.int32,
126
- device="cuda",
124
+ # Attention backend
125
+ self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
126
+ self.seq_len_fill_value = (
127
+ self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
127
128
  )
128
- self.flashinfer_kv_last_page_len = torch.ones(
129
- (self.max_bs,), dtype=torch.int32, device="cuda"
130
- )
131
- if model_runner.sliding_window_size is None:
132
- self.flashinfer_workspace_buffer = (
133
- self.model_runner.flashinfer_workspace_buffer
134
- )
135
- else:
136
- self.flashinfer_workspace_buffer = (
137
- self.model_runner.flashinfer_workspace_buffer
138
- )
139
-
140
- self.flashinfer_kv_indptr = [
141
- self.flashinfer_kv_indptr,
142
- self.flashinfer_kv_indptr.clone(),
143
- ]
144
- self.flashinfer_kv_indices = [
145
- self.flashinfer_kv_indices,
146
- self.flashinfer_kv_indices.clone(),
147
- ]
148
129
 
149
- # Sampling inputs
150
- vocab_size = model_runner.model_config.vocab_size
151
- self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
152
-
153
- self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
154
-
155
- if use_torch_compile:
130
+ if self.use_torch_compile:
156
131
  set_torch_compile_config()
157
132
 
133
+ # Capture
134
+ try:
135
+ self.capture()
136
+ except RuntimeError as e:
137
+ raise Exception(
138
+ f"Capture cuda graph failed: {e}\n"
139
+ "Possible solutions:\n"
140
+ "1. disable cuda graph by --disable-cuda-graph\n"
141
+ "2. set --mem-fraction-static to a smaller value\n"
142
+ "3. disable torch compile by not using --enable-torch-compile\n"
143
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
144
+ )
145
+
158
146
  def can_run(self, batch_size: int):
159
147
  if self.disable_padding:
160
148
  return batch_size in self.graphs
161
149
  else:
162
150
  return batch_size <= self.max_bs
163
151
 
164
- def capture(self, batch_size_list: List[int]):
165
- self.batch_size_list = batch_size_list
152
+ def capture(self):
166
153
  with graph_capture() as graph_capture_context:
167
154
  self.stream = graph_capture_context.stream
168
- for bs in batch_size_list:
155
+ for bs in self.capture_bs:
169
156
  with patch_model(
170
157
  self.model_runner.model,
171
158
  bs in self.compile_bs,
@@ -173,14 +160,10 @@ class CudaGraphRunner:
173
160
  ) as forward:
174
161
  (
175
162
  graph,
176
- input_buffers,
177
163
  output_buffers,
178
- flashinfer_handler,
179
164
  ) = self.capture_one_batch_size(bs, forward)
180
165
  self.graphs[bs] = graph
181
- self.input_buffers[bs] = input_buffers
182
166
  self.output_buffers[bs] = output_buffers
183
- self.flashinfer_handlers[bs] = flashinfer_handler
184
167
 
185
168
  def capture_one_batch_size(self, bs: int, forward: Callable):
186
169
  graph = torch.cuda.CUDAGraph()
@@ -193,67 +176,26 @@ class CudaGraphRunner:
193
176
  position_ids_offsets = self.position_ids_offsets[:bs]
194
177
  out_cache_loc = self.out_cache_loc[:bs]
195
178
 
196
- # FlashInfer inputs
197
- if not _grouped_size_compiled_for_decode_kernels(
198
- self.model_runner.model_config.num_attention_heads
199
- // self.model_runner.tp_size,
200
- self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
201
- ):
202
- use_tensor_cores = True
203
- else:
204
- use_tensor_cores = False
205
- if self.model_runner.sliding_window_size is None:
206
- flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
207
- self.flashinfer_workspace_buffer,
208
- "NHD",
209
- use_cuda_graph=True,
210
- use_tensor_cores=use_tensor_cores,
211
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
212
- paged_kv_indices_buffer=self.flashinfer_kv_indices,
213
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
214
- )
215
- else:
216
- flashinfer_decode_wrapper = []
217
- for i in range(2):
218
- flashinfer_decode_wrapper.append(
219
- BatchDecodeWithPagedKVCacheWrapper(
220
- self.flashinfer_workspace_buffer,
221
- "NHD",
222
- use_cuda_graph=True,
223
- use_tensor_cores=use_tensor_cores,
224
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
225
- paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
226
- paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
227
- :bs
228
- ],
229
- )
230
- )
231
- update_flashinfer_indices(
232
- ForwardMode.DECODE,
233
- self.model_runner,
234
- req_pool_indices,
235
- seq_lens,
236
- None,
237
- flashinfer_decode_wrapper,
179
+ # Attention backend
180
+ self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
181
+ bs, req_pool_indices, seq_lens
238
182
  )
239
183
 
240
184
  # Run and capture
241
185
  def run_once():
242
186
  input_metadata = InputMetadata(
243
187
  forward_mode=ForwardMode.DECODE,
244
- sampling_info=self.sampling_info[:bs],
245
188
  batch_size=bs,
246
189
  req_pool_indices=req_pool_indices,
247
190
  seq_lens=seq_lens,
248
191
  req_to_token_pool=self.model_runner.req_to_token_pool,
249
192
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
193
+ attn_backend=self.model_runner.attn_backend,
250
194
  out_cache_loc=out_cache_loc,
251
195
  return_logprob=False,
252
- top_logprobs_nums=0,
196
+ top_logprobs_nums=[0] * bs,
253
197
  positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
254
- flashinfer_decode_wrapper=flashinfer_decode_wrapper,
255
198
  )
256
-
257
199
  return forward(input_ids, input_metadata.positions, input_metadata)
258
200
 
259
201
  for _ in range(2):
@@ -275,17 +217,17 @@ class CudaGraphRunner:
275
217
  self.model_runner.tp_group.barrier()
276
218
 
277
219
  self.graph_memory_pool = graph.pool()
278
- return graph, None, out, flashinfer_decode_wrapper
220
+ return graph, out
279
221
 
280
222
  def replay(self, batch: ScheduleBatch):
281
223
  assert batch.out_cache_loc is not None
282
224
  raw_bs = len(batch.reqs)
283
225
 
284
226
  # Pad
285
- index = bisect.bisect_left(self.batch_size_list, raw_bs)
286
- bs = self.batch_size_list[index]
227
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
228
+ bs = self.capture_bs[index]
287
229
  if bs != raw_bs:
288
- self.seq_lens.zero_()
230
+ self.seq_lens.fill_(self.seq_len_fill_value)
289
231
  self.position_ids_offsets.fill_(1)
290
232
  self.out_cache_loc.zero_()
291
233
 
@@ -296,24 +238,14 @@ class CudaGraphRunner:
296
238
  self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
297
239
  self.out_cache_loc[:raw_bs] = batch.out_cache_loc
298
240
 
299
- # FlashInfer inputs
300
- update_flashinfer_indices(
301
- ForwardMode.DECODE,
302
- self.model_runner,
303
- self.req_pool_indices[:bs],
304
- self.seq_lens[:bs],
305
- None,
306
- self.flashinfer_handlers[bs],
241
+ # Attention backend
242
+ self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
243
+ bs, self.req_pool_indices, self.seq_lens
307
244
  )
308
245
 
309
- # Sampling inputs
310
- self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
311
-
312
246
  # Replay
313
- torch.cuda.synchronize()
314
247
  self.graphs[bs].replay()
315
- torch.cuda.synchronize()
316
- sample_output, logits_output = self.output_buffers[bs]
248
+ logits_output = self.output_buffers[bs]
317
249
 
318
250
  # Unpad
319
251
  if bs != raw_bs:
@@ -325,11 +257,6 @@ class CudaGraphRunner:
325
257
  input_top_logprobs=None,
326
258
  output_top_logprobs=None,
327
259
  )
328
- sample_output = SampleOutput(
329
- sample_output.success[:raw_bs],
330
- sample_output.probs[:raw_bs],
331
- sample_output.batch_next_token_ids[:raw_bs],
332
- )
333
260
 
334
261
  # Extract logprobs
335
262
  if batch.return_logprob:
@@ -346,4 +273,4 @@ class CudaGraphRunner:
346
273
  logits_output.next_token_logprobs, logits_metadata
347
274
  )[1]
348
275
 
349
- return sample_output, logits_output
276
+ return logits_output