sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
+
3
+ from types import MappingProxyType
4
+ from typing import List, Mapping, Tuple, Union
5
+
6
+ import torch
7
+
8
+ from sglang.srt.utils import is_cuda
9
+
10
+ _is_cuda = is_cuda()
11
+
12
+ if _is_cuda:
13
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
14
+ else:
15
+ from vllm import _custom_ops as vllm_ops
16
+
17
+
18
+ def is_fp8_fnuz() -> bool:
19
+ # only device 0 is checked, this assumes MI300 platforms are homogeneous
20
+ return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
21
+
22
+
23
+ def is_layer_skipped(
24
+ prefix: str,
25
+ ignored_layers: List[str],
26
+ fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
27
+ ) -> bool:
28
+ # prefix: model.layers.0.self_attn.q_proj
29
+ # proj_name: q_proj
30
+ proj_name = prefix.split(".")[-1]
31
+
32
+ # Fused layers like gate_up_proj or qkv_proj will not be fused
33
+ # in the safetensors checkpoint. So, we convert the name
34
+ # from the fused version to unfused + check to make sure that
35
+ # each shard of the fused layer has the same scheme.
36
+ if proj_name in fused_mapping:
37
+ shard_prefixes = [
38
+ prefix.replace(proj_name, shard_proj_name)
39
+ for shard_proj_name in fused_mapping[proj_name]
40
+ ]
41
+
42
+ is_skipped = None
43
+ for shard_prefix in shard_prefixes:
44
+ is_shard_skipped = shard_prefix in ignored_layers
45
+
46
+ if is_skipped is None:
47
+ is_skipped = is_shard_skipped
48
+ elif is_shard_skipped != is_skipped:
49
+ raise ValueError(
50
+ f"Detected some but not all shards of {prefix} "
51
+ "are quantized. All shards of fused layers "
52
+ "to have the same precision."
53
+ )
54
+ else:
55
+ is_skipped = prefix in ignored_layers
56
+
57
+ assert is_skipped is not None
58
+ return is_skipped
59
+
60
+
61
+ def per_tensor_dequantize(
62
+ tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
63
+ ) -> torch.Tensor:
64
+ fake_qweight = tensor.to(torch.float16)
65
+ dq_weight = fake_qweight * inv_scale
66
+ return dq_weight
67
+
68
+
69
+ def all_close_1d(x: torch.Tensor) -> bool:
70
+ assert len(x.shape) == 1
71
+ return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
72
+
73
+
74
+ def convert_to_channelwise(
75
+ weight_scale: torch.Tensor, logical_widths: List[int]
76
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ # Create channelwise buffer
78
+ weight_scale_channel = torch.empty(
79
+ (sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
80
+ )
81
+
82
+ # Handle scalar tensor case: broadcast same scale to all channels
83
+ if weight_scale.dim() == 0:
84
+ weight_scale_channel.fill_(weight_scale.item())
85
+ return weight_scale_channel
86
+
87
+ # Expand each scale to match the size of each logical matrix.
88
+ start = 0
89
+ for idx, logical_width in enumerate(logical_widths):
90
+ end = start + logical_width
91
+ weight_scale_channel[start:end, :] = weight_scale[idx]
92
+ start = end
93
+
94
+ return weight_scale_channel
95
+
96
+
97
+ def requantize_with_max_scale(
98
+ weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
99
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ # Max scale to be used for requanitzation.
101
+ max_w_scale = weight_scale.max()
102
+
103
+ # QKV / MLP is fused in the on disk checkpoint if any of the
104
+ # weight scales are still set to the default since we initialize
105
+ # N weight scales for N shards but we only load 1 weight scale
106
+ # from disk in this case. Skip requantization in this case (since)
107
+ # we already are quantized with the single scale.
108
+ # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
109
+ unfused_module_in_checkpoint = (
110
+ weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
111
+ )
112
+
113
+ # If unfused checkpoint, need requanize with the single scale.
114
+ if unfused_module_in_checkpoint:
115
+ start = 0
116
+ for idx, logical_width in enumerate(logical_widths):
117
+ end = start + logical_width
118
+ weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
119
+ if _is_cuda:
120
+ weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
121
+ else:
122
+ weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
123
+ weight_dq, max_w_scale
124
+ )
125
+ start = end
126
+
127
+ return max_w_scale, weight
128
+
129
+
130
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
131
+ # Newly generated tensors need to replace existing tensors that are
132
+ # already registered as parameters by vLLM (and won't be freed)
133
+ def replace_parameter(
134
+ mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
135
+ ) -> None:
136
+
137
+ old = getattr(mod, name)
138
+ if (
139
+ type(old) is type(new)
140
+ and old.dtype == new.dtype
141
+ and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
142
+ ):
143
+ # If we can just update in-place to avoid re-registering
144
+ # can be faster if the underlying storage is the same
145
+ update_tensor_inplace(old, new)
146
+ else:
147
+ # Fallback re-register parameter, convert to Parameter if necessary
148
+ # this not only ensures we don't register a tensor as a parameter, but
149
+ # also ensures that all parameter subclasses get re-registered as
150
+ # parameters for `torch.compile` compatibility
151
+ if not isinstance(new, torch.nn.Parameter):
152
+ new = torch.nn.Parameter(new, requires_grad=False)
153
+ mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
9
9
  QuantizationConfig,
10
10
  QuantizeMethodBase,
11
11
  )
12
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
12
13
  from sglang.srt.layers.quantization.fp8_utils import (
13
14
  apply_fp8_linear,
14
15
  cutlass_fp8_supported,
16
+ input_to_float8,
15
17
  normalize_e4m3fn_to_e4m3fnuz,
16
18
  )
17
19
  from sglang.srt.utils import is_hip
@@ -22,12 +24,24 @@ _is_hip = is_hip()
22
24
  class W8A8Fp8Config(QuantizationConfig):
23
25
  """Config class for W8A8 FP8 Quantization.
24
26
 
25
- - Weight: static, per-channel, symmetric
26
- - Activation: dynamic, per-token, symmetric
27
+ Weight Quantization:
28
+ - Method: Static quantization
29
+ - Granularity: Per-channel
30
+ - Type: Symmetric
31
+
32
+ Activation Quantization:
33
+ - Method: Dynamic quantization
34
+ - Granularity: Per-token
35
+ - Type: Symmetric
36
+
37
+ Note:
38
+ - For models without offline quantization, weights will be quantized during model loading
39
+ - If CUTLASS is supported: Per-channel weight quantization is used
40
+ - If CUTLASS is not supported: Falls back to per-tensor weight quantization
27
41
  """
28
42
 
29
- def __init__(self):
30
- pass
43
+ def __init__(self, is_checkpoint_fp8_serialized: bool = False):
44
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
31
45
 
32
46
  @classmethod
33
47
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
47
61
 
48
62
  @classmethod
49
63
  def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
50
- return cls()
64
+ quant_method = cls.get_from_keys(config, ["quant_method"])
65
+ is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
66
+ return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
51
67
 
52
68
  def get_quant_method(
53
69
  self,
@@ -72,13 +88,40 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
72
88
 
73
89
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
74
90
  weight = layer.weight
75
- weight_scale = layer.weight_scale.detach()
76
- if _is_hip:
77
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
78
- weight=weight, weight_scale=weight_scale
79
- )
80
- layer.weight = Parameter(weight.t(), requires_grad=False)
81
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
91
+
92
+ if self.quantization_config.is_checkpoint_fp8_serialized:
93
+ weight_scale = layer.weight_scale.detach()
94
+ # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
95
+ if _is_hip:
96
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
97
+ weight=weight, weight_scale=weight_scale
98
+ )
99
+
100
+ layer.weight = Parameter(weight.t(), requires_grad=False)
101
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
102
+ else:
103
+ # If checkpoint not offline quantized, quantize the weights with per-channel quantization.
104
+ if self.cutlass_fp8_supported:
105
+ # if cutlass supported, we use cutlass_scaled_mm
106
+ # which requires per-channel quantization on weight
107
+ qweight, weight_scale = per_token_group_quant_fp8(
108
+ layer.weight, layer.weight.shape[-1]
109
+ )
110
+ weight_scale = weight_scale.t().contiguous()
111
+ if _is_hip:
112
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
113
+ weight=weight, weight_scale=weight_scale
114
+ )
115
+ else:
116
+ # if cutlass not supported, we fall back to use torch._scaled_mm
117
+ # which requires per tensor quantization on weight
118
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
119
+ qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
120
+
121
+ # Update the layer with the new values.
122
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
123
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
124
+ layer.input_scale = None
82
125
 
83
126
  def create_weights(
84
127
  self,
@@ -90,6 +133,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
90
133
  params_dtype: torch.dtype,
91
134
  **extra_weight_attrs
92
135
  ):
136
+ weight_dtype = (
137
+ torch.float8_e4m3fn
138
+ if self.quantization_config.is_checkpoint_fp8_serialized
139
+ else params_dtype
140
+ )
93
141
 
94
142
  weight_loader = extra_weight_attrs.get("weight_loader")
95
143
  self.logical_widths = output_partition_sizes
@@ -98,7 +146,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
98
146
  data=torch.empty(
99
147
  sum(output_partition_sizes),
100
148
  input_size_per_partition,
101
- dtype=torch.float8_e4m3fn,
149
+ dtype=weight_dtype,
102
150
  ),
103
151
  input_dim=1,
104
152
  output_dim=0,
@@ -106,12 +154,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
106
154
  )
107
155
  layer.register_parameter("weight", weight)
108
156
 
109
- weight_scale = ChannelQuantScaleParameter(
110
- data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
111
- output_dim=0,
112
- weight_loader=weight_loader,
113
- )
114
- layer.register_parameter("weight_scale", weight_scale)
157
+ if self.quantization_config.is_checkpoint_fp8_serialized:
158
+ weight_scale = ChannelQuantScaleParameter(
159
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
160
+ output_dim=0,
161
+ weight_loader=weight_loader,
162
+ )
163
+ layer.register_parameter("weight_scale", weight_scale)
164
+ else:
165
+ layer.weight_scale = None
115
166
 
116
167
  def apply(
117
168
  self,
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
  import torch.nn as nn
9
- from vllm import _custom_ops as ops
10
9
 
11
10
  from sglang.srt.custom_op import CustomOp
12
11
  from sglang.srt.utils import is_cuda_available
@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available
14
13
  _is_cuda_available = is_cuda_available()
15
14
  if _is_cuda_available:
16
15
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
16
+ else:
17
+ from vllm import _custom_ops as ops
17
18
 
18
19
 
19
20
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -147,7 +148,7 @@ class RotaryEmbedding(CustomOp):
147
148
  key: torch.Tensor,
148
149
  offsets: Optional[torch.Tensor] = None,
149
150
  ) -> Tuple[torch.Tensor, torch.Tensor]:
150
- if _is_cuda_available:
151
+ if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
151
152
  apply_rope_with_cos_sin_cache_inplace(
152
153
  positions=positions,
153
154
  query=query,
@@ -168,76 +169,6 @@ class RotaryEmbedding(CustomOp):
168
169
  )
169
170
  return query, key
170
171
 
171
- def forward_xpu(
172
- self,
173
- positions: torch.Tensor,
174
- query: torch.Tensor,
175
- key: torch.Tensor,
176
- offsets: Optional[torch.Tensor] = None,
177
- ) -> Tuple[torch.Tensor, torch.Tensor]:
178
- from vllm._ipex_ops import ipex_ops as ops
179
-
180
- self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
181
- ops.rotary_embedding(
182
- positions,
183
- query,
184
- key,
185
- self.head_size,
186
- self.cos_sin_cache,
187
- self.is_neox_style,
188
- )
189
- return query, key
190
-
191
- def forward_hpu(
192
- self,
193
- positions: torch.Tensor,
194
- query: torch.Tensor,
195
- key: torch.Tensor,
196
- offsets: Optional[torch.Tensor] = None,
197
- ) -> Tuple[torch.Tensor, torch.Tensor]:
198
- from habana_frameworks.torch.hpex.kernels import (
199
- RotaryPosEmbeddingMode,
200
- apply_rotary_pos_emb,
201
- )
202
-
203
- positions = positions.flatten()
204
- if offsets is not None:
205
- positions = positions + offsets
206
- num_tokens = positions.shape[0]
207
- cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
208
- cos, sin = cos_sin.chunk(2, dim=-1)
209
- # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
210
- # to query hidden dimension, so the original tensors need to be
211
- # expanded
212
- # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
213
- # and expansion of cos/sin tensors via concatenation
214
- # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
215
- # and expansion of cos/sin tensors via repeat_interleave
216
- rope_mode: RotaryPosEmbeddingMode
217
- if self.is_neox_style:
218
- rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
219
- cos = torch.cat((cos, cos), dim=-1)
220
- sin = torch.cat((sin, sin), dim=-1)
221
- else:
222
- rope_mode = RotaryPosEmbeddingMode.PAIRWISE
223
- sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
224
- cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
225
-
226
- query_shape = query.shape
227
- query = query.view(num_tokens, -1, self.head_size)
228
- query_rot = query[..., : self.rotary_dim]
229
- query_pass = query[..., self.rotary_dim :]
230
- query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
231
- query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
232
-
233
- key_shape = key.shape
234
- key = key.view(num_tokens, -1, self.head_size)
235
- key_rot = key[..., : self.rotary_dim]
236
- key_pass = key[..., self.rotary_dim :]
237
- key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
238
- key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
239
- return query, key
240
-
241
172
  def extra_repr(self) -> str:
242
173
  s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
243
174
  s += f", max_position_embeddings={self.max_position_embeddings}"
@@ -510,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
510
441
  ):
511
442
  super().__init__()
512
443
 
513
- if rotary_dim != head_size:
514
- raise ValueError(
515
- f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
516
- rotary_dim != head_size ({rotary_dim}!={head_size})."
517
- )
518
444
  if is_neox_style is False:
519
445
  raise ValueError(
520
446
  "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
521
447
  )
522
448
 
449
+ self.rotary_dim = rotary_dim
523
450
  self.head_size = head_size
524
451
  self.max_position_embeddings = max_position_embeddings
525
452
  self.original_max_position_embeddings = original_max_position_embeddings
@@ -568,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
568
495
  * (
569
496
  self.base
570
497
  ** (
571
- torch.arange(0, self.head_size, 2, dtype=torch.float)
572
- / self.head_size
498
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
499
+ / self.rotary_dim
573
500
  )
574
501
  )
575
502
  )
@@ -618,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
618
545
  cos = cos.repeat(1, 2).unsqueeze(-2)
619
546
  sin = sin.repeat(1, 2).unsqueeze(-2)
620
547
 
621
- query = query * cos + _rotate_neox(query) * sin
622
- key = key * cos + _rotate_neox(key) * sin
548
+ query_rot = query[..., : self.rotary_dim]
549
+ query_pass = query[..., self.rotary_dim :]
550
+ query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
551
+ query = torch.cat((query_rot, query_pass), dim=-1)
552
+
553
+ key_rot = key[..., : self.rotary_dim]
554
+ key_pass = key[..., self.rotary_dim :]
555
+ key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
556
+ key = torch.cat((key_rot, key_pass), dim=-1)
623
557
 
624
558
  return query.flatten(-2), key.flatten(-2)
625
559
 
@@ -717,6 +651,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
717
651
  query: torch.Tensor,
718
652
  key: torch.Tensor,
719
653
  offsets: Optional[torch.Tensor] = None,
654
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
655
+ if _is_cuda_available:
656
+ return self.forward_cuda(positions, query, key, offsets)
657
+ else:
658
+ return self.forward_native(positions, query, key, offsets)
659
+
660
+ def forward_native(
661
+ self,
662
+ positions: torch.Tensor,
663
+ query: torch.Tensor,
664
+ key: torch.Tensor,
665
+ offsets: Optional[torch.Tensor] = None,
720
666
  ) -> Tuple[torch.Tensor, torch.Tensor]:
721
667
  """PyTorch-native implementation equivalent to forward()."""
722
668
  query_rot = query[..., : self.rotary_dim]
@@ -879,8 +825,17 @@ class MRotaryEmbedding(RotaryEmbedding):
879
825
  spatial_merge_size: int,
880
826
  context_len: int = 0,
881
827
  seq_len: Optional[int] = None,
828
+ second_per_grid_ts: Optional[torch.Tensor] = None,
829
+ tokens_per_second: Optional[int] = None,
882
830
  ) -> Tuple[List[List[int]], int]:
883
- """Get mrope input positions and delta value."""
831
+ """
832
+ Get mrope input positions and delta value.
833
+
834
+ :arg
835
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
836
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
837
+
838
+ """
884
839
 
885
840
  if isinstance(image_grid_thw, torch.Tensor):
886
841
  image_grid_thw = image_grid_thw.tolist()
@@ -917,6 +872,7 @@ class MRotaryEmbedding(RotaryEmbedding):
917
872
  )
918
873
  image_index += 1
919
874
  remain_images -= 1
875
+ second_per_grid_t = 0
920
876
  ed = ed_image
921
877
  else:
922
878
  t, h, w = (
@@ -924,6 +880,10 @@ class MRotaryEmbedding(RotaryEmbedding):
924
880
  video_grid_thw[video_index][1],
925
881
  video_grid_thw[video_index][2],
926
882
  )
883
+ if second_per_grid_ts is not None:
884
+ second_per_grid_t = second_per_grid_ts[video_index]
885
+ else:
886
+ second_per_grid_t = 1.0
927
887
  video_index += 1
928
888
  remain_videos -= 1
929
889
  ed = ed_video
@@ -940,11 +900,11 @@ class MRotaryEmbedding(RotaryEmbedding):
940
900
  )
941
901
 
942
902
  t_index = (
943
- torch.arange(llm_grid_t)
944
- .view(-1, 1)
945
- .expand(-1, llm_grid_h * llm_grid_w)
946
- .flatten()
947
- )
903
+ torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
904
+ * second_per_grid_t
905
+ * tokens_per_second
906
+ ).flatten()
907
+
948
908
  h_index = (
949
909
  torch.arange(llm_grid_h)
950
910
  .view(1, -1, 1)
@@ -1172,6 +1132,37 @@ def get_rope(
1172
1132
  return rotary_emb
1173
1133
 
1174
1134
 
1135
+ # Copied from transformers
1136
+ def rotate_half(x):
1137
+ """Rotates half the hidden dims of the input."""
1138
+ x1 = x[..., : x.shape[-1] // 2]
1139
+ x2 = x[..., x.shape[-1] // 2 :]
1140
+ return torch.cat((-x2, x1), dim=-1)
1141
+
1142
+
1143
+ def apply_rotary_pos_emb(
1144
+ q: torch.Tensor,
1145
+ k: torch.Tensor,
1146
+ cos: torch.Tensor,
1147
+ sin: torch.Tensor,
1148
+ unsqueeze_dim=1,
1149
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1150
+ orig_q_dtype = q.dtype
1151
+ orig_k_dtype = k.dtype
1152
+ q, k = q.float(), k.float()
1153
+
1154
+ # embedding is performed in float
1155
+ cos = cos.unsqueeze(unsqueeze_dim).float()
1156
+ sin = sin.unsqueeze(unsqueeze_dim).float()
1157
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1158
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1159
+
1160
+ q_embed = q_embed.to(orig_q_dtype)
1161
+ k_embed = k_embed.to(orig_k_dtype)
1162
+
1163
+ return q_embed, k_embed
1164
+
1165
+
1175
1166
  def get_rope_cpu(
1176
1167
  head_size: int,
1177
1168
  rotary_dim: int,
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
168
168
  group=self.tp_sync_group,
169
169
  )
170
170
 
171
- return batch_next_token_ids.to(torch.int32)
171
+ return batch_next_token_ids
172
172
 
173
173
  def _apply_custom_logit_processor(
174
174
  self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
- def get_fuse_output_scaling_add_from_name(name: str) -> bool:
8
+ def get_fuse_output_add_from_name(name: str) -> bool:
9
9
  mapping = {
10
10
  "triton": True,
11
11
  "flashinfer": False,
@@ -28,14 +28,14 @@ class BaseLoRABackend:
28
28
  Args:
29
29
  name: name of backend
30
30
  batch_info: information of current batch for use
31
- fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
- and the operation of scaling and adding will be fused into kernel
31
+ fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
+ and the operation of adding will be fused into kernel
33
33
  """
34
34
 
35
35
  def __init__(self, name: str, batch_info: LoRABatchInfo = None):
36
36
  self.name = name
37
37
  self.batch_info = batch_info
38
- self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
38
+ self.fuse_output_add = get_fuse_output_add_from_name(name)
39
39
  self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
40
40
 
41
41
  def run_lora_a_sgemm(
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
37
37
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
38
38
  ) -> torch.Tensor:
39
39
 
40
- return self.segment_gemm.run(
41
- x=x,
42
- weights=weights,
43
- batch_size=self.batch_info.bs,
44
- weight_column_major=True,
45
- seg_indptr=self.batch_info.seg_indptr,
46
- weight_indices=self.batch_info.weight_indices,
40
+ return (
41
+ self.segment_gemm.run(
42
+ x=x,
43
+ weights=weights,
44
+ batch_size=self.batch_info.bs,
45
+ weight_column_major=True,
46
+ seg_indptr=self.batch_info.seg_indptr,
47
+ weight_indices=self.batch_info.weight_indices,
48
+ )
49
+ * self.batch_info.scalings[0]
47
50
  )
48
51
 
49
52
  def run_qkv_lora(
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
90
93
  weights=kv_lora_b[1],
91
94
  )
92
95
 
93
- return lora_output
96
+ return lora_output * self.batch_info.scalings[0]
94
97
 
95
98
  def run_gate_up_lora(
96
99
  self,
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
125
128
  weights=gate_up_lora_b[1],
126
129
  )
127
130
 
128
- return lora_output
131
+ return lora_output * self.batch_info.scalings[0]
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
25
25
  x: torch.Tensor,
26
26
  weights: torch.Tensor,
27
27
  base_output: torch.Tensor = None,
28
- scaling: float = 1.0,
29
28
  *args,
30
29
  **kwargs
31
30
  ) -> torch.Tensor:
32
- return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
31
+ return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
33
32
 
34
33
  def run_qkv_lora(
35
34
  self,
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
39
38
  output_offset: torch.Tensor,
40
39
  max_qkv_out_dim: int,
41
40
  base_output: torch.Tensor = None,
42
- scaling: float = 1.0,
43
41
  *args,
44
42
  **kwargs
45
43
  ) -> torch.Tensor:
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
49
47
  # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
50
48
  assert isinstance(qkv_lora_b, torch.Tensor)
51
49
 
52
- lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
50
+ lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
53
51
  lora_output = qkv_lora_b_fwd(
54
52
  lora_a_output,
55
53
  qkv_lora_b,
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
57
55
  output_offset,
58
56
  max_qkv_out_dim,
59
57
  base_output,
60
- scaling,
61
58
  )
62
59
  return lora_output
63
60
 
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
67
64
  gate_up_lora_a: torch.Tensor,
68
65
  gate_up_lora_b: torch.Tensor,
69
66
  base_output: torch.Tensor = None,
70
- scaling: float = 1.0,
71
67
  *args,
72
68
  **kwargs
73
69
  ) -> torch.Tensor:
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
79
75
  output_dim = gate_up_lora_b.shape[-2] // 2
80
76
 
81
77
  # lora_a_output: (s, 2 * r)
82
- lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
78
+ lora_a_output = sgemm_lora_a_fwd(
79
+ x, gate_up_lora_a, self.batch_info, stack_num=2
80
+ )
83
81
  lora_output = gate_up_lora_b_fwd(
84
82
  lora_a_output,
85
83
  gate_up_lora_b,
86
84
  self.batch_info,
87
85
  output_dim,
88
86
  base_output,
89
- scaling,
90
87
  )
91
88
  return lora_output