sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase):
235
235
  loaded_weight = loaded_weight[:1]
236
236
  else:
237
237
  raise ValueError(f"{loaded_weight} are not all equal")
238
-
239
- assert param.size() == loaded_weight.size()
238
+ assert (
239
+ param.size() == loaded_weight.size()
240
+ ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
240
241
  param.data.copy_(loaded_weight)
241
242
 
242
243
  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
61
61
  hidden_states: Optional[torch.Tensor] = None
62
62
 
63
63
  ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
64
- # The logprobs of the next tokens. shape: [#seq]
64
+ # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
65
65
  next_token_logprobs: Optional[torch.Tensor] = None
66
66
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
67
67
  next_token_top_logprobs_val: Optional[List] = None
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
91
91
  assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
92
  assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
93
  assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
- assert (
95
- w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
- and w1_scale.shape[2] == w1_q.shape[1] * 4
97
- ), "W1 scale shape mismatch"
98
- assert (
99
- w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
- and w2_scale.shape[2] == w2_q.shape[1] * 4
101
- ), "W2 scale shape mismatch"
102
94
 
103
95
  assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
96
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
- assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
97
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
98
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
99
  num_experts = w1_q.size(0)
108
100
  m = a.size(0)
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
1362
1362
  gateup_input,
1363
1363
  gateup_input_scale,
1364
1364
  )
1365
+
1366
+
1367
+ @triton.jit
1368
+ def compute_identity_kernel(
1369
+ top_k,
1370
+ hidden_states_ptr,
1371
+ expert_scales_ptr,
1372
+ num_tokens,
1373
+ output_ptr,
1374
+ hidden_dim,
1375
+ scales_stride,
1376
+ BLOCK_SIZE: tl.constexpr,
1377
+ ):
1378
+ pid = tl.program_id(0)
1379
+
1380
+ batch_id = pid // (hidden_dim // BLOCK_SIZE)
1381
+ dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
1382
+
1383
+ if batch_id >= num_tokens or dim_offset >= hidden_dim:
1384
+ return
1385
+
1386
+ h = tl.load(
1387
+ hidden_states_ptr
1388
+ + batch_id * hidden_dim
1389
+ + dim_offset
1390
+ + tl.arange(0, BLOCK_SIZE),
1391
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1392
+ )
1393
+
1394
+ result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
1395
+ for i in range(top_k):
1396
+ scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
1397
+ result += h * scale
1398
+
1399
+ tl.store(
1400
+ output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
1401
+ result,
1402
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1403
+ )
1404
+
1405
+
1406
+ def zero_experts_compute_triton(
1407
+ expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
1408
+ ):
1409
+ N = expert_indices.numel()
1410
+ top_k = expert_indices.size(-1)
1411
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
1412
+
1413
+ if zero_expert_type == "identity":
1414
+ zero_expert_mask = expert_indices < num_experts
1415
+ zero_expert_scales = expert_scales.clone()
1416
+ zero_expert_scales[zero_expert_mask] = 0.0
1417
+
1418
+ normal_expert_mask = expert_indices >= num_experts
1419
+ expert_indices[normal_expert_mask] = 0
1420
+ expert_scales[normal_expert_mask] = 0.0
1421
+
1422
+ output = torch.zeros_like(hidden_states).to(hidden_states.device)
1423
+ hidden_dim = hidden_states.size(-1)
1424
+ num_tokens = hidden_states.size(0)
1425
+
1426
+ grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
1427
+ compute_identity_kernel[grid](
1428
+ top_k,
1429
+ hidden_states,
1430
+ zero_expert_scales,
1431
+ num_tokens,
1432
+ output,
1433
+ hidden_dim,
1434
+ zero_expert_scales.stride(0),
1435
+ BLOCK_SIZE=256,
1436
+ )
1437
+
1438
+ return output
@@ -114,9 +114,6 @@ class EPMoE(FusedMoE):
114
114
  with_bias=with_bias,
115
115
  )
116
116
 
117
- self.start_expert_id = self.moe_ep_rank * self.num_local_experts
118
- self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
119
-
120
117
  self.intermediate_size = intermediate_size
121
118
 
122
119
  if isinstance(quant_config, Fp8Config):
@@ -232,7 +229,7 @@ class EPMoE(FusedMoE):
232
229
  (
233
230
  _cast_to_e8m0_with_rounding_up(gateup_input_scale)
234
231
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
235
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
232
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
236
233
  gateup_input_scale
237
234
  )
238
235
  ),
@@ -289,9 +286,7 @@ class EPMoE(FusedMoE):
289
286
  (
290
287
  down_input_scale
291
288
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
292
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
293
- down_input_scale
294
- )
289
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
295
290
  ),
296
291
  )
297
292
  down_output = torch.empty(
@@ -746,19 +741,25 @@ class DeepEPMoE(EPMoE):
746
741
  hidden_states = torch_npu.npu_grouped_matmul(
747
742
  x=[hidden_states],
748
743
  weight=[self.w13_weight],
749
- scale=[self.w13_weight_scale.to(output_dtype)],
750
- per_token_scale=[pertoken_scale],
751
744
  split_item=2,
752
745
  group_list_type=group_list_type,
753
746
  group_type=0,
754
747
  group_list=seg_indptr,
755
- output_dtype=output_dtype,
748
+ output_dtype=torch.int32,
756
749
  )[0]
757
750
 
758
751
  # act_fn: swiglu
759
- hidden_states = torch_npu.npu_swiglu(hidden_states)
760
-
761
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
752
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
753
+ x=hidden_states,
754
+ weight_scale=self.w13_weight_scale.to(torch.float32),
755
+ activation_scale=pertoken_scale,
756
+ bias=None,
757
+ quant_scale=None,
758
+ quant_offset=None,
759
+ group_index=seg_indptr,
760
+ activate_left=True,
761
+ quant_mode=1,
762
+ )
762
763
 
763
764
  # gmm2: down_proj
764
765
  hidden_states = torch_npu.npu_grouped_matmul(
@@ -1,16 +1,18 @@
1
1
  from contextlib import contextmanager
2
2
  from typing import Any, Dict, Optional
3
3
 
4
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
5
- fused_experts,
4
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
5
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
6
6
  get_config_file_name,
7
- moe_align_block_size,
8
7
  try_get_optimal_moe_config,
9
8
  )
10
9
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
11
10
  FusedMoE,
12
11
  FusedMoeWeightScaleSupported,
13
12
  )
13
+ from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (
14
+ moe_align_block_size,
15
+ )
14
16
 
15
17
  _config: Optional[Dict[str, Any]] = None
16
18
 
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 256,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 256,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 5
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 256,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 5
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 256,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 5
145
+ }
146
+ }