sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,7 @@ import triton.language as tl
26
26
 
27
27
  from sglang.srt.managers.schedule_batch import global_server_args_dict
28
28
 
29
- if global_server_args_dict.get("attention_reduce_in_fp32", False):
29
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
30
30
  REDUCE_TRITON_TYPE = tl.float32
31
31
  REDUCE_TORCH_TYPE = torch.float32
32
32
  else:
@@ -58,7 +58,6 @@ def _fwd_kernel_stage1(
58
58
  att_stride_h,
59
59
  kv_group_num: tl.constexpr,
60
60
  BLOCK_DMODEL: tl.constexpr,
61
- BLOCK_DPE: tl.constexpr,
62
61
  BLOCK_N: tl.constexpr,
63
62
  logit_cap: tl.constexpr,
64
63
  ):
@@ -78,10 +77,6 @@ def _fwd_kernel_stage1(
78
77
 
79
78
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
80
79
 
81
- if BLOCK_DPE > 0:
82
- offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
83
- off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
84
-
85
80
  offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
86
81
 
87
82
  block_stard_index = start_n * BLOCK_N
@@ -106,19 +101,6 @@ def _fwd_kernel_stage1(
106
101
  other=0.0,
107
102
  ).to(REDUCE_TRITON_TYPE)
108
103
  att_value = tl.sum(q[None, :] * k, 1)
109
- if BLOCK_DPE > 0:
110
- qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
111
- offs_buf_kpe = (
112
- k_loc[:, None] * stride_buf_kbs
113
- + cur_kv_head * stride_buf_kh
114
- + offs_dpe[None, :]
115
- )
116
- kpe = tl.load(
117
- K_Buffer + offs_buf_kpe,
118
- mask=offs_n_new[:, None] < cur_batch_end_index,
119
- other=0.0,
120
- ).to(REDUCE_TRITON_TYPE)
121
- att_value += tl.sum(qpe[None, :] * kpe, 1)
122
104
  att_value *= sm_scale
123
105
 
124
106
  if logit_cap > 0:
@@ -214,14 +196,7 @@ def _decode_att_m_fwd(
214
196
  # shape constraints
215
197
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
216
198
  assert Lq == Lk
217
- assert Lk in {16, 32, 64, 128, 256, 576}
218
-
219
- if Lk == 576:
220
- BLOCK_DMODEL = 512
221
- BLOCK_DPE = 64
222
- else:
223
- BLOCK_DMODEL = Lk
224
- BLOCK_DPE = 0
199
+ assert Lk in {16, 32, 64, 128, 256}
225
200
 
226
201
  batch, head_num = B_req_idx.shape[0], q.shape[1]
227
202
 
@@ -249,8 +224,7 @@ def _decode_att_m_fwd(
249
224
  k_buffer.stride(1),
250
225
  att_out.stride(0),
251
226
  kv_group_num=kv_group_num,
252
- BLOCK_DMODEL=BLOCK_DMODEL,
253
- BLOCK_DPE=BLOCK_DPE,
227
+ BLOCK_DMODEL=Lk,
254
228
  BLOCK_N=BLOCK,
255
229
  logit_cap=logit_cap,
256
230
  num_warps=num_warps,
@@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd(
296
270
  )
297
271
 
298
272
 
273
+ @triton.jit
274
+ def _fwd_grouped_kernel_stage1(
275
+ Q,
276
+ K_Buffer,
277
+ sm_scale,
278
+ Req_to_tokens,
279
+ B_req_idx,
280
+ B_Start_Loc,
281
+ B_Seqlen,
282
+ Att_Out,
283
+ stride_req_to_tokens_b,
284
+ stride_qbs,
285
+ stride_qh,
286
+ stride_buf_kbs,
287
+ stride_buf_kh,
288
+ att_stride_h,
289
+ kv_group_num: tl.constexpr,
290
+ q_head_num: tl.constexpr,
291
+ BLOCK_DMODEL: tl.constexpr,
292
+ BLOCK_DPE: tl.constexpr,
293
+ BLOCK_N: tl.constexpr,
294
+ BLOCK_H: tl.constexpr,
295
+ logit_cap: tl.constexpr,
296
+ ):
297
+ cur_batch = tl.program_id(0)
298
+ cur_kv_head = tl.program_id(1)
299
+ start_n = tl.program_id(2)
300
+
301
+ cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
302
+ mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
303
+ mask_h = mask_h & (cur_head < q_head_num)
304
+
305
+ offs_d = tl.arange(0, BLOCK_DMODEL)
306
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
307
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
308
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
309
+
310
+ cur_batch_start_index = 0
311
+ cur_batch_end_index = cur_batch_seq_len
312
+
313
+ offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
314
+
315
+ if BLOCK_DPE > 0:
316
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
317
+ off_qpe = (
318
+ cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
319
+ )
320
+
321
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
322
+
323
+ block_stard_index = start_n * BLOCK_N
324
+ block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
325
+
326
+ for start_mark in range(0, block_mask, 1):
327
+ q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
328
+ REDUCE_TRITON_TYPE
329
+ )
330
+ offs_n_new = cur_batch_start_index + offs_n
331
+ k_loc = tl.load(
332
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
333
+ mask=offs_n_new < cur_batch_end_index,
334
+ other=0,
335
+ )
336
+ offs_buf_k = (
337
+ k_loc[None, :] * stride_buf_kbs
338
+ + cur_kv_head * stride_buf_kh
339
+ + offs_d[:, None]
340
+ )
341
+ k = tl.load(
342
+ K_Buffer + offs_buf_k,
343
+ mask=offs_n_new[None, :] < cur_batch_end_index,
344
+ other=0.0,
345
+ ).to(REDUCE_TRITON_TYPE)
346
+ qk = tl.dot(q, k)
347
+ if BLOCK_DPE > 0:
348
+ qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
349
+ REDUCE_TRITON_TYPE
350
+ )
351
+ offs_buf_kpe = (
352
+ k_loc[None, :] * stride_buf_kbs
353
+ + cur_kv_head * stride_buf_kh
354
+ + offs_dpe[:, None]
355
+ )
356
+ kpe = tl.load(
357
+ K_Buffer + offs_buf_kpe,
358
+ mask=offs_n_new[None, :] < cur_batch_end_index,
359
+ other=0.0,
360
+ ).to(REDUCE_TRITON_TYPE)
361
+ qk += tl.dot(qpe, kpe)
362
+ qk *= sm_scale
363
+
364
+ if logit_cap > 0:
365
+ qk = logit_cap * tanh(qk / logit_cap)
366
+
367
+ offs_o = cur_head[:, None] * att_stride_h + (
368
+ cur_batch_in_all_start_index + offs_n[None, :]
369
+ )
370
+
371
+ tl.store(
372
+ Att_Out + offs_o,
373
+ qk,
374
+ mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
375
+ )
376
+
377
+
378
+ @triton.jit
379
+ def _fwd_grouped_kernel_stage2(
380
+ Logics,
381
+ V_Buffer,
382
+ Out,
383
+ Req_to_tokens,
384
+ B_req_idx,
385
+ B_Start_Loc,
386
+ B_Seqlen,
387
+ stride_logic_h,
388
+ stride_buf_vbs,
389
+ stride_buf_vh,
390
+ stride_obs,
391
+ stride_oh,
392
+ stride_req_to_token_b,
393
+ kv_group_num: tl.constexpr,
394
+ q_head_num: tl.constexpr,
395
+ BLOCK_DMODEL: tl.constexpr,
396
+ BLOCK_N: tl.constexpr,
397
+ BLOCK_H: tl.constexpr,
398
+ ):
399
+ cur_batch = tl.program_id(0)
400
+ cur_kv_head = tl.program_id(1)
401
+
402
+ cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
403
+ mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
404
+ mask_h = mask_h & (cur_head < q_head_num)
405
+
406
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
407
+ cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
408
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
409
+
410
+ offs_n = tl.arange(0, BLOCK_N)
411
+ offs_d = tl.arange(0, BLOCK_DMODEL)
412
+
413
+ offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
414
+ v_ptrs = V_Buffer + offs_buf_v
415
+
416
+ e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
417
+ e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
418
+ acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
419
+
420
+ for start_n in range(0, cur_batch_seq_len, BLOCK_N):
421
+ start_n = tl.multiple_of(start_n, BLOCK_N)
422
+ v_index = tl.load(
423
+ Req_to_tokens
424
+ + cur_batch_req_idx * stride_req_to_token_b
425
+ + (start_n + offs_n),
426
+ mask=(start_n + offs_n) < cur_batch_seq_len,
427
+ other=0,
428
+ )
429
+
430
+ offs_qk = cur_head[:, None] * stride_logic_h + (
431
+ cur_batch_start_loc + start_n + offs_n[None, :]
432
+ )
433
+
434
+ qk = tl.load(
435
+ Logics + offs_qk,
436
+ mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
437
+ other=float("-inf"),
438
+ )
439
+
440
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
441
+ old_scale = tl.exp(e_max - n_e_max)
442
+ p = tl.exp(qk - n_e_max[:, None])
443
+ e_sum = e_sum * old_scale + tl.sum(p, 1)
444
+ v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
445
+ p = p.to(v.dtype)
446
+ acc = acc * old_scale[:, None] + tl.dot(p, v)
447
+ e_max = n_e_max
448
+
449
+ acc = acc / e_sum[:, None]
450
+ off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
451
+ out_ptrs = Out + off_o
452
+ tl.store(out_ptrs, acc, mask=mask_h[:, None])
453
+
454
+
455
+ def _decode_grouped_att_m_fwd(
456
+ q,
457
+ k_buffer,
458
+ att_out,
459
+ Req_to_tokens,
460
+ B_req_idx,
461
+ B_Start_Loc,
462
+ B_Seqlen,
463
+ max_len_in_batch,
464
+ sm_scale,
465
+ logit_cap,
466
+ ):
467
+ BLOCK = 32
468
+ # shape constraints
469
+ Lq, Lk = q.shape[-1], k_buffer.shape[-1]
470
+ assert Lq == Lk
471
+ assert Lk in {16, 32, 64, 128, 256, 576}
472
+
473
+ if Lk == 576:
474
+ BLOCK_DMODEL = 512
475
+ BLOCK_DPE = 64
476
+ else:
477
+ BLOCK_DMODEL = Lk
478
+ BLOCK_DPE = 0
479
+
480
+ batch, head_num = B_req_idx.shape[0], q.shape[1]
481
+ kv_group_num = q.shape[1] // k_buffer.shape[1]
482
+
483
+ BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
484
+ grid = (
485
+ batch,
486
+ triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
487
+ triton.cdiv(max_len_in_batch, BLOCK),
488
+ )
489
+
490
+ num_warps = 4
491
+
492
+ _fwd_grouped_kernel_stage1[grid](
493
+ q,
494
+ k_buffer,
495
+ sm_scale,
496
+ Req_to_tokens,
497
+ B_req_idx,
498
+ B_Start_Loc,
499
+ B_Seqlen,
500
+ att_out,
501
+ Req_to_tokens.stride(0),
502
+ q.stride(0),
503
+ q.stride(1),
504
+ k_buffer.stride(0),
505
+ k_buffer.stride(1),
506
+ att_out.stride(0),
507
+ kv_group_num=kv_group_num,
508
+ q_head_num=head_num,
509
+ BLOCK_DMODEL=BLOCK_DMODEL,
510
+ BLOCK_DPE=BLOCK_DPE,
511
+ BLOCK_N=BLOCK,
512
+ BLOCK_H=BLOCK_H,
513
+ logit_cap=logit_cap,
514
+ num_warps=num_warps,
515
+ num_stages=1,
516
+ )
517
+
518
+
519
+ def _decode_grouped_softmax_reducev_fwd(
520
+ logics,
521
+ v_buffer,
522
+ o,
523
+ req_to_tokens,
524
+ b_req_idx,
525
+ b_start_loc,
526
+ b_seq_len,
527
+ ):
528
+ BLOCK = 128
529
+ batch, head_num = b_seq_len.shape[0], logics.shape[0]
530
+ kv_group_num = logics.shape[0] // v_buffer.shape[1]
531
+ BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
532
+ grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
533
+
534
+ num_warps = 8
535
+
536
+ _fwd_grouped_kernel_stage2[grid](
537
+ logics,
538
+ v_buffer,
539
+ o,
540
+ req_to_tokens,
541
+ b_req_idx,
542
+ b_start_loc,
543
+ b_seq_len,
544
+ logics.stride(0),
545
+ v_buffer.stride(0),
546
+ v_buffer.stride(1),
547
+ o.stride(0),
548
+ o.stride(1),
549
+ req_to_tokens.stride(0),
550
+ kv_group_num=kv_group_num,
551
+ q_head_num=head_num,
552
+ BLOCK_DMODEL=v_buffer.shape[-1],
553
+ BLOCK_N=BLOCK,
554
+ BLOCK_H=BLOCK_H,
555
+ num_warps=num_warps,
556
+ num_stages=1,
557
+ )
558
+
559
+
299
560
  def decode_attention_fwd(
300
561
  q,
301
562
  k_buffer,
@@ -316,24 +577,51 @@ def decode_attention_fwd(
316
577
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
317
578
  )
318
579
 
319
- _decode_att_m_fwd(
320
- q,
321
- k_buffer,
322
- att_m,
323
- req_to_token,
324
- b_req_idx,
325
- b_start_loc,
326
- b_seq_len,
327
- max_len_in_batch,
328
- sm_scale,
329
- logit_cap,
330
- )
331
- _decode_softmax_reducev_fwd(
332
- att_m,
333
- v_buffer,
334
- o,
335
- req_to_token,
336
- b_req_idx,
337
- b_start_loc,
338
- b_seq_len,
339
- )
580
+ kv_group_num = q.shape[1] // v_buffer.shape[1]
581
+
582
+ if kv_group_num == 1:
583
+ # MHA
584
+ _decode_att_m_fwd(
585
+ q,
586
+ k_buffer,
587
+ att_m,
588
+ req_to_token,
589
+ b_req_idx,
590
+ b_start_loc,
591
+ b_seq_len,
592
+ max_len_in_batch,
593
+ sm_scale,
594
+ logit_cap,
595
+ )
596
+ _decode_softmax_reducev_fwd(
597
+ att_m,
598
+ v_buffer,
599
+ o,
600
+ req_to_token,
601
+ b_req_idx,
602
+ b_start_loc,
603
+ b_seq_len,
604
+ )
605
+ else:
606
+ # GQA/MQA/MLA
607
+ _decode_grouped_att_m_fwd(
608
+ q,
609
+ k_buffer,
610
+ att_m,
611
+ req_to_token,
612
+ b_req_idx,
613
+ b_start_loc,
614
+ b_seq_len,
615
+ max_len_in_batch,
616
+ sm_scale,
617
+ logit_cap,
618
+ )
619
+ _decode_grouped_softmax_reducev_fwd(
620
+ att_m,
621
+ v_buffer,
622
+ o,
623
+ req_to_token,
624
+ b_req_idx,
625
+ b_start_loc,
626
+ b_seq_len,
627
+ )
@@ -275,7 +275,9 @@ def extend_attention_fwd(
275
275
  BLOCK_DPE = 0
276
276
  BLOCK_DV = Lv
277
277
 
278
- if CUDA_CAPABILITY[0] >= 8:
278
+ if CUDA_CAPABILITY[0] >= 9:
279
+ BLOCK_M, BLOCK_N = (128, 64)
280
+ elif CUDA_CAPABILITY[0] >= 8:
279
281
  BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
280
282
  else:
281
283
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
@@ -0,0 +1 @@
1
+ from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase