sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,14 @@ limitations under the License.
15
15
 
16
16
  """
17
17
  Memory-efficient attention for prefill.
18
- It supporst page size = 1 and prefill with KV cache (i.e. extend).
18
+ It supports page size = 1 and prefill with KV cache (i.e. extend).
19
19
  """
20
20
 
21
21
  import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- from sglang.srt.layers.prefill_attention import context_attention_fwd
25
+ from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd
26
26
 
27
27
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
28
 
@@ -61,12 +61,14 @@ def _fwd_kernel(
61
61
  stride_buf_vbs,
62
62
  stride_buf_vh,
63
63
  stride_req_to_tokens_b,
64
+ logit_cap: tl.constexpr,
65
+ Lq: tl.constexpr,
66
+ Lv: tl.constexpr,
64
67
  BLOCK_DMODEL: tl.constexpr,
65
68
  BLOCK_DPE: tl.constexpr,
66
69
  BLOCK_DV: tl.constexpr,
67
70
  BLOCK_M: tl.constexpr,
68
71
  BLOCK_N: tl.constexpr,
69
- logit_cap: tl.constexpr,
70
72
  ):
71
73
  cur_seq = tl.program_id(0)
72
74
  cur_head = tl.program_id(1)
@@ -86,13 +88,18 @@ def _fwd_kernel(
86
88
  offs_m = tl.arange(0, BLOCK_M)
87
89
  mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
88
90
 
91
+ mask_d = offs_d < Lq
92
+ mask_dv = offs_dv < Lv
93
+
89
94
  offs_q = (
90
95
  (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
91
96
  * stride_qbs
92
97
  + cur_head * stride_qh
93
98
  + offs_d[None, :]
94
99
  )
95
- q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
100
+ q = tl.load(
101
+ Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
102
+ )
96
103
 
97
104
  if BLOCK_DPE > 0:
98
105
  offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
@@ -104,7 +111,7 @@ def _fwd_kernel(
104
111
  )
105
112
  qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
106
113
 
107
- # stage1: compute scores with prefix
114
+ # stage 1: compute scores with prefix
108
115
  offs_n = tl.arange(0, BLOCK_N)
109
116
 
110
117
  acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
@@ -125,7 +132,9 @@ def _fwd_kernel(
125
132
  + cur_kv_head * stride_buf_kh
126
133
  + offs_d[:, None]
127
134
  )
128
- k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
135
+ k = tl.load(
136
+ K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
137
+ )
129
138
 
130
139
  qk = tl.dot(q.to(k.dtype), k)
131
140
  if BLOCK_DPE > 0:
@@ -157,13 +166,15 @@ def _fwd_kernel(
157
166
  + cur_kv_head * stride_buf_vh
158
167
  + offs_dv[None, :]
159
168
  )
160
- v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
169
+ v = tl.load(
170
+ V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
171
+ )
161
172
  p = p.to(v.dtype)
162
173
  acc = acc * re_scale[:, None] + tl.dot(p, v)
163
174
 
164
175
  e_max = n_e_max
165
176
 
166
- # stage2: compute the trianlge part
177
+ # stage 2: compute the trianlge part
167
178
 
168
179
  cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
169
180
  for start_n in range(0, cur_block_m_end, BLOCK_N):
@@ -176,7 +187,9 @@ def _fwd_kernel(
176
187
  + cur_kv_head * stride_kh
177
188
  + offs_d[:, None]
178
189
  )
179
- k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
190
+ k = tl.load(
191
+ K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
192
+ )
180
193
 
181
194
  qk = tl.dot(q, k, out_dtype=tl.float32)
182
195
  if BLOCK_DPE > 0:
@@ -214,7 +227,9 @@ def _fwd_kernel(
214
227
  + cur_kv_head * stride_vh
215
228
  + offs_dv[None, :]
216
229
  )
217
- v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
230
+ v = tl.load(
231
+ V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
232
+ )
218
233
  p = p.to(v.dtype)
219
234
  acc = acc * re_scale[:, None] + tl.dot(p, v)
220
235
 
@@ -226,7 +241,9 @@ def _fwd_kernel(
226
241
  + cur_head * stride_oh
227
242
  + offs_dv[None, :]
228
243
  )
229
- tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
244
+ tl.store(
245
+ O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
246
+ )
230
247
 
231
248
 
232
249
  def extend_attention_fwd(
@@ -238,39 +255,34 @@ def extend_attention_fwd(
238
255
  v_buffer,
239
256
  req_to_tokens,
240
257
  b_req_idx,
241
- b_start_loc,
242
258
  b_seq_len,
243
- b_seq_len_prefix,
244
- b_start_loc_extend,
245
259
  b_seq_len_extend,
246
- max_len_in_batch,
260
+ b_start_loc_extend,
247
261
  max_len_extend,
248
262
  sm_scale=None,
249
- logit_cap=-1,
263
+ logit_cap=0.0,
250
264
  ):
251
265
  """
252
266
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
253
267
 
254
268
  k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
255
269
  """
256
- Lq, Lk, Lv, Lo = (
270
+ Lq, Lk, Lv = (
257
271
  q_extend.shape[-1],
258
272
  k_extend.shape[-1],
259
273
  v_extend.shape[-1],
260
- o_extend.shape[-1],
261
274
  )
262
275
 
263
- assert Lq == Lk and Lv == Lo
264
- assert Lq in {16, 32, 64, 128, 256, 576}
265
- assert Lv in {16, 32, 64, 128, 256, 512}
266
-
267
276
  if Lq == 576:
268
277
  BLOCK_DMODEL = 512
269
278
  BLOCK_DPE = 64
279
+ elif Lq == 288:
280
+ BLOCK_DMODEL = 256
281
+ BLOCK_DPE = 32
270
282
  else:
271
- BLOCK_DMODEL = Lq
283
+ BLOCK_DMODEL = triton.next_power_of_2(Lq)
272
284
  BLOCK_DPE = 0
273
- BLOCK_DV = Lv
285
+ BLOCK_DV = triton.next_power_of_2(Lv)
274
286
 
275
287
  if CUDA_CAPABILITY[0] >= 9:
276
288
  if Lq <= 256:
@@ -287,7 +299,7 @@ def extend_attention_fwd(
287
299
  else:
288
300
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
289
301
 
290
- sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
302
+ sm_scale = sm_scale or 1.0 / (Lq**0.5)
291
303
  batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
292
304
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
293
305
 
@@ -322,25 +334,24 @@ def extend_attention_fwd(
322
334
  v_buffer.stride(0),
323
335
  v_buffer.stride(1),
324
336
  req_to_tokens.stride(0),
337
+ logit_cap=logit_cap,
325
338
  BLOCK_DMODEL=BLOCK_DMODEL,
326
339
  BLOCK_DPE=BLOCK_DPE,
327
340
  BLOCK_DV=BLOCK_DV,
328
341
  BLOCK_M=BLOCK_M,
329
342
  BLOCK_N=BLOCK_N,
343
+ Lq=Lq,
344
+ Lv=Lv,
330
345
  num_warps=num_warps,
331
346
  num_stages=num_stages,
332
- logit_cap=logit_cap,
333
347
  )
334
348
 
335
349
 
336
350
  def redundant_attention(
337
351
  q_extend,
338
- k_extend,
339
- v_extend,
340
352
  o_extend,
341
353
  k_buffer,
342
354
  v_buffer,
343
- req_to_tokens,
344
355
  b_req_idx,
345
356
  b_start_loc,
346
357
  b_seq_len,
@@ -371,106 +382,3 @@ def redundant_attention(
371
382
  pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
372
383
  o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
373
384
  pt += cur_seq_len_extend
374
-
375
-
376
- def test():
377
- torch.manual_seed(0)
378
-
379
- B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
380
- dtype = torch.float16
381
-
382
- b_seq_len_prefix = torch.randint(
383
- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
384
- )
385
- b_seq_len_extend = torch.randint(
386
- 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
387
- )
388
- b_seq_len = b_seq_len_prefix + b_seq_len_extend
389
- max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
390
-
391
- b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
392
- req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
393
- b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
394
- b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
395
- b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
396
- b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
397
- for i in range(B):
398
- req_to_tokens[i, : b_seq_len[i]] = torch.arange(
399
- b_start_loc[i], b_start_loc[i] + b_seq_len[i]
400
- )
401
-
402
- total_token_num = torch.sum(b_seq_len).item()
403
- extend_token_num = torch.sum(b_seq_len_extend).item()
404
- k_buffer = torch.empty(
405
- (total_token_num, H_KV, D), dtype=dtype, device="cuda"
406
- ).normal_(mean=0.1, std=0.2)
407
- v_buffer = torch.empty(
408
- (total_token_num, H_KV, D), dtype=dtype, device="cuda"
409
- ).normal_(mean=0.1, std=0.2)
410
-
411
- k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
412
- v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
413
- q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
414
- for i in range(B):
415
- extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
416
- extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
417
- extend_start = b_start_loc_extend[i]
418
- extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
419
- k_extend[extend_start:extend_end] = k_buffer[
420
- extend_start_in_buffer:extend_end_in_buffer
421
- ]
422
- v_extend[extend_start:extend_end] = v_buffer[
423
- extend_start_in_buffer:extend_end_in_buffer
424
- ]
425
- q_extend[extend_start:extend_end] = torch.empty(
426
- (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
427
- ).normal_(mean=0.1, std=0.2)
428
-
429
- o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
430
- o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
431
-
432
- b_seq_len_extend = b_seq_len - b_seq_len_prefix
433
- b_start_loc_extend = torch.zeros_like(b_seq_len)
434
- b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
435
- max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
436
- extend_attention_fwd(
437
- q_extend,
438
- k_extend,
439
- v_extend,
440
- o_extend,
441
- k_buffer,
442
- v_buffer,
443
- req_to_tokens,
444
- b_req_idx,
445
- b_start_loc,
446
- b_seq_len,
447
- b_seq_len_prefix,
448
- b_start_loc_extend,
449
- b_seq_len_extend,
450
- max_len_in_batch,
451
- max_len_extend,
452
- )
453
-
454
- redundant_attention(
455
- q_extend,
456
- k_extend,
457
- v_extend,
458
- o_redundant,
459
- k_buffer,
460
- v_buffer,
461
- req_to_tokens,
462
- b_req_idx,
463
- b_start_loc,
464
- b_seq_len,
465
- b_seq_len_prefix,
466
- max_len_in_batch,
467
- )
468
-
469
- print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
470
- print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
471
-
472
- assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
473
-
474
-
475
- if __name__ == "__main__":
476
- test()
@@ -48,6 +48,7 @@ def _fwd_kernel(
48
48
  BLOCK_M: tl.constexpr,
49
49
  BLOCK_DMODEL: tl.constexpr,
50
50
  BLOCK_N: tl.constexpr,
51
+ Lk: tl.constexpr,
51
52
  ):
52
53
  cur_batch = tl.program_id(0)
53
54
  cur_head = tl.program_id(1)
@@ -72,7 +73,11 @@ def _fwd_kernel(
72
73
  off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
73
74
  off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
74
75
 
75
- q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
76
+ mask_d = offs_d < Lk
77
+
78
+ q = tl.load(
79
+ Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
80
+ )
76
81
 
77
82
  k_ptrs = K + off_k
78
83
  v_ptrs = V + off_v
@@ -89,7 +94,7 @@ def _fwd_kernel(
89
94
  # -- compute qk ----
90
95
  k = tl.load(
91
96
  k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
92
- mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
97
+ mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
93
98
  other=0.0,
94
99
  )
95
100
  # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
@@ -118,7 +123,7 @@ def _fwd_kernel(
118
123
  # update acc
119
124
  v = tl.load(
120
125
  v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
121
- mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
126
+ mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
122
127
  other=0.0,
123
128
  )
124
129
 
@@ -134,7 +139,9 @@ def _fwd_kernel(
134
139
  + offs_d[None, :]
135
140
  )
136
141
  out_ptrs = Out + off_o
137
- tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
142
+ tl.store(
143
+ out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
144
+ )
138
145
 
139
146
 
140
147
  def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
@@ -144,8 +151,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
144
151
  BLOCK = 64
145
152
 
146
153
  Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
147
- assert Lq == Lk and Lk == Lv
148
- assert Lk in {16, 32, 64, 128, 256}
149
154
 
150
155
  sm_scale = 1.0 / (Lq**0.5)
151
156
  batch, head = b_seq_len.shape[0], q.shape[1]
@@ -172,8 +177,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
172
177
  o.stride(1),
173
178
  kv_group_num=kv_group_num,
174
179
  BLOCK_M=BLOCK,
175
- BLOCK_DMODEL=Lk,
180
+ BLOCK_DMODEL=triton.next_power_of_2(Lk),
176
181
  BLOCK_N=BLOCK,
177
182
  num_warps=num_warps,
178
183
  num_stages=1,
184
+ Lk=Lk,
179
185
  )