sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -75,8 +75,9 @@ class FusedMoE(torch.nn.Module):
|
|
75
75
|
inplace: bool = True,
|
76
76
|
no_combine: bool = False,
|
77
77
|
routed_scaling_factor: Optional[float] = None,
|
78
|
-
|
78
|
+
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
79
79
|
enable_ep_moe: Optional[bool] = False,
|
80
|
+
skip_quant: Optional[bool] = False,
|
80
81
|
):
|
81
82
|
super().__init__()
|
82
83
|
|
@@ -92,16 +93,13 @@ class FusedMoE(torch.nn.Module):
|
|
92
93
|
self.num_experts = num_experts
|
93
94
|
self.expert_map = None
|
94
95
|
|
95
|
-
if
|
96
|
+
if enable_flashinfer_cutlass_moe and quant_config is None:
|
96
97
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
97
|
-
|
98
|
+
enable_flashinfer_cutlass_moe = False
|
98
99
|
enable_ep_moe = False
|
99
100
|
|
100
|
-
self.
|
101
|
+
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
101
102
|
if enable_ep_moe:
|
102
|
-
assert (
|
103
|
-
self.enable_flashinfer_moe
|
104
|
-
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
105
103
|
self.ep_size = self.tp_size
|
106
104
|
self.ep_rank = self.tp_rank
|
107
105
|
self.tp_size = 1
|
@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
|
|
110
108
|
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
111
109
|
# Create a expert map for the local experts
|
112
110
|
assert num_experts % self.ep_size == 0
|
113
|
-
self.
|
111
|
+
self.num_local_experts = num_experts // self.ep_size
|
114
112
|
self.expert_map[
|
115
113
|
self.ep_rank
|
116
|
-
* self.
|
117
|
-
* self.
|
118
|
-
] = torch.arange(0, self.
|
114
|
+
* self.num_local_experts : (self.ep_rank + 1)
|
115
|
+
* self.num_local_experts
|
116
|
+
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
119
117
|
else:
|
120
118
|
self.ep_size = 1
|
121
119
|
self.ep_rank = 0
|
122
|
-
self.
|
120
|
+
self.num_local_experts = num_experts
|
123
121
|
self.routed_scaling_factor = routed_scaling_factor
|
124
122
|
assert intermediate_size % self.tp_size == 0
|
125
123
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
|
|
134
132
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
135
133
|
)
|
136
134
|
|
135
|
+
if skip_quant:
|
136
|
+
return
|
137
|
+
|
137
138
|
if quant_config is None:
|
138
139
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
139
140
|
self.use_triton_kernels
|
@@ -141,13 +142,15 @@ class FusedMoE(torch.nn.Module):
|
|
141
142
|
else:
|
142
143
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
143
144
|
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
144
|
-
self.quant_method.
|
145
|
+
self.quant_method.enable_flashinfer_cutlass_moe = (
|
146
|
+
self.enable_flashinfer_cutlass_moe
|
147
|
+
)
|
145
148
|
assert self.quant_method is not None
|
146
149
|
|
147
150
|
self.quant_config = quant_config
|
148
151
|
self.quant_method.create_weights(
|
149
152
|
layer=self,
|
150
|
-
num_experts=self.
|
153
|
+
num_experts=self.num_local_experts,
|
151
154
|
hidden_size=hidden_size,
|
152
155
|
# FIXME: figure out which intermediate_size to use
|
153
156
|
intermediate_size=self.intermediate_size_per_partition,
|
@@ -376,6 +379,23 @@ class FusedMoE(torch.nn.Module):
|
|
376
379
|
if expert_id == -1:
|
377
380
|
return
|
378
381
|
|
382
|
+
self._weight_loader_impl(
|
383
|
+
param=param,
|
384
|
+
loaded_weight=loaded_weight,
|
385
|
+
weight_name=weight_name,
|
386
|
+
shard_id=shard_id,
|
387
|
+
expert_id=expert_id,
|
388
|
+
)
|
389
|
+
|
390
|
+
def _weight_loader_impl(
|
391
|
+
self,
|
392
|
+
param: torch.nn.Parameter,
|
393
|
+
loaded_weight: torch.Tensor,
|
394
|
+
weight_name: str,
|
395
|
+
shard_id: str,
|
396
|
+
expert_id: int,
|
397
|
+
) -> None:
|
398
|
+
|
379
399
|
# TP rank is set to 0 if EP is enabled
|
380
400
|
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
381
401
|
|
@@ -396,6 +416,10 @@ class FusedMoE(torch.nn.Module):
|
|
396
416
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
397
417
|
)
|
398
418
|
|
419
|
+
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
420
|
+
if getattr(self, "use_flashinfer_trtllm_moe", False):
|
421
|
+
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
422
|
+
|
399
423
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
400
424
|
# Fetch the dim to shard the parameter/loaded weight
|
401
425
|
# based on the shard id. This will be whatever
|
@@ -603,37 +627,3 @@ class FusedMoE(torch.nn.Module):
|
|
603
627
|
("w3", ckpt_up_proj_name),
|
604
628
|
]
|
605
629
|
]
|
606
|
-
|
607
|
-
def _load_fp8_scale(
|
608
|
-
self,
|
609
|
-
param: torch.nn.Parameter,
|
610
|
-
loaded_weight: torch.Tensor,
|
611
|
-
weight_name: str,
|
612
|
-
shard_id: str,
|
613
|
-
expert_id: int,
|
614
|
-
) -> None:
|
615
|
-
param_data = param.data
|
616
|
-
|
617
|
-
# Input scales can be loaded directly and should be equal.
|
618
|
-
if "input_scale" in weight_name:
|
619
|
-
if (
|
620
|
-
param_data[expert_id] != 1
|
621
|
-
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
622
|
-
):
|
623
|
-
raise ValueError(
|
624
|
-
"input_scales of w1 and w3 of a layer "
|
625
|
-
f"must be equal. But got {param_data[expert_id]} "
|
626
|
-
f"vs. {loaded_weight}"
|
627
|
-
)
|
628
|
-
param_data[expert_id] = loaded_weight
|
629
|
-
# Weight scales
|
630
|
-
elif "weight_scale" in weight_name:
|
631
|
-
# If we are in merged column case (gate_up_proj)
|
632
|
-
if shard_id in ("w1", "w3"):
|
633
|
-
# We have to keep the weight scales of w1 and w3 because
|
634
|
-
# we need to re-quantize w1/w3 weights after weight loading.
|
635
|
-
idx = 0 if shard_id == "w1" else 1
|
636
|
-
param_data[expert_id][idx] = loaded_weight
|
637
|
-
# If we are in the row parallel case (down_proj)
|
638
|
-
else:
|
639
|
-
param_data[expert_id] = loaded_weight
|
@@ -1,21 +1,25 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
2
|
-
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Optional
|
3
6
|
|
4
7
|
import torch
|
5
8
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
6
9
|
from triton_kernels.matmul_ogs import matmul_ogs
|
7
|
-
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
10
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
8
11
|
|
9
12
|
from sglang.srt.utils import direct_register_custom_op
|
10
13
|
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
16
|
+
|
11
17
|
|
12
18
|
def triton_kernel_moe_forward(
|
13
19
|
hidden_states: torch.Tensor,
|
14
20
|
w1: torch.Tensor,
|
15
21
|
w2: torch.Tensor,
|
16
|
-
|
17
|
-
topk: int,
|
18
|
-
renormalize: bool,
|
22
|
+
topk_output: TopKOutput,
|
19
23
|
inplace: bool = False,
|
20
24
|
activation: str = "silu",
|
21
25
|
apply_router_weight_on_input: bool = False,
|
@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
|
|
30
34
|
block_shape: Optional[list[int]] = None,
|
31
35
|
) -> torch.Tensor:
|
32
36
|
|
33
|
-
|
34
|
-
|
35
|
-
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
37
|
+
assert topk_output.format.is_triton_kernel()
|
38
|
+
routing_data, gather_idx, scatter_idx = topk_output
|
36
39
|
|
37
40
|
return triton_kernel_fused_experts(
|
38
41
|
hidden_states,
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -15,7 +15,8 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import math
|
18
|
-
from
|
18
|
+
from enum import Enum, auto
|
19
|
+
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
19
20
|
|
20
21
|
import torch
|
21
22
|
import torch.nn.functional as F
|
@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
|
27
28
|
ExpertLocationDispatchInfo,
|
28
29
|
topk_ids_logical_to_physical,
|
29
30
|
)
|
31
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
30
32
|
from sglang.srt.utils import (
|
31
33
|
cpu_has_amx_support,
|
32
34
|
get_bool_env_var,
|
@@ -37,12 +39,18 @@ from sglang.srt.utils import (
|
|
37
39
|
is_npu,
|
38
40
|
)
|
39
41
|
|
42
|
+
try:
|
43
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
44
|
+
except ImportError:
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
40
48
|
_is_cuda = is_cuda()
|
41
49
|
_is_hip = is_hip()
|
42
|
-
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
44
50
|
_is_cpu = is_cpu()
|
51
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
45
52
|
_is_npu = is_npu()
|
53
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
46
54
|
|
47
55
|
if _is_cuda:
|
48
56
|
from sgl_kernel import moe_fused_gate
|
@@ -54,20 +62,62 @@ if _use_aiter:
|
|
54
62
|
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
55
63
|
except ImportError:
|
56
64
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
57
|
-
|
58
65
|
if _is_npu:
|
59
66
|
import torch_npu
|
60
67
|
|
61
68
|
|
62
|
-
|
69
|
+
# -------------------------------- TopKOutput ---------------------------------------
|
70
|
+
|
71
|
+
|
72
|
+
class TopKOutputFormat(Enum):
|
73
|
+
STANDARD = auto()
|
74
|
+
TRITON_KERNEL = auto()
|
75
|
+
|
76
|
+
def is_standard(self) -> bool:
|
77
|
+
return self == TopKOutputFormat.STANDARD
|
78
|
+
|
79
|
+
def is_triton_kernel(self) -> bool:
|
80
|
+
return self == TopKOutputFormat.TRITON_KERNEL
|
81
|
+
|
82
|
+
|
83
|
+
@runtime_checkable
|
84
|
+
class TopKOutput(Protocol):
|
85
|
+
"""Protocol for top-k outputs in different formats."""
|
86
|
+
|
87
|
+
@property
|
88
|
+
def format(self) -> TopKOutputFormat:
|
89
|
+
"""The format of the output."""
|
90
|
+
...
|
91
|
+
|
92
|
+
|
93
|
+
class StandardTopKOutput(NamedTuple):
|
94
|
+
"""Standard top-k output format."""
|
95
|
+
|
63
96
|
topk_weights: torch.Tensor
|
64
97
|
topk_ids: torch.Tensor
|
65
98
|
router_logits: torch.Tensor
|
66
99
|
|
100
|
+
@property
|
101
|
+
def format(self) -> TopKOutputFormat:
|
102
|
+
return TopKOutputFormat.STANDARD
|
103
|
+
|
104
|
+
|
105
|
+
class TritonKernelTopKOutput(NamedTuple):
|
106
|
+
"""Triton kernel top-k output format."""
|
107
|
+
|
108
|
+
routing_data: RoutingData
|
109
|
+
gather_indx: GatherIndx
|
110
|
+
scatter_indx: ScatterIndx
|
111
|
+
|
112
|
+
@property
|
113
|
+
def format(self) -> TopKOutputFormat:
|
114
|
+
return TopKOutputFormat.TRITON_KERNEL
|
115
|
+
|
116
|
+
|
117
|
+
# -------------------------------- TopK ---------------------------------------
|
67
118
|
|
68
|
-
class TopK(CustomOp):
|
69
119
|
|
70
|
-
|
120
|
+
class TopK(CustomOp):
|
71
121
|
|
72
122
|
def __init__(
|
73
123
|
self,
|
@@ -98,6 +148,8 @@ class TopK(CustomOp):
|
|
98
148
|
self.correction_bias = correction_bias
|
99
149
|
self.routed_scaling_factor = routed_scaling_factor
|
100
150
|
|
151
|
+
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
152
|
+
|
101
153
|
def forward_native(
|
102
154
|
self,
|
103
155
|
hidden_states: torch.Tensor,
|
@@ -132,23 +184,29 @@ class TopK(CustomOp):
|
|
132
184
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
133
185
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
134
186
|
) -> TopKOutput:
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
187
|
+
if self.use_triton_kernels:
|
188
|
+
routing_data, gather_idx, scatter_idx = routing(
|
189
|
+
router_logits, self.top_k, self.renormalize
|
190
|
+
)
|
191
|
+
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
192
|
+
else:
|
193
|
+
torch_native = False
|
194
|
+
return select_experts(
|
195
|
+
hidden_states=hidden_states,
|
196
|
+
router_logits=router_logits,
|
197
|
+
top_k=self.top_k,
|
198
|
+
use_grouped_topk=self.use_grouped_topk,
|
199
|
+
renormalize=self.renormalize,
|
200
|
+
topk_group=self.topk_group,
|
201
|
+
num_expert_group=self.num_expert_group,
|
202
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
203
|
+
custom_routing_function=self.custom_routing_function,
|
204
|
+
correction_bias=self.correction_bias,
|
205
|
+
torch_native=torch_native,
|
206
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
207
|
+
num_token_non_padded=num_token_non_padded,
|
208
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
209
|
+
)
|
152
210
|
|
153
211
|
def forward_cpu(
|
154
212
|
self,
|
@@ -218,6 +276,9 @@ class TopK(CustomOp):
|
|
218
276
|
)
|
219
277
|
|
220
278
|
|
279
|
+
# ------------------------------- TopK implementation -------------------------------------
|
280
|
+
|
281
|
+
|
221
282
|
def fused_topk_torch_native(
|
222
283
|
hidden_states: torch.Tensor,
|
223
284
|
gating_output: torch.Tensor,
|
@@ -387,6 +448,7 @@ def grouped_topk_cpu(
|
|
387
448
|
)
|
388
449
|
|
389
450
|
|
451
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
390
452
|
def biased_grouped_topk_impl(
|
391
453
|
hidden_states: torch.Tensor,
|
392
454
|
gating_output: torch.Tensor,
|
@@ -482,7 +544,6 @@ def biased_grouped_topk_gpu(
|
|
482
544
|
renormalize: bool,
|
483
545
|
num_expert_group: int = 0,
|
484
546
|
topk_group: int = 0,
|
485
|
-
compiled: bool = not _is_npu,
|
486
547
|
num_fused_shared_experts: int = 0,
|
487
548
|
routed_scaling_factor: Optional[float] = None,
|
488
549
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
@@ -535,14 +596,7 @@ def biased_grouped_topk_gpu(
|
|
535
596
|
)
|
536
597
|
return topk_weights, topk_ids
|
537
598
|
else:
|
538
|
-
|
539
|
-
torch.compile(
|
540
|
-
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
541
|
-
)
|
542
|
-
if compiled
|
543
|
-
else biased_grouped_topk_impl
|
544
|
-
)
|
545
|
-
return biased_grouped_topk_fn(
|
599
|
+
return biased_grouped_topk_impl(
|
546
600
|
hidden_states,
|
547
601
|
gating_output,
|
548
602
|
correction_bias,
|
@@ -688,4 +742,4 @@ def select_experts(
|
|
688
742
|
|
689
743
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
690
744
|
|
691
|
-
return
|
745
|
+
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
sglang/srt/layers/multimodal.py
CHANGED
@@ -55,14 +55,17 @@ def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
|
55
55
|
|
56
56
|
intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
|
57
57
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
58
|
+
# Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
59
|
+
# Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
|
60
|
+
with torch.cuda.device(tensor.device):
|
61
|
+
hash_kernel[grid](
|
62
|
+
tensor,
|
63
|
+
intermediate_hashes,
|
64
|
+
n,
|
65
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
66
|
+
PRIME=PRIME_1,
|
67
|
+
XCONST=PRIME_2,
|
68
|
+
)
|
66
69
|
|
67
70
|
# TODO: threads can't be synced on triton kernel
|
68
71
|
final_hash = intermediate_hashes.sum().item()
|
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
|
|
28
28
|
CompressedTensorsConfig,
|
29
29
|
)
|
30
30
|
|
31
|
-
_is_cuda = is_cuda()
|
32
|
-
_is_npu = is_npu()
|
33
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
34
|
-
_is_cpu = is_cpu()
|
35
|
-
_is_hip = is_hip()
|
36
|
-
|
37
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
38
|
-
from vllm import _custom_ops as vllm_ops
|
39
|
-
from vllm._custom_ops import scaled_fp8_quant
|
40
31
|
|
41
32
|
try:
|
42
33
|
import vllm
|
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
568
559
|
requires_grad=False,
|
569
560
|
)
|
570
561
|
|
562
|
+
from vllm import _custom_ops as vllm_ops
|
563
|
+
|
571
564
|
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
572
565
|
layer.w13_weight_packed,
|
573
566
|
layer.w13_g_idx_sort_indices,
|