sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
+
3
+ from types import MappingProxyType
4
+ from typing import List, Mapping, Tuple, Union
5
+
6
+ import torch
7
+
8
+ from sglang.srt.utils import is_cuda
9
+
10
+ _is_cuda = is_cuda()
11
+
12
+ if _is_cuda:
13
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
14
+ else:
15
+ from vllm import _custom_ops as vllm_ops
16
+
17
+
18
+ def is_fp8_fnuz() -> bool:
19
+ # only device 0 is checked, this assumes MI300 platforms are homogeneous
20
+ return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
21
+
22
+
23
+ def is_layer_skipped(
24
+ prefix: str,
25
+ ignored_layers: List[str],
26
+ fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
27
+ ) -> bool:
28
+ # prefix: model.layers.0.self_attn.q_proj
29
+ # proj_name: q_proj
30
+ proj_name = prefix.split(".")[-1]
31
+
32
+ # Fused layers like gate_up_proj or qkv_proj will not be fused
33
+ # in the safetensors checkpoint. So, we convert the name
34
+ # from the fused version to unfused + check to make sure that
35
+ # each shard of the fused layer has the same scheme.
36
+ if proj_name in fused_mapping:
37
+ shard_prefixes = [
38
+ prefix.replace(proj_name, shard_proj_name)
39
+ for shard_proj_name in fused_mapping[proj_name]
40
+ ]
41
+
42
+ is_skipped = None
43
+ for shard_prefix in shard_prefixes:
44
+ is_shard_skipped = shard_prefix in ignored_layers
45
+
46
+ if is_skipped is None:
47
+ is_skipped = is_shard_skipped
48
+ elif is_shard_skipped != is_skipped:
49
+ raise ValueError(
50
+ f"Detected some but not all shards of {prefix} "
51
+ "are quantized. All shards of fused layers "
52
+ "to have the same precision."
53
+ )
54
+ else:
55
+ is_skipped = prefix in ignored_layers
56
+
57
+ assert is_skipped is not None
58
+ return is_skipped
59
+
60
+
61
+ def per_tensor_dequantize(
62
+ tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
63
+ ) -> torch.Tensor:
64
+ fake_qweight = tensor.to(torch.float16)
65
+ dq_weight = fake_qweight * inv_scale
66
+ return dq_weight
67
+
68
+
69
+ def all_close_1d(x: torch.Tensor) -> bool:
70
+ assert len(x.shape) == 1
71
+ return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
72
+
73
+
74
+ def convert_to_channelwise(
75
+ weight_scale: torch.Tensor, logical_widths: List[int]
76
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ # Create channelwise buffer
78
+ weight_scale_channel = torch.empty(
79
+ (sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
80
+ )
81
+
82
+ # Handle scalar tensor case: broadcast same scale to all channels
83
+ if weight_scale.dim() == 0:
84
+ weight_scale_channel.fill_(weight_scale.item())
85
+ return weight_scale_channel
86
+
87
+ # Expand each scale to match the size of each logical matrix.
88
+ start = 0
89
+ for idx, logical_width in enumerate(logical_widths):
90
+ end = start + logical_width
91
+ weight_scale_channel[start:end, :] = weight_scale[idx]
92
+ start = end
93
+
94
+ return weight_scale_channel
95
+
96
+
97
+ def requantize_with_max_scale(
98
+ weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
99
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ # Max scale to be used for requanitzation.
101
+ max_w_scale = weight_scale.max()
102
+
103
+ # QKV / MLP is fused in the on disk checkpoint if any of the
104
+ # weight scales are still set to the default since we initialize
105
+ # N weight scales for N shards but we only load 1 weight scale
106
+ # from disk in this case. Skip requantization in this case (since)
107
+ # we already are quantized with the single scale.
108
+ # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
109
+ unfused_module_in_checkpoint = (
110
+ weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
111
+ )
112
+
113
+ # If unfused checkpoint, need requanize with the single scale.
114
+ if unfused_module_in_checkpoint:
115
+ start = 0
116
+ for idx, logical_width in enumerate(logical_widths):
117
+ end = start + logical_width
118
+ weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
119
+ if _is_cuda:
120
+ weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
121
+ else:
122
+ weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
123
+ weight_dq, max_w_scale
124
+ )
125
+ start = end
126
+
127
+ return max_w_scale, weight
128
+
129
+
130
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
131
+ # Newly generated tensors need to replace existing tensors that are
132
+ # already registered as parameters by vLLM (and won't be freed)
133
+ def replace_parameter(
134
+ mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
135
+ ) -> None:
136
+
137
+ old = getattr(mod, name)
138
+ if (
139
+ type(old) is type(new)
140
+ and old.dtype == new.dtype
141
+ and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
142
+ ):
143
+ # If we can just update in-place to avoid re-registering
144
+ # can be faster if the underlying storage is the same
145
+ update_tensor_inplace(old, new)
146
+ else:
147
+ # Fallback re-register parameter, convert to Parameter if necessary
148
+ # this not only ensures we don't register a tensor as a parameter, but
149
+ # also ensures that all parameter subclasses get re-registered as
150
+ # parameters for `torch.compile` compatibility
151
+ if not isinstance(new, torch.nn.Parameter):
152
+ new = torch.nn.Parameter(new, requires_grad=False)
153
+ mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
9
9
  QuantizationConfig,
10
10
  QuantizeMethodBase,
11
11
  )
12
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
12
13
  from sglang.srt.layers.quantization.fp8_utils import (
13
14
  apply_fp8_linear,
14
15
  cutlass_fp8_supported,
16
+ input_to_float8,
15
17
  normalize_e4m3fn_to_e4m3fnuz,
16
18
  )
17
19
  from sglang.srt.utils import is_hip
@@ -22,12 +24,24 @@ _is_hip = is_hip()
22
24
  class W8A8Fp8Config(QuantizationConfig):
23
25
  """Config class for W8A8 FP8 Quantization.
24
26
 
25
- - Weight: static, per-channel, symmetric
26
- - Activation: dynamic, per-token, symmetric
27
+ Weight Quantization:
28
+ - Method: Static quantization
29
+ - Granularity: Per-channel
30
+ - Type: Symmetric
31
+
32
+ Activation Quantization:
33
+ - Method: Dynamic quantization
34
+ - Granularity: Per-token
35
+ - Type: Symmetric
36
+
37
+ Note:
38
+ - For models without offline quantization, weights will be quantized during model loading
39
+ - If CUTLASS is supported: Per-channel weight quantization is used
40
+ - If CUTLASS is not supported: Falls back to per-token weight quantization
27
41
  """
28
42
 
29
- def __init__(self):
30
- pass
43
+ def __init__(self, is_checkpoint_fp8_serialized: bool = False):
44
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
31
45
 
32
46
  @classmethod
33
47
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
47
61
 
48
62
  @classmethod
49
63
  def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
50
- return cls()
64
+ quant_method = cls.get_from_keys(config, ["quant_method"])
65
+ is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
66
+ return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
51
67
 
52
68
  def get_quant_method(
53
69
  self,
@@ -72,13 +88,40 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
72
88
 
73
89
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
74
90
  weight = layer.weight
75
- weight_scale = layer.weight_scale.detach()
76
- if _is_hip:
77
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
78
- weight=weight, weight_scale=weight_scale
79
- )
80
- layer.weight = Parameter(weight.t(), requires_grad=False)
81
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
91
+
92
+ if self.quantization_config.is_checkpoint_fp8_serialized:
93
+ weight_scale = layer.weight_scale.detach()
94
+ # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
95
+ if _is_hip:
96
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
97
+ weight=weight, weight_scale=weight_scale
98
+ )
99
+
100
+ layer.weight = Parameter(weight.t(), requires_grad=False)
101
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
102
+ else:
103
+ # If checkpoint not offline quantized, quantize the weights with per-channel quantization.
104
+ if self.cutlass_fp8_supported:
105
+ # if cutlass supported, we use cutlass_scaled_mm
106
+ # which requires per-channel quantization on weight
107
+ qweight, weight_scale = per_token_group_quant_fp8(
108
+ layer.weight, layer.weight.shape[-1]
109
+ )
110
+ weight_scale = weight_scale.t().contiguous()
111
+ if _is_hip:
112
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
113
+ weight=weight, weight_scale=weight_scale
114
+ )
115
+ else:
116
+ # if cutlass not supported, we fall back to use torch._scaled_mm
117
+ # which requires per tensor quantization on weight
118
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
119
+ qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
120
+
121
+ # Update the layer with the new values.
122
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
123
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
124
+ layer.input_scale = None
82
125
 
83
126
  def create_weights(
84
127
  self,
@@ -90,6 +133,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
90
133
  params_dtype: torch.dtype,
91
134
  **extra_weight_attrs
92
135
  ):
136
+ weight_dtype = (
137
+ torch.float8_e4m3fn
138
+ if self.quantization_config.is_checkpoint_fp8_serialized
139
+ else params_dtype
140
+ )
93
141
 
94
142
  weight_loader = extra_weight_attrs.get("weight_loader")
95
143
  self.logical_widths = output_partition_sizes
@@ -98,7 +146,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
98
146
  data=torch.empty(
99
147
  sum(output_partition_sizes),
100
148
  input_size_per_partition,
101
- dtype=torch.float8_e4m3fn,
149
+ dtype=weight_dtype,
102
150
  ),
103
151
  input_dim=1,
104
152
  output_dim=0,
@@ -106,12 +154,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
106
154
  )
107
155
  layer.register_parameter("weight", weight)
108
156
 
109
- weight_scale = ChannelQuantScaleParameter(
110
- data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
111
- output_dim=0,
112
- weight_loader=weight_loader,
113
- )
114
- layer.register_parameter("weight_scale", weight_scale)
157
+ if self.quantization_config.is_checkpoint_fp8_serialized:
158
+ weight_scale = ChannelQuantScaleParameter(
159
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
160
+ output_dim=0,
161
+ weight_loader=weight_loader,
162
+ )
163
+ layer.register_parameter("weight_scale", weight_scale)
164
+ else:
165
+ layer.weight_scale = None
115
166
 
116
167
  def apply(
117
168
  self,
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
  import torch.nn as nn
9
- from vllm import _custom_ops as ops
10
9
 
11
10
  from sglang.srt.custom_op import CustomOp
12
11
  from sglang.srt.utils import is_cuda_available
@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available
14
13
  _is_cuda_available = is_cuda_available()
15
14
  if _is_cuda_available:
16
15
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
16
+ else:
17
+ from vllm import _custom_ops as ops
17
18
 
18
19
 
19
20
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -147,7 +148,7 @@ class RotaryEmbedding(CustomOp):
147
148
  key: torch.Tensor,
148
149
  offsets: Optional[torch.Tensor] = None,
149
150
  ) -> Tuple[torch.Tensor, torch.Tensor]:
150
- if _is_cuda_available:
151
+ if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
151
152
  apply_rope_with_cos_sin_cache_inplace(
152
153
  positions=positions,
153
154
  query=query,
@@ -168,76 +169,6 @@ class RotaryEmbedding(CustomOp):
168
169
  )
169
170
  return query, key
170
171
 
171
- def forward_xpu(
172
- self,
173
- positions: torch.Tensor,
174
- query: torch.Tensor,
175
- key: torch.Tensor,
176
- offsets: Optional[torch.Tensor] = None,
177
- ) -> Tuple[torch.Tensor, torch.Tensor]:
178
- from vllm._ipex_ops import ipex_ops as ops
179
-
180
- self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
181
- ops.rotary_embedding(
182
- positions,
183
- query,
184
- key,
185
- self.head_size,
186
- self.cos_sin_cache,
187
- self.is_neox_style,
188
- )
189
- return query, key
190
-
191
- def forward_hpu(
192
- self,
193
- positions: torch.Tensor,
194
- query: torch.Tensor,
195
- key: torch.Tensor,
196
- offsets: Optional[torch.Tensor] = None,
197
- ) -> Tuple[torch.Tensor, torch.Tensor]:
198
- from habana_frameworks.torch.hpex.kernels import (
199
- RotaryPosEmbeddingMode,
200
- apply_rotary_pos_emb,
201
- )
202
-
203
- positions = positions.flatten()
204
- if offsets is not None:
205
- positions = positions + offsets
206
- num_tokens = positions.shape[0]
207
- cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
208
- cos, sin = cos_sin.chunk(2, dim=-1)
209
- # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
210
- # to query hidden dimension, so the original tensors need to be
211
- # expanded
212
- # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
213
- # and expansion of cos/sin tensors via concatenation
214
- # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
215
- # and expansion of cos/sin tensors via repeat_interleave
216
- rope_mode: RotaryPosEmbeddingMode
217
- if self.is_neox_style:
218
- rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
219
- cos = torch.cat((cos, cos), dim=-1)
220
- sin = torch.cat((sin, sin), dim=-1)
221
- else:
222
- rope_mode = RotaryPosEmbeddingMode.PAIRWISE
223
- sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
224
- cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
225
-
226
- query_shape = query.shape
227
- query = query.view(num_tokens, -1, self.head_size)
228
- query_rot = query[..., : self.rotary_dim]
229
- query_pass = query[..., self.rotary_dim :]
230
- query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
231
- query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
232
-
233
- key_shape = key.shape
234
- key = key.view(num_tokens, -1, self.head_size)
235
- key_rot = key[..., : self.rotary_dim]
236
- key_pass = key[..., self.rotary_dim :]
237
- key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
238
- key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
239
- return query, key
240
-
241
172
  def extra_repr(self) -> str:
242
173
  s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
243
174
  s += f", max_position_embeddings={self.max_position_embeddings}"
@@ -510,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
510
441
  ):
511
442
  super().__init__()
512
443
 
513
- if rotary_dim != head_size:
514
- raise ValueError(
515
- f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
516
- rotary_dim != head_size ({rotary_dim}!={head_size})."
517
- )
518
444
  if is_neox_style is False:
519
445
  raise ValueError(
520
446
  "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
521
447
  )
522
448
 
449
+ self.rotary_dim = rotary_dim
523
450
  self.head_size = head_size
524
451
  self.max_position_embeddings = max_position_embeddings
525
452
  self.original_max_position_embeddings = original_max_position_embeddings
@@ -568,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
568
495
  * (
569
496
  self.base
570
497
  ** (
571
- torch.arange(0, self.head_size, 2, dtype=torch.float)
572
- / self.head_size
498
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
499
+ / self.rotary_dim
573
500
  )
574
501
  )
575
502
  )
@@ -618,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
618
545
  cos = cos.repeat(1, 2).unsqueeze(-2)
619
546
  sin = sin.repeat(1, 2).unsqueeze(-2)
620
547
 
621
- query = query * cos + _rotate_neox(query) * sin
622
- key = key * cos + _rotate_neox(key) * sin
548
+ query_rot = query[..., : self.rotary_dim]
549
+ query_pass = query[..., self.rotary_dim :]
550
+ query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
551
+ query = torch.cat((query_rot, query_pass), dim=-1)
552
+
553
+ key_rot = key[..., : self.rotary_dim]
554
+ key_pass = key[..., self.rotary_dim :]
555
+ key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
556
+ key = torch.cat((key_rot, key_pass), dim=-1)
623
557
 
624
558
  return query.flatten(-2), key.flatten(-2)
625
559
 
@@ -879,8 +813,17 @@ class MRotaryEmbedding(RotaryEmbedding):
879
813
  spatial_merge_size: int,
880
814
  context_len: int = 0,
881
815
  seq_len: Optional[int] = None,
816
+ second_per_grid_ts: Optional[torch.Tensor] = None,
817
+ tokens_per_second: Optional[int] = None,
882
818
  ) -> Tuple[List[List[int]], int]:
883
- """Get mrope input positions and delta value."""
819
+ """
820
+ Get mrope input positions and delta value.
821
+
822
+ :arg
823
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
824
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
825
+
826
+ """
884
827
 
885
828
  if isinstance(image_grid_thw, torch.Tensor):
886
829
  image_grid_thw = image_grid_thw.tolist()
@@ -917,6 +860,7 @@ class MRotaryEmbedding(RotaryEmbedding):
917
860
  )
918
861
  image_index += 1
919
862
  remain_images -= 1
863
+ second_per_grid_t = 0
920
864
  ed = ed_image
921
865
  else:
922
866
  t, h, w = (
@@ -924,6 +868,10 @@ class MRotaryEmbedding(RotaryEmbedding):
924
868
  video_grid_thw[video_index][1],
925
869
  video_grid_thw[video_index][2],
926
870
  )
871
+ if second_per_grid_ts is not None:
872
+ second_per_grid_t = second_per_grid_ts[video_index]
873
+ else:
874
+ second_per_grid_t = 1.0
927
875
  video_index += 1
928
876
  remain_videos -= 1
929
877
  ed = ed_video
@@ -940,11 +888,11 @@ class MRotaryEmbedding(RotaryEmbedding):
940
888
  )
941
889
 
942
890
  t_index = (
943
- torch.arange(llm_grid_t)
944
- .view(-1, 1)
945
- .expand(-1, llm_grid_h * llm_grid_w)
946
- .flatten()
947
- )
891
+ torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
892
+ * second_per_grid_t
893
+ * tokens_per_second
894
+ ).flatten()
895
+
948
896
  h_index = (
949
897
  torch.arange(llm_grid_h)
950
898
  .view(1, -1, 1)
@@ -1172,6 +1120,37 @@ def get_rope(
1172
1120
  return rotary_emb
1173
1121
 
1174
1122
 
1123
+ # Copied from transformers
1124
+ def rotate_half(x):
1125
+ """Rotates half the hidden dims of the input."""
1126
+ x1 = x[..., : x.shape[-1] // 2]
1127
+ x2 = x[..., x.shape[-1] // 2 :]
1128
+ return torch.cat((-x2, x1), dim=-1)
1129
+
1130
+
1131
+ def apply_rotary_pos_emb(
1132
+ q: torch.Tensor,
1133
+ k: torch.Tensor,
1134
+ cos: torch.Tensor,
1135
+ sin: torch.Tensor,
1136
+ unsqueeze_dim=1,
1137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1138
+ orig_q_dtype = q.dtype
1139
+ orig_k_dtype = k.dtype
1140
+ q, k = q.float(), k.float()
1141
+
1142
+ # embedding is performed in float
1143
+ cos = cos.unsqueeze(unsqueeze_dim).float()
1144
+ sin = sin.unsqueeze(unsqueeze_dim).float()
1145
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1146
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1147
+
1148
+ q_embed = q_embed.to(orig_q_dtype)
1149
+ k_embed = k_embed.to(orig_k_dtype)
1150
+
1151
+ return q_embed, k_embed
1152
+
1153
+
1175
1154
  def get_rope_cpu(
1176
1155
  head_size: int,
1177
1156
  rotary_dim: int,
@@ -168,7 +168,7 @@ class Sampler(nn.Module):
168
168
  group=self.tp_sync_group,
169
169
  )
170
170
 
171
- return batch_next_token_ids.to(torch.int32)
171
+ return batch_next_token_ids
172
172
 
173
173
  def _apply_custom_logit_processor(
174
174
  self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
sglang/srt/lora/layers.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import List, Tuple
2
+
1
3
  import torch
2
4
  from torch import nn
3
5
 
@@ -38,8 +40,22 @@ class BaseLayerWithLoRA(nn.Module):
38
40
  def set_lora_info(self, *args):
39
41
  pass
40
42
 
43
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
44
+ pass
45
+
46
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
47
+ pass
48
+
41
49
 
42
50
  class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
51
+ """
52
+ Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
53
+
54
+ Note: The current version does not yet implement the LoRA functionality.
55
+ This class behaves exactly the same as the base VocabParallelEmbedding.
56
+ Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
57
+ """
58
+
43
59
  def __init__(
44
60
  self,
45
61
  base_layer: VocabParallelEmbedding,
@@ -101,6 +117,16 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
101
117
  output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
102
118
  return output, output_bias
103
119
 
120
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
121
+ return A
122
+
123
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
124
+ shard_size = self.base_layer.output_partition_sizes[0]
125
+ start_idx = tp_rank * shard_size
126
+ end_idx = (tp_rank + 1) * shard_size
127
+ B = B[start_idx:end_idx, :]
128
+ return B
129
+
104
130
 
105
131
  class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
106
132
  def __init__(
@@ -120,6 +146,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
120
146
  self.set_lora = True
121
147
  self.A_buffer_gate_up = A_buffer
122
148
  if self.lora_backend.fuse_stacked_lora_b:
149
+ # TODO: avoid using contiguous() in GPU.
123
150
  # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
124
151
  self.B_buffer_gate_up = torch.cat(
125
152
  (B_buffer[0], B_buffer[1]), dim=-2
@@ -142,6 +169,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
142
169
  else base_output + lora_output * self.scaling
143
170
  )
144
171
 
172
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
173
+ return A
174
+
175
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
176
+ # Since the outputs for both gate and up are identical, we use a random one.
177
+ shard_size = self.base_layer.output_partition_sizes[0]
178
+ start_idx = tp_rank * shard_size
179
+ end_idx = (tp_rank + 1) * shard_size
180
+ return B[:, start_idx:end_idx, :]
181
+
145
182
 
146
183
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
147
184
  def init__(
@@ -210,6 +247,27 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
210
247
  else base_output + lora_output * self.scaling
211
248
  )
212
249
 
250
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
251
+ return A
252
+
253
+ def slice_lora_b_weights(
254
+ self, B: List[torch.Tensor], tp_rank: int
255
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
256
+ B_q, B_kv = B
257
+ base_layer = self.base_layer
258
+ q_proj_shard_size = base_layer.q_proj_shard_size
259
+ kv_proj_shard_size = base_layer.kv_proj_shard_size
260
+ num_kv_head_replicas = base_layer.num_kv_head_replicas
261
+
262
+ q_start_idx = q_proj_shard_size * tp_rank
263
+ q_end_idx = q_start_idx + q_proj_shard_size
264
+
265
+ kv_shard_id = tp_rank // num_kv_head_replicas
266
+ kv_start_idx = kv_proj_shard_size * kv_shard_id
267
+ kv_end_idx = kv_start_idx + kv_proj_shard_size
268
+
269
+ return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
270
+
213
271
 
214
272
  class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
215
273
  def __init__(
@@ -274,6 +332,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
274
332
  output_bias = self.base_layer.bias
275
333
  return output, output_bias
276
334
 
335
+ def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
336
+ shard_size = self.base_layer.input_size_per_partition
337
+ start_idx = tp_rank * shard_size
338
+ end_idx = (tp_rank + 1) * shard_size
339
+ A = A[:, start_idx:end_idx].contiguous()
340
+ return A
341
+
342
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
343
+ return B
344
+
277
345
 
278
346
  def get_lora_layer(
279
347
  layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
sglang/srt/lora/lora.py CHANGED
@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
39
39
  super().__init__()
40
40
  self.config: LoRAConfig = config
41
41
  self.base_hf_config: AutoConfig = base_hf_config
42
- self.weights: Dict[str, torch.Tensor] = {}
43
- self.weight_gpu: Dict[str, torch.Tensor] = {}
44
-
45
- def load_to_gpu(self):
46
- for name, weight in self.weights.items():
47
- self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
48
42
 
49
- def offload_from_gpu(self):
50
- for name, weight in self.weights.items():
51
- self.weight_gpu[name] = None
43
+ # lora weights in cpu. The weights are loaded from checkpoint.
44
+ self.weights: Dict[str, torch.Tensor] = {}
52
45
 
53
46
 
54
47
  class LoRAAdapter(nn.Module):
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
77
70
  )
78
71
 
79
72
  self.weights: Dict[str, torch.Tensor] = {}
80
- self.weights_gpu: Dict[str, torch.Tensor] = {}
81
-
82
- def load_to_gpu(self):
83
- for name, weight in self.weights.items():
84
- self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
85
- for layer in self.layers:
86
- layer.load_to_gpu()
87
-
88
- def offload_from_gpu(self):
89
- for name, weight in self.weights.items():
90
- self.weights_gpu[name] = None
91
- for layer in self.layers:
92
- layer.offload_from_gpu()
93
73
 
94
74
  # initialize the LoRA weights to cpu
95
75
  def initialize_weights(self):