sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  9. sglang/srt/layers/attention/vision.py +243 -40
  10. sglang/srt/layers/layernorm.py +1 -5
  11. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  22. sglang/srt/layers/moe/topk.py +4 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/fp8.py +7 -0
  46. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  47. sglang/srt/layers/rotary_embedding.py +29 -15
  48. sglang/srt/layers/sampler.py +9 -6
  49. sglang/srt/lora/backend/__init__.py +8 -0
  50. sglang/srt/lora/backend/base_backend.py +95 -0
  51. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  52. sglang/srt/lora/backend/triton_backend.py +61 -0
  53. sglang/srt/lora/lora.py +127 -112
  54. sglang/srt/lora/lora_manager.py +50 -18
  55. sglang/srt/lora/triton_ops/__init__.py +5 -0
  56. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  57. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  59. sglang/srt/managers/image_processor.py +77 -38
  60. sglang/srt/managers/scheduler.py +17 -3
  61. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  62. sglang/srt/mem_cache/chunk_cache.py +3 -0
  63. sglang/srt/mem_cache/radix_cache.py +30 -1
  64. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  65. sglang/srt/model_executor/forward_batch_info.py +58 -59
  66. sglang/srt/model_executor/model_runner.py +2 -2
  67. sglang/srt/models/minicpmv.py +129 -76
  68. sglang/srt/models/mllama.py +16 -56
  69. sglang/srt/models/qwen2.py +4 -1
  70. sglang/srt/models/qwen2_vl.py +19 -9
  71. sglang/srt/server_args.py +19 -2
  72. sglang/srt/speculative/build_eagle_tree.py +4 -2
  73. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  74. sglang/srt/speculative/eagle_utils.py +361 -372
  75. sglang/srt/speculative/eagle_worker.py +177 -45
  76. sglang/srt/utils.py +7 -2
  77. sglang/test/runners.py +2 -0
  78. sglang/utils.py +42 -0
  79. sglang/version.py +1 -1
  80. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
  81. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
  82. sglang/srt/layers/custom_op_util.py +0 -25
  83. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  84. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  85. {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
49
49
  K_Buffer,
50
50
  V_Buffer,
51
51
  sm_scale,
52
- Req_to_tokens,
53
- B_req_idx,
54
- B_Seqlen,
52
+ kv_indptr,
53
+ kv_indices,
55
54
  Att_Out,
56
- stride_req_to_tokens_b,
57
55
  stride_qbs,
58
56
  stride_qh,
59
57
  stride_buf_kbs,
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
82
80
  offs_dv = tl.arange(0, BLOCK_DV)
83
81
  mask_d = offs_d < Lk
84
82
  mask_dv = offs_dv < Lv
85
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
86
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
83
+
84
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
85
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
87
86
 
88
87
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
89
88
  q = tl.load(Q + off_q, mask=mask_d, other=0.0)
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
100
99
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
101
100
  offs_n = start_n + tl.arange(0, BLOCK_N)
102
101
  kv_loc = tl.load(
103
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
102
+ kv_indices + cur_batch_kv_start_idx + offs_n,
104
103
  mask=offs_n < split_kv_end,
105
104
  other=0,
106
105
  )
@@ -173,19 +172,21 @@ def _decode_att_m_fwd(
173
172
  k_buffer,
174
173
  v_buffer,
175
174
  att_out,
176
- Req_to_tokens,
177
- B_req_idx,
178
- B_Seqlen,
175
+ kv_indptr,
176
+ kv_indices,
179
177
  num_kv_splits,
180
178
  sm_scale,
181
179
  logit_cap,
182
180
  ):
183
181
  BLOCK = 64
182
+ # [TODO] work around SGPR limit on MI3xx
183
+ if is_hip_:
184
+ BLOCK = 8
184
185
  NUM_KV_SPLITS = num_kv_splits
185
186
  Lk = k_buffer.shape[-1]
186
187
  Lv = v_buffer.shape[-1]
187
188
 
188
- batch, head_num = B_req_idx.shape[0], q.shape[1]
189
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
189
190
 
190
191
  grid = (batch, head_num, NUM_KV_SPLITS)
191
192
  kv_group_num = q.shape[1] // k_buffer.shape[1]
@@ -194,6 +195,8 @@ def _decode_att_m_fwd(
194
195
  num_warps = 4
195
196
  else:
196
197
  num_warps = 2
198
+ if is_hip_:
199
+ num_warps = 1
197
200
 
198
201
  BLOCK_DMODEL = triton.next_power_of_2(Lk)
199
202
  BLOCK_DV = triton.next_power_of_2(Lv)
@@ -203,11 +206,9 @@ def _decode_att_m_fwd(
203
206
  k_buffer,
204
207
  v_buffer,
205
208
  sm_scale,
206
- Req_to_tokens,
207
- B_req_idx,
208
- B_Seqlen,
209
+ kv_indptr,
210
+ kv_indices,
209
211
  att_out,
210
- Req_to_tokens.stride(0),
211
212
  q.stride(0),
212
213
  q.stride(1),
213
214
  k_buffer.stride(0),
@@ -236,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
236
237
  K_Buffer,
237
238
  V_Buffer,
238
239
  sm_scale,
239
- Req_to_tokens,
240
- B_req_idx,
241
- B_Seqlen,
240
+ kv_indptr,
241
+ kv_indices,
242
242
  Att_Out,
243
- stride_req_to_tokens_b,
244
243
  stride_qbs,
245
244
  stride_qh,
246
245
  stride_buf_kbs,
@@ -279,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
279
278
  offs_dv = tl.arange(0, BLOCK_DV)
280
279
  mask_d = offs_d < Lk
281
280
  mask_dv = offs_dv < Lv
282
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
283
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
281
+
282
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
283
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
284
284
 
285
285
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
286
286
  q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
@@ -307,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
307
307
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
308
308
  offs_n = start_n + tl.arange(0, BLOCK_N)
309
309
  kv_loc = tl.load(
310
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
310
+ kv_indices + cur_batch_kv_start_idx + offs_n,
311
311
  mask=offs_n < split_kv_end,
312
312
  other=0,
313
313
  )
@@ -395,9 +395,8 @@ def _decode_grouped_att_m_fwd(
395
395
  k_buffer,
396
396
  v_buffer,
397
397
  att_out,
398
- Req_to_tokens,
399
- B_req_idx,
400
- B_Seqlen,
398
+ kv_indptr,
399
+ kv_indices,
401
400
  num_kv_splits,
402
401
  sm_scale,
403
402
  logit_cap,
@@ -421,7 +420,7 @@ def _decode_grouped_att_m_fwd(
421
420
  BLOCK_DPE = 0
422
421
  BLOCK_DV = triton.next_power_of_2(Lv)
423
422
 
424
- batch, head_num = B_req_idx.shape[0], q.shape[1]
423
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
425
424
  kv_group_num = q.shape[1] // k_buffer.shape[1]
426
425
 
427
426
  BLOCK_H = 16
@@ -433,21 +432,21 @@ def _decode_grouped_att_m_fwd(
433
432
  )
434
433
 
435
434
  extra_kargs = {}
435
+ num_stages = 2
436
436
  if is_hip_:
437
437
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
438
438
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
439
- extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
439
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
440
+ num_stages = 1
440
441
 
441
442
  _fwd_grouped_kernel_stage1[grid](
442
443
  q,
443
444
  k_buffer,
444
445
  v_buffer,
445
446
  sm_scale,
446
- Req_to_tokens,
447
- B_req_idx,
448
- B_Seqlen,
447
+ kv_indptr,
448
+ kv_indices,
449
449
  att_out,
450
- Req_to_tokens.stride(0),
451
450
  q.stride(0),
452
451
  q.stride(1),
453
452
  k_buffer.stride(0),
@@ -467,7 +466,7 @@ def _decode_grouped_att_m_fwd(
467
466
  NUM_KV_SPLITS=NUM_KV_SPLITS,
468
467
  logit_cap=logit_cap,
469
468
  num_warps=4,
470
- num_stages=2,
469
+ num_stages=num_stages,
471
470
  Lk=Lk,
472
471
  Lv=Lv,
473
472
  **extra_kargs,
@@ -478,7 +477,7 @@ def _decode_grouped_att_m_fwd(
478
477
  def _fwd_kernel_stage2(
479
478
  Mid_O,
480
479
  O,
481
- B_Seqlen,
480
+ kv_indptr,
482
481
  stride_mid_ob,
483
482
  stride_mid_oh,
484
483
  stride_mid_os,
@@ -491,7 +490,9 @@ def _fwd_kernel_stage2(
491
490
  cur_batch = tl.program_id(0)
492
491
  cur_head = tl.program_id(1)
493
492
 
494
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
493
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
494
+ kv_indptr + cur_batch
495
+ )
495
496
 
496
497
  offs_d = tl.arange(0, BLOCK_DV)
497
498
  mask_d = offs_d < Lv
@@ -535,7 +536,7 @@ def _decode_softmax_reducev_fwd(
535
536
  q,
536
537
  o,
537
538
  v_buffer,
538
- b_seq_len,
539
+ kv_indptr,
539
540
  num_kv_splits,
540
541
  ):
541
542
  batch, head_num = q.shape[0], q.shape[1]
@@ -554,7 +555,7 @@ def _decode_softmax_reducev_fwd(
554
555
  _fwd_kernel_stage2[grid](
555
556
  logits,
556
557
  o,
557
- b_seq_len,
558
+ kv_indptr,
558
559
  logits.stride(0),
559
560
  logits.stride(1),
560
561
  logits.stride(2),
@@ -574,9 +575,8 @@ def decode_attention_fwd_normal(
574
575
  k_buffer,
575
576
  v_buffer,
576
577
  o,
577
- req_to_token,
578
- b_req_idx,
579
- b_seq_len,
578
+ kv_indptr,
579
+ kv_indices,
580
580
  attn_logits,
581
581
  num_kv_splits,
582
582
  sm_scale,
@@ -587,14 +587,13 @@ def decode_attention_fwd_normal(
587
587
  k_buffer,
588
588
  v_buffer,
589
589
  attn_logits,
590
- req_to_token,
591
- b_req_idx,
592
- b_seq_len,
590
+ kv_indptr,
591
+ kv_indices,
593
592
  num_kv_splits,
594
593
  sm_scale,
595
594
  logit_cap,
596
595
  )
597
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
596
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
598
597
 
599
598
 
600
599
  def decode_attention_fwd_grouped(
@@ -602,9 +601,8 @@ def decode_attention_fwd_grouped(
602
601
  k_buffer,
603
602
  v_buffer,
604
603
  o,
605
- req_to_token,
606
- b_req_idx,
607
- b_seq_len,
604
+ kv_indptr,
605
+ kv_indices,
608
606
  attn_logits,
609
607
  num_kv_splits,
610
608
  sm_scale,
@@ -615,14 +613,13 @@ def decode_attention_fwd_grouped(
615
613
  k_buffer,
616
614
  v_buffer,
617
615
  attn_logits,
618
- req_to_token,
619
- b_req_idx,
620
- b_seq_len,
616
+ kv_indptr,
617
+ kv_indices,
621
618
  num_kv_splits,
622
619
  sm_scale,
623
620
  logit_cap,
624
621
  )
625
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
622
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
626
623
 
627
624
 
628
625
  def decode_attention_fwd(
@@ -630,9 +627,8 @@ def decode_attention_fwd(
630
627
  k_buffer,
631
628
  v_buffer,
632
629
  o,
633
- req_to_token,
634
- b_req_idx,
635
- b_seq_len,
630
+ kv_indptr,
631
+ kv_indices,
636
632
  attn_logits,
637
633
  num_kv_splits,
638
634
  sm_scale,
@@ -648,9 +644,8 @@ def decode_attention_fwd(
648
644
  k_buffer,
649
645
  v_buffer,
650
646
  o,
651
- req_to_token,
652
- b_req_idx,
653
- b_seq_len,
647
+ kv_indptr,
648
+ kv_indices,
654
649
  attn_logits,
655
650
  num_kv_splits,
656
651
  sm_scale,
@@ -663,9 +658,8 @@ def decode_attention_fwd(
663
658
  k_buffer,
664
659
  v_buffer,
665
660
  o,
666
- req_to_token,
667
- b_req_idx,
668
- b_seq_len,
661
+ kv_indptr,
662
+ kv_indices,
669
663
  attn_logits,
670
664
  num_kv_splits,
671
665
  sm_scale,
@@ -166,6 +166,12 @@ def _fwd_kernel(
166
166
  def context_attention_fwd(
167
167
  q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
168
168
  ):
169
+ """
170
+ q, k, v: [b * s, head, head_dim]
171
+ b_start_loc: [b]
172
+ b_seq_len: [b]
173
+ out: [b * s, head, head_dim]
174
+ """
169
175
  if is_cuda_available and CUDA_CAPABILITY[0] > 8:
170
176
  BLOCK = 128
171
177
  else:
@@ -4,6 +4,7 @@ from typing import Optional
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
7
+ import torch.nn.functional as F
7
8
  from einops import rearrange, repeat
8
9
 
9
10
  from sglang.srt.distributed import parallel_state
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
63
64
 
64
65
 
65
66
  class VisionAttention(nn.Module):
66
- """Multi-headed attention without any cache, mostly used for ViT."""
67
+ r"""
68
+ Multi-headed attention without any cache, mostly used for ViT.
69
+
70
+
71
+ Args:
72
+ use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
73
+ use_context_forward (bool, default to True):
74
+ if ``True``, a flash_attn style attention will be applied
75
+ Otherwise, a full-sequence attention will be applied.
76
+ use_full_precision_softmax (bool, default to False):
77
+ if ``True``, the softmax will be performed in full-precision
78
+ Otherwise, it will be performed in half-precision
79
+
80
+ """
67
81
 
68
82
  def __init__(
69
83
  self,
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
72
86
  projection_size: int,
73
87
  use_qkv_parallel: bool,
74
88
  quant_config: Optional[QuantizationConfig] = None,
89
+ dropout: float = 0.0,
90
+ use_context_forward: bool = True,
91
+ use_full_precision_softmax: bool = False,
92
+ flatten_batch: bool = False,
75
93
  prefix: str = "",
76
94
  ):
77
95
  super().__init__()
96
+ self.use_context_forward = use_context_forward
78
97
  world_size = parallel_state.get_tensor_model_parallel_world_size()
79
-
98
+ self.dropout = dropout
99
+ self.head_size = embed_dim // num_heads
80
100
  self.hidden_size_per_attention_head = dist_utils.divide(
81
101
  projection_size, num_heads
82
102
  )
83
103
  self.num_attention_heads_per_partition = dist_utils.divide(
84
104
  num_heads, world_size
85
105
  )
86
- # self.tp_size = get_tensor_model_parallel_world_size()
87
- # num_heads = self.num_heads_per_partition
106
+
107
+ if self.use_context_forward:
108
+ self.qkv_backend = VisionTritonAttention()
109
+ else:
110
+ self.qkv_backend = VisionSdpaAttention(
111
+ head_size=self.head_size,
112
+ dropout=dropout,
113
+ flatten_batch=flatten_batch,
114
+ use_full_precision_softmax=use_full_precision_softmax,
115
+ )
116
+
88
117
  self.use_qkv_parallel = use_qkv_parallel
89
118
  if use_qkv_parallel:
90
- self.head_dim = embed_dim // num_heads
91
119
  self.qkv_proj = QKVParallelLinear(
92
120
  hidden_size=embed_dim,
93
- head_size=self.head_dim,
121
+ head_size=self.head_size,
94
122
  total_num_heads=num_heads,
95
123
  quant_config=quant_config,
96
124
  prefix=f"{prefix}.qkv_proj",
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
114
142
  x: torch.Tensor,
115
143
  cu_seqlens: Optional[torch.Tensor] = None,
116
144
  rotary_pos_emb: torch.Tensor = None,
145
+ attention_mask: Optional[torch.Tensor] = None,
117
146
  ) -> torch.Tensor:
147
+ r"""
148
+ Args:
149
+ x: [b, s, embed_dim]
150
+ cu_seqlens: [b]
151
+ Returns:
152
+ [s, b, num_heads * head]
118
153
  """
119
- Input shape: [b, s, embed_dim]
120
- Output shape: [s, b, num_heads * head_size]
121
- """
122
-
123
154
  bsz, s, _ = x.shape
124
155
  if self.use_qkv_parallel:
125
156
  # [b, s, embed_dim] --> [b, s, embed_dim]
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
136
167
  else:
137
168
  # [b, s, embed_dim] --> [s, b, embed_dim]
138
169
  x = rearrange(x, "b s ... -> s b ...")
139
- # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
170
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
140
171
  qkv, _ = self.qkv_proj(x)
141
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
172
+ # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
142
173
  new_x_shape = qkv.size()[:-1] + (
143
174
  self.num_attention_heads_per_partition,
144
175
  3 * self.hidden_size_per_attention_head,
145
176
  )
146
177
  qkv = qkv.view(*new_x_shape)
147
178
 
148
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
179
+ # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
149
180
  q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150
181
 
151
- # [s, b, head, head_dim] --> [b, s, head, head_dim]
182
+ # [s, b, head, head_size] --> [b, s, head, head_size]
152
183
  q, k, v = [
153
184
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154
185
  ]
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
160
191
  if self.use_qkv_parallel:
161
192
  pass
162
193
  else:
163
- # [b, s, head, head_dim] --> [b * s, head, head_dim]
194
+ # [b, s, head, head_size] --> [b * s, head, head_size]
164
195
  q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165
196
 
166
- # [b * s, num_heads, head_size]
167
- output = torch.empty_like(q)
168
-
169
- seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170
- max_seqlen = seq_lens.max().item()
171
-
172
- context_attention_fwd(
173
- q,
174
- k,
175
- v,
176
- output,
177
- cu_seqlens.cuda(),
178
- seq_lens,
179
- max_seqlen,
180
- is_causal=False,
181
- )
197
+ output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
182
198
 
183
199
  if self.use_qkv_parallel:
184
-
185
- # [b * s, head, head_dim] --> [b, s, head * head_dim]
200
+ # [b * s, h, head_size] --> [b, s, h * head_size]
186
201
  output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187
202
 
188
- # [b, s, head, head_dim] --> [b, s, head, head_dim]
203
+ # [b, s, h * head_size] --> [b, s, h * head_size]
189
204
  output, _ = self.proj(output)
190
205
  else:
191
- # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193
-
194
- # [s, b, num_heads * head_size]
206
+ # [b * s, h, head_size] --> [s, b, h * head_size]
195
207
  context_layer = rearrange(
196
- context_layer, "b s h d -> s b (h d)"
208
+ output, "(b s) h d -> s b (h d)", b=bsz, s=s
197
209
  ).contiguous()
198
210
 
199
- # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
211
+ # [s, b, h * head_size] --> [s, b, h * head_size]
200
212
  output, _ = self.proj(context_layer)
201
213
 
214
+ # [s, b, h * head_size] --> [b, s, h * head_size]
202
215
  output = output.view(bsz, s, -1)
203
216
 
204
217
  return output
218
+
219
+
220
+ class VisionSdpaAttention(nn.Module):
221
+ r"""
222
+ Scaled Dot Product Attention inner product
223
+
224
+ """
225
+
226
+ # TODO: Should it be released after used?
227
+ _mask_cache = {}
228
+
229
+ def __init__(
230
+ self,
231
+ head_size: int,
232
+ dropout: float = 0.0,
233
+ flatten_batch: bool = False,
234
+ use_full_precision_softmax: bool = False,
235
+ ):
236
+ super().__init__()
237
+ self.head_size = head_size
238
+ self.flatten_batch = flatten_batch
239
+ self.use_full_precision_softmax = use_full_precision_softmax
240
+ self.dropout = dropout
241
+
242
+ def generate_patch_attention_mask(
243
+ self,
244
+ s: int,
245
+ bsz: int,
246
+ device,
247
+ cu_seqlens: Optional[torch.Tensor],
248
+ flatten_batch: bool = False,
249
+ dtype=torch.bfloat16,
250
+ ) -> torch.Tensor:
251
+ r"""
252
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253
+
254
+ When `flatten_batch` is True:
255
+ - All sequences in the batch are flattened into a single dimension
256
+ - `s` represents the total number of tokens across all sequences in the batch
257
+ - Returns a unified mask of shape `(1, 1, s, s)`
258
+
259
+ When `flatten_batch` is False:
260
+ - Each sequence has its own attention mask
261
+ - `s` represents the maximum sequence length in the batch
262
+ - Returns separate masks of shape `(b, 1, s, s)`
263
+
264
+ Args:
265
+ flatten_batch: (bool):
266
+ If True, treats all sequences in the batch as a single flattened sequence
267
+ If False, generates separate masks for each sequence
268
+
269
+ Returns:
270
+ Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
271
+ """
272
+
273
+ cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
274
+
275
+ if cache_key in VisionSdpaAttention._mask_cache:
276
+ cached_mask = VisionSdpaAttention._mask_cache[cache_key]
277
+ # print(f"cache hit for key: {cache_key}")
278
+ return cached_mask.to(device=device, dtype=dtype)
279
+
280
+ if cu_seqlens is None:
281
+ raise ValueError("Internal Error: cu_seqlens cannot be None")
282
+
283
+ if flatten_batch:
284
+ mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
285
+ for i in range(1, len(cu_seqlens)):
286
+ start = cu_seqlens[i - 1]
287
+ end = cu_seqlens[i]
288
+ mask[
289
+ ...,
290
+ start:end,
291
+ start:end,
292
+ ] = True
293
+ else:
294
+ # [1, 1, 1, s]
295
+ row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
296
+ # [1, 1, s, 1]
297
+ col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
298
+ # [b, 1, 1, 1]
299
+ seq_lens = (
300
+ (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
301
+ )
302
+
303
+ mask = (row_indices < seq_lens) & (col_indices < seq_lens)
304
+
305
+ # Convert to attention mask format (False -> 0, True -> -inf)
306
+ mask = (~mask).to(dtype) * torch.finfo(dtype).min
307
+
308
+ VisionSdpaAttention._mask_cache[cache_key] = mask
309
+
310
+ return mask
311
+
312
+ def forward(
313
+ self,
314
+ q: torch.Tensor,
315
+ k: torch.Tensor,
316
+ v: torch.Tensor,
317
+ bsz: int,
318
+ cu_seqlens: Optional[torch.Tensor] = None,
319
+ attention_mask: Optional[torch.Tensor] = None,
320
+ ) -> torch.Tensor:
321
+ r"""
322
+ Args:
323
+ cu_seqlens: [b]
324
+ Returns:
325
+ [b * s, h, head_size]
326
+ """
327
+
328
+ s = q.shape[0] // bsz
329
+
330
+ # [b, 1, s, s]
331
+ if attention_mask is None:
332
+ attention_mask = self.generate_patch_attention_mask(
333
+ s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
334
+ )
335
+ q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
336
+ # [b, 1, s]
337
+ if self.use_full_precision_softmax:
338
+ scale = self.head_size**-0.5
339
+ k_transposed = rearrange(k, "b h s d -> b h d s")
340
+ attn_weights = torch.matmul(q, k_transposed) * scale
341
+ del k, k_transposed
342
+ attn_weights = attn_weights + attention_mask
343
+ del attention_mask
344
+ # full-precision
345
+ attn_weights = nn.functional.softmax(
346
+ attn_weights, dim=-1, dtype=torch.float32
347
+ ).to(q.dtype)
348
+ attn_weights = nn.functional.dropout(
349
+ attn_weights, p=self.dropout, training=False
350
+ )
351
+ output = torch.matmul(attn_weights, v)
352
+ del attn_weights, v
353
+ else:
354
+ # SDPA
355
+ # [b, h, s, head_size]
356
+ output = F.scaled_dot_product_attention(
357
+ q, k, v, attention_mask, dropout_p=self.dropout
358
+ )
359
+
360
+ # [b, h, s, head_size] --> [b * s, h, head_size]
361
+ output = rearrange(output, "b h s d -> (b s) h d")
362
+
363
+ return output
364
+
365
+
366
+ class VisionTritonAttention(nn.Module):
367
+ """
368
+ Triton-implemented attention without a causal mask
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ ):
374
+ super().__init__()
375
+
376
+ def forward(
377
+ self,
378
+ q: torch.Tensor,
379
+ k: torch.Tensor,
380
+ v: torch.Tensor,
381
+ _bsz: int,
382
+ cu_seqlens: Optional[torch.Tensor],
383
+ **kwargs,
384
+ ) -> torch.Tensor:
385
+ r"""
386
+ Args:
387
+ cu_seqlens: [b]
388
+ Returns:
389
+ [b * s, h, head_size]
390
+ """
391
+
392
+ # [b * s, head, head_size]
393
+ output = torch.empty_like(q)
394
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
395
+ max_seqlen = seq_lens.max().item()
396
+ context_attention_fwd(
397
+ q,
398
+ k,
399
+ v,
400
+ output,
401
+ cu_seqlens.cuda(),
402
+ seq_lens.cuda(),
403
+ max_seqlen,
404
+ is_causal=False,
405
+ )
406
+
407
+ return output
@@ -29,14 +29,11 @@ if is_cuda_available():
29
29
  rmsnorm,
30
30
  )
31
31
 
32
- from vllm.model_executor.custom_op import CustomOp
33
-
34
- from sglang.srt.layers.custom_op_util import register_custom_op
32
+ from sglang.srt.custom_op import CustomOp
35
33
 
36
34
  logger = logging.getLogger(__name__)
37
35
 
38
36
 
39
- @register_custom_op("sglang_rmsnorm")
40
37
  class RMSNorm(CustomOp):
41
38
  def __init__(
42
39
  self,
@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
79
76
  return x, residual
80
77
 
81
78
 
82
- @register_custom_op("sglang_gemma_rmsnorm")
83
79
  class GemmaRMSNorm(CustomOp):
84
80
  def __init__(
85
81
  self,