sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,661 @@
1
+ import logging
2
+ from typing import Callable, List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch.nn import Module
6
+ from vllm import _custom_ops as ops
7
+ from vllm.distributed import (
8
+ get_tensor_model_parallel_rank,
9
+ get_tensor_model_parallel_world_size,
10
+ )
11
+ from vllm.model_executor.custom_op import CustomOp
12
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
13
+
14
+ from sglang.srt.layers.custom_op_util import register_custom_op
15
+ from sglang.srt.layers.ep_moe.kernels import (
16
+ grouped_gemm_triton,
17
+ post_reorder_triton_kernel,
18
+ pre_reorder_triton_kernel,
19
+ run_moe_ep_preproess,
20
+ silu_and_mul_triton_kernel,
21
+ )
22
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
23
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
24
+ from sglang.srt.layers.quantization.base_config import (
25
+ QuantizationConfig,
26
+ QuantizeMethodBase,
27
+ )
28
+ from sglang.srt.utils import is_hip, set_weight_attrs
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class GroupedGemmRunner(torch.nn.Module):
34
+ flashinfer_gemm_warpper = None
35
+
36
+ def __init__(self, device, use_flashinfer: bool = False):
37
+ super().__init__()
38
+ self.device = device
39
+ self.use_flashinfer = use_flashinfer
40
+ if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
41
+ GroupedGemmRunner._init_flashinfer_wrapper(device)
42
+
43
+ @classmethod
44
+ def _init_flashinfer_wrapper(cls, device):
45
+ from flashinfer import SegmentGEMMWrapper
46
+
47
+ workspace_buffer = torch.empty(
48
+ 128 * 1024 * 1024, dtype=torch.int8, device=device
49
+ )
50
+ cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
51
+
52
+ # c = a * b
53
+ def forward(
54
+ self,
55
+ a: torch.Tensor,
56
+ b: torch.Tensor,
57
+ c: torch.Tensor,
58
+ batch_size: int,
59
+ weight_column_major: bool,
60
+ seg_indptr: Optional[torch.Tensor] = None,
61
+ weight_indices: Optional[torch.Tensor] = None,
62
+ use_fp8_w8a8: bool = False,
63
+ scale_a: torch.Tensor = None,
64
+ scale_b: torch.Tensor = None,
65
+ ):
66
+ if self.use_flashinfer:
67
+ # TODO: flashinfer
68
+ assert False
69
+ assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
70
+ c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
71
+ x=a,
72
+ weights=b,
73
+ batch_size=batch_size,
74
+ weight_column_major=weight_column_major,
75
+ seg_indptr=seg_indptr,
76
+ weight_indices=weight_indices,
77
+ )
78
+ else:
79
+ assert weight_column_major == True
80
+ c = grouped_gemm_triton(
81
+ a,
82
+ b,
83
+ c,
84
+ batch_size,
85
+ weight_column_major,
86
+ seg_indptr,
87
+ weight_indices,
88
+ use_fp8_w8a8,
89
+ scale_a,
90
+ scale_b,
91
+ )
92
+ return c
93
+
94
+
95
+ class EPMoE(torch.nn.Module):
96
+ """
97
+ MoE Expert Parallel Impl
98
+
99
+
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ num_experts: int,
105
+ top_k: int,
106
+ hidden_size: int,
107
+ intermediate_size: int,
108
+ params_dtype: Optional[torch.dtype] = None,
109
+ renormalize: bool = True,
110
+ use_grouped_topk: bool = False,
111
+ num_expert_group: Optional[int] = None,
112
+ topk_group: Optional[int] = None,
113
+ quant_config: Optional[QuantizationConfig] = None,
114
+ tp_size: Optional[int] = None,
115
+ prefix: str = "",
116
+ ):
117
+ super().__init__()
118
+
119
+ if params_dtype is None:
120
+ params_dtype = torch.get_default_dtype()
121
+
122
+ self.tp_size = (
123
+ tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
124
+ )
125
+ self.tp_rank = get_tensor_model_parallel_rank()
126
+
127
+ self.num_experts = num_experts
128
+ assert self.num_experts % self.tp_size == 0
129
+ self.num_experts_per_partition = self.num_experts // self.tp_size
130
+ self.start_expert_id = self.tp_rank * self.num_experts_per_partition
131
+ self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
132
+
133
+ self.top_k = top_k
134
+ self.intermediate_size = intermediate_size
135
+ self.renormalize = renormalize
136
+ self.use_grouped_topk = use_grouped_topk
137
+ if self.use_grouped_topk:
138
+ assert num_expert_group is not None and topk_group is not None
139
+ self.num_expert_group = num_expert_group
140
+ self.topk_group = topk_group
141
+
142
+ if quant_config is None:
143
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
144
+ self.use_fp8_w8a8 = False
145
+ self.activation_scheme = None
146
+ else:
147
+ self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
148
+ quant_config
149
+ )
150
+ self.use_fp8_w8a8 = True
151
+ self.fp8_dtype = torch.float8_e4m3fn
152
+ self.activation_scheme = quant_config.activation_scheme
153
+
154
+ self.quant_method.create_weights(
155
+ layer=self,
156
+ num_experts_per_partition=self.num_experts_per_partition,
157
+ hidden_size=hidden_size,
158
+ intermediate_size=self.intermediate_size,
159
+ params_dtype=params_dtype,
160
+ weight_loader=self.weight_loader,
161
+ )
162
+
163
+ self.grouped_gemm_runner = None
164
+
165
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
166
+ assert self.quant_method is not None
167
+
168
+ if self.grouped_gemm_runner is None:
169
+ self.grouped_gemm_runner = GroupedGemmRunner(
170
+ hidden_states.device, use_flashinfer=False # TODO: use flashinfer
171
+ )
172
+
173
+ topk_weights, topk_ids = self.select_experts(
174
+ hidden_states,
175
+ router_logits,
176
+ self.top_k,
177
+ self.renormalize,
178
+ self.topk_group,
179
+ self.num_expert_group,
180
+ )
181
+
182
+ reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
183
+ topk_ids, self.num_experts
184
+ )
185
+
186
+ gateup_input = torch.empty(
187
+ (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
188
+ device=hidden_states.device,
189
+ dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
190
+ )
191
+ if self.activation_scheme == "dynamic":
192
+ max_value = (
193
+ torch.max(hidden_states)
194
+ .repeat(self.num_experts_per_partition)
195
+ .to(torch.float32)
196
+ )
197
+ self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
198
+
199
+ # PreReorder
200
+ pre_reorder_triton_kernel[(hidden_states.shape[0],)](
201
+ hidden_states,
202
+ gateup_input,
203
+ src2dst,
204
+ topk_ids,
205
+ self.w13_input_scale,
206
+ self.start_expert_id,
207
+ self.end_expert_id,
208
+ self.top_k,
209
+ hidden_states.shape[1],
210
+ BLOCK_SIZE=512,
211
+ )
212
+
213
+ seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
214
+ weight_indices_cur_rank = torch.arange(
215
+ 0,
216
+ self.num_experts_per_partition,
217
+ device=hidden_states.device,
218
+ dtype=torch.int64,
219
+ )
220
+ # GroupGemm-0
221
+ gateup_output = torch.empty(
222
+ gateup_input.shape[0],
223
+ self.w13_weight.shape[1],
224
+ device=hidden_states.device,
225
+ dtype=hidden_states.dtype,
226
+ )
227
+ gateup_output = self.grouped_gemm_runner(
228
+ a=gateup_input,
229
+ b=self.w13_weight,
230
+ c=gateup_output,
231
+ batch_size=self.num_experts_per_partition,
232
+ weight_column_major=True,
233
+ seg_indptr=seg_indptr_cur_rank,
234
+ weight_indices=weight_indices_cur_rank,
235
+ use_fp8_w8a8=self.use_fp8_w8a8,
236
+ scale_a=self.w13_input_scale,
237
+ scale_b=self.w13_weight_scale,
238
+ )
239
+
240
+ # Act
241
+ down_input = torch.empty(
242
+ gateup_output.shape[0],
243
+ gateup_output.shape[1] // 2,
244
+ device=gateup_output.device,
245
+ dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
246
+ )
247
+ if self.w2_input_scale is None:
248
+ self.w2_input_scale = torch.ones(
249
+ self.num_experts_per_partition,
250
+ dtype=torch.float32,
251
+ device=hidden_states.device,
252
+ )
253
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
254
+ gateup_output,
255
+ down_input,
256
+ gateup_output.shape[1],
257
+ reorder_topk_ids,
258
+ self.w2_input_scale,
259
+ self.start_expert_id,
260
+ self.end_expert_id,
261
+ BLOCK_SIZE=512,
262
+ )
263
+
264
+ # GroupGemm-1
265
+ down_output = torch.empty(
266
+ down_input.shape[0],
267
+ self.w2_weight.shape[1],
268
+ device=hidden_states.device,
269
+ dtype=hidden_states.dtype,
270
+ )
271
+ down_output = self.grouped_gemm_runner(
272
+ a=down_input,
273
+ b=self.w2_weight,
274
+ c=down_output,
275
+ batch_size=self.num_experts_per_partition,
276
+ weight_column_major=True,
277
+ seg_indptr=seg_indptr_cur_rank,
278
+ weight_indices=weight_indices_cur_rank,
279
+ use_fp8_w8a8=self.use_fp8_w8a8,
280
+ scale_a=self.w2_input_scale,
281
+ scale_b=self.w2_weight_scale,
282
+ )
283
+
284
+ # PostReorder
285
+ output = torch.empty_like(hidden_states)
286
+ post_reorder_triton_kernel[(hidden_states.size(0),)](
287
+ down_output,
288
+ output,
289
+ src2dst,
290
+ topk_ids,
291
+ topk_weights,
292
+ self.start_expert_id,
293
+ self.end_expert_id,
294
+ self.top_k,
295
+ hidden_states.size(1),
296
+ BLOCK_SIZE=512,
297
+ )
298
+ return output
299
+
300
+ def select_experts(
301
+ self,
302
+ hidden_states: torch.Tensor,
303
+ router_logits: torch.Tensor,
304
+ top_k: int,
305
+ renormalize: bool,
306
+ topk_group: Optional[int] = None,
307
+ num_expert_group: Optional[int] = None,
308
+ ):
309
+ if self.use_grouped_topk:
310
+ assert topk_group is not None
311
+ assert num_expert_group is not None
312
+ topk_weights, topk_ids = grouped_topk(
313
+ hidden_states=hidden_states,
314
+ gating_output=router_logits,
315
+ topk=top_k,
316
+ renormalize=renormalize,
317
+ num_expert_group=num_expert_group,
318
+ topk_group=topk_group,
319
+ )
320
+ else:
321
+ topk_weights, topk_ids = fused_topk(
322
+ hidden_states=hidden_states,
323
+ gating_output=router_logits,
324
+ topk=top_k,
325
+ renormalize=renormalize,
326
+ )
327
+ return topk_weights, topk_ids.to(torch.int32)
328
+
329
+ @classmethod
330
+ def make_expert_params_mapping(
331
+ cls,
332
+ ckpt_gate_proj_name: str,
333
+ ckpt_down_proj_name: str,
334
+ ckpt_up_proj_name: str,
335
+ num_experts: int,
336
+ ) -> List[Tuple[str, str, int, str]]:
337
+
338
+ return [
339
+ # (param_name, weight_name, expert_id, shard_id)
340
+ (
341
+ (
342
+ "experts.w13_"
343
+ if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
344
+ else "experts.w2_"
345
+ ),
346
+ f"experts.{expert_id}.{weight_name}.",
347
+ expert_id,
348
+ shard_id,
349
+ )
350
+ for expert_id in range(num_experts)
351
+ for shard_id, weight_name in [
352
+ ("w1", ckpt_gate_proj_name),
353
+ ("w2", ckpt_down_proj_name),
354
+ ("w3", ckpt_up_proj_name),
355
+ ]
356
+ ]
357
+
358
+ def weight_loader(
359
+ self,
360
+ param: torch.nn.Parameter,
361
+ loaded_weight: torch.Tensor,
362
+ weight_name: str,
363
+ shard_id: str,
364
+ expert_id: int,
365
+ ) -> None:
366
+ if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
367
+ return
368
+ expert_id = expert_id - self.start_expert_id
369
+
370
+ if shard_id not in ("w1", "w2", "w3"):
371
+ raise ValueError(
372
+ f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
373
+ )
374
+
375
+ # Special case for fp8 scales.
376
+ if "scale" in weight_name:
377
+ self._load_fp8_scale(
378
+ param.data, loaded_weight, weight_name, shard_id, expert_id
379
+ )
380
+ return
381
+
382
+ expert_data = param.data[expert_id]
383
+ if shard_id == "w2":
384
+ param.data[expert_id] = loaded_weight
385
+ elif shard_id == "w1":
386
+ param.data[expert_id][: self.intermediate_size, :] = loaded_weight
387
+ elif shard_id == "w3":
388
+ param.data[expert_id][self.intermediate_size :, :] = loaded_weight
389
+ else:
390
+ raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
391
+
392
+ def _load_fp8_scale(
393
+ self,
394
+ param: torch.nn.Parameter,
395
+ loaded_weight: torch.Tensor,
396
+ weight_name: str,
397
+ shard_id: str,
398
+ expert_id: int,
399
+ ) -> None:
400
+ param_data = param.data
401
+
402
+ # Input scales can be loaded directly and should be equal.
403
+ if "input_scale" in weight_name:
404
+ if (
405
+ param_data[expert_id] != 1
406
+ and (param_data[expert_id] - loaded_weight).abs() > 1e-5
407
+ ):
408
+ raise ValueError(
409
+ "input_scales of w1 and w3 of a layer "
410
+ f"must be equal. But got {param_data[expert_id]} "
411
+ f"vs. {loaded_weight}"
412
+ )
413
+ param_data[expert_id] = loaded_weight
414
+ # Weight scales
415
+ elif "weight_scale" in weight_name:
416
+ # If we are in merged column case (gate_up_proj)
417
+ if shard_id in ("w1", "w3"):
418
+ # We have to keep the weight scales of w1 and w3 because
419
+ # we need to re-quantize w1/w3 weights after weight loading.
420
+ idx = 0 if shard_id == "w1" else 1
421
+ param_data[expert_id][idx] = loaded_weight
422
+ # If we are in the row parallel case (down_proj)
423
+ else:
424
+ param_data[expert_id] = loaded_weight
425
+
426
+
427
+ @register_custom_op("sglang_unquantized_ep_moe")
428
+ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
429
+ def create_weights(
430
+ self,
431
+ layer: torch.nn.Module,
432
+ num_experts_per_partition: int,
433
+ hidden_size: int,
434
+ intermediate_size: int,
435
+ params_dtype: torch.dtype,
436
+ **extra_weight_attrs,
437
+ ):
438
+ # Fused gate_up_proj (column parallel)
439
+ w13_weight = torch.nn.Parameter(
440
+ torch.empty(
441
+ num_experts_per_partition,
442
+ 2 * intermediate_size,
443
+ hidden_size,
444
+ dtype=params_dtype,
445
+ ),
446
+ requires_grad=False,
447
+ )
448
+ layer.register_parameter("w13_weight", w13_weight)
449
+ set_weight_attrs(w13_weight, extra_weight_attrs)
450
+
451
+ # down_proj (row parallel)
452
+ w2_weight = torch.nn.Parameter(
453
+ torch.empty(
454
+ num_experts_per_partition,
455
+ hidden_size,
456
+ intermediate_size,
457
+ dtype=params_dtype,
458
+ ),
459
+ requires_grad=False,
460
+ )
461
+ layer.register_parameter("w2_weight", w2_weight)
462
+ set_weight_attrs(w2_weight, extra_weight_attrs)
463
+
464
+ # scale
465
+ ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
466
+ w13_input_scale = torch.nn.Parameter(
467
+ ones_tensor,
468
+ requires_grad=False,
469
+ )
470
+ layer.register_parameter("w13_input_scale", w13_input_scale)
471
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
472
+
473
+ w2_input_scale = torch.nn.Parameter(
474
+ ones_tensor,
475
+ requires_grad=False,
476
+ )
477
+ layer.register_parameter("w2_input_scale", w2_input_scale)
478
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
479
+
480
+ w13_weight_scale = torch.nn.Parameter(
481
+ ones_tensor,
482
+ requires_grad=False,
483
+ )
484
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
485
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
486
+
487
+ w2_weight_scale = torch.nn.Parameter(
488
+ ones_tensor,
489
+ requires_grad=False,
490
+ )
491
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
492
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
493
+
494
+ def apply(
495
+ self,
496
+ layer: torch.nn.Module,
497
+ x: torch.Tensor,
498
+ router_logits: torch.Tensor,
499
+ top_k: int,
500
+ renormalize: bool,
501
+ use_grouped_topk: bool,
502
+ topk_group: Optional[int] = None,
503
+ num_expert_group: Optional[int] = None,
504
+ custom_routing_function: Optional[Callable] = None,
505
+ ) -> torch.Tensor:
506
+ raise NotImplementedError
507
+
508
+
509
+ class Fp8EPMoEMethod(Fp8MoEMethod):
510
+ """MoE method for FP8.
511
+ Supports loading FP8 checkpoints with static weight scale and
512
+ dynamic/static activation scale.
513
+
514
+ Args:
515
+ quant_config: The quantization config.
516
+ """
517
+
518
+ def __init__(self, quant_config: Fp8Config):
519
+ self.quant_config = quant_config
520
+
521
+ def create_weights(
522
+ self,
523
+ layer: Module,
524
+ num_experts_per_partition: int,
525
+ hidden_size: int,
526
+ intermediate_size: int,
527
+ params_dtype: torch.dtype,
528
+ **extra_weight_attrs,
529
+ ):
530
+
531
+ if self.quant_config.is_checkpoint_fp8_serialized:
532
+ params_dtype = torch.float8_e4m3fn
533
+
534
+ # WEIGHTS
535
+ w13_weight = torch.nn.Parameter(
536
+ torch.empty(
537
+ num_experts_per_partition,
538
+ 2 * intermediate_size,
539
+ hidden_size,
540
+ dtype=params_dtype,
541
+ ),
542
+ requires_grad=False,
543
+ )
544
+ layer.register_parameter("w13_weight", w13_weight)
545
+ set_weight_attrs(w13_weight, extra_weight_attrs)
546
+
547
+ w2_weight = torch.nn.Parameter(
548
+ torch.empty(
549
+ num_experts_per_partition,
550
+ hidden_size,
551
+ intermediate_size,
552
+ dtype=params_dtype,
553
+ ),
554
+ requires_grad=False,
555
+ )
556
+ layer.register_parameter("w2_weight", w2_weight)
557
+ set_weight_attrs(w2_weight, extra_weight_attrs)
558
+
559
+ # WEIGHT_SCALES
560
+ # Allocate 2 scales for w1 and w3 respectively.
561
+ w13_weight_scale = torch.nn.Parameter(
562
+ torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
563
+ requires_grad=False,
564
+ )
565
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
566
+
567
+ w2_weight_scale = torch.nn.Parameter(
568
+ torch.ones(num_experts_per_partition, dtype=torch.float32),
569
+ requires_grad=False,
570
+ )
571
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
572
+ # Add the quantization method used (per tensor/grouped/channel)
573
+ # to ensure the weight scales are loaded in properly
574
+ extra_weight_attrs.update({"quant_method": "tensor"})
575
+ # If loading fp8 checkpoint, pass the weight loaders.
576
+ # If loading an fp16 checkpoint, do not (we will quantize in
577
+ # process_weights_after_loading()
578
+ if self.quant_config.is_checkpoint_fp8_serialized:
579
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
580
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
581
+
582
+ # INPUT_SCALES
583
+ if self.quant_config.activation_scheme == "static":
584
+ if not self.quant_config.is_checkpoint_fp8_serialized:
585
+ raise ValueError(
586
+ "Found static activation scheme for checkpoint that "
587
+ "was not serialized fp8."
588
+ )
589
+
590
+ w13_input_scale = torch.nn.Parameter(
591
+ torch.ones(num_experts_per_partition, dtype=torch.float32),
592
+ requires_grad=False,
593
+ )
594
+ layer.register_parameter("w13_input_scale", w13_input_scale)
595
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
596
+
597
+ w2_input_scale = torch.nn.Parameter(
598
+ torch.ones(num_experts_per_partition, dtype=torch.float32),
599
+ requires_grad=False,
600
+ )
601
+ layer.register_parameter("w2_input_scale", w2_input_scale)
602
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
603
+
604
+ else:
605
+ layer.w13_input_scale = None
606
+ layer.w2_input_scale = None
607
+
608
+ def process_weights_after_loading(self, layer: Module) -> None:
609
+
610
+ # If checkpoint is fp16, quantize in place.
611
+ if not self.quant_config.is_checkpoint_fp8_serialized:
612
+ # If rocm, use float8_e4m3fnuz as dtype
613
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
614
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
615
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
616
+
617
+ layer.w13_weight_scale = torch.nn.Parameter(
618
+ torch.ones(
619
+ layer.num_experts_per_partition,
620
+ dtype=torch.float32,
621
+ device=w13_weight.device,
622
+ ),
623
+ requires_grad=False,
624
+ )
625
+
626
+ for expert in range(layer.num_experts_per_partition):
627
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
628
+ ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
629
+ )
630
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
631
+ ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
632
+ )
633
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
634
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
635
+ return
636
+
637
+ # If checkpoint is fp8, we need to handle that the
638
+ # MoE kernels require single activation scale and single weight
639
+ # scale for w13 per expert.
640
+ else:
641
+ if self.quant_config.activation_scheme == "static":
642
+ if layer.w13_input_scale is None or layer.w2_input_scale is None:
643
+ raise ValueError(
644
+ "QuantConfig has static quantization, but found "
645
+ "activation scales are None."
646
+ )
647
+ return
648
+
649
+ def apply(
650
+ self,
651
+ layer: torch.nn.Module,
652
+ x: torch.Tensor,
653
+ router_logits: torch.Tensor,
654
+ top_k: int,
655
+ renormalize: bool,
656
+ use_grouped_topk: bool,
657
+ topk_group: Optional[int] = None,
658
+ num_expert_group: Optional[int] = None,
659
+ custom_routing_function: Optional[Callable] = None,
660
+ ) -> torch.Tensor:
661
+ raise NotImplementedError
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
105
105
  num_expert_group: Optional[int] = None,
106
106
  custom_routing_function: Optional[Callable] = None,
107
107
  ) -> torch.Tensor:
108
- assert custom_routing_function is None
109
- topk_weights, topk_ids = select_experts_native(
110
- hidden_states=x,
111
- router_logits=router_logits,
112
- use_grouped_topk=use_grouped_topk,
113
- top_k=top_k,
114
- renormalize=renormalize,
115
- topk_group=topk_group,
116
- num_expert_group=num_expert_group,
117
- )
108
+
109
+ if use_grouped_topk:
110
+ assert num_expert_group is not None and topk_group is not None
111
+ topk_weights, topk_ids = grouped_topk(
112
+ x,
113
+ router_logits,
114
+ top_k,
115
+ renormalize,
116
+ num_expert_group,
117
+ topk_group,
118
+ )
119
+ elif custom_routing_function is None:
120
+ topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
121
+ else:
122
+ topk_weights, topk_ids = custom_routing_function(
123
+ x, router_logits, top_k, renormalize
124
+ )
125
+
118
126
  w13_weights = layer.w13_weight[topk_ids]
119
127
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
120
128
  w2_weights = layer.w2_weight[topk_ids]
121
- x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
129
+ x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
130
+ x1 = F.silu(x1)
122
131
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
123
132
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
124
133
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
42
42
  "Fp8LinearMethod",
43
43
  "MarlinLinearMethod",
44
44
  "GPTQLinearMethod",
45
+ "QQQLinearMethod",
45
46
  ]
46
47
 
47
48