sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -37,6 +37,9 @@ logger.warning(
37
37
  )
38
38
 
39
39
 
40
+ _MIN_BLOCK_KV = 32
41
+
42
+
40
43
  @triton.jit
41
44
  def tanh(x):
42
45
  # Tanh is just a scaled sigmoid
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
52
55
  kv_indptr,
53
56
  kv_indices,
54
57
  Att_Out,
58
+ Att_Lse,
59
+ num_kv_splits,
55
60
  stride_qbs,
56
61
  stride_qh,
57
62
  stride_buf_kbs,
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
65
70
  BLOCK_DMODEL: tl.constexpr,
66
71
  BLOCK_DV: tl.constexpr,
67
72
  BLOCK_N: tl.constexpr,
68
- NUM_KV_SPLITS: tl.constexpr,
73
+ MIN_BLOCK_KV: tl.constexpr,
69
74
  logit_cap: tl.constexpr,
70
75
  Lk: tl.constexpr,
71
76
  Lv: tl.constexpr,
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
83
88
 
84
89
  cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
85
90
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
91
+ kv_splits = tl.load(num_kv_splits + cur_batch)
86
92
 
87
93
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
88
- q = tl.load(Q + off_q, mask=mask_d, other=0.0)
89
94
 
90
- kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
95
+ kv_len_per_split = (
96
+ tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
97
+ )
91
98
  split_kv_start = kv_len_per_split * split_kv_id
92
99
  split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
93
100
 
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
96
103
  acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
97
104
 
98
105
  if split_kv_end > split_kv_start:
106
+ q = tl.load(Q + off_q, mask=mask_d, other=0.0)
99
107
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
100
108
  offs_n = start_n + tl.arange(0, BLOCK_N)
101
109
  kv_loc = tl.load(
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
158
166
  cur_batch * stride_mid_ob
159
167
  + cur_head * stride_mid_oh
160
168
  + split_kv_id * stride_mid_os
161
- + Lv
162
- )
169
+ ) // Lv
163
170
 
164
171
  tl.store(
165
- Att_Out + offs_mid_o_1,
172
+ Att_Lse + offs_mid_o_1,
166
173
  e_max + tl.log(e_sum),
167
174
  )
168
175
 
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
172
179
  k_buffer,
173
180
  v_buffer,
174
181
  att_out,
182
+ att_lse,
175
183
  kv_indptr,
176
184
  kv_indices,
177
185
  num_kv_splits,
186
+ max_kv_splits,
178
187
  sm_scale,
179
188
  logit_cap,
180
189
  ):
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
182
191
  # [TODO] work around SGPR limit on MI3xx
183
192
  if _is_hip:
184
193
  BLOCK = 8
185
- NUM_KV_SPLITS = num_kv_splits
194
+ MAX_KV_SPLITS = max_kv_splits
186
195
  Lk = k_buffer.shape[-1]
187
196
  Lv = v_buffer.shape[-1]
188
197
 
189
198
  batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
190
199
 
191
- grid = (batch, head_num, NUM_KV_SPLITS)
200
+ grid = (batch, head_num, MAX_KV_SPLITS)
192
201
  kv_group_num = q.shape[1] // k_buffer.shape[1]
193
202
 
194
203
  if kv_group_num == 1:
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
209
218
  kv_indptr,
210
219
  kv_indices,
211
220
  att_out,
221
+ att_lse,
222
+ num_kv_splits,
212
223
  q.stride(0),
213
224
  q.stride(1),
214
225
  k_buffer.stride(0),
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
222
233
  BLOCK_DMODEL=BLOCK_DMODEL,
223
234
  BLOCK_DV=BLOCK_DV,
224
235
  BLOCK_N=BLOCK,
225
- NUM_KV_SPLITS=NUM_KV_SPLITS,
236
+ MIN_BLOCK_KV=_MIN_BLOCK_KV,
226
237
  logit_cap=logit_cap,
227
238
  num_warps=num_warps,
228
239
  num_stages=2,
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
240
251
  kv_indptr,
241
252
  kv_indices,
242
253
  Att_Out,
254
+ Att_Lse,
255
+ num_kv_splits,
243
256
  stride_qbs,
244
257
  stride_qh,
245
258
  stride_buf_kbs,
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
256
269
  BLOCK_DV: tl.constexpr,
257
270
  BLOCK_N: tl.constexpr,
258
271
  BLOCK_H: tl.constexpr,
259
- NUM_KV_SPLITS: tl.constexpr,
272
+ MIN_BLOCK_KV: tl.constexpr,
260
273
  logit_cap: tl.constexpr,
261
274
  Lk: tl.constexpr,
262
275
  Lv: tl.constexpr,
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
281
294
 
282
295
  cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
283
296
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
297
+ kv_splits = tl.load(num_kv_splits + cur_batch)
284
298
 
285
299
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
286
- q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
287
300
 
288
301
  if BLOCK_DPE > 0:
289
302
  offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
291
304
  off_qpe = (
292
305
  cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
293
306
  )
294
- qpe = tl.load(
295
- Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
296
- )
297
307
 
298
- kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
308
+ kv_len_per_split = (
309
+ tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
310
+ )
299
311
  split_kv_start = kv_len_per_split * split_kv_id
300
312
  split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
301
313
 
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
304
316
  acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
305
317
 
306
318
  if split_kv_end > split_kv_start:
319
+ q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
320
+ if BLOCK_DPE > 0:
321
+ qpe = tl.load(
322
+ Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
323
+ )
307
324
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
308
325
  offs_n = start_n + tl.arange(0, BLOCK_N)
309
326
  kv_loc = tl.load(
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
380
397
  cur_batch * stride_mid_ob
381
398
  + cur_head * stride_mid_oh
382
399
  + split_kv_id * stride_mid_os
383
- + Lv
384
- )
400
+ ) // Lv
385
401
 
386
402
  tl.store(
387
- Att_Out + offs_mid_o_1,
403
+ Att_Lse + offs_mid_o_1,
388
404
  e_max + tl.log(e_sum),
389
405
  mask=mask_h,
390
406
  )
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
395
411
  k_buffer,
396
412
  v_buffer,
397
413
  att_out,
414
+ att_lse,
398
415
  kv_indptr,
399
416
  kv_indices,
400
417
  num_kv_splits,
418
+ max_kv_splits,
401
419
  sm_scale,
402
420
  logit_cap,
403
421
  ):
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
424
442
  kv_group_num = q.shape[1] // k_buffer.shape[1]
425
443
 
426
444
  BLOCK_H = 16
427
- NUM_KV_SPLITS = num_kv_splits
445
+ MAX_KV_SPLITS = max_kv_splits
428
446
  grid = (
429
447
  batch,
430
448
  triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
431
- NUM_KV_SPLITS,
449
+ MAX_KV_SPLITS,
432
450
  )
433
451
 
434
452
  extra_kargs = {}
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
447
465
  kv_indptr,
448
466
  kv_indices,
449
467
  att_out,
468
+ att_lse,
469
+ num_kv_splits,
450
470
  q.stride(0),
451
471
  q.stride(1),
452
472
  k_buffer.stride(0),
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
463
483
  BLOCK_DV=BLOCK_DV,
464
484
  BLOCK_N=BLOCK,
465
485
  BLOCK_H=BLOCK_H,
466
- NUM_KV_SPLITS=NUM_KV_SPLITS,
486
+ MIN_BLOCK_KV=_MIN_BLOCK_KV,
467
487
  logit_cap=logit_cap,
468
488
  num_warps=4,
469
489
  num_stages=num_stages,
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
476
496
  @triton.jit
477
497
  def _fwd_kernel_stage2(
478
498
  Mid_O,
499
+ Mid_O_1,
479
500
  O,
480
501
  kv_indptr,
502
+ num_kv_splits,
481
503
  stride_mid_ob,
482
504
  stride_mid_oh,
483
505
  stride_mid_os,
484
506
  stride_obs,
485
507
  stride_oh,
486
- NUM_KV_SPLITS: tl.constexpr,
508
+ MAX_KV_SPLITS: tl.constexpr,
509
+ MIN_BLOCK_KV: tl.constexpr,
487
510
  BLOCK_DV: tl.constexpr,
488
511
  Lv: tl.constexpr,
489
512
  ):
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
493
516
  cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
494
517
  kv_indptr + cur_batch
495
518
  )
519
+ kv_splits = tl.load(num_kv_splits + cur_batch)
496
520
 
497
521
  offs_d = tl.arange(0, BLOCK_DV)
498
522
  mask_d = offs_d < Lv
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
502
526
  acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
503
527
 
504
528
  offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
505
- offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
529
+ offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv
530
+ kv_len_per_split = (
531
+ tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
532
+ )
506
533
 
507
- for split_kv_id in range(0, NUM_KV_SPLITS):
508
- kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
534
+ for split_kv_id in range(0, MAX_KV_SPLITS):
509
535
  split_kv_start = kv_len_per_split * split_kv_id
510
536
  split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
511
537
 
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
513
539
  tv = tl.load(
514
540
  Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
515
541
  )
516
- tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
542
+ tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv)
517
543
  n_e_max = tl.maximum(tlogic, e_max)
518
544
 
519
545
  old_scale = tl.exp(e_max - n_e_max)
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
533
559
 
534
560
  def _decode_softmax_reducev_fwd(
535
561
  logits,
562
+ lse,
536
563
  q,
537
564
  o,
538
565
  v_buffer,
539
566
  kv_indptr,
540
567
  num_kv_splits,
568
+ max_kv_splits,
541
569
  ):
542
570
  batch, head_num = q.shape[0], q.shape[1]
543
571
  Lv = v_buffer.shape[-1]
544
572
  BLOCK_DV = triton.next_power_of_2(Lv)
545
573
 
546
- NUM_KV_SPLITS = num_kv_splits
574
+ MAX_KV_SPLITS = max_kv_splits
547
575
 
548
576
  extra_kargs = {}
549
577
  if _is_hip:
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
554
582
  grid = (batch, head_num)
555
583
  _fwd_kernel_stage2[grid](
556
584
  logits,
585
+ lse,
557
586
  o,
558
587
  kv_indptr,
588
+ num_kv_splits,
559
589
  logits.stride(0),
560
590
  logits.stride(1),
561
591
  logits.stride(2),
562
592
  o.stride(0),
563
593
  o.stride(1),
564
- NUM_KV_SPLITS=NUM_KV_SPLITS,
594
+ MAX_KV_SPLITS=MAX_KV_SPLITS,
595
+ MIN_BLOCK_KV=_MIN_BLOCK_KV,
565
596
  BLOCK_DV=BLOCK_DV,
566
597
  Lv=Lv,
567
598
  num_warps=4,
@@ -578,7 +609,9 @@ def decode_attention_fwd_normal(
578
609
  kv_indptr,
579
610
  kv_indices,
580
611
  attn_logits,
612
+ attn_lse,
581
613
  num_kv_splits,
614
+ max_kv_splits,
582
615
  sm_scale,
583
616
  logit_cap=0.0,
584
617
  ):
@@ -587,13 +620,24 @@ def decode_attention_fwd_normal(
587
620
  k_buffer,
588
621
  v_buffer,
589
622
  attn_logits,
623
+ attn_lse,
590
624
  kv_indptr,
591
625
  kv_indices,
592
626
  num_kv_splits,
627
+ max_kv_splits,
593
628
  sm_scale,
594
629
  logit_cap,
595
630
  )
596
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
631
+ _decode_softmax_reducev_fwd(
632
+ attn_logits,
633
+ attn_lse,
634
+ q,
635
+ o,
636
+ v_buffer,
637
+ kv_indptr,
638
+ num_kv_splits,
639
+ max_kv_splits,
640
+ )
597
641
 
598
642
 
599
643
  def decode_attention_fwd_grouped(
@@ -604,7 +648,9 @@ def decode_attention_fwd_grouped(
604
648
  kv_indptr,
605
649
  kv_indices,
606
650
  attn_logits,
651
+ attn_lse,
607
652
  num_kv_splits,
653
+ max_kv_splits,
608
654
  sm_scale,
609
655
  logit_cap=0.0,
610
656
  ):
@@ -613,13 +659,24 @@ def decode_attention_fwd_grouped(
613
659
  k_buffer,
614
660
  v_buffer,
615
661
  attn_logits,
662
+ attn_lse,
616
663
  kv_indptr,
617
664
  kv_indices,
618
665
  num_kv_splits,
666
+ max_kv_splits,
619
667
  sm_scale,
620
668
  logit_cap,
621
669
  )
622
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
670
+ _decode_softmax_reducev_fwd(
671
+ attn_logits,
672
+ attn_lse,
673
+ q,
674
+ o,
675
+ v_buffer,
676
+ kv_indptr,
677
+ num_kv_splits,
678
+ max_kv_splits,
679
+ )
623
680
 
624
681
 
625
682
  def decode_attention_fwd(
@@ -630,11 +687,13 @@ def decode_attention_fwd(
630
687
  kv_indptr,
631
688
  kv_indices,
632
689
  attn_logits,
690
+ attn_lse,
633
691
  num_kv_splits,
692
+ max_kv_splits,
634
693
  sm_scale,
635
694
  logit_cap=0.0,
636
695
  ):
637
- assert num_kv_splits == attn_logits.shape[2]
696
+ assert max_kv_splits == attn_logits.shape[2]
638
697
  assert q.shape[0] <= kv_indptr.shape[0] - 1
639
698
  assert q.shape[0] <= attn_logits.shape[0]
640
699
 
@@ -650,7 +709,9 @@ def decode_attention_fwd(
650
709
  kv_indptr,
651
710
  kv_indices,
652
711
  attn_logits,
712
+ attn_lse,
653
713
  num_kv_splits,
714
+ max_kv_splits,
654
715
  sm_scale,
655
716
  logit_cap,
656
717
  )
@@ -664,7 +725,9 @@ def decode_attention_fwd(
664
725
  kv_indptr,
665
726
  kv_indices,
666
727
  attn_logits,
728
+ attn_lse,
667
729
  num_kv_splits,
730
+ max_kv_splits,
668
731
  sm_scale,
669
732
  logit_cap,
670
733
  )
@@ -341,12 +341,21 @@ def extend_attention_fwd(
341
341
  else:
342
342
  BLOCK_M, BLOCK_N = (32, 64)
343
343
  elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
344
- if Lq <= 128:
345
- BLOCK_M, BLOCK_N = (128, 128)
346
- elif Lq <= 256:
347
- BLOCK_M, BLOCK_N = (64, 64)
344
+ # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
345
+ if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
346
+ if Lq <= 128:
347
+ BLOCK_M, BLOCK_N = (64, 128)
348
+ elif Lq <= 256:
349
+ BLOCK_M, BLOCK_N = (64, 64)
350
+ else:
351
+ BLOCK_M, BLOCK_N = (32, 32)
348
352
  else:
349
- BLOCK_M, BLOCK_N = (32, 64)
353
+ if Lq <= 128:
354
+ BLOCK_M, BLOCK_N = (128, 128)
355
+ elif Lq <= 256:
356
+ BLOCK_M, BLOCK_N = (64, 64)
357
+ else:
358
+ BLOCK_M, BLOCK_N = (32, 64)
350
359
  else:
351
360
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
352
361
 
@@ -15,6 +15,7 @@ def create_flashinfer_kv_indices_triton(
15
15
  BLOCK_SIZE: tl.constexpr = 512
16
16
  pid = tl.program_id(axis=0)
17
17
 
18
+ # find the req pool idx, this is for batch to token
18
19
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
19
20
  kv_indices_offset = tl.load(kv_indptr + pid)
20
21
 
@@ -37,3 +38,55 @@ def create_flashinfer_kv_indices_triton(
37
38
  mask=mask,
38
39
  )
39
40
  tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
41
+
42
+
43
+ @triton.jit
44
+ def create_flashmla_kv_indices_triton(
45
+ req_to_token_ptr, # [max_batch, max_context_len]
46
+ req_pool_indices_ptr,
47
+ page_kernel_lens_ptr,
48
+ kv_start_idx,
49
+ kv_indices_ptr,
50
+ req_to_token_ptr_stride: tl.constexpr,
51
+ kv_indices_ptr_stride: tl.constexpr,
52
+ ):
53
+ PAGED_SIZE: tl.constexpr = 64
54
+ BLOCK_SIZE: tl.constexpr = 4096
55
+ NUM_PAGE_PER_BLOCK: tl.constexpr = 64
56
+ pid = tl.program_id(axis=0)
57
+
58
+ # find the req pool idx, this is for batch to token
59
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
60
+
61
+ kv_start = 0
62
+ kv_end = 0
63
+ if kv_start_idx:
64
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
65
+ kv_end = kv_start
66
+
67
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
68
+
69
+ num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
70
+ num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
71
+
72
+ for i in range(num_pages_loop):
73
+ paged_offset = (
74
+ tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
75
+ ) * PAGED_SIZE
76
+ paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
77
+
78
+ mask = paged_offset <= num_paged * PAGED_SIZE
79
+ mask_out = paged_offset_out <= num_paged
80
+
81
+ data = tl.load(
82
+ req_to_token_ptr
83
+ + req_pool_index * req_to_token_ptr_stride
84
+ + kv_start
85
+ + paged_offset,
86
+ mask=mask,
87
+ )
88
+ tl.store(
89
+ kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,
90
+ data // PAGED_SIZE,
91
+ mask=mask_out,
92
+ )
@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
19
19
  RowParallelLinear,
20
20
  )
21
21
  from sglang.srt.layers.quantization import QuantizationConfig
22
+ from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
22
23
  from sglang.srt.utils import add_prefix
23
24
 
24
25
 
25
- # Copied from transformers, modeling_qwen2_vl.py
26
- def rotate_half(x):
27
- """Rotates half the hidden dims of the input."""
28
- x1 = x[..., : x.shape[-1] // 2]
29
- x2 = x[..., x.shape[-1] // 2 :]
30
- return torch.cat((-x2, x1), dim=-1)
31
-
32
-
33
- def apply_rotary_pos_emb_vision(
34
- q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
35
- ) -> Tuple[torch.Tensor, torch.Tensor]:
36
- orig_q_dtype = q.dtype
37
- orig_k_dtype = k.dtype
38
- q, k = q.float(), k.float()
39
-
40
- cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
41
- q_embed = (q * cos) + (rotate_half(q) * sin)
42
- k_embed = (k * cos) + (rotate_half(k) * sin)
43
-
44
- q_embed = q_embed.to(orig_q_dtype)
45
- k_embed = k_embed.to(orig_k_dtype)
46
-
47
- return q_embed, k_embed
48
-
49
-
50
26
  class VisionAttention(nn.Module):
51
27
  r"""
52
28
  Multi-headed attention without any cache, mostly used for ViT.
@@ -167,9 +143,14 @@ class VisionAttention(nn.Module):
167
143
  if position_embeddings is not None:
168
144
  cos, sin = position_embeddings
169
145
  original_shape = q.shape
170
- q, k = q.view(s, head, -1), k.view(s, head, -1)
171
- q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
172
- q, k = q.reshape(original_shape), k.reshape(original_shape)
146
+ # [total_tokens, head, head_size]
147
+ q = q.view(-1, head, self.head_size)
148
+ k = k.view(-1, head, self.head_size)
149
+
150
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
151
+
152
+ q = q.view(original_shape)
153
+ k = k.view(original_shape)
173
154
 
174
155
  if self.use_qkv_parallel:
175
156
  pass
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import functools
4
4
  import logging
5
5
  from contextlib import contextmanager
6
- from typing import TYPE_CHECKING, Union
6
+ from typing import TYPE_CHECKING, List
7
7
 
8
8
  import torch
9
9
  import triton
@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
38
38
  return attn_tp_rank, attn_tp_size, dp_rank
39
39
 
40
40
 
41
- def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
41
+ def initialize_dp_attention(
42
+ enable_dp_attention: bool,
43
+ tp_rank: int,
44
+ tp_size: int,
45
+ dp_size: int,
46
+ ):
42
47
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
43
48
 
44
49
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
@@ -46,7 +51,11 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
46
51
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
47
52
  enable_dp_attention, tp_rank, tp_size, dp_size
48
53
  )
49
- _DP_SIZE = dp_size
54
+
55
+ if enable_dp_attention:
56
+ _DP_SIZE = dp_size
57
+ else:
58
+ _DP_SIZE = 1
50
59
 
51
60
  tp_group = get_tp_group()
52
61
  _ATTN_TP_GROUP = GroupCoordinator(
@@ -54,7 +63,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
54
63
  list(range(head, head + _ATTN_TP_SIZE))
55
64
  for head in range(0, tp_size, _ATTN_TP_SIZE)
56
65
  ],
57
- tp_rank,
66
+ tp_group.local_rank,
58
67
  torch.distributed.get_backend(tp_group.device_group),
59
68
  SYNC_TOKEN_IDS_ACROSS_TP,
60
69
  False,
@@ -169,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
169
178
  memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
170
179
 
171
180
 
172
- def dp_gather(
181
+ def _dp_gather(
173
182
  global_tokens: torch.Tensor,
174
183
  local_tokens: torch.Tensor,
175
184
  forward_batch: ForwardBatch,
176
- layer_id: Union[str, int],
185
+ is_partial: bool,
177
186
  ):
178
187
  local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
179
188
 
180
189
  global_tokens.fill_(0)
181
190
  assert local_tokens.is_contiguous()
182
191
  assert global_tokens.is_contiguous()
183
- if local_tokens.shape[0] > 0 and (
184
- layer_id != "embedding" or get_attention_tp_rank() == 0
185
- ):
192
+
193
+ if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
186
194
  assert (
187
195
  global_tokens.untyped_storage().data_ptr()
188
196
  != local_tokens.untyped_storage().data_ptr()
@@ -205,6 +213,22 @@ def dp_gather(
205
213
  global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
206
214
 
207
215
 
216
+ def dp_gather_partial(
217
+ global_tokens: torch.Tensor,
218
+ local_tokens: torch.Tensor,
219
+ forward_batch: ForwardBatch,
220
+ ):
221
+ _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
222
+
223
+
224
+ def dp_gather_replicate(
225
+ global_tokens: torch.Tensor,
226
+ local_tokens: torch.Tensor,
227
+ forward_batch: ForwardBatch,
228
+ ):
229
+ _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
230
+
231
+
208
232
  def dp_scatter(
209
233
  local_tokens: torch.Tensor, # output
210
234
  global_tokens: torch.Tensor, # input
@@ -227,14 +251,12 @@ def dp_scatter(
227
251
  )
228
252
 
229
253
 
230
- def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
231
- def do_logits_dp_scatter(logits: torch.Tensor):
232
- local_logits = torch.empty(
233
- (forward_batch.input_ids.shape[0], *logits.shape[1:]),
234
- dtype=logits.dtype,
235
- device=logits.device,
236
- )
237
- dp_scatter(local_logits, logits, forward_batch)
238
- return local_logits
254
+ def tp_reduce_scatter(
255
+ output: torch.Tensor,
256
+ input_list: List[torch.Tensor],
257
+ ):
258
+ return get_attention_tp_group().reduce_scatter(output, input_list)
259
+
239
260
 
240
- return do_logits_dp_scatter
261
+ def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
262
+ return get_attention_tp_group().all_gather(input_, tensor_list=output_list)