sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -15,22 +15,19 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- """ModelRunner runs the forward passes of the models."""
18
+ """Meta data for a forward pass."""
19
19
  from dataclasses import dataclass
20
20
  from enum import IntEnum, auto
21
21
  from typing import TYPE_CHECKING, List
22
22
 
23
23
  import numpy as np
24
24
  import torch
25
- import triton
26
- import triton.language as tl
27
-
28
- from sglang.srt.managers.schedule_batch import ScheduleBatch
29
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
30
25
 
31
26
  if TYPE_CHECKING:
27
+ from sglang.srt.layers.attention_backend import AttentionBackend
28
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
29
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
30
  from sglang.srt.model_executor.model_runner import ModelRunner
33
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
34
31
 
35
32
 
36
33
  class ForwardMode(IntEnum):
@@ -40,6 +37,20 @@ class ForwardMode(IntEnum):
40
37
  EXTEND = auto()
41
38
  # Decode one token.
42
39
  DECODE = auto()
40
+ # Contains both PREFILL and EXTEND.
41
+ MIXED = auto()
42
+
43
+ def is_prefill(self):
44
+ return self == ForwardMode.PREFILL
45
+
46
+ def is_extend(self):
47
+ return self == ForwardMode.EXTEND or self == ForwardMode.MIXED
48
+
49
+ def is_decode(self):
50
+ return self == ForwardMode.DECODE
51
+
52
+ def is_mixed(self):
53
+ return self == ForwardMode.MIXED
43
54
 
44
55
 
45
56
  @dataclass
@@ -47,18 +58,16 @@ class InputMetadata:
47
58
  """Store all inforamtion of a forward pass."""
48
59
 
49
60
  forward_mode: ForwardMode
50
- sampling_info: SamplingBatchInfo
51
61
  batch_size: int
52
62
  req_pool_indices: torch.Tensor
53
63
  seq_lens: torch.Tensor
54
64
  req_to_token_pool: ReqToTokenPool
55
65
  token_to_kv_pool: BaseTokenToKVPool
66
+ attn_backend: AttentionBackend
56
67
 
57
68
  # Output location of the KV cache
58
69
  out_cache_loc: torch.Tensor
59
70
 
60
- total_num_tokens: int = None
61
-
62
71
  # Position information
63
72
  positions: torch.Tensor = None
64
73
 
@@ -72,35 +81,25 @@ class InputMetadata:
72
81
  return_logprob: bool = False
73
82
  top_logprobs_nums: List[int] = None
74
83
  extend_seq_lens_cpu: List[int] = None
75
- logprob_start_lens_cpu: List[int] = None
84
+ extend_logprob_start_lens_cpu: List[int] = None
76
85
 
77
86
  # For multimodal
78
87
  pixel_values: List[torch.Tensor] = None
79
88
  image_sizes: List[List[List[int]]] = None
80
89
  image_offsets: List[List[int]] = None
81
-
82
- # Trition attention backend
83
- triton_max_seq_len: int = 0
84
- triton_max_extend_len: int = 0
85
- triton_start_loc: torch.Tensor = None
86
- triton_prefix_lens: torch.Tensor = None
87
-
88
- # FlashInfer attention backend
89
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
90
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
91
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
92
- flashinfer_use_ragged: bool = False
90
+ modalities: List[List[str]] = None
93
91
 
94
92
  def init_multimuldal_info(self, batch: ScheduleBatch):
95
93
  reqs = batch.reqs
96
94
  self.pixel_values = [r.pixel_values for r in reqs]
97
95
  self.image_sizes = [r.image_sizes for r in reqs]
98
96
  self.image_offsets = [r.image_offsets for r in reqs]
97
+ self.modalities = [r.modalities for r in reqs]
99
98
 
100
99
  def compute_positions(self, batch: ScheduleBatch):
101
100
  position_ids_offsets = batch.position_ids_offsets
102
101
 
103
- if self.forward_mode == ForwardMode.DECODE:
102
+ if self.forward_mode.is_decode():
104
103
  if True:
105
104
  self.positions = self.seq_lens - 1
106
105
  else:
@@ -139,315 +138,39 @@ class InputMetadata:
139
138
  self.positions = self.positions.to(torch.int64)
140
139
 
141
140
  def compute_extend_infos(self, batch: ScheduleBatch):
142
- if self.forward_mode == ForwardMode.DECODE:
143
- self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
144
- self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
145
- else:
146
- extend_lens_cpu = [
147
- len(r.fill_ids) - batch.prefix_lens_cpu[i]
148
- for i, r in enumerate(batch.reqs)
149
- ]
150
- self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
151
- self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
152
- self.extend_start_loc = torch.zeros_like(self.seq_lens)
153
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
154
- self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
155
-
156
- self.extend_seq_lens_cpu = extend_lens_cpu
157
- self.logprob_start_lens_cpu = [
158
- (
159
- min(
160
- req.logprob_start_len - batch.prefix_lens_cpu[i],
161
- extend_lens_cpu[i] - 1,
162
- )
163
- if req.logprob_start_len >= batch.prefix_lens_cpu[i]
164
- else extend_lens_cpu[i] - 1 # Fake extend, actually decode
165
- )
166
- for i, req in enumerate(batch.reqs)
167
- ]
141
+ self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
142
+ self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
143
+ self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
144
+ self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
145
+ self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
146
+ self.extend_seq_lens_cpu = batch.extend_lens_cpu
147
+ self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
168
148
 
169
149
  @classmethod
170
150
  def from_schedule_batch(
171
151
  cls,
172
152
  model_runner: "ModelRunner",
173
153
  batch: ScheduleBatch,
174
- forward_mode: ForwardMode,
175
154
  ):
176
155
  ret = cls(
177
- forward_mode=forward_mode,
178
- sampling_info=batch.sampling_info,
156
+ forward_mode=batch.forward_mode,
179
157
  batch_size=batch.batch_size(),
180
158
  req_pool_indices=batch.req_pool_indices,
181
159
  seq_lens=batch.seq_lens,
182
160
  req_to_token_pool=model_runner.req_to_token_pool,
183
161
  token_to_kv_pool=model_runner.token_to_kv_pool,
162
+ attn_backend=model_runner.attn_backend,
184
163
  out_cache_loc=batch.out_cache_loc,
185
164
  return_logprob=batch.return_logprob,
186
165
  top_logprobs_nums=batch.top_logprobs_nums,
187
166
  )
188
167
 
189
- ret.sampling_info.prepare_penalties()
190
-
191
168
  ret.compute_positions(batch)
192
169
 
193
- ret.compute_extend_infos(batch)
194
-
195
- if (
196
- forward_mode != ForwardMode.DECODE
197
- or model_runner.server_args.disable_flashinfer
198
- ):
199
- ret.total_num_tokens = int(torch.sum(ret.seq_lens))
200
-
201
- if forward_mode != ForwardMode.DECODE:
170
+ if not batch.forward_mode.is_decode():
202
171
  ret.init_multimuldal_info(batch)
172
+ ret.compute_extend_infos(batch)
203
173
 
204
- if model_runner.server_args.disable_flashinfer:
205
- ret.init_triton_args(batch)
206
-
207
- flashinfer_use_ragged = False
208
- if not model_runner.server_args.disable_flashinfer:
209
- if (
210
- forward_mode != ForwardMode.DECODE
211
- and int(torch.sum(ret.seq_lens)) > 4096
212
- and model_runner.sliding_window_size is None
213
- ):
214
- flashinfer_use_ragged = True
215
- ret.init_flashinfer_handlers(
216
- model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
217
- )
174
+ model_runner.attn_backend.init_forward_metadata(batch, ret)
218
175
 
219
176
  return ret
220
-
221
- def init_triton_args(self, batch: ScheduleBatch):
222
- """Init auxiliary variables for triton attention backend."""
223
- self.triton_max_seq_len = int(torch.max(self.seq_lens))
224
- self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
225
- self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
226
-
227
- if self.forward_mode == ForwardMode.DECODE:
228
- self.triton_max_extend_len = None
229
- else:
230
- self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
231
- extend_seq_lens = self.seq_lens - self.triton_prefix_lens
232
- self.triton_max_extend_len = int(torch.max(extend_seq_lens))
233
-
234
- def init_flashinfer_handlers(
235
- self,
236
- model_runner,
237
- prefix_lens_cpu,
238
- flashinfer_use_ragged,
239
- ):
240
- if self.forward_mode == ForwardMode.DECODE:
241
- prefix_lens = None
242
- else:
243
- prefix_lens = self.extend_prefix_lens
244
-
245
- update_flashinfer_indices(
246
- self.forward_mode,
247
- model_runner,
248
- self.req_pool_indices,
249
- self.seq_lens,
250
- prefix_lens,
251
- flashinfer_use_ragged=flashinfer_use_ragged,
252
- )
253
-
254
- (
255
- self.flashinfer_prefill_wrapper_ragged,
256
- self.flashinfer_prefill_wrapper_paged,
257
- self.flashinfer_decode_wrapper,
258
- self.flashinfer_use_ragged,
259
- ) = (
260
- model_runner.flashinfer_prefill_wrapper_ragged,
261
- model_runner.flashinfer_prefill_wrapper_paged,
262
- model_runner.flashinfer_decode_wrapper,
263
- flashinfer_use_ragged,
264
- )
265
-
266
-
267
- @triton.jit
268
- def create_flashinfer_kv_indices_triton(
269
- req_to_token_ptr, # [max_batch, max_context_len]
270
- req_pool_indices_ptr,
271
- page_kernel_lens_ptr,
272
- kv_indptr,
273
- kv_start_idx,
274
- max_context_len,
275
- kv_indices_ptr,
276
- ):
277
- BLOCK_SIZE: tl.constexpr = 512
278
- pid = tl.program_id(axis=0)
279
- req_pool_index = tl.load(req_pool_indices_ptr + pid)
280
- kv_indices_offset = tl.load(kv_indptr + pid)
281
-
282
- kv_start = 0
283
- kv_end = 0
284
- if kv_start_idx:
285
- kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
286
- kv_end = kv_start
287
- kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
288
-
289
- req_to_token_ptr += req_pool_index * max_context_len
290
- kv_indices_ptr += kv_indices_offset
291
-
292
- ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
293
- st_offset = tl.arange(0, BLOCK_SIZE)
294
- num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
295
- for _ in range(num_loop):
296
- mask = ld_offset < kv_end
297
- data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
298
- tl.store(kv_indices_ptr + st_offset, data, mask=mask)
299
- ld_offset += BLOCK_SIZE
300
- st_offset += BLOCK_SIZE
301
-
302
-
303
- def update_flashinfer_indices(
304
- forward_mode,
305
- model_runner,
306
- req_pool_indices,
307
- seq_lens,
308
- prefix_lens,
309
- flashinfer_decode_wrapper=None,
310
- flashinfer_use_ragged=False,
311
- ):
312
- """Init auxiliary variables for FlashInfer attention backend."""
313
- num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
314
- num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
315
- head_dim = model_runner.model_config.head_dim
316
- batch_size = len(req_pool_indices)
317
-
318
- if model_runner.sliding_window_size is None:
319
- if flashinfer_use_ragged:
320
- paged_kernel_lens = prefix_lens
321
- else:
322
- paged_kernel_lens = seq_lens
323
-
324
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
325
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
326
-
327
- kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
328
- create_flashinfer_kv_indices_triton[(batch_size,)](
329
- model_runner.req_to_token_pool.req_to_token,
330
- req_pool_indices,
331
- paged_kernel_lens,
332
- kv_indptr,
333
- None,
334
- model_runner.req_to_token_pool.req_to_token.size(1),
335
- kv_indices,
336
- )
337
-
338
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
339
-
340
- if forward_mode == ForwardMode.DECODE:
341
- # CUDA graph uses different flashinfer_decode_wrapper
342
- if flashinfer_decode_wrapper is None:
343
- flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
344
-
345
- flashinfer_decode_wrapper.end_forward()
346
- flashinfer_decode_wrapper.begin_forward(
347
- kv_indptr,
348
- kv_indices,
349
- kv_last_page_len,
350
- num_qo_heads,
351
- num_kv_heads,
352
- head_dim,
353
- 1,
354
- data_type=model_runner.kv_cache_dtype,
355
- q_data_type=model_runner.dtype,
356
- )
357
- else:
358
- # extend part
359
- qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
360
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
361
-
362
- if flashinfer_use_ragged:
363
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
364
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
365
- qo_indptr,
366
- qo_indptr,
367
- num_qo_heads,
368
- num_kv_heads,
369
- head_dim,
370
- )
371
-
372
- # cached part
373
- model_runner.flashinfer_prefill_wrapper_paged.end_forward()
374
- model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
375
- qo_indptr,
376
- kv_indptr,
377
- kv_indices,
378
- kv_last_page_len,
379
- num_qo_heads,
380
- num_kv_heads,
381
- head_dim,
382
- 1,
383
- )
384
- else:
385
- # window attention use paged only
386
- kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
387
- for wrapper_id in range(2):
388
- if wrapper_id == 0:
389
- if forward_mode == ForwardMode.DECODE:
390
- paged_kernel_lens = torch.minimum(
391
- seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
392
- )
393
- else:
394
- paged_kernel_lens = torch.minimum(
395
- seq_lens,
396
- torch.tensor(model_runner.sliding_window_size)
397
- + seq_lens
398
- - prefix_lens,
399
- )
400
- else:
401
- paged_kernel_lens = seq_lens
402
-
403
- kv_start_idx = seq_lens - paged_kernel_lens
404
-
405
- kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
406
- kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
407
-
408
- kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
409
- create_flashinfer_kv_indices_triton[(batch_size,)](
410
- model_runner.req_to_token_pool.req_to_token,
411
- req_pool_indices,
412
- paged_kernel_lens,
413
- kv_indptr,
414
- kv_start_idx,
415
- model_runner.req_to_token_pool.req_to_token.size(1),
416
- kv_indices,
417
- )
418
-
419
- if forward_mode == ForwardMode.DECODE:
420
- # CUDA graph uses different flashinfer_decode_wrapper
421
- if flashinfer_decode_wrapper is None:
422
- flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
423
-
424
- flashinfer_decode_wrapper[wrapper_id].end_forward()
425
- flashinfer_decode_wrapper[wrapper_id].begin_forward(
426
- kv_indptr,
427
- kv_indices,
428
- kv_last_page_len,
429
- num_qo_heads,
430
- num_kv_heads,
431
- head_dim,
432
- 1,
433
- data_type=model_runner.kv_cache_dtype,
434
- q_data_type=model_runner.dtype,
435
- )
436
- else:
437
- # extend part
438
- qo_indptr = torch.zeros(
439
- (batch_size + 1,), dtype=torch.int32, device="cuda"
440
- )
441
- qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
442
-
443
- model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
444
- model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
445
- qo_indptr,
446
- kv_indptr,
447
- kv_indices,
448
- kv_last_page_len,
449
- num_qo_heads,
450
- num_kv_heads,
451
- head_dim,
452
- 1,
453
- )