sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,566 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from: https://github.com/vllm-project/vllm/blob/ab3e80042eac24dd362408e6d63ad98768046359/vllm/model_executor/layers/quantization/gguf.py
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional
8
+
9
+ import gguf
10
+ import torch
11
+ from gguf import GGMLQuantizationType as WeightType
12
+ from torch.nn.parameter import Parameter, UninitializedParameter
13
+
14
+ from sglang.srt.layers.linear import LinearBase
15
+ from sglang.srt.layers.moe import MoeRunnerConfig
16
+ from sglang.srt.layers.quantization.base_config import (
17
+ FusedMoEMethodBase,
18
+ LinearMethodBase,
19
+ QuantizationConfig,
20
+ QuantizeMethodBase,
21
+ )
22
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
23
+ from sglang.srt.utils import is_cuda, is_hip, is_xpu, set_weight_attrs
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.token_dispatcher import (
27
+ CombineInput,
28
+ StandardDispatchOutput,
29
+ )
30
+
31
+ _is_cuda = is_cuda()
32
+ _is_hip = is_hip()
33
+ _is_xpu = is_xpu()
34
+
35
+ if _is_cuda:
36
+ from sgl_kernel import gelu_and_mul, moe_align_block_size, moe_sum, silu_and_mul
37
+ from sgl_kernel.quantization import (
38
+ ggml_dequantize,
39
+ ggml_moe_a8,
40
+ ggml_moe_a8_vec,
41
+ ggml_moe_get_block_size,
42
+ ggml_mul_mat_a8,
43
+ ggml_mul_mat_vec_a8,
44
+ )
45
+ else:
46
+ warnings.warn(f"Only CUDA support GGUF q uantization currently.")
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class GGUFConfig(QuantizationConfig):
52
+ """Config class for GGUF."""
53
+
54
+ def __init__(self, modules_to_not_convert: list[str] | None = None) -> None:
55
+ super().__init__()
56
+ self.modules_to_not_convert = modules_to_not_convert or []
57
+
58
+ def __repr__(self) -> str:
59
+ return "GGUFConfig()"
60
+
61
+ def get_scaled_act_names(self) -> List[str]:
62
+ return []
63
+
64
+ def get_name(self) -> "str":
65
+ return "gguf"
66
+
67
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
68
+ return [torch.half, torch.bfloat16, torch.float32]
69
+
70
+ @classmethod
71
+ def get_min_capability(cls) -> int:
72
+ return 60
73
+
74
+ @classmethod
75
+ def get_config_filenames(cls) -> list[str]:
76
+ return [] # no extra configs.
77
+
78
+ @classmethod
79
+ def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
80
+ modules_to_not_convert = cls.get_from_keys_or(
81
+ config, ["modules_to_not_convert"], None
82
+ )
83
+ return cls(modules_to_not_convert)
84
+
85
+ def get_quant_method(
86
+ self, layer: torch.nn.Module, prefix: str
87
+ ) -> Optional["QuantizeMethodBase"]:
88
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
89
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
90
+
91
+ if isinstance(layer, LinearBase):
92
+ if is_layer_skipped_gguf(prefix, self.modules_to_not_convert):
93
+ return UnquantizedLinearMethod()
94
+ return GGUFLinearMethod(self)
95
+ elif isinstance(layer, VocabParallelEmbedding):
96
+ return GGUFEmbeddingMethod(self)
97
+ elif isinstance(layer, FusedMoE):
98
+ return GGUFMoEMethod(self)
99
+ return None
100
+
101
+
102
+ def is_layer_skipped_gguf(prefix: str, modules_to_not_convert: list[str]):
103
+ return any(module_name in prefix for module_name in modules_to_not_convert)
104
+
105
+
106
+ UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
107
+ STANDARD_QUANT_TYPES = {
108
+ WeightType.Q4_0,
109
+ WeightType.Q4_1,
110
+ WeightType.Q5_0,
111
+ WeightType.Q5_1,
112
+ WeightType.Q8_0,
113
+ WeightType.Q8_1,
114
+ }
115
+ KQUANT_TYPES = {
116
+ WeightType.Q2_K,
117
+ WeightType.Q3_K,
118
+ WeightType.Q4_K,
119
+ WeightType.Q5_K,
120
+ WeightType.Q6_K,
121
+ }
122
+ IMATRIX_QUANT_TYPES = {
123
+ WeightType.IQ1_M,
124
+ WeightType.IQ1_S,
125
+ WeightType.IQ2_XXS,
126
+ WeightType.IQ2_XS,
127
+ WeightType.IQ2_S,
128
+ WeightType.IQ3_XXS,
129
+ WeightType.IQ3_S,
130
+ WeightType.IQ4_XS,
131
+ WeightType.IQ4_NL,
132
+ }
133
+ # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
134
+ # Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
135
+ # MMQ kernel for I-Matrix quantization.
136
+ DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
137
+ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
138
+ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
139
+
140
+
141
+ def fused_mul_mat_gguf(
142
+ x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
143
+ ) -> torch.Tensor:
144
+ if qweight_type in IMATRIX_QUANT_TYPES:
145
+ mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
146
+ else:
147
+ mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
148
+ # HACK: when doing chunked prefill we don't generate output tokens
149
+ # so input to logits generator is empty which causes invalid parameter
150
+ if x.shape[0] == 0:
151
+ return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
152
+ # there is no need to call any kernel for fp16/bf16
153
+ if qweight_type in UNQUANTIZED_TYPES:
154
+ return x @ qweight.T
155
+ # enable MMVQ in contiguous batching with batch_size=1
156
+ if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
157
+ y = ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
158
+ # Use MMQ Kernel if it's available (standard + k-quants)
159
+ elif qweight_type in MMQ_QUANT_TYPES:
160
+ y = ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
161
+ # If there is no available MMQ kernel, fallback to dequantize
162
+ elif qweight_type in DEQUANT_TYPES:
163
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
164
+ shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
165
+ weight = ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
166
+ y = x @ weight.T
167
+ else:
168
+ # Raise an error if the quantization type is not supported.
169
+ # Might be useful if llama.cpp adds a new quantization type.
170
+ # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
171
+ qweight_type = WeightType(qweight_type)
172
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
173
+ return y
174
+
175
+
176
+ def fused_moe_gguf(
177
+ x: torch.Tensor,
178
+ w1: torch.Tensor,
179
+ w2: torch.Tensor,
180
+ topk_weights: torch.Tensor,
181
+ topk_ids: torch.Tensor,
182
+ qweight_type: int,
183
+ qweight_type2: int,
184
+ activation: str,
185
+ ) -> torch.Tensor:
186
+ def act(x: torch.Tensor):
187
+ d = x.shape[-1] // 2
188
+ output_shape = x.shape[:-1] + (d,)
189
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
190
+ if activation == "silu":
191
+ silu_and_mul(out, x)
192
+ elif activation == "gelu":
193
+ gelu_and_mul(out, x)
194
+ else:
195
+ raise ValueError(f"Unsupported activation: {activation}")
196
+ return out
197
+
198
+ out_hidden_states = torch.empty_like(x)
199
+ # unless we decent expert reuse we are better off running moe_vec kernel
200
+ if (
201
+ qweight_type2 in MMQ_QUANT_TYPES
202
+ and qweight_type in MMQ_QUANT_TYPES
203
+ and x.shape[0] > 64
204
+ ):
205
+ num_tokens, _ = x.shape
206
+ E, N, _ = w1.shape
207
+ top_k = topk_ids.shape[1]
208
+ BLOCK_SIZE = ggml_moe_get_block_size(qweight_type)
209
+
210
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
211
+ topk_ids, BLOCK_SIZE, E
212
+ )
213
+ out = ggml_moe_a8(
214
+ x,
215
+ w1,
216
+ sorted_token_ids,
217
+ expert_ids,
218
+ num_tokens_post_padded,
219
+ qweight_type,
220
+ N,
221
+ top_k,
222
+ num_tokens,
223
+ )
224
+ out = act(out)
225
+ out = ggml_moe_a8(
226
+ out,
227
+ w2,
228
+ sorted_token_ids,
229
+ expert_ids,
230
+ num_tokens_post_padded,
231
+ qweight_type2,
232
+ w2.shape[1],
233
+ 1,
234
+ num_tokens * top_k,
235
+ )
236
+ out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
237
+ topk_weights.view(num_tokens, top_k, 1)
238
+ )
239
+ # TODO(FlamingoPg): maybe we can use moe_sum_reduce here?
240
+ moe_sum(out, out_hidden_states)
241
+ elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
242
+ num_tokens, _ = x.shape
243
+ E, N, _ = w1.shape
244
+ top_k = topk_ids.shape[1]
245
+
246
+ out = ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
247
+ out = act(out)
248
+
249
+ out = ggml_moe_a8_vec(
250
+ out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
251
+ )
252
+ out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
253
+ topk_weights.view(num_tokens, top_k, 1)
254
+ )
255
+ moe_sum(out, out_hidden_states)
256
+ else:
257
+ logger.warning_once(
258
+ "There is no support for fast MoE kernel "
259
+ "for current quantization method. "
260
+ "Falling back to slow implementation. "
261
+ )
262
+ for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
263
+ inp = x[tok].reshape((1,) + x.shape[1:])
264
+ current_hidden_state = None
265
+ for ww, ii in zip(w, idx):
266
+ expert_up = w1[ii]
267
+
268
+ out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
269
+ out = act(out)
270
+
271
+ expert_down = w2[ii]
272
+ current_state = fused_mul_mat_gguf(
273
+ out, expert_down, qweight_type2
274
+ ).mul_(ww)
275
+ if current_hidden_state is None:
276
+ current_hidden_state = current_state
277
+ else:
278
+ current_hidden_state.add_(current_state)
279
+ out_hidden_states[tok] = current_hidden_state
280
+ return out_hidden_states
281
+
282
+
283
+ def apply_gguf_embedding(
284
+ x: torch.Tensor,
285
+ qweight: torch.Tensor,
286
+ qweight_type: int,
287
+ hidden_size: int,
288
+ dtype: torch.dtype | None = None,
289
+ ) -> torch.Tensor:
290
+ if qweight_type in UNQUANTIZED_TYPES:
291
+ return torch.embedding(qweight, x)
292
+ elif qweight_type in DEQUANT_TYPES:
293
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
294
+ x_flat = x.flatten()
295
+ assert hidden_size == qweight.shape[1] // type_size * block_size
296
+ quant = torch.index_select(qweight, dim=0, index=x_flat)
297
+ dequant = ggml_dequantize(
298
+ quant, qweight_type, hidden_size, x_flat.shape[0], dtype
299
+ )
300
+ return dequant.view(*x.shape, hidden_size)
301
+ else:
302
+ qweight_type = WeightType(qweight_type)
303
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
304
+
305
+
306
+ class GGUFLinearMethod(LinearMethodBase):
307
+ """Linear method for GGUF.
308
+
309
+ Args:
310
+ quant_config: The GGUF quantization config.
311
+ """
312
+
313
+ def __init__(self, quant_config: GGUFConfig):
314
+ self.quant_config = quant_config
315
+
316
+ def create_weights(
317
+ self,
318
+ layer: torch.nn.Module,
319
+ input_size_per_partition: int,
320
+ output_partition_sizes: list[int],
321
+ input_size: int,
322
+ output_size: int,
323
+ params_dtype: torch.dtype,
324
+ **extra_weight_attrs,
325
+ ):
326
+ self.params_dtype = params_dtype
327
+ output_size_per_partition = sum(output_partition_sizes)
328
+
329
+ tensor_shape = (output_size_per_partition, input_size_per_partition)
330
+ qweight = GGUFUninitializedParameter(requires_grad=False)
331
+ set_weight_attrs(
332
+ qweight,
333
+ {
334
+ "input_dim": 1,
335
+ "output_dim": 0,
336
+ "tensor_shape": tensor_shape,
337
+ "is_gguf_weight": True,
338
+ "data_container": [],
339
+ "shard_id": [],
340
+ "shard_id_map": {},
341
+ },
342
+ )
343
+ set_weight_attrs(qweight, extra_weight_attrs)
344
+ layer.register_parameter("qweight", qweight)
345
+
346
+ qweight_type = Parameter(
347
+ torch.empty(len(output_partition_sizes), dtype=torch.uint8),
348
+ requires_grad=False,
349
+ )
350
+ set_weight_attrs(
351
+ qweight_type,
352
+ {
353
+ "is_gguf_weight_type": True,
354
+ "weight_type": 0,
355
+ "shard_weight_type": {},
356
+ "ignore_warning": True,
357
+ },
358
+ )
359
+ set_weight_attrs(qweight_type, extra_weight_attrs)
360
+ layer.register_parameter("qweight_type", qweight_type)
361
+
362
+ def process_weights_after_loading(self, layer: torch.nn.Module):
363
+ qweight_type = layer.qweight_type.weight_type
364
+ if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
365
+ qweight_type = WeightType(qweight_type)
366
+ raise ValueError(
367
+ f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
368
+ )
369
+ # For MergedColumnParallelLinear and QKVParallelLinear, we need to
370
+ # materialize the padded weight parameter for CUDA Graph compatibility.
371
+ self._create_padded_weight_param(layer)
372
+
373
+ def _create_padded_weight_param(self, layer: torch.nn.Module):
374
+ """Create padded weight parameter for GGUF MergedLinear layer."""
375
+ qweight = layer.qweight
376
+ shard_id_map = qweight.shard_id_map
377
+ shard_id = qweight.shard_id
378
+ if len(data_container := qweight.data_container) > 1:
379
+ dtype = {data.dtype for data in data_container}
380
+ assert len(dtype) == 1, ValueError(
381
+ f"Data container has mixed dtypes: {dtype}"
382
+ )
383
+ dtype = next(iter(dtype))
384
+ # concat dim0 and pad dim1
385
+ padded_side = max(x.size(1) for x in data_container)
386
+ concat_side = sum(x.size(0) for x in data_container)
387
+ # Pad the quantized weights to dense tensor, and create a map
388
+ # with the location of each shard in the padded tensor.
389
+ padded_data = torch.zeros(
390
+ (concat_side, padded_side), dtype=dtype, device=qweight.device
391
+ )
392
+ # (dim0_start, dim0_end, dim1_size)
393
+ shard_offset_map = dict[str, tuple[int, int, int]]()
394
+ for idx in shard_id:
395
+ id_in_container = shard_id_map[idx]
396
+ start = sum(x.size(0) for x in data_container[:id_in_container])
397
+ end = start + data_container[id_in_container].size(0)
398
+ size = data_container[id_in_container].size(1)
399
+ padded_data[start:end, :size] = data_container[id_in_container]
400
+ shard_offset_map[idx] = (start, end, size)
401
+ qweight.data_container.clear()
402
+ padded_param = Parameter(padded_data, requires_grad=False)
403
+ set_weight_attrs(padded_param, vars(qweight))
404
+ set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
405
+ layer.register_parameter("qweight", padded_param)
406
+
407
+ def apply(
408
+ self,
409
+ layer: torch.nn.Module,
410
+ x: torch.Tensor,
411
+ bias: torch.Tensor | None = None,
412
+ ) -> torch.Tensor:
413
+ shard_id = layer.qweight.shard_id
414
+
415
+ if shard_id:
416
+ # dequantize shard weights respectively
417
+ shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
418
+ qweight = layer.qweight
419
+ result = []
420
+ for idx in shard_id:
421
+ start, end, offset = layer.qweight.shard_offset_map[idx]
422
+ qweight_type = layer.qweight_type.shard_weight_type[idx]
423
+ result.append(
424
+ fused_mul_mat_gguf(
425
+ x, qweight[start:end, :offset].contiguous(), qweight_type
426
+ )
427
+ )
428
+ out = torch.cat(result, axis=1)
429
+ else:
430
+ qweight = layer.qweight
431
+ qweight_type = layer.qweight_type.weight_type
432
+ out = fused_mul_mat_gguf(x, qweight, qweight_type)
433
+ if bias is not None:
434
+ out.add_(bias)
435
+ return out
436
+
437
+
438
+ class GGUFMoEMethod(FusedMoEMethodBase):
439
+ """MoE method for GGUF.
440
+
441
+ Args:
442
+ quant_config: The GGUF quantization config.
443
+ """
444
+
445
+ def __init__(self, quant_config: GGUFConfig):
446
+ self.quant_config = quant_config
447
+
448
+ def create_weights(
449
+ self,
450
+ layer: torch.nn.Module,
451
+ num_experts: int,
452
+ hidden_size: int,
453
+ intermediate_size_per_partition: int,
454
+ params_dtype: torch.dtype,
455
+ **extra_weight_attrs,
456
+ ):
457
+ tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
458
+ # gate up proj
459
+ w13_qweight = GGUFUninitializedParameter(requires_grad=False)
460
+ set_weight_attrs(
461
+ w13_qweight,
462
+ {
463
+ "input_dim": 1,
464
+ "output_dim": 0,
465
+ "tensor_shape": tensor_shape,
466
+ "is_gguf_weight": True,
467
+ "data_container": [],
468
+ },
469
+ )
470
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
471
+ layer.register_parameter("w13_qweight", w13_qweight)
472
+
473
+ w13_qweight_type = Parameter(
474
+ torch.empty(1, dtype=torch.uint8), requires_grad=False
475
+ )
476
+ set_weight_attrs(
477
+ w13_qweight_type,
478
+ {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
479
+ )
480
+ set_weight_attrs(w13_qweight_type, extra_weight_attrs)
481
+ layer.register_parameter("w13_qweight_type", w13_qweight_type)
482
+
483
+ tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
484
+ # gate down proj
485
+ w2_qweight = GGUFUninitializedParameter(requires_grad=False)
486
+ set_weight_attrs(
487
+ w2_qweight,
488
+ {
489
+ "input_dim": 1,
490
+ "output_dim": 0,
491
+ "tensor_shape": tensor_shape,
492
+ "is_gguf_weight": True,
493
+ "data_container": [],
494
+ },
495
+ )
496
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
497
+ layer.register_parameter("w2_qweight", w2_qweight)
498
+
499
+ w2_qweight_type = Parameter(
500
+ torch.empty(1, dtype=torch.uint8), requires_grad=False
501
+ )
502
+ set_weight_attrs(
503
+ w2_qweight_type,
504
+ {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
505
+ )
506
+
507
+ set_weight_attrs(w2_qweight_type, extra_weight_attrs)
508
+ layer.register_parameter("w2_qweight_type", w2_qweight_type)
509
+
510
+ def create_moe_runner(
511
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
512
+ ):
513
+ self.moe_runner_config = moe_runner_config
514
+
515
+ def apply(
516
+ self,
517
+ layer: torch.nn.Module,
518
+ dispatch_output: StandardDispatchOutput,
519
+ ) -> CombineInput:
520
+ assert self.fused_experts is None
521
+
522
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
523
+
524
+ assert (
525
+ self.moe_runner_config.activation == "silu"
526
+ ), "Only SiLU activation is supported."
527
+
528
+ x = dispatch_output.hidden_states
529
+ topk_output = dispatch_output.topk_output
530
+
531
+ moe_runner_config = self.moe_runner_config
532
+
533
+ topk_weights, topk_ids, _ = topk_output
534
+ output = fused_moe_gguf(
535
+ x=x,
536
+ w1=layer.w13_qweight,
537
+ w2=layer.w2_qweight,
538
+ topk_weights=topk_weights,
539
+ topk_ids=topk_ids,
540
+ qweight_type=layer.w13_qweight_type.weight_type,
541
+ qweight_type2=layer.w2_qweight_type.weight_type,
542
+ activation=moe_runner_config.activation,
543
+ )
544
+ return StandardCombineInput(hidden_states=output)
545
+
546
+
547
+ class GGUFEmbeddingMethod(GGUFLinearMethod):
548
+ """Embedding method for GGUF.
549
+
550
+ Args:
551
+ quant_config: The GGUF quantization config.
552
+ """
553
+
554
+ def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
555
+ qweight = layer.qweight
556
+ qweight_type = layer.qweight_type.weight_type
557
+ hidden_size = qweight.tensor_shape[1]
558
+
559
+ return apply_gguf_embedding(
560
+ x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
561
+ )
562
+
563
+
564
+ class GGUFUninitializedParameter(UninitializedParameter):
565
+ cls_to_become = Parameter
566
+ data_container: list[torch.Tensor]