sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
6
6
  import torch
7
7
  from torch.nn.parameter import Parameter
8
8
 
9
- from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
10
9
  from sglang.srt.layers.linear import LinearBase, LinearMethodBase
11
10
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
12
11
  from sglang.srt.layers.quantization.base_config import (
@@ -22,6 +21,11 @@ from sglang.srt.layers.quantization.utils import (
22
21
  convert_to_channelwise,
23
22
  requantize_with_max_scale,
24
23
  )
24
+ from sglang.srt.layers.radix_attention import RadixAttention
25
+ from sglang.srt.utils import is_cuda_available
26
+
27
+ if is_cuda_available():
28
+ from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
25
29
 
26
30
  # Initialize logger for the module
27
31
  logger = logging.getLogger(__name__)
@@ -33,12 +37,19 @@ ACTIVATION_SCHEMES = ["static"]
33
37
  class ModelOptFp8Config(QuantizationConfig):
34
38
  """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
35
39
 
36
- def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
40
+ def __init__(
41
+ self,
42
+ is_checkpoint_fp8_serialized: bool = False,
43
+ kv_cache_quant_method: Optional[str] = None,
44
+ exclude_modules: Optional[List[str]] = None,
45
+ ) -> None:
37
46
  """
38
47
  Args:
39
48
  is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
40
49
  """
41
50
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
51
+ self.kv_cache_quant_method = kv_cache_quant_method
52
+ self.exclude_modules = exclude_modules
42
53
  if is_checkpoint_fp8_serialized:
43
54
  logger.warning(
44
55
  "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
@@ -63,6 +74,12 @@ class ModelOptFp8Config(QuantizationConfig):
63
74
  @classmethod
64
75
  def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
65
76
  quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
77
+ kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
78
+ "kv_cache_quant_algo"
79
+ )
80
+ exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
81
+ "exclude_modules"
82
+ )
66
83
 
67
84
  if "FP8" not in quant_method:
68
85
  raise ValueError(
@@ -70,15 +87,23 @@ class ModelOptFp8Config(QuantizationConfig):
70
87
  "Check the `hf_quant_config.json` file for your model's configuration."
71
88
  )
72
89
 
73
- return cls(is_checkpoint_fp8_serialized=True)
90
+ return cls(
91
+ is_checkpoint_fp8_serialized=True,
92
+ kv_cache_quant_method=kv_cache_quant_method,
93
+ exclude_modules=exclude_modules,
94
+ )
74
95
 
75
96
  def get_quant_method(
76
97
  self, layer: torch.nn.Module, prefix: str
77
98
  ) -> Optional["QuantizeMethodBase"]:
99
+ if self.exclude_modules and any(
100
+ module in prefix for module in self.exclude_modules
101
+ ):
102
+ return None
78
103
 
79
104
  if isinstance(layer, LinearBase):
80
105
  return ModelOptFp8LinearMethod(self)
81
- if isinstance(layer, AttentionBackend):
106
+ if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
82
107
  return ModelOptFp8KVCacheMethod(self)
83
108
 
84
109
  return None
@@ -194,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
194
219
 
195
220
  def __init__(self, quant_config: ModelOptFp8Config):
196
221
  super().__init__(quant_config)
222
+
223
+
224
+ class ModelOptFp4Config(QuantizationConfig):
225
+ """Config class for FP4."""
226
+
227
+ def __init__(
228
+ self,
229
+ is_checkpoint_nvfp4_serialized: bool = False,
230
+ kv_cache_quant_algo: str = None,
231
+ group_size: int = None,
232
+ exclude_modules: List[str] = None,
233
+ ) -> None:
234
+ self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
235
+ if is_checkpoint_nvfp4_serialized:
236
+ logger.warning(
237
+ "Detected nvfp4 checkpoint. Please note that the "
238
+ "format is experimental and subject to change."
239
+ )
240
+ self.group_size = group_size
241
+ self.kv_cache_quant_algo = kv_cache_quant_algo
242
+ self.exclude_modules = exclude_modules
243
+
244
+ @classmethod
245
+ def get_name(cls) -> str:
246
+ return "modelopt_fp4"
247
+
248
+ @classmethod
249
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
250
+ return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
251
+
252
+ @classmethod
253
+ def get_min_capability(cls) -> int:
254
+ return 100
255
+
256
+ @classmethod
257
+ def get_config_filenames(cls) -> List[str]:
258
+ return ["hf_quant_config.json"]
259
+
260
+ @classmethod
261
+ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
262
+ quant_config = cls.get_from_keys(config, ["quantization"])
263
+ quant_method = quant_config["quant_algo"]
264
+ if not quant_method in ["FP8", "NVFP4"]:
265
+ raise ValueError(
266
+ f"ModelOpt currently only supports: FP8, NVFP4"
267
+ " quantizations in sglang. Please check the "
268
+ "`hf_quant_config.json` file for your model's "
269
+ "quant configuration."
270
+ )
271
+ is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
272
+ kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
273
+ group_size = quant_config["group_size"]
274
+ exclude_modules = quant_config["exclude_modules"]
275
+ if not (group_size and kv_cache_quant_algo and exclude_modules):
276
+ raise ValueError(
277
+ "NVFP4 quantization requires group size and "
278
+ "kv_cache_quant_algo specified in "
279
+ "hf_quant_config.json"
280
+ )
281
+ return cls(
282
+ is_checkpoint_nvfp4_serialized,
283
+ kv_cache_quant_algo,
284
+ group_size,
285
+ exclude_modules,
286
+ )
287
+
288
+ def get_quant_method(
289
+ self, layer: torch.nn.Module, prefix: str
290
+ ) -> Optional["QuantizeMethodBase"]:
291
+ if self.exclude_modules and any(
292
+ module in prefix for module in self.exclude_modules
293
+ ):
294
+ return None
295
+
296
+ if isinstance(layer, LinearBase):
297
+ return ModelOptFp4LinearMethod(self)
298
+ if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
299
+ return ModelOptFp8KVCacheMethod(self)
300
+
301
+ return None
302
+
303
+ def get_scaled_act_names(self) -> List[str]:
304
+ return []
305
+
306
+
307
+ class ModelOptFp4LinearMethod(LinearMethodBase):
308
+ """Linear method for NVFP4.
309
+ Supports loading NVFP4 checkpoints with the following structure:
310
+
311
+ |Tensor Name | datatype | shape |
312
+ |----------------------------------------------------|
313
+ |input_scale | torch.float32 | scalar |
314
+ |weight | NVFP4(SE2M1) | [1, X, y/2] |
315
+ |weight_scale | FP8-E4M3 | [X, Y] |
316
+ |weight_scale_2 | torch.float32 | scalar |
317
+
318
+ The weights are quantized per block of 16 elements.
319
+ Args: quant_config: The ModelOpt quantization config.
320
+ """
321
+
322
+ def __init__(self, quant_config: ModelOptFp4Config):
323
+ self.quant_config = quant_config
324
+
325
+ def create_weights(
326
+ self,
327
+ layer: torch.nn.Module,
328
+ input_size_per_partition: int,
329
+ output_partition_sizes: List[int],
330
+ input_size: int,
331
+ output_size: int,
332
+ params_dtype: torch.dtype,
333
+ **extra_weight_attrs,
334
+ ):
335
+ del input_size, output_size
336
+ if not self.quant_config.is_checkpoint_nvfp4_serialized:
337
+ raise ValueError(
338
+ "NVFP4 quantization was selected, "
339
+ " dynamic quantization is not supported."
340
+ )
341
+
342
+ output_size_per_partition = sum(output_partition_sizes)
343
+ weight_loader = extra_weight_attrs.get("weight_loader")
344
+
345
+ layer.logical_widths = output_partition_sizes
346
+
347
+ layer.input_size_per_partition = input_size_per_partition
348
+ layer.output_size_per_partition = output_size_per_partition
349
+ if input_size_per_partition % 16 != 0:
350
+ raise ValueError(
351
+ "Unsupported model when in features size is " "not multiple of 16"
352
+ )
353
+
354
+ weight_dtype = (
355
+ torch.float8_e4m3fn
356
+ if self.quant_config.is_checkpoint_nvfp4_serialized
357
+ else params_dtype
358
+ )
359
+
360
+ weight = ModelWeightParameter(
361
+ data=torch.empty(
362
+ # 2 fp4 data is packed in one uint8 in the input dimension
363
+ output_size_per_partition,
364
+ input_size_per_partition // 2,
365
+ dtype=torch.uint8,
366
+ ),
367
+ input_dim=1,
368
+ output_dim=0,
369
+ weight_loader=weight_loader,
370
+ )
371
+ layer.register_parameter("weight", weight)
372
+
373
+ input_scale = PerTensorScaleParameter(
374
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
375
+ weight_loader=weight_loader,
376
+ )
377
+
378
+ layer.register_parameter("input_scale", input_scale)
379
+
380
+ weight_scale_2 = PerTensorScaleParameter(
381
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
382
+ weight_loader=weight_loader,
383
+ )
384
+ layer.register_parameter("weight_scale_2", weight_scale_2)
385
+
386
+ weight_scale = ModelWeightParameter(
387
+ data=torch.empty(
388
+ output_size_per_partition,
389
+ input_size_per_partition // self.quant_config.group_size,
390
+ dtype=weight_dtype,
391
+ ),
392
+ input_dim=1,
393
+ output_dim=0,
394
+ weight_loader=weight_loader,
395
+ )
396
+
397
+ layer.register_parameter("weight_scale", weight_scale)
398
+
399
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
400
+ input_scale_2 = layer.input_scale.max().to(torch.float32)
401
+ weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
402
+ layer.input_scale = Parameter(input_scale_2, requires_grad=False)
403
+ layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
404
+ layer.alpha = Parameter(
405
+ layer.input_scale * layer.weight_scale_2, requires_grad=False
406
+ )
407
+
408
+ # Pad and blockwise interleave weight_scale
409
+ scales = layer.weight_scale
410
+ scale_ndim = scales.ndim
411
+ if scale_ndim == 2:
412
+ scales = scales.unsqueeze(0)
413
+ assert scales.ndim == 3
414
+ B, M, K = scales.shape
415
+ round_up_multiple = lambda x, m: (x + m - 1) // m * m
416
+ M_padded = round_up_multiple(M, 128)
417
+ K_padded = round_up_multiple(K, 4)
418
+ padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
419
+ padded_scales[:B, :M, :K] = scales
420
+ batches, rows, cols = padded_scales.shape
421
+ assert rows % 128 == 0
422
+ assert cols % 4 == 0
423
+ padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
424
+ padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
425
+ padded_scales = padded_scales.contiguous().cuda()
426
+ padded_scales = (
427
+ padded_scales.reshape(M, K)
428
+ if scale_ndim == 2
429
+ else padded_scales.reshape(B, M, K)
430
+ )
431
+ layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
432
+
433
+ def apply(
434
+ self,
435
+ layer: torch.nn.Module,
436
+ x: torch.Tensor,
437
+ bias: Optional[torch.Tensor] = None,
438
+ ) -> torch.Tensor:
439
+ output_dtype = x.dtype
440
+ x_m, _ = x.shape
441
+ w_n, _ = layer.weight.shape
442
+ output_shape = [x_m, w_n]
443
+
444
+ # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
445
+ x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
446
+
447
+ assert x_fp4.dtype == torch.uint8
448
+ assert x_scale_interleaved.dtype == torch.float8_e4m3fn
449
+ assert layer.weight.dtype == torch.uint8
450
+ assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
451
+ assert layer.alpha.dtype == torch.float32
452
+
453
+ out = cutlass_scaled_fp4_mm(
454
+ x_fp4,
455
+ layer.weight,
456
+ x_scale_interleaved,
457
+ layer.weight_scale_interleaved,
458
+ layer.alpha,
459
+ output_dtype,
460
+ )
461
+ if bias is not None:
462
+ out = out + bias
463
+ return out.view(*output_shape)
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Callable, Dict, List, Optional
2
2
 
3
3
  import torch
4
4
  from torch.nn.parameter import Parameter
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
16
16
  input_to_float8,
17
17
  normalize_e4m3fn_to_e4m3fnuz,
18
18
  )
19
- from sglang.srt.utils import is_hip
19
+ from sglang.srt.utils import is_hip, set_weight_attrs
20
20
 
21
21
  _is_hip = is_hip()
22
22
 
@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
62
62
  @classmethod
63
63
  def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
64
64
  quant_method = cls.get_from_keys(config, ["quant_method"])
65
- is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
65
+ is_checkpoint_fp8_serialized = (
66
+ "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
67
+ )
66
68
  return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
67
69
 
68
70
  def get_quant_method(
@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
71
73
  prefix: str,
72
74
  ) -> Optional["QuantizeMethodBase"]:
73
75
  from sglang.srt.layers.linear import LinearBase
76
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
74
77
 
75
78
  if isinstance(layer, LinearBase):
76
79
  return W8A8Fp8LinearMethod(self)
80
+ elif isinstance(layer, FusedMoE):
81
+ return W8A8FP8MoEMethod(self)
77
82
  return None
78
83
 
79
84
  def get_scaled_act_names(self) -> List[str]:
@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
131
136
  input_size: int,
132
137
  output_size: int,
133
138
  params_dtype: torch.dtype,
134
- **extra_weight_attrs
139
+ **extra_weight_attrs,
135
140
  ):
136
141
  weight_dtype = (
137
142
  torch.float8_e4m3fn
@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
177
182
  bias=bias,
178
183
  cutlass_fp8_supported=self.cutlass_fp8_supported,
179
184
  )
185
+
186
+
187
+ class W8A8FP8MoEMethod:
188
+ """MoE method for FP8.
189
+ Supports loading FP8 checkpoints with static weight scale and
190
+ dynamic/static activation scale.
191
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
192
+ activation scaling. The weight scaling factor will be initialized after
193
+ the model weights are loaded.
194
+ Args:
195
+ quant_config: The quantization config.
196
+ """
197
+
198
+ def __new__(cls, *args, **kwargs):
199
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
200
+
201
+ if not hasattr(cls, "_initialized"):
202
+ original_init = cls.__init__
203
+ new_cls = type(
204
+ cls.__name__,
205
+ (FusedMoEMethodBase,),
206
+ {
207
+ "__init__": original_init,
208
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
209
+ },
210
+ )
211
+ obj = super(new_cls, new_cls).__new__(new_cls)
212
+ obj.__init__(*args, **kwargs)
213
+ return obj
214
+ return super().__new__(cls)
215
+
216
+ def __init__(self, quant_config):
217
+ self.quant_config = quant_config
218
+
219
+ def create_weights(
220
+ self,
221
+ layer: torch.nn.Module,
222
+ num_experts: int,
223
+ hidden_size: int,
224
+ intermediate_size: int,
225
+ params_dtype: torch.dtype,
226
+ **extra_weight_attrs,
227
+ ):
228
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
229
+
230
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
231
+ # WEIGHTS
232
+ w13_weight = torch.nn.Parameter(
233
+ torch.empty(
234
+ num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
235
+ ),
236
+ requires_grad=False,
237
+ )
238
+ layer.register_parameter("w13_weight", w13_weight)
239
+ set_weight_attrs(w13_weight, extra_weight_attrs)
240
+
241
+ w2_weight = torch.nn.Parameter(
242
+ torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
243
+ requires_grad=False,
244
+ )
245
+ layer.register_parameter("w2_weight", w2_weight)
246
+ set_weight_attrs(w2_weight, extra_weight_attrs)
247
+
248
+ w13_weight_scale = torch.nn.Parameter(
249
+ torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
250
+ requires_grad=False,
251
+ )
252
+ w2_weight_scale = torch.nn.Parameter(
253
+ torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
254
+ requires_grad=False,
255
+ )
256
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
257
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
258
+
259
+ extra_weight_attrs.update(
260
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
261
+ )
262
+
263
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
264
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
265
+
266
+ w13_input_scale = None
267
+ layer.register_parameter("w13_input_scale", w13_input_scale)
268
+
269
+ w2_input_scale = None
270
+ layer.register_parameter("w2_input_scale", w2_input_scale)
271
+
272
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
273
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
274
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
275
+ layer.w13_weight_scale = Parameter(
276
+ layer.w13_weight_scale.data, requires_grad=False
277
+ )
278
+ layer.w2_weight_scale = Parameter(
279
+ layer.w2_weight_scale.data, requires_grad=False
280
+ )
281
+
282
+ def apply(
283
+ self,
284
+ layer: torch.nn.Module,
285
+ x: torch.Tensor,
286
+ router_logits: torch.Tensor,
287
+ top_k: int,
288
+ renormalize: bool,
289
+ use_grouped_topk: bool,
290
+ topk_group: Optional[int] = None,
291
+ num_expert_group: Optional[int] = None,
292
+ custom_routing_function: Optional[Callable] = None,
293
+ correction_bias: Optional[torch.Tensor] = None,
294
+ activation: str = "silu",
295
+ inplace: bool = True,
296
+ no_combine: bool = False,
297
+ ) -> torch.Tensor:
298
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
299
+ from sglang.srt.layers.moe.topk import select_experts
300
+
301
+ # Expert selection
302
+ topk_weights, topk_ids = select_experts(
303
+ hidden_states=x,
304
+ router_logits=router_logits,
305
+ use_grouped_topk=use_grouped_topk,
306
+ top_k=top_k,
307
+ renormalize=renormalize,
308
+ topk_group=topk_group,
309
+ num_expert_group=num_expert_group,
310
+ custom_routing_function=custom_routing_function,
311
+ correction_bias=correction_bias,
312
+ )
313
+
314
+ return fused_experts(
315
+ x,
316
+ layer.w13_weight,
317
+ layer.w2_weight,
318
+ topk_weights=topk_weights,
319
+ topk_ids=topk_ids,
320
+ inplace=inplace,
321
+ activation=activation,
322
+ use_fp8_w8a8=True,
323
+ per_channel_quant=True,
324
+ w1_scale=(layer.w13_weight_scale),
325
+ w2_scale=(layer.w2_weight_scale),
326
+ a1_scale=layer.w13_input_scale,
327
+ a2_scale=layer.w2_input_scale,
328
+ no_combine=no_combine,
329
+ )
@@ -260,6 +260,7 @@ class W8A8Int8MoEMethod:
260
260
  activation=activation,
261
261
  apply_router_weight_on_input=apply_router_weight_on_input,
262
262
  use_int8_w8a8=True,
263
+ per_channel_quant=True,
263
264
  w1_scale=(layer.w13_weight_scale),
264
265
  w2_scale=(layer.w2_weight_scale),
265
266
  a1_scale=layer.w13_input_scale,
@@ -13,8 +13,12 @@
13
13
  # ==============================================================================
14
14
  """Radix attention."""
15
15
 
16
+ from typing import Optional
17
+
16
18
  from torch import nn
17
19
 
20
+ from sglang.srt.layers.linear import UnquantizedLinearMethod
21
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
23
 
20
24
 
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
34
38
  v_head_dim: int = -1,
35
39
  sliding_window_size: int = -1,
36
40
  is_cross_attention: bool = False,
41
+ quant_config: Optional[QuantizationConfig] = None,
37
42
  prefix: str = "",
38
43
  use_irope: bool = False,
39
44
  ):
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
49
54
  self.logit_cap = logit_cap
50
55
  self.sliding_window_size = sliding_window_size or -1
51
56
  self.is_cross_attention = is_cross_attention
57
+ self.use_irope = use_irope
52
58
  self.k_scale = None
53
59
  self.v_scale = None
54
- self.use_irope = use_irope
60
+ self.k_scale_float = None
61
+ self.v_scale_float = None
62
+ self.quant_method = None
63
+ if quant_config is not None:
64
+ self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
65
+ if self.quant_method is not None:
66
+ self.quant_method.create_weights(self)
55
67
 
56
68
  def forward(
57
69
  self,
@@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
645
645
  cache = torch.cat((cos, sin), dim=-1)
646
646
  return cache
647
647
 
648
- def forward(
648
+ def forward_hip(self, *args, **kwargs):
649
+ return self.forward_native(*args, **kwargs)
650
+
651
+ def forward(self, *args, **kwargs):
652
+ if torch.compiler.is_compiling():
653
+ return self.forward_native(*args, **kwargs)
654
+ if _is_cuda_available:
655
+ return self.forward_cuda(*args, **kwargs)
656
+ else:
657
+ return self.forward_native(*args, **kwargs)
658
+
659
+ def forward_native(
649
660
  self,
650
661
  positions: torch.Tensor,
651
662
  query: torch.Tensor,