sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,11 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
|
+
import einops
|
4
5
|
import torch
|
6
|
+
from sgl_kernel import silu_and_mul
|
5
7
|
from torch.nn import Module
|
6
8
|
|
7
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
8
|
-
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
9
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
10
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
11
|
-
|
12
|
-
try:
|
13
|
-
from deep_gemm import (
|
14
|
-
get_col_major_tma_aligned_tensor,
|
15
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
16
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
17
|
-
)
|
18
|
-
from sgl_kernel import silu_and_mul
|
19
|
-
|
20
|
-
from sglang.srt.layers.quantization.fp8_kernel import (
|
21
|
-
sglang_per_token_group_quant_fp8,
|
22
|
-
)
|
23
|
-
|
24
|
-
use_deep_gemm = True
|
25
|
-
except ImportError:
|
26
|
-
use_deep_gemm = False
|
27
|
-
|
28
9
|
from sglang.srt.custom_op import CustomOp
|
29
10
|
from sglang.srt.distributed import (
|
30
11
|
get_tensor_model_parallel_rank,
|
@@ -35,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
35
16
|
ep_scatter,
|
36
17
|
gelu_and_mul_triton_kernel,
|
37
18
|
grouped_gemm_triton,
|
19
|
+
moe_ep_deepgemm_preprocess,
|
38
20
|
post_reorder_triton_kernel,
|
39
21
|
pre_reorder_triton_kernel,
|
40
22
|
run_moe_ep_preproess,
|
@@ -45,19 +27,33 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
45
27
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
46
28
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
47
29
|
from sglang.srt.layers.moe.topk import select_experts
|
30
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
48
31
|
from sglang.srt.layers.quantization.base_config import (
|
49
32
|
QuantizationConfig,
|
50
33
|
QuantizeMethodBase,
|
51
34
|
)
|
52
35
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
53
36
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
37
|
+
is_fp8_fnuz,
|
54
38
|
scaled_fp8_quant,
|
39
|
+
sglang_per_token_group_quant_fp8,
|
55
40
|
sglang_per_token_quant_fp8,
|
56
41
|
)
|
42
|
+
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
43
|
+
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
44
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
45
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
57
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
58
|
-
from sglang.srt.utils import
|
47
|
+
from sglang.srt.utils import (
|
48
|
+
DeepEPMode,
|
49
|
+
dispose_tensor,
|
50
|
+
get_bool_env_var,
|
51
|
+
is_hip,
|
52
|
+
set_weight_attrs,
|
53
|
+
)
|
59
54
|
|
60
55
|
_is_hip = is_hip()
|
56
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
61
57
|
|
62
58
|
if _is_hip:
|
63
59
|
from vllm._custom_ops import scaled_fp8_quant
|
@@ -183,6 +179,7 @@ class EPMoE(torch.nn.Module):
|
|
183
179
|
assert (
|
184
180
|
num_fused_shared_experts == 0
|
185
181
|
), "num_fused_shared_experts is not supported in EP"
|
182
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
186
183
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
187
184
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
188
185
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
@@ -232,13 +229,182 @@ class EPMoE(torch.nn.Module):
|
|
232
229
|
|
233
230
|
self.grouped_gemm_runner = None
|
234
231
|
|
232
|
+
self.w13_weight_fp8 = (
|
233
|
+
self.w13_weight,
|
234
|
+
(
|
235
|
+
self.w13_weight_scale_inv
|
236
|
+
if self.use_block_quant
|
237
|
+
else self.w13_weight_scale
|
238
|
+
),
|
239
|
+
)
|
240
|
+
self.w2_weight_fp8 = (
|
241
|
+
self.w2_weight,
|
242
|
+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
243
|
+
)
|
244
|
+
|
235
245
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
246
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
247
|
+
return self.forward_deepgemm(hidden_states, router_logits)
|
248
|
+
else:
|
249
|
+
return self.forward_normal(hidden_states, router_logits)
|
250
|
+
|
251
|
+
def forward_deepgemm(
|
252
|
+
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
253
|
+
):
|
254
|
+
assert self.quant_method is not None
|
255
|
+
assert self.activation == "silu"
|
236
256
|
hidden_states_shape = hidden_states.shape
|
237
257
|
hidden_states_dtype = hidden_states.dtype
|
238
258
|
hidden_states_device = hidden_states.device
|
259
|
+
topk_weights, topk_ids = select_experts(
|
260
|
+
hidden_states=hidden_states,
|
261
|
+
router_logits=router_logits,
|
262
|
+
top_k=self.top_k,
|
263
|
+
use_grouped_topk=self.use_grouped_topk,
|
264
|
+
renormalize=self.renormalize,
|
265
|
+
topk_group=self.topk_group,
|
266
|
+
num_expert_group=self.num_expert_group,
|
267
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
268
|
+
correction_bias=self.correction_bias,
|
269
|
+
custom_routing_function=self.custom_routing_function,
|
270
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
271
|
+
)
|
239
272
|
|
240
|
-
|
273
|
+
if not self.use_block_quant:
|
274
|
+
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
275
|
+
scale_block_size = 128
|
276
|
+
w13_weight_scale_n = 2 * (
|
277
|
+
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
278
|
+
)
|
279
|
+
w13_weight_scale_k = (
|
280
|
+
hidden_states_shape[-1] + scale_block_size - 1
|
281
|
+
) // scale_block_size
|
282
|
+
w13_weight_scale = (
|
283
|
+
self.w13_weight_scale.unsqueeze(1)
|
284
|
+
.repeat_interleave(w13_weight_scale_n, dim=1)
|
285
|
+
.unsqueeze(2)
|
286
|
+
.repeat_interleave(w13_weight_scale_k, dim=2)
|
287
|
+
)
|
288
|
+
self.w13_weight_fp8 = (
|
289
|
+
self.w13_weight,
|
290
|
+
w13_weight_scale,
|
291
|
+
)
|
292
|
+
w2_weight_scale_n = (
|
293
|
+
hidden_states_shape[-1] + scale_block_size - 1
|
294
|
+
) // scale_block_size
|
295
|
+
w2_weight_scale_k = (
|
296
|
+
self.intermediate_size + scale_block_size - 1
|
297
|
+
) // scale_block_size
|
298
|
+
w2_weight_scale = (
|
299
|
+
self.w2_weight_scale.unsqueeze(1)
|
300
|
+
.repeat_interleave(w2_weight_scale_n, dim=1)
|
301
|
+
.unsqueeze(2)
|
302
|
+
.repeat_interleave(w2_weight_scale_k, dim=2)
|
303
|
+
)
|
304
|
+
self.w2_weight_fp8 = (
|
305
|
+
self.w2_weight,
|
306
|
+
w2_weight_scale,
|
307
|
+
)
|
308
|
+
|
309
|
+
# PreReorder
|
310
|
+
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
311
|
+
moe_ep_deepgemm_preprocess(
|
312
|
+
topk_ids,
|
313
|
+
self.num_experts,
|
314
|
+
hidden_states,
|
315
|
+
self.top_k,
|
316
|
+
self.start_expert_id,
|
317
|
+
self.end_expert_id,
|
318
|
+
self.block_shape,
|
319
|
+
)
|
320
|
+
)
|
321
|
+
|
322
|
+
dispose_tensor(hidden_states)
|
323
|
+
|
324
|
+
# GroupGemm-0
|
325
|
+
gateup_input_fp8 = (
|
326
|
+
gateup_input,
|
327
|
+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
|
328
|
+
)
|
329
|
+
num_groups, m, k = gateup_input_fp8[0].size()
|
330
|
+
n = self.w13_weight.size(1)
|
331
|
+
gateup_output = torch.empty(
|
332
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
333
|
+
)
|
334
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
335
|
+
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
336
|
+
)
|
337
|
+
del gateup_input
|
338
|
+
del gateup_input_fp8
|
241
339
|
|
340
|
+
# Act
|
341
|
+
down_input = torch.empty(
|
342
|
+
(
|
343
|
+
gateup_output.shape[0],
|
344
|
+
gateup_output.shape[1],
|
345
|
+
gateup_output.shape[2] // 2,
|
346
|
+
),
|
347
|
+
device=hidden_states_device,
|
348
|
+
dtype=self.fp8_dtype,
|
349
|
+
)
|
350
|
+
scale_block_size = 128
|
351
|
+
down_input_scale = torch.empty(
|
352
|
+
(
|
353
|
+
gateup_output.shape[0],
|
354
|
+
gateup_output.shape[1],
|
355
|
+
gateup_output.shape[2] // 2 // scale_block_size,
|
356
|
+
),
|
357
|
+
device=hidden_states_device,
|
358
|
+
dtype=torch.float32,
|
359
|
+
)
|
360
|
+
silu_and_mul_masked_post_quant_fwd(
|
361
|
+
gateup_output,
|
362
|
+
down_input,
|
363
|
+
down_input_scale,
|
364
|
+
scale_block_size,
|
365
|
+
masked_m,
|
366
|
+
)
|
367
|
+
del gateup_output
|
368
|
+
|
369
|
+
# GroupGemm-1
|
370
|
+
n = self.w2_weight.size(1)
|
371
|
+
down_input_fp8 = (
|
372
|
+
down_input,
|
373
|
+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
374
|
+
)
|
375
|
+
down_output = torch.empty(
|
376
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
377
|
+
)
|
378
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
379
|
+
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
380
|
+
)
|
381
|
+
del down_input
|
382
|
+
del down_input_fp8
|
383
|
+
|
384
|
+
# PostReorder
|
385
|
+
output = torch.empty(
|
386
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
387
|
+
)
|
388
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
389
|
+
down_output,
|
390
|
+
output,
|
391
|
+
src2dst,
|
392
|
+
topk_ids,
|
393
|
+
topk_weights,
|
394
|
+
self.start_expert_id,
|
395
|
+
self.end_expert_id,
|
396
|
+
self.top_k,
|
397
|
+
hidden_states_shape[1],
|
398
|
+
m_max * self.start_expert_id,
|
399
|
+
BLOCK_SIZE=512,
|
400
|
+
)
|
401
|
+
return output
|
402
|
+
|
403
|
+
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
404
|
+
assert self.quant_method is not None
|
405
|
+
hidden_states_shape = hidden_states.shape
|
406
|
+
hidden_states_dtype = hidden_states.dtype
|
407
|
+
hidden_states_device = hidden_states.device
|
242
408
|
if self.grouped_gemm_runner is None:
|
243
409
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
244
410
|
hidden_states.device,
|
@@ -254,6 +420,7 @@ class EPMoE(torch.nn.Module):
|
|
254
420
|
renormalize=self.renormalize,
|
255
421
|
topk_group=self.topk_group,
|
256
422
|
num_expert_group=self.num_expert_group,
|
423
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
257
424
|
correction_bias=self.correction_bias,
|
258
425
|
custom_routing_function=self.custom_routing_function,
|
259
426
|
routed_scaling_factor=self.routed_scaling_factor,
|
@@ -445,6 +612,7 @@ class EPMoE(torch.nn.Module):
|
|
445
612
|
self.end_expert_id,
|
446
613
|
self.top_k,
|
447
614
|
hidden_states_shape[1],
|
615
|
+
0,
|
448
616
|
BLOCK_SIZE=512,
|
449
617
|
)
|
450
618
|
return output
|
@@ -680,7 +848,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
680
848
|
params_dtype: torch.dtype,
|
681
849
|
**extra_weight_attrs,
|
682
850
|
):
|
683
|
-
|
684
851
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
685
852
|
params_dtype = torch.float8_e4m3fn
|
686
853
|
|
@@ -852,6 +1019,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
852
1019
|
torch.max(layer.w13_weight_scale, dim=1).values,
|
853
1020
|
requires_grad=False,
|
854
1021
|
)
|
1022
|
+
if self.block_quant:
|
1023
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1024
|
+
if _is_fp8_fnuz:
|
1025
|
+
# activation_scheme: dynamic
|
1026
|
+
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1027
|
+
weight=layer.w13_weight,
|
1028
|
+
weight_scale=layer.w13_weight_scale_inv,
|
1029
|
+
input_scale=None,
|
1030
|
+
)
|
1031
|
+
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1032
|
+
weight=layer.w2_weight,
|
1033
|
+
weight_scale=layer.w2_weight_scale_inv,
|
1034
|
+
input_scale=None,
|
1035
|
+
)
|
1036
|
+
# Reset the parameter
|
1037
|
+
layer.w13_weight = torch.nn.Parameter(
|
1038
|
+
w13_weight, requires_grad=False
|
1039
|
+
)
|
1040
|
+
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1041
|
+
w13_weight_scale, requires_grad=False
|
1042
|
+
)
|
1043
|
+
layer.w13_input_scale = None
|
1044
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1045
|
+
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1046
|
+
w2_weight_scale, requires_grad=False
|
1047
|
+
)
|
1048
|
+
layer.w2_input_scale = None
|
855
1049
|
return
|
856
1050
|
|
857
1051
|
def apply(
|
@@ -920,7 +1114,9 @@ class DeepEPMoE(EPMoE):
|
|
920
1114
|
)
|
921
1115
|
self.deepep_mode = deepep_mode
|
922
1116
|
if self.deepep_mode.enable_low_latency():
|
923
|
-
assert
|
1117
|
+
assert (
|
1118
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1119
|
+
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
924
1120
|
self.w13_weight_fp8 = (
|
925
1121
|
self.w13_weight,
|
926
1122
|
(
|
@@ -948,7 +1144,7 @@ class DeepEPMoE(EPMoE):
|
|
948
1144
|
):
|
949
1145
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
950
1146
|
if resolved_deepep_mode == DeepEPMode.normal:
|
951
|
-
if
|
1147
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
952
1148
|
return self.forward_deepgemm_contiguous(
|
953
1149
|
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
954
1150
|
)
|
@@ -1145,7 +1341,7 @@ class DeepEPMoE(EPMoE):
|
|
1145
1341
|
dtype=torch.bfloat16,
|
1146
1342
|
)
|
1147
1343
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1148
|
-
|
1344
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1149
1345
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1150
1346
|
)
|
1151
1347
|
del input_tensor
|
@@ -1169,7 +1365,7 @@ class DeepEPMoE(EPMoE):
|
|
1169
1365
|
)
|
1170
1366
|
del down_input
|
1171
1367
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
1172
|
-
|
1368
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1173
1369
|
(down_input_fp8, down_input_scale),
|
1174
1370
|
self.w2_weight_fp8,
|
1175
1371
|
down_output,
|
@@ -1202,8 +1398,13 @@ class DeepEPMoE(EPMoE):
|
|
1202
1398
|
gateup_output = torch.empty(
|
1203
1399
|
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
1204
1400
|
)
|
1205
|
-
|
1206
|
-
hidden_states_fp8,
|
1401
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1402
|
+
hidden_states_fp8,
|
1403
|
+
self.w13_weight_fp8,
|
1404
|
+
gateup_output,
|
1405
|
+
masked_m,
|
1406
|
+
expected_m,
|
1407
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
1207
1408
|
)
|
1208
1409
|
dispose_tensor(hidden_states_fp8[0])
|
1209
1410
|
|
@@ -1233,6 +1434,7 @@ class DeepEPMoE(EPMoE):
|
|
1233
1434
|
down_input_scale,
|
1234
1435
|
scale_block_size,
|
1235
1436
|
masked_m,
|
1437
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1236
1438
|
)
|
1237
1439
|
del gateup_output
|
1238
1440
|
|
@@ -1240,13 +1442,24 @@ class DeepEPMoE(EPMoE):
|
|
1240
1442
|
n = self.w2_weight.size(1)
|
1241
1443
|
down_input_fp8 = (
|
1242
1444
|
down_input,
|
1243
|
-
|
1445
|
+
(
|
1446
|
+
down_input_scale
|
1447
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1448
|
+
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
1449
|
+
down_input_scale
|
1450
|
+
)
|
1451
|
+
),
|
1244
1452
|
)
|
1245
1453
|
down_output = torch.empty(
|
1246
1454
|
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
1247
1455
|
)
|
1248
|
-
|
1249
|
-
down_input_fp8,
|
1456
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1457
|
+
down_input_fp8,
|
1458
|
+
self.w2_weight_fp8,
|
1459
|
+
down_output,
|
1460
|
+
masked_m,
|
1461
|
+
expected_m,
|
1462
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
1250
1463
|
)
|
1251
1464
|
|
1252
1465
|
return down_output
|
@@ -1255,6 +1468,9 @@ class DeepEPMoE(EPMoE):
|
|
1255
1468
|
def get_moe_impl_class():
|
1256
1469
|
if global_server_args_dict["enable_deepep_moe"]:
|
1257
1470
|
return DeepEPMoE
|
1471
|
+
if global_server_args_dict["enable_flashinfer_moe"]:
|
1472
|
+
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1473
|
+
return FusedMoE
|
1258
1474
|
if global_server_args_dict["enable_ep_moe"]:
|
1259
1475
|
return EPMoE
|
1260
1476
|
return FusedMoE
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from dataclasses import dataclass
|
3
3
|
|
4
|
-
from sglang.srt.layers.quantization
|
4
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
5
5
|
from sglang.srt.managers.expert_distribution import (
|
6
6
|
get_global_expert_distribution_recorder,
|
7
7
|
)
|
@@ -107,6 +107,8 @@ class DeepEPBuffer:
|
|
107
107
|
num_rdma_bytes,
|
108
108
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
109
109
|
num_qps_per_rank=num_qps_per_rank,
|
110
|
+
# TODO can be false when unneeded
|
111
|
+
allow_mnnvl=True,
|
110
112
|
)
|
111
113
|
return cls._buffer
|
112
114
|
|
@@ -234,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
234
236
|
topk_weights: torch.Tensor,
|
235
237
|
):
|
236
238
|
topk_idx = topk_idx.to(torch.int64)
|
237
|
-
if
|
239
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
238
240
|
# TODO hard code 128 block quant,use fp8 communication
|
239
241
|
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
240
242
|
previous_event = Buffer.capture() if self.async_finish else None
|
241
243
|
return hidden_states, topk_idx, topk_weights, previous_event
|
242
244
|
|
243
245
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
244
|
-
if
|
246
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
245
247
|
(
|
246
248
|
hidden_states,
|
247
249
|
topk_idx,
|
@@ -343,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
343
345
|
previous_event=previous_event,
|
344
346
|
async_finish=self.async_finish,
|
345
347
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
346
|
-
expert_alignment=128 if
|
348
|
+
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
347
349
|
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
348
350
|
)
|
349
351
|
|
@@ -407,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
407
409
|
topk_idx: torch.Tensor,
|
408
410
|
topk_weights: torch.Tensor,
|
409
411
|
):
|
410
|
-
if
|
412
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
411
413
|
output = hidden_states
|
412
414
|
else:
|
413
415
|
if hidden_states.shape[0] > 0:
|
@@ -540,38 +542,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
540
542
|
topk_idx: torch.Tensor,
|
541
543
|
use_fp8: bool = False,
|
542
544
|
):
|
543
|
-
"""
|
544
|
-
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
|
545
|
-
# Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
|
546
|
-
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
547
|
-
|
548
|
-
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
549
|
-
index 76ae2e2..8ecd08f 100644
|
550
|
-
--- a/csrc/kernels/internode_ll.cu
|
551
|
-
+++ b/csrc/kernels/internode_ll.cu
|
552
|
-
@@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
553
|
-
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
554
|
-
void* workspace, cudaStream_t stream, int phases) {
|
555
|
-
constexpr int kNumMaxTopK = 9;
|
556
|
-
- constexpr int kNumWarpsPerGroup = 10;
|
557
|
-
- constexpr int kNumWarpGroups = 3;
|
558
|
-
+ constexpr int kNumWarpsPerGroup = 8;
|
559
|
-
+ constexpr int kNumWarpGroups = 4;
|
560
|
-
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
561
|
-
|
562
|
-
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
563
|
-
@@ -501,8 +501,8 @@ void combine(void* combined_x,
|
564
|
-
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
565
|
-
int num_topk, int num_experts, int rank, int num_ranks,
|
566
|
-
void* workspace, cudaStream_t stream, int phases) {
|
567
|
-
- constexpr int kNumWarpsPerGroup = 10;
|
568
|
-
- constexpr int kNumWarpGroups = 3;
|
569
|
-
+ constexpr int kNumWarpsPerGroup = 8;
|
570
|
-
+ constexpr int kNumWarpGroups = 4;
|
571
|
-
constexpr int kNumMaxTopk = 9;
|
572
|
-
|
573
|
-
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
574
|
-
"""
|
575
545
|
buffer = self._get_buffer()
|
576
546
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
577
547
|
buffer.low_latency_dispatch(
|
@@ -582,6 +552,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
582
552
|
use_fp8=use_fp8,
|
583
553
|
async_finish=not self.return_recv_hook,
|
584
554
|
return_recv_hook=self.return_recv_hook,
|
555
|
+
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
556
|
+
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
557
|
+
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
558
|
+
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
585
559
|
)
|
586
560
|
)
|
587
561
|
return packed_recv_hidden, packed_recv_count, event, hook
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 5
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|