sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import (
|
|
23
23
|
PackedvLLMParameter,
|
24
24
|
PerTensorScaleParameter,
|
25
25
|
RowvLLMParameter,
|
26
|
+
_ColumnvLLMParameter,
|
26
27
|
)
|
27
28
|
from sglang.srt.layers.quantization.base_config import (
|
28
29
|
QuantizationConfig,
|
@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase):
|
|
423
424
|
assert loaded_weight.numel() == 1
|
424
425
|
loaded_weight = loaded_weight.reshape(1)
|
425
426
|
|
426
|
-
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
427
|
-
|
428
427
|
if isinstance(param, _ColumnvLLMParameter):
|
429
428
|
param.load_column_parallel_weight(
|
430
429
|
loaded_weight,
|
@@ -687,10 +686,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
687
686
|
):
|
688
687
|
if loaded_shard_id is None:
|
689
688
|
if isinstance(param, PerTensorScaleParameter):
|
690
|
-
param.load_merged_column_weight(
|
689
|
+
param.load_merged_column_weight(
|
690
|
+
loaded_weight=loaded_weight,
|
691
|
+
shard_id=0,
|
692
|
+
tp_rank=self.tp_rank,
|
693
|
+
tp_size=self.tp_size,
|
694
|
+
)
|
691
695
|
return
|
692
696
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
693
|
-
param.load_merged_column_weight(
|
697
|
+
param.load_merged_column_weight(
|
698
|
+
loaded_weight=loaded_weight,
|
699
|
+
tp_rank=self.tp_rank,
|
700
|
+
tp_size=self.tp_size,
|
701
|
+
)
|
694
702
|
return
|
695
703
|
# TODO: @dsikka - move to parameter.py
|
696
704
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
@@ -719,6 +727,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
719
727
|
shard_offset=shard_offset,
|
720
728
|
shard_size=shard_size,
|
721
729
|
use_presharded_weights=self.use_presharded_weights,
|
730
|
+
tp_rank=self.tp_rank,
|
731
|
+
tp_size=self.tp_size,
|
722
732
|
)
|
723
733
|
|
724
734
|
|
@@ -782,6 +792,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
782
792
|
else:
|
783
793
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
784
794
|
self.num_kv_head_replicas = 1
|
795
|
+
self.q_proj_shard_size = self.num_heads * self.head_size
|
796
|
+
self.kv_proj_shard_size = self.num_kv_heads * self.head_size
|
785
797
|
input_size = self.hidden_size
|
786
798
|
output_size = (
|
787
799
|
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
@@ -1234,7 +1246,7 @@ class RowParallelLinear(LinearBase):
|
|
1234
1246
|
assert loaded_weight.numel() == 1
|
1235
1247
|
loaded_weight = loaded_weight.reshape(1)
|
1236
1248
|
|
1237
|
-
if isinstance(param,
|
1249
|
+
if isinstance(param, RowvLLMParameter):
|
1238
1250
|
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
|
1239
1251
|
# It supports additional parameters like tp_rank and use_presharded_weights.
|
1240
1252
|
param.load_row_parallel_weight(
|
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
|
|
28
28
|
tensor_model_parallel_all_gather,
|
29
29
|
)
|
30
30
|
from sglang.srt.layers.dp_attention import (
|
31
|
-
|
31
|
+
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
33
|
get_attention_dp_rank,
|
34
34
|
get_attention_dp_size,
|
@@ -223,16 +223,18 @@ class LogitsProcessor(nn.Module):
|
|
223
223
|
hidden_states,
|
224
224
|
lm_head: VocabParallelEmbedding,
|
225
225
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
226
|
+
aux_hidden_states: Optional[torch.Tensor] = None,
|
226
227
|
) -> LogitsProcessorOutput:
|
227
228
|
if isinstance(logits_metadata, ForwardBatch):
|
228
229
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
229
|
-
|
230
230
|
# Get the last hidden states and last logits for the next token prediction
|
231
231
|
if (
|
232
232
|
logits_metadata.forward_mode.is_decode_or_idle()
|
233
233
|
or logits_metadata.forward_mode.is_target_verify()
|
234
234
|
):
|
235
235
|
pruned_states = hidden_states
|
236
|
+
if aux_hidden_states is not None:
|
237
|
+
aux_pruned_states = [hidden for hidden in aux_hidden_states]
|
236
238
|
sample_indices = None
|
237
239
|
input_logprob_indices = None
|
238
240
|
elif (
|
@@ -256,6 +258,8 @@ class LogitsProcessor(nn.Module):
|
|
256
258
|
- 1
|
257
259
|
)
|
258
260
|
pruned_states = hidden_states[last_index]
|
261
|
+
if aux_hidden_states is not None:
|
262
|
+
aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
|
259
263
|
sample_indices = None
|
260
264
|
input_logprob_indices = None
|
261
265
|
else:
|
@@ -319,13 +323,27 @@ class LogitsProcessor(nn.Module):
|
|
319
323
|
hidden_states_to_store: Optional[torch.Tensor] = None
|
320
324
|
if logits_metadata.capture_hidden_mode.need_capture():
|
321
325
|
if logits_metadata.capture_hidden_mode.is_full():
|
322
|
-
|
326
|
+
if aux_hidden_states is not None:
|
327
|
+
aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
|
328
|
+
hidden_states_to_store = aux_hidden_states
|
329
|
+
else:
|
330
|
+
hidden_states_to_store = hidden_states
|
323
331
|
elif logits_metadata.capture_hidden_mode.is_last():
|
324
332
|
# Get the last token hidden states. If sample_indices is None,
|
325
333
|
# pruned states only contain the last tokens already.
|
326
|
-
|
327
|
-
|
328
|
-
|
334
|
+
if aux_hidden_states is not None:
|
335
|
+
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
336
|
+
hidden_states_to_store = (
|
337
|
+
aux_pruned_states[sample_indices]
|
338
|
+
if sample_indices
|
339
|
+
else aux_pruned_states
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
hidden_states_to_store = (
|
343
|
+
pruned_states[sample_indices]
|
344
|
+
if sample_indices
|
345
|
+
else pruned_states
|
346
|
+
)
|
329
347
|
else:
|
330
348
|
assert False, "Should never reach"
|
331
349
|
|
@@ -410,7 +428,7 @@ class LogitsProcessor(nn.Module):
|
|
410
428
|
logits_metadata.gathered_buffer,
|
411
429
|
hidden_states.clone(),
|
412
430
|
)
|
413
|
-
|
431
|
+
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
414
432
|
|
415
433
|
if hasattr(lm_head, "weight"):
|
416
434
|
logits = torch.matmul(
|
@@ -5,6 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
8
9
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
10
|
from sglang.srt.utils import is_cuda
|
10
11
|
|
@@ -16,6 +17,115 @@ if _is_cuda:
|
|
16
17
|
logger = logging.getLogger(__name__)
|
17
18
|
|
18
19
|
|
20
|
+
@triton.jit
|
21
|
+
def deepep_permute_triton_kernel(
|
22
|
+
input_ptr,
|
23
|
+
gateup_input_ptr,
|
24
|
+
src2dst_ptr,
|
25
|
+
topk_ids_ptr,
|
26
|
+
a1_scales_ptr,
|
27
|
+
topk,
|
28
|
+
hidden_size,
|
29
|
+
BLOCK_SIZE: tl.constexpr,
|
30
|
+
):
|
31
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
32
|
+
|
33
|
+
src_idx = tl.program_id(0)
|
34
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
35
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
36
|
+
|
37
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
38
|
+
|
39
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
40
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
41
|
+
mask = offset < hidden_size
|
42
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)
|
43
|
+
|
44
|
+
for idx in range(topk):
|
45
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
46
|
+
if dst_idx >= 0:
|
47
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
48
|
+
tl.store(dst_ptr + offset, in_data, mask=mask)
|
49
|
+
|
50
|
+
|
51
|
+
@triton.jit
|
52
|
+
def deepep_post_reorder_triton_kernel(
|
53
|
+
down_output_ptr,
|
54
|
+
output_ptr,
|
55
|
+
src2dst_ptr,
|
56
|
+
topk_ids_ptr,
|
57
|
+
topk_weights_ptr,
|
58
|
+
topk,
|
59
|
+
hidden_size,
|
60
|
+
BLOCK_SIZE: tl.constexpr,
|
61
|
+
):
|
62
|
+
InDtype = down_output_ptr.dtype.element_ty
|
63
|
+
|
64
|
+
src_idx = tl.program_id(0)
|
65
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
66
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
67
|
+
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
68
|
+
|
69
|
+
store_ptr = output_ptr + src_idx * hidden_size
|
70
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
71
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
72
|
+
mask = offset < hidden_size
|
73
|
+
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
74
|
+
for idx in range(topk):
|
75
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
76
|
+
if dst_idx >= 0:
|
77
|
+
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
78
|
+
load_ptr = down_output_ptr + dst_idx * hidden_size
|
79
|
+
in_data = tl.load(load_ptr + offset, mask=mask)
|
80
|
+
sum_vec += in_data * weigh_scale
|
81
|
+
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
82
|
+
|
83
|
+
|
84
|
+
@triton.jit
|
85
|
+
def compute_src2dst_triton_kernel(
|
86
|
+
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
87
|
+
):
|
88
|
+
pid = tl.program_id(axis=0)
|
89
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
90
|
+
mask = dst_id < num_toks
|
91
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
92
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
93
|
+
|
94
|
+
|
95
|
+
@triton.jit
|
96
|
+
def deepep_compute_src2dst_triton_kernel(
|
97
|
+
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
|
98
|
+
):
|
99
|
+
pid = tl.program_id(axis=0)
|
100
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
101
|
+
mask = dst_id < num_toks
|
102
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
103
|
+
num_invalid = tl.load(num_minus_one)
|
104
|
+
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
|
105
|
+
|
106
|
+
|
107
|
+
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
108
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
109
|
+
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
110
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
|
111
|
+
|
112
|
+
# Find offet
|
113
|
+
expert_ids = torch.arange(
|
114
|
+
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
115
|
+
)
|
116
|
+
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
|
117
|
+
num_minus_one = seg_indptr[0]
|
118
|
+
seg_indptr = seg_indptr - num_minus_one
|
119
|
+
|
120
|
+
BLOCK_SIZE = 512
|
121
|
+
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
122
|
+
deepep_compute_src2dst_triton_kernel[grid](
|
123
|
+
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
|
124
|
+
)
|
125
|
+
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
|
126
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
127
|
+
|
128
|
+
|
19
129
|
@triton.jit
|
20
130
|
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
21
131
|
expert = tl.program_id(0)
|
@@ -33,17 +143,6 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
|
33
143
|
tl.store(seg_indptr + expert + 1, target_location + 1)
|
34
144
|
|
35
145
|
|
36
|
-
@triton.jit
|
37
|
-
def compute_src2dst_triton_kernel(
|
38
|
-
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
39
|
-
):
|
40
|
-
pid = tl.program_id(axis=0)
|
41
|
-
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
42
|
-
mask = dst_id < num_toks
|
43
|
-
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
44
|
-
tl.store(src2dst + src_id, dst_id, mask=mask)
|
45
|
-
|
46
|
-
|
47
146
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
48
147
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
49
148
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
@@ -2,8 +2,14 @@ import logging
|
|
2
2
|
from typing import Callable, List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
|
+
|
6
|
+
# TODO: use deep_gemm masked kernel after low latency dispatch
|
7
|
+
# import deep_gemm
|
8
|
+
# from deep_gemm import (
|
9
|
+
# get_col_major_tma_aligned_tensor,
|
10
|
+
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
11
|
+
# )
|
5
12
|
from torch.nn import Module
|
6
|
-
from vllm import _custom_ops as vllm_ops
|
7
13
|
|
8
14
|
from sglang.srt.custom_op import CustomOp
|
9
15
|
from sglang.srt.distributed import (
|
@@ -26,18 +32,23 @@ from sglang.srt.layers.quantization.base_config import (
|
|
26
32
|
QuantizeMethodBase,
|
27
33
|
)
|
28
34
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
29
36
|
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
30
37
|
|
31
38
|
_is_cuda = is_cuda()
|
32
39
|
|
33
40
|
if _is_cuda:
|
34
41
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
42
|
+
else:
|
43
|
+
from vllm import _custom_ops as vllm_ops
|
35
44
|
|
36
45
|
|
37
46
|
logger = logging.getLogger(__name__)
|
38
47
|
|
39
48
|
_is_hip = is_hip()
|
40
49
|
|
50
|
+
_buffer = None
|
51
|
+
|
41
52
|
|
42
53
|
class GroupedGemmRunner(torch.nn.Module):
|
43
54
|
flashinfer_gemm_warpper = None
|
@@ -772,3 +783,264 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
772
783
|
custom_routing_function: Optional[Callable] = None,
|
773
784
|
) -> torch.Tensor:
|
774
785
|
raise NotImplementedError
|
786
|
+
|
787
|
+
|
788
|
+
class DeepEPMoE(EPMoE):
|
789
|
+
"""
|
790
|
+
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
791
|
+
"""
|
792
|
+
|
793
|
+
_has_printed = False
|
794
|
+
|
795
|
+
def __init__(
|
796
|
+
self,
|
797
|
+
num_experts: int,
|
798
|
+
top_k: int,
|
799
|
+
hidden_size: int,
|
800
|
+
intermediate_size: int,
|
801
|
+
params_dtype: Optional[torch.dtype] = None,
|
802
|
+
renormalize: bool = True,
|
803
|
+
use_grouped_topk: bool = False,
|
804
|
+
num_expert_group: Optional[int] = None,
|
805
|
+
topk_group: Optional[int] = None,
|
806
|
+
quant_config: Optional[QuantizationConfig] = None,
|
807
|
+
tp_size: Optional[int] = None,
|
808
|
+
prefix: str = "",
|
809
|
+
correction_bias: Optional[torch.Tensor] = None,
|
810
|
+
custom_routing_function: Optional[Callable] = None,
|
811
|
+
activation: str = "silu",
|
812
|
+
):
|
813
|
+
super().__init__(
|
814
|
+
num_experts,
|
815
|
+
top_k,
|
816
|
+
hidden_size,
|
817
|
+
intermediate_size,
|
818
|
+
params_dtype,
|
819
|
+
renormalize,
|
820
|
+
use_grouped_topk,
|
821
|
+
num_expert_group,
|
822
|
+
topk_group,
|
823
|
+
quant_config,
|
824
|
+
tp_size,
|
825
|
+
prefix,
|
826
|
+
correction_bias,
|
827
|
+
custom_routing_function,
|
828
|
+
activation,
|
829
|
+
)
|
830
|
+
|
831
|
+
def forward(
|
832
|
+
self,
|
833
|
+
hidden_states: torch.Tensor,
|
834
|
+
reorder_topk_ids: torch.Tensor,
|
835
|
+
seg_indptr: torch.Tensor,
|
836
|
+
forward_mode: ForwardMode,
|
837
|
+
):
|
838
|
+
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
839
|
+
if True: # not forward_mode.is_decode():
|
840
|
+
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
841
|
+
else:
|
842
|
+
return self.forward_deepgemm_masked(
|
843
|
+
hidden_states, reorder_topk_ids, seg_indptr
|
844
|
+
)
|
845
|
+
|
846
|
+
def forward_normal(
|
847
|
+
self,
|
848
|
+
hidden_states: torch.Tensor,
|
849
|
+
reorder_topk_ids: torch.Tensor,
|
850
|
+
seg_indptr: torch.Tensor,
|
851
|
+
):
|
852
|
+
assert self.quant_method is not None
|
853
|
+
assert self.activation == "silu"
|
854
|
+
if self.grouped_gemm_runner is None:
|
855
|
+
self.grouped_gemm_runner = GroupedGemmRunner(
|
856
|
+
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
857
|
+
)
|
858
|
+
|
859
|
+
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
860
|
+
max_value = (
|
861
|
+
torch.max(hidden_states)
|
862
|
+
.repeat(self.num_experts_per_partition)
|
863
|
+
.to(torch.float32)
|
864
|
+
)
|
865
|
+
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
866
|
+
weight_indices_cur_rank = torch.arange(
|
867
|
+
0,
|
868
|
+
self.num_experts_per_partition,
|
869
|
+
device=hidden_states.device,
|
870
|
+
dtype=torch.int64,
|
871
|
+
)
|
872
|
+
|
873
|
+
# GroupGemm-0
|
874
|
+
gateup_output = torch.empty(
|
875
|
+
hidden_states.shape[0],
|
876
|
+
self.w13_weight.shape[1],
|
877
|
+
device=hidden_states.device,
|
878
|
+
dtype=hidden_states.dtype,
|
879
|
+
)
|
880
|
+
|
881
|
+
if hidden_states.shape[0] > 0:
|
882
|
+
gateup_output = self.grouped_gemm_runner(
|
883
|
+
a=hidden_states,
|
884
|
+
b=self.w13_weight,
|
885
|
+
c=gateup_output,
|
886
|
+
batch_size=self.num_experts_per_partition,
|
887
|
+
weight_column_major=True,
|
888
|
+
seg_indptr=seg_indptr,
|
889
|
+
weight_indices=weight_indices_cur_rank,
|
890
|
+
use_fp8_w8a8=self.use_fp8_w8a8,
|
891
|
+
scale_a=self.w13_input_scale,
|
892
|
+
scale_b=(
|
893
|
+
self.w13_weight_scale_inv
|
894
|
+
if self.use_block_quant
|
895
|
+
else self.w13_weight_scale
|
896
|
+
),
|
897
|
+
block_shape=self.block_shape,
|
898
|
+
)
|
899
|
+
|
900
|
+
# Act
|
901
|
+
down_input = torch.empty(
|
902
|
+
gateup_output.shape[0],
|
903
|
+
gateup_output.shape[1] // 2,
|
904
|
+
device=gateup_output.device,
|
905
|
+
dtype=(
|
906
|
+
self.fp8_dtype
|
907
|
+
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
908
|
+
else hidden_states.dtype
|
909
|
+
),
|
910
|
+
)
|
911
|
+
if self.w2_input_scale is None and not self.use_block_quant:
|
912
|
+
self.w2_input_scale = torch.ones(
|
913
|
+
self.num_experts_per_partition,
|
914
|
+
dtype=torch.float32,
|
915
|
+
device=hidden_states.device,
|
916
|
+
)
|
917
|
+
|
918
|
+
if self.activation == "silu":
|
919
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
920
|
+
gateup_output,
|
921
|
+
down_input,
|
922
|
+
gateup_output.shape[1],
|
923
|
+
reorder_topk_ids,
|
924
|
+
self.w2_input_scale,
|
925
|
+
0,
|
926
|
+
self.num_experts_per_partition - 1,
|
927
|
+
BLOCK_SIZE=512,
|
928
|
+
)
|
929
|
+
else:
|
930
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
931
|
+
|
932
|
+
# GroupGemm-1
|
933
|
+
down_output = torch.empty(
|
934
|
+
down_input.shape[0],
|
935
|
+
self.w2_weight.shape[1],
|
936
|
+
device=hidden_states.device,
|
937
|
+
dtype=hidden_states.dtype,
|
938
|
+
)
|
939
|
+
if down_input.shape[0] > 0:
|
940
|
+
down_output = self.grouped_gemm_runner(
|
941
|
+
a=down_input,
|
942
|
+
b=self.w2_weight,
|
943
|
+
c=down_output,
|
944
|
+
batch_size=self.num_experts_per_partition,
|
945
|
+
weight_column_major=True,
|
946
|
+
seg_indptr=seg_indptr,
|
947
|
+
weight_indices=weight_indices_cur_rank,
|
948
|
+
use_fp8_w8a8=self.use_fp8_w8a8,
|
949
|
+
scale_a=self.w2_input_scale,
|
950
|
+
scale_b=(
|
951
|
+
self.w2_weight_scale_inv
|
952
|
+
if self.use_block_quant
|
953
|
+
else self.w2_weight_scale
|
954
|
+
),
|
955
|
+
block_shape=self.block_shape,
|
956
|
+
)
|
957
|
+
return down_output
|
958
|
+
|
959
|
+
def forward_deepgemm_masked(
|
960
|
+
self,
|
961
|
+
hidden_states: torch.Tensor,
|
962
|
+
reorder_topk_ids: torch.Tensor,
|
963
|
+
seg_indptr: torch.Tensor,
|
964
|
+
):
|
965
|
+
assert self.quant_method is not None
|
966
|
+
assert self.activation == "silu"
|
967
|
+
|
968
|
+
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
969
|
+
max_value = (
|
970
|
+
torch.max(hidden_states)
|
971
|
+
.repeat(self.num_experts_per_partition)
|
972
|
+
.to(torch.float32)
|
973
|
+
)
|
974
|
+
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
975
|
+
|
976
|
+
# GroupGemm-0
|
977
|
+
gateup_output = torch.empty(
|
978
|
+
hidden_states.shape[0],
|
979
|
+
self.w13_weight.shape[1],
|
980
|
+
device=hidden_states.device,
|
981
|
+
dtype=hidden_states.dtype,
|
982
|
+
)
|
983
|
+
if hidden_states.shape[0] > 0:
|
984
|
+
# Transpose earlier so that the testing will not trigger transposing kernels
|
985
|
+
hidden_states = (
|
986
|
+
hidden_states[0],
|
987
|
+
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
988
|
+
)
|
989
|
+
"""
|
990
|
+
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
991
|
+
hidden_states, self.w13_weight, out, masked_m, expected_m
|
992
|
+
)
|
993
|
+
"""
|
994
|
+
|
995
|
+
# Act
|
996
|
+
down_input = torch.empty(
|
997
|
+
gateup_output.shape[0],
|
998
|
+
gateup_output.shape[1] // 2,
|
999
|
+
device=gateup_output.device,
|
1000
|
+
dtype=(
|
1001
|
+
self.fp8_dtype
|
1002
|
+
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
1003
|
+
else hidden_states.dtype
|
1004
|
+
),
|
1005
|
+
)
|
1006
|
+
if self.w2_input_scale is None and not self.use_block_quant:
|
1007
|
+
self.w2_input_scale = torch.ones(
|
1008
|
+
self.num_experts_per_partition,
|
1009
|
+
dtype=torch.float32,
|
1010
|
+
device=hidden_states.device,
|
1011
|
+
)
|
1012
|
+
|
1013
|
+
if self.activation == "silu":
|
1014
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
1015
|
+
gateup_output,
|
1016
|
+
down_input,
|
1017
|
+
gateup_output.shape[1],
|
1018
|
+
reorder_topk_ids,
|
1019
|
+
self.w2_input_scale,
|
1020
|
+
0,
|
1021
|
+
self.num_experts_per_partition - 1,
|
1022
|
+
BLOCK_SIZE=512,
|
1023
|
+
)
|
1024
|
+
else:
|
1025
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
1026
|
+
|
1027
|
+
# GroupGemm-1
|
1028
|
+
down_output = torch.empty(
|
1029
|
+
down_input.shape[0],
|
1030
|
+
self.w2_weight.shape[1],
|
1031
|
+
device=hidden_states.device,
|
1032
|
+
dtype=hidden_states.dtype,
|
1033
|
+
)
|
1034
|
+
if down_input.shape[0] > 0:
|
1035
|
+
# Transpose earlier so that the testing will not trigger transposing kernels
|
1036
|
+
down_input = (
|
1037
|
+
down_input[0],
|
1038
|
+
get_col_major_tma_aligned_tensor(down_input[1]),
|
1039
|
+
)
|
1040
|
+
"""
|
1041
|
+
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1042
|
+
down_input, self.w2_weight, out, masked_m, expected_m
|
1043
|
+
)
|
1044
|
+
"""
|
1045
|
+
|
1046
|
+
return down_output
|