sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ __all__ = ["CompressedTensorsScheme"]
10
+
11
+
12
+ class CompressedTensorsScheme(ABC):
13
+ """
14
+ Abstract class used to describe the weight creation and forward pass
15
+ of different quantization schemes supported by CompressedTensors.
16
+ """
17
+
18
+ @classmethod
19
+ @abstractmethod
20
+ def get_min_capability(cls) -> int:
21
+ """
22
+ Get minimum device capability.
23
+ """
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def create_weights(self, *args, **kwargs):
28
+ """
29
+ Weight creation for the particular scheme. Inputs to this function
30
+
31
+ """
32
+ raise NotImplementedError
33
+
34
+ @abstractmethod
35
+ def apply_weights(
36
+ self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
37
+ ):
38
+ """
39
+ Run the forward pass for the particular scheme. This is where
40
+ scheme-specific dequant/quant steps/kernels should be applied.
41
+
42
+ :param layer: torch.nn.Module with the registered weights and
43
+ other parameters relevant to the particular scheme.
44
+ :param x: input to the layer
45
+ :param bias: bias parameter
46
+
47
+ """
48
+ raise NotImplementedError
49
+
50
+ @abstractmethod
51
+ def process_weights_after_loading(self, layer: torch.nn.Module):
52
+ """
53
+ Called after weight loading is complete for any cleanup that
54
+ needs to occur.
55
+ """
56
+ raise NotImplementedError
@@ -0,0 +1,162 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Callable, List, Optional
5
+
6
+ import torch
7
+ from compressed_tensors.quantization import QuantizationStrategy
8
+ from torch.nn import Parameter
9
+
10
+ from sglang.srt.layers.parameter import (
11
+ ChannelQuantScaleParameter,
12
+ ModelWeightParameter,
13
+ PerTensorScaleParameter,
14
+ )
15
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
+ CompressedTensorsScheme,
17
+ )
18
+ from sglang.srt.layers.quantization.fp8_utils import (
19
+ Fp8LinearOp,
20
+ maybe_create_device_identity,
21
+ normalize_e4m3fn_to_e4m3fnuz,
22
+ )
23
+ from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
24
+
25
+ __all__ = ["CompressedTensorsW8A8Fp8"]
26
+
27
+
28
+ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
29
+
30
+ def __init__(self, strategy: str, is_static_input_scheme: bool):
31
+ self.strategy = strategy
32
+ self.is_static_input_scheme = is_static_input_scheme
33
+ self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
34
+
35
+ @classmethod
36
+ def get_min_capability(cls) -> int:
37
+ # lovelace and up
38
+ return 89
39
+
40
+ def process_weights_after_loading(self, layer) -> None:
41
+ # If per tensor, when we have a fused module (e.g. QKV) with per
42
+ # tensor scales (thus N scales being passed to the kernel),
43
+ # requantize so we can always run per tensor
44
+ if self.strategy == QuantizationStrategy.TENSOR:
45
+ max_w_scale, weight = requantize_with_max_scale(
46
+ weight=layer.weight,
47
+ weight_scale=layer.weight_scale,
48
+ logical_widths=layer.logical_widths,
49
+ )
50
+
51
+ if is_fp8_fnuz():
52
+ input_scale = getattr(layer, "input_scale", None)
53
+
54
+ weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
55
+ weight=weight, weight_scale=max_w_scale, input_scale=input_scale
56
+ )
57
+ if input_scale is not None:
58
+ layer.input_scale = Parameter(input_scale, requires_grad=False)
59
+
60
+ layer.weight = Parameter(weight.t(), requires_grad=False)
61
+ layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
62
+
63
+ # If channelwise, scales are already lined up, so just transpose.
64
+ elif self.strategy == QuantizationStrategy.CHANNEL:
65
+ weight = layer.weight
66
+
67
+ if is_fp8_fnuz():
68
+ input_scale = getattr(layer, "input_scale", None)
69
+
70
+ weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
71
+ weight=weight,
72
+ weight_scale=layer.weight_scale,
73
+ input_scale=input_scale,
74
+ )
75
+ if input_scale is not None:
76
+ layer.input_scale = Parameter(input_scale, requires_grad=False)
77
+ else:
78
+ weight_scale = layer.weight_scale.data
79
+
80
+ layer.weight = Parameter(weight.t(), requires_grad=False)
81
+ # required by torch.compile to be torch.nn.Parameter
82
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
83
+
84
+ else:
85
+ raise ValueError(f"Unknown quantization strategy {self.strategy}")
86
+
87
+ # INPUT SCALE
88
+ if self.is_static_input_scheme and hasattr(layer, "input_scale"):
89
+ layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
90
+ else:
91
+ layer.input_scale = None
92
+
93
+ def create_weights(
94
+ self,
95
+ layer: torch.nn.Module,
96
+ output_partition_sizes: List[int],
97
+ input_size_per_partition: int,
98
+ params_dtype: torch.dtype,
99
+ weight_loader: Callable,
100
+ **kwargs,
101
+ ):
102
+ maybe_create_device_identity()
103
+
104
+ output_size_per_partition = sum(output_partition_sizes)
105
+ layer.logical_widths = output_partition_sizes
106
+
107
+ # WEIGHT
108
+ weight = ModelWeightParameter(
109
+ data=torch.empty(
110
+ output_size_per_partition,
111
+ input_size_per_partition,
112
+ dtype=torch.float8_e4m3fn,
113
+ ),
114
+ input_dim=1,
115
+ output_dim=0,
116
+ weight_loader=weight_loader,
117
+ )
118
+ layer.register_parameter("weight", weight)
119
+
120
+ # WEIGHT SCALE
121
+ # TODO: update create_xxx_parameter functions to return
122
+ # the newly added parameters
123
+ if self.strategy == QuantizationStrategy.CHANNEL:
124
+ weight_scale = ChannelQuantScaleParameter(
125
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
126
+ output_dim=0,
127
+ weight_loader=weight_loader,
128
+ )
129
+ else:
130
+ assert self.strategy == QuantizationStrategy.TENSOR
131
+ weight_scale = PerTensorScaleParameter(
132
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
133
+ weight_loader=weight_loader,
134
+ )
135
+
136
+ # min requirement for fp8 kernels
137
+ weight_scale[:] = torch.finfo(torch.float32).min
138
+ layer.register_parameter("weight_scale", weight_scale)
139
+
140
+ # INPUT SCALE
141
+ if self.is_static_input_scheme:
142
+ input_scale = PerTensorScaleParameter(
143
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
144
+ weight_loader=weight_loader,
145
+ )
146
+ input_scale[:] = torch.finfo(torch.float32).min
147
+ layer.register_parameter("input_scale", input_scale)
148
+
149
+ def apply_weights(
150
+ self,
151
+ layer: torch.nn.Module,
152
+ x: torch.Tensor,
153
+ bias: Optional[torch.Tensor] = None,
154
+ ) -> torch.Tensor:
155
+
156
+ return self.fp8_linear.apply(
157
+ input=x,
158
+ weight=layer.weight,
159
+ weight_scale=layer.weight_scale,
160
+ input_scale=layer.input_scale,
161
+ bias=bias,
162
+ )
@@ -0,0 +1,218 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import re
5
+ from types import MappingProxyType
6
+ from typing import Iterable, List, Mapping, Optional
7
+
8
+ from compressed_tensors import CompressionFormat
9
+ from torch.nn import Module
10
+
11
+
12
+ def is_activation_quantization_format(format: str) -> bool:
13
+ _ACTIVATION_QUANTIZATION_FORMATS = [
14
+ CompressionFormat.naive_quantized.value,
15
+ CompressionFormat.int_quantized.value,
16
+ CompressionFormat.float_quantized.value,
17
+ ]
18
+ return format in _ACTIVATION_QUANTIZATION_FORMATS
19
+
20
+
21
+ def should_ignore_layer(
22
+ layer_name: Optional[str],
23
+ ignore: Iterable[str] = tuple(),
24
+ fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
25
+ ) -> bool:
26
+ if layer_name is None:
27
+ return False
28
+
29
+ # layer_name = model.layers.0.self_attn.qkv_proj
30
+ # proj_name = qkv_proj
31
+ proj_name = layer_name.split(".")[-1]
32
+
33
+ # Fused layers like gate_up_proj or qkv_proj will not be fused
34
+ # in the safetensors checkpoint. So, we convert the name
35
+ # from the fused version to unfused + check to make sure that
36
+ # each shard of the fused layer has the same scheme.
37
+ if proj_name in fused_mapping and layer_name not in ignore:
38
+ shard_proj_names = fused_mapping[proj_name]
39
+
40
+ # Convert fused_name --> [shard_names]
41
+ shard_names = [
42
+ layer_name.replace(proj_name, shard_proj_name)
43
+ for shard_proj_name in shard_proj_names
44
+ ]
45
+
46
+ # Layer should be ignored if shards are ignored.
47
+ should_ignore_layer = None
48
+ for shard_name in shard_names:
49
+ should_ignore_shard = check_equal_or_regex_match(
50
+ layer_name=shard_name, targets=ignore
51
+ )
52
+
53
+ # If shard_idx=0, set layer ignore to match shard.
54
+ if should_ignore_layer is None:
55
+ should_ignore_layer = should_ignore_shard
56
+
57
+ # If shard_idx=1+ confirm scheme matches prior shards.
58
+ elif should_ignore_shard != should_ignore_layer:
59
+ raise ValueError(
60
+ f"Found a different quantization schemes for "
61
+ f"{shard_proj_names} in {layer_name}. vLLM "
62
+ "requires all to use the same scheme."
63
+ )
64
+
65
+ # Unfused layers like down_proj and o_proj will match
66
+ # the safetensors checkpoint already.
67
+ else:
68
+ should_ignore_layer = check_equal_or_regex_match(
69
+ layer_name=layer_name, targets=ignore
70
+ )
71
+
72
+ assert should_ignore_layer is not None
73
+ return should_ignore_layer
74
+
75
+
76
+ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
77
+ """
78
+ Checks whether a layer_name is exactly equal or a regex match for
79
+ if target starts with 're:' to any target in list.
80
+ """
81
+ for target in targets:
82
+ if _is_equal_or_regex_match(layer_name, target):
83
+ return True
84
+ return False
85
+
86
+
87
+ def find_matched_target(
88
+ layer_name: Optional[str],
89
+ module: Module,
90
+ targets: Iterable[str],
91
+ fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
92
+ ) -> str:
93
+ """
94
+ Helper function to look up which "target" in the compressed-tensors
95
+ config that a layer corresponds to.
96
+
97
+ Recall that a compressed-tensors configs has a concept of
98
+ config_groups, where each layer can be quantized with with a different
99
+ scheme.
100
+
101
+ targets in each config_group will be a list of either layer names
102
+ (or regexes corresponding to layer names) or names of torch Modules.
103
+
104
+ First, we try to match the layer_name with a target
105
+ Second, we try to match the module's name with a target
106
+ Third, we try to map the layer_name to a list of fused module names.
107
+ *All* component module names must match in order for a match to be
108
+ successful. A successful match returns the first component target
109
+
110
+ :param layer_name: layer name
111
+ :param module: torch.nn.Module
112
+ :param targets: list of targets to match the layer against
113
+ :param fused_mapping: map from fused layer names to its components
114
+ :param fused_strategy: either "all" or "any". If using "all", fused
115
+ layers match if "all" of its components match
116
+ """
117
+
118
+ if layer_name is None:
119
+ layer_name = ""
120
+
121
+ matched_target = (
122
+ _find_first_match(layer_name, targets)
123
+ or _find_first_match(module.__class__.__name__, targets, True)
124
+ or _match_fused_layer(layer_name, targets, fused_mapping)
125
+ )
126
+
127
+ if matched_target is None:
128
+ raise ValueError(
129
+ f"Unable to find matching target for {layer_name} in the "
130
+ "compressed-tensors config."
131
+ )
132
+
133
+ return matched_target
134
+
135
+
136
+ def _find_first_match(
137
+ value: str, targets: Iterable[str], check_contains: bool = False
138
+ ) -> Optional[str]:
139
+ """
140
+ Returns first element of target that matches value either
141
+ exactly or as a regex after 're:'. If check_contains is set to True,
142
+ additionally checks if the target string is contained within the value.
143
+
144
+ :param value: string to compare the list of targets against
145
+ :param targets: list of targets to match the layer against
146
+ :param check_contains: whether or not to do a substring match
147
+ """
148
+
149
+ for target in targets:
150
+ if _is_equal_or_regex_match(value, target, check_contains=check_contains):
151
+ return target
152
+ return None
153
+
154
+
155
+ def _is_equal_or_regex_match(
156
+ value: str, target: str, check_contains: bool = False
157
+ ) -> bool:
158
+ """
159
+ Checks whether a value is exactly equal or a regex match for target
160
+ if target starts with 're:'. If check_contains is set to True,
161
+ additionally checks if the target string is contained within the value.
162
+ """
163
+
164
+ if target.startswith("re:"):
165
+ pattern = target[3:]
166
+ if re.match(pattern, value):
167
+ return True
168
+ elif check_contains:
169
+ if target.lower() in value.lower():
170
+ return True
171
+ elif target == value:
172
+ return True
173
+ return False
174
+
175
+
176
+ def _match_fused_layer(
177
+ layer_name: str,
178
+ target_layers: Iterable[str],
179
+ fused_mapping: Mapping[str, List[str]],
180
+ ) -> Optional[str]:
181
+ """
182
+ Match a fused layer name to its corresponding individual layer in
183
+ target_layers. Returns first value in fused_mapping which matches targets
184
+
185
+ Implements an "all" matching strategy where a fused layer matches iff
186
+ "all" of its components match
187
+
188
+ :param layer_name: layer name
189
+ :param target_layers: list of targets to match the layer against
190
+ :param fused_mapping: map from fused layer names to its components
191
+
192
+ Examples:
193
+ layer_name = "model.layers.0.self_attn.qkv_proj"
194
+ target_layers = ["model.layers.0.self_attn.q_proj",
195
+ "model.layers.0.self_attn.k_proj",
196
+ "model.layers.0.self_attn.v_proj"]
197
+ """
198
+ # find layer_name in mapping
199
+ fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
200
+ if fused is None:
201
+ return None
202
+
203
+ # expand path of unfused components
204
+ unfused_paths = [
205
+ layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
206
+ ]
207
+
208
+ # for each unfused component, find a match in targets
209
+ unfused_matches: List[Optional[str]] = []
210
+ for unfused in unfused_paths:
211
+ for target in target_layers:
212
+ if _is_equal_or_regex_match(unfused, target):
213
+ unfused_matches.append(target)
214
+ break
215
+ else:
216
+ unfused_matches.append(None)
217
+
218
+ return unfused_matches[0] if all(unfused_matches) else None
@@ -7,20 +7,33 @@ import torch
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Module
9
9
  from torch.nn.parameter import Parameter
10
- from vllm import _custom_ops as ops
11
- from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
12
- from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
13
- apply_fp8_marlin_linear,
14
- prepare_fp8_layer_for_marlin,
15
- )
16
- from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
17
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10
+
11
+ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
12
+ from sglang.srt.layers.quantization.utils import (
18
13
  all_close_1d,
19
14
  convert_to_channelwise,
15
+ is_layer_skipped,
20
16
  per_tensor_dequantize,
21
17
  requantize_with_max_scale,
22
18
  )
23
19
 
20
+ try:
21
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
22
+ apply_fp8_marlin_linear,
23
+ prepare_fp8_layer_for_marlin,
24
+ )
25
+
26
+ MARLIN_FP8_AVAILABLE = True
27
+ except ImportError:
28
+ MARLIN_FP8_AVAILABLE = False
29
+
30
+ def apply_fp8_marlin_linear(*args, **kwargs):
31
+ raise ImportError("vllm is not installed")
32
+
33
+ def prepare_fp8_layer_for_marlin(*args, **kwargs):
34
+ raise ImportError("vllm is not installed")
35
+
36
+
24
37
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
25
38
  from sglang.srt.layers.linear import (
26
39
  LinearBase,
@@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
46
59
  )
47
60
  from sglang.srt.utils import (
48
61
  get_bool_env_var,
62
+ is_cuda,
49
63
  is_hip,
50
64
  permute_weight,
51
65
  print_warning_once,
@@ -60,6 +74,13 @@ if _is_hip:
60
74
  from aiter.fused_moe_bf16_asm import asm_moe
61
75
  from aiter.ops.shuffle import shuffle_weight
62
76
 
77
+ _is_cuda = is_cuda()
78
+
79
+ if _is_cuda:
80
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
81
+ else:
82
+ from vllm import _custom_ops as vllm_ops
83
+
63
84
  logger = logging.getLogger(__name__)
64
85
 
65
86
 
@@ -131,8 +152,6 @@ class Fp8Config(QuantizationConfig):
131
152
  def get_quant_method(
132
153
  self, layer: torch.nn.Module, prefix: str
133
154
  ) -> Optional["QuantizeMethodBase"]:
134
- from vllm.attention.layer import Attention # Avoid circular import
135
-
136
155
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
137
156
 
138
157
  if isinstance(layer, LinearBase):
@@ -141,8 +160,6 @@ class Fp8Config(QuantizationConfig):
141
160
  return Fp8LinearMethod(self)
142
161
  elif isinstance(layer, FusedMoE):
143
162
  return Fp8MoEMethod(self)
144
- elif isinstance(layer, Attention):
145
- return Fp8KVCacheMethod(self)
146
163
  return None
147
164
 
148
165
  def get_scaled_act_names(self) -> List[str]:
@@ -173,7 +190,9 @@ class Fp8LinearMethod(LinearMethodBase):
173
190
 
174
191
  # For GPUs that lack FP8 hardware support, we can leverage the Marlin
175
192
  # kernel for fast weight-only FP8 quantization
176
- self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
193
+ self.use_marlin = (
194
+ get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE
195
+ )
177
196
  # Disable marlin for ROCm
178
197
  if _is_hip:
179
198
  self.use_marlin = False
@@ -371,9 +390,12 @@ class Fp8LinearMethod(LinearMethodBase):
371
390
  )
372
391
 
373
392
  if self.use_marlin:
374
- prepare_fp8_layer_for_marlin(layer)
375
- # Activations not quantized for marlin.
376
- del layer.input_scale
393
+ try:
394
+ prepare_fp8_layer_for_marlin(layer)
395
+ # Activations not quantized for marlin.
396
+ del layer.input_scale
397
+ except ImportError:
398
+ self.use_marlin = False
377
399
 
378
400
  def apply(
379
401
  self,
@@ -383,15 +405,18 @@ class Fp8LinearMethod(LinearMethodBase):
383
405
  ) -> torch.Tensor:
384
406
 
385
407
  if self.use_marlin:
386
- return apply_fp8_marlin_linear(
387
- input=x,
388
- weight=layer.weight,
389
- weight_scale=layer.weight_scale,
390
- workspace=layer.workspace,
391
- size_n=layer.output_size_per_partition,
392
- size_k=layer.input_size_per_partition,
393
- bias=bias,
394
- )
408
+ try:
409
+ return apply_fp8_marlin_linear(
410
+ input=x,
411
+ weight=layer.weight,
412
+ weight_scale=layer.weight_scale,
413
+ workspace=layer.workspace,
414
+ size_n=layer.output_size_per_partition,
415
+ size_k=layer.input_size_per_partition,
416
+ bias=bias,
417
+ )
418
+ except ImportError:
419
+ self.use_marlin = False
395
420
 
396
421
  if self.block_quant:
397
422
  return apply_w8a8_block_fp8_linear(
@@ -680,12 +705,20 @@ class Fp8MoEMethod:
680
705
  requires_grad=False,
681
706
  )
682
707
  for expert in range(layer.num_experts):
683
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
684
- ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
685
- )
686
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
687
- ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
688
- )
708
+ if _is_cuda:
709
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
710
+ sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
711
+ )
712
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
713
+ sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
714
+ )
715
+ else:
716
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
717
+ vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
718
+ )
719
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
720
+ vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
721
+ )
689
722
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
690
723
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
691
724
 
@@ -762,9 +795,18 @@ class Fp8MoEMethod:
762
795
  layer.w13_weight[expert_id][start : start + shard_size, :],
763
796
  layer.w13_weight_scale[expert_id][shard_id],
764
797
  )
765
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
766
- ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
767
- )
798
+ if _is_cuda:
799
+ (
800
+ layer.w13_weight[expert_id][start : start + shard_size, :],
801
+ _,
802
+ ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
803
+ else:
804
+ (
805
+ layer.w13_weight[expert_id][start : start + shard_size, :],
806
+ _,
807
+ ) = vllm_ops.scaled_fp8_quant(
808
+ dq_weight, max_w13_scales[expert_id]
809
+ )
768
810
  start += shard_size
769
811
 
770
812
  layer.w13_weight_scale = torch.nn.Parameter(
@@ -26,11 +26,14 @@ from sglang.srt.utils import (
26
26
  direct_register_custom_op,
27
27
  get_device_core_count,
28
28
  get_device_name,
29
+ get_device_sm,
29
30
  is_cuda,
30
31
  is_hip,
31
32
  supports_custom_op,
32
33
  )
33
34
 
35
+ _enable_jit_deepgemm = False
36
+
34
37
  _is_hip = is_hip()
35
38
  fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
36
39
 
@@ -39,9 +42,12 @@ if _is_cuda:
39
42
  import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
40
43
  from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
41
44
 
42
- logger = logging.getLogger(__name__)
45
+ sm_version = get_device_sm()
46
+ if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
47
+ _enable_jit_deepgemm = True
43
48
 
44
- _enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
49
+
50
+ logger = logging.getLogger(__name__)
45
51
 
46
52
  if supports_custom_op():
47
53
 
@@ -168,6 +174,7 @@ def per_token_group_quant_fp8(
168
174
  eps: float = 1e-10,
169
175
  dtype: torch.dtype = fp8_type_,
170
176
  column_major_scales: bool = False,
177
+ scale_tma_aligned: bool = False,
171
178
  ) -> Tuple[torch.Tensor, torch.Tensor]:
172
179
  """Function to perform per-token-group quantization on an input tensor `x`.
173
180
 
@@ -200,11 +207,20 @@ def per_token_group_quant_fp8(
200
207
  M = x.numel() // group_size
201
208
  N = group_size
202
209
  if column_major_scales:
203
- x_s = torch.empty(
204
- (x.shape[-1] // group_size,) + x.shape[:-1],
205
- device=x.device,
206
- dtype=torch.float32,
207
- ).permute(-1, -2)
210
+ if scale_tma_aligned:
211
+ # aligned to 4 * sizeof(float)
212
+ aligned_size = (x.shape[-2] + 3) // 4 * 4
213
+ x_s = torch.empty(
214
+ x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
215
+ device=x.device,
216
+ dtype=torch.float32,
217
+ ).permute(-1, -2)[: x.shape[-2], :]
218
+ else:
219
+ x_s = torch.empty(
220
+ (x.shape[-1] // group_size,) + x.shape[:-1],
221
+ device=x.device,
222
+ dtype=torch.float32,
223
+ ).permute(-1, -2)
208
224
  else:
209
225
  x_s = torch.empty(
210
226
  x.shape[:-1] + (x.shape[-1] // group_size,),
@@ -761,7 +777,7 @@ def w8a8_block_fp8_matmul(
761
777
  )
762
778
 
763
779
  # deepgemm only support bf16
764
- if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
780
+ if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
765
781
  if supports_custom_op():
766
782
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
767
783
  else: