sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -23,11 +23,12 @@ import triton.language as tl
|
|
23
23
|
from torch import nn
|
24
24
|
|
25
25
|
from sglang.srt.distributed import (
|
26
|
+
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
27
28
|
tensor_model_parallel_all_gather,
|
28
29
|
)
|
29
30
|
from sglang.srt.layers.dp_attention import (
|
30
|
-
|
31
|
+
dp_gather_replicate,
|
31
32
|
dp_scatter,
|
32
33
|
get_attention_dp_rank,
|
33
34
|
get_attention_dp_size,
|
@@ -222,16 +223,18 @@ class LogitsProcessor(nn.Module):
|
|
222
223
|
hidden_states,
|
223
224
|
lm_head: VocabParallelEmbedding,
|
224
225
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
226
|
+
aux_hidden_states: Optional[torch.Tensor] = None,
|
225
227
|
) -> LogitsProcessorOutput:
|
226
228
|
if isinstance(logits_metadata, ForwardBatch):
|
227
229
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
228
|
-
|
229
230
|
# Get the last hidden states and last logits for the next token prediction
|
230
231
|
if (
|
231
232
|
logits_metadata.forward_mode.is_decode_or_idle()
|
232
233
|
or logits_metadata.forward_mode.is_target_verify()
|
233
234
|
):
|
234
235
|
pruned_states = hidden_states
|
236
|
+
if aux_hidden_states is not None:
|
237
|
+
aux_pruned_states = [hidden for hidden in aux_hidden_states]
|
235
238
|
sample_indices = None
|
236
239
|
input_logprob_indices = None
|
237
240
|
elif (
|
@@ -255,6 +258,8 @@ class LogitsProcessor(nn.Module):
|
|
255
258
|
- 1
|
256
259
|
)
|
257
260
|
pruned_states = hidden_states[last_index]
|
261
|
+
if aux_hidden_states is not None:
|
262
|
+
aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
|
258
263
|
sample_indices = None
|
259
264
|
input_logprob_indices = None
|
260
265
|
else:
|
@@ -318,13 +323,27 @@ class LogitsProcessor(nn.Module):
|
|
318
323
|
hidden_states_to_store: Optional[torch.Tensor] = None
|
319
324
|
if logits_metadata.capture_hidden_mode.need_capture():
|
320
325
|
if logits_metadata.capture_hidden_mode.is_full():
|
321
|
-
|
326
|
+
if aux_hidden_states is not None:
|
327
|
+
aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
|
328
|
+
hidden_states_to_store = aux_hidden_states
|
329
|
+
else:
|
330
|
+
hidden_states_to_store = hidden_states
|
322
331
|
elif logits_metadata.capture_hidden_mode.is_last():
|
323
332
|
# Get the last token hidden states. If sample_indices is None,
|
324
333
|
# pruned states only contain the last tokens already.
|
325
|
-
|
326
|
-
|
327
|
-
|
334
|
+
if aux_hidden_states is not None:
|
335
|
+
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
336
|
+
hidden_states_to_store = (
|
337
|
+
aux_pruned_states[sample_indices]
|
338
|
+
if sample_indices
|
339
|
+
else aux_pruned_states
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
hidden_states_to_store = (
|
343
|
+
pruned_states[sample_indices]
|
344
|
+
if sample_indices
|
345
|
+
else pruned_states
|
346
|
+
)
|
328
347
|
else:
|
329
348
|
assert False, "Should never reach"
|
330
349
|
|
@@ -409,7 +428,7 @@ class LogitsProcessor(nn.Module):
|
|
409
428
|
logits_metadata.gathered_buffer,
|
410
429
|
hidden_states.clone(),
|
411
430
|
)
|
412
|
-
|
431
|
+
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
413
432
|
|
414
433
|
if hasattr(lm_head, "weight"):
|
415
434
|
logits = torch.matmul(
|
@@ -5,6 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
8
9
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
10
|
from sglang.srt.utils import is_cuda
|
10
11
|
|
@@ -16,6 +17,115 @@ if _is_cuda:
|
|
16
17
|
logger = logging.getLogger(__name__)
|
17
18
|
|
18
19
|
|
20
|
+
@triton.jit
|
21
|
+
def deepep_permute_triton_kernel(
|
22
|
+
input_ptr,
|
23
|
+
gateup_input_ptr,
|
24
|
+
src2dst_ptr,
|
25
|
+
topk_ids_ptr,
|
26
|
+
a1_scales_ptr,
|
27
|
+
topk,
|
28
|
+
hidden_size,
|
29
|
+
BLOCK_SIZE: tl.constexpr,
|
30
|
+
):
|
31
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
32
|
+
|
33
|
+
src_idx = tl.program_id(0)
|
34
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
35
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
36
|
+
|
37
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
38
|
+
|
39
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
40
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
41
|
+
mask = offset < hidden_size
|
42
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)
|
43
|
+
|
44
|
+
for idx in range(topk):
|
45
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
46
|
+
if dst_idx >= 0:
|
47
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
48
|
+
tl.store(dst_ptr + offset, in_data, mask=mask)
|
49
|
+
|
50
|
+
|
51
|
+
@triton.jit
|
52
|
+
def deepep_post_reorder_triton_kernel(
|
53
|
+
down_output_ptr,
|
54
|
+
output_ptr,
|
55
|
+
src2dst_ptr,
|
56
|
+
topk_ids_ptr,
|
57
|
+
topk_weights_ptr,
|
58
|
+
topk,
|
59
|
+
hidden_size,
|
60
|
+
BLOCK_SIZE: tl.constexpr,
|
61
|
+
):
|
62
|
+
InDtype = down_output_ptr.dtype.element_ty
|
63
|
+
|
64
|
+
src_idx = tl.program_id(0)
|
65
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
66
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
67
|
+
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
68
|
+
|
69
|
+
store_ptr = output_ptr + src_idx * hidden_size
|
70
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
71
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
72
|
+
mask = offset < hidden_size
|
73
|
+
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
74
|
+
for idx in range(topk):
|
75
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
76
|
+
if dst_idx >= 0:
|
77
|
+
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
78
|
+
load_ptr = down_output_ptr + dst_idx * hidden_size
|
79
|
+
in_data = tl.load(load_ptr + offset, mask=mask)
|
80
|
+
sum_vec += in_data * weigh_scale
|
81
|
+
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
82
|
+
|
83
|
+
|
84
|
+
@triton.jit
|
85
|
+
def compute_src2dst_triton_kernel(
|
86
|
+
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
87
|
+
):
|
88
|
+
pid = tl.program_id(axis=0)
|
89
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
90
|
+
mask = dst_id < num_toks
|
91
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
92
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
93
|
+
|
94
|
+
|
95
|
+
@triton.jit
|
96
|
+
def deepep_compute_src2dst_triton_kernel(
|
97
|
+
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
|
98
|
+
):
|
99
|
+
pid = tl.program_id(axis=0)
|
100
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
101
|
+
mask = dst_id < num_toks
|
102
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
103
|
+
num_invalid = tl.load(num_minus_one)
|
104
|
+
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
|
105
|
+
|
106
|
+
|
107
|
+
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
108
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
109
|
+
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
110
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
|
111
|
+
|
112
|
+
# Find offet
|
113
|
+
expert_ids = torch.arange(
|
114
|
+
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
115
|
+
)
|
116
|
+
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
|
117
|
+
num_minus_one = seg_indptr[0]
|
118
|
+
seg_indptr = seg_indptr - num_minus_one
|
119
|
+
|
120
|
+
BLOCK_SIZE = 512
|
121
|
+
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
122
|
+
deepep_compute_src2dst_triton_kernel[grid](
|
123
|
+
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
|
124
|
+
)
|
125
|
+
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
|
126
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
127
|
+
|
128
|
+
|
19
129
|
@triton.jit
|
20
130
|
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
21
131
|
expert = tl.program_id(0)
|
@@ -33,17 +143,6 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
|
33
143
|
tl.store(seg_indptr + expert + 1, target_location + 1)
|
34
144
|
|
35
145
|
|
36
|
-
@triton.jit
|
37
|
-
def compute_src2dst_triton_kernel(
|
38
|
-
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
39
|
-
):
|
40
|
-
pid = tl.program_id(axis=0)
|
41
|
-
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
42
|
-
mask = dst_id < num_toks
|
43
|
-
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
44
|
-
tl.store(src2dst + src_id, dst_id, mask=mask)
|
45
|
-
|
46
|
-
|
47
146
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
48
147
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
49
148
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
@@ -2,8 +2,14 @@ import logging
|
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
|
+
|
6
|
+
# TODO: use deep_gemm masked kernel after low latency dispatch
|
7
|
+
# import deep_gemm
|
8
|
+
# from deep_gemm import (
|
9
|
+
# get_col_major_tma_aligned_tensor,
|
10
|
+
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
11
|
+
# )
|
5
12
|
from torch.nn import Module
|
6
|
-
from vllm import _custom_ops as vllm_ops
|
7
13
|
|
8
14
|
from sglang.srt.custom_op import CustomOp
|
9
15
|
from sglang.srt.distributed import (
|
@@ -26,18 +32,23 @@ from sglang.srt.layers.quantization.base_config import (
|
|
26
32
|
QuantizeMethodBase,
|
27
33
|
)
|
28
34
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
29
36
|
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
30
37
|
|
31
38
|
_is_cuda = is_cuda()
|
32
39
|
|
33
40
|
if _is_cuda:
|
34
41
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
42
|
+
else:
|
43
|
+
from vllm import _custom_ops as vllm_ops
|
35
44
|
|
36
45
|
|
37
46
|
logger = logging.getLogger(__name__)
|
38
47
|
|
39
48
|
_is_hip = is_hip()
|
40
49
|
|
50
|
+
_buffer = None
|
51
|
+
|
41
52
|
|
42
53
|
class GroupedGemmRunner(torch.nn.Module):
|
43
54
|
flashinfer_gemm_warpper = None
|
@@ -772,3 +783,264 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
772
783
|
custom_routing_function: Optional[Callable] = None,
|
773
784
|
) -> torch.Tensor:
|
774
785
|
raise NotImplementedError
|
786
|
+
|
787
|
+
|
788
|
+
class DeepEPMoE(EPMoE):
|
789
|
+
"""
|
790
|
+
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
791
|
+
"""
|
792
|
+
|
793
|
+
_has_printed = False
|
794
|
+
|
795
|
+
def __init__(
|
796
|
+
self,
|
797
|
+
num_experts: int,
|
798
|
+
top_k: int,
|
799
|
+
hidden_size: int,
|
800
|
+
intermediate_size: int,
|
801
|
+
params_dtype: Optional[torch.dtype] = None,
|
802
|
+
renormalize: bool = True,
|
803
|
+
use_grouped_topk: bool = False,
|
804
|
+
num_expert_group: Optional[int] = None,
|
805
|
+
topk_group: Optional[int] = None,
|
806
|
+
quant_config: Optional[QuantizationConfig] = None,
|
807
|
+
tp_size: Optional[int] = None,
|
808
|
+
prefix: str = "",
|
809
|
+
correction_bias: Optional[torch.Tensor] = None,
|
810
|
+
custom_routing_function: Optional[Callable] = None,
|
811
|
+
activation: str = "silu",
|
812
|
+
):
|
813
|
+
super().__init__(
|
814
|
+
num_experts,
|
815
|
+
top_k,
|
816
|
+
hidden_size,
|
817
|
+
intermediate_size,
|
818
|
+
params_dtype,
|
819
|
+
renormalize,
|
820
|
+
use_grouped_topk,
|
821
|
+
num_expert_group,
|
822
|
+
topk_group,
|
823
|
+
quant_config,
|
824
|
+
tp_size,
|
825
|
+
prefix,
|
826
|
+
correction_bias,
|
827
|
+
custom_routing_function,
|
828
|
+
activation,
|
829
|
+
)
|
830
|
+
|
831
|
+
def forward(
|
832
|
+
self,
|
833
|
+
hidden_states: torch.Tensor,
|
834
|
+
reorder_topk_ids: torch.Tensor,
|
835
|
+
seg_indptr: torch.Tensor,
|
836
|
+
forward_mode: ForwardMode,
|
837
|
+
):
|
838
|
+
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
839
|
+
if True: # not forward_mode.is_decode():
|
840
|
+
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
841
|
+
else:
|
842
|
+
return self.forward_deepgemm_masked(
|
843
|
+
hidden_states, reorder_topk_ids, seg_indptr
|
844
|
+
)
|
845
|
+
|
846
|
+
def forward_normal(
|
847
|
+
self,
|
848
|
+
hidden_states: torch.Tensor,
|
849
|
+
reorder_topk_ids: torch.Tensor,
|
850
|
+
seg_indptr: torch.Tensor,
|
851
|
+
):
|
852
|
+
assert self.quant_method is not None
|
853
|
+
assert self.activation == "silu"
|
854
|
+
if self.grouped_gemm_runner is None:
|
855
|
+
self.grouped_gemm_runner = GroupedGemmRunner(
|
856
|
+
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
857
|
+
)
|
858
|
+
|
859
|
+
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
860
|
+
max_value = (
|
861
|
+
torch.max(hidden_states)
|
862
|
+
.repeat(self.num_experts_per_partition)
|
863
|
+
.to(torch.float32)
|
864
|
+
)
|
865
|
+
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
866
|
+
weight_indices_cur_rank = torch.arange(
|
867
|
+
0,
|
868
|
+
self.num_experts_per_partition,
|
869
|
+
device=hidden_states.device,
|
870
|
+
dtype=torch.int64,
|
871
|
+
)
|
872
|
+
|
873
|
+
# GroupGemm-0
|
874
|
+
gateup_output = torch.empty(
|
875
|
+
hidden_states.shape[0],
|
876
|
+
self.w13_weight.shape[1],
|
877
|
+
device=hidden_states.device,
|
878
|
+
dtype=hidden_states.dtype,
|
879
|
+
)
|
880
|
+
|
881
|
+
if hidden_states.shape[0] > 0:
|
882
|
+
gateup_output = self.grouped_gemm_runner(
|
883
|
+
a=hidden_states,
|
884
|
+
b=self.w13_weight,
|
885
|
+
c=gateup_output,
|
886
|
+
batch_size=self.num_experts_per_partition,
|
887
|
+
weight_column_major=True,
|
888
|
+
seg_indptr=seg_indptr,
|
889
|
+
weight_indices=weight_indices_cur_rank,
|
890
|
+
use_fp8_w8a8=self.use_fp8_w8a8,
|
891
|
+
scale_a=self.w13_input_scale,
|
892
|
+
scale_b=(
|
893
|
+
self.w13_weight_scale_inv
|
894
|
+
if self.use_block_quant
|
895
|
+
else self.w13_weight_scale
|
896
|
+
),
|
897
|
+
block_shape=self.block_shape,
|
898
|
+
)
|
899
|
+
|
900
|
+
# Act
|
901
|
+
down_input = torch.empty(
|
902
|
+
gateup_output.shape[0],
|
903
|
+
gateup_output.shape[1] // 2,
|
904
|
+
device=gateup_output.device,
|
905
|
+
dtype=(
|
906
|
+
self.fp8_dtype
|
907
|
+
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
908
|
+
else hidden_states.dtype
|
909
|
+
),
|
910
|
+
)
|
911
|
+
if self.w2_input_scale is None and not self.use_block_quant:
|
912
|
+
self.w2_input_scale = torch.ones(
|
913
|
+
self.num_experts_per_partition,
|
914
|
+
dtype=torch.float32,
|
915
|
+
device=hidden_states.device,
|
916
|
+
)
|
917
|
+
|
918
|
+
if self.activation == "silu":
|
919
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
920
|
+
gateup_output,
|
921
|
+
down_input,
|
922
|
+
gateup_output.shape[1],
|
923
|
+
reorder_topk_ids,
|
924
|
+
self.w2_input_scale,
|
925
|
+
0,
|
926
|
+
self.num_experts_per_partition - 1,
|
927
|
+
BLOCK_SIZE=512,
|
928
|
+
)
|
929
|
+
else:
|
930
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
931
|
+
|
932
|
+
# GroupGemm-1
|
933
|
+
down_output = torch.empty(
|
934
|
+
down_input.shape[0],
|
935
|
+
self.w2_weight.shape[1],
|
936
|
+
device=hidden_states.device,
|
937
|
+
dtype=hidden_states.dtype,
|
938
|
+
)
|
939
|
+
if down_input.shape[0] > 0:
|
940
|
+
down_output = self.grouped_gemm_runner(
|
941
|
+
a=down_input,
|
942
|
+
b=self.w2_weight,
|
943
|
+
c=down_output,
|
944
|
+
batch_size=self.num_experts_per_partition,
|
945
|
+
weight_column_major=True,
|
946
|
+
seg_indptr=seg_indptr,
|
947
|
+
weight_indices=weight_indices_cur_rank,
|
948
|
+
use_fp8_w8a8=self.use_fp8_w8a8,
|
949
|
+
scale_a=self.w2_input_scale,
|
950
|
+
scale_b=(
|
951
|
+
self.w2_weight_scale_inv
|
952
|
+
if self.use_block_quant
|
953
|
+
else self.w2_weight_scale
|
954
|
+
),
|
955
|
+
block_shape=self.block_shape,
|
956
|
+
)
|
957
|
+
return down_output
|
958
|
+
|
959
|
+
def forward_deepgemm_masked(
|
960
|
+
self,
|
961
|
+
hidden_states: torch.Tensor,
|
962
|
+
reorder_topk_ids: torch.Tensor,
|
963
|
+
seg_indptr: torch.Tensor,
|
964
|
+
):
|
965
|
+
assert self.quant_method is not None
|
966
|
+
assert self.activation == "silu"
|
967
|
+
|
968
|
+
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
969
|
+
max_value = (
|
970
|
+
torch.max(hidden_states)
|
971
|
+
.repeat(self.num_experts_per_partition)
|
972
|
+
.to(torch.float32)
|
973
|
+
)
|
974
|
+
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
975
|
+
|
976
|
+
# GroupGemm-0
|
977
|
+
gateup_output = torch.empty(
|
978
|
+
hidden_states.shape[0],
|
979
|
+
self.w13_weight.shape[1],
|
980
|
+
device=hidden_states.device,
|
981
|
+
dtype=hidden_states.dtype,
|
982
|
+
)
|
983
|
+
if hidden_states.shape[0] > 0:
|
984
|
+
# Transpose earlier so that the testing will not trigger transposing kernels
|
985
|
+
hidden_states = (
|
986
|
+
hidden_states[0],
|
987
|
+
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
988
|
+
)
|
989
|
+
"""
|
990
|
+
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
991
|
+
hidden_states, self.w13_weight, out, masked_m, expected_m
|
992
|
+
)
|
993
|
+
"""
|
994
|
+
|
995
|
+
# Act
|
996
|
+
down_input = torch.empty(
|
997
|
+
gateup_output.shape[0],
|
998
|
+
gateup_output.shape[1] // 2,
|
999
|
+
device=gateup_output.device,
|
1000
|
+
dtype=(
|
1001
|
+
self.fp8_dtype
|
1002
|
+
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
1003
|
+
else hidden_states.dtype
|
1004
|
+
),
|
1005
|
+
)
|
1006
|
+
if self.w2_input_scale is None and not self.use_block_quant:
|
1007
|
+
self.w2_input_scale = torch.ones(
|
1008
|
+
self.num_experts_per_partition,
|
1009
|
+
dtype=torch.float32,
|
1010
|
+
device=hidden_states.device,
|
1011
|
+
)
|
1012
|
+
|
1013
|
+
if self.activation == "silu":
|
1014
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
1015
|
+
gateup_output,
|
1016
|
+
down_input,
|
1017
|
+
gateup_output.shape[1],
|
1018
|
+
reorder_topk_ids,
|
1019
|
+
self.w2_input_scale,
|
1020
|
+
0,
|
1021
|
+
self.num_experts_per_partition - 1,
|
1022
|
+
BLOCK_SIZE=512,
|
1023
|
+
)
|
1024
|
+
else:
|
1025
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
1026
|
+
|
1027
|
+
# GroupGemm-1
|
1028
|
+
down_output = torch.empty(
|
1029
|
+
down_input.shape[0],
|
1030
|
+
self.w2_weight.shape[1],
|
1031
|
+
device=hidden_states.device,
|
1032
|
+
dtype=hidden_states.dtype,
|
1033
|
+
)
|
1034
|
+
if down_input.shape[0] > 0:
|
1035
|
+
# Transpose earlier so that the testing will not trigger transposing kernels
|
1036
|
+
down_input = (
|
1037
|
+
down_input[0],
|
1038
|
+
get_col_major_tma_aligned_tensor(down_input[1]),
|
1039
|
+
)
|
1040
|
+
"""
|
1041
|
+
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1042
|
+
down_input, self.w2_weight, out, masked_m, expected_m
|
1043
|
+
)
|
1044
|
+
"""
|
1045
|
+
|
1046
|
+
return down_output
|