sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
15
15
|
|
16
16
|
import torch
|
17
17
|
|
18
|
+
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
19
|
+
import torch._dynamo
|
20
|
+
|
21
|
+
torch._dynamo.config.suppress_errors = True
|
22
|
+
|
18
23
|
from sglang.global_config import global_config
|
19
24
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
25
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
82
87
|
self.max_context_len = model_runner.model_config.context_len
|
83
88
|
self.skip_prefill = skip_prefill
|
84
89
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
85
|
-
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
86
|
-
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
87
90
|
|
88
91
|
assert not (
|
89
92
|
model_runner.sliding_window_size is not None
|
@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
268
271
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
269
272
|
]
|
270
273
|
|
274
|
+
# Ensure tensors are properly allocated
|
275
|
+
for i in range(self.num_wrappers):
|
276
|
+
# Force allocation by performing a small operation
|
277
|
+
if len(self.cuda_graph_kv_indices[i]) > 0:
|
278
|
+
self.cuda_graph_kv_indices[i][0] = 0
|
279
|
+
|
271
280
|
if not self.skip_prefill:
|
272
281
|
self.cuda_graph_custom_mask = torch.zeros(
|
273
282
|
(max_bs * self.max_context_len),
|
@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
396
405
|
forward_batch: ForwardBatch,
|
397
406
|
save_kv_cache=True,
|
398
407
|
):
|
399
|
-
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
400
|
-
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
401
408
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
402
409
|
self._get_wrapper_idx(layer)
|
403
410
|
]
|
@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
414
421
|
assert v is not None
|
415
422
|
if save_kv_cache:
|
416
423
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
417
|
-
layer, cache_loc, k, v, k_scale, v_scale
|
424
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
418
425
|
)
|
419
426
|
|
420
427
|
o = prefill_wrapper_paged.forward(
|
@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
424
431
|
sm_scale=layer.scaling,
|
425
432
|
window_left=layer.sliding_window_size,
|
426
433
|
logits_soft_cap=logits_soft_cap,
|
427
|
-
k_scale=k_scale,
|
428
|
-
v_scale=v_scale,
|
434
|
+
k_scale=layer.k_scale,
|
435
|
+
v_scale=layer.v_scale,
|
429
436
|
)
|
430
437
|
else:
|
431
438
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
452
459
|
|
453
460
|
if save_kv_cache:
|
454
461
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
455
|
-
layer, cache_loc, k, v, k_scale, v_scale
|
462
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
456
463
|
)
|
457
464
|
|
458
465
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
466
473
|
forward_batch: ForwardBatch,
|
467
474
|
save_kv_cache=True,
|
468
475
|
):
|
469
|
-
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
470
|
-
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
471
476
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
472
477
|
self._get_wrapper_idx(layer)
|
473
478
|
]
|
@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
481
486
|
assert v is not None
|
482
487
|
if save_kv_cache:
|
483
488
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
484
|
-
layer, cache_loc, k, v, k_scale, v_scale
|
489
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
485
490
|
)
|
486
491
|
|
492
|
+
# Call the wrapped function
|
487
493
|
o = decode_wrapper.forward(
|
488
494
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
489
495
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
490
496
|
sm_scale=layer.scaling,
|
491
497
|
logits_soft_cap=layer.logit_cap,
|
492
|
-
k_scale=k_scale,
|
493
|
-
v_scale=v_scale,
|
498
|
+
k_scale=layer.k_scale,
|
499
|
+
v_scale=layer.v_scale,
|
494
500
|
)
|
495
501
|
|
496
502
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -1146,8 +1152,9 @@ def fast_decode_plan(
|
|
1146
1152
|
pos_encoding_mode: str = "NONE",
|
1147
1153
|
window_left: int = -1,
|
1148
1154
|
logits_soft_cap: Optional[float] = None,
|
1149
|
-
data_type: Union[str, torch.dtype] = "float16",
|
1150
1155
|
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1156
|
+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
1157
|
+
data_type: Optional[Union[str, torch.dtype]] = None,
|
1151
1158
|
sm_scale: Optional[float] = None,
|
1152
1159
|
rope_scale: Optional[float] = None,
|
1153
1160
|
rope_theta: Optional[float] = None,
|
@@ -1163,6 +1170,18 @@ def fast_decode_plan(
|
|
1163
1170
|
if logits_soft_cap is None:
|
1164
1171
|
logits_soft_cap = 0.0
|
1165
1172
|
|
1173
|
+
# Handle data types consistently
|
1174
|
+
if data_type is not None:
|
1175
|
+
if q_data_type is None:
|
1176
|
+
q_data_type = data_type
|
1177
|
+
if kv_data_type is None:
|
1178
|
+
kv_data_type = data_type
|
1179
|
+
elif q_data_type is None:
|
1180
|
+
q_data_type = "float16"
|
1181
|
+
|
1182
|
+
if kv_data_type is None:
|
1183
|
+
kv_data_type = q_data_type
|
1184
|
+
|
1166
1185
|
if self.use_tensor_cores:
|
1167
1186
|
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1168
1187
|
|
@@ -1178,36 +1197,33 @@ def fast_decode_plan(
|
|
1178
1197
|
raise ValueError(
|
1179
1198
|
"The size of indices should be less than or equal to the allocated buffer"
|
1180
1199
|
)
|
1181
|
-
# Skip these copies because we directly write to them during prepartion
|
1182
|
-
# self._paged_kv_indptr_buf.copy_(indptr)
|
1183
|
-
# self._paged_kv_indices_buf[: len(indices)] = indices
|
1184
|
-
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
1185
1200
|
else:
|
1186
1201
|
self._paged_kv_indptr_buf = indptr
|
1187
1202
|
self._paged_kv_indices_buf = indices
|
1188
1203
|
self._paged_kv_last_page_len_buf = last_page_len
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
)
|
1210
|
-
self.
|
1204
|
+
if self.use_tensor_cores:
|
1205
|
+
self._qo_indptr_buf = qo_indptr_host.to(
|
1206
|
+
self.device, non_blocking=non_blocking
|
1207
|
+
)
|
1208
|
+
|
1209
|
+
# Create empty tensors for dtype info if needed
|
1210
|
+
empty_q_data = torch.empty(
|
1211
|
+
0,
|
1212
|
+
dtype=(
|
1213
|
+
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
1214
|
+
),
|
1215
|
+
device=self.device,
|
1216
|
+
)
|
1217
|
+
|
1218
|
+
empty_kv_cache = torch.empty(
|
1219
|
+
0,
|
1220
|
+
dtype=(
|
1221
|
+
getattr(torch, kv_data_type)
|
1222
|
+
if isinstance(kv_data_type, str)
|
1223
|
+
else kv_data_type
|
1224
|
+
),
|
1225
|
+
device=self.device,
|
1226
|
+
)
|
1211
1227
|
|
1212
1228
|
indptr_host = (
|
1213
1229
|
global_override_indptr_cpu
|
@@ -1215,48 +1231,57 @@ def fast_decode_plan(
|
|
1215
1231
|
else indptr.cpu()
|
1216
1232
|
)
|
1217
1233
|
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1234
|
+
with torch.cuda.device(self.device):
|
1235
|
+
|
1236
|
+
if self.use_tensor_cores:
|
1237
|
+
# ALSO convert last_page_len to CPU
|
1238
|
+
last_page_len_host = last_page_len.cpu()
|
1239
|
+
|
1240
|
+
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1241
|
+
|
1242
|
+
try:
|
1243
|
+
# Make sure we pass exactly 15 arguments for tensor core version
|
1244
|
+
self._plan_info = self._cached_module.plan(
|
1245
|
+
self._float_workspace_buffer,
|
1246
|
+
self._int_workspace_buffer,
|
1247
|
+
self._pin_memory_int_workspace_buffer,
|
1248
|
+
qo_indptr_host,
|
1249
|
+
indptr_host,
|
1250
|
+
kv_lens_arr_host,
|
1251
|
+
batch_size, # total_num_rows
|
1252
|
+
batch_size,
|
1253
|
+
num_qo_heads,
|
1254
|
+
num_kv_heads,
|
1255
|
+
page_size,
|
1256
|
+
self.is_cuda_graph_enabled,
|
1257
|
+
head_dim,
|
1258
|
+
head_dim,
|
1259
|
+
False, # causal
|
1260
|
+
)
|
1261
|
+
except Exception as e:
|
1262
|
+
raise RuntimeError(f"Error in standard plan: {e}")
|
1263
|
+
else:
|
1264
|
+
try:
|
1265
|
+
# Make sure we pass exactly 15 arguments for standard version
|
1266
|
+
self._plan_info = self._cached_module.plan(
|
1267
|
+
self._float_workspace_buffer,
|
1268
|
+
self._int_workspace_buffer,
|
1269
|
+
self._pin_memory_int_workspace_buffer,
|
1270
|
+
indptr_host,
|
1271
|
+
batch_size,
|
1272
|
+
num_qo_heads,
|
1273
|
+
num_kv_heads,
|
1274
|
+
page_size,
|
1275
|
+
self.is_cuda_graph_enabled,
|
1276
|
+
window_left,
|
1277
|
+
logits_soft_cap,
|
1278
|
+
head_dim,
|
1279
|
+
head_dim,
|
1280
|
+
empty_q_data,
|
1281
|
+
empty_kv_cache,
|
1282
|
+
)
|
1283
|
+
except Exception as e:
|
1284
|
+
raise RuntimeError(f"Error in standard plan: {e}")
|
1260
1285
|
|
1261
1286
|
self._pos_encoding_mode = pos_encoding_mode
|
1262
1287
|
self._window_left = window_left
|
@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
|
|
9
9
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
10
10
|
"""
|
11
11
|
|
12
|
+
import os
|
12
13
|
from dataclasses import dataclass
|
13
14
|
from functools import partial
|
14
15
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
16
17
|
import torch
|
17
18
|
import triton
|
18
19
|
|
20
|
+
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
21
|
+
import torch._dynamo
|
22
|
+
|
23
|
+
torch._dynamo.config.suppress_errors = True
|
24
|
+
|
19
25
|
from sglang.global_config import global_config
|
20
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
21
27
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
388
394
|
k,
|
389
395
|
v,
|
390
396
|
)
|
397
|
+
|
398
|
+
# Reshape inputs
|
391
399
|
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
392
400
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
393
|
-
|
401
|
+
|
402
|
+
# Direct call to run without the wrapper
|
394
403
|
o = decode_wrapper.run(
|
395
404
|
reshaped_q[:, :, : layer.v_head_dim],
|
396
405
|
reshaped_q[:, :, layer.v_head_dim :],
|
397
|
-
|
398
|
-
|
406
|
+
k_buffer[:, :, : layer.v_head_dim],
|
407
|
+
k_buffer[:, :, layer.v_head_dim :],
|
399
408
|
)
|
400
409
|
|
401
410
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -825,16 +834,18 @@ def fast_mla_decode_plan(
|
|
825
834
|
self._sm_scale = sm_scale
|
826
835
|
|
827
836
|
with self.device as device:
|
828
|
-
|
829
|
-
|
830
|
-
self.
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
837
|
+
try:
|
838
|
+
# Standard version with just the required arguments (no use_profiler)
|
839
|
+
self._cached_module.plan.default(
|
840
|
+
self._float_workspace_buffer,
|
841
|
+
self._int_workspace_buffer,
|
842
|
+
self._pin_memory_int_workspace_buffer,
|
843
|
+
qo_indptr_cpu,
|
844
|
+
kv_indptr_cpu,
|
845
|
+
kv_len_arr_cpu,
|
846
|
+
num_heads,
|
847
|
+
head_dim_ckv,
|
848
|
+
causal,
|
849
|
+
)
|
850
|
+
except Exception as e:
|
851
|
+
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
|
|
49
49
|
kv_indices_ptr,
|
50
50
|
req_to_token_ptr_stride: tl.constexpr,
|
51
51
|
kv_indices_ptr_stride: tl.constexpr,
|
52
|
+
PAGED_SIZE: tl.constexpr = 64,
|
52
53
|
):
|
53
|
-
PAGED_SIZE: tl.constexpr = 64
|
54
54
|
BLOCK_SIZE: tl.constexpr = 4096
|
55
55
|
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
56
56
|
pid = tl.program_id(axis=0)
|
@@ -43,6 +43,7 @@ def initialize_dp_attention(
|
|
43
43
|
tp_rank: int,
|
44
44
|
tp_size: int,
|
45
45
|
dp_size: int,
|
46
|
+
pp_size: int,
|
46
47
|
):
|
47
48
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
48
49
|
|
@@ -53,17 +54,19 @@ def initialize_dp_attention(
|
|
53
54
|
)
|
54
55
|
|
55
56
|
if enable_dp_attention:
|
57
|
+
local_rank = tp_rank % (tp_size // dp_size)
|
56
58
|
_DP_SIZE = dp_size
|
57
59
|
else:
|
60
|
+
local_rank = tp_rank
|
58
61
|
_DP_SIZE = 1
|
59
62
|
|
60
63
|
tp_group = get_tp_group()
|
61
64
|
_ATTN_TP_GROUP = GroupCoordinator(
|
62
65
|
[
|
63
66
|
list(range(head, head + _ATTN_TP_SIZE))
|
64
|
-
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
67
|
+
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
65
68
|
],
|
66
|
-
|
69
|
+
local_rank,
|
67
70
|
torch.distributed.get_backend(tp_group.device_group),
|
68
71
|
SYNC_TOKEN_IDS_ACROSS_TP,
|
69
72
|
False,
|
@@ -84,9 +84,7 @@ class DeepEPBuffer:
|
|
84
84
|
num_nvl_bytes,
|
85
85
|
num_rdma_bytes,
|
86
86
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
87
|
-
num_qps_per_rank=(
|
88
|
-
num_experts // group.size() if deepep_mode.enable_low_latency() else 1
|
89
|
-
),
|
87
|
+
num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
|
90
88
|
)
|
91
89
|
return cls._buffer
|
92
90
|
|
sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 256,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|