sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,607 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
+
3
+ import logging
4
+ import os
5
+ from typing import Any, Callable, Dict, List, Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.nn import Module
10
+ from torch.nn.parameter import Parameter
11
+ from vllm import _custom_ops as ops
12
+ from vllm.model_executor.layers.linear import LinearBase
13
+ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
14
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
15
+ apply_fp8_marlin_linear,
16
+ prepare_fp8_layer_for_marlin,
17
+ )
18
+ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
19
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20
+ all_close_1d,
21
+ apply_fp8_linear,
22
+ convert_to_channelwise,
23
+ cutlass_fp8_supported,
24
+ per_tensor_dequantize,
25
+ requantize_with_max_scale,
26
+ )
27
+ from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
28
+
29
+ from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
30
+ from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
31
+ from sglang.srt.layers.quantization.base_config import (
32
+ QuantizationConfig,
33
+ QuantizeMethodBase,
34
+ )
35
+ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
36
+ from sglang.srt.utils import (
37
+ get_bool_env_var,
38
+ is_hip,
39
+ print_warning_once,
40
+ set_weight_attrs,
41
+ )
42
+
43
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class Fp8Config(QuantizationConfig):
49
+ """Config class for FP8."""
50
+
51
+ def __init__(
52
+ self,
53
+ is_checkpoint_fp8_serialized: bool = False,
54
+ activation_scheme: str = "dynamic",
55
+ ignored_layers: Optional[List[str]] = None,
56
+ ) -> None:
57
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
58
+ if is_checkpoint_fp8_serialized:
59
+ logger.warning(
60
+ "Detected fp8 checkpoint. Please note that the "
61
+ "format is experimental and subject to change."
62
+ )
63
+ if activation_scheme not in ACTIVATION_SCHEMES:
64
+ raise ValueError(f"Unsupported activation scheme {activation_scheme}")
65
+ self.activation_scheme = activation_scheme
66
+ self.ignored_layers = ignored_layers or []
67
+
68
+ @classmethod
69
+ def get_name(cls) -> str:
70
+ return "fp8"
71
+
72
+ @classmethod
73
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
74
+ return [torch.bfloat16, torch.half]
75
+
76
+ @classmethod
77
+ def get_min_capability(cls) -> int:
78
+ return 80
79
+
80
+ @classmethod
81
+ def get_config_filenames(cls) -> List[str]:
82
+ return []
83
+
84
+ @classmethod
85
+ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
86
+ quant_method = cls.get_from_keys(config, ["quant_method"])
87
+ is_checkpoint_fp8_serialized = "fp8" in quant_method
88
+ activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
89
+ ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
90
+ return cls(
91
+ is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
92
+ activation_scheme=activation_scheme,
93
+ ignored_layers=ignored_layers,
94
+ )
95
+
96
+ def get_quant_method(
97
+ self, layer: torch.nn.Module, prefix: str
98
+ ) -> Optional["QuantizeMethodBase"]:
99
+ from vllm.attention.layer import Attention # Avoid circular import
100
+
101
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
102
+
103
+ if isinstance(layer, LinearBase):
104
+ if is_layer_skipped(prefix, self.ignored_layers):
105
+ return UnquantizedLinearMethod()
106
+ return Fp8LinearMethod(self)
107
+ elif isinstance(layer, FusedMoE):
108
+ return Fp8MoEMethod(self)
109
+ elif isinstance(layer, Attention):
110
+ return Fp8KVCacheMethod(self)
111
+ return None
112
+
113
+ def get_scaled_act_names(self) -> List[str]:
114
+ return []
115
+
116
+
117
+ class Fp8LinearMethod(LinearMethodBase):
118
+ """Linear method for FP8.
119
+ Supports loading FP8 checkpoints with static weight scale and
120
+ dynamic/static activation scale.
121
+
122
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
123
+ activation scaling. The weight scaling factor will be initialized after
124
+ the model weights are loaded.
125
+
126
+ Limitations:
127
+ 1. Only support per-tensor quantization due to torch._scaled_mm support.
128
+ 2. Only support float8_e4m3fn data type due to the limitation of
129
+ torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
130
+
131
+ Args:
132
+ quant_config: The quantization config.
133
+ """
134
+
135
+ def __init__(self, quant_config: Fp8Config):
136
+ self.quant_config = quant_config
137
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
138
+
139
+ # For GPUs that lack FP8 hardware support, we can leverage the Marlin
140
+ # kernel for fast weight-only FP8 quantization
141
+ self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
142
+ # Disable marlin for ROCm
143
+ if is_hip():
144
+ self.use_marlin = False
145
+
146
+ def create_weights(
147
+ self,
148
+ layer: torch.nn.Module,
149
+ input_size_per_partition: int,
150
+ output_partition_sizes: List[int],
151
+ input_size: int,
152
+ output_size: int,
153
+ params_dtype: torch.dtype,
154
+ **extra_weight_attrs,
155
+ ):
156
+ del input_size, output_size
157
+ output_size_per_partition = sum(output_partition_sizes)
158
+ weight_loader = extra_weight_attrs.get("weight_loader")
159
+
160
+ layer.logical_widths = output_partition_sizes
161
+
162
+ layer.input_size_per_partition = input_size_per_partition
163
+ layer.output_size_per_partition = output_size_per_partition
164
+ layer.orig_dtype = params_dtype
165
+
166
+ # WEIGHT
167
+ weight_dtype = (
168
+ torch.float8_e4m3fn
169
+ if self.quant_config.is_checkpoint_fp8_serialized
170
+ else params_dtype
171
+ )
172
+
173
+ weight = ModelWeightParameter(
174
+ data=torch.empty(
175
+ output_size_per_partition, input_size_per_partition, dtype=weight_dtype
176
+ ),
177
+ input_dim=1,
178
+ output_dim=0,
179
+ weight_loader=weight_loader,
180
+ )
181
+ layer.register_parameter("weight", weight)
182
+
183
+ # If checkpoint is serialized fp8, load them.
184
+ # Otherwise, wait until process_weights_after_loading.
185
+ if self.quant_config.is_checkpoint_fp8_serialized:
186
+ # WEIGHT SCALE
187
+ scale = PerTensorScaleParameter(
188
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
189
+ weight_loader=weight_loader,
190
+ )
191
+
192
+ scale[:] = torch.finfo(torch.float32).min
193
+ layer.register_parameter("weight_scale", scale)
194
+
195
+ # INPUT ACTIVATION SCALE
196
+ if self.quant_config.activation_scheme == "static":
197
+ scale = PerTensorScaleParameter(
198
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
199
+ weight_loader=weight_loader,
200
+ )
201
+
202
+ scale[:] = torch.finfo(torch.float32).min
203
+ layer.register_parameter("input_scale", scale)
204
+ else:
205
+ layer.register_parameter("input_scale", None)
206
+
207
+ def process_weights_after_loading(self, layer: Module) -> None:
208
+ layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
209
+ # If checkpoint not serialized fp8, quantize the weights.
210
+ if not self.quant_config.is_checkpoint_fp8_serialized:
211
+ qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
212
+
213
+ # If using marlin (w8a16), kernel uses channelwise weights,
214
+ # so extend the weight scales to be channelwise.
215
+ if self.use_marlin:
216
+ assert weight_scale.numel() == 1
217
+ weight_scale = convert_to_channelwise(
218
+ weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
219
+ )
220
+
221
+ # Update the layer with the new values.
222
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
223
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
224
+ layer.input_scale = None
225
+
226
+ # If checkpoint is fp8, handle that there are N scales for N
227
+ # shards in a fused module
228
+ else:
229
+ layer.weight_scale = torch.nn.Parameter(
230
+ layer.weight_scale.data, requires_grad=False
231
+ )
232
+ if self.quant_config.activation_scheme == "static":
233
+ layer.input_scale = torch.nn.Parameter(
234
+ layer.input_scale.data, requires_grad=False
235
+ )
236
+ # If using marlin (w8a16), kernel uses channelwise weights,
237
+ # so extend the weight scales to be channelwise.
238
+ if self.use_marlin:
239
+ weight = layer.weight
240
+ weight_scale = convert_to_channelwise(
241
+ layer.weight_scale, layer.logical_widths
242
+ )
243
+
244
+ # If using w8a8, torch._scaled_mm needs per tensor, so
245
+ # requantize the logical shards as a single weight.
246
+ else:
247
+ # Dequant -> Quant with max scale so we can run per tensor.
248
+ weight = layer.weight
249
+ weight_scale = layer.weight_scale
250
+
251
+ # If ROCm, normalize the weights and scales to e4m3fnuz
252
+ if is_hip():
253
+ weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
254
+ weight=weight,
255
+ weight_scale=weight_scale,
256
+ input_scale=layer.input_scale,
257
+ )
258
+ if input_scale is not None:
259
+ layer.input_scale = Parameter(input_scale, requires_grad=False)
260
+
261
+ weight_scale, weight = requantize_with_max_scale(
262
+ weight=weight,
263
+ weight_scale=weight_scale,
264
+ logical_widths=layer.logical_widths,
265
+ )
266
+
267
+ # Update layer with new values.
268
+ layer.weight = Parameter(weight.t(), requires_grad=False)
269
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
270
+ if self.quant_config.activation_scheme == "static":
271
+ layer.input_scale = Parameter(
272
+ layer.input_scale.max(), requires_grad=False
273
+ )
274
+
275
+ if self.use_marlin:
276
+ prepare_fp8_layer_for_marlin(layer)
277
+ # Activations not quantized for marlin.
278
+ del layer.input_scale
279
+
280
+ def apply(
281
+ self,
282
+ layer: torch.nn.Module,
283
+ x: torch.Tensor,
284
+ bias: Optional[torch.Tensor] = None,
285
+ ) -> torch.Tensor:
286
+
287
+ if self.use_marlin:
288
+ return apply_fp8_marlin_linear(
289
+ input=x,
290
+ weight=layer.weight,
291
+ weight_scale=layer.weight_scale,
292
+ workspace=layer.workspace,
293
+ size_n=layer.output_size_per_partition,
294
+ size_k=layer.input_size_per_partition,
295
+ bias=bias,
296
+ )
297
+
298
+ return apply_fp8_linear(
299
+ input=x,
300
+ weight=layer.weight,
301
+ weight_scale=layer.weight_scale,
302
+ input_scale=layer.input_scale,
303
+ bias=bias,
304
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
305
+ use_per_token_if_dynamic=False,
306
+ )
307
+
308
+
309
+ class Fp8MoEMethod:
310
+ """MoE method for FP8.
311
+ Supports loading FP8 checkpoints with static weight scale and
312
+ dynamic/static activation scale.
313
+
314
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
315
+ activation scaling. The weight scaling factor will be initialized after
316
+ the model weights are loaded.
317
+
318
+ Args:
319
+ quant_config: The quantization config.
320
+ """
321
+
322
+ def __new__(cls, *args, **kwargs):
323
+ from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
324
+
325
+ if not hasattr(cls, "_initialized"):
326
+ original_init = cls.__init__
327
+ new_cls = type(
328
+ cls.__name__,
329
+ (FusedMoEMethodBase,),
330
+ {
331
+ "__init__": original_init,
332
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
333
+ },
334
+ )
335
+ obj = super(new_cls, new_cls).__new__(new_cls)
336
+ obj.__init__(*args, **kwargs)
337
+ return obj
338
+ return super().__new__(cls)
339
+
340
+ def __init__(self, quant_config):
341
+ self.quant_config = quant_config
342
+
343
+ def create_weights(
344
+ self,
345
+ layer: Module,
346
+ num_experts: int,
347
+ hidden_size: int,
348
+ intermediate_size: int,
349
+ params_dtype: torch.dtype,
350
+ **extra_weight_attrs,
351
+ ):
352
+ from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
353
+
354
+ if self.quant_config.is_checkpoint_fp8_serialized:
355
+ params_dtype = torch.float8_e4m3fn
356
+
357
+ # WEIGHTS
358
+ w13_weight = torch.nn.Parameter(
359
+ torch.empty(
360
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
361
+ ),
362
+ requires_grad=False,
363
+ )
364
+ layer.register_parameter("w13_weight", w13_weight)
365
+ set_weight_attrs(w13_weight, extra_weight_attrs)
366
+
367
+ w2_weight = torch.nn.Parameter(
368
+ torch.empty(
369
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
370
+ ),
371
+ requires_grad=False,
372
+ )
373
+ layer.register_parameter("w2_weight", w2_weight)
374
+ set_weight_attrs(w2_weight, extra_weight_attrs)
375
+
376
+ # WEIGHT_SCALES
377
+ # Allocate 2 scales for w1 and w3 respectively.
378
+ # They will be combined to a single scale after weight loading.
379
+ w13_weight_scale = torch.nn.Parameter(
380
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
381
+ )
382
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
383
+
384
+ w2_weight_scale = torch.nn.Parameter(
385
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
386
+ )
387
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
388
+ # Add the quantization method used (per tensor/grouped/channel)
389
+ # to ensure the weight scales are loaded in properly
390
+ extra_weight_attrs.update(
391
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
392
+ )
393
+ # If loading fp8 checkpoint, pass the weight loaders.
394
+ # If loading an fp16 checkpoint, do not (we will quantize in
395
+ # process_weights_after_loading()
396
+ if self.quant_config.is_checkpoint_fp8_serialized:
397
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
398
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
399
+
400
+ # INPUT_SCALES
401
+ if self.quant_config.activation_scheme == "static":
402
+ if not self.quant_config.is_checkpoint_fp8_serialized:
403
+ raise ValueError(
404
+ "Found static activation scheme for checkpoint that "
405
+ "was not serialized fp8."
406
+ )
407
+
408
+ w13_input_scale = torch.nn.Parameter(
409
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
410
+ )
411
+ layer.register_parameter("w13_input_scale", w13_input_scale)
412
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
413
+
414
+ w2_input_scale = torch.nn.Parameter(
415
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
416
+ )
417
+ layer.register_parameter("w2_input_scale", w2_input_scale)
418
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
419
+
420
+ else:
421
+ layer.w13_input_scale = None
422
+ layer.w2_input_scale = None
423
+
424
+ def process_weights_after_loading(self, layer: Module) -> None:
425
+
426
+ # If checkpoint is fp16 or bfloat16, quantize in place.
427
+ if not self.quant_config.is_checkpoint_fp8_serialized:
428
+ # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
429
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
430
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
431
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
432
+
433
+ # Re-initialize w13_scale because we directly quantize
434
+ # merged w13 weights and generate a single scaling factor.
435
+ layer.w13_weight_scale = torch.nn.Parameter(
436
+ torch.ones(
437
+ layer.num_experts, dtype=torch.float32, device=w13_weight.device
438
+ ),
439
+ requires_grad=False,
440
+ )
441
+ for expert in range(layer.num_experts):
442
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
443
+ ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
444
+ )
445
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
446
+ ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
447
+ )
448
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
449
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
450
+
451
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
452
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
453
+ layer.w13_weight = torch.nn.Parameter(
454
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
455
+ requires_grad=False,
456
+ )
457
+ torch.cuda.empty_cache()
458
+ layer.w2_weight = torch.nn.Parameter(
459
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
460
+ requires_grad=False,
461
+ )
462
+ torch.cuda.empty_cache()
463
+ return
464
+
465
+ # If checkpoint is fp8, we need to handle that the
466
+ # MoE kernels require single activation scale and single weight
467
+ # scale for w13 per expert.
468
+ else:
469
+ # Fp8 moe kernels require a single activation scale.
470
+ # We take the max of all the scales in case they differ.
471
+ if self.quant_config.activation_scheme == "static":
472
+ if layer.w13_input_scale is None or layer.w2_input_scale is None:
473
+ raise ValueError(
474
+ "QuantConfig has static quantization, but found "
475
+ "activation scales are None."
476
+ )
477
+ if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
478
+ layer.w2_input_scale
479
+ ):
480
+ print_warning_once(
481
+ "Found input_scales that are not equal for "
482
+ "fp8 MoE layer. Using the maximum across experts "
483
+ "for each layer. "
484
+ )
485
+ layer.w13_input_scale = torch.nn.Parameter(
486
+ layer.w13_input_scale.max(), requires_grad=False
487
+ )
488
+ layer.w2_input_scale = torch.nn.Parameter(
489
+ layer.w2_input_scale.max(), requires_grad=False
490
+ )
491
+
492
+ # If ROCm, normalize the weights and scales to e4m3fnuz
493
+ if is_hip():
494
+ # Normalize the weights and scales
495
+ w13_weight, w13_weight_scale, w13_input_scale = (
496
+ normalize_e4m3fn_to_e4m3fnuz(
497
+ layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
498
+ )
499
+ )
500
+ w2_weight, w2_weight_scale, w2_input_scale = (
501
+ normalize_e4m3fn_to_e4m3fnuz(
502
+ layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
503
+ )
504
+ )
505
+ # Reset the parameter
506
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
507
+ layer.w13_weight_scale = torch.nn.Parameter(
508
+ w13_weight_scale, requires_grad=False
509
+ )
510
+ if w13_input_scale is not None:
511
+ layer.w13_input_scale = torch.nn.Parameter(
512
+ w13_input_scale, requires_grad=False
513
+ )
514
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
515
+ layer.w2_weight_scale = torch.nn.Parameter(
516
+ w2_weight_scale, requires_grad=False
517
+ )
518
+ if w2_input_scale is not None:
519
+ layer.w2_input_scale = torch.nn.Parameter(
520
+ w2_input_scale, requires_grad=False
521
+ )
522
+
523
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
524
+ # We take the max then dequant and requant each expert.
525
+ assert layer.w13_weight_scale is not None
526
+ shard_size = layer.intermediate_size_per_partition
527
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
528
+ for expert_id in range(layer.num_experts):
529
+ start = 0
530
+ for shard_id in range(2):
531
+ dq_weight = per_tensor_dequantize(
532
+ layer.w13_weight[expert_id][start : start + shard_size, :],
533
+ layer.w13_weight_scale[expert_id][shard_id],
534
+ )
535
+ layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
536
+ ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
537
+ )
538
+ start += shard_size
539
+
540
+ layer.w13_weight_scale = torch.nn.Parameter(
541
+ max_w13_scales, requires_grad=False
542
+ )
543
+
544
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
545
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
546
+ layer.w13_weight = torch.nn.Parameter(
547
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
548
+ requires_grad=False,
549
+ )
550
+ torch.cuda.empty_cache()
551
+ layer.w2_weight = torch.nn.Parameter(
552
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
553
+ requires_grad=False,
554
+ )
555
+ torch.cuda.empty_cache()
556
+ return
557
+
558
+ def apply(
559
+ self,
560
+ layer: torch.nn.Module,
561
+ x: torch.Tensor,
562
+ router_logits: torch.Tensor,
563
+ top_k: int,
564
+ renormalize: bool,
565
+ use_grouped_topk: bool,
566
+ topk_group: Optional[int] = None,
567
+ num_expert_group: Optional[int] = None,
568
+ custom_routing_function: Optional[Callable] = None,
569
+ ) -> torch.Tensor:
570
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
571
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
572
+
573
+ # Expert selection
574
+ topk_weights, topk_ids = FusedMoE.select_experts(
575
+ hidden_states=x,
576
+ router_logits=router_logits,
577
+ use_grouped_topk=use_grouped_topk,
578
+ top_k=top_k,
579
+ renormalize=renormalize,
580
+ topk_group=topk_group,
581
+ num_expert_group=num_expert_group,
582
+ custom_routing_function=custom_routing_function,
583
+ )
584
+
585
+ # Expert fusion with FP8 quantization
586
+ return fused_experts(
587
+ x,
588
+ layer.w13_weight,
589
+ layer.w2_weight,
590
+ topk_weights=topk_weights,
591
+ topk_ids=topk_ids,
592
+ inplace=True,
593
+ use_fp8_w8a8=True,
594
+ w1_scale=layer.w13_weight_scale,
595
+ w2_scale=layer.w2_weight_scale,
596
+ a1_scale=layer.w13_input_scale,
597
+ a2_scale=layer.w2_input_scale,
598
+ )
599
+
600
+
601
+ class Fp8KVCacheMethod(BaseKVCacheMethod):
602
+ """
603
+ Supports loading kv-cache scaling factors from FP8 checkpoints.
604
+ """
605
+
606
+ def __init__(self, quant_config: Fp8Config):
607
+ super().__init__(quant_config)
@@ -0,0 +1,27 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_e4m3fn_to_e4m3fnuz(
7
+ weight: torch.Tensor,
8
+ weight_scale: torch.Tensor,
9
+ input_scale: Optional[torch.Tensor] = None,
10
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
11
+ assert weight.dtype == torch.float8_e4m3fn
12
+ # The bits pattern 10000000(-128) represents zero in e4m3fn
13
+ # but NaN in e4m3fnuz. So here we set it to 0.
14
+ # https://onnx.ai/onnx/technical/float8.html
15
+ weight_as_int8 = weight.view(torch.int8)
16
+ ROCM_FP8_NAN_AS_INT = -128
17
+ weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
18
+ weight = weight_as_int8.view(torch.float8_e4m3fnuz)
19
+
20
+ # For the same bits representation, e4m3fnuz value is half of
21
+ # the e4m3fn value, so we should double the scaling factor to
22
+ # get the same dequantized value.
23
+ # https://onnx.ai/onnx/technical/float8.html
24
+ weight_scale = weight_scale * 2.0
25
+ if input_scale is not None:
26
+ input_scale = input_scale * 2.0
27
+ return weight, weight_scale, input_scale
@@ -48,11 +48,20 @@ class RadixAttention(nn.Module):
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
50
 
51
- def forward(self, q, k, v, forward_batch: ForwardBatch):
51
+ def forward(
52
+ self,
53
+ q,
54
+ k,
55
+ v,
56
+ forward_batch: ForwardBatch,
57
+ save_kv_cache: bool = True,
58
+ ):
52
59
  if k is not None:
53
60
  # For cross-layer sharing, kv can be None
54
61
  assert v is not None
55
62
  k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
56
63
  v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
57
64
 
58
- return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
65
+ return forward_batch.attn_backend.forward(
66
+ q, k, v, self, forward_batch, save_kv_cache
67
+ )