sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
2
2
 
3
3
  """Rotary Positional Embeddings."""
4
+ import itertools
4
5
  import math
5
6
  from typing import Any, Dict, List, Optional, Tuple, Union
6
7
 
@@ -221,6 +222,7 @@ class RotaryEmbedding(CustomOp):
221
222
  query: torch.Tensor,
222
223
  key: torch.Tensor,
223
224
  offsets: Optional[torch.Tensor] = None,
225
+ fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
224
226
  ) -> Tuple[torch.Tensor, torch.Tensor]:
225
227
  if _is_cuda and (self.head_size in [64, 128, 256, 512]):
226
228
  apply_rope_with_cos_sin_cache_inplace(
@@ -230,8 +232,17 @@ class RotaryEmbedding(CustomOp):
230
232
  head_size=self.head_size,
231
233
  cos_sin_cache=self.cos_sin_cache,
232
234
  is_neox=self.is_neox_style,
235
+ # Compatible with old sgl-kernel
236
+ **(
237
+ dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
238
+ if fused_set_kv_buffer_arg is not None
239
+ else {}
240
+ ),
233
241
  )
234
242
  else:
243
+ assert (
244
+ fused_set_kv_buffer_arg is None
245
+ ), "save kv cache is not supported for vllm_rotary_embedding."
235
246
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
236
247
  self.vllm_rotary_embedding(
237
248
  positions,
@@ -679,7 +690,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
679
690
  )
680
691
 
681
692
  # Re-dispatch
682
- if _is_hip or _is_npu:
693
+ if _is_hip:
683
694
  self._forward_method = self.forward_native
684
695
 
685
696
  def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -764,6 +775,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
764
775
  key = key_rot
765
776
  return query.to(dtype), key.to(dtype)
766
777
 
778
+ def forward_npu(
779
+ self,
780
+ positions: torch.Tensor,
781
+ query: torch.Tensor,
782
+ key: torch.Tensor,
783
+ offsets: Optional[torch.Tensor] = None,
784
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
785
+ # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
786
+ # and generalization to more scenarios will be supported in the future.
787
+ if query.shape[1] * query.shape[2] > 4096:
788
+ return self.forward_native(positions, query, key, offsets)
789
+ num_tokens = query.shape[0]
790
+ rotary_mode = "half" if self.is_neox_style else "interleave"
791
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
792
+ query_rot = query[..., : self.rotary_dim]
793
+ key_rot = key[..., : self.rotary_dim]
794
+ if self.rotary_dim < self.head_size:
795
+ query_pass = query[..., self.rotary_dim :]
796
+ key_pass = key[..., self.rotary_dim :]
797
+
798
+ query_rot, key_rot = torch_npu.npu_mrope(
799
+ torch.add(positions, offsets) if offsets is not None else positions,
800
+ query_rot.reshape(num_tokens, -1),
801
+ key_rot.reshape(num_tokens, -1),
802
+ self.cos_sin_cache,
803
+ self.rotary_dim,
804
+ mrope_section=[0, 0, 0],
805
+ rotary_mode=rotary_mode,
806
+ )
807
+ query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
808
+ key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
809
+
810
+ if self.rotary_dim < self.head_size:
811
+ query = torch.cat((query_rot, query_pass), dim=-1)
812
+ key = torch.cat((key_rot, key_pass), dim=-1)
813
+ else:
814
+ query = query_rot
815
+ key = key_rot
816
+ return query, key
817
+
767
818
  def forward_cpu(
768
819
  self,
769
820
  positions: torch.Tensor,
@@ -946,7 +997,37 @@ class MRotaryEmbedding(RotaryEmbedding):
946
997
 
947
998
  self.mrope_section = mrope_section
948
999
  if self.mrope_section:
949
- assert sum(self.mrope_section) == rotary_dim // 2
1000
+ expected_sum = rotary_dim // 2
1001
+ actual_sum = sum(self.mrope_section)
1002
+ if actual_sum != expected_sum:
1003
+ print(
1004
+ f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. "
1005
+ f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}"
1006
+ )
1007
+ # Auto-correct by scaling the mrope_section proportionally
1008
+ if actual_sum > 0:
1009
+ scale_factor = expected_sum / actual_sum
1010
+ self.mrope_section = [
1011
+ max(1, int(section * scale_factor))
1012
+ for section in self.mrope_section
1013
+ ]
1014
+ # Ensure the sum exactly matches by adjusting the last element
1015
+ current_sum = sum(self.mrope_section)
1016
+ if current_sum != expected_sum:
1017
+ self.mrope_section[-1] += expected_sum - current_sum
1018
+ else:
1019
+ # If all sections are 0, create a default distribution
1020
+ self.mrope_section = [
1021
+ expected_sum // len(self.mrope_section)
1022
+ ] * len(self.mrope_section)
1023
+ # Handle remainder
1024
+ remainder = expected_sum % len(self.mrope_section)
1025
+ for i in range(remainder):
1026
+ self.mrope_section[i] += 1
1027
+
1028
+ print(
1029
+ f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
+ )
950
1031
 
951
1032
  def forward(
952
1033
  self,
@@ -1153,6 +1234,204 @@ class MRotaryEmbedding(RotaryEmbedding):
1153
1234
  mrope_position_deltas = max_position_ids + 1 - s
1154
1235
  return position_ids, mrope_position_deltas
1155
1236
 
1237
+ # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
1238
+ @staticmethod
1239
+ def get_rope_index_glm4v(
1240
+ input_ids: torch.Tensor,
1241
+ hf_config: Any,
1242
+ image_grid_thw: Union[list[list[int]], torch.Tensor],
1243
+ video_grid_thw: Union[list[list[int]], torch.Tensor],
1244
+ attention_mask: torch.Tensor,
1245
+ **kwargs,
1246
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1247
+ """Get mrope input positions and delta value for GLM4V."""
1248
+ image_token_id = hf_config.image_token_id
1249
+ video_start_token_id = hf_config.video_start_token_id
1250
+ video_end_token_id = hf_config.video_end_token_id
1251
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
1252
+
1253
+ mrope_position_deltas = []
1254
+ if input_ids is not None and (
1255
+ image_grid_thw is not None or video_grid_thw is not None
1256
+ ):
1257
+ total_input_ids = input_ids
1258
+ if attention_mask is None:
1259
+ attention_mask = torch.ones_like(total_input_ids)
1260
+ position_ids = torch.ones(
1261
+ 3,
1262
+ input_ids.shape[0],
1263
+ input_ids.shape[1],
1264
+ dtype=input_ids.dtype,
1265
+ device=input_ids.device,
1266
+ )
1267
+ image_index, video_index = 0, 0
1268
+ video_group_index = 0
1269
+ attention_mask = attention_mask.to(total_input_ids.device)
1270
+ for i, input_ids in enumerate(total_input_ids):
1271
+ input_ids = input_ids[attention_mask[i] == 1]
1272
+ input_tokens = input_ids.tolist()
1273
+
1274
+ input_token_type = []
1275
+ video_check_flg = False
1276
+ for token in input_tokens:
1277
+ if token == video_start_token_id:
1278
+ video_check_flg = True
1279
+ elif token == video_end_token_id:
1280
+ video_check_flg = False
1281
+
1282
+ if token == image_token_id and not video_check_flg:
1283
+ input_token_type.append("image")
1284
+ elif token == image_token_id and video_check_flg:
1285
+ input_token_type.append("video")
1286
+ else:
1287
+ input_token_type.append("text")
1288
+
1289
+ input_type_group = []
1290
+ for key, group in itertools.groupby(
1291
+ enumerate(input_token_type), lambda x: x[1]
1292
+ ):
1293
+ group = list(group)
1294
+ start_index = group[0][0]
1295
+ end_index = group[-1][0] + 1
1296
+ input_type_group.append((key, start_index, end_index))
1297
+
1298
+ llm_pos_ids_list = []
1299
+ video_frame_num = 1
1300
+ for modality_type, start_idx, end_idx in input_type_group:
1301
+ st_idx = (
1302
+ llm_pos_ids_list[-1].max() + 1
1303
+ if len(llm_pos_ids_list) > 0
1304
+ else 0
1305
+ )
1306
+
1307
+ if modality_type == "image":
1308
+ t, h, w = (
1309
+ image_grid_thw[image_index][0],
1310
+ image_grid_thw[image_index][1],
1311
+ image_grid_thw[image_index][2],
1312
+ )
1313
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1314
+ t.item(),
1315
+ h.item() // spatial_merge_size,
1316
+ w.item() // spatial_merge_size,
1317
+ )
1318
+
1319
+ t_index = (
1320
+ torch.arange(llm_grid_t)
1321
+ .view(-1, 1)
1322
+ .expand(-1, llm_grid_h * llm_grid_w)
1323
+ .flatten()
1324
+ )
1325
+ h_index = (
1326
+ torch.arange(llm_grid_h)
1327
+ .view(1, -1, 1)
1328
+ .expand(llm_grid_t, -1, llm_grid_w)
1329
+ .flatten()
1330
+ )
1331
+ w_index = (
1332
+ torch.arange(llm_grid_w)
1333
+ .view(1, 1, -1)
1334
+ .expand(llm_grid_t, llm_grid_h, -1)
1335
+ .flatten()
1336
+ )
1337
+ llm_pos_ids_list.append(
1338
+ torch.stack([t_index, h_index, w_index]) + st_idx
1339
+ )
1340
+
1341
+ image_index += 1
1342
+ video_frame_num = 1
1343
+
1344
+ elif modality_type == "video":
1345
+ t, h, w = (
1346
+ video_frame_num,
1347
+ video_grid_thw[video_index][1],
1348
+ video_grid_thw[video_index][2],
1349
+ )
1350
+
1351
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1352
+ t,
1353
+ h.item() // spatial_merge_size,
1354
+ w.item() // spatial_merge_size,
1355
+ )
1356
+
1357
+ for t_idx in range(llm_grid_t):
1358
+ t_index = (
1359
+ torch.tensor(t_idx)
1360
+ .view(-1, 1)
1361
+ .expand(-1, llm_grid_h * llm_grid_w)
1362
+ .flatten()
1363
+ )
1364
+
1365
+ h_index = (
1366
+ torch.arange(llm_grid_h)
1367
+ .view(1, -1, 1)
1368
+ .expand(1, -1, llm_grid_w)
1369
+ .flatten()
1370
+ )
1371
+ w_index = (
1372
+ torch.arange(llm_grid_w)
1373
+ .view(1, 1, -1)
1374
+ .expand(1, llm_grid_h, -1)
1375
+ .flatten()
1376
+ )
1377
+ llm_pos_ids_list.append(
1378
+ torch.stack([t_index, h_index, w_index]) + st_idx
1379
+ )
1380
+
1381
+ video_group_index += 1
1382
+
1383
+ if video_group_index >= video_grid_thw[video_index][0]:
1384
+ video_index += 1
1385
+ video_group_index = 0
1386
+
1387
+ video_frame_num += 1
1388
+
1389
+ else:
1390
+ text_len = end_idx - start_idx
1391
+ llm_pos_ids_list.append(
1392
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1393
+ )
1394
+
1395
+ video_frame_num = 1
1396
+
1397
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1398
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
1399
+ position_ids.device
1400
+ )
1401
+ mrope_position_deltas.append(
1402
+ llm_positions.max() + 1 - len(total_input_ids[i])
1403
+ )
1404
+ mrope_position_deltas = torch.tensor(
1405
+ mrope_position_deltas, device=input_ids.device
1406
+ ).unsqueeze(1)
1407
+ return position_ids, mrope_position_deltas
1408
+ else:
1409
+ if attention_mask is not None:
1410
+ position_ids = attention_mask.long().cumsum(-1) - 1
1411
+ position_ids.masked_fill_(attention_mask == 0, 1)
1412
+ position_ids = (
1413
+ position_ids.unsqueeze(0)
1414
+ .expand(3, -1, -1)
1415
+ .to(attention_mask.device)
1416
+ )
1417
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1418
+ -1, keepdim=True
1419
+ )[0]
1420
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1421
+ else:
1422
+ position_ids = (
1423
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1424
+ .view(1, 1, -1)
1425
+ .expand(3, input_ids.shape[0], -1)
1426
+ )
1427
+ mrope_position_deltas = torch.zeros(
1428
+ [input_ids.shape[0], 1],
1429
+ device=input_ids.device,
1430
+ dtype=input_ids.dtype,
1431
+ )
1432
+
1433
+ return position_ids, mrope_position_deltas
1434
+
1156
1435
  @staticmethod
1157
1436
  def get_next_input_positions(
1158
1437
  mrope_position_delta: int,
@@ -1172,6 +1451,202 @@ class MRotaryEmbedding(RotaryEmbedding):
1172
1451
  )
1173
1452
 
1174
1453
 
1454
+ class DualChunkRotaryEmbedding(CustomOp):
1455
+ """Rotary positional embedding for Dual Chunk Attention."""
1456
+
1457
+ def __init__(
1458
+ self,
1459
+ head_size: int,
1460
+ rotary_dim: int,
1461
+ max_position_embeddings: int,
1462
+ base: int,
1463
+ is_neox_style: bool,
1464
+ dtype: torch.dtype,
1465
+ chunk_size: int,
1466
+ local_size: int,
1467
+ ) -> None:
1468
+ super().__init__()
1469
+ self.head_size = head_size
1470
+ self.rotary_dim = rotary_dim
1471
+ self.max_position_embeddings = max_position_embeddings
1472
+ self.base = base
1473
+ self.is_neox_style = is_neox_style
1474
+ self.chunk_size = chunk_size
1475
+ self.local_size = local_size
1476
+ self.dtype = dtype
1477
+ self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
1478
+ (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
1479
+ self._compute_cos_sin_cache()
1480
+ )
1481
+
1482
+ self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
1483
+ self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
1484
+ self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
1485
+ self.register_buffer(
1486
+ "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
1487
+ )
1488
+ self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
1489
+
1490
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
1491
+ """Compute the inverse frequency."""
1492
+ # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
1493
+ # However, we use `torch.arange(..., dtype=torch.float)` instead to
1494
+ # avoid numerical issues with large base values (e.g., 10000000).
1495
+ # This may cause a slight numerical difference between the HF
1496
+ # implementation and ours.
1497
+ # NOTE(woosuk): To exactly match the HF implementation, we need to
1498
+ # use CPU to compute the cache and then move it to GPU. However, we
1499
+ # create the cache on GPU for faster initialization. This may cause
1500
+ # a slight numerical difference between the HF implementation and ours.
1501
+ inv_freq = 1.0 / (
1502
+ base
1503
+ ** (
1504
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
1505
+ )
1506
+ )
1507
+ return inv_freq
1508
+
1509
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
1510
+ """Compute the cos and sin cache."""
1511
+ inv_freq = self._compute_inv_freq(self.base)
1512
+ chunk_len = self.chunk_size - self.local_size
1513
+ q_t = torch.arange(chunk_len, dtype=torch.float)
1514
+ qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
1515
+ max=self.chunk_size
1516
+ )
1517
+ k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
1518
+
1519
+ # count from chunk_len, no clamp(self.chunk_size) restriction
1520
+ qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
1521
+ # count from self.chunk_size for q_inter's rope
1522
+ q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
1523
+
1524
+ q_freqs = torch.outer(q_t, inv_freq)
1525
+ qc_freqs = torch.outer(qc_t, inv_freq)
1526
+ k_freqs = torch.outer(k_t, inv_freq)
1527
+ qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
1528
+ q_inter_freqs = torch.outer(q_inter_t, inv_freq)
1529
+
1530
+ q_cos = q_freqs.cos()
1531
+ q_sin = q_freqs.sin()
1532
+ qc_cos = qc_freqs.cos()
1533
+ qc_sin = qc_freqs.sin()
1534
+ k_cos = k_freqs.cos()
1535
+ k_sin = k_freqs.sin()
1536
+
1537
+ qc_no_clamp_cos = qc_no_clamp_freqs.cos()
1538
+ qc_no_clamp_sin = qc_no_clamp_freqs.sin()
1539
+ q_inter_cos = q_inter_freqs.cos()
1540
+ q_inter_sin = q_inter_freqs.sin()
1541
+
1542
+ q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
1543
+ dtype=self.dtype, device=self.device
1544
+ )
1545
+ qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
1546
+ dtype=self.dtype, device=self.device
1547
+ )
1548
+ k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
1549
+ dtype=self.dtype, device=self.device
1550
+ )
1551
+ qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
1552
+ dtype=self.dtype, device=self.device
1553
+ )
1554
+ q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
1555
+ dtype=self.dtype, device=self.device
1556
+ )
1557
+ return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
1558
+
1559
+ def forward(
1560
+ self,
1561
+ positions: torch.Tensor,
1562
+ query: torch.Tensor,
1563
+ key: torch.Tensor,
1564
+ offsets: Optional[torch.Tensor] = None,
1565
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1566
+ query = query.view(*query.shape[:-1], -1, self.head_size)
1567
+ key = key.view(*key.shape[:-1], -1, self.head_size)
1568
+ query_rot = query[..., : self.rotary_dim]
1569
+ key_rot = key[..., : self.rotary_dim]
1570
+ if self.rotary_dim < self.head_size:
1571
+ query_pass = query[..., self.rotary_dim :]
1572
+ key_pass = key[..., self.rotary_dim :]
1573
+ else:
1574
+ query_pass = None
1575
+ key_pass = None
1576
+
1577
+ positions_with_offsets = (
1578
+ torch.add(positions, offsets) if offsets is not None else positions
1579
+ )
1580
+ key = self._apply_rotary_embedding(
1581
+ self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
1582
+ )
1583
+ chunk_len = self.chunk_size - self.local_size
1584
+ query = self._apply_rotary_embedding(
1585
+ self.cos_sin_q_cache[positions_with_offsets % chunk_len],
1586
+ query_rot,
1587
+ query_pass,
1588
+ )
1589
+ query_succ = self._apply_rotary_embedding(
1590
+ self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
1591
+ query_rot,
1592
+ query_pass,
1593
+ )
1594
+ query_inter = self._apply_rotary_embedding(
1595
+ self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
1596
+ query_rot,
1597
+ query_pass,
1598
+ )
1599
+ query_succ_critical = self._apply_rotary_embedding(
1600
+ self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
1601
+ query_rot,
1602
+ query_pass,
1603
+ )
1604
+ query_inter_critical = self._apply_rotary_embedding(
1605
+ self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
1606
+ query_rot,
1607
+ query_pass,
1608
+ )
1609
+
1610
+ # merge query into one tensor to simplify the interfaces
1611
+ query = torch.cat(
1612
+ (
1613
+ query,
1614
+ query_succ,
1615
+ query_inter,
1616
+ query_succ_critical,
1617
+ query_inter_critical,
1618
+ ),
1619
+ dim=-1,
1620
+ )
1621
+ return query, key
1622
+
1623
+ def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
1624
+ cos, sin = cos_sin.chunk(2, dim=-1)
1625
+ if self.is_neox_style:
1626
+ # NOTE(woosuk): Here we assume that the positions tensor has the
1627
+ # shape [batch_size, seq_len].
1628
+ cos = cos.repeat(1, 1, 2).unsqueeze(-2)
1629
+ sin = sin.repeat(1, 1, 2).unsqueeze(-2)
1630
+ else:
1631
+ cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
1632
+ sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
1633
+ rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
1634
+ hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
1635
+
1636
+ if self.rotary_dim < self.head_size:
1637
+ hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
1638
+ else:
1639
+ hidden = hidden_rot
1640
+ return hidden.flatten(-2).squeeze(0)
1641
+
1642
+ def extra_repr(self) -> str:
1643
+ s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
1644
+ s += f", max_position_embeddings={self.max_position_embeddings}"
1645
+ s += f", base={self.base}, is_neox_style={self.is_neox_style}"
1646
+ s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
1647
+ return s
1648
+
1649
+
1175
1650
  _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
1176
1651
 
1177
1652
 
@@ -1184,6 +1659,7 @@ def get_rope(
1184
1659
  rope_scaling: Optional[Dict[str, Any]] = None,
1185
1660
  dtype: Optional[torch.dtype] = None,
1186
1661
  partial_rotary_factor: float = 1.0,
1662
+ dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
1187
1663
  ) -> RotaryEmbedding:
1188
1664
  if dtype is None:
1189
1665
  dtype = torch.get_default_dtype()
@@ -1195,6 +1671,17 @@ def get_rope(
1195
1671
  rope_scaling_args = tuple(rope_scaling_tuple.items())
1196
1672
  else:
1197
1673
  rope_scaling_args = None
1674
+
1675
+ if dual_chunk_attention_config is not None:
1676
+ dual_chunk_attention_tuple = {
1677
+ k: tuple(v) if isinstance(v, list) else v
1678
+ for k, v in dual_chunk_attention_config.items()
1679
+ if k != "sparse_attention_config"
1680
+ }
1681
+ dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
1682
+ else:
1683
+ dual_chunk_attention_args = None
1684
+
1198
1685
  if partial_rotary_factor < 1.0:
1199
1686
  rotary_dim = int(rotary_dim * partial_rotary_factor)
1200
1687
  key = (
@@ -1204,12 +1691,28 @@ def get_rope(
1204
1691
  base,
1205
1692
  is_neox_style,
1206
1693
  rope_scaling_args,
1694
+ dual_chunk_attention_args,
1207
1695
  dtype,
1208
1696
  )
1209
1697
  if key in _ROPE_DICT:
1210
1698
  return _ROPE_DICT[key]
1211
1699
 
1212
- if rope_scaling is None:
1700
+ if dual_chunk_attention_config is not None:
1701
+ extra_kwargs = {
1702
+ k: v
1703
+ for k, v in dual_chunk_attention_config.items()
1704
+ if k in ("chunk_size", "local_size")
1705
+ }
1706
+ rotary_emb = DualChunkRotaryEmbedding(
1707
+ head_size,
1708
+ rotary_dim,
1709
+ max_position,
1710
+ base,
1711
+ is_neox_style,
1712
+ dtype,
1713
+ **extra_kwargs,
1714
+ )
1715
+ elif rope_scaling is None:
1213
1716
  rotary_emb = RotaryEmbedding(
1214
1717
  head_size, rotary_dim, max_position, base, is_neox_style, dtype
1215
1718
  )
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import re
3
+ from functools import lru_cache
3
4
 
4
5
  import torch
5
6
 
@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
35
36
  return (input,) if self.return_tuple else input
36
37
 
37
38
 
39
+ @lru_cache(maxsize=1)
38
40
  def is_sm100_supported(device=None) -> bool:
39
41
  return (torch.cuda.get_device_capability(device)[0] == 10) and (
40
42
  torch.version.cuda >= "12.8"
41
43
  )
44
+
45
+
46
+ @lru_cache(maxsize=1)
47
+ def is_sm90_supported(device=None) -> bool:
48
+ return (torch.cuda.get_device_capability(device)[0] == 9) and (
49
+ torch.version.cuda >= "12.3"
50
+ )
@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import (
26
26
  method_has_implemented_embedding,
27
27
  )
28
28
  from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
29
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
29
+ from sglang.srt.utils import (
30
+ cpu_has_amx_support,
31
+ get_compiler_backend,
32
+ is_cpu,
33
+ set_weight_attrs,
34
+ )
30
35
 
31
36
  DEFAULT_VOCAB_PADDING_SIZE = 64
32
37
 
@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices:
117
122
  assert self.num_added_elements <= self.num_added_elements_padded
118
123
 
119
124
 
120
- @torch.jit.script
125
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
121
126
  def get_masked_input_and_mask(
122
127
  input_: torch.Tensor,
123
128
  org_vocab_start_index: int,
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
126
131
  added_vocab_start_index: int,
127
132
  added_vocab_end_index: int,
128
133
  ) -> Tuple[torch.Tensor, torch.Tensor]:
129
- # torch.jit.script will fuse all of the pointwise ops below
134
+ # torch.compile will fuse all of the pointwise ops below
130
135
  # into a single kernel, making it very fast
131
136
  org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
132
137
  added_vocab_mask = (input_ >= added_vocab_start_index) & (
@@ -5,22 +5,6 @@ import torch
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
- def get_fuse_output_add_from_name(name: str) -> bool:
9
- mapping = {
10
- "triton": True,
11
- "flashinfer": False,
12
- }
13
- return mapping.get(name, False)
14
-
15
-
16
- def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
17
- mapping = {
18
- "triton": True,
19
- "flashinfer": False,
20
- }
21
- return mapping.get(name, False)
22
-
23
-
24
8
  class BaseLoRABackend:
25
9
  """Base class for different Lora backends.
26
10
  Each backend has its own implementation of Lora kernels.
@@ -28,15 +12,11 @@ class BaseLoRABackend:
28
12
  Args:
29
13
  name: name of backend
30
14
  batch_info: information of current batch for use
31
- fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
- and the operation of adding will be fused into kernel
33
15
  """
34
16
 
35
17
  def __init__(self, name: str, batch_info: LoRABatchInfo = None):
36
18
  self.name = name
37
19
  self.batch_info = batch_info
38
- self.fuse_output_add = get_fuse_output_add_from_name(name)
39
- self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
40
20
 
41
21
  def run_lora_a_sgemm(
42
22
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
126
106
 
127
107
  return TritonLoRABackend
128
108
  elif name == "flashinfer":
129
- from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
130
-
131
- return FlashInferLoRABackend
109
+ raise ValueError(
110
+ "FlashInfer LoRA backend has been deprecated, please use `triton` instead."
111
+ )
132
112
  else:
133
113
  raise ValueError(f"Invalid backend: {name}")