sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -37,6 +37,9 @@ logger.warning(
|
|
37
37
|
)
|
38
38
|
|
39
39
|
|
40
|
+
_MIN_BLOCK_KV = 32
|
41
|
+
|
42
|
+
|
40
43
|
@triton.jit
|
41
44
|
def tanh(x):
|
42
45
|
# Tanh is just a scaled sigmoid
|
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
|
|
52
55
|
kv_indptr,
|
53
56
|
kv_indices,
|
54
57
|
Att_Out,
|
58
|
+
Att_Lse,
|
59
|
+
num_kv_splits,
|
55
60
|
stride_qbs,
|
56
61
|
stride_qh,
|
57
62
|
stride_buf_kbs,
|
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
|
|
65
70
|
BLOCK_DMODEL: tl.constexpr,
|
66
71
|
BLOCK_DV: tl.constexpr,
|
67
72
|
BLOCK_N: tl.constexpr,
|
68
|
-
|
73
|
+
MIN_BLOCK_KV: tl.constexpr,
|
69
74
|
logit_cap: tl.constexpr,
|
70
75
|
Lk: tl.constexpr,
|
71
76
|
Lv: tl.constexpr,
|
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
|
|
83
88
|
|
84
89
|
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
85
90
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
91
|
+
kv_splits = tl.load(num_kv_splits + cur_batch)
|
86
92
|
|
87
93
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
88
|
-
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
89
94
|
|
90
|
-
kv_len_per_split =
|
95
|
+
kv_len_per_split = (
|
96
|
+
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
|
97
|
+
)
|
91
98
|
split_kv_start = kv_len_per_split * split_kv_id
|
92
99
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
93
100
|
|
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
|
|
96
103
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
97
104
|
|
98
105
|
if split_kv_end > split_kv_start:
|
106
|
+
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
99
107
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
100
108
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
101
109
|
kv_loc = tl.load(
|
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
|
|
158
166
|
cur_batch * stride_mid_ob
|
159
167
|
+ cur_head * stride_mid_oh
|
160
168
|
+ split_kv_id * stride_mid_os
|
161
|
-
|
162
|
-
)
|
169
|
+
) // Lv
|
163
170
|
|
164
171
|
tl.store(
|
165
|
-
|
172
|
+
Att_Lse + offs_mid_o_1,
|
166
173
|
e_max + tl.log(e_sum),
|
167
174
|
)
|
168
175
|
|
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
|
|
172
179
|
k_buffer,
|
173
180
|
v_buffer,
|
174
181
|
att_out,
|
182
|
+
att_lse,
|
175
183
|
kv_indptr,
|
176
184
|
kv_indices,
|
177
185
|
num_kv_splits,
|
186
|
+
max_kv_splits,
|
178
187
|
sm_scale,
|
179
188
|
logit_cap,
|
180
189
|
):
|
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
|
|
182
191
|
# [TODO] work around SGPR limit on MI3xx
|
183
192
|
if _is_hip:
|
184
193
|
BLOCK = 8
|
185
|
-
|
194
|
+
MAX_KV_SPLITS = max_kv_splits
|
186
195
|
Lk = k_buffer.shape[-1]
|
187
196
|
Lv = v_buffer.shape[-1]
|
188
197
|
|
189
198
|
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
190
199
|
|
191
|
-
grid = (batch, head_num,
|
200
|
+
grid = (batch, head_num, MAX_KV_SPLITS)
|
192
201
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
193
202
|
|
194
203
|
if kv_group_num == 1:
|
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
|
|
209
218
|
kv_indptr,
|
210
219
|
kv_indices,
|
211
220
|
att_out,
|
221
|
+
att_lse,
|
222
|
+
num_kv_splits,
|
212
223
|
q.stride(0),
|
213
224
|
q.stride(1),
|
214
225
|
k_buffer.stride(0),
|
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
|
|
222
233
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
223
234
|
BLOCK_DV=BLOCK_DV,
|
224
235
|
BLOCK_N=BLOCK,
|
225
|
-
|
236
|
+
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
226
237
|
logit_cap=logit_cap,
|
227
238
|
num_warps=num_warps,
|
228
239
|
num_stages=2,
|
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
|
|
240
251
|
kv_indptr,
|
241
252
|
kv_indices,
|
242
253
|
Att_Out,
|
254
|
+
Att_Lse,
|
255
|
+
num_kv_splits,
|
243
256
|
stride_qbs,
|
244
257
|
stride_qh,
|
245
258
|
stride_buf_kbs,
|
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
|
|
256
269
|
BLOCK_DV: tl.constexpr,
|
257
270
|
BLOCK_N: tl.constexpr,
|
258
271
|
BLOCK_H: tl.constexpr,
|
259
|
-
|
272
|
+
MIN_BLOCK_KV: tl.constexpr,
|
260
273
|
logit_cap: tl.constexpr,
|
261
274
|
Lk: tl.constexpr,
|
262
275
|
Lv: tl.constexpr,
|
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
|
|
281
294
|
|
282
295
|
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
283
296
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
297
|
+
kv_splits = tl.load(num_kv_splits + cur_batch)
|
284
298
|
|
285
299
|
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
286
|
-
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
287
300
|
|
288
301
|
if BLOCK_DPE > 0:
|
289
302
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
|
|
291
304
|
off_qpe = (
|
292
305
|
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
293
306
|
)
|
294
|
-
qpe = tl.load(
|
295
|
-
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
296
|
-
)
|
297
307
|
|
298
|
-
kv_len_per_split =
|
308
|
+
kv_len_per_split = (
|
309
|
+
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
|
310
|
+
)
|
299
311
|
split_kv_start = kv_len_per_split * split_kv_id
|
300
312
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
301
313
|
|
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
|
|
304
316
|
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
305
317
|
|
306
318
|
if split_kv_end > split_kv_start:
|
319
|
+
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
|
320
|
+
if BLOCK_DPE > 0:
|
321
|
+
qpe = tl.load(
|
322
|
+
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
323
|
+
)
|
307
324
|
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
308
325
|
offs_n = start_n + tl.arange(0, BLOCK_N)
|
309
326
|
kv_loc = tl.load(
|
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
|
|
380
397
|
cur_batch * stride_mid_ob
|
381
398
|
+ cur_head * stride_mid_oh
|
382
399
|
+ split_kv_id * stride_mid_os
|
383
|
-
|
384
|
-
)
|
400
|
+
) // Lv
|
385
401
|
|
386
402
|
tl.store(
|
387
|
-
|
403
|
+
Att_Lse + offs_mid_o_1,
|
388
404
|
e_max + tl.log(e_sum),
|
389
405
|
mask=mask_h,
|
390
406
|
)
|
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
|
|
395
411
|
k_buffer,
|
396
412
|
v_buffer,
|
397
413
|
att_out,
|
414
|
+
att_lse,
|
398
415
|
kv_indptr,
|
399
416
|
kv_indices,
|
400
417
|
num_kv_splits,
|
418
|
+
max_kv_splits,
|
401
419
|
sm_scale,
|
402
420
|
logit_cap,
|
403
421
|
):
|
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
|
|
424
442
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
425
443
|
|
426
444
|
BLOCK_H = 16
|
427
|
-
|
445
|
+
MAX_KV_SPLITS = max_kv_splits
|
428
446
|
grid = (
|
429
447
|
batch,
|
430
448
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
431
|
-
|
449
|
+
MAX_KV_SPLITS,
|
432
450
|
)
|
433
451
|
|
434
452
|
extra_kargs = {}
|
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
|
|
447
465
|
kv_indptr,
|
448
466
|
kv_indices,
|
449
467
|
att_out,
|
468
|
+
att_lse,
|
469
|
+
num_kv_splits,
|
450
470
|
q.stride(0),
|
451
471
|
q.stride(1),
|
452
472
|
k_buffer.stride(0),
|
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
|
|
463
483
|
BLOCK_DV=BLOCK_DV,
|
464
484
|
BLOCK_N=BLOCK,
|
465
485
|
BLOCK_H=BLOCK_H,
|
466
|
-
|
486
|
+
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
467
487
|
logit_cap=logit_cap,
|
468
488
|
num_warps=4,
|
469
489
|
num_stages=num_stages,
|
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
|
|
476
496
|
@triton.jit
|
477
497
|
def _fwd_kernel_stage2(
|
478
498
|
Mid_O,
|
499
|
+
Mid_O_1,
|
479
500
|
O,
|
480
501
|
kv_indptr,
|
502
|
+
num_kv_splits,
|
481
503
|
stride_mid_ob,
|
482
504
|
stride_mid_oh,
|
483
505
|
stride_mid_os,
|
484
506
|
stride_obs,
|
485
507
|
stride_oh,
|
486
|
-
|
508
|
+
MAX_KV_SPLITS: tl.constexpr,
|
509
|
+
MIN_BLOCK_KV: tl.constexpr,
|
487
510
|
BLOCK_DV: tl.constexpr,
|
488
511
|
Lv: tl.constexpr,
|
489
512
|
):
|
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
|
|
493
516
|
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
|
494
517
|
kv_indptr + cur_batch
|
495
518
|
)
|
519
|
+
kv_splits = tl.load(num_kv_splits + cur_batch)
|
496
520
|
|
497
521
|
offs_d = tl.arange(0, BLOCK_DV)
|
498
522
|
mask_d = offs_d < Lv
|
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
|
|
502
526
|
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
503
527
|
|
504
528
|
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
505
|
-
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh
|
529
|
+
offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv
|
530
|
+
kv_len_per_split = (
|
531
|
+
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
|
532
|
+
)
|
506
533
|
|
507
|
-
for split_kv_id in range(0,
|
508
|
-
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
534
|
+
for split_kv_id in range(0, MAX_KV_SPLITS):
|
509
535
|
split_kv_start = kv_len_per_split * split_kv_id
|
510
536
|
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
511
537
|
|
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
|
|
513
539
|
tv = tl.load(
|
514
540
|
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
|
515
541
|
)
|
516
|
-
tlogic = tl.load(
|
542
|
+
tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv)
|
517
543
|
n_e_max = tl.maximum(tlogic, e_max)
|
518
544
|
|
519
545
|
old_scale = tl.exp(e_max - n_e_max)
|
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
|
|
533
559
|
|
534
560
|
def _decode_softmax_reducev_fwd(
|
535
561
|
logits,
|
562
|
+
lse,
|
536
563
|
q,
|
537
564
|
o,
|
538
565
|
v_buffer,
|
539
566
|
kv_indptr,
|
540
567
|
num_kv_splits,
|
568
|
+
max_kv_splits,
|
541
569
|
):
|
542
570
|
batch, head_num = q.shape[0], q.shape[1]
|
543
571
|
Lv = v_buffer.shape[-1]
|
544
572
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
545
573
|
|
546
|
-
|
574
|
+
MAX_KV_SPLITS = max_kv_splits
|
547
575
|
|
548
576
|
extra_kargs = {}
|
549
577
|
if _is_hip:
|
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
|
|
554
582
|
grid = (batch, head_num)
|
555
583
|
_fwd_kernel_stage2[grid](
|
556
584
|
logits,
|
585
|
+
lse,
|
557
586
|
o,
|
558
587
|
kv_indptr,
|
588
|
+
num_kv_splits,
|
559
589
|
logits.stride(0),
|
560
590
|
logits.stride(1),
|
561
591
|
logits.stride(2),
|
562
592
|
o.stride(0),
|
563
593
|
o.stride(1),
|
564
|
-
|
594
|
+
MAX_KV_SPLITS=MAX_KV_SPLITS,
|
595
|
+
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
565
596
|
BLOCK_DV=BLOCK_DV,
|
566
597
|
Lv=Lv,
|
567
598
|
num_warps=4,
|
@@ -578,7 +609,9 @@ def decode_attention_fwd_normal(
|
|
578
609
|
kv_indptr,
|
579
610
|
kv_indices,
|
580
611
|
attn_logits,
|
612
|
+
attn_lse,
|
581
613
|
num_kv_splits,
|
614
|
+
max_kv_splits,
|
582
615
|
sm_scale,
|
583
616
|
logit_cap=0.0,
|
584
617
|
):
|
@@ -587,13 +620,24 @@ def decode_attention_fwd_normal(
|
|
587
620
|
k_buffer,
|
588
621
|
v_buffer,
|
589
622
|
attn_logits,
|
623
|
+
attn_lse,
|
590
624
|
kv_indptr,
|
591
625
|
kv_indices,
|
592
626
|
num_kv_splits,
|
627
|
+
max_kv_splits,
|
593
628
|
sm_scale,
|
594
629
|
logit_cap,
|
595
630
|
)
|
596
|
-
_decode_softmax_reducev_fwd(
|
631
|
+
_decode_softmax_reducev_fwd(
|
632
|
+
attn_logits,
|
633
|
+
attn_lse,
|
634
|
+
q,
|
635
|
+
o,
|
636
|
+
v_buffer,
|
637
|
+
kv_indptr,
|
638
|
+
num_kv_splits,
|
639
|
+
max_kv_splits,
|
640
|
+
)
|
597
641
|
|
598
642
|
|
599
643
|
def decode_attention_fwd_grouped(
|
@@ -604,7 +648,9 @@ def decode_attention_fwd_grouped(
|
|
604
648
|
kv_indptr,
|
605
649
|
kv_indices,
|
606
650
|
attn_logits,
|
651
|
+
attn_lse,
|
607
652
|
num_kv_splits,
|
653
|
+
max_kv_splits,
|
608
654
|
sm_scale,
|
609
655
|
logit_cap=0.0,
|
610
656
|
):
|
@@ -613,13 +659,24 @@ def decode_attention_fwd_grouped(
|
|
613
659
|
k_buffer,
|
614
660
|
v_buffer,
|
615
661
|
attn_logits,
|
662
|
+
attn_lse,
|
616
663
|
kv_indptr,
|
617
664
|
kv_indices,
|
618
665
|
num_kv_splits,
|
666
|
+
max_kv_splits,
|
619
667
|
sm_scale,
|
620
668
|
logit_cap,
|
621
669
|
)
|
622
|
-
_decode_softmax_reducev_fwd(
|
670
|
+
_decode_softmax_reducev_fwd(
|
671
|
+
attn_logits,
|
672
|
+
attn_lse,
|
673
|
+
q,
|
674
|
+
o,
|
675
|
+
v_buffer,
|
676
|
+
kv_indptr,
|
677
|
+
num_kv_splits,
|
678
|
+
max_kv_splits,
|
679
|
+
)
|
623
680
|
|
624
681
|
|
625
682
|
def decode_attention_fwd(
|
@@ -630,11 +687,13 @@ def decode_attention_fwd(
|
|
630
687
|
kv_indptr,
|
631
688
|
kv_indices,
|
632
689
|
attn_logits,
|
690
|
+
attn_lse,
|
633
691
|
num_kv_splits,
|
692
|
+
max_kv_splits,
|
634
693
|
sm_scale,
|
635
694
|
logit_cap=0.0,
|
636
695
|
):
|
637
|
-
assert
|
696
|
+
assert max_kv_splits == attn_logits.shape[2]
|
638
697
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
639
698
|
assert q.shape[0] <= attn_logits.shape[0]
|
640
699
|
|
@@ -650,7 +709,9 @@ def decode_attention_fwd(
|
|
650
709
|
kv_indptr,
|
651
710
|
kv_indices,
|
652
711
|
attn_logits,
|
712
|
+
attn_lse,
|
653
713
|
num_kv_splits,
|
714
|
+
max_kv_splits,
|
654
715
|
sm_scale,
|
655
716
|
logit_cap,
|
656
717
|
)
|
@@ -664,7 +725,9 @@ def decode_attention_fwd(
|
|
664
725
|
kv_indptr,
|
665
726
|
kv_indices,
|
666
727
|
attn_logits,
|
728
|
+
attn_lse,
|
667
729
|
num_kv_splits,
|
730
|
+
max_kv_splits,
|
668
731
|
sm_scale,
|
669
732
|
logit_cap,
|
670
733
|
)
|
@@ -341,12 +341,21 @@ def extend_attention_fwd(
|
|
341
341
|
else:
|
342
342
|
BLOCK_M, BLOCK_N = (32, 64)
|
343
343
|
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
344
|
+
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
345
|
+
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
346
|
+
if Lq <= 128:
|
347
|
+
BLOCK_M, BLOCK_N = (64, 128)
|
348
|
+
elif Lq <= 256:
|
349
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
350
|
+
else:
|
351
|
+
BLOCK_M, BLOCK_N = (32, 32)
|
348
352
|
else:
|
349
|
-
|
353
|
+
if Lq <= 128:
|
354
|
+
BLOCK_M, BLOCK_N = (128, 128)
|
355
|
+
elif Lq <= 256:
|
356
|
+
BLOCK_M, BLOCK_N = (64, 64)
|
357
|
+
else:
|
358
|
+
BLOCK_M, BLOCK_N = (32, 64)
|
350
359
|
else:
|
351
360
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
352
361
|
|
@@ -15,6 +15,7 @@ def create_flashinfer_kv_indices_triton(
|
|
15
15
|
BLOCK_SIZE: tl.constexpr = 512
|
16
16
|
pid = tl.program_id(axis=0)
|
17
17
|
|
18
|
+
# find the req pool idx, this is for batch to token
|
18
19
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
19
20
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
20
21
|
|
@@ -37,3 +38,55 @@ def create_flashinfer_kv_indices_triton(
|
|
37
38
|
mask=mask,
|
38
39
|
)
|
39
40
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
41
|
+
|
42
|
+
|
43
|
+
@triton.jit
|
44
|
+
def create_flashmla_kv_indices_triton(
|
45
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
46
|
+
req_pool_indices_ptr,
|
47
|
+
page_kernel_lens_ptr,
|
48
|
+
kv_start_idx,
|
49
|
+
kv_indices_ptr,
|
50
|
+
req_to_token_ptr_stride: tl.constexpr,
|
51
|
+
kv_indices_ptr_stride: tl.constexpr,
|
52
|
+
):
|
53
|
+
PAGED_SIZE: tl.constexpr = 64
|
54
|
+
BLOCK_SIZE: tl.constexpr = 4096
|
55
|
+
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
56
|
+
pid = tl.program_id(axis=0)
|
57
|
+
|
58
|
+
# find the req pool idx, this is for batch to token
|
59
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
60
|
+
|
61
|
+
kv_start = 0
|
62
|
+
kv_end = 0
|
63
|
+
if kv_start_idx:
|
64
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
65
|
+
kv_end = kv_start
|
66
|
+
|
67
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
68
|
+
|
69
|
+
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
|
70
|
+
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
71
|
+
|
72
|
+
for i in range(num_pages_loop):
|
73
|
+
paged_offset = (
|
74
|
+
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
75
|
+
) * PAGED_SIZE
|
76
|
+
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
77
|
+
|
78
|
+
mask = paged_offset <= num_paged * PAGED_SIZE
|
79
|
+
mask_out = paged_offset_out <= num_paged
|
80
|
+
|
81
|
+
data = tl.load(
|
82
|
+
req_to_token_ptr
|
83
|
+
+ req_pool_index * req_to_token_ptr_stride
|
84
|
+
+ kv_start
|
85
|
+
+ paged_offset,
|
86
|
+
mask=mask,
|
87
|
+
)
|
88
|
+
tl.store(
|
89
|
+
kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,
|
90
|
+
data // PAGED_SIZE,
|
91
|
+
mask=mask_out,
|
92
|
+
)
|
@@ -19,34 +19,10 @@ from sglang.srt.layers.linear import (
|
|
19
19
|
RowParallelLinear,
|
20
20
|
)
|
21
21
|
from sglang.srt.layers.quantization import QuantizationConfig
|
22
|
+
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
|
22
23
|
from sglang.srt.utils import add_prefix
|
23
24
|
|
24
25
|
|
25
|
-
# Copied from transformers, modeling_qwen2_vl.py
|
26
|
-
def rotate_half(x):
|
27
|
-
"""Rotates half the hidden dims of the input."""
|
28
|
-
x1 = x[..., : x.shape[-1] // 2]
|
29
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
30
|
-
return torch.cat((-x2, x1), dim=-1)
|
31
|
-
|
32
|
-
|
33
|
-
def apply_rotary_pos_emb_vision(
|
34
|
-
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
35
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
-
orig_q_dtype = q.dtype
|
37
|
-
orig_k_dtype = k.dtype
|
38
|
-
q, k = q.float(), k.float()
|
39
|
-
|
40
|
-
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
41
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
42
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
43
|
-
|
44
|
-
q_embed = q_embed.to(orig_q_dtype)
|
45
|
-
k_embed = k_embed.to(orig_k_dtype)
|
46
|
-
|
47
|
-
return q_embed, k_embed
|
48
|
-
|
49
|
-
|
50
26
|
class VisionAttention(nn.Module):
|
51
27
|
r"""
|
52
28
|
Multi-headed attention without any cache, mostly used for ViT.
|
@@ -167,9 +143,14 @@ class VisionAttention(nn.Module):
|
|
167
143
|
if position_embeddings is not None:
|
168
144
|
cos, sin = position_embeddings
|
169
145
|
original_shape = q.shape
|
170
|
-
|
171
|
-
q
|
172
|
-
|
146
|
+
# [total_tokens, head, head_size]
|
147
|
+
q = q.view(-1, head, self.head_size)
|
148
|
+
k = k.view(-1, head, self.head_size)
|
149
|
+
|
150
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
151
|
+
|
152
|
+
q = q.view(original_shape)
|
153
|
+
k = k.view(original_shape)
|
173
154
|
|
174
155
|
if self.use_qkv_parallel:
|
175
156
|
pass
|
@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|
38
38
|
return attn_tp_rank, attn_tp_size, dp_rank
|
39
39
|
|
40
40
|
|
41
|
-
def initialize_dp_attention(
|
41
|
+
def initialize_dp_attention(
|
42
|
+
enable_dp_attention: bool,
|
43
|
+
tp_rank: int,
|
44
|
+
tp_size: int,
|
45
|
+
dp_size: int,
|
46
|
+
):
|
42
47
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
43
48
|
|
44
49
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
@@ -46,7 +51,11 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|
46
51
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
47
52
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
48
53
|
)
|
49
|
-
|
54
|
+
|
55
|
+
if enable_dp_attention:
|
56
|
+
_DP_SIZE = dp_size
|
57
|
+
else:
|
58
|
+
_DP_SIZE = 1
|
50
59
|
|
51
60
|
tp_group = get_tp_group()
|
52
61
|
_ATTN_TP_GROUP = GroupCoordinator(
|
@@ -54,7 +63,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|
54
63
|
list(range(head, head + _ATTN_TP_SIZE))
|
55
64
|
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
56
65
|
],
|
57
|
-
|
66
|
+
tp_group.local_rank,
|
58
67
|
torch.distributed.get_backend(tp_group.device_group),
|
59
68
|
SYNC_TOKEN_IDS_ACROSS_TP,
|
60
69
|
False,
|
@@ -169,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
|
169
178
|
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
170
179
|
|
171
180
|
|
172
|
-
def
|
181
|
+
def _dp_gather(
|
173
182
|
global_tokens: torch.Tensor,
|
174
183
|
local_tokens: torch.Tensor,
|
175
184
|
forward_batch: ForwardBatch,
|
176
|
-
|
185
|
+
is_partial: bool,
|
177
186
|
):
|
178
187
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
179
188
|
|
180
189
|
global_tokens.fill_(0)
|
181
190
|
assert local_tokens.is_contiguous()
|
182
191
|
assert global_tokens.is_contiguous()
|
183
|
-
|
184
|
-
|
185
|
-
):
|
192
|
+
|
193
|
+
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
186
194
|
assert (
|
187
195
|
global_tokens.untyped_storage().data_ptr()
|
188
196
|
!= local_tokens.untyped_storage().data_ptr()
|
@@ -205,6 +213,22 @@ def dp_gather(
|
|
205
213
|
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
206
214
|
|
207
215
|
|
216
|
+
def dp_gather_partial(
|
217
|
+
global_tokens: torch.Tensor,
|
218
|
+
local_tokens: torch.Tensor,
|
219
|
+
forward_batch: ForwardBatch,
|
220
|
+
):
|
221
|
+
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
|
222
|
+
|
223
|
+
|
224
|
+
def dp_gather_replicate(
|
225
|
+
global_tokens: torch.Tensor,
|
226
|
+
local_tokens: torch.Tensor,
|
227
|
+
forward_batch: ForwardBatch,
|
228
|
+
):
|
229
|
+
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
|
230
|
+
|
231
|
+
|
208
232
|
def dp_scatter(
|
209
233
|
local_tokens: torch.Tensor, # output
|
210
234
|
global_tokens: torch.Tensor, # input
|
@@ -225,16 +249,3 @@ def dp_scatter(
|
|
225
249
|
memcpy_triton(
|
226
250
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
227
251
|
)
|
228
|
-
|
229
|
-
|
230
|
-
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
231
|
-
def do_logits_dp_scatter(logits: torch.Tensor):
|
232
|
-
local_logits = torch.empty(
|
233
|
-
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
234
|
-
dtype=logits.dtype,
|
235
|
-
device=logits.device,
|
236
|
-
)
|
237
|
-
dp_scatter(local_logits, logits, forward_batch)
|
238
|
-
return local_logits
|
239
|
-
|
240
|
-
return do_logits_dp_scatter
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -21,7 +21,9 @@ import torch.nn as nn
|
|
21
21
|
|
22
22
|
from sglang.srt.utils import is_cuda_available
|
23
23
|
|
24
|
-
|
24
|
+
_is_cuda = is_cuda_available()
|
25
|
+
|
26
|
+
if _is_cuda:
|
25
27
|
from sgl_kernel import (
|
26
28
|
fused_add_rmsnorm,
|
27
29
|
gemma_fused_add_rmsnorm,
|
@@ -117,7 +119,27 @@ class GemmaRMSNorm(CustomOp):
|
|
117
119
|
return out
|
118
120
|
|
119
121
|
|
120
|
-
|
122
|
+
class Gemma3RMSNorm(nn.Module):
|
123
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
124
|
+
super().__init__()
|
125
|
+
self.eps = eps
|
126
|
+
self.weight = nn.Parameter(torch.zeros(dim))
|
127
|
+
|
128
|
+
def _norm(self, x):
|
129
|
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
130
|
+
|
131
|
+
def forward(self, x):
|
132
|
+
output = self._norm(x.float())
|
133
|
+
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
134
|
+
# See https://github.com/huggingface/transformers/pull/29402
|
135
|
+
output = output * (1.0 + self.weight.float())
|
136
|
+
return output.type_as(x)
|
137
|
+
|
138
|
+
def extra_repr(self):
|
139
|
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
140
|
+
|
141
|
+
|
142
|
+
if not _is_cuda:
|
121
143
|
logger.info(
|
122
144
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
123
145
|
)
|