sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +3 -2
  3. sglang/global_config.py +1 -1
  4. sglang/lang/backend/runtime_endpoint.py +60 -49
  5. sglang/lang/interpreter.py +4 -2
  6. sglang/lang/ir.py +13 -4
  7. sglang/srt/constrained/jump_forward.py +13 -2
  8. sglang/srt/layers/activation.py +0 -1
  9. sglang/srt/layers/extend_attention.py +3 -1
  10. sglang/srt/layers/fused_moe/__init__.py +1 -0
  11. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  12. sglang/srt/layers/fused_moe/layer.py +587 -0
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/radix_attention.py +38 -14
  15. sglang/srt/managers/schedule_batch.py +9 -14
  16. sglang/srt/managers/tokenizer_manager.py +1 -1
  17. sglang/srt/managers/tp_worker.py +1 -7
  18. sglang/srt/model_executor/cuda_graph_runner.py +48 -17
  19. sglang/srt/model_executor/forward_batch_info.py +132 -58
  20. sglang/srt/model_executor/model_runner.py +61 -28
  21. sglang/srt/models/chatglm.py +2 -2
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/deepseek.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +7 -6
  25. sglang/srt/models/gemma.py +1 -1
  26. sglang/srt/models/gemma2.py +11 -5
  27. sglang/srt/models/grok.py +50 -396
  28. sglang/srt/models/minicpm.py +2 -2
  29. sglang/srt/models/mixtral.py +56 -254
  30. sglang/srt/models/mixtral_quant.py +1 -4
  31. sglang/srt/models/qwen.py +2 -2
  32. sglang/srt/models/qwen2.py +2 -2
  33. sglang/srt/models/qwen2_moe.py +2 -2
  34. sglang/srt/models/stablelm.py +1 -1
  35. sglang/srt/openai_api/adapter.py +32 -21
  36. sglang/srt/sampling_params.py +0 -4
  37. sglang/srt/server.py +23 -15
  38. sglang/srt/server_args.py +7 -1
  39. sglang/srt/utils.py +1 -2
  40. sglang/test/runners.py +18 -10
  41. sglang/test/test_programs.py +32 -5
  42. sglang/test/test_utils.py +5 -1
  43. sglang/version.py +1 -1
  44. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
  45. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
  46. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  47. sglang/srt/model_loader/model_loader.py +0 -292
  48. sglang/srt/model_loader/utils.py +0 -275
  49. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  50. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,587 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
3
+ from abc import abstractmethod
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ from vllm.distributed import (
8
+ get_tensor_model_parallel_rank,
9
+ get_tensor_model_parallel_world_size,
10
+ tensor_model_parallel_all_reduce,
11
+ )
12
+ from vllm.logger import init_logger
13
+ from vllm.model_executor.custom_op import CustomOp
14
+ from vllm.model_executor.layers.quantization.base_config import (
15
+ QuantizationConfig,
16
+ QuantizeMethodBase,
17
+ )
18
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
19
+ from vllm.model_executor.utils import set_weight_attrs
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class FusedMoEMethodBase(QuantizeMethodBase):
25
+
26
+ @abstractmethod
27
+ def create_weights(
28
+ self,
29
+ layer: torch.nn.Module,
30
+ num_experts: int,
31
+ hidden_size: int,
32
+ intermediate_size: int,
33
+ params_dtype: torch.dtype,
34
+ **extra_weight_attrs,
35
+ ):
36
+ raise NotImplementedError
37
+
38
+ @abstractmethod
39
+ def apply(
40
+ self,
41
+ layer: torch.nn.Module,
42
+ x: torch.Tensor,
43
+ router_logits: torch.Tensor,
44
+ top_k: int,
45
+ renormalize: bool = True,
46
+ use_grouped_topk: bool = False,
47
+ num_expert_group: Optional[int] = None,
48
+ topk_group: Optional[int] = None,
49
+ ) -> torch.Tensor:
50
+ raise NotImplementedError
51
+
52
+
53
+ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
54
+ """MoE method without quantization."""
55
+
56
+ def create_weights(
57
+ self,
58
+ layer: torch.nn.Module,
59
+ num_experts: int,
60
+ hidden_size: int,
61
+ intermediate_size: int,
62
+ params_dtype: torch.dtype,
63
+ **extra_weight_attrs,
64
+ ):
65
+
66
+ # Fused gate_up_proj (column parallel)
67
+ w13_weight = torch.nn.Parameter(
68
+ torch.empty(
69
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
70
+ ),
71
+ requires_grad=False,
72
+ )
73
+ layer.register_parameter("w13_weight", w13_weight)
74
+ set_weight_attrs(w13_weight, extra_weight_attrs)
75
+
76
+ # down_proj (row parallel)
77
+ w2_weight = torch.nn.Parameter(
78
+ torch.empty(
79
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
80
+ ),
81
+ requires_grad=False,
82
+ )
83
+ layer.register_parameter("w2_weight", w2_weight)
84
+ set_weight_attrs(w2_weight, extra_weight_attrs)
85
+
86
+ def apply(
87
+ self,
88
+ layer: torch.nn.Module,
89
+ x: torch.Tensor,
90
+ router_logits: torch.Tensor,
91
+ top_k: int,
92
+ renormalize: bool = True,
93
+ use_grouped_topk: bool = False,
94
+ num_expert_group: Optional[int] = None,
95
+ topk_group: Optional[int] = None,
96
+ ) -> torch.Tensor:
97
+ return self.forward(
98
+ x,
99
+ layer.w13_weight,
100
+ layer.w2_weight,
101
+ router_logits,
102
+ top_k,
103
+ renormalize,
104
+ use_grouped_topk,
105
+ num_expert_group,
106
+ topk_group,
107
+ )
108
+
109
+ def forward_cuda(
110
+ self,
111
+ x: torch.Tensor,
112
+ w1: torch.Tensor,
113
+ w2: torch.Tensor,
114
+ router_logits: torch.Tensor,
115
+ top_k: int,
116
+ renormalize: bool,
117
+ use_grouped_topk: bool,
118
+ num_expert_group: Optional[int],
119
+ topk_group: Optional[int],
120
+ ) -> torch.Tensor:
121
+ from sglang.srt.layers.fused_moe.fused_moe import fused_moe
122
+
123
+ return fused_moe(
124
+ x,
125
+ w1,
126
+ w2,
127
+ router_logits,
128
+ top_k,
129
+ renormalize=renormalize,
130
+ inplace=True,
131
+ use_grouped_topk=use_grouped_topk,
132
+ num_expert_group=num_expert_group,
133
+ topk_group=topk_group,
134
+ )
135
+
136
+ def forward_cpu(self, *args, **kwargs):
137
+ raise NotImplementedError("The CPU backend currently does not support MoE.")
138
+
139
+ def forward_tpu(
140
+ self,
141
+ x: torch.Tensor,
142
+ w1: torch.Tensor,
143
+ w2: torch.Tensor,
144
+ router_logits: torch.Tensor,
145
+ top_k: int,
146
+ renormalize: bool,
147
+ use_grouped_topk: bool,
148
+ num_expert_group: Optional[int],
149
+ topk_group: Optional[int],
150
+ ) -> torch.Tensor:
151
+ from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
152
+
153
+ assert not use_grouped_topk
154
+ assert num_expert_group is None
155
+ assert topk_group is None
156
+ return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
157
+
158
+
159
+ class FusedMoE(torch.nn.Module):
160
+ """FusedMoE layer for MoE models.
161
+
162
+ This layer contains both MergedColumnParallel weights (gate_up_proj /
163
+ w13) and RowParallelLinear weights (down_proj/ w2).
164
+
165
+ Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
166
+ copy that naming convention here and handle any remapping in the
167
+ load_weights function in each model implementation.
168
+
169
+ Args:
170
+ num_experts: Number of experts in the model
171
+ top_k: Number of experts selected for each token
172
+ hidden_size: Input hidden state size of the transformer
173
+ intermediate_size: Intermediate size of the experts
174
+ params_dtype: Data type for the parameters.
175
+ reduce_results: Whether to all all_reduce on the output of the layer
176
+ renomalize: Whether to renormalize the logits in the fused_moe kernel
177
+ quant_config: Quantization configure.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ num_experts: int,
183
+ top_k: int,
184
+ hidden_size: int,
185
+ intermediate_size: int,
186
+ params_dtype: Optional[torch.dtype] = None,
187
+ reduce_results: bool = False,
188
+ renormalize: bool = True,
189
+ use_grouped_topk: bool = False,
190
+ num_expert_group: Optional[int] = None,
191
+ topk_group: Optional[int] = None,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ tp_size: Optional[int] = None,
194
+ prefix: str = "",
195
+ ):
196
+ super().__init__()
197
+
198
+ if params_dtype is None:
199
+ params_dtype = torch.get_default_dtype()
200
+
201
+ self.tp_size = (
202
+ tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
203
+ )
204
+ self.top_k = top_k
205
+ self.num_experts = num_experts
206
+ self.intermediate_size_per_partition = intermediate_size // self.tp_size
207
+ self.reduce_results = reduce_results
208
+ self.renormalize = renormalize
209
+ self.use_grouped_topk = use_grouped_topk
210
+ if self.use_grouped_topk:
211
+ assert num_expert_group is not None and topk_group is not None
212
+ self.num_expert_group = num_expert_group
213
+ self.topk_group = topk_group
214
+
215
+ if quant_config is None:
216
+ self.quant_method: Optional[QuantizeMethodBase] = (
217
+ UnquantizedFusedMoEMethod()
218
+ )
219
+ else:
220
+ if isinstance(quant_config, Fp8Config):
221
+ self.quant_method = Fp8MoEMethod(quant_config)
222
+ else:
223
+ self.quant_method = quant_config.get_quant_method(self, prefix)
224
+ assert self.quant_method is not None
225
+
226
+ self.quant_method.create_weights(
227
+ layer=self,
228
+ num_experts=num_experts,
229
+ hidden_size=hidden_size,
230
+ intermediate_size=self.intermediate_size_per_partition,
231
+ params_dtype=params_dtype,
232
+ weight_loader=self.weight_loader,
233
+ )
234
+
235
+ def weight_loader(
236
+ self,
237
+ param: torch.nn.Parameter,
238
+ loaded_weight: torch.Tensor,
239
+ weight_name: str,
240
+ shard_id: int,
241
+ expert_id: int,
242
+ pre_sharded: bool,
243
+ ):
244
+ param_data = param.data
245
+
246
+ # Input scales can be loaded directly and should be equal.
247
+ if "input_scale" in weight_name:
248
+ if (
249
+ param_data[expert_id] != 1
250
+ and (param_data[expert_id] - loaded_weight).abs() > 1e-5
251
+ ):
252
+ raise ValueError(
253
+ "input_scales of w1 and w3 of a layer "
254
+ f"must be equal. But got {param_data[expert_id]} "
255
+ f"vs. {loaded_weight}"
256
+ )
257
+ param_data[expert_id] = loaded_weight
258
+ # Weight scales
259
+ elif "weight_scale" in weight_name:
260
+ # If we are in merged column case (gate_up_proj)
261
+ # shard_id 0 == gate_proj / w1
262
+ # shard_id 2 == up_proj / w3
263
+ if shard_id == 0 or shard_id == 2:
264
+ # We have to keep the weight scales of w1 and w3 because
265
+ # we need to re-quantize w1/w3 weights after weight loading.
266
+ idx = 0 if shard_id == 0 else 1
267
+ param_data[expert_id][idx] = loaded_weight
268
+ # If we are in the row parallel case (down_proj)
269
+ # shard_id 1 == down_proj / w2
270
+ else:
271
+ param_data[expert_id] = loaded_weight
272
+ # Weights
273
+ else:
274
+ tp_rank = get_tensor_model_parallel_rank()
275
+ shard_size = self.intermediate_size_per_partition
276
+ if pre_sharded:
277
+ shard = slice(None)
278
+ else:
279
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
280
+
281
+ # w1, gate_proj case: Load into first shard of w13.
282
+ if shard_id == 0:
283
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
284
+ # w3, up_proj case: Load into second shard of w13.
285
+ elif shard_id == 2:
286
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
287
+ shard, :
288
+ ]
289
+ # w2, down_proj case: Load into only shard of w2.
290
+ elif shard_id == 1:
291
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
292
+ else:
293
+ raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
294
+
295
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
296
+ assert self.quant_method is not None
297
+
298
+ # Matrix multiply.
299
+ final_hidden_states = self.quant_method.apply(
300
+ self,
301
+ x=hidden_states,
302
+ router_logits=router_logits,
303
+ top_k=self.top_k,
304
+ renormalize=self.renormalize,
305
+ use_grouped_topk=self.use_grouped_topk,
306
+ num_expert_group=self.num_expert_group,
307
+ topk_group=self.topk_group,
308
+ )
309
+
310
+ if self.reduce_results and self.tp_size > 1:
311
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
312
+
313
+ return final_hidden_states
314
+
315
+ @classmethod
316
+ def make_expert_params_mapping(
317
+ cls,
318
+ ckpt_gate_proj_name: str,
319
+ ckpt_down_proj_name: str,
320
+ ckpt_up_proj_name: str,
321
+ num_experts: int,
322
+ ) -> List[Tuple[str, str, int, int]]:
323
+
324
+ gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
325
+ gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name]
326
+
327
+ return (
328
+ [
329
+ # These are the weight scales for the experts
330
+ # (param_name, weight_name, expert_id, shard_id)
331
+ (
332
+ (
333
+ "experts.w13_scale"
334
+ if weight_name in gate_up
335
+ else "experts.w2_scale"
336
+ ),
337
+ f"experts.{expert_id}.{weight_name}.weight_scale",
338
+ expert_id,
339
+ shard_id,
340
+ )
341
+ for expert_id in range(num_experts)
342
+ for shard_id, weight_name in enumerate(gate_down_up)
343
+ ]
344
+ + [
345
+ # These are the weights for the experts
346
+ # (param_name, weight_name, expert_id, shard_id)
347
+ (
348
+ (
349
+ "experts.w13_weight"
350
+ if weight_name in gate_up
351
+ else "experts.w2_weight"
352
+ ),
353
+ f"experts.{expert_id}.{weight_name}.weight",
354
+ expert_id,
355
+ shard_id,
356
+ )
357
+ for expert_id in range(num_experts)
358
+ for shard_id, weight_name in enumerate(gate_down_up)
359
+ ]
360
+ + [
361
+ # These are the weight scales for the experts
362
+ # (param_name, weight_name, expert_id, shard_id)
363
+ (
364
+ (
365
+ "experts.a13_scale"
366
+ if weight_name in gate_up
367
+ else "experts.a2_scale"
368
+ ),
369
+ f"experts.{expert_id}.{weight_name}.input_scale",
370
+ expert_id,
371
+ shard_id,
372
+ )
373
+ for expert_id in range(num_experts)
374
+ for shard_id, weight_name in enumerate(gate_down_up)
375
+ ]
376
+ )
377
+
378
+
379
+ import torch
380
+ from torch.nn import Module
381
+ from vllm import _custom_ops as ops
382
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
383
+ all_close_1d,
384
+ per_tensor_dequantize,
385
+ )
386
+ from vllm.utils import print_warning_once
387
+
388
+
389
+ class Fp8MoEMethod(FusedMoEMethodBase):
390
+ """MoE method for FP8.
391
+ Supports loading FP8 checkpoints with static weight scale and
392
+ dynamic/static activation scale.
393
+
394
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
395
+ activation scaling. The weight scaling factor will be initialized after
396
+ the model weights are loaded.
397
+
398
+ Args:
399
+ quant_config: The quantization config.
400
+ """
401
+
402
+ def __init__(self, quant_config: Fp8Config):
403
+ self.quant_config = quant_config
404
+
405
+ def create_weights(
406
+ self,
407
+ layer: Module,
408
+ num_experts: int,
409
+ hidden_size: int,
410
+ intermediate_size: int,
411
+ params_dtype: torch.dtype,
412
+ **extra_weight_attrs,
413
+ ):
414
+
415
+ if self.quant_config.is_checkpoint_fp8_serialized:
416
+ params_dtype = torch.float8_e4m3fn
417
+
418
+ # WEIGHTS
419
+ w13_weight = torch.nn.Parameter(
420
+ torch.empty(
421
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
422
+ ),
423
+ requires_grad=False,
424
+ )
425
+ layer.register_parameter("w13_weight", w13_weight)
426
+ set_weight_attrs(w13_weight, extra_weight_attrs)
427
+
428
+ w2_weight = torch.nn.Parameter(
429
+ torch.empty(
430
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
431
+ ),
432
+ requires_grad=False,
433
+ )
434
+ layer.register_parameter("w2_weight", w2_weight)
435
+ set_weight_attrs(w2_weight, extra_weight_attrs)
436
+
437
+ # WEIGHT_SCALES
438
+ # Allocate 2 scales for w1 and w3 respectively.
439
+ # They will be combined to a single scale after weight loading.
440
+ w13_scale = torch.nn.Parameter(
441
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
442
+ )
443
+ layer.register_parameter("w13_scale", w13_scale)
444
+
445
+ w2_scale = torch.nn.Parameter(
446
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
447
+ )
448
+ layer.register_parameter("w2_scale", w2_scale)
449
+
450
+ # If loading fp8 checkpoint, pass the weight loaders.
451
+ # If loading an fp16 checkpoint, do not (we will quantize in
452
+ # process_weights_after_loading()
453
+ if self.quant_config.is_checkpoint_fp8_serialized:
454
+ set_weight_attrs(w13_scale, extra_weight_attrs)
455
+ set_weight_attrs(w2_scale, extra_weight_attrs)
456
+
457
+ # INPUT_SCALES
458
+ if self.quant_config.activation_scheme == "static":
459
+ if not self.quant_config.is_checkpoint_fp8_serialized:
460
+ raise ValueError(
461
+ "Found static activation scheme for checkpoint that "
462
+ "was not serialized fp8."
463
+ )
464
+
465
+ a13_scale = torch.nn.Parameter(
466
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
467
+ )
468
+ layer.register_parameter("a13_scale", a13_scale)
469
+ set_weight_attrs(a13_scale, extra_weight_attrs)
470
+
471
+ a2_scale = torch.nn.Parameter(
472
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
473
+ )
474
+ layer.register_parameter("a2_scale", a2_scale)
475
+ set_weight_attrs(a2_scale, extra_weight_attrs)
476
+ else:
477
+ layer.a13_scale = None
478
+ layer.a2_scale = None
479
+
480
+ def process_weights_after_loading(self, layer: Module) -> None:
481
+
482
+ # If checkpoint is fp16, quantize in place.
483
+ if not self.quant_config.is_checkpoint_fp8_serialized:
484
+ w13_weight = torch.empty_like(
485
+ layer.w13_weight.data, dtype=torch.float8_e4m3fn
486
+ )
487
+ w2_weight = torch.empty_like(
488
+ layer.w2_weight.data, dtype=torch.float8_e4m3fn
489
+ )
490
+
491
+ # Re-initialize w13_scale because we directly quantize
492
+ # merged w13 weights and generate a single scaling factor.
493
+ layer.w13_scale = torch.nn.Parameter(
494
+ torch.ones(
495
+ layer.num_experts, dtype=torch.float32, device=w13_weight.device
496
+ ),
497
+ requires_grad=False,
498
+ )
499
+ for expert in range(layer.num_experts):
500
+ w13_weight[expert, :, :], layer.w13_scale[expert] = (
501
+ ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
502
+ )
503
+ w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
504
+ layer.w2_weight.data[expert, :, :]
505
+ )
506
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
507
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
508
+ return
509
+
510
+ # If checkpoint is fp8, we need to handle that the
511
+ # MoE kernels require single activation scale and single weight
512
+ # scale for w13 per expert.
513
+ else:
514
+ # Fp8 moe kernels require a single activation scale.
515
+ # We take the max of all the scales in case they differ.
516
+ if self.quant_config.activation_scheme == "static":
517
+ if layer.a13_scale is None or layer.a2_scale is None:
518
+ raise ValueError(
519
+ "QuantConfig has static quantization, but found "
520
+ "activation scales are None."
521
+ )
522
+ if not all_close_1d(layer.a13_scale) or not all_close_1d(
523
+ layer.a2_scale
524
+ ):
525
+ print_warning_once(
526
+ "Found input_scales that are not equal for "
527
+ "fp8 MoE layer. Using the maximum across experts "
528
+ "for each layer. "
529
+ )
530
+ layer.a13_scale = torch.nn.Parameter(
531
+ layer.a13_scale.max(), requires_grad=False
532
+ )
533
+ layer.a2_scale = torch.nn.Parameter(
534
+ layer.a2_scale.max(), requires_grad=False
535
+ )
536
+
537
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
538
+ # We take the max then dequant and requant each expert.
539
+ assert layer.w13_scale is not None
540
+ shard_size = layer.intermediate_size_per_partition
541
+ max_w13_scales = layer.w13_scale.max(dim=1).values
542
+ for expert_id in range(layer.num_experts):
543
+ start = 0
544
+ for shard_id in range(2):
545
+ dq_weight = per_tensor_dequantize(
546
+ layer.w13_weight[expert_id][start : start + shard_size, :],
547
+ layer.w13_scale[expert_id][shard_id],
548
+ )
549
+ layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
550
+ ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
551
+ )
552
+ start += shard_size
553
+
554
+ layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
555
+ return
556
+
557
+ def apply(
558
+ self,
559
+ layer: torch.nn.Module,
560
+ x: torch.Tensor,
561
+ router_logits: torch.Tensor,
562
+ top_k: int,
563
+ renormalize: bool = True,
564
+ use_grouped_topk: bool = False,
565
+ num_expert_group: Optional[int] = None,
566
+ topk_group: Optional[int] = None,
567
+ ) -> torch.Tensor:
568
+
569
+ from sglang.srt.layers.fused_moe.fused_moe import fused_moe
570
+
571
+ return fused_moe(
572
+ x,
573
+ layer.w13_weight,
574
+ layer.w2_weight,
575
+ router_logits,
576
+ top_k,
577
+ renormalize=renormalize,
578
+ inplace=True,
579
+ use_fp8=True,
580
+ w1_scale=layer.w13_scale,
581
+ w2_scale=layer.w2_scale,
582
+ a1_scale=layer.a13_scale,
583
+ a2_scale=layer.a2_scale,
584
+ use_grouped_topk=use_grouped_topk,
585
+ num_expert_group=num_expert_group,
586
+ topk_group=topk_group,
587
+ )
@@ -164,9 +164,9 @@ class LogitsProcessor(nn.Module):
164
164
  last_logits = last_logits[:, : self.config.vocab_size].float()
165
165
 
166
166
  if hasattr(self.config, "final_logit_softcapping"):
167
- last_logits /= self.config.final_logit_softcapping
167
+ last_logits.div_(self.config.final_logit_softcapping)
168
168
  last_logits = torch.tanh(last_logits)
169
- last_logits *= self.config.final_logit_softcapping
169
+ last_logits.mul_(self.config.final_logit_softcapping)
170
170
 
171
171
  # Return only last_logits if logprob is not requested
172
172
  if not logits_metadata.return_logprob:
@@ -209,9 +209,9 @@ class LogitsProcessor(nn.Module):
209
209
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
210
 
211
211
  if hasattr(self.config, "final_logit_softcapping"):
212
- all_logits /= self.config.final_logit_softcapping
212
+ all_logits.div_(self.config.final_logit_softcapping)
213
213
  all_logits = torch.tanh(all_logits)
214
- all_logits *= self.config.final_logit_softcapping
214
+ all_logits.mul_(self.config.final_logit_softcapping)
215
215
 
216
216
  all_logprobs = all_logits
217
217
  del all_logits, hidden_states