sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -17,15 +17,21 @@ from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
18
  from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
19
 
20
- not_hip = False
20
+ is_hip_flag = False
21
21
  if not is_hip():
22
22
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
23
 
24
- not_hip = True
24
+ is_hip_flag = False
25
+ else:
26
+ is_hip_flag = True
25
27
 
26
28
  logger = logging.getLogger(__name__)
27
29
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
28
30
 
31
+ enable_moe_align_block_size_triton = bool(
32
+ int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
33
+ )
34
+
29
35
 
30
36
  @triton.jit
31
37
  def fused_moe_kernel(
@@ -222,6 +228,139 @@ def fused_moe_kernel(
222
228
  tl.store(c_ptrs, accumulator, mask=c_mask)
223
229
 
224
230
 
231
+ def ceil_div(a, b):
232
+ return (a + b - 1) // b
233
+
234
+
235
+ @triton.jit
236
+ def moe_align_block_size_stage1(
237
+ topk_ids_ptr,
238
+ tokens_cnts_ptr,
239
+ num_experts: tl.constexpr,
240
+ numel: tl.constexpr,
241
+ tokens_per_thread: tl.constexpr,
242
+ ):
243
+ pid = tl.program_id(0)
244
+
245
+ start_idx = pid * tokens_per_thread
246
+
247
+ off_c = (pid + 1) * num_experts
248
+
249
+ for i in range(tokens_per_thread):
250
+ if start_idx + i < numel:
251
+ idx = tl.load(topk_ids_ptr + start_idx + i)
252
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
253
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
254
+
255
+
256
+ @triton.jit
257
+ def moe_align_block_size_stage2(
258
+ tokens_cnts_ptr,
259
+ num_experts: tl.constexpr,
260
+ ):
261
+ pid = tl.program_id(0)
262
+
263
+ last_cnt = 0
264
+ for i in range(1, num_experts + 1):
265
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
266
+ last_cnt = last_cnt + token_cnt
267
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
268
+
269
+
270
+ @triton.jit
271
+ def moe_align_block_size_stage3(
272
+ total_tokens_post_pad_ptr,
273
+ tokens_cnts_ptr,
274
+ cumsum_ptr,
275
+ num_experts: tl.constexpr,
276
+ block_size: tl.constexpr,
277
+ ):
278
+ last_cumsum = 0
279
+ off_cnt = num_experts * num_experts
280
+ for i in range(1, num_experts + 1):
281
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
282
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
283
+ tl.store(cumsum_ptr + i, last_cumsum)
284
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
285
+
286
+
287
+ @triton.jit
288
+ def moe_align_block_size_stage4(
289
+ topk_ids_ptr,
290
+ sorted_token_ids_ptr,
291
+ expert_ids_ptr,
292
+ tokens_cnts_ptr,
293
+ cumsum_ptr,
294
+ num_experts: tl.constexpr,
295
+ block_size: tl.constexpr,
296
+ numel: tl.constexpr,
297
+ tokens_per_thread: tl.constexpr,
298
+ ):
299
+ pid = tl.program_id(0)
300
+ start_idx = tl.load(cumsum_ptr + pid)
301
+ end_idx = tl.load(cumsum_ptr + pid + 1)
302
+
303
+ for i in range(start_idx, end_idx, block_size):
304
+ tl.store(expert_ids_ptr + i // block_size, pid)
305
+
306
+ start_idx = pid * tokens_per_thread
307
+ off_t = pid * num_experts
308
+
309
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
310
+ expert_id = tl.load(topk_ids_ptr + i)
311
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
312
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
313
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
314
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
315
+
316
+
317
+ def moe_align_block_size_triton(
318
+ topk_ids: torch.Tensor,
319
+ num_experts: int,
320
+ block_size: int,
321
+ sorted_token_ids: torch.Tensor,
322
+ expert_ids: torch.Tensor,
323
+ num_tokens_post_pad: torch.Tensor,
324
+ ) -> None:
325
+ numel = topk_ids.numel()
326
+ grid = (num_experts,)
327
+ tokens_cnts = torch.zeros(
328
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
329
+ )
330
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
331
+ tokens_per_thread = ceil_div(numel, num_experts)
332
+
333
+ moe_align_block_size_stage1[grid](
334
+ topk_ids,
335
+ tokens_cnts,
336
+ num_experts,
337
+ numel,
338
+ tokens_per_thread,
339
+ )
340
+ moe_align_block_size_stage2[grid](
341
+ tokens_cnts,
342
+ num_experts,
343
+ )
344
+ moe_align_block_size_stage3[(1,)](
345
+ num_tokens_post_pad,
346
+ tokens_cnts,
347
+ cumsum,
348
+ num_experts,
349
+ block_size,
350
+ )
351
+ moe_align_block_size_stage4[grid](
352
+ topk_ids,
353
+ sorted_token_ids,
354
+ expert_ids,
355
+ tokens_cnts,
356
+ cumsum,
357
+ num_experts,
358
+ block_size,
359
+ numel,
360
+ tokens_per_thread,
361
+ )
362
+
363
+
225
364
  def moe_align_block_size(
226
365
  topk_ids: torch.Tensor, block_size: int, num_experts: int
227
366
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -272,24 +411,36 @@ def moe_align_block_size(
272
411
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
273
412
  )
274
413
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
275
- if not_hip and num_experts >= 224:
276
- token_cnts_buffer = torch.empty(
277
- (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
278
- )
279
- cumsum_buffer = torch.empty(
280
- num_experts + 1, dtype=torch.int32, device=topk_ids.device
281
- )
414
+ if num_experts >= 224:
415
+ if enable_moe_align_block_size_triton or is_hip_flag:
416
+ moe_align_block_size_triton(
417
+ topk_ids,
418
+ num_experts,
419
+ block_size,
420
+ sorted_ids,
421
+ expert_ids,
422
+ num_tokens_post_pad,
423
+ )
424
+ else:
425
+ token_cnts_buffer = torch.empty(
426
+ (num_experts + 1) * num_experts,
427
+ dtype=torch.int32,
428
+ device=topk_ids.device,
429
+ )
430
+ cumsum_buffer = torch.empty(
431
+ num_experts + 1, dtype=torch.int32, device=topk_ids.device
432
+ )
282
433
 
283
- sgl_moe_align_block_size(
284
- topk_ids,
285
- num_experts,
286
- block_size,
287
- sorted_ids,
288
- expert_ids,
289
- num_tokens_post_pad,
290
- token_cnts_buffer,
291
- cumsum_buffer,
292
- )
434
+ sgl_moe_align_block_size(
435
+ topk_ids,
436
+ num_experts,
437
+ block_size,
438
+ sorted_ids,
439
+ expert_ids,
440
+ num_tokens_post_pad,
441
+ token_cnts_buffer,
442
+ cumsum_buffer,
443
+ )
293
444
  else:
294
445
  ops.moe_align_block_size(
295
446
  topk_ids,
@@ -326,9 +477,9 @@ def invoke_fused_moe_kernel(
326
477
 
327
478
  padded_size = 0
328
479
  if use_fp8_w8a8:
329
- padded_size = padding_size
330
480
  assert B_scale is not None
331
481
  if block_shape is None:
482
+ padded_size = padding_size
332
483
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
333
484
  else:
334
485
  assert len(block_shape) == 2
@@ -463,7 +614,7 @@ def get_default_config(
463
614
  "BLOCK_SIZE_K": 128,
464
615
  "GROUP_SIZE_M": 32,
465
616
  "num_warps": 8,
466
- "num_stages": 4,
617
+ "num_stages": 2 if is_hip_flag else 4,
467
618
  }
468
619
  if M <= E:
469
620
  config = {
@@ -472,7 +623,7 @@ def get_default_config(
472
623
  "BLOCK_SIZE_K": 128,
473
624
  "GROUP_SIZE_M": 1,
474
625
  "num_warps": 4,
475
- "num_stages": 4,
626
+ "num_stages": 2 if is_hip_flag else 4,
476
627
  }
477
628
  else:
478
629
  # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -482,7 +633,7 @@ def get_default_config(
482
633
  "BLOCK_SIZE_K": block_shape[1],
483
634
  "GROUP_SIZE_M": 32,
484
635
  "num_warps": 4,
485
- "num_stages": 3,
636
+ "num_stages": 2 if is_hip_flag else 3,
486
637
  }
487
638
  else:
488
639
  config = {
@@ -727,7 +878,7 @@ def fused_experts_impl(
727
878
  block_shape: Optional[List[int]] = None,
728
879
  ):
729
880
  padded_size = padding_size
730
- if not use_fp8_w8a8:
881
+ if not use_fp8_w8a8 or block_shape is not None:
731
882
  padded_size = 0
732
883
 
733
884
  # Check constraints.
@@ -854,11 +1005,29 @@ def fused_experts_impl(
854
1005
  block_shape=block_shape,
855
1006
  )
856
1007
 
857
- torch.sum(
858
- intermediate_cache3.view(*intermediate_cache3.shape),
859
- dim=1,
860
- out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
861
- )
1008
+ if is_hip_flag:
1009
+ ops.moe_sum(
1010
+ intermediate_cache3.view(*intermediate_cache3.shape),
1011
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1012
+ )
1013
+ else:
1014
+ if topk_ids.shape[1] == 1:
1015
+ out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
1016
+ intermediate_cache3[:, 0]
1017
+ )
1018
+ elif topk_ids.shape[1] == 2:
1019
+ torch.add(
1020
+ intermediate_cache3[:, 0],
1021
+ intermediate_cache3[:, 1],
1022
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
1023
+ ).squeeze(dim=1)
1024
+ elif topk_ids.shape[1] > 2:
1025
+ torch.sum(
1026
+ intermediate_cache3.view(*intermediate_cache3.shape),
1027
+ dim=1,
1028
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
1029
+ )
1030
+
862
1031
  return out_hidden_states
863
1032
 
864
1033
 
@@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module):
204
204
  prefix: str = "",
205
205
  custom_routing_function: Optional[Callable] = None,
206
206
  correction_bias: Optional[torch.Tensor] = None,
207
+ use_presharded_weights: bool = False,
207
208
  ):
208
209
  super().__init__()
209
210
 
@@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module):
243
244
  params_dtype=params_dtype,
244
245
  weight_loader=self.weight_loader,
245
246
  )
247
+ self.use_presharded_weights = use_presharded_weights
246
248
 
247
249
  def _load_per_tensor_weight_scale(
248
250
  self,
@@ -321,9 +323,12 @@ class FusedMoE(torch.nn.Module):
321
323
  # Index the loaded weight for tp sharding.
322
324
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
323
325
  shard_size = expert_data.shape[shard_dim] // 2
324
- loaded_weight = loaded_weight.narrow(
325
- shard_dim, shard_size * tp_rank, shard_size
326
- )
326
+
327
+ if not self.use_presharded_weights:
328
+ loaded_weight = loaded_weight.narrow(
329
+ shard_dim, shard_size * tp_rank, shard_size
330
+ )
331
+
327
332
  # Narrow parameter and load.
328
333
  # w1, gate_proj: Load into first logical weight of w13.
329
334
  if shard_id == "w1":
@@ -347,9 +352,12 @@ class FusedMoE(torch.nn.Module):
347
352
  # down_proj: "RowParallel" so tp sharding on input_dim
348
353
  # Narrow parameter and load.
349
354
  shard_size = expert_data.shape[shard_dim]
350
- loaded_weight = loaded_weight.narrow(
351
- shard_dim, shard_size * tp_rank, shard_size
352
- )
355
+
356
+ if not self.use_presharded_weights:
357
+ loaded_weight = loaded_weight.narrow(
358
+ shard_dim, shard_size * tp_rank, shard_size
359
+ )
360
+
353
361
  # w2, down_proj: Load into only logical weight of w2.
354
362
  expert_data.copy_(loaded_weight)
355
363
 
@@ -390,7 +398,6 @@ class FusedMoE(torch.nn.Module):
390
398
  shard_id: str,
391
399
  expert_id: int,
392
400
  ) -> None:
393
-
394
401
  # compressed-tensors checkpoints with packed weights are stored flipped
395
402
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
396
403
  # against known CompressionFormat enum values that have this quality