sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -5,20 +5,21 @@ from enum import Enum
5
5
  from typing import Callable, List, Optional, Tuple
6
6
 
7
7
  import torch
8
- from vllm.distributed import (
8
+ from vllm.model_executor.custom_op import CustomOp
9
+
10
+ from sglang.srt.distributed import (
9
11
  get_tensor_model_parallel_rank,
10
12
  get_tensor_model_parallel_world_size,
11
13
  tensor_model_parallel_all_reduce,
12
14
  )
13
- from vllm.model_executor.custom_op import CustomOp
14
-
15
15
  from sglang.srt.layers.custom_op_util import register_custom_op
16
+ from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
16
17
  from sglang.srt.layers.moe.topk import select_experts
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
19
20
  QuantizeMethodBase,
20
21
  )
21
- from sglang.srt.utils import set_weight_attrs
22
+ from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
22
23
 
23
24
  if torch.cuda.is_available():
24
25
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -27,6 +28,8 @@ else:
27
28
 
28
29
  import logging
29
30
 
31
+ is_hip_ = is_hip()
32
+
30
33
  logger = logging.getLogger(__name__)
31
34
 
32
35
 
@@ -97,6 +100,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
97
100
  layer.register_parameter("w2_weight", w2_weight)
98
101
  set_weight_attrs(w2_weight, extra_weight_attrs)
99
102
 
103
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
104
+ if is_hip_ and get_bool_env_var("CK_MOE"):
105
+ layer.w13_weight = torch.nn.Parameter(
106
+ permute_weight(layer.w13_weight.data),
107
+ requires_grad=False,
108
+ )
109
+ torch.cuda.empty_cache()
110
+ layer.w2_weight = torch.nn.Parameter(
111
+ permute_weight(layer.w2_weight.data),
112
+ requires_grad=False,
113
+ )
114
+ torch.cuda.empty_cache()
115
+ return
116
+
100
117
  def apply(
101
118
  self,
102
119
  layer: torch.nn.Module,
@@ -148,17 +165,52 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
148
165
  correction_bias=correction_bias,
149
166
  )
150
167
 
151
- return fused_experts(
152
- hidden_states=x,
153
- w1=layer.w13_weight,
154
- w2=layer.w2_weight,
155
- topk_weights=topk_weights,
156
- topk_ids=topk_ids,
157
- inplace=True,
158
- )
168
+ if is_hip_ and get_bool_env_var("CK_MOE"):
169
+ import ater
170
+ from ater.fused_moe import fused_experts_ck
171
+
172
+ return fused_experts_ck(
173
+ hidden_states=x,
174
+ w1=layer.w13_weight,
175
+ w2=layer.w2_weight,
176
+ topk_weights=topk_weights,
177
+ topk_ids=topk_ids,
178
+ )
179
+ else:
180
+ return fused_experts(
181
+ hidden_states=x,
182
+ w1=layer.w13_weight,
183
+ w2=layer.w2_weight,
184
+ topk_weights=topk_weights,
185
+ topk_ids=topk_ids,
186
+ inplace=True,
187
+ )
159
188
 
160
- def forward_cpu(self, *args, **kwargs):
161
- raise NotImplementedError("The CPU backend currently does not support MoE.")
189
+ def forward_cpu(
190
+ self,
191
+ layer: torch.nn.Module,
192
+ x: torch.Tensor,
193
+ use_grouped_topk: bool,
194
+ top_k: int,
195
+ router_logits: torch.Tensor,
196
+ renormalize: bool,
197
+ topk_group: Optional[int] = None,
198
+ num_expert_group: Optional[int] = None,
199
+ custom_routing_function: Optional[Callable] = None,
200
+ correction_bias: Optional[torch.Tensor] = None,
201
+ ) -> torch.Tensor:
202
+ return moe_forward_native(
203
+ layer,
204
+ x,
205
+ use_grouped_topk,
206
+ top_k,
207
+ router_logits,
208
+ renormalize,
209
+ topk_group,
210
+ num_expert_group,
211
+ custom_routing_function,
212
+ correction_bias,
213
+ )
162
214
 
163
215
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
164
216
  raise NotImplementedError("The TPU backend currently does not support MoE.")
@@ -24,7 +24,9 @@ def fused_topk_native(
24
24
  topk: int,
25
25
  renormalize: bool,
26
26
  ):
27
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
27
+ assert (
28
+ hidden_states.shape[0] == gating_output.shape[0]
29
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
28
30
  M, _ = hidden_states.shape
29
31
  topk_weights = torch.empty(
30
32
  M, topk, dtype=torch.float32, device=hidden_states.device
@@ -180,7 +182,7 @@ def select_experts(
180
182
  num_expert_group=num_expert_group,
181
183
  topk_group=topk_group,
182
184
  )
183
- elif torch_native:
185
+ elif torch_native and custom_routing_function is None:
184
186
  topk_weights, topk_ids = fused_topk_native(
185
187
  hidden_states=hidden_states,
186
188
  gating_output=router_logits,
@@ -1,7 +1,4 @@
1
- """
2
- Adapted from vLLM (0.6.4.post1).
3
- https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
4
- """
1
+ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
5
2
 
6
3
  import logging
7
4
  from fractions import Fraction
@@ -9,7 +6,8 @@ from typing import Callable, Optional, Union
9
6
 
10
7
  import torch
11
8
  from torch.nn import Parameter
12
- from vllm.distributed import get_tensor_model_parallel_rank
9
+
10
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
13
11
 
14
12
  __all__ = [
15
13
  "BasevLLMParameter",
@@ -88,12 +86,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
88
86
  def output_dim(self):
89
87
  return self._output_dim
90
88
 
91
- def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
92
- tp_rank = get_tensor_model_parallel_rank()
93
- shard_size = self.data.shape[self.output_dim]
94
- loaded_weight = loaded_weight.narrow(
95
- self.output_dim, tp_rank * shard_size, shard_size
96
- )
89
+ def load_column_parallel_weight(
90
+ self,
91
+ loaded_weight: torch.Tensor,
92
+ tp_rank: int,
93
+ use_presharded_weights: bool = False,
94
+ ):
95
+ if not use_presharded_weights:
96
+ shard_size = self.data.shape[self.output_dim]
97
+ loaded_weight = loaded_weight.narrow(
98
+ self.output_dim, tp_rank * shard_size, shard_size
99
+ )
97
100
  assert self.data.shape == loaded_weight.shape
98
101
  self.data.copy_(loaded_weight)
99
102
 
@@ -121,7 +124,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
121
124
  assert param_data.shape == loaded_weight.shape
122
125
  param_data.copy_(loaded_weight)
123
126
 
124
- def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
127
+ def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
125
128
 
126
129
  shard_offset = kwargs.get("shard_offset")
127
130
  shard_size = kwargs.get("shard_size")
@@ -137,7 +140,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
137
140
  )
138
141
 
139
142
  param_data = self.data
140
- tp_rank = get_tensor_model_parallel_rank()
141
143
  shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
142
144
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
143
145
  loaded_weight = loaded_weight.narrow(
@@ -164,11 +166,14 @@ class RowvLLMParameter(BasevLLMParameter):
164
166
  def input_dim(self):
165
167
  return self._input_dim
166
168
 
167
- def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
168
- use_presharded_weights = kwargs.get("use_presharded_weights")
169
- tp_rank = get_tensor_model_parallel_rank()
170
- shard_size = self.data.shape[self.input_dim]
169
+ def load_row_parallel_weight(
170
+ self,
171
+ loaded_weight: torch.Tensor,
172
+ tp_rank: int,
173
+ use_presharded_weights: bool = False,
174
+ ):
171
175
  if not use_presharded_weights:
176
+ shard_size = self.data.shape[self.input_dim]
172
177
  loaded_weight = loaded_weight.narrow(
173
178
  self.input_dim, tp_rank * shard_size, shard_size
174
179
  )
@@ -238,6 +243,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
238
243
  # For row parallel layers, no sharding needed
239
244
  # load weight into parameter as is
240
245
  def load_row_parallel_weight(self, *args, **kwargs):
246
+ kwargs.pop("tp_rank", None)
247
+ kwargs.pop("use_presharded_weights", None)
241
248
  super().load_row_parallel_weight(*args, **kwargs)
242
249
 
243
250
  def load_merged_column_weight(self, *args, **kwargs):
@@ -247,6 +254,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
247
254
  self._load_into_shard_id(*args, **kwargs)
248
255
 
249
256
  def load_column_parallel_weight(self, *args, **kwargs):
257
+ kwargs.pop("tp_rank", None)
258
+ kwargs.pop("use_presharded_weights", None)
250
259
  super().load_row_parallel_weight(*args, **kwargs)
251
260
 
252
261
  def _load_into_shard_id(
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
23
23
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
24
  from sglang.srt.layers.quantization.fp8 import Fp8Config
25
25
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
26
+ from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
26
27
 
27
28
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
29
  "aqlm": AQLMConfig,
@@ -42,6 +43,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
42
43
  "bitsandbytes": BitsAndBytesConfig,
43
44
  "qqq": QQQConfig,
44
45
  "experts_int8": ExpertsInt8Config,
46
+ "w8a8_int8": W8A8Int8Config,
45
47
  }
46
48
 
47
49
 
@@ -54,33 +56,13 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
54
56
  return QUANTIZATION_METHODS[quantization]
55
57
 
56
58
 
57
- def fp8_get_quant_method(self, layer, prefix):
58
- """Enhanced get_quant_method for FP8 config."""
59
- from vllm.model_executor.layers.linear import LinearBase
60
- from vllm.model_executor.layers.quantization.utils.quant_utils import (
61
- is_layer_skipped,
62
- )
63
-
64
- from sglang.srt.layers.linear import UnquantizedLinearMethod
65
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
66
- from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
67
-
68
- if isinstance(layer, LinearBase):
69
- if is_layer_skipped(prefix, self.ignored_layers):
70
- return UnquantizedLinearMethod()
71
- return Fp8LinearMethod(self)
72
- elif isinstance(layer, FusedMoE):
73
- return Fp8MoEMethod(self)
74
- return None
75
-
76
-
77
59
  def gptq_get_quant_method(self, layer, prefix):
78
- from vllm.model_executor.layers.linear import LinearBase
79
60
  from vllm.model_executor.layers.quantization.gptq_marlin import (
80
61
  GPTQMarlinLinearMethod,
81
62
  GPTQMarlinMoEMethod,
82
63
  )
83
64
 
65
+ from sglang.srt.layers.linear import LinearBase
84
66
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
85
67
 
86
68
  if isinstance(layer, LinearBase):
@@ -91,12 +73,12 @@ def gptq_get_quant_method(self, layer, prefix):
91
73
 
92
74
 
93
75
  def awq_get_quant_method(self, layer, prefix):
94
- from vllm.model_executor.layers.linear import LinearBase
95
76
  from vllm.model_executor.layers.quantization.awq_marlin import (
96
77
  AWQMarlinLinearMethod,
97
78
  AWQMoEMethod,
98
79
  )
99
80
 
81
+ from sglang.srt.layers.linear import LinearBase
100
82
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
101
83
 
102
84
  if isinstance(layer, LinearBase):
@@ -106,13 +88,30 @@ def awq_get_quant_method(self, layer, prefix):
106
88
  return None
107
89
 
108
90
 
91
+ def patch_vllm_linear_base_isinstance():
92
+ import builtins
93
+
94
+ from vllm.model_executor.layers.linear import LinearBase
95
+
96
+ from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
97
+
98
+ original_isinstance = builtins.isinstance
99
+
100
+ def patched_isinstance(obj, classinfo):
101
+ if classinfo is LinearBase:
102
+ return original_isinstance(obj, PatchedLinearBase)
103
+ return original_isinstance(obj, classinfo)
104
+
105
+ builtins.isinstance = patched_isinstance
106
+
107
+
109
108
  def apply_monkey_patches():
110
109
  """Apply all monkey patches in one place."""
111
- setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
112
110
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
113
111
  setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
114
112
 
115
113
 
114
+ patch_vllm_linear_base_isinstance()
116
115
  # Apply patches when module is imported
117
116
  apply_monkey_patches()
118
117
 
@@ -1,7 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
- import os
5
4
  from typing import Any, Callable, Dict, List, Optional
6
5
 
7
6
  import torch
@@ -9,8 +8,6 @@ import torch.nn.functional as F
9
8
  from torch.nn import Module
10
9
  from torch.nn.parameter import Parameter
11
10
  from vllm import _custom_ops as ops
12
- from vllm.distributed import get_tensor_model_parallel_world_size
13
- from vllm.model_executor.layers.linear import LinearBase
14
11
  from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
15
12
  from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
16
13
  apply_fp8_marlin_linear,
@@ -26,7 +23,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
26
23
  requantize_with_max_scale,
27
24
  )
28
25
 
29
- from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.linear import (
28
+ LinearBase,
29
+ LinearMethodBase,
30
+ UnquantizedLinearMethod,
31
+ )
30
32
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
31
33
  from sglang.srt.layers.quantization.base_config import (
32
34
  QuantizationConfig,
@@ -40,12 +42,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
40
42
  from sglang.srt.utils import (
41
43
  get_bool_env_var,
42
44
  is_hip,
45
+ permute_weight,
43
46
  print_warning_once,
44
47
  set_weight_attrs,
45
48
  )
46
49
 
47
50
  ACTIVATION_SCHEMES = ["static", "dynamic"]
48
51
 
52
+ is_hip_ = is_hip()
53
+
49
54
  logger = logging.getLogger(__name__)
50
55
 
51
56
 
@@ -161,7 +166,7 @@ class Fp8LinearMethod(LinearMethodBase):
161
166
  # kernel for fast weight-only FP8 quantization
162
167
  self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
163
168
  # Disable marlin for ROCm
164
- if is_hip():
169
+ if is_hip_:
165
170
  self.use_marlin = False
166
171
 
167
172
  self.block_quant = self.quant_config.weight_block_size is not None
@@ -273,7 +278,7 @@ class Fp8LinearMethod(LinearMethodBase):
273
278
  # Block quant doesn't need to process weights after loading
274
279
  if self.block_quant:
275
280
  # If ROCm, normalize the weights and scales to e4m3fnuz
276
- if is_hip():
281
+ if is_hip_:
277
282
  # activation_scheme: dynamic
278
283
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
279
284
  weight=layer.weight,
@@ -330,7 +335,7 @@ class Fp8LinearMethod(LinearMethodBase):
330
335
  weight_scale = layer.weight_scale
331
336
 
332
337
  # If ROCm, normalize the weights and scales to e4m3fnuz
333
- if is_hip():
338
+ if is_hip_:
334
339
  weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
335
340
  weight=weight,
336
341
  weight_scale=weight_scale,
@@ -567,7 +572,7 @@ class Fp8MoEMethod:
567
572
  # Block quant doesn't need to process weights after loading
568
573
  if self.block_quant:
569
574
  # If ROCm, normalize the weights and scales to e4m3fnuz
570
- if is_hip():
575
+ if is_hip_:
571
576
  # activation_scheme: dynamic
572
577
  w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
573
578
  weight=layer.w13_weight,
@@ -594,7 +599,7 @@ class Fp8MoEMethod:
594
599
  # If checkpoint is fp16 or bfloat16, quantize in place.
595
600
  if not self.quant_config.is_checkpoint_fp8_serialized:
596
601
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
597
- fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
602
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
598
603
  w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
599
604
  w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
600
605
 
@@ -616,18 +621,30 @@ class Fp8MoEMethod:
616
621
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
617
622
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
618
623
 
619
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
620
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
621
- layer.w13_weight = torch.nn.Parameter(
622
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
623
- requires_grad=False,
624
- )
625
- torch.cuda.empty_cache()
626
- layer.w2_weight = torch.nn.Parameter(
627
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
628
- requires_grad=False,
629
- )
630
- torch.cuda.empty_cache()
624
+ if is_hip_:
625
+ if get_bool_env_var("CK_MOE"):
626
+ layer.w13_weight = torch.nn.Parameter(
627
+ permute_weight(layer.w13_weight.data),
628
+ requires_grad=False,
629
+ )
630
+ torch.cuda.empty_cache()
631
+ layer.w2_weight = torch.nn.Parameter(
632
+ permute_weight(layer.w2_weight.data),
633
+ requires_grad=False,
634
+ )
635
+ torch.cuda.empty_cache()
636
+ elif get_bool_env_var("MOE_PADDING"):
637
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
638
+ layer.w13_weight = torch.nn.Parameter(
639
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
640
+ requires_grad=False,
641
+ )
642
+ torch.cuda.empty_cache()
643
+ layer.w2_weight = torch.nn.Parameter(
644
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
645
+ requires_grad=False,
646
+ )
647
+ torch.cuda.empty_cache()
631
648
  return
632
649
 
633
650
  # If checkpoint is fp8, we need to handle that the
@@ -658,7 +675,7 @@ class Fp8MoEMethod:
658
675
  )
659
676
 
660
677
  # If ROCm, normalize the weights and scales to e4m3fnuz
661
- if is_hip():
678
+ if is_hip_:
662
679
  # Normalize the weights and scales
663
680
  w13_weight, w13_weight_scale, w13_input_scale = (
664
681
  normalize_e4m3fn_to_e4m3fnuz(
@@ -708,18 +725,30 @@ class Fp8MoEMethod:
708
725
  max_w13_scales, requires_grad=False
709
726
  )
710
727
 
711
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
712
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
713
- layer.w13_weight = torch.nn.Parameter(
714
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
715
- requires_grad=False,
716
- )
717
- torch.cuda.empty_cache()
718
- layer.w2_weight = torch.nn.Parameter(
719
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
720
- requires_grad=False,
721
- )
722
- torch.cuda.empty_cache()
728
+ if is_hip_:
729
+ if get_bool_env_var("CK_MOE"):
730
+ layer.w13_weight = torch.nn.Parameter(
731
+ permute_weight(layer.w13_weight.data),
732
+ requires_grad=False,
733
+ )
734
+ torch.cuda.empty_cache()
735
+ layer.w2_weight = torch.nn.Parameter(
736
+ permute_weight(layer.w2_weight.data),
737
+ requires_grad=False,
738
+ )
739
+ torch.cuda.empty_cache()
740
+ elif get_bool_env_var("MOE_PADDING"):
741
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
742
+ layer.w13_weight = torch.nn.Parameter(
743
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
744
+ requires_grad=False,
745
+ )
746
+ torch.cuda.empty_cache()
747
+ layer.w2_weight = torch.nn.Parameter(
748
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
749
+ requires_grad=False,
750
+ )
751
+ torch.cuda.empty_cache()
723
752
  return
724
753
 
725
754
  def apply(
@@ -752,27 +781,55 @@ class Fp8MoEMethod:
752
781
  correction_bias=correction_bias,
753
782
  )
754
783
 
755
- # Expert fusion with FP8 quantization
756
- return fused_experts(
757
- x,
758
- layer.w13_weight,
759
- layer.w2_weight,
760
- topk_weights=topk_weights,
761
- topk_ids=topk_ids,
762
- inplace=True,
763
- use_fp8_w8a8=True,
764
- w1_scale=(
765
- layer.w13_weight_scale_inv
766
- if self.block_quant
767
- else layer.w13_weight_scale
768
- ),
769
- w2_scale=(
770
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
771
- ),
772
- a1_scale=layer.w13_input_scale,
773
- a2_scale=layer.w2_input_scale,
774
- block_shape=self.quant_config.weight_block_size,
775
- )
784
+ if is_hip_ and get_bool_env_var("CK_MOE"):
785
+ import ater
786
+ from ater.fused_moe import fused_experts_ck
787
+
788
+ return fused_experts_ck(
789
+ x,
790
+ layer.w13_weight,
791
+ layer.w2_weight,
792
+ topk_weights=topk_weights,
793
+ topk_ids=topk_ids,
794
+ use_fp8_w8a8=True,
795
+ w1_scale=(
796
+ layer.w13_weight_scale_inv
797
+ if self.block_quant
798
+ else layer.w13_weight_scale
799
+ ),
800
+ w2_scale=(
801
+ layer.w2_weight_scale_inv
802
+ if self.block_quant
803
+ else layer.w2_weight_scale
804
+ ),
805
+ a1_scale=layer.w13_input_scale,
806
+ a2_scale=layer.w2_input_scale,
807
+ )
808
+
809
+ else:
810
+ # Expert fusion with FP8 quantization
811
+ return fused_experts(
812
+ x,
813
+ layer.w13_weight,
814
+ layer.w2_weight,
815
+ topk_weights=topk_weights,
816
+ topk_ids=topk_ids,
817
+ inplace=True,
818
+ use_fp8_w8a8=True,
819
+ w1_scale=(
820
+ layer.w13_weight_scale_inv
821
+ if self.block_quant
822
+ else layer.w13_weight_scale
823
+ ),
824
+ w2_scale=(
825
+ layer.w2_weight_scale_inv
826
+ if self.block_quant
827
+ else layer.w2_weight_scale
828
+ ),
829
+ a1_scale=layer.w13_input_scale,
830
+ a2_scale=layer.w2_input_scale,
831
+ block_shape=self.quant_config.weight_block_size,
832
+ )
776
833
 
777
834
 
778
835
  class Fp8KVCacheMethod(BaseKVCacheMethod):
@@ -1,8 +1,8 @@
1
1
  from typing import List, Optional, Tuple
2
2
 
3
3
  import torch
4
- from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
5
4
 
5
+ from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
6
6
  from sglang.srt.layers.quantization.fp8_kernel import (
7
7
  per_token_group_quant_fp8,
8
8
  w8a8_block_fp8_matmul,
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _per_token_quant_int8(
8
+ x_ptr,
9
+ xq_ptr,
10
+ scale_ptr,
11
+ stride_x,
12
+ stride_xq,
13
+ N,
14
+ BLOCK: tl.constexpr,
15
+ ):
16
+ # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
17
+ row_id = tl.program_id(0)
18
+
19
+ cols = tl.arange(0, BLOCK)
20
+ mask = cols < N
21
+
22
+ x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
23
+ absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
24
+ scale_x = absmax / 127
25
+ x_q = x * (127 / absmax)
26
+ x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
27
+
28
+ tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
29
+ tl.store(scale_ptr + row_id, scale_x)
30
+
31
+
32
+ def per_token_quant_int8(x):
33
+ M = x.numel() // x.shape[-1]
34
+ N = x.shape[-1]
35
+ x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
36
+ scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
37
+ BLOCK = triton.next_power_of_2(N)
38
+ # heuristics for number of warps
39
+ num_warps = min(max(BLOCK // 256, 1), 8)
40
+
41
+ assert x.is_contiguous()
42
+ _per_token_quant_int8[(M,)](
43
+ x,
44
+ x_q,
45
+ scales,
46
+ stride_x=x.stride(-2),
47
+ stride_xq=x_q.stride(-2),
48
+ N=N,
49
+ BLOCK=BLOCK,
50
+ num_warps=num_warps,
51
+ num_stages=1,
52
+ )
53
+
54
+ return x_q, scales
@@ -5,15 +5,14 @@ from typing import Any, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn.parameter import Parameter
8
- from vllm.model_executor.layers.linear import LinearBase
9
8
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10
9
  apply_fp8_linear,
11
10
  cutlass_fp8_supported,
12
11
  requantize_with_max_scale,
13
12
  )
14
- from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
15
13
 
16
- from sglang.srt.layers.linear import LinearMethodBase
14
+ from sglang.srt.layers.linear import LinearBase, LinearMethodBase
15
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
17
16
  from sglang.srt.layers.quantization.base_config import (
18
17
  QuantizationConfig,
19
18
  QuantizeMethodBase,