sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_one_batch.py +2 -4
  4. sglang/bench_serving.py +75 -26
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +2 -2
  7. sglang/srt/configs/model_config.py +13 -14
  8. sglang/srt/constrained/__init__.py +13 -14
  9. sglang/srt/constrained/base_grammar_backend.py +13 -15
  10. sglang/srt/constrained/outlines_backend.py +13 -15
  11. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  12. sglang/srt/constrained/xgrammar_backend.py +38 -57
  13. sglang/srt/conversation.py +13 -15
  14. sglang/srt/hf_transformers_utils.py +13 -15
  15. sglang/srt/layers/activation.py +13 -13
  16. sglang/srt/layers/attention/flashinfer_backend.py +13 -6
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  18. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  19. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  20. sglang/srt/layers/custom_op_util.py +13 -14
  21. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  22. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  23. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  24. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  25. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  26. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  27. sglang/srt/layers/layernorm.py +13 -15
  28. sglang/srt/layers/logits_processor.py +13 -15
  29. sglang/srt/layers/quantization/__init__.py +77 -17
  30. sglang/srt/layers/radix_attention.py +13 -15
  31. sglang/srt/layers/rotary_embedding.py +13 -13
  32. sglang/srt/lora/lora.py +13 -14
  33. sglang/srt/lora/lora_config.py +13 -14
  34. sglang/srt/lora/lora_manager.py +22 -24
  35. sglang/srt/managers/data_parallel_controller.py +25 -19
  36. sglang/srt/managers/detokenizer_manager.py +13 -16
  37. sglang/srt/managers/io_struct.py +43 -28
  38. sglang/srt/managers/schedule_batch.py +55 -26
  39. sglang/srt/managers/schedule_policy.py +13 -15
  40. sglang/srt/managers/scheduler.py +89 -70
  41. sglang/srt/managers/session_controller.py +14 -15
  42. sglang/srt/managers/tokenizer_manager.py +29 -22
  43. sglang/srt/managers/tp_worker.py +13 -15
  44. sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
  45. sglang/srt/metrics/collector.py +13 -15
  46. sglang/srt/metrics/func_timer.py +13 -15
  47. sglang/srt/mm_utils.py +13 -14
  48. sglang/srt/model_executor/cuda_graph_runner.py +20 -19
  49. sglang/srt/model_executor/forward_batch_info.py +19 -17
  50. sglang/srt/model_executor/model_runner.py +42 -30
  51. sglang/srt/models/chatglm.py +15 -16
  52. sglang/srt/models/commandr.py +15 -16
  53. sglang/srt/models/dbrx.py +15 -16
  54. sglang/srt/models/deepseek.py +15 -15
  55. sglang/srt/models/deepseek_v2.py +15 -15
  56. sglang/srt/models/exaone.py +14 -15
  57. sglang/srt/models/gemma.py +14 -14
  58. sglang/srt/models/gemma2.py +24 -19
  59. sglang/srt/models/gemma2_reward.py +13 -14
  60. sglang/srt/models/gpt_bigcode.py +14 -14
  61. sglang/srt/models/grok.py +15 -15
  62. sglang/srt/models/internlm2.py +13 -15
  63. sglang/srt/models/internlm2_reward.py +13 -14
  64. sglang/srt/models/llama.py +21 -21
  65. sglang/srt/models/llama_classification.py +13 -14
  66. sglang/srt/models/llama_reward.py +13 -14
  67. sglang/srt/models/llava.py +13 -15
  68. sglang/srt/models/llavavid.py +13 -15
  69. sglang/srt/models/minicpm.py +13 -15
  70. sglang/srt/models/minicpm3.py +13 -15
  71. sglang/srt/models/mistral.py +13 -15
  72. sglang/srt/models/mixtral.py +15 -15
  73. sglang/srt/models/mixtral_quant.py +14 -14
  74. sglang/srt/models/olmo.py +21 -19
  75. sglang/srt/models/olmoe.py +23 -20
  76. sglang/srt/models/qwen.py +14 -14
  77. sglang/srt/models/qwen2.py +22 -19
  78. sglang/srt/models/qwen2_moe.py +17 -18
  79. sglang/srt/models/stablelm.py +18 -16
  80. sglang/srt/models/torch_native_llama.py +15 -17
  81. sglang/srt/models/xverse.py +13 -14
  82. sglang/srt/models/xverse_moe.py +15 -16
  83. sglang/srt/models/yivl.py +13 -15
  84. sglang/srt/openai_api/adapter.py +13 -15
  85. sglang/srt/openai_api/protocol.py +13 -15
  86. sglang/srt/sampling/sampling_batch_info.py +4 -1
  87. sglang/srt/sampling/sampling_params.py +13 -15
  88. sglang/srt/server.py +59 -34
  89. sglang/srt/server_args.py +22 -22
  90. sglang/srt/utils.py +196 -17
  91. sglang/test/few_shot_gsm8k.py +8 -4
  92. sglang/test/runners.py +13 -14
  93. sglang/test/test_utils.py +1 -1
  94. sglang/version.py +1 -1
  95. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  96. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
  97. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  98. sglang/srt/layers/fused_moe/__init__.py +0 -1
  99. sglang-0.3.6.dist-info/RECORD +0 -161
  100. /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
  101. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
  102. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,633 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
+
3
+ from abc import abstractmethod
4
+ from enum import Enum
5
+ from typing import Callable, List, Optional, Tuple
6
+
7
+ import torch
8
+ from vllm.distributed import (
9
+ get_tensor_model_parallel_rank,
10
+ get_tensor_model_parallel_world_size,
11
+ tensor_model_parallel_all_reduce,
12
+ )
13
+ from vllm.model_executor.custom_op import CustomOp
14
+
15
+ from sglang.srt.layers.custom_op_util import register_custom_op
16
+ from sglang.srt.layers.quantization.base_config import (
17
+ QuantizationConfig,
18
+ QuantizeMethodBase,
19
+ )
20
+ from sglang.srt.utils import set_weight_attrs
21
+
22
+ if torch.cuda.is_available() or torch.hip.is_available():
23
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
24
+ else:
25
+ fused_experts = None # type: ignore
26
+
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class FusedMoeWeightScaleSupported(Enum):
33
+ TENSOR = "tensor"
34
+ CHANNEL = "channel"
35
+ GROUP = "group"
36
+
37
+
38
+ class FusedMoEMethodBase(QuantizeMethodBase):
39
+
40
+ @abstractmethod
41
+ def create_weights(
42
+ self,
43
+ layer: torch.nn.Module,
44
+ num_experts: int,
45
+ hidden_size: int,
46
+ intermediate_size: int,
47
+ params_dtype: torch.dtype,
48
+ **extra_weight_attrs,
49
+ ):
50
+ raise NotImplementedError
51
+
52
+ @abstractmethod
53
+ def apply(
54
+ self,
55
+ layer: torch.nn.Module,
56
+ x: torch.Tensor,
57
+ router_logits: torch.Tensor,
58
+ top_k: int,
59
+ renormalize: bool,
60
+ use_grouped_topk: bool,
61
+ ) -> torch.Tensor:
62
+ raise NotImplementedError
63
+
64
+
65
+ @register_custom_op("sglang_unquantized_fused_moe")
66
+ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
67
+ """MoE method without quantization."""
68
+
69
+ def create_weights(
70
+ self,
71
+ layer: torch.nn.Module,
72
+ num_experts: int,
73
+ hidden_size: int,
74
+ intermediate_size: int,
75
+ params_dtype: torch.dtype,
76
+ **extra_weight_attrs,
77
+ ):
78
+ # Fused gate_up_proj (column parallel)
79
+ w13_weight = torch.nn.Parameter(
80
+ torch.empty(
81
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
82
+ ),
83
+ requires_grad=False,
84
+ )
85
+ layer.register_parameter("w13_weight", w13_weight)
86
+ set_weight_attrs(w13_weight, extra_weight_attrs)
87
+
88
+ # down_proj (row parallel)
89
+ w2_weight = torch.nn.Parameter(
90
+ torch.empty(
91
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
92
+ ),
93
+ requires_grad=False,
94
+ )
95
+ layer.register_parameter("w2_weight", w2_weight)
96
+ set_weight_attrs(w2_weight, extra_weight_attrs)
97
+
98
+ def apply(
99
+ self,
100
+ layer: torch.nn.Module,
101
+ x: torch.Tensor,
102
+ router_logits: torch.Tensor,
103
+ top_k: int,
104
+ renormalize: bool,
105
+ use_grouped_topk: bool,
106
+ topk_group: Optional[int] = None,
107
+ num_expert_group: Optional[int] = None,
108
+ custom_routing_function: Optional[Callable] = None,
109
+ ) -> torch.Tensor:
110
+ return self.forward(
111
+ x=x,
112
+ layer=layer,
113
+ router_logits=router_logits,
114
+ top_k=top_k,
115
+ renormalize=renormalize,
116
+ use_grouped_topk=use_grouped_topk,
117
+ topk_group=topk_group,
118
+ num_expert_group=num_expert_group,
119
+ custom_routing_function=custom_routing_function,
120
+ )
121
+
122
+ def forward_cuda(
123
+ self,
124
+ layer: torch.nn.Module,
125
+ x: torch.Tensor,
126
+ use_grouped_topk: bool,
127
+ top_k: int,
128
+ router_logits: torch.Tensor,
129
+ renormalize: bool,
130
+ topk_group: Optional[int] = None,
131
+ num_expert_group: Optional[int] = None,
132
+ custom_routing_function: Optional[Callable] = None,
133
+ ) -> torch.Tensor:
134
+ topk_weights, topk_ids = FusedMoE.select_experts(
135
+ hidden_states=x,
136
+ router_logits=router_logits,
137
+ use_grouped_topk=use_grouped_topk,
138
+ top_k=top_k,
139
+ renormalize=renormalize,
140
+ topk_group=topk_group,
141
+ num_expert_group=num_expert_group,
142
+ custom_routing_function=custom_routing_function,
143
+ )
144
+
145
+ return fused_experts(
146
+ hidden_states=x,
147
+ w1=layer.w13_weight,
148
+ w2=layer.w2_weight,
149
+ topk_weights=topk_weights,
150
+ topk_ids=topk_ids,
151
+ inplace=True,
152
+ )
153
+
154
+ def forward_cpu(self, *args, **kwargs):
155
+ raise NotImplementedError("The CPU backend currently does not support MoE.")
156
+
157
+ def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
158
+ raise NotImplementedError("The TPU backend currently does not support MoE.")
159
+
160
+ forward_native = forward_cuda
161
+
162
+
163
+ class FusedMoE(torch.nn.Module):
164
+ """FusedMoE layer for MoE models.
165
+
166
+ This layer contains both MergedColumnParallel weights (gate_up_proj /
167
+ w13) and RowParallelLinear weights (down_proj/ w2).
168
+
169
+ Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
170
+ copy that naming convention here and handle any remapping in the
171
+ load_weights function in each model implementation.
172
+
173
+ Args:
174
+ num_experts: Number of experts in the model
175
+ top_k: Number of experts selected for each token
176
+ hidden_size: Input hidden state size of the transformer
177
+ intermediate_size: Intermediate size of the experts
178
+ params_dtype: Data type for the parameters.
179
+ reduce_results: Whether to all all_reduce on the output of the layer
180
+ renomalize: Whether to renormalize the logits in the fused_moe kernel
181
+ quant_config: Quantization configure.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ num_experts: int,
187
+ top_k: int,
188
+ hidden_size: int,
189
+ intermediate_size: int,
190
+ params_dtype: Optional[torch.dtype] = None,
191
+ reduce_results: bool = False,
192
+ renormalize: bool = True,
193
+ use_grouped_topk: bool = False,
194
+ num_expert_group: Optional[int] = None,
195
+ topk_group: Optional[int] = None,
196
+ quant_config: Optional[QuantizationConfig] = None,
197
+ tp_size: Optional[int] = None,
198
+ prefix: str = "",
199
+ custom_routing_function: Optional[Callable] = None,
200
+ ):
201
+ super().__init__()
202
+
203
+ if params_dtype is None:
204
+ params_dtype = torch.get_default_dtype()
205
+
206
+ self.tp_size = (
207
+ tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
208
+ )
209
+ self.top_k = top_k
210
+ self.num_experts = num_experts
211
+ self.intermediate_size_per_partition = intermediate_size // self.tp_size
212
+ self.reduce_results = reduce_results
213
+ self.renormalize = renormalize
214
+ self.use_grouped_topk = use_grouped_topk
215
+ if self.use_grouped_topk:
216
+ assert num_expert_group is not None and topk_group is not None
217
+ self.num_expert_group = num_expert_group
218
+ self.topk_group = topk_group
219
+ self.custom_routing_function = custom_routing_function
220
+
221
+ if quant_config is None:
222
+ self.quant_method: Optional[QuantizeMethodBase] = (
223
+ UnquantizedFusedMoEMethod()
224
+ )
225
+ else:
226
+ self.quant_method = quant_config.get_quant_method(self, prefix)
227
+ assert self.quant_method is not None
228
+
229
+ self.quant_method.create_weights(
230
+ layer=self,
231
+ num_experts=num_experts,
232
+ hidden_size=hidden_size,
233
+ intermediate_size=self.intermediate_size_per_partition,
234
+ params_dtype=params_dtype,
235
+ weight_loader=self.weight_loader,
236
+ )
237
+
238
+ def _load_per_tensor_weight_scale(
239
+ self,
240
+ shard_id: str,
241
+ param: torch.nn.Parameter,
242
+ loaded_weight: torch.Tensor,
243
+ expert_id: int,
244
+ ):
245
+ param_data = param.data
246
+ # for per tensor weight quantization
247
+ if shard_id in ("w1", "w3"):
248
+ # We have to keep the weight scales of w1 and w3 because
249
+ # we need to re-quantize w1/w3 weights after weight loading.
250
+ idx = 0 if shard_id == "w1" else 1
251
+ param_data[expert_id][idx] = loaded_weight
252
+ # If we are in the row parallel case (down_proj)
253
+ elif shard_id == "w2":
254
+ param_data[expert_id] = loaded_weight
255
+
256
+ def _load_model_weight_or_group_weight_scale(
257
+ self,
258
+ shard_dim: int,
259
+ expert_data: torch.Tensor,
260
+ shard_id: str,
261
+ loaded_weight: torch.tensor,
262
+ tp_rank: int,
263
+ ):
264
+ # Load grouped weight scales for group quantization
265
+ # or model weights
266
+ if shard_id == "w2":
267
+ self._load_w2(
268
+ shard_id=shard_id,
269
+ shard_dim=shard_dim,
270
+ loaded_weight=loaded_weight,
271
+ expert_data=expert_data,
272
+ tp_rank=tp_rank,
273
+ )
274
+ elif shard_id in ("w1", "w3"):
275
+ self._load_w13(
276
+ shard_id=shard_id,
277
+ shard_dim=shard_dim,
278
+ loaded_weight=loaded_weight,
279
+ expert_data=expert_data,
280
+ tp_rank=tp_rank,
281
+ )
282
+
283
+ def _load_per_channel_weight_scale(
284
+ self,
285
+ expert_data: torch.Tensor,
286
+ shard_dim: int,
287
+ shard_id: str,
288
+ loaded_weight: torch.tensor,
289
+ tp_rank: int,
290
+ ):
291
+ # for per channel weight quantization
292
+ if shard_id == "w2":
293
+ expert_data.copy_(loaded_weight)
294
+ elif shard_id in ("w1", "w3"):
295
+ self._load_w13(
296
+ shard_id=shard_id,
297
+ shard_dim=shard_dim,
298
+ loaded_weight=loaded_weight,
299
+ expert_data=expert_data,
300
+ tp_rank=tp_rank,
301
+ )
302
+
303
+ def _load_w13(
304
+ self,
305
+ expert_data: torch.Tensor,
306
+ shard_dim: int,
307
+ shard_id: str,
308
+ loaded_weight: torch.tensor,
309
+ tp_rank: int,
310
+ ):
311
+
312
+ # Index the loaded weight for tp sharding.
313
+ # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
314
+ shard_size = expert_data.shape[shard_dim] // 2
315
+ loaded_weight = loaded_weight.narrow(
316
+ shard_dim, shard_size * tp_rank, shard_size
317
+ )
318
+ # Narrow parameter and load.
319
+ # w1, gate_proj: Load into first logical weight of w13.
320
+ if shard_id == "w1":
321
+ expert_data = expert_data.narrow(shard_dim, 0, shard_size)
322
+ # w3, up_proj: Load into second logical weight of w13.
323
+ else:
324
+ assert shard_id == "w3"
325
+ expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
326
+ expert_data.copy_(loaded_weight)
327
+
328
+ def _load_w2(
329
+ self,
330
+ expert_data: torch.Tensor,
331
+ shard_dim: int,
332
+ shard_id: str,
333
+ loaded_weight: torch.tensor,
334
+ tp_rank: int,
335
+ ):
336
+
337
+ # Index the loaded weight for tp sharding.
338
+ # down_proj: "RowParallel" so tp sharding on input_dim
339
+ # Narrow parameter and load.
340
+ shard_size = expert_data.shape[shard_dim]
341
+ loaded_weight = loaded_weight.narrow(
342
+ shard_dim, shard_size * tp_rank, shard_size
343
+ )
344
+ # w2, down_proj: Load into only logical weight of w2.
345
+ expert_data.copy_(loaded_weight)
346
+
347
+ def _load_single_value(
348
+ self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
349
+ ):
350
+ param_data = param.data
351
+
352
+ # Input scales can be loaded directly and should be equal.
353
+ param_data[expert_id] = loaded_weight
354
+
355
+ def _load_g_idx(
356
+ self,
357
+ shard_id: str,
358
+ expert_data: torch.Tensor,
359
+ shard_dim: int,
360
+ loaded_weight: torch.tensor,
361
+ tp_rank: int,
362
+ ):
363
+
364
+ if shard_id == "w2":
365
+ self._load_w2(
366
+ shard_id=shard_id,
367
+ shard_dim=shard_dim,
368
+ loaded_weight=loaded_weight,
369
+ expert_data=expert_data,
370
+ tp_rank=tp_rank,
371
+ )
372
+ else:
373
+ assert shard_id in ("w1", "w3")
374
+ expert_data.copy_(loaded_weight)
375
+
376
+ def weight_loader(
377
+ self,
378
+ param: torch.nn.Parameter,
379
+ loaded_weight: torch.Tensor,
380
+ weight_name: str,
381
+ shard_id: str,
382
+ expert_id: int,
383
+ ) -> None:
384
+
385
+ # compressed-tensors checkpoints with packed weights are stored flipped
386
+ # TODO (mgoin): check self.quant_method.quant_config.quant_format
387
+ # against known CompressionFormat enum values that have this quality
388
+ loaded_weight = (
389
+ loaded_weight.t().contiguous()
390
+ if (
391
+ self.quant_method.__class__.__name__
392
+ == "CompressedTensorsWNA16MoEMethod"
393
+ )
394
+ else loaded_weight
395
+ )
396
+
397
+ if shard_id not in ("w1", "w2", "w3"):
398
+ raise ValueError(
399
+ f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
400
+ )
401
+
402
+ WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
403
+ # Fetch the dim to shard the parameter/loaded weight
404
+ # based on the shard id. This will be whatever
405
+ # dimension intermediate_size is used.
406
+ SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
407
+
408
+ expert_data = param.data[expert_id]
409
+ tp_rank = get_tensor_model_parallel_rank()
410
+
411
+ # is_transposed: if the dim to shard the weight
412
+ # should be flipped. Required by GPTQ, compressed-tensors
413
+ # should be whatever dimension intermediate_size is
414
+ is_transposed = getattr(param, "is_transposed", False)
415
+ shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
416
+ if is_transposed:
417
+ shard_dim = ~shard_dim
418
+
419
+ # Case input scale: input_scale loading is only supported for fp8
420
+ if "input_scale" in weight_name:
421
+ # this is needed for compressed-tensors only
422
+ loaded_weight = loaded_weight.to(param.data.device)
423
+
424
+ if (
425
+ param.data[expert_id] != 1
426
+ and (param.data[expert_id] - loaded_weight).abs() > 1e-5
427
+ ):
428
+ raise ValueError(
429
+ "input_scales of w1 and w3 of a layer "
430
+ f"must be equal. But got {param.data[expert_id]} "
431
+ f"vs. {loaded_weight}"
432
+ )
433
+
434
+ self._load_single_value(
435
+ param=param, loaded_weight=loaded_weight, expert_id=expert_id
436
+ )
437
+ return
438
+
439
+ # Case g_idx
440
+ if "g_idx" in weight_name:
441
+ self._load_g_idx(
442
+ shard_dim=0,
443
+ shard_id=shard_id,
444
+ loaded_weight=loaded_weight,
445
+ expert_data=expert_data,
446
+ tp_rank=tp_rank,
447
+ )
448
+ return
449
+
450
+ # Case weight scales and zero_points
451
+ if "scale" in weight_name or "zero" in weight_name:
452
+ # load the weight scales and zp based on the quantization scheme
453
+ # supported weight scales/zp can be found in
454
+ # FusedMoeWeightScaleSupported
455
+ # TODO @dsikka: once hardened, refactor to use vLLM Parameters
456
+ # specific to each case
457
+ quant_method = getattr(param, "quant_method", None)
458
+ if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
459
+ self._load_per_channel_weight_scale(
460
+ shard_id=shard_id,
461
+ shard_dim=shard_dim,
462
+ loaded_weight=loaded_weight,
463
+ expert_data=expert_data,
464
+ tp_rank=tp_rank,
465
+ )
466
+ elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
467
+ self._load_model_weight_or_group_weight_scale(
468
+ shard_id=shard_id,
469
+ shard_dim=shard_dim,
470
+ loaded_weight=loaded_weight,
471
+ expert_data=expert_data,
472
+ tp_rank=tp_rank,
473
+ )
474
+ elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
475
+ self._load_per_tensor_weight_scale(
476
+ shard_id=shard_id,
477
+ param=param,
478
+ loaded_weight=loaded_weight,
479
+ expert_id=expert_id,
480
+ )
481
+ else:
482
+ raise ValueError(
483
+ f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}"
484
+ )
485
+ return
486
+
487
+ # Case weight_shape
488
+ if "weight_shape" in weight_name:
489
+ # only required by compressed-tensors
490
+ self._load_single_value(
491
+ param=param, loaded_weight=loaded_weight, expert_id=expert_id
492
+ )
493
+ return
494
+
495
+ # Case model weights
496
+ if "weight" in weight_name:
497
+ self._load_model_weight_or_group_weight_scale(
498
+ shard_id=shard_id,
499
+ shard_dim=shard_dim,
500
+ loaded_weight=loaded_weight,
501
+ expert_data=expert_data,
502
+ tp_rank=tp_rank,
503
+ )
504
+ return
505
+
506
+ @staticmethod
507
+ def select_experts(
508
+ hidden_states: torch.Tensor,
509
+ router_logits: torch.Tensor,
510
+ top_k: int,
511
+ use_grouped_topk: bool,
512
+ renormalize: bool,
513
+ topk_group: Optional[int] = None,
514
+ num_expert_group: Optional[int] = None,
515
+ custom_routing_function: Optional[Callable] = None,
516
+ ):
517
+ from sglang.srt.layers.fused_moe_triton.fused_moe import (
518
+ fused_topk,
519
+ grouped_topk,
520
+ )
521
+
522
+ # DeekSeekv2 uses grouped_top_k
523
+ if use_grouped_topk:
524
+ assert topk_group is not None
525
+ assert num_expert_group is not None
526
+ topk_weights, topk_ids = grouped_topk(
527
+ hidden_states=hidden_states,
528
+ gating_output=router_logits,
529
+ topk=top_k,
530
+ renormalize=renormalize,
531
+ num_expert_group=num_expert_group,
532
+ topk_group=topk_group,
533
+ )
534
+ elif custom_routing_function is None:
535
+ topk_weights, topk_ids = fused_topk(
536
+ hidden_states=hidden_states,
537
+ gating_output=router_logits,
538
+ topk=top_k,
539
+ renormalize=renormalize,
540
+ )
541
+ else:
542
+ topk_weights, topk_ids = custom_routing_function(
543
+ hidden_states=hidden_states,
544
+ gating_output=router_logits,
545
+ topk=top_k,
546
+ renormalize=renormalize,
547
+ )
548
+
549
+ return topk_weights, topk_ids
550
+
551
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
552
+ assert self.quant_method is not None
553
+
554
+ # Matrix multiply.
555
+ final_hidden_states = self.quant_method.apply(
556
+ layer=self,
557
+ x=hidden_states,
558
+ router_logits=router_logits,
559
+ top_k=self.top_k,
560
+ renormalize=self.renormalize,
561
+ use_grouped_topk=self.use_grouped_topk,
562
+ topk_group=self.topk_group,
563
+ num_expert_group=self.num_expert_group,
564
+ custom_routing_function=self.custom_routing_function,
565
+ )
566
+
567
+ if self.reduce_results and self.tp_size > 1:
568
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
569
+
570
+ return final_hidden_states
571
+
572
+ @classmethod
573
+ def make_expert_params_mapping(
574
+ cls,
575
+ ckpt_gate_proj_name: str,
576
+ ckpt_down_proj_name: str,
577
+ ckpt_up_proj_name: str,
578
+ num_experts: int,
579
+ ) -> List[Tuple[str, str, int, str]]:
580
+
581
+ return [
582
+ # (param_name, weight_name, expert_id, shard_id)
583
+ (
584
+ (
585
+ "experts.w13_"
586
+ if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
587
+ else "experts.w2_"
588
+ ),
589
+ f"experts.{expert_id}.{weight_name}.",
590
+ expert_id,
591
+ shard_id,
592
+ )
593
+ for expert_id in range(num_experts)
594
+ for shard_id, weight_name in [
595
+ ("w1", ckpt_gate_proj_name),
596
+ ("w2", ckpt_down_proj_name),
597
+ ("w3", ckpt_up_proj_name),
598
+ ]
599
+ ]
600
+
601
+ def _load_fp8_scale(
602
+ self,
603
+ param: torch.nn.Parameter,
604
+ loaded_weight: torch.Tensor,
605
+ weight_name: str,
606
+ shard_id: str,
607
+ expert_id: int,
608
+ ) -> None:
609
+ param_data = param.data
610
+
611
+ # Input scales can be loaded directly and should be equal.
612
+ if "input_scale" in weight_name:
613
+ if (
614
+ param_data[expert_id] != 1
615
+ and (param_data[expert_id] - loaded_weight).abs() > 1e-5
616
+ ):
617
+ raise ValueError(
618
+ "input_scales of w1 and w3 of a layer "
619
+ f"must be equal. But got {param_data[expert_id]} "
620
+ f"vs. {loaded_weight}"
621
+ )
622
+ param_data[expert_id] = loaded_weight
623
+ # Weight scales
624
+ elif "weight_scale" in weight_name:
625
+ # If we are in merged column case (gate_up_proj)
626
+ if shard_id in ("w1", "w3"):
627
+ # We have to keep the weight scales of w1 and w3 because
628
+ # we need to re-quantize w1/w3 weights after weight loading.
629
+ idx = 0 if shard_id == "w1" else 1
630
+ param_data[expert_id][idx] = loaded_weight
631
+ # If we are in the row parallel case (down_proj)
632
+ else:
633
+ param_data[expert_id] = loaded_weight
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Fused operators for normalization layers."""
17
15
 
18
16
  import logging
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Logits processing."""
17
15
 
18
16
  import dataclasses