sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import bisect
19
19
  from contextlib import contextmanager
20
+ from typing import Callable, List
20
21
 
21
22
  import torch
22
23
  from flashinfer import BatchDecodeWithPagedKVCacheWrapper
@@ -51,12 +52,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
51
52
 
52
53
  @contextmanager
53
54
  def patch_model(
54
- model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
55
+ model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
55
56
  ):
56
57
  backup_ca_comm = None
57
58
 
58
59
  try:
59
- if use_compile:
60
+ if enable_compile:
60
61
  _to_torch(model)
61
62
  monkey_patch_vllm_all_gather()
62
63
  backup_ca_comm = tp_group.ca_comm
@@ -65,7 +66,7 @@ def patch_model(
65
66
  else:
66
67
  yield model.forward
67
68
  finally:
68
- if use_compile:
69
+ if enable_compile:
69
70
  _to_torch(model, reverse=True)
70
71
  monkey_patch_vllm_all_gather(reverse=True)
71
72
  tp_group.ca_comm = backup_ca_comm
@@ -84,13 +85,20 @@ def set_torch_compile_config():
84
85
 
85
86
 
86
87
  class CudaGraphRunner:
87
- def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
88
+ def __init__(
89
+ self,
90
+ model_runner: "ModelRunner",
91
+ max_batch_size_to_capture: int,
92
+ use_torch_compile: bool,
93
+ disable_padding: bool,
94
+ ):
88
95
  self.model_runner = model_runner
89
96
  self.graphs = {}
90
97
  self.input_buffers = {}
91
98
  self.output_buffers = {}
92
99
  self.flashinfer_handlers = {}
93
100
  self.graph_memory_pool = None
101
+ self.disable_padding = disable_padding
94
102
 
95
103
  # Common inputs
96
104
  self.max_bs = max_batch_size_to_capture
@@ -141,10 +149,13 @@ class CudaGraphRunner:
141
149
  if use_torch_compile:
142
150
  set_torch_compile_config()
143
151
 
144
- def can_run(self, batch_size):
145
- return batch_size < self.max_bs
152
+ def can_run(self, batch_size: int):
153
+ if self.disable_padding:
154
+ return batch_size in self.graphs
155
+ else:
156
+ return batch_size <= self.max_bs
146
157
 
147
- def capture(self, batch_size_list):
158
+ def capture(self, batch_size_list: List[int]):
148
159
  self.batch_size_list = batch_size_list
149
160
  with graph_capture() as graph_capture_context:
150
161
  self.stream = graph_capture_context.stream
@@ -165,7 +176,7 @@ class CudaGraphRunner:
165
176
  self.output_buffers[bs] = output_buffers
166
177
  self.flashinfer_handlers[bs] = flashinfer_handler
167
178
 
168
- def capture_one_batch_size(self, bs, forward):
179
+ def capture_one_batch_size(self, bs: int, forward: Callable):
169
180
  graph = torch.cuda.CUDAGraph()
170
181
  stream = self.stream
171
182
 
@@ -239,12 +250,23 @@ class CudaGraphRunner:
239
250
  return forward(input_ids, input_metadata.positions, input_metadata)
240
251
 
241
252
  for _ in range(2):
253
+ torch.cuda.synchronize()
254
+ self.model_runner.tp_group.barrier()
255
+
242
256
  run_once()
243
257
 
258
+ torch.cuda.synchronize()
259
+ self.model_runner.tp_group.barrier()
260
+
244
261
  torch.cuda.synchronize()
262
+ self.model_runner.tp_group.barrier()
263
+
245
264
  with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
246
265
  out = run_once()
266
+
247
267
  torch.cuda.synchronize()
268
+ self.model_runner.tp_group.barrier()
269
+
248
270
  self.graph_memory_pool = graph.pool()
249
271
  return graph, None, out, flashinfer_decode_wrapper
250
272
 
@@ -278,7 +300,9 @@ class CudaGraphRunner:
278
300
  )
279
301
 
280
302
  # Replay
303
+ torch.cuda.synchronize()
281
304
  self.graphs[bs].replay()
305
+ torch.cuda.synchronize()
282
306
  output = self.output_buffers[bs]
283
307
 
284
308
  # Unpad
@@ -61,9 +61,11 @@ class InputMetadata:
61
61
  extend_start_loc: torch.Tensor = None
62
62
  extend_no_prefix: bool = None
63
63
 
64
- # Output options
64
+ # For logprob
65
65
  return_logprob: bool = False
66
66
  top_logprobs_nums: List[int] = None
67
+ extend_seq_lens_cpu: List[int] = None
68
+ logprob_start_lens_cpu: List[int] = None
67
69
 
68
70
  # For multimodal
69
71
  pixel_values: List[torch.Tensor] = None
@@ -86,14 +88,19 @@ class InputMetadata:
86
88
  reqs = batch.reqs
87
89
  self.pixel_values = [r.pixel_values for r in reqs]
88
90
  self.image_sizes = [r.image_size for r in reqs]
89
- self.image_offsets = [
90
- (
91
- (r.image_offset - len(r.prefix_indices))
92
- if r.image_offset is not None
93
- else 0
94
- )
95
- for r in reqs
96
- ]
91
+ self.image_offsets = []
92
+ for r in reqs:
93
+ if isinstance(r.image_offset, list):
94
+ self.image_offsets.append(
95
+ [
96
+ (image_offset - len(r.prefix_indices))
97
+ for image_offset in r.image_offset
98
+ ]
99
+ )
100
+ elif isinstance(r.image_offset, int):
101
+ self.image_offsets.append(r.image_offset - len(r.prefix_indices))
102
+ elif r.image_offset is None:
103
+ self.image_offsets.append(0)
97
104
 
98
105
  def compute_positions(self, batch: ScheduleBatch):
99
106
  position_ids_offsets = batch.position_ids_offsets
@@ -109,8 +116,8 @@ class InputMetadata:
109
116
  self.positions = torch.tensor(
110
117
  np.concatenate(
111
118
  [
112
- np.arange(len(req.prefix_indices), len(req.fill_ids))
113
- for req in batch.reqs
119
+ np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
120
+ for i, req in enumerate(batch.reqs)
114
121
  ],
115
122
  axis=0,
116
123
  ),
@@ -123,7 +130,7 @@ class InputMetadata:
123
130
  np.concatenate(
124
131
  [
125
132
  np.arange(
126
- len(req.prefix_indices) + position_ids_offsets_cpu[i],
133
+ batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
127
134
  len(req.fill_ids) + position_ids_offsets_cpu[i],
128
135
  )
129
136
  for i, req in enumerate(batch.reqs)
@@ -139,14 +146,29 @@ class InputMetadata:
139
146
  def compute_extend_infos(self, batch: ScheduleBatch):
140
147
  if self.forward_mode == ForwardMode.DECODE:
141
148
  self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
149
+ self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
142
150
  else:
143
151
  extend_lens_cpu = [
144
- len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
152
+ len(r.fill_ids) - batch.prefix_lens_cpu[i]
153
+ for i, r in enumerate(batch.reqs)
145
154
  ]
146
155
  self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
147
156
  self.extend_start_loc = torch.zeros_like(self.seq_lens)
148
157
  self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
149
- self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
158
+ self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
159
+
160
+ self.extend_seq_lens_cpu = extend_lens_cpu
161
+ self.logprob_start_lens_cpu = [
162
+ (
163
+ min(
164
+ req.logprob_start_len - batch.prefix_lens_cpu[i],
165
+ extend_lens_cpu[i] - 1,
166
+ )
167
+ if req.logprob_start_len >= batch.prefix_lens_cpu[i]
168
+ else extend_lens_cpu[i] - 1 # Fake extend, actually decode
169
+ )
170
+ for i, req in enumerate(batch.reqs)
171
+ ]
150
172
 
151
173
  @classmethod
152
174
  def from_schedule_batch(
@@ -180,14 +202,8 @@ class InputMetadata:
180
202
  if forward_mode != ForwardMode.DECODE:
181
203
  ret.init_multimuldal_info(batch)
182
204
 
183
- prefix_lens = None
184
- if forward_mode != ForwardMode.DECODE:
185
- prefix_lens = torch.tensor(
186
- [len(r.prefix_indices) for r in batch.reqs], device="cuda"
187
- )
188
-
189
205
  if model_runner.server_args.disable_flashinfer:
190
- ret.init_triton_args(batch, prefix_lens)
206
+ ret.init_triton_args(batch)
191
207
 
192
208
  flashinfer_use_ragged = False
193
209
  if not model_runner.server_args.disable_flashinfer:
@@ -198,30 +214,35 @@ class InputMetadata:
198
214
  ):
199
215
  flashinfer_use_ragged = True
200
216
  ret.init_flashinfer_handlers(
201
- model_runner, prefix_lens, flashinfer_use_ragged
217
+ model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
202
218
  )
203
219
 
204
220
  return ret
205
221
 
206
- def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
222
+ def init_triton_args(self, batch: ScheduleBatch):
207
223
  """Init auxiliary variables for triton attention backend."""
208
224
  self.triton_max_seq_len = int(torch.max(self.seq_lens))
209
- self.triton_prefix_lens = prefix_lens
210
225
  self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
211
226
  self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
212
227
 
213
228
  if self.forward_mode == ForwardMode.DECODE:
214
229
  self.triton_max_extend_len = None
215
230
  else:
216
- extend_seq_lens = self.seq_lens - prefix_lens
231
+ self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
232
+ extend_seq_lens = self.seq_lens - self.triton_prefix_lens
217
233
  self.triton_max_extend_len = int(torch.max(extend_seq_lens))
218
234
 
219
235
  def init_flashinfer_handlers(
220
236
  self,
221
237
  model_runner,
222
- prefix_lens,
238
+ prefix_lens_cpu,
223
239
  flashinfer_use_ragged,
224
240
  ):
241
+ if self.forward_mode != ForwardMode.DECODE:
242
+ prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
243
+ else:
244
+ prefix_lens = None
245
+
225
246
  update_flashinfer_indices(
226
247
  self.forward_mode,
227
248
  model_runner,
@@ -294,6 +315,8 @@ def update_flashinfer_indices(
294
315
  num_kv_heads,
295
316
  head_dim,
296
317
  1,
318
+ data_type=model_runner.kv_cache_dtype,
319
+ q_data_type=model_runner.dtype,
297
320
  )
298
321
  else:
299
322
  # extend part
@@ -372,6 +395,8 @@ def update_flashinfer_indices(
372
395
  num_kv_heads,
373
396
  head_dim,
374
397
  1,
398
+ data_type=model_runner.kv_cache_dtype,
399
+ q_data_type=model_runner.dtype,
375
400
  )
376
401
  else:
377
402
  # extend part