sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
2
2
 
3
3
  """Rotary Positional Embeddings."""
4
+ import itertools
4
5
  import math
5
6
  from typing import Any, Dict, List, Optional, Tuple, Union
6
7
 
@@ -221,6 +222,7 @@ class RotaryEmbedding(CustomOp):
221
222
  query: torch.Tensor,
222
223
  key: torch.Tensor,
223
224
  offsets: Optional[torch.Tensor] = None,
225
+ fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
224
226
  ) -> Tuple[torch.Tensor, torch.Tensor]:
225
227
  if _is_cuda and (self.head_size in [64, 128, 256, 512]):
226
228
  apply_rope_with_cos_sin_cache_inplace(
@@ -230,8 +232,17 @@ class RotaryEmbedding(CustomOp):
230
232
  head_size=self.head_size,
231
233
  cos_sin_cache=self.cos_sin_cache,
232
234
  is_neox=self.is_neox_style,
235
+ # Compatible with old sgl-kernel
236
+ **(
237
+ dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
238
+ if fused_set_kv_buffer_arg is not None
239
+ else {}
240
+ ),
233
241
  )
234
242
  else:
243
+ assert (
244
+ fused_set_kv_buffer_arg is None
245
+ ), "save kv cache is not supported for vllm_rotary_embedding."
235
246
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
236
247
  self.vllm_rotary_embedding(
237
248
  positions,
@@ -679,7 +690,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
679
690
  )
680
691
 
681
692
  # Re-dispatch
682
- if _is_hip or _is_npu:
693
+ if _is_hip:
683
694
  self._forward_method = self.forward_native
684
695
 
685
696
  def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -764,6 +775,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
764
775
  key = key_rot
765
776
  return query.to(dtype), key.to(dtype)
766
777
 
778
+ def forward_npu(
779
+ self,
780
+ positions: torch.Tensor,
781
+ query: torch.Tensor,
782
+ key: torch.Tensor,
783
+ offsets: Optional[torch.Tensor] = None,
784
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
785
+ # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
786
+ # and generalization to more scenarios will be supported in the future.
787
+ if query.shape[1] * query.shape[2] > 4096:
788
+ return self.forward_native(positions, query, key, offsets)
789
+ num_tokens = query.shape[0]
790
+ rotary_mode = "half" if self.is_neox_style else "interleave"
791
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
792
+ query_rot = query[..., : self.rotary_dim]
793
+ key_rot = key[..., : self.rotary_dim]
794
+ if self.rotary_dim < self.head_size:
795
+ query_pass = query[..., self.rotary_dim :]
796
+ key_pass = key[..., self.rotary_dim :]
797
+
798
+ query_rot, key_rot = torch_npu.npu_mrope(
799
+ torch.add(positions, offsets) if offsets is not None else positions,
800
+ query_rot.reshape(num_tokens, -1),
801
+ key_rot.reshape(num_tokens, -1),
802
+ self.cos_sin_cache,
803
+ self.rotary_dim,
804
+ mrope_section=[0, 0, 0],
805
+ rotary_mode=rotary_mode,
806
+ )
807
+ query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
808
+ key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
809
+
810
+ if self.rotary_dim < self.head_size:
811
+ query = torch.cat((query_rot, query_pass), dim=-1)
812
+ key = torch.cat((key_rot, key_pass), dim=-1)
813
+ else:
814
+ query = query_rot
815
+ key = key_rot
816
+ return query, key
817
+
767
818
  def forward_cpu(
768
819
  self,
769
820
  positions: torch.Tensor,
@@ -946,7 +997,37 @@ class MRotaryEmbedding(RotaryEmbedding):
946
997
 
947
998
  self.mrope_section = mrope_section
948
999
  if self.mrope_section:
949
- assert sum(self.mrope_section) == rotary_dim // 2
1000
+ expected_sum = rotary_dim // 2
1001
+ actual_sum = sum(self.mrope_section)
1002
+ if actual_sum != expected_sum:
1003
+ print(
1004
+ f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. "
1005
+ f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}"
1006
+ )
1007
+ # Auto-correct by scaling the mrope_section proportionally
1008
+ if actual_sum > 0:
1009
+ scale_factor = expected_sum / actual_sum
1010
+ self.mrope_section = [
1011
+ max(1, int(section * scale_factor))
1012
+ for section in self.mrope_section
1013
+ ]
1014
+ # Ensure the sum exactly matches by adjusting the last element
1015
+ current_sum = sum(self.mrope_section)
1016
+ if current_sum != expected_sum:
1017
+ self.mrope_section[-1] += expected_sum - current_sum
1018
+ else:
1019
+ # If all sections are 0, create a default distribution
1020
+ self.mrope_section = [
1021
+ expected_sum // len(self.mrope_section)
1022
+ ] * len(self.mrope_section)
1023
+ # Handle remainder
1024
+ remainder = expected_sum % len(self.mrope_section)
1025
+ for i in range(remainder):
1026
+ self.mrope_section[i] += 1
1027
+
1028
+ print(
1029
+ f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
+ )
950
1031
 
951
1032
  def forward(
952
1033
  self,
@@ -1153,6 +1234,204 @@ class MRotaryEmbedding(RotaryEmbedding):
1153
1234
  mrope_position_deltas = max_position_ids + 1 - s
1154
1235
  return position_ids, mrope_position_deltas
1155
1236
 
1237
+ # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
1238
+ @staticmethod
1239
+ def get_rope_index_glm4v(
1240
+ input_ids: torch.Tensor,
1241
+ hf_config: Any,
1242
+ image_grid_thw: Union[list[list[int]], torch.Tensor],
1243
+ video_grid_thw: Union[list[list[int]], torch.Tensor],
1244
+ attention_mask: torch.Tensor,
1245
+ **kwargs,
1246
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1247
+ """Get mrope input positions and delta value for GLM4V."""
1248
+ image_token_id = hf_config.image_token_id
1249
+ video_start_token_id = hf_config.video_start_token_id
1250
+ video_end_token_id = hf_config.video_end_token_id
1251
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
1252
+
1253
+ mrope_position_deltas = []
1254
+ if input_ids is not None and (
1255
+ image_grid_thw is not None or video_grid_thw is not None
1256
+ ):
1257
+ total_input_ids = input_ids
1258
+ if attention_mask is None:
1259
+ attention_mask = torch.ones_like(total_input_ids)
1260
+ position_ids = torch.ones(
1261
+ 3,
1262
+ input_ids.shape[0],
1263
+ input_ids.shape[1],
1264
+ dtype=input_ids.dtype,
1265
+ device=input_ids.device,
1266
+ )
1267
+ image_index, video_index = 0, 0
1268
+ video_group_index = 0
1269
+ attention_mask = attention_mask.to(total_input_ids.device)
1270
+ for i, input_ids in enumerate(total_input_ids):
1271
+ input_ids = input_ids[attention_mask[i] == 1]
1272
+ input_tokens = input_ids.tolist()
1273
+
1274
+ input_token_type = []
1275
+ video_check_flg = False
1276
+ for token in input_tokens:
1277
+ if token == video_start_token_id:
1278
+ video_check_flg = True
1279
+ elif token == video_end_token_id:
1280
+ video_check_flg = False
1281
+
1282
+ if token == image_token_id and not video_check_flg:
1283
+ input_token_type.append("image")
1284
+ elif token == image_token_id and video_check_flg:
1285
+ input_token_type.append("video")
1286
+ else:
1287
+ input_token_type.append("text")
1288
+
1289
+ input_type_group = []
1290
+ for key, group in itertools.groupby(
1291
+ enumerate(input_token_type), lambda x: x[1]
1292
+ ):
1293
+ group = list(group)
1294
+ start_index = group[0][0]
1295
+ end_index = group[-1][0] + 1
1296
+ input_type_group.append((key, start_index, end_index))
1297
+
1298
+ llm_pos_ids_list = []
1299
+ video_frame_num = 1
1300
+ for modality_type, start_idx, end_idx in input_type_group:
1301
+ st_idx = (
1302
+ llm_pos_ids_list[-1].max() + 1
1303
+ if len(llm_pos_ids_list) > 0
1304
+ else 0
1305
+ )
1306
+
1307
+ if modality_type == "image":
1308
+ t, h, w = (
1309
+ image_grid_thw[image_index][0],
1310
+ image_grid_thw[image_index][1],
1311
+ image_grid_thw[image_index][2],
1312
+ )
1313
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1314
+ t.item(),
1315
+ h.item() // spatial_merge_size,
1316
+ w.item() // spatial_merge_size,
1317
+ )
1318
+
1319
+ t_index = (
1320
+ torch.arange(llm_grid_t)
1321
+ .view(-1, 1)
1322
+ .expand(-1, llm_grid_h * llm_grid_w)
1323
+ .flatten()
1324
+ )
1325
+ h_index = (
1326
+ torch.arange(llm_grid_h)
1327
+ .view(1, -1, 1)
1328
+ .expand(llm_grid_t, -1, llm_grid_w)
1329
+ .flatten()
1330
+ )
1331
+ w_index = (
1332
+ torch.arange(llm_grid_w)
1333
+ .view(1, 1, -1)
1334
+ .expand(llm_grid_t, llm_grid_h, -1)
1335
+ .flatten()
1336
+ )
1337
+ llm_pos_ids_list.append(
1338
+ torch.stack([t_index, h_index, w_index]) + st_idx
1339
+ )
1340
+
1341
+ image_index += 1
1342
+ video_frame_num = 1
1343
+
1344
+ elif modality_type == "video":
1345
+ t, h, w = (
1346
+ video_frame_num,
1347
+ video_grid_thw[video_index][1],
1348
+ video_grid_thw[video_index][2],
1349
+ )
1350
+
1351
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1352
+ t,
1353
+ h.item() // spatial_merge_size,
1354
+ w.item() // spatial_merge_size,
1355
+ )
1356
+
1357
+ for t_idx in range(llm_grid_t):
1358
+ t_index = (
1359
+ torch.tensor(t_idx)
1360
+ .view(-1, 1)
1361
+ .expand(-1, llm_grid_h * llm_grid_w)
1362
+ .flatten()
1363
+ )
1364
+
1365
+ h_index = (
1366
+ torch.arange(llm_grid_h)
1367
+ .view(1, -1, 1)
1368
+ .expand(1, -1, llm_grid_w)
1369
+ .flatten()
1370
+ )
1371
+ w_index = (
1372
+ torch.arange(llm_grid_w)
1373
+ .view(1, 1, -1)
1374
+ .expand(1, llm_grid_h, -1)
1375
+ .flatten()
1376
+ )
1377
+ llm_pos_ids_list.append(
1378
+ torch.stack([t_index, h_index, w_index]) + st_idx
1379
+ )
1380
+
1381
+ video_group_index += 1
1382
+
1383
+ if video_group_index >= video_grid_thw[video_index][0]:
1384
+ video_index += 1
1385
+ video_group_index = 0
1386
+
1387
+ video_frame_num += 1
1388
+
1389
+ else:
1390
+ text_len = end_idx - start_idx
1391
+ llm_pos_ids_list.append(
1392
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1393
+ )
1394
+
1395
+ video_frame_num = 1
1396
+
1397
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1398
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
1399
+ position_ids.device
1400
+ )
1401
+ mrope_position_deltas.append(
1402
+ llm_positions.max() + 1 - len(total_input_ids[i])
1403
+ )
1404
+ mrope_position_deltas = torch.tensor(
1405
+ mrope_position_deltas, device=input_ids.device
1406
+ ).unsqueeze(1)
1407
+ return position_ids, mrope_position_deltas
1408
+ else:
1409
+ if attention_mask is not None:
1410
+ position_ids = attention_mask.long().cumsum(-1) - 1
1411
+ position_ids.masked_fill_(attention_mask == 0, 1)
1412
+ position_ids = (
1413
+ position_ids.unsqueeze(0)
1414
+ .expand(3, -1, -1)
1415
+ .to(attention_mask.device)
1416
+ )
1417
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1418
+ -1, keepdim=True
1419
+ )[0]
1420
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1421
+ else:
1422
+ position_ids = (
1423
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1424
+ .view(1, 1, -1)
1425
+ .expand(3, input_ids.shape[0], -1)
1426
+ )
1427
+ mrope_position_deltas = torch.zeros(
1428
+ [input_ids.shape[0], 1],
1429
+ device=input_ids.device,
1430
+ dtype=input_ids.dtype,
1431
+ )
1432
+
1433
+ return position_ids, mrope_position_deltas
1434
+
1156
1435
  @staticmethod
1157
1436
  def get_next_input_positions(
1158
1437
  mrope_position_delta: int,
@@ -6,7 +6,10 @@ import torch.distributed as dist
6
6
  from torch import nn
7
7
 
8
8
  from sglang.srt.distributed import get_tp_group
9
- from sglang.srt.layers.dp_attention import get_attention_tp_group
9
+ from sglang.srt.layers.dp_attention import (
10
+ get_attention_tp_group,
11
+ is_dp_attention_enabled,
12
+ )
10
13
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
11
14
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
15
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
32
35
  self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
33
36
  self.tp_sync_group = get_tp_group().device_group
34
37
 
35
- if global_server_args_dict["enable_dp_attention"]:
38
+ if is_dp_attention_enabled():
36
39
  self.tp_sync_group = get_attention_tp_group().device_group
37
40
 
38
41
  def forward(
@@ -5,22 +5,6 @@ import torch
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
- def get_fuse_output_add_from_name(name: str) -> bool:
9
- mapping = {
10
- "triton": True,
11
- "flashinfer": False,
12
- }
13
- return mapping.get(name, False)
14
-
15
-
16
- def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
17
- mapping = {
18
- "triton": True,
19
- "flashinfer": False,
20
- }
21
- return mapping.get(name, False)
22
-
23
-
24
8
  class BaseLoRABackend:
25
9
  """Base class for different Lora backends.
26
10
  Each backend has its own implementation of Lora kernels.
@@ -28,15 +12,11 @@ class BaseLoRABackend:
28
12
  Args:
29
13
  name: name of backend
30
14
  batch_info: information of current batch for use
31
- fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
- and the operation of adding will be fused into kernel
33
15
  """
34
16
 
35
17
  def __init__(self, name: str, batch_info: LoRABatchInfo = None):
36
18
  self.name = name
37
19
  self.batch_info = batch_info
38
- self.fuse_output_add = get_fuse_output_add_from_name(name)
39
- self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
40
20
 
41
21
  def run_lora_a_sgemm(
42
22
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
126
106
 
127
107
  return TritonLoRABackend
128
108
  elif name == "flashinfer":
129
- from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
130
-
131
- return FlashInferLoRABackend
109
+ raise ValueError(
110
+ "FlashInfer LoRA backend has been deprecated, please use `triton` instead."
111
+ )
132
112
  else:
133
113
  raise ValueError(f"Invalid backend: {name}")
sglang/srt/lora/layers.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List, Tuple
2
-
3
1
  import torch
4
2
  from torch import nn
5
3
 
@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
79
77
  self.B_buffer = B_buffer
80
78
 
81
79
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
82
- backend_kwargs = {"base_output": base_output}
83
80
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
84
81
  lora_output = self.lora_backend.run_lora_b_sgemm(
85
- lora_a_output,
86
- self.B_buffer[0],
87
- **backend_kwargs,
88
- )
89
- return (
90
- lora_output
91
- if self.lora_backend.fuse_output_add
92
- else base_output + lora_output
82
+ x=lora_a_output,
83
+ weights=self.B_buffer,
84
+ base_output=base_output,
93
85
  )
86
+ return lora_output
94
87
 
95
88
  def forward(self, input_: torch.Tensor):
96
89
  # duplicate the logic in ColumnParallelLinear
@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
135
128
  ):
136
129
  self.set_lora = True
137
130
  self.A_buffer_gate_up = A_buffer
138
- if self.lora_backend.fuse_stacked_lora_b:
139
- # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
140
- if getattr(self, "B_buffer_gate_up", None) is None:
141
- self.B_buffer_gate_up = torch.empty(
142
- (
143
- B_buffer[0].shape[0],
144
- 2 * B_buffer[0].shape[1],
145
- B_buffer[0].shape[2],
146
- ),
147
- dtype=B_buffer[0].dtype,
148
- device=B_buffer[0].device,
149
- )
150
- self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
151
- self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
152
- else:
153
- self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
131
+ self.B_buffer_gate_up = B_buffer
154
132
 
155
133
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
156
- backend_kwargs = {"base_output": base_output}
157
-
158
134
  lora_output = self.lora_backend.run_gate_up_lora(
159
- x,
160
- self.A_buffer_gate_up,
161
- self.B_buffer_gate_up,
162
- **backend_kwargs,
163
- )
164
- return (
165
- lora_output
166
- if self.lora_backend.fuse_output_add
167
- else base_output + lora_output
135
+ x=x,
136
+ gate_up_lora_a=self.A_buffer_gate_up,
137
+ gate_up_lora_b=self.B_buffer_gate_up,
138
+ base_output=base_output,
168
139
  )
140
+ return lora_output
169
141
 
170
142
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
171
143
  return A
@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
173
145
  def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
174
146
  # Since the outputs for both gate and up are identical, we use a random one.
175
147
  shard_size = self.base_layer.output_partition_sizes[0]
148
+ gate_size = self.base_layer.output_sizes[0]
176
149
  start_idx = tp_rank * shard_size
177
150
  end_idx = (tp_rank + 1) * shard_size
178
- return B[:, start_idx:end_idx, :]
151
+ return torch.concat(
152
+ (
153
+ B[start_idx:end_idx, :],
154
+ B[gate_size + start_idx : gate_size + end_idx],
155
+ ),
156
+ dim=0,
157
+ )
179
158
 
180
159
 
181
160
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
185
164
  lora_backend: BaseLoRABackend,
186
165
  ) -> None:
187
166
  super().__init__(base_layer, lora_backend)
167
+ q_proj_shard_size = self.base_layer.q_proj_shard_size
168
+ kv_proj_shard_size = self.base_layer.kv_proj_shard_size
169
+ self.output_offset = torch.tensor(
170
+ [
171
+ 0,
172
+ q_proj_shard_size,
173
+ q_proj_shard_size + kv_proj_shard_size,
174
+ q_proj_shard_size + 2 * kv_proj_shard_size,
175
+ ],
176
+ dtype=torch.int32,
177
+ device=next(self.base_layer.parameters()).device,
178
+ )
179
+
180
+ # For computing number of launched blocks
181
+ self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
188
182
 
189
183
  def set_lora_info(
190
184
  self,
191
185
  A_buffer_qkv: torch.Tensor,
192
- B_buffer_q: torch.Tensor,
193
- B_buffer_kv: torch.Tensor,
186
+ B_buffer_qkv: torch.Tensor,
194
187
  ):
195
188
  self.set_lora = True
196
189
  self.A_buffer_qkv = A_buffer_qkv
197
-
198
- if self.lora_backend.fuse_stacked_lora_b:
199
- assert (
200
- B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
201
- ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
202
- output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
203
-
204
- # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
205
- if getattr(self, "B_buffer_qkv", None) is None:
206
- self.B_buffer_qkv = torch.empty(
207
- (
208
- B_buffer_q[0].shape[0],
209
- output_dim_q + 2 * output_dim_kv,
210
- B_buffer_q[0].shape[2],
211
- ),
212
- dtype=B_buffer_q[0].dtype,
213
- device=B_buffer_q[0].device,
214
- )
215
- self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
216
- self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
217
- B_buffer_kv[0]
218
- )
219
- self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
220
- B_buffer_kv[1]
221
- )
222
-
223
- # Offsets of q/k/v in output dimension
224
- if getattr(self, "output_offset", None) is None:
225
- self.output_offset = torch.tensor(
226
- [
227
- 0,
228
- output_dim_q,
229
- output_dim_q + output_dim_kv,
230
- output_dim_q + 2 * output_dim_kv,
231
- ],
232
- dtype=torch.int32,
233
- device=B_buffer_q.device,
234
- )
235
- # For computing number of launched blocks
236
- self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
237
- else:
238
- self.B_buffer_qkv = (
239
- B_buffer_q,
240
- B_buffer_kv,
241
- )
190
+ self.B_buffer_qkv = B_buffer_qkv
242
191
 
243
192
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
244
- backend_kwargs = {"base_output": base_output}
245
- if self.lora_backend.fuse_stacked_lora_b:
246
- backend_kwargs["output_offset"] = self.output_offset
247
- backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
248
-
249
193
  lora_output = self.lora_backend.run_qkv_lora(
250
- x,
251
- self.A_buffer_qkv,
252
- self.B_buffer_qkv,
253
- **backend_kwargs,
254
- )
255
- return (
256
- lora_output
257
- if self.lora_backend.fuse_output_add
258
- else base_output + lora_output
194
+ x=x,
195
+ qkv_lora_a=self.A_buffer_qkv,
196
+ qkv_lora_b=self.B_buffer_qkv,
197
+ base_output=base_output,
198
+ output_offset=self.output_offset,
199
+ max_qkv_out_dim=self.max_qkv_out_dim,
259
200
  )
201
+ return lora_output
260
202
 
261
203
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
262
204
  return A
263
205
 
264
- def slice_lora_b_weights(
265
- self, B: List[torch.Tensor], tp_rank: int
266
- ) -> Tuple[torch.Tensor, torch.Tensor]:
267
- B_q, B_kv = B
206
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
268
207
  base_layer = self.base_layer
269
208
  q_proj_shard_size = base_layer.q_proj_shard_size
270
209
  kv_proj_shard_size = base_layer.kv_proj_shard_size
@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
277
216
  kv_start_idx = kv_proj_shard_size * kv_shard_id
278
217
  kv_end_idx = kv_start_idx + kv_proj_shard_size
279
218
 
280
- return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
219
+ q_size, k_size, _ = base_layer.output_sizes
220
+ B_q_shard = B[q_start_idx:q_end_idx, :]
221
+ B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
222
+ B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
223
+
224
+ return torch.concat(
225
+ (
226
+ B_q_shard,
227
+ B_k_shard,
228
+ B_v_shard,
229
+ ),
230
+ dim=0,
231
+ )
281
232
 
282
233
 
283
234
  class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@@ -294,20 +245,15 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
294
245
  self.B_buffer = B_buffer
295
246
 
296
247
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
297
- backend_kwargs = {"base_output": base_output}
298
248
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
299
249
  lora_output = self.lora_backend.run_lora_b_sgemm(
300
- lora_a_output,
301
- self.B_buffer[0],
302
- **backend_kwargs,
303
- )
304
- return (
305
- lora_output
306
- if self.lora_backend.fuse_output_add
307
- else base_output + lora_output
250
+ x=lora_a_output,
251
+ weights=self.B_buffer,
252
+ base_output=base_output,
308
253
  )
254
+ return lora_output
309
255
 
310
- def forward(self, input_: torch.Tensor):
256
+ def forward(self, input_: torch.Tensor, skip_all_reduce=False):
311
257
  # duplicate the logic in RowParallelLinear
312
258
  if self.base_layer.input_is_parallel:
313
259
  input_parallel = input_
@@ -324,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
324
270
  if self.set_lora:
325
271
  output_parallel = self.apply_lora(output_parallel, input_parallel)
326
272
 
327
- if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
273
+ if (
274
+ self.base_layer.reduce_results
275
+ and self.base_layer.tp_size > 1
276
+ and not skip_all_reduce
277
+ ):
328
278
  output_ = tensor_model_parallel_all_reduce(output_parallel)
329
279
  else:
330
280
  output_ = output_parallel