sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
- import os
5
4
  from typing import Any, Callable, Dict, List, Optional
6
5
 
7
6
  import torch
@@ -25,9 +24,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
25
24
  per_tensor_dequantize,
26
25
  requantize_with_max_scale,
27
26
  )
28
- from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
29
27
 
30
28
  from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
29
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
31
30
  from sglang.srt.layers.quantization.base_config import (
32
31
  QuantizationConfig,
33
32
  QuantizeMethodBase,
@@ -40,12 +39,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
40
39
  from sglang.srt.utils import (
41
40
  get_bool_env_var,
42
41
  is_hip,
42
+ permute_weight,
43
43
  print_warning_once,
44
44
  set_weight_attrs,
45
45
  )
46
46
 
47
47
  ACTIVATION_SCHEMES = ["static", "dynamic"]
48
48
 
49
+ is_hip_ = is_hip()
50
+
49
51
  logger = logging.getLogger(__name__)
50
52
 
51
53
 
@@ -161,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
161
163
  # kernel for fast weight-only FP8 quantization
162
164
  self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
163
165
  # Disable marlin for ROCm
164
- if is_hip():
166
+ if is_hip_:
165
167
  self.use_marlin = False
166
168
 
167
169
  self.block_quant = self.quant_config.weight_block_size is not None
@@ -273,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
273
275
  # Block quant doesn't need to process weights after loading
274
276
  if self.block_quant:
275
277
  # If ROCm, normalize the weights and scales to e4m3fnuz
276
- if is_hip():
278
+ if is_hip_:
277
279
  # activation_scheme: dynamic
278
280
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
279
281
  weight=layer.weight,
@@ -330,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
330
332
  weight_scale = layer.weight_scale
331
333
 
332
334
  # If ROCm, normalize the weights and scales to e4m3fnuz
333
- if is_hip():
335
+ if is_hip_:
334
336
  weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
335
337
  weight=weight,
336
338
  weight_scale=weight_scale,
@@ -567,7 +569,7 @@ class Fp8MoEMethod:
567
569
  # Block quant doesn't need to process weights after loading
568
570
  if self.block_quant:
569
571
  # If ROCm, normalize the weights and scales to e4m3fnuz
570
- if is_hip():
572
+ if is_hip_:
571
573
  # activation_scheme: dynamic
572
574
  w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
573
575
  weight=layer.w13_weight,
@@ -594,7 +596,7 @@ class Fp8MoEMethod:
594
596
  # If checkpoint is fp16 or bfloat16, quantize in place.
595
597
  if not self.quant_config.is_checkpoint_fp8_serialized:
596
598
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
597
- fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
599
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
598
600
  w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
599
601
  w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
600
602
 
@@ -616,18 +618,30 @@ class Fp8MoEMethod:
616
618
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
617
619
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
618
620
 
619
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
620
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
621
- layer.w13_weight = torch.nn.Parameter(
622
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
623
- requires_grad=False,
624
- )
625
- torch.cuda.empty_cache()
626
- layer.w2_weight = torch.nn.Parameter(
627
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
628
- requires_grad=False,
629
- )
630
- torch.cuda.empty_cache()
621
+ if is_hip_:
622
+ if get_bool_env_var("CK_MOE"):
623
+ layer.w13_weight = torch.nn.Parameter(
624
+ permute_weight(layer.w13_weight.data),
625
+ requires_grad=False,
626
+ )
627
+ torch.cuda.empty_cache()
628
+ layer.w2_weight = torch.nn.Parameter(
629
+ permute_weight(layer.w2_weight.data),
630
+ requires_grad=False,
631
+ )
632
+ torch.cuda.empty_cache()
633
+ elif get_bool_env_var("MOE_PADDING"):
634
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
635
+ layer.w13_weight = torch.nn.Parameter(
636
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
637
+ requires_grad=False,
638
+ )
639
+ torch.cuda.empty_cache()
640
+ layer.w2_weight = torch.nn.Parameter(
641
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
642
+ requires_grad=False,
643
+ )
644
+ torch.cuda.empty_cache()
631
645
  return
632
646
 
633
647
  # If checkpoint is fp8, we need to handle that the
@@ -658,7 +672,7 @@ class Fp8MoEMethod:
658
672
  )
659
673
 
660
674
  # If ROCm, normalize the weights and scales to e4m3fnuz
661
- if is_hip():
675
+ if is_hip_:
662
676
  # Normalize the weights and scales
663
677
  w13_weight, w13_weight_scale, w13_input_scale = (
664
678
  normalize_e4m3fn_to_e4m3fnuz(
@@ -708,18 +722,30 @@ class Fp8MoEMethod:
708
722
  max_w13_scales, requires_grad=False
709
723
  )
710
724
 
711
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
712
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
713
- layer.w13_weight = torch.nn.Parameter(
714
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
715
- requires_grad=False,
716
- )
717
- torch.cuda.empty_cache()
718
- layer.w2_weight = torch.nn.Parameter(
719
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
720
- requires_grad=False,
721
- )
722
- torch.cuda.empty_cache()
725
+ if is_hip_:
726
+ if get_bool_env_var("CK_MOE"):
727
+ layer.w13_weight = torch.nn.Parameter(
728
+ permute_weight(layer.w13_weight.data),
729
+ requires_grad=False,
730
+ )
731
+ torch.cuda.empty_cache()
732
+ layer.w2_weight = torch.nn.Parameter(
733
+ permute_weight(layer.w2_weight.data),
734
+ requires_grad=False,
735
+ )
736
+ torch.cuda.empty_cache()
737
+ elif get_bool_env_var("MOE_PADDING"):
738
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
739
+ layer.w13_weight = torch.nn.Parameter(
740
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
741
+ requires_grad=False,
742
+ )
743
+ torch.cuda.empty_cache()
744
+ layer.w2_weight = torch.nn.Parameter(
745
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
746
+ requires_grad=False,
747
+ )
748
+ torch.cuda.empty_cache()
723
749
  return
724
750
 
725
751
  def apply(
@@ -752,27 +778,55 @@ class Fp8MoEMethod:
752
778
  correction_bias=correction_bias,
753
779
  )
754
780
 
755
- # Expert fusion with FP8 quantization
756
- return fused_experts(
757
- x,
758
- layer.w13_weight,
759
- layer.w2_weight,
760
- topk_weights=topk_weights,
761
- topk_ids=topk_ids,
762
- inplace=True,
763
- use_fp8_w8a8=True,
764
- w1_scale=(
765
- layer.w13_weight_scale_inv
766
- if self.block_quant
767
- else layer.w13_weight_scale
768
- ),
769
- w2_scale=(
770
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
771
- ),
772
- a1_scale=layer.w13_input_scale,
773
- a2_scale=layer.w2_input_scale,
774
- block_shape=self.quant_config.weight_block_size,
775
- )
781
+ if is_hip_ and get_bool_env_var("CK_MOE"):
782
+ import ater
783
+ from ater.fused_moe import fused_experts_ck
784
+
785
+ return fused_experts_ck(
786
+ x,
787
+ layer.w13_weight,
788
+ layer.w2_weight,
789
+ topk_weights=topk_weights,
790
+ topk_ids=topk_ids,
791
+ use_fp8_w8a8=True,
792
+ w1_scale=(
793
+ layer.w13_weight_scale_inv
794
+ if self.block_quant
795
+ else layer.w13_weight_scale
796
+ ),
797
+ w2_scale=(
798
+ layer.w2_weight_scale_inv
799
+ if self.block_quant
800
+ else layer.w2_weight_scale
801
+ ),
802
+ a1_scale=layer.w13_input_scale,
803
+ a2_scale=layer.w2_input_scale,
804
+ )
805
+
806
+ else:
807
+ # Expert fusion with FP8 quantization
808
+ return fused_experts(
809
+ x,
810
+ layer.w13_weight,
811
+ layer.w2_weight,
812
+ topk_weights=topk_weights,
813
+ topk_ids=topk_ids,
814
+ inplace=True,
815
+ use_fp8_w8a8=True,
816
+ w1_scale=(
817
+ layer.w13_weight_scale_inv
818
+ if self.block_quant
819
+ else layer.w13_weight_scale
820
+ ),
821
+ w2_scale=(
822
+ layer.w2_weight_scale_inv
823
+ if self.block_quant
824
+ else layer.w2_weight_scale
825
+ ),
826
+ a1_scale=layer.w13_input_scale,
827
+ a2_scale=layer.w2_input_scale,
828
+ block_shape=self.quant_config.weight_block_size,
829
+ )
776
830
 
777
831
 
778
832
  class Fp8KVCacheMethod(BaseKVCacheMethod):
@@ -1,8 +1,8 @@
1
1
  from typing import List, Optional, Tuple
2
2
 
3
3
  import torch
4
- from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
5
4
 
5
+ from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
6
6
  from sglang.srt.layers.quantization.fp8_kernel import (
7
7
  per_token_group_quant_fp8,
8
8
  w8a8_block_fp8_matmul,
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _per_token_quant_int8(
8
+ x_ptr,
9
+ xq_ptr,
10
+ scale_ptr,
11
+ stride_x,
12
+ stride_xq,
13
+ N,
14
+ BLOCK: tl.constexpr,
15
+ ):
16
+ # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
17
+ row_id = tl.program_id(0)
18
+
19
+ cols = tl.arange(0, BLOCK)
20
+ mask = cols < N
21
+
22
+ x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
23
+ absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
24
+ scale_x = absmax / 127
25
+ x_q = x * (127 / absmax)
26
+ x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
27
+
28
+ tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
29
+ tl.store(scale_ptr + row_id, scale_x)
30
+
31
+
32
+ def per_token_quant_int8(x):
33
+ M = x.numel() // x.shape[-1]
34
+ N = x.shape[-1]
35
+ x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
36
+ scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
37
+ BLOCK = triton.next_power_of_2(N)
38
+ # heuristics for number of warps
39
+ num_warps = min(max(BLOCK // 256, 1), 8)
40
+
41
+ assert x.is_contiguous()
42
+ _per_token_quant_int8[(M,)](
43
+ x,
44
+ x_q,
45
+ scales,
46
+ stride_x=x.stride(-2),
47
+ stride_xq=x_q.stride(-2),
48
+ N=N,
49
+ BLOCK=BLOCK,
50
+ num_warps=num_warps,
51
+ num_stages=1,
52
+ )
53
+
54
+ return x_q, scales
@@ -0,0 +1,174 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch.nn.parameter import Parameter
8
+ from vllm.model_executor.layers.linear import LinearBase
9
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10
+ apply_fp8_linear,
11
+ cutlass_fp8_supported,
12
+ requantize_with_max_scale,
13
+ )
14
+
15
+ from sglang.srt.layers.linear import LinearMethodBase
16
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
17
+ from sglang.srt.layers.quantization.base_config import (
18
+ QuantizationConfig,
19
+ QuantizeMethodBase,
20
+ )
21
+
22
+ # Initialize logger for the module
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Supported activation schemes for the current configuration
26
+ ACTIVATION_SCHEMES = ["static"]
27
+
28
+
29
+ class ModelOptFp8Config(QuantizationConfig):
30
+ """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
31
+
32
+ def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
33
+ """
34
+ Args:
35
+ is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
36
+ """
37
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
38
+ if is_checkpoint_fp8_serialized:
39
+ logger.warning(
40
+ "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
41
+ )
42
+
43
+ @classmethod
44
+ def get_name(cls) -> str:
45
+ return "modelopt"
46
+
47
+ @classmethod
48
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
49
+ return [torch.bfloat16, torch.half]
50
+
51
+ @classmethod
52
+ def get_min_capability(cls) -> int:
53
+ return 89 # Minimum hardware capability (e.g., Hopper GPUs).
54
+
55
+ @classmethod
56
+ def get_config_filenames(cls) -> List[str]:
57
+ return ["hf_quant_config.json"]
58
+
59
+ @classmethod
60
+ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
61
+ quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
62
+
63
+ if "FP8" not in quant_method:
64
+ raise ValueError(
65
+ "ModelOpt only supports static FP8 quantization in SGLang. "
66
+ "Check the `hf_quant_config.json` file for your model's configuration."
67
+ )
68
+
69
+ return cls(is_checkpoint_fp8_serialized=True)
70
+
71
+ def get_quant_method(
72
+ self, layer: torch.nn.Module, prefix: str
73
+ ) -> Optional["QuantizeMethodBase"]:
74
+ return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
75
+
76
+ def get_scaled_act_names(self) -> List[str]:
77
+ return []
78
+
79
+
80
+ class ModelOptFp8LinearMethod(LinearMethodBase):
81
+ """Linear method for ModelOpt static FP8 quantization.
82
+
83
+ Supports loading FP8 checkpoints with static weight and activation scales.
84
+ Future support may include dynamic scales.
85
+
86
+ **Limitations**:
87
+ 1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations.
88
+ 2. Only supports the `float8_e4m3fn` data type.
89
+
90
+ Args:
91
+ quant_config (ModelOptFp8Config): The ModelOpt quantization configuration.
92
+ """
93
+
94
+ def __init__(self, quant_config: ModelOptFp8Config):
95
+ super().__init__()
96
+ self.quant_config = quant_config
97
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
98
+
99
+ def create_weights(
100
+ self,
101
+ layer: torch.nn.Module,
102
+ input_size_per_partition: int,
103
+ output_partition_sizes: List[int],
104
+ params_dtype: torch.dtype,
105
+ **extra_weight_attrs,
106
+ ) -> None:
107
+ """Creates and registers weights, weight scales, and input scales for FP8 quantization."""
108
+ output_size_per_partition = sum(output_partition_sizes)
109
+ weight_loader = extra_weight_attrs.get("weight_loader")
110
+ weight_dtype = (
111
+ torch.float8_e4m3fn
112
+ if self.quant_config.is_checkpoint_fp8_serialized
113
+ else params_dtype
114
+ )
115
+
116
+ # Set layer attributes
117
+ layer.logical_widths = output_partition_sizes
118
+ layer.input_size_per_partition = input_size_per_partition
119
+ layer.output_size_per_partition = output_size_per_partition
120
+
121
+ # Register weight
122
+ layer.register_parameter(
123
+ "weight",
124
+ ModelWeightParameter(
125
+ data=torch.empty(
126
+ output_size_per_partition,
127
+ input_size_per_partition,
128
+ dtype=weight_dtype,
129
+ ),
130
+ input_dim=1,
131
+ output_dim=0,
132
+ weight_loader=weight_loader,
133
+ ),
134
+ )
135
+
136
+ if self.quant_config.is_checkpoint_fp8_serialized:
137
+ # Register weight and input scales
138
+ for scale_name in ["weight_scale", "input_scale"]:
139
+ layer.register_parameter(
140
+ scale_name,
141
+ PerTensorScaleParameter(
142
+ data=torch.full(
143
+ (len(output_partition_sizes),),
144
+ torch.finfo(torch.float32).min,
145
+ dtype=torch.float32,
146
+ ),
147
+ weight_loader=weight_loader,
148
+ ),
149
+ )
150
+
151
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
152
+ """Requantizes weights after loading using the maximum scale."""
153
+ max_w_scale, quantized_weight = requantize_with_max_scale(
154
+ layer.weight, layer.weight_scale, layer.logical_widths
155
+ )
156
+ layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
157
+ layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
158
+ layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
159
+
160
+ def apply(
161
+ self,
162
+ layer: torch.nn.Module,
163
+ x: torch.Tensor,
164
+ bias: Optional[torch.Tensor] = None,
165
+ ) -> torch.Tensor:
166
+ """Applies FP8 linear transformation."""
167
+ return apply_fp8_linear(
168
+ input=x,
169
+ weight=layer.weight,
170
+ weight_scale=layer.weight_scale,
171
+ input_scale=layer.input_scale,
172
+ bias=bias,
173
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
174
+ )
@@ -0,0 +1,117 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+
5
+ from sglang.srt.utils import is_cuda_available
6
+
7
+ is_cuda = is_cuda_available()
8
+ if is_cuda:
9
+ from sgl_kernel import int8_scaled_mm
10
+
11
+ from torch.nn.parameter import Parameter
12
+
13
+ from sglang.srt.layers.linear import LinearMethodBase
14
+ from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
15
+ from sglang.srt.layers.quantization.base_config import (
16
+ QuantizationConfig,
17
+ QuantizeMethodBase,
18
+ )
19
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
20
+
21
+
22
+ class W8A8Int8Config(QuantizationConfig):
23
+ """Config class for W8A8 Int8 Quantization.
24
+
25
+ - Weight: static, per-channel, symmetric
26
+ - Activation: dynamic, per-token, symmetric
27
+ """
28
+
29
+ def __init__(self):
30
+ pass
31
+
32
+ @classmethod
33
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
34
+ return [torch.float16, torch.bfloat16]
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ return 75
39
+
40
+ @classmethod
41
+ def get_name(self) -> str:
42
+ return "w8a8_int8"
43
+
44
+ @classmethod
45
+ def get_config_filenames(cls) -> List[str]:
46
+ return []
47
+
48
+ @classmethod
49
+ def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
50
+ return cls()
51
+
52
+ def get_quant_method(
53
+ self,
54
+ layer: torch.nn.Module,
55
+ prefix: str,
56
+ ) -> Optional["QuantizeMethodBase"]:
57
+ from vllm.model_executor.layers.linear import LinearBase
58
+
59
+ if isinstance(layer, LinearBase):
60
+ return W8A8Int8LinearMethod(self)
61
+ return None
62
+
63
+ def get_scaled_act_names(self) -> List[str]:
64
+ return []
65
+
66
+
67
+ class W8A8Int8LinearMethod(LinearMethodBase):
68
+
69
+ def __init__(self, quantization_config: W8A8Int8Config):
70
+ self.quantization_config = quantization_config
71
+
72
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
73
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
74
+ layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
75
+
76
+ def create_weights(
77
+ self,
78
+ layer: torch.nn.Module,
79
+ input_size_per_partition: int,
80
+ output_partition_sizes: List[int],
81
+ input_size: int,
82
+ output_size: int,
83
+ params_dtype: torch.dtype,
84
+ **extra_weight_attrs
85
+ ):
86
+
87
+ weight_loader = extra_weight_attrs.get("weight_loader")
88
+ self.logical_widths = output_partition_sizes
89
+
90
+ weight = ModelWeightParameter(
91
+ data=torch.empty(
92
+ sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
93
+ ),
94
+ input_dim=1,
95
+ output_dim=0,
96
+ weight_loader=weight_loader,
97
+ )
98
+ layer.register_parameter("weight", weight)
99
+
100
+ weight_scale = ChannelQuantScaleParameter(
101
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
102
+ output_dim=0,
103
+ weight_loader=weight_loader,
104
+ )
105
+ layer.register_parameter("weight_scale", weight_scale)
106
+
107
+ def apply(
108
+ self,
109
+ layer: torch.nn.Module,
110
+ x: torch.Tensor,
111
+ bias: Optional[torch.Tensor] = None,
112
+ ):
113
+ x_q, x_scale = per_token_quant_int8(x)
114
+
115
+ return int8_scaled_mm(
116
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
117
+ )
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
47
47
  self.logit_cap = logit_cap
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
+ self.k_scale = 1.0
51
+ self.v_scale = 1.0
50
52
 
51
53
  def forward(
52
54
  self,
@@ -12,8 +12,8 @@ from vllm.distributed import (
12
12
  get_tensor_model_parallel_world_size,
13
13
  tensor_model_parallel_all_reduce,
14
14
  )
15
- from vllm.model_executor.parameter import BasevLLMParameter
16
15
 
16
+ from sglang.srt.layers.parameter import BasevLLMParameter
17
17
  from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
220
220
  quant_config: Optional[QuantizationConfig] = None,
221
221
  prefix: str = "",
222
222
  enable_tp: bool = True,
223
+ use_presharded_weights: bool = False,
223
224
  ):
224
225
  super().__init__()
225
226
  self.quant_config = quant_config
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
236
237
  self.padding_size = padding_size
237
238
  self.org_vocab_size = org_num_embeddings or num_embeddings
238
239
  num_added_embeddings = num_embeddings - self.org_vocab_size
240
+ self.use_presharded_weights = use_presharded_weights
241
+ if use_presharded_weights:
242
+ assert (
243
+ num_added_embeddings == 0
244
+ ), "Lora is not supported with presharded weights."
245
+
239
246
  self.org_vocab_size_padded = pad_vocab_size(
240
247
  self.org_vocab_size, self.padding_size
241
248
  )
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
447
454
  start_idx = start_idx // packed_factor
448
455
  shard_size = shard_size // packed_factor
449
456
  else:
450
- assert loaded_weight.shape[output_dim] == self.org_vocab_size
457
+ assert loaded_weight.shape[output_dim] == (
458
+ self.org_vocab_size
459
+ // (self.tp_size if self.use_presharded_weights else 1)
460
+ )
451
461
 
452
462
  # Copy the data.
453
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
463
+ if not self.use_presharded_weights:
464
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
454
465
  param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
455
466
  param[loaded_weight.shape[0] :].data.fill_(0)
456
467
 
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
514
525
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
515
526
  quant_config: Optional[QuantizationConfig] = None,
516
527
  prefix: str = "",
528
+ use_presharded_weights: bool = False,
517
529
  ):
518
530
  super().__init__(
519
531
  num_embeddings,
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
523
535
  padding_size,
524
536
  quant_config,
525
537
  prefix,
538
+ use_presharded_weights=use_presharded_weights,
526
539
  )
527
540
  self.quant_config = quant_config
528
541
  if bias: