sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
25
25
  from vllm.model_executor.custom_op import CustomOp
26
26
 
27
27
  from sglang.srt.layers.logits_processor import (
28
- LogitProcessorOutput,
29
28
  LogitsMetadata,
30
29
  LogitsProcessor,
30
+ LogitsProcessorOutput,
31
31
  )
32
+ from sglang.srt.layers.sampler import SampleOutput
32
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
33
34
  from sglang.srt.model_executor.forward_batch_info import (
34
35
  ForwardMode,
35
36
  InputMetadata,
36
37
  update_flashinfer_indices,
37
38
  )
39
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
38
40
  from sglang.srt.utils import monkey_patch_vllm_all_gather
39
41
 
40
42
 
@@ -84,13 +86,20 @@ def set_torch_compile_config():
84
86
 
85
87
 
86
88
  class CudaGraphRunner:
87
- def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
89
+ def __init__(
90
+ self,
91
+ model_runner,
92
+ max_batch_size_to_capture: int,
93
+ use_torch_compile: bool,
94
+ disable_padding: bool,
95
+ ):
88
96
  self.model_runner = model_runner
89
97
  self.graphs = {}
90
98
  self.input_buffers = {}
91
99
  self.output_buffers = {}
92
100
  self.flashinfer_handlers = {}
93
101
  self.graph_memory_pool = None
102
+ self.disable_padding = disable_padding
94
103
 
95
104
  # Common inputs
96
105
  self.max_bs = max_batch_size_to_capture
@@ -136,13 +145,20 @@ class CudaGraphRunner:
136
145
  self.flashinfer_kv_indices.clone(),
137
146
  ]
138
147
 
148
+ # Sampling inputs
149
+ vocab_size = model_runner.model_config.vocab_size
150
+ self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
151
+
139
152
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
140
153
 
141
154
  if use_torch_compile:
142
155
  set_torch_compile_config()
143
156
 
144
157
  def can_run(self, batch_size):
145
- return batch_size < self.max_bs
158
+ if self.disable_padding:
159
+ return batch_size in self.graphs
160
+ else:
161
+ return batch_size <= self.max_bs
146
162
 
147
163
  def capture(self, batch_size_list):
148
164
  self.batch_size_list = batch_size_list
@@ -224,6 +240,7 @@ class CudaGraphRunner:
224
240
  def run_once():
225
241
  input_metadata = InputMetadata(
226
242
  forward_mode=ForwardMode.DECODE,
243
+ sampling_info=self.sampling_info[:bs],
227
244
  batch_size=bs,
228
245
  req_pool_indices=req_pool_indices,
229
246
  seq_lens=seq_lens,
@@ -239,12 +256,23 @@ class CudaGraphRunner:
239
256
  return forward(input_ids, input_metadata.positions, input_metadata)
240
257
 
241
258
  for _ in range(2):
259
+ torch.cuda.synchronize()
260
+ self.model_runner.tp_group.barrier()
261
+
242
262
  run_once()
243
263
 
264
+ torch.cuda.synchronize()
265
+ self.model_runner.tp_group.barrier()
266
+
244
267
  torch.cuda.synchronize()
268
+ self.model_runner.tp_group.barrier()
269
+
245
270
  with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
246
271
  out = run_once()
272
+
247
273
  torch.cuda.synchronize()
274
+ self.model_runner.tp_group.barrier()
275
+
248
276
  self.graph_memory_pool = graph.pool()
249
277
  return graph, None, out, flashinfer_decode_wrapper
250
278
 
@@ -277,25 +305,35 @@ class CudaGraphRunner:
277
305
  self.flashinfer_handlers[bs],
278
306
  )
279
307
 
308
+ # Sampling inputs
309
+ self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
310
+
280
311
  # Replay
312
+ torch.cuda.synchronize()
281
313
  self.graphs[bs].replay()
282
- output = self.output_buffers[bs]
314
+ torch.cuda.synchronize()
315
+ sample_output, logits_output = self.output_buffers[bs]
283
316
 
284
317
  # Unpad
285
318
  if bs != raw_bs:
286
- output = LogitProcessorOutput(
287
- next_token_logits=output.next_token_logits[:raw_bs],
319
+ logits_output = LogitsProcessorOutput(
320
+ next_token_logits=logits_output.next_token_logits[:raw_bs],
288
321
  next_token_logprobs=None,
289
322
  normalized_prompt_logprobs=None,
290
323
  input_token_logprobs=None,
291
324
  input_top_logprobs=None,
292
325
  output_top_logprobs=None,
293
326
  )
327
+ sample_output = SampleOutput(
328
+ sample_output.success[:raw_bs],
329
+ sample_output.probs[:raw_bs],
330
+ sample_output.batch_next_token_ids[:raw_bs],
331
+ )
294
332
 
295
333
  # Extract logprobs
296
334
  if batch.return_logprob:
297
- output.next_token_logprobs = torch.nn.functional.log_softmax(
298
- output.next_token_logits, dim=-1
335
+ logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
336
+ logits_output.next_token_logits, dim=-1
299
337
  )
300
338
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
301
339
  if return_top_logprob:
@@ -303,8 +341,8 @@ class CudaGraphRunner:
303
341
  forward_mode=ForwardMode.DECODE,
304
342
  top_logprobs_nums=batch.top_logprobs_nums,
305
343
  )
306
- output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
307
- output.next_token_logprobs, logits_metadata
344
+ logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
345
+ logits_output.next_token_logprobs, logits_metadata
308
346
  )[1]
309
347
 
310
- return output
348
+ return sample_output, logits_output
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,7 +18,7 @@ limitations under the License.
16
18
  """ModelRunner runs the forward passes of the models."""
17
19
  from dataclasses import dataclass
18
20
  from enum import IntEnum, auto
19
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List
20
22
 
21
23
  import numpy as np
22
24
  import torch
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
26
28
 
27
29
  if TYPE_CHECKING:
28
30
  from sglang.srt.model_executor.model_runner import ModelRunner
31
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
29
32
 
30
33
 
31
34
  class ForwardMode(IntEnum):
@@ -42,6 +45,7 @@ class InputMetadata:
42
45
  """Store all inforamtion of a forward pass."""
43
46
 
44
47
  forward_mode: ForwardMode
48
+ sampling_info: SamplingBatchInfo
45
49
  batch_size: int
46
50
  req_pool_indices: torch.Tensor
47
51
  seq_lens: torch.Tensor
@@ -61,9 +65,11 @@ class InputMetadata:
61
65
  extend_start_loc: torch.Tensor = None
62
66
  extend_no_prefix: bool = None
63
67
 
64
- # Output options
68
+ # For logprob
65
69
  return_logprob: bool = False
66
70
  top_logprobs_nums: List[int] = None
71
+ extend_seq_lens_cpu: List[int] = None
72
+ logprob_start_lens_cpu: List[int] = None
67
73
 
68
74
  # For multimodal
69
75
  pixel_values: List[torch.Tensor] = None
@@ -86,14 +92,19 @@ class InputMetadata:
86
92
  reqs = batch.reqs
87
93
  self.pixel_values = [r.pixel_values for r in reqs]
88
94
  self.image_sizes = [r.image_size for r in reqs]
89
- self.image_offsets = [
90
- (
91
- (r.image_offset - len(r.prefix_indices))
92
- if r.image_offset is not None
93
- else 0
94
- )
95
- for r in reqs
96
- ]
95
+ self.image_offsets = []
96
+ for r in reqs:
97
+ if isinstance(r.image_offset, list):
98
+ self.image_offsets.append(
99
+ [
100
+ (image_offset - len(r.prefix_indices))
101
+ for image_offset in r.image_offset
102
+ ]
103
+ )
104
+ elif isinstance(r.image_offset, int):
105
+ self.image_offsets.append(r.image_offset - len(r.prefix_indices))
106
+ elif r.image_offset is None:
107
+ self.image_offsets.append(0)
97
108
 
98
109
  def compute_positions(self, batch: ScheduleBatch):
99
110
  position_ids_offsets = batch.position_ids_offsets
@@ -109,8 +120,8 @@ class InputMetadata:
109
120
  self.positions = torch.tensor(
110
121
  np.concatenate(
111
122
  [
112
- np.arange(len(req.prefix_indices), len(req.fill_ids))
113
- for req in batch.reqs
123
+ np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
124
+ for i, req in enumerate(batch.reqs)
114
125
  ],
115
126
  axis=0,
116
127
  ),
@@ -123,7 +134,7 @@ class InputMetadata:
123
134
  np.concatenate(
124
135
  [
125
136
  np.arange(
126
- len(req.prefix_indices) + position_ids_offsets_cpu[i],
137
+ batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
127
138
  len(req.fill_ids) + position_ids_offsets_cpu[i],
128
139
  )
129
140
  for i, req in enumerate(batch.reqs)
@@ -139,14 +150,29 @@ class InputMetadata:
139
150
  def compute_extend_infos(self, batch: ScheduleBatch):
140
151
  if self.forward_mode == ForwardMode.DECODE:
141
152
  self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
153
+ self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
142
154
  else:
143
155
  extend_lens_cpu = [
144
- len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
156
+ len(r.fill_ids) - batch.prefix_lens_cpu[i]
157
+ for i, r in enumerate(batch.reqs)
145
158
  ]
146
159
  self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
147
160
  self.extend_start_loc = torch.zeros_like(self.seq_lens)
148
161
  self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
149
- self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
162
+ self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
163
+
164
+ self.extend_seq_lens_cpu = extend_lens_cpu
165
+ self.logprob_start_lens_cpu = [
166
+ (
167
+ min(
168
+ req.logprob_start_len - batch.prefix_lens_cpu[i],
169
+ extend_lens_cpu[i] - 1,
170
+ )
171
+ if req.logprob_start_len >= batch.prefix_lens_cpu[i]
172
+ else extend_lens_cpu[i] - 1 # Fake extend, actually decode
173
+ )
174
+ for i, req in enumerate(batch.reqs)
175
+ ]
150
176
 
151
177
  @classmethod
152
178
  def from_schedule_batch(
@@ -157,6 +183,7 @@ class InputMetadata:
157
183
  ):
158
184
  ret = cls(
159
185
  forward_mode=forward_mode,
186
+ sampling_info=batch.sampling_info,
160
187
  batch_size=batch.batch_size(),
161
188
  req_pool_indices=batch.req_pool_indices,
162
189
  seq_lens=batch.seq_lens,
@@ -167,6 +194,8 @@ class InputMetadata:
167
194
  top_logprobs_nums=batch.top_logprobs_nums,
168
195
  )
169
196
 
197
+ ret.sampling_info.prepare_penalties()
198
+
170
199
  ret.compute_positions(batch)
171
200
 
172
201
  ret.compute_extend_infos(batch)
@@ -180,14 +209,8 @@ class InputMetadata:
180
209
  if forward_mode != ForwardMode.DECODE:
181
210
  ret.init_multimuldal_info(batch)
182
211
 
183
- prefix_lens = None
184
- if forward_mode != ForwardMode.DECODE:
185
- prefix_lens = torch.tensor(
186
- [len(r.prefix_indices) for r in batch.reqs], device="cuda"
187
- )
188
-
189
212
  if model_runner.server_args.disable_flashinfer:
190
- ret.init_triton_args(batch, prefix_lens)
213
+ ret.init_triton_args(batch)
191
214
 
192
215
  flashinfer_use_ragged = False
193
216
  if not model_runner.server_args.disable_flashinfer:
@@ -198,30 +221,35 @@ class InputMetadata:
198
221
  ):
199
222
  flashinfer_use_ragged = True
200
223
  ret.init_flashinfer_handlers(
201
- model_runner, prefix_lens, flashinfer_use_ragged
224
+ model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
202
225
  )
203
226
 
204
227
  return ret
205
228
 
206
- def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
229
+ def init_triton_args(self, batch: ScheduleBatch):
207
230
  """Init auxiliary variables for triton attention backend."""
208
231
  self.triton_max_seq_len = int(torch.max(self.seq_lens))
209
- self.triton_prefix_lens = prefix_lens
210
232
  self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
211
233
  self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
212
234
 
213
235
  if self.forward_mode == ForwardMode.DECODE:
214
236
  self.triton_max_extend_len = None
215
237
  else:
216
- extend_seq_lens = self.seq_lens - prefix_lens
238
+ self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
239
+ extend_seq_lens = self.seq_lens - self.triton_prefix_lens
217
240
  self.triton_max_extend_len = int(torch.max(extend_seq_lens))
218
241
 
219
242
  def init_flashinfer_handlers(
220
243
  self,
221
244
  model_runner,
222
- prefix_lens,
245
+ prefix_lens_cpu,
223
246
  flashinfer_use_ragged,
224
247
  ):
248
+ if self.forward_mode != ForwardMode.DECODE:
249
+ prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
250
+ else:
251
+ prefix_lens = None
252
+
225
253
  update_flashinfer_indices(
226
254
  self.forward_mode,
227
255
  model_runner,
@@ -294,6 +322,8 @@ def update_flashinfer_indices(
294
322
  num_kv_heads,
295
323
  head_dim,
296
324
  1,
325
+ data_type=model_runner.kv_cache_dtype,
326
+ q_data_type=model_runner.dtype,
297
327
  )
298
328
  else:
299
329
  # extend part
@@ -372,6 +402,8 @@ def update_flashinfer_indices(
372
402
  num_kv_heads,
373
403
  head_dim,
374
404
  1,
405
+ data_type=model_runner.kv_cache_dtype,
406
+ q_data_type=model_runner.dtype,
375
407
  )
376
408
  else:
377
409
  # extend part