sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
15
15
 
16
16
  import torch
17
17
 
18
+ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
19
+ import torch._dynamo
20
+
21
+ torch._dynamo.config.suppress_errors = True
22
+
18
23
  from sglang.global_config import global_config
19
24
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
25
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend):
82
87
  self.max_context_len = model_runner.model_config.context_len
83
88
  self.skip_prefill = skip_prefill
84
89
  self.is_multimodal = model_runner.model_config.is_multimodal
85
- self.kv_cache_dtype = model_runner.kv_cache_dtype
86
- self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
87
90
 
88
91
  assert not (
89
92
  model_runner.sliding_window_size is not None
@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend):
268
271
  cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
269
272
  ]
270
273
 
274
+ # Ensure tensors are properly allocated
275
+ for i in range(self.num_wrappers):
276
+ # Force allocation by performing a small operation
277
+ if len(self.cuda_graph_kv_indices[i]) > 0:
278
+ self.cuda_graph_kv_indices[i][0] = 0
279
+
271
280
  if not self.skip_prefill:
272
281
  self.cuda_graph_custom_mask = torch.zeros(
273
282
  (max_bs * self.max_context_len),
@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend):
396
405
  forward_batch: ForwardBatch,
397
406
  save_kv_cache=True,
398
407
  ):
399
- k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
400
- v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
401
408
  prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
402
409
  self._get_wrapper_idx(layer)
403
410
  ]
@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend):
414
421
  assert v is not None
415
422
  if save_kv_cache:
416
423
  forward_batch.token_to_kv_pool.set_kv_buffer(
417
- layer, cache_loc, k, v, k_scale, v_scale
424
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
418
425
  )
419
426
 
420
427
  o = prefill_wrapper_paged.forward(
@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend):
424
431
  sm_scale=layer.scaling,
425
432
  window_left=layer.sliding_window_size,
426
433
  logits_soft_cap=logits_soft_cap,
427
- k_scale=k_scale,
428
- v_scale=v_scale,
434
+ k_scale=layer.k_scale,
435
+ v_scale=layer.v_scale,
429
436
  )
430
437
  else:
431
438
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend):
452
459
 
453
460
  if save_kv_cache:
454
461
  forward_batch.token_to_kv_pool.set_kv_buffer(
455
- layer, cache_loc, k, v, k_scale, v_scale
462
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
456
463
  )
457
464
 
458
465
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend):
466
473
  forward_batch: ForwardBatch,
467
474
  save_kv_cache=True,
468
475
  ):
469
- k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
470
- v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
471
476
  decode_wrapper = self.forward_metadata.decode_wrappers[
472
477
  self._get_wrapper_idx(layer)
473
478
  ]
@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend):
481
486
  assert v is not None
482
487
  if save_kv_cache:
483
488
  forward_batch.token_to_kv_pool.set_kv_buffer(
484
- layer, cache_loc, k, v, k_scale, v_scale
489
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
485
490
  )
486
491
 
492
+ # Call the wrapped function
487
493
  o = decode_wrapper.forward(
488
494
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
489
495
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
490
496
  sm_scale=layer.scaling,
491
497
  logits_soft_cap=layer.logit_cap,
492
- k_scale=k_scale,
493
- v_scale=v_scale,
498
+ k_scale=layer.k_scale,
499
+ v_scale=layer.v_scale,
494
500
  )
495
501
 
496
502
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -1146,8 +1152,9 @@ def fast_decode_plan(
1146
1152
  pos_encoding_mode: str = "NONE",
1147
1153
  window_left: int = -1,
1148
1154
  logits_soft_cap: Optional[float] = None,
1149
- data_type: Union[str, torch.dtype] = "float16",
1150
1155
  q_data_type: Optional[Union[str, torch.dtype]] = None,
1156
+ kv_data_type: Optional[Union[str, torch.dtype]] = None,
1157
+ data_type: Optional[Union[str, torch.dtype]] = None,
1151
1158
  sm_scale: Optional[float] = None,
1152
1159
  rope_scale: Optional[float] = None,
1153
1160
  rope_theta: Optional[float] = None,
@@ -1163,6 +1170,18 @@ def fast_decode_plan(
1163
1170
  if logits_soft_cap is None:
1164
1171
  logits_soft_cap = 0.0
1165
1172
 
1173
+ # Handle data types consistently
1174
+ if data_type is not None:
1175
+ if q_data_type is None:
1176
+ q_data_type = data_type
1177
+ if kv_data_type is None:
1178
+ kv_data_type = data_type
1179
+ elif q_data_type is None:
1180
+ q_data_type = "float16"
1181
+
1182
+ if kv_data_type is None:
1183
+ kv_data_type = q_data_type
1184
+
1166
1185
  if self.use_tensor_cores:
1167
1186
  qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
1168
1187
 
@@ -1178,36 +1197,33 @@ def fast_decode_plan(
1178
1197
  raise ValueError(
1179
1198
  "The size of indices should be less than or equal to the allocated buffer"
1180
1199
  )
1181
- # Skip these copies because we directly write to them during prepartion
1182
- # self._paged_kv_indptr_buf.copy_(indptr)
1183
- # self._paged_kv_indices_buf[: len(indices)] = indices
1184
- # self._paged_kv_last_page_len_buf.copy_(last_page_len)
1185
1200
  else:
1186
1201
  self._paged_kv_indptr_buf = indptr
1187
1202
  self._paged_kv_indices_buf = indices
1188
1203
  self._paged_kv_last_page_len_buf = last_page_len
1189
- self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
1190
-
1191
- # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
1192
- if not q_data_type:
1193
- q_data_type = data_type
1194
-
1195
- if not hasattr(self, "empty_q_data"):
1196
- self.empty_q_data = torch.empty(
1197
- 0,
1198
- dtype=(
1199
- getattr(torch, q_data_type)
1200
- if isinstance(q_data_type, str)
1201
- else q_data_type
1202
- ),
1203
- )
1204
- self.empty_kv_cache = torch.empty(
1205
- 0,
1206
- dtype=(
1207
- getattr(torch, data_type) if isinstance(data_type, str) else data_type
1208
- ),
1209
- )
1210
- self.last_page_len = torch.ones(32768, dtype=torch.int32)
1204
+ if self.use_tensor_cores:
1205
+ self._qo_indptr_buf = qo_indptr_host.to(
1206
+ self.device, non_blocking=non_blocking
1207
+ )
1208
+
1209
+ # Create empty tensors for dtype info if needed
1210
+ empty_q_data = torch.empty(
1211
+ 0,
1212
+ dtype=(
1213
+ getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1214
+ ),
1215
+ device=self.device,
1216
+ )
1217
+
1218
+ empty_kv_cache = torch.empty(
1219
+ 0,
1220
+ dtype=(
1221
+ getattr(torch, kv_data_type)
1222
+ if isinstance(kv_data_type, str)
1223
+ else kv_data_type
1224
+ ),
1225
+ device=self.device,
1226
+ )
1211
1227
 
1212
1228
  indptr_host = (
1213
1229
  global_override_indptr_cpu
@@ -1215,48 +1231,57 @@ def fast_decode_plan(
1215
1231
  else indptr.cpu()
1216
1232
  )
1217
1233
 
1218
- if self.use_tensor_cores:
1219
- kv_lens_arr_host = get_seq_lens(
1220
- indptr_host, self.last_page_len[:batch_size], page_size
1221
- )
1222
-
1223
- self._plan_info = self._cached_module.plan(
1224
- self._float_workspace_buffer,
1225
- self._int_workspace_buffer,
1226
- self._pin_memory_int_workspace_buffer,
1227
- qo_indptr_host,
1228
- indptr_host,
1229
- kv_lens_arr_host,
1230
- batch_size, # total_num_rows
1231
- batch_size,
1232
- num_qo_heads,
1233
- num_kv_heads,
1234
- page_size,
1235
- self.is_cuda_graph_enabled,
1236
- head_dim,
1237
- head_dim,
1238
- False, # causal
1239
- torch.cuda.current_stream().cuda_stream,
1240
- )
1241
- else:
1242
- self._plan_info = self._cached_module.plan(
1243
- self._float_workspace_buffer,
1244
- self._int_workspace_buffer,
1245
- self._pin_memory_int_workspace_buffer,
1246
- indptr_host,
1247
- batch_size,
1248
- num_qo_heads,
1249
- num_kv_heads,
1250
- page_size,
1251
- self.is_cuda_graph_enabled,
1252
- window_left,
1253
- logits_soft_cap,
1254
- head_dim,
1255
- head_dim,
1256
- self.empty_q_data,
1257
- self.empty_kv_cache,
1258
- torch.cuda.current_stream().cuda_stream,
1259
- )
1234
+ with torch.cuda.device(self.device):
1235
+
1236
+ if self.use_tensor_cores:
1237
+ # ALSO convert last_page_len to CPU
1238
+ last_page_len_host = last_page_len.cpu()
1239
+
1240
+ kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1241
+
1242
+ try:
1243
+ # Make sure we pass exactly 15 arguments for tensor core version
1244
+ self._plan_info = self._cached_module.plan(
1245
+ self._float_workspace_buffer,
1246
+ self._int_workspace_buffer,
1247
+ self._pin_memory_int_workspace_buffer,
1248
+ qo_indptr_host,
1249
+ indptr_host,
1250
+ kv_lens_arr_host,
1251
+ batch_size, # total_num_rows
1252
+ batch_size,
1253
+ num_qo_heads,
1254
+ num_kv_heads,
1255
+ page_size,
1256
+ self.is_cuda_graph_enabled,
1257
+ head_dim,
1258
+ head_dim,
1259
+ False, # causal
1260
+ )
1261
+ except Exception as e:
1262
+ raise RuntimeError(f"Error in standard plan: {e}")
1263
+ else:
1264
+ try:
1265
+ # Make sure we pass exactly 15 arguments for standard version
1266
+ self._plan_info = self._cached_module.plan(
1267
+ self._float_workspace_buffer,
1268
+ self._int_workspace_buffer,
1269
+ self._pin_memory_int_workspace_buffer,
1270
+ indptr_host,
1271
+ batch_size,
1272
+ num_qo_heads,
1273
+ num_kv_heads,
1274
+ page_size,
1275
+ self.is_cuda_graph_enabled,
1276
+ window_left,
1277
+ logits_soft_cap,
1278
+ head_dim,
1279
+ head_dim,
1280
+ empty_q_data,
1281
+ empty_kv_cache,
1282
+ )
1283
+ except Exception as e:
1284
+ raise RuntimeError(f"Error in standard plan: {e}")
1260
1285
 
1261
1286
  self._pos_encoding_mode = pos_encoding_mode
1262
1287
  self._window_left = window_left
@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
9
9
  More details can be found in https://docs.flashinfer.ai/api/mla.html
10
10
  """
11
11
 
12
+ import os
12
13
  from dataclasses import dataclass
13
14
  from functools import partial
14
15
  from typing import TYPE_CHECKING, Callable, Optional, Union
@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
16
17
  import torch
17
18
  import triton
18
19
 
20
+ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
21
+ import torch._dynamo
22
+
23
+ torch._dynamo.config.suppress_errors = True
24
+
19
25
  from sglang.global_config import global_config
20
26
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
21
27
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
388
394
  k,
389
395
  v,
390
396
  )
397
+
398
+ # Reshape inputs
391
399
  reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
392
400
  k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
393
- reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
401
+
402
+ # Direct call to run without the wrapper
394
403
  o = decode_wrapper.run(
395
404
  reshaped_q[:, :, : layer.v_head_dim],
396
405
  reshaped_q[:, :, layer.v_head_dim :],
397
- reshaped_k[:, :, : layer.v_head_dim],
398
- reshaped_k[:, :, layer.v_head_dim :],
406
+ k_buffer[:, :, : layer.v_head_dim],
407
+ k_buffer[:, :, layer.v_head_dim :],
399
408
  )
400
409
 
401
410
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -825,16 +834,18 @@ def fast_mla_decode_plan(
825
834
  self._sm_scale = sm_scale
826
835
 
827
836
  with self.device as device:
828
- stream = torch.cuda.current_stream(device).cuda_stream
829
- self._cached_module.plan(
830
- self._float_workspace_buffer,
831
- self._int_workspace_buffer,
832
- self._pin_memory_int_workspace_buffer,
833
- qo_indptr_cpu,
834
- kv_indptr_cpu,
835
- kv_len_arr_cpu,
836
- num_heads,
837
- head_dim_ckv,
838
- causal,
839
- stream,
840
- )
837
+ try:
838
+ # Standard version with just the required arguments (no use_profiler)
839
+ self._cached_module.plan.default(
840
+ self._float_workspace_buffer,
841
+ self._int_workspace_buffer,
842
+ self._pin_memory_int_workspace_buffer,
843
+ qo_indptr_cpu,
844
+ kv_indptr_cpu,
845
+ kv_len_arr_cpu,
846
+ num_heads,
847
+ head_dim_ckv,
848
+ causal,
849
+ )
850
+ except Exception as e:
851
+ raise RuntimeError(f"Error in alternate MLA plan: {e}")
@@ -241,6 +241,9 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
241
241
  seq_lens_cpu,
242
242
  )
243
243
 
244
+ def get_cuda_graph_seq_len_fill_value(self):
245
+ return 1024
246
+
244
247
  def forward_decode(
245
248
  self,
246
249
  q: torch.Tensor,
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
49
49
  kv_indices_ptr,
50
50
  req_to_token_ptr_stride: tl.constexpr,
51
51
  kv_indices_ptr_stride: tl.constexpr,
52
+ PAGED_SIZE: tl.constexpr = 64,
52
53
  ):
53
- PAGED_SIZE: tl.constexpr = 64
54
54
  BLOCK_SIZE: tl.constexpr = 4096
55
55
  NUM_PAGE_PER_BLOCK: tl.constexpr = 64
56
56
  pid = tl.program_id(axis=0)
@@ -43,6 +43,7 @@ def initialize_dp_attention(
43
43
  tp_rank: int,
44
44
  tp_size: int,
45
45
  dp_size: int,
46
+ pp_size: int,
46
47
  ):
47
48
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
48
49
 
@@ -53,17 +54,19 @@ def initialize_dp_attention(
53
54
  )
54
55
 
55
56
  if enable_dp_attention:
57
+ local_rank = tp_rank % (tp_size // dp_size)
56
58
  _DP_SIZE = dp_size
57
59
  else:
60
+ local_rank = tp_rank
58
61
  _DP_SIZE = 1
59
62
 
60
63
  tp_group = get_tp_group()
61
64
  _ATTN_TP_GROUP = GroupCoordinator(
62
65
  [
63
66
  list(range(head, head + _ATTN_TP_SIZE))
64
- for head in range(0, tp_size, _ATTN_TP_SIZE)
67
+ for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
65
68
  ],
66
- tp_group.local_rank,
69
+ local_rank,
67
70
  torch.distributed.get_backend(tp_group.device_group),
68
71
  SYNC_TOKEN_IDS_ACROSS_TP,
69
72
  False,
@@ -84,9 +84,7 @@ class DeepEPBuffer:
84
84
  num_nvl_bytes,
85
85
  num_rdma_bytes,
86
86
  low_latency_mode=deepep_mode.enable_low_latency(),
87
- num_qps_per_rank=(
88
- num_experts // group.size() if deepep_mode.enable_low_latency() else 1
89
- ),
87
+ num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
90
88
  )
91
89
  return cls._buffer
92
90
 
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }