sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,358 +0,0 @@
1
- import itertools
2
- import random
3
- import unittest
4
- from typing import Any, Callable, Dict, List, Optional, Tuple
5
-
6
- import torch
7
-
8
- from sglang.srt.layers.moe.ep_moe.kernels import (
9
- grouped_gemm_triton,
10
- post_reorder_triton_kernel,
11
- pre_reorder_triton_kernel,
12
- run_moe_ep_preproess,
13
- silu_and_mul_triton_kernel,
14
- )
15
- from sglang.srt.layers.moe.topk import TopKConfig, select_experts
16
- from sglang.test.test_utils import CustomTestCase
17
-
18
-
19
- # For test
20
- def ep_moe(
21
- hidden_states: torch.Tensor,
22
- w1: torch.Tensor,
23
- w2: torch.Tensor,
24
- router_logits: torch.Tensor,
25
- topk_config: TopKConfig,
26
- # ep config
27
- num_experts: int = 256,
28
- fp8_dtype: torch.types = torch.float8_e4m3fn,
29
- num_experts_per_partition: int = 128,
30
- start_expert_id: int = 0,
31
- end_expert_id: int = 127,
32
- use_fp8_w8a8: bool = False,
33
- w1_scale_inv: Optional[torch.Tensor] = None,
34
- w2_scale_inv: Optional[torch.Tensor] = None,
35
- block_shape: Optional[List[int]] = None,
36
- ):
37
- use_blockwise_fp8 = block_shape is not None
38
- top_k = topk_config.top_k
39
- topk_output = select_experts(
40
- hidden_states=hidden_states,
41
- router_logits=router_logits,
42
- topk_config=topk_config,
43
- )
44
- topk_weights, topk_ids, _ = topk_output
45
-
46
- reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
47
-
48
- gateup_input = torch.empty(
49
- (int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
50
- device=hidden_states.device,
51
- dtype=(
52
- fp8_dtype
53
- if (use_fp8_w8a8 and not use_blockwise_fp8)
54
- else hidden_states.dtype
55
- ),
56
- )
57
-
58
- if use_fp8_w8a8 and not use_blockwise_fp8:
59
- max_value = (
60
- torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
61
- )
62
- w1_input_scale = max_value / torch.finfo(fp8_dtype).max
63
- else:
64
- w1_input_scale = None
65
-
66
- # PreReorder
67
- pre_reorder_triton_kernel[(hidden_states.shape[0],)](
68
- hidden_states,
69
- gateup_input,
70
- src2dst,
71
- topk_ids,
72
- w1_input_scale,
73
- start_expert_id,
74
- end_expert_id,
75
- top_k,
76
- hidden_states.shape[1],
77
- BLOCK_SIZE=512,
78
- use_per_token_if_dynamic=True,
79
- )
80
-
81
- seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
82
- weight_indices_cur_rank = torch.arange(
83
- 0,
84
- num_experts_per_partition,
85
- device=hidden_states.device,
86
- dtype=torch.int64,
87
- )
88
-
89
- # GroupGemm-0
90
- gateup_output = torch.empty(
91
- gateup_input.shape[0],
92
- w1.shape[1],
93
- device=hidden_states.device,
94
- dtype=hidden_states.dtype,
95
- )
96
-
97
- gateup_output = grouped_gemm_triton(
98
- a=gateup_input,
99
- b=w1,
100
- c=gateup_output,
101
- batch_size=num_experts_per_partition,
102
- weight_column_major=True,
103
- seg_indptr=seg_indptr_cur_rank,
104
- weight_indices=weight_indices_cur_rank,
105
- use_fp8_w8a8=use_fp8_w8a8,
106
- scale_a=w1_input_scale,
107
- scale_b=w1_scale_inv,
108
- block_shape=block_shape,
109
- )
110
-
111
- # Act
112
- down_input = torch.empty(
113
- gateup_output.shape[0],
114
- gateup_output.shape[1] // 2,
115
- device=gateup_output.device,
116
- dtype=(
117
- fp8_dtype
118
- if (use_fp8_w8a8 and not use_blockwise_fp8)
119
- else hidden_states.dtype
120
- ),
121
- )
122
- if use_fp8_w8a8 and not use_blockwise_fp8:
123
- w2_input_scale = torch.ones(
124
- num_experts_per_partition,
125
- dtype=torch.float32,
126
- device=hidden_states.device,
127
- )
128
- else:
129
- w2_input_scale = None
130
-
131
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
132
- gateup_output,
133
- down_input,
134
- gateup_output.shape[1],
135
- reorder_topk_ids,
136
- w2_input_scale,
137
- start_expert_id,
138
- end_expert_id,
139
- BLOCK_SIZE=512,
140
- )
141
-
142
- # GroupGemm-1
143
- down_output = torch.empty(
144
- down_input.shape[0],
145
- w2.shape[1],
146
- device=hidden_states.device,
147
- dtype=hidden_states.dtype,
148
- )
149
-
150
- down_output = grouped_gemm_triton(
151
- a=down_input,
152
- b=w2,
153
- c=down_output,
154
- batch_size=num_experts_per_partition,
155
- weight_column_major=True,
156
- seg_indptr=seg_indptr_cur_rank,
157
- weight_indices=weight_indices_cur_rank,
158
- use_fp8_w8a8=use_fp8_w8a8,
159
- scale_a=w2_input_scale,
160
- scale_b=w2_scale_inv,
161
- block_shape=block_shape,
162
- )
163
-
164
- # PostReorder
165
- output = torch.empty_like(hidden_states)
166
- post_reorder_triton_kernel[(hidden_states.size(0),)](
167
- down_output,
168
- output,
169
- src2dst,
170
- topk_ids,
171
- topk_weights,
172
- start_expert_id,
173
- end_expert_id,
174
- top_k,
175
- hidden_states.size(1),
176
- 0,
177
- BLOCK_SIZE=512,
178
- )
179
- return output
180
-
181
-
182
- # test util
183
- def block_dequant(
184
- x_q_block: torch.Tensor,
185
- x_s: torch.Tensor,
186
- block_size: List[int],
187
- ) -> Tuple[torch.Tensor, torch.Tensor]:
188
- """This function converts block-wise quantization to tensor-wise quantization.
189
- The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
190
- and the block size.
191
- The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
192
- Note only float8 is supported for now.
193
- """
194
-
195
- # process 3D tensor
196
- if x_q_block.dim() == 3:
197
- batch_size = x_q_block.size(0)
198
- return torch.stack(
199
- [block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
200
- )
201
-
202
- block_n, block_k = block_size[0], block_size[1]
203
- n, k = x_q_block.shape
204
- n_tiles = (n + block_n - 1) // block_n
205
- k_tiles = (k + block_k - 1) // block_k
206
- assert n_tiles == x_s.shape[0]
207
- assert k_tiles == x_s.shape[1]
208
-
209
- x_dq_block = x_q_block.to(torch.float32)
210
-
211
- x_dq_block_tiles = [
212
- [
213
- x_dq_block[
214
- j * block_n : min((j + 1) * block_n, n),
215
- i * block_k : min((i + 1) * block_k, k),
216
- ]
217
- for i in range(k_tiles)
218
- ]
219
- for j in range(n_tiles)
220
- ]
221
-
222
- for i in range(k_tiles):
223
- for j in range(n_tiles):
224
- x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
225
-
226
- return x_dq_block
227
-
228
-
229
- class TestW8A8BlockFP8EPMoE(CustomTestCase):
230
- DTYPES = [torch.half, torch.bfloat16]
231
- M = [1, 222, 1024, 2048]
232
- N = [128, 1024, 2048]
233
- K = [256, 4096, 5120]
234
- E = [8, 16]
235
- ep_size = [2, 4]
236
- TOP_KS = [2, 4]
237
- BLOCK_SIZE = [[128, 128]]
238
- SEEDS = [0]
239
-
240
- @classmethod
241
- def setUpClass(cls):
242
- if not torch.cuda.is_available():
243
- raise unittest.SkipTest("CUDA is not available")
244
- torch.set_default_device("cuda")
245
-
246
- def _w8a8_block_fp8_ep_moe(
247
- self, M, N, K, E, ep_size, topk, block_size, dtype, seed
248
- ):
249
- torch.manual_seed(seed)
250
- random.seed(seed)
251
- # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
252
- factor_for_scale = 1e-2
253
- fp8_info = torch.finfo(torch.float8_e4m3fn)
254
- fp8_max, fp8_min = fp8_info.max, fp8_info.min
255
-
256
- a = torch.randn((M, K), dtype=dtype) / 10
257
-
258
- w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
259
- w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
260
-
261
- w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
262
- w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
263
-
264
- block_n, block_k = block_size[0], block_size[1]
265
- n_tiles_w1 = (2 * N + block_n - 1) // block_n
266
- n_tiles_w2 = (K + block_n - 1) // block_n
267
- k_tiles_w1 = (K + block_k - 1) // block_k
268
- k_tiles_w2 = (N + block_k - 1) // block_k
269
-
270
- w1_s = (
271
- torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
272
- * factor_for_scale
273
- )
274
- w2_s = (
275
- torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
276
- * factor_for_scale
277
- )
278
-
279
- w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
280
- w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
281
-
282
- score = torch.randn((M, E), dtype=dtype)
283
- num_experts_per_partition = E // ep_size
284
- cur_rank = random.randint(0, ep_size - 1)
285
- start_id = cur_rank * num_experts_per_partition
286
- end_id = start_id + num_experts_per_partition - 1
287
-
288
- topk_config = TopKConfig(
289
- top_k=topk,
290
- renormalize=False,
291
- )
292
-
293
- with torch.inference_mode():
294
- out = ep_moe(
295
- hidden_states=a,
296
- w1=w1,
297
- w2=w2,
298
- router_logits=score,
299
- topk_config=topk_config,
300
- use_fp8_w8a8=True,
301
- w1_scale_inv=w1_s,
302
- w2_scale_inv=w2_s,
303
- block_shape=block_size,
304
- num_experts=E,
305
- num_experts_per_partition=num_experts_per_partition,
306
- start_expert_id=start_id,
307
- end_expert_id=end_id,
308
- )
309
- ref_out = ep_moe(
310
- hidden_states=a,
311
- w1=w1_ref,
312
- w2=w2_ref,
313
- router_logits=score,
314
- topk_config=topk_config,
315
- use_fp8_w8a8=False,
316
- w1_scale_inv=None,
317
- w2_scale_inv=None,
318
- block_shape=None,
319
- num_experts=E,
320
- num_experts_per_partition=num_experts_per_partition,
321
- start_expert_id=start_id,
322
- end_expert_id=end_id,
323
- )
324
- self.assertTrue(
325
- torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
326
- / (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
327
- < 0.06
328
- )
329
-
330
- def test_w8a8_block_fp8_ep_moe(self):
331
- for params in itertools.product(
332
- self.M,
333
- self.N,
334
- self.K,
335
- self.E,
336
- self.ep_size,
337
- self.TOP_KS,
338
- self.BLOCK_SIZE,
339
- self.DTYPES,
340
- self.SEEDS,
341
- ):
342
- with self.subTest(
343
- M=params[0],
344
- N=params[1],
345
- K=params[2],
346
- E=params[3],
347
- ep_size=params[4],
348
- topk=params[5],
349
- block_size=params[6],
350
- dtype=params[7],
351
- seed=params[8],
352
- ):
353
- self._w8a8_block_fp8_ep_moe(*params)
354
- torch.cuda.empty_cache()
355
-
356
-
357
- if __name__ == "__main__":
358
- unittest.main(verbosity=2)