sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,566 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from: https://github.com/vllm-project/vllm/blob/ab3e80042eac24dd362408e6d63ad98768046359/vllm/model_executor/layers/quantization/gguf.py
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional
8
+
9
+ import gguf
10
+ import torch
11
+ from gguf import GGMLQuantizationType as WeightType
12
+ from torch.nn.parameter import Parameter, UninitializedParameter
13
+
14
+ from sglang.srt.layers.linear import LinearBase
15
+ from sglang.srt.layers.moe import MoeRunnerConfig
16
+ from sglang.srt.layers.quantization.base_config import (
17
+ FusedMoEMethodBase,
18
+ LinearMethodBase,
19
+ QuantizationConfig,
20
+ QuantizeMethodBase,
21
+ )
22
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
23
+ from sglang.srt.utils import is_cuda, is_hip, is_xpu, set_weight_attrs
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.token_dispatcher import (
27
+ CombineInput,
28
+ StandardDispatchOutput,
29
+ )
30
+
31
+ _is_cuda = is_cuda()
32
+ _is_hip = is_hip()
33
+ _is_xpu = is_xpu()
34
+
35
+ if _is_cuda:
36
+ from sgl_kernel import gelu_and_mul, moe_align_block_size, moe_sum, silu_and_mul
37
+ from sgl_kernel.quantization import (
38
+ ggml_dequantize,
39
+ ggml_moe_a8,
40
+ ggml_moe_a8_vec,
41
+ ggml_moe_get_block_size,
42
+ ggml_mul_mat_a8,
43
+ ggml_mul_mat_vec_a8,
44
+ )
45
+ else:
46
+ warnings.warn(f"Only CUDA support GGUF q uantization currently.")
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class GGUFConfig(QuantizationConfig):
52
+ """Config class for GGUF."""
53
+
54
+ def __init__(self, modules_to_not_convert: list[str] | None = None) -> None:
55
+ super().__init__()
56
+ self.modules_to_not_convert = modules_to_not_convert or []
57
+
58
+ def __repr__(self) -> str:
59
+ return "GGUFConfig()"
60
+
61
+ def get_scaled_act_names(self) -> List[str]:
62
+ return []
63
+
64
+ def get_name(self) -> "str":
65
+ return "gguf"
66
+
67
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
68
+ return [torch.half, torch.bfloat16, torch.float32]
69
+
70
+ @classmethod
71
+ def get_min_capability(cls) -> int:
72
+ return 60
73
+
74
+ @classmethod
75
+ def get_config_filenames(cls) -> list[str]:
76
+ return [] # no extra configs.
77
+
78
+ @classmethod
79
+ def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
80
+ modules_to_not_convert = cls.get_from_keys_or(
81
+ config, ["modules_to_not_convert"], None
82
+ )
83
+ return cls(modules_to_not_convert)
84
+
85
+ def get_quant_method(
86
+ self, layer: torch.nn.Module, prefix: str
87
+ ) -> Optional["QuantizeMethodBase"]:
88
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
89
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
90
+
91
+ if isinstance(layer, LinearBase):
92
+ if is_layer_skipped_gguf(prefix, self.modules_to_not_convert):
93
+ return UnquantizedLinearMethod()
94
+ return GGUFLinearMethod(self)
95
+ elif isinstance(layer, VocabParallelEmbedding):
96
+ return GGUFEmbeddingMethod(self)
97
+ elif isinstance(layer, FusedMoE):
98
+ return GGUFMoEMethod(self)
99
+ return None
100
+
101
+
102
+ def is_layer_skipped_gguf(prefix: str, modules_to_not_convert: list[str]):
103
+ return any(module_name in prefix for module_name in modules_to_not_convert)
104
+
105
+
106
+ UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
107
+ STANDARD_QUANT_TYPES = {
108
+ WeightType.Q4_0,
109
+ WeightType.Q4_1,
110
+ WeightType.Q5_0,
111
+ WeightType.Q5_1,
112
+ WeightType.Q8_0,
113
+ WeightType.Q8_1,
114
+ }
115
+ KQUANT_TYPES = {
116
+ WeightType.Q2_K,
117
+ WeightType.Q3_K,
118
+ WeightType.Q4_K,
119
+ WeightType.Q5_K,
120
+ WeightType.Q6_K,
121
+ }
122
+ IMATRIX_QUANT_TYPES = {
123
+ WeightType.IQ1_M,
124
+ WeightType.IQ1_S,
125
+ WeightType.IQ2_XXS,
126
+ WeightType.IQ2_XS,
127
+ WeightType.IQ2_S,
128
+ WeightType.IQ3_XXS,
129
+ WeightType.IQ3_S,
130
+ WeightType.IQ4_XS,
131
+ WeightType.IQ4_NL,
132
+ }
133
+ # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
134
+ # Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
135
+ # MMQ kernel for I-Matrix quantization.
136
+ DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
137
+ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
138
+ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
139
+
140
+
141
+ def fused_mul_mat_gguf(
142
+ x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
143
+ ) -> torch.Tensor:
144
+ if qweight_type in IMATRIX_QUANT_TYPES:
145
+ mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
146
+ else:
147
+ mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
148
+ # HACK: when doing chunked prefill we don't generate output tokens
149
+ # so input to logits generator is empty which causes invalid parameter
150
+ if x.shape[0] == 0:
151
+ return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
152
+ # there is no need to call any kernel for fp16/bf16
153
+ if qweight_type in UNQUANTIZED_TYPES:
154
+ return x @ qweight.T
155
+ # enable MMVQ in contiguous batching with batch_size=1
156
+ if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
157
+ y = ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
158
+ # Use MMQ Kernel if it's available (standard + k-quants)
159
+ elif qweight_type in MMQ_QUANT_TYPES:
160
+ y = ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
161
+ # If there is no available MMQ kernel, fallback to dequantize
162
+ elif qweight_type in DEQUANT_TYPES:
163
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
164
+ shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
165
+ weight = ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
166
+ y = x @ weight.T
167
+ else:
168
+ # Raise an error if the quantization type is not supported.
169
+ # Might be useful if llama.cpp adds a new quantization type.
170
+ # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
171
+ qweight_type = WeightType(qweight_type)
172
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
173
+ return y
174
+
175
+
176
+ def fused_moe_gguf(
177
+ x: torch.Tensor,
178
+ w1: torch.Tensor,
179
+ w2: torch.Tensor,
180
+ topk_weights: torch.Tensor,
181
+ topk_ids: torch.Tensor,
182
+ qweight_type: int,
183
+ qweight_type2: int,
184
+ activation: str,
185
+ ) -> torch.Tensor:
186
+ def act(x: torch.Tensor):
187
+ d = x.shape[-1] // 2
188
+ output_shape = x.shape[:-1] + (d,)
189
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
190
+ if activation == "silu":
191
+ silu_and_mul(out, x)
192
+ elif activation == "gelu":
193
+ gelu_and_mul(out, x)
194
+ else:
195
+ raise ValueError(f"Unsupported activation: {activation}")
196
+ return out
197
+
198
+ out_hidden_states = torch.empty_like(x)
199
+ # unless we decent expert reuse we are better off running moe_vec kernel
200
+ if (
201
+ qweight_type2 in MMQ_QUANT_TYPES
202
+ and qweight_type in MMQ_QUANT_TYPES
203
+ and x.shape[0] > 64
204
+ ):
205
+ num_tokens, _ = x.shape
206
+ E, N, _ = w1.shape
207
+ top_k = topk_ids.shape[1]
208
+ BLOCK_SIZE = ggml_moe_get_block_size(qweight_type)
209
+
210
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
211
+ topk_ids, BLOCK_SIZE, E
212
+ )
213
+ out = ggml_moe_a8(
214
+ x,
215
+ w1,
216
+ sorted_token_ids,
217
+ expert_ids,
218
+ num_tokens_post_padded,
219
+ qweight_type,
220
+ N,
221
+ top_k,
222
+ num_tokens,
223
+ )
224
+ out = act(out)
225
+ out = ggml_moe_a8(
226
+ out,
227
+ w2,
228
+ sorted_token_ids,
229
+ expert_ids,
230
+ num_tokens_post_padded,
231
+ qweight_type2,
232
+ w2.shape[1],
233
+ 1,
234
+ num_tokens * top_k,
235
+ )
236
+ out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
237
+ topk_weights.view(num_tokens, top_k, 1)
238
+ )
239
+ # TODO(FlamingoPg): maybe we can use moe_sum_reduce here?
240
+ moe_sum(out, out_hidden_states)
241
+ elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
242
+ num_tokens, _ = x.shape
243
+ E, N, _ = w1.shape
244
+ top_k = topk_ids.shape[1]
245
+
246
+ out = ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
247
+ out = act(out)
248
+
249
+ out = ggml_moe_a8_vec(
250
+ out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
251
+ )
252
+ out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
253
+ topk_weights.view(num_tokens, top_k, 1)
254
+ )
255
+ moe_sum(out, out_hidden_states)
256
+ else:
257
+ logger.warning_once(
258
+ "There is no support for fast MoE kernel "
259
+ "for current quantization method. "
260
+ "Falling back to slow implementation. "
261
+ )
262
+ for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
263
+ inp = x[tok].reshape((1,) + x.shape[1:])
264
+ current_hidden_state = None
265
+ for ww, ii in zip(w, idx):
266
+ expert_up = w1[ii]
267
+
268
+ out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
269
+ out = act(out)
270
+
271
+ expert_down = w2[ii]
272
+ current_state = fused_mul_mat_gguf(
273
+ out, expert_down, qweight_type2
274
+ ).mul_(ww)
275
+ if current_hidden_state is None:
276
+ current_hidden_state = current_state
277
+ else:
278
+ current_hidden_state.add_(current_state)
279
+ out_hidden_states[tok] = current_hidden_state
280
+ return out_hidden_states
281
+
282
+
283
+ def apply_gguf_embedding(
284
+ x: torch.Tensor,
285
+ qweight: torch.Tensor,
286
+ qweight_type: int,
287
+ hidden_size: int,
288
+ dtype: torch.dtype | None = None,
289
+ ) -> torch.Tensor:
290
+ if qweight_type in UNQUANTIZED_TYPES:
291
+ return torch.embedding(qweight, x)
292
+ elif qweight_type in DEQUANT_TYPES:
293
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
294
+ x_flat = x.flatten()
295
+ assert hidden_size == qweight.shape[1] // type_size * block_size
296
+ quant = torch.index_select(qweight, dim=0, index=x_flat)
297
+ dequant = ggml_dequantize(
298
+ quant, qweight_type, hidden_size, x_flat.shape[0], dtype
299
+ )
300
+ return dequant.view(*x.shape, hidden_size)
301
+ else:
302
+ qweight_type = WeightType(qweight_type)
303
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
304
+
305
+
306
+ class GGUFLinearMethod(LinearMethodBase):
307
+ """Linear method for GGUF.
308
+
309
+ Args:
310
+ quant_config: The GGUF quantization config.
311
+ """
312
+
313
+ def __init__(self, quant_config: GGUFConfig):
314
+ self.quant_config = quant_config
315
+
316
+ def create_weights(
317
+ self,
318
+ layer: torch.nn.Module,
319
+ input_size_per_partition: int,
320
+ output_partition_sizes: list[int],
321
+ input_size: int,
322
+ output_size: int,
323
+ params_dtype: torch.dtype,
324
+ **extra_weight_attrs,
325
+ ):
326
+ self.params_dtype = params_dtype
327
+ output_size_per_partition = sum(output_partition_sizes)
328
+
329
+ tensor_shape = (output_size_per_partition, input_size_per_partition)
330
+ qweight = GGUFUninitializedParameter(requires_grad=False)
331
+ set_weight_attrs(
332
+ qweight,
333
+ {
334
+ "input_dim": 1,
335
+ "output_dim": 0,
336
+ "tensor_shape": tensor_shape,
337
+ "is_gguf_weight": True,
338
+ "data_container": [],
339
+ "shard_id": [],
340
+ "shard_id_map": {},
341
+ },
342
+ )
343
+ set_weight_attrs(qweight, extra_weight_attrs)
344
+ layer.register_parameter("qweight", qweight)
345
+
346
+ qweight_type = Parameter(
347
+ torch.empty(len(output_partition_sizes), dtype=torch.uint8),
348
+ requires_grad=False,
349
+ )
350
+ set_weight_attrs(
351
+ qweight_type,
352
+ {
353
+ "is_gguf_weight_type": True,
354
+ "weight_type": 0,
355
+ "shard_weight_type": {},
356
+ "ignore_warning": True,
357
+ },
358
+ )
359
+ set_weight_attrs(qweight_type, extra_weight_attrs)
360
+ layer.register_parameter("qweight_type", qweight_type)
361
+
362
+ def process_weights_after_loading(self, layer: torch.nn.Module):
363
+ qweight_type = layer.qweight_type.weight_type
364
+ if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
365
+ qweight_type = WeightType(qweight_type)
366
+ raise ValueError(
367
+ f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
368
+ )
369
+ # For MergedColumnParallelLinear and QKVParallelLinear, we need to
370
+ # materialize the padded weight parameter for CUDA Graph compatibility.
371
+ self._create_padded_weight_param(layer)
372
+
373
+ def _create_padded_weight_param(self, layer: torch.nn.Module):
374
+ """Create padded weight parameter for GGUF MergedLinear layer."""
375
+ qweight = layer.qweight
376
+ shard_id_map = qweight.shard_id_map
377
+ shard_id = qweight.shard_id
378
+ if len(data_container := qweight.data_container) > 1:
379
+ dtype = {data.dtype for data in data_container}
380
+ assert len(dtype) == 1, ValueError(
381
+ f"Data container has mixed dtypes: {dtype}"
382
+ )
383
+ dtype = next(iter(dtype))
384
+ # concat dim0 and pad dim1
385
+ padded_side = max(x.size(1) for x in data_container)
386
+ concat_side = sum(x.size(0) for x in data_container)
387
+ # Pad the quantized weights to dense tensor, and create a map
388
+ # with the location of each shard in the padded tensor.
389
+ padded_data = torch.zeros(
390
+ (concat_side, padded_side), dtype=dtype, device=qweight.device
391
+ )
392
+ # (dim0_start, dim0_end, dim1_size)
393
+ shard_offset_map = dict[str, tuple[int, int, int]]()
394
+ for idx in shard_id:
395
+ id_in_container = shard_id_map[idx]
396
+ start = sum(x.size(0) for x in data_container[:id_in_container])
397
+ end = start + data_container[id_in_container].size(0)
398
+ size = data_container[id_in_container].size(1)
399
+ padded_data[start:end, :size] = data_container[id_in_container]
400
+ shard_offset_map[idx] = (start, end, size)
401
+ qweight.data_container.clear()
402
+ padded_param = Parameter(padded_data, requires_grad=False)
403
+ set_weight_attrs(padded_param, vars(qweight))
404
+ set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
405
+ layer.register_parameter("qweight", padded_param)
406
+
407
+ def apply(
408
+ self,
409
+ layer: torch.nn.Module,
410
+ x: torch.Tensor,
411
+ bias: torch.Tensor | None = None,
412
+ ) -> torch.Tensor:
413
+ shard_id = layer.qweight.shard_id
414
+
415
+ if shard_id:
416
+ # dequantize shard weights respectively
417
+ shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
418
+ qweight = layer.qweight
419
+ result = []
420
+ for idx in shard_id:
421
+ start, end, offset = layer.qweight.shard_offset_map[idx]
422
+ qweight_type = layer.qweight_type.shard_weight_type[idx]
423
+ result.append(
424
+ fused_mul_mat_gguf(
425
+ x, qweight[start:end, :offset].contiguous(), qweight_type
426
+ )
427
+ )
428
+ out = torch.cat(result, axis=1)
429
+ else:
430
+ qweight = layer.qweight
431
+ qweight_type = layer.qweight_type.weight_type
432
+ out = fused_mul_mat_gguf(x, qweight, qweight_type)
433
+ if bias is not None:
434
+ out.add_(bias)
435
+ return out
436
+
437
+
438
+ class GGUFMoEMethod(FusedMoEMethodBase):
439
+ """MoE method for GGUF.
440
+
441
+ Args:
442
+ quant_config: The GGUF quantization config.
443
+ """
444
+
445
+ def __init__(self, quant_config: GGUFConfig):
446
+ self.quant_config = quant_config
447
+
448
+ def create_weights(
449
+ self,
450
+ layer: torch.nn.Module,
451
+ num_experts: int,
452
+ hidden_size: int,
453
+ intermediate_size_per_partition: int,
454
+ params_dtype: torch.dtype,
455
+ **extra_weight_attrs,
456
+ ):
457
+ tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
458
+ # gate up proj
459
+ w13_qweight = GGUFUninitializedParameter(requires_grad=False)
460
+ set_weight_attrs(
461
+ w13_qweight,
462
+ {
463
+ "input_dim": 1,
464
+ "output_dim": 0,
465
+ "tensor_shape": tensor_shape,
466
+ "is_gguf_weight": True,
467
+ "data_container": [],
468
+ },
469
+ )
470
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
471
+ layer.register_parameter("w13_qweight", w13_qweight)
472
+
473
+ w13_qweight_type = Parameter(
474
+ torch.empty(1, dtype=torch.uint8), requires_grad=False
475
+ )
476
+ set_weight_attrs(
477
+ w13_qweight_type,
478
+ {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
479
+ )
480
+ set_weight_attrs(w13_qweight_type, extra_weight_attrs)
481
+ layer.register_parameter("w13_qweight_type", w13_qweight_type)
482
+
483
+ tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
484
+ # gate down proj
485
+ w2_qweight = GGUFUninitializedParameter(requires_grad=False)
486
+ set_weight_attrs(
487
+ w2_qweight,
488
+ {
489
+ "input_dim": 1,
490
+ "output_dim": 0,
491
+ "tensor_shape": tensor_shape,
492
+ "is_gguf_weight": True,
493
+ "data_container": [],
494
+ },
495
+ )
496
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
497
+ layer.register_parameter("w2_qweight", w2_qweight)
498
+
499
+ w2_qweight_type = Parameter(
500
+ torch.empty(1, dtype=torch.uint8), requires_grad=False
501
+ )
502
+ set_weight_attrs(
503
+ w2_qweight_type,
504
+ {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
505
+ )
506
+
507
+ set_weight_attrs(w2_qweight_type, extra_weight_attrs)
508
+ layer.register_parameter("w2_qweight_type", w2_qweight_type)
509
+
510
+ def create_moe_runner(
511
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
512
+ ):
513
+ self.moe_runner_config = moe_runner_config
514
+
515
+ def apply(
516
+ self,
517
+ layer: torch.nn.Module,
518
+ dispatch_output: StandardDispatchOutput,
519
+ ) -> CombineInput:
520
+ assert self.fused_experts is None
521
+
522
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
523
+
524
+ assert (
525
+ self.moe_runner_config.activation == "silu"
526
+ ), "Only SiLU activation is supported."
527
+
528
+ x = dispatch_output.hidden_states
529
+ topk_output = dispatch_output.topk_output
530
+
531
+ moe_runner_config = self.moe_runner_config
532
+
533
+ topk_weights, topk_ids, _ = topk_output
534
+ output = fused_moe_gguf(
535
+ x=x,
536
+ w1=layer.w13_qweight,
537
+ w2=layer.w2_qweight,
538
+ topk_weights=topk_weights,
539
+ topk_ids=topk_ids,
540
+ qweight_type=layer.w13_qweight_type.weight_type,
541
+ qweight_type2=layer.w2_qweight_type.weight_type,
542
+ activation=moe_runner_config.activation,
543
+ )
544
+ return StandardCombineInput(hidden_states=output)
545
+
546
+
547
+ class GGUFEmbeddingMethod(GGUFLinearMethod):
548
+ """Embedding method for GGUF.
549
+
550
+ Args:
551
+ quant_config: The GGUF quantization config.
552
+ """
553
+
554
+ def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
555
+ qweight = layer.qweight
556
+ qweight_type = layer.qweight_type.weight_type
557
+ hidden_size = qweight.tensor_shape[1]
558
+
559
+ return apply_gguf_embedding(
560
+ x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
561
+ )
562
+
563
+
564
+ class GGUFUninitializedParameter(UninitializedParameter):
565
+ cls_to_become = Parameter
566
+ data_container: list[torch.Tensor]
@@ -261,26 +261,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
261
261
 
262
262
  self.prefix = prefix
263
263
  self.topk_indices_dtype = None
264
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
264
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
265
265
  self.with_bias = False
266
266
  self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
267
267
  self.flashinfer_mxfp4_moe_precision = (
268
268
  get_global_server_args().flashinfer_mxfp4_moe_precision
269
269
  )
270
270
 
271
- self.triton_kernel_moe_forward = None
272
- self.triton_kernel_moe_with_bias_forward = None
273
- if torch.cuda.is_available() and has_triton_kernels:
274
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
275
- triton_kernel_moe_forward as _tk_forward,
276
- )
277
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
278
- triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
279
- )
280
-
281
- self.triton_kernel_moe_forward = _tk_forward
282
- self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
283
-
284
271
  def create_weights(
285
272
  self,
286
273
  layer: torch.nn.Module,
@@ -600,7 +587,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
600
587
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
601
588
  ):
602
589
  self.moe_runner_config = moe_runner_config
603
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
590
+ backend = (
591
+ MoeRunnerBackend.TRITON_KERNELS
592
+ if self.use_triton_kernels
593
+ else MoeRunnerBackend.TRITON
594
+ )
595
+ self.runner = MoeRunner(backend, moe_runner_config)
604
596
 
605
597
  def apply(
606
598
  self,
@@ -677,31 +669,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
677
669
  )[0]
678
670
  return StandardCombineInput(hidden_states=trtllm_gen_output)
679
671
 
680
- if self.use_triton_kernels:
672
+ backend = self.runner.runner_backend
673
+ if backend.is_triton_kernels():
674
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import (
675
+ TritonKernelsQuantInfo,
676
+ )
677
+
681
678
  assert (
682
679
  layer.moe_ep_size == 1
683
680
  ), "Expert parallel is not supported when using triton kernels"
684
- if self.with_bias:
685
- output = self.triton_kernel_moe_with_bias_forward(
686
- hidden_states=x,
687
- w1=self.w13_weight_triton_tensor,
688
- w1_pcg=self.w13_precision_config,
689
- w2=self.w2_weight_triton_tensor,
690
- w2_pcg=self.w2_precision_config,
691
- b1=layer.w13_weight_bias,
692
- b2=layer.w2_weight_bias,
693
- topk_output=topk_output,
694
- moe_runner_config=moe_runner_config,
695
- )
696
- else:
697
- output = self.triton_kernel_moe_forward(
698
- hidden_states=x,
699
- w1=layer.w13_weight,
700
- w2=layer.w2_weight,
701
- topk_output=topk_output,
702
- moe_runner_config=moe_runner_config,
703
- )
704
- return StandardCombineInput(hidden_states=output)
681
+ quant_info = TritonKernelsQuantInfo(
682
+ w13_weight=(
683
+ self.w13_weight_triton_tensor
684
+ if self.w13_weight_triton_tensor is not None
685
+ else layer.w13_weight
686
+ ),
687
+ w2_weight=(
688
+ self.w2_weight_triton_tensor
689
+ if self.w2_weight_triton_tensor is not None
690
+ else layer.w2_weight
691
+ ),
692
+ w13_bias=getattr(layer, "w13_weight_bias", None),
693
+ w2_bias=getattr(layer, "w2_weight_bias", None),
694
+ w13_precision_config=getattr(self, "w13_precision_config", None),
695
+ w2_precision_config=getattr(self, "w2_precision_config", None),
696
+ )
705
697
  else:
706
698
  quant_info = TritonMoeQuantInfo(
707
699
  w13_weight=layer.w13_weight,
@@ -709,7 +701,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
709
701
  b13=getattr(layer, "w13_weight_bias", None),
710
702
  b2=getattr(layer, "w2_weight_bias", None),
711
703
  )
712
- return self.runner.run(dispatch_output, quant_info)
704
+ return self.runner.run(dispatch_output, quant_info)
713
705
 
714
706
 
715
707
  class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):