sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
|
2
2
|
|
3
3
|
"""Rotary Positional Embeddings."""
|
4
|
+
import itertools
|
4
5
|
import math
|
5
6
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
6
7
|
|
@@ -221,6 +222,7 @@ class RotaryEmbedding(CustomOp):
|
|
221
222
|
query: torch.Tensor,
|
222
223
|
key: torch.Tensor,
|
223
224
|
offsets: Optional[torch.Tensor] = None,
|
225
|
+
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
|
224
226
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
225
227
|
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
226
228
|
apply_rope_with_cos_sin_cache_inplace(
|
@@ -230,8 +232,17 @@ class RotaryEmbedding(CustomOp):
|
|
230
232
|
head_size=self.head_size,
|
231
233
|
cos_sin_cache=self.cos_sin_cache,
|
232
234
|
is_neox=self.is_neox_style,
|
235
|
+
# Compatible with old sgl-kernel
|
236
|
+
**(
|
237
|
+
dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
|
238
|
+
if fused_set_kv_buffer_arg is not None
|
239
|
+
else {}
|
240
|
+
),
|
233
241
|
)
|
234
242
|
else:
|
243
|
+
assert (
|
244
|
+
fused_set_kv_buffer_arg is None
|
245
|
+
), "save kv cache is not supported for vllm_rotary_embedding."
|
235
246
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
236
247
|
self.vllm_rotary_embedding(
|
237
248
|
positions,
|
@@ -679,7 +690,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
679
690
|
)
|
680
691
|
|
681
692
|
# Re-dispatch
|
682
|
-
if _is_hip
|
693
|
+
if _is_hip:
|
683
694
|
self._forward_method = self.forward_native
|
684
695
|
|
685
696
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
@@ -764,6 +775,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
764
775
|
key = key_rot
|
765
776
|
return query.to(dtype), key.to(dtype)
|
766
777
|
|
778
|
+
def forward_npu(
|
779
|
+
self,
|
780
|
+
positions: torch.Tensor,
|
781
|
+
query: torch.Tensor,
|
782
|
+
key: torch.Tensor,
|
783
|
+
offsets: Optional[torch.Tensor] = None,
|
784
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
785
|
+
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
|
786
|
+
# and generalization to more scenarios will be supported in the future.
|
787
|
+
if query.shape[1] * query.shape[2] > 4096:
|
788
|
+
return self.forward_native(positions, query, key, offsets)
|
789
|
+
num_tokens = query.shape[0]
|
790
|
+
rotary_mode = "half" if self.is_neox_style else "interleave"
|
791
|
+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
792
|
+
query_rot = query[..., : self.rotary_dim]
|
793
|
+
key_rot = key[..., : self.rotary_dim]
|
794
|
+
if self.rotary_dim < self.head_size:
|
795
|
+
query_pass = query[..., self.rotary_dim :]
|
796
|
+
key_pass = key[..., self.rotary_dim :]
|
797
|
+
|
798
|
+
query_rot, key_rot = torch_npu.npu_mrope(
|
799
|
+
torch.add(positions, offsets) if offsets is not None else positions,
|
800
|
+
query_rot.reshape(num_tokens, -1),
|
801
|
+
key_rot.reshape(num_tokens, -1),
|
802
|
+
self.cos_sin_cache,
|
803
|
+
self.rotary_dim,
|
804
|
+
mrope_section=[0, 0, 0],
|
805
|
+
rotary_mode=rotary_mode,
|
806
|
+
)
|
807
|
+
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
808
|
+
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
809
|
+
|
810
|
+
if self.rotary_dim < self.head_size:
|
811
|
+
query = torch.cat((query_rot, query_pass), dim=-1)
|
812
|
+
key = torch.cat((key_rot, key_pass), dim=-1)
|
813
|
+
else:
|
814
|
+
query = query_rot
|
815
|
+
key = key_rot
|
816
|
+
return query, key
|
817
|
+
|
767
818
|
def forward_cpu(
|
768
819
|
self,
|
769
820
|
positions: torch.Tensor,
|
@@ -946,7 +997,37 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
946
997
|
|
947
998
|
self.mrope_section = mrope_section
|
948
999
|
if self.mrope_section:
|
949
|
-
|
1000
|
+
expected_sum = rotary_dim // 2
|
1001
|
+
actual_sum = sum(self.mrope_section)
|
1002
|
+
if actual_sum != expected_sum:
|
1003
|
+
print(
|
1004
|
+
f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. "
|
1005
|
+
f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}"
|
1006
|
+
)
|
1007
|
+
# Auto-correct by scaling the mrope_section proportionally
|
1008
|
+
if actual_sum > 0:
|
1009
|
+
scale_factor = expected_sum / actual_sum
|
1010
|
+
self.mrope_section = [
|
1011
|
+
max(1, int(section * scale_factor))
|
1012
|
+
for section in self.mrope_section
|
1013
|
+
]
|
1014
|
+
# Ensure the sum exactly matches by adjusting the last element
|
1015
|
+
current_sum = sum(self.mrope_section)
|
1016
|
+
if current_sum != expected_sum:
|
1017
|
+
self.mrope_section[-1] += expected_sum - current_sum
|
1018
|
+
else:
|
1019
|
+
# If all sections are 0, create a default distribution
|
1020
|
+
self.mrope_section = [
|
1021
|
+
expected_sum // len(self.mrope_section)
|
1022
|
+
] * len(self.mrope_section)
|
1023
|
+
# Handle remainder
|
1024
|
+
remainder = expected_sum % len(self.mrope_section)
|
1025
|
+
for i in range(remainder):
|
1026
|
+
self.mrope_section[i] += 1
|
1027
|
+
|
1028
|
+
print(
|
1029
|
+
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
1030
|
+
)
|
950
1031
|
|
951
1032
|
def forward(
|
952
1033
|
self,
|
@@ -1153,6 +1234,204 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1153
1234
|
mrope_position_deltas = max_position_ids + 1 - s
|
1154
1235
|
return position_ids, mrope_position_deltas
|
1155
1236
|
|
1237
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
|
1238
|
+
@staticmethod
|
1239
|
+
def get_rope_index_glm4v(
|
1240
|
+
input_ids: torch.Tensor,
|
1241
|
+
hf_config: Any,
|
1242
|
+
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
1243
|
+
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
1244
|
+
attention_mask: torch.Tensor,
|
1245
|
+
**kwargs,
|
1246
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1247
|
+
"""Get mrope input positions and delta value for GLM4V."""
|
1248
|
+
image_token_id = hf_config.image_token_id
|
1249
|
+
video_start_token_id = hf_config.video_start_token_id
|
1250
|
+
video_end_token_id = hf_config.video_end_token_id
|
1251
|
+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
1252
|
+
|
1253
|
+
mrope_position_deltas = []
|
1254
|
+
if input_ids is not None and (
|
1255
|
+
image_grid_thw is not None or video_grid_thw is not None
|
1256
|
+
):
|
1257
|
+
total_input_ids = input_ids
|
1258
|
+
if attention_mask is None:
|
1259
|
+
attention_mask = torch.ones_like(total_input_ids)
|
1260
|
+
position_ids = torch.ones(
|
1261
|
+
3,
|
1262
|
+
input_ids.shape[0],
|
1263
|
+
input_ids.shape[1],
|
1264
|
+
dtype=input_ids.dtype,
|
1265
|
+
device=input_ids.device,
|
1266
|
+
)
|
1267
|
+
image_index, video_index = 0, 0
|
1268
|
+
video_group_index = 0
|
1269
|
+
attention_mask = attention_mask.to(total_input_ids.device)
|
1270
|
+
for i, input_ids in enumerate(total_input_ids):
|
1271
|
+
input_ids = input_ids[attention_mask[i] == 1]
|
1272
|
+
input_tokens = input_ids.tolist()
|
1273
|
+
|
1274
|
+
input_token_type = []
|
1275
|
+
video_check_flg = False
|
1276
|
+
for token in input_tokens:
|
1277
|
+
if token == video_start_token_id:
|
1278
|
+
video_check_flg = True
|
1279
|
+
elif token == video_end_token_id:
|
1280
|
+
video_check_flg = False
|
1281
|
+
|
1282
|
+
if token == image_token_id and not video_check_flg:
|
1283
|
+
input_token_type.append("image")
|
1284
|
+
elif token == image_token_id and video_check_flg:
|
1285
|
+
input_token_type.append("video")
|
1286
|
+
else:
|
1287
|
+
input_token_type.append("text")
|
1288
|
+
|
1289
|
+
input_type_group = []
|
1290
|
+
for key, group in itertools.groupby(
|
1291
|
+
enumerate(input_token_type), lambda x: x[1]
|
1292
|
+
):
|
1293
|
+
group = list(group)
|
1294
|
+
start_index = group[0][0]
|
1295
|
+
end_index = group[-1][0] + 1
|
1296
|
+
input_type_group.append((key, start_index, end_index))
|
1297
|
+
|
1298
|
+
llm_pos_ids_list = []
|
1299
|
+
video_frame_num = 1
|
1300
|
+
for modality_type, start_idx, end_idx in input_type_group:
|
1301
|
+
st_idx = (
|
1302
|
+
llm_pos_ids_list[-1].max() + 1
|
1303
|
+
if len(llm_pos_ids_list) > 0
|
1304
|
+
else 0
|
1305
|
+
)
|
1306
|
+
|
1307
|
+
if modality_type == "image":
|
1308
|
+
t, h, w = (
|
1309
|
+
image_grid_thw[image_index][0],
|
1310
|
+
image_grid_thw[image_index][1],
|
1311
|
+
image_grid_thw[image_index][2],
|
1312
|
+
)
|
1313
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
1314
|
+
t.item(),
|
1315
|
+
h.item() // spatial_merge_size,
|
1316
|
+
w.item() // spatial_merge_size,
|
1317
|
+
)
|
1318
|
+
|
1319
|
+
t_index = (
|
1320
|
+
torch.arange(llm_grid_t)
|
1321
|
+
.view(-1, 1)
|
1322
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
1323
|
+
.flatten()
|
1324
|
+
)
|
1325
|
+
h_index = (
|
1326
|
+
torch.arange(llm_grid_h)
|
1327
|
+
.view(1, -1, 1)
|
1328
|
+
.expand(llm_grid_t, -1, llm_grid_w)
|
1329
|
+
.flatten()
|
1330
|
+
)
|
1331
|
+
w_index = (
|
1332
|
+
torch.arange(llm_grid_w)
|
1333
|
+
.view(1, 1, -1)
|
1334
|
+
.expand(llm_grid_t, llm_grid_h, -1)
|
1335
|
+
.flatten()
|
1336
|
+
)
|
1337
|
+
llm_pos_ids_list.append(
|
1338
|
+
torch.stack([t_index, h_index, w_index]) + st_idx
|
1339
|
+
)
|
1340
|
+
|
1341
|
+
image_index += 1
|
1342
|
+
video_frame_num = 1
|
1343
|
+
|
1344
|
+
elif modality_type == "video":
|
1345
|
+
t, h, w = (
|
1346
|
+
video_frame_num,
|
1347
|
+
video_grid_thw[video_index][1],
|
1348
|
+
video_grid_thw[video_index][2],
|
1349
|
+
)
|
1350
|
+
|
1351
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
1352
|
+
t,
|
1353
|
+
h.item() // spatial_merge_size,
|
1354
|
+
w.item() // spatial_merge_size,
|
1355
|
+
)
|
1356
|
+
|
1357
|
+
for t_idx in range(llm_grid_t):
|
1358
|
+
t_index = (
|
1359
|
+
torch.tensor(t_idx)
|
1360
|
+
.view(-1, 1)
|
1361
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
1362
|
+
.flatten()
|
1363
|
+
)
|
1364
|
+
|
1365
|
+
h_index = (
|
1366
|
+
torch.arange(llm_grid_h)
|
1367
|
+
.view(1, -1, 1)
|
1368
|
+
.expand(1, -1, llm_grid_w)
|
1369
|
+
.flatten()
|
1370
|
+
)
|
1371
|
+
w_index = (
|
1372
|
+
torch.arange(llm_grid_w)
|
1373
|
+
.view(1, 1, -1)
|
1374
|
+
.expand(1, llm_grid_h, -1)
|
1375
|
+
.flatten()
|
1376
|
+
)
|
1377
|
+
llm_pos_ids_list.append(
|
1378
|
+
torch.stack([t_index, h_index, w_index]) + st_idx
|
1379
|
+
)
|
1380
|
+
|
1381
|
+
video_group_index += 1
|
1382
|
+
|
1383
|
+
if video_group_index >= video_grid_thw[video_index][0]:
|
1384
|
+
video_index += 1
|
1385
|
+
video_group_index = 0
|
1386
|
+
|
1387
|
+
video_frame_num += 1
|
1388
|
+
|
1389
|
+
else:
|
1390
|
+
text_len = end_idx - start_idx
|
1391
|
+
llm_pos_ids_list.append(
|
1392
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
video_frame_num = 1
|
1396
|
+
|
1397
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1398
|
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
1399
|
+
position_ids.device
|
1400
|
+
)
|
1401
|
+
mrope_position_deltas.append(
|
1402
|
+
llm_positions.max() + 1 - len(total_input_ids[i])
|
1403
|
+
)
|
1404
|
+
mrope_position_deltas = torch.tensor(
|
1405
|
+
mrope_position_deltas, device=input_ids.device
|
1406
|
+
).unsqueeze(1)
|
1407
|
+
return position_ids, mrope_position_deltas
|
1408
|
+
else:
|
1409
|
+
if attention_mask is not None:
|
1410
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1411
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1412
|
+
position_ids = (
|
1413
|
+
position_ids.unsqueeze(0)
|
1414
|
+
.expand(3, -1, -1)
|
1415
|
+
.to(attention_mask.device)
|
1416
|
+
)
|
1417
|
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
1418
|
+
-1, keepdim=True
|
1419
|
+
)[0]
|
1420
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
1421
|
+
else:
|
1422
|
+
position_ids = (
|
1423
|
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
1424
|
+
.view(1, 1, -1)
|
1425
|
+
.expand(3, input_ids.shape[0], -1)
|
1426
|
+
)
|
1427
|
+
mrope_position_deltas = torch.zeros(
|
1428
|
+
[input_ids.shape[0], 1],
|
1429
|
+
device=input_ids.device,
|
1430
|
+
dtype=input_ids.dtype,
|
1431
|
+
)
|
1432
|
+
|
1433
|
+
return position_ids, mrope_position_deltas
|
1434
|
+
|
1156
1435
|
@staticmethod
|
1157
1436
|
def get_next_input_positions(
|
1158
1437
|
mrope_position_delta: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -6,7 +6,10 @@ import torch.distributed as dist
|
|
6
6
|
from torch import nn
|
7
7
|
|
8
8
|
from sglang.srt.distributed import get_tp_group
|
9
|
-
from sglang.srt.layers.dp_attention import
|
9
|
+
from sglang.srt.layers.dp_attention import (
|
10
|
+
get_attention_tp_group,
|
11
|
+
is_dp_attention_enabled,
|
12
|
+
)
|
10
13
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
11
14
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
15
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
|
|
32
35
|
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
33
36
|
self.tp_sync_group = get_tp_group().device_group
|
34
37
|
|
35
|
-
if
|
38
|
+
if is_dp_attention_enabled():
|
36
39
|
self.tp_sync_group = get_attention_tp_group().device_group
|
37
40
|
|
38
41
|
def forward(
|
@@ -5,22 +5,6 @@ import torch
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
|
-
def get_fuse_output_add_from_name(name: str) -> bool:
|
9
|
-
mapping = {
|
10
|
-
"triton": True,
|
11
|
-
"flashinfer": False,
|
12
|
-
}
|
13
|
-
return mapping.get(name, False)
|
14
|
-
|
15
|
-
|
16
|
-
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
17
|
-
mapping = {
|
18
|
-
"triton": True,
|
19
|
-
"flashinfer": False,
|
20
|
-
}
|
21
|
-
return mapping.get(name, False)
|
22
|
-
|
23
|
-
|
24
8
|
class BaseLoRABackend:
|
25
9
|
"""Base class for different Lora backends.
|
26
10
|
Each backend has its own implementation of Lora kernels.
|
@@ -28,15 +12,11 @@ class BaseLoRABackend:
|
|
28
12
|
Args:
|
29
13
|
name: name of backend
|
30
14
|
batch_info: information of current batch for use
|
31
|
-
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
-
and the operation of adding will be fused into kernel
|
33
15
|
"""
|
34
16
|
|
35
17
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
36
18
|
self.name = name
|
37
19
|
self.batch_info = batch_info
|
38
|
-
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
39
|
-
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
40
20
|
|
41
21
|
def run_lora_a_sgemm(
|
42
22
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
126
106
|
|
127
107
|
return TritonLoRABackend
|
128
108
|
elif name == "flashinfer":
|
129
|
-
|
130
|
-
|
131
|
-
|
109
|
+
raise ValueError(
|
110
|
+
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
111
|
+
)
|
132
112
|
else:
|
133
113
|
raise ValueError(f"Invalid backend: {name}")
|
sglang/srt/lora/layers.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import List, Tuple
|
2
|
-
|
3
1
|
import torch
|
4
2
|
from torch import nn
|
5
3
|
|
@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
79
77
|
self.B_buffer = B_buffer
|
80
78
|
|
81
79
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
82
|
-
backend_kwargs = {"base_output": base_output}
|
83
80
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
84
81
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
85
|
-
lora_a_output,
|
86
|
-
self.B_buffer
|
87
|
-
|
88
|
-
)
|
89
|
-
return (
|
90
|
-
lora_output
|
91
|
-
if self.lora_backend.fuse_output_add
|
92
|
-
else base_output + lora_output
|
82
|
+
x=lora_a_output,
|
83
|
+
weights=self.B_buffer,
|
84
|
+
base_output=base_output,
|
93
85
|
)
|
86
|
+
return lora_output
|
94
87
|
|
95
88
|
def forward(self, input_: torch.Tensor):
|
96
89
|
# duplicate the logic in ColumnParallelLinear
|
@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
135
128
|
):
|
136
129
|
self.set_lora = True
|
137
130
|
self.A_buffer_gate_up = A_buffer
|
138
|
-
|
139
|
-
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
140
|
-
if getattr(self, "B_buffer_gate_up", None) is None:
|
141
|
-
self.B_buffer_gate_up = torch.empty(
|
142
|
-
(
|
143
|
-
B_buffer[0].shape[0],
|
144
|
-
2 * B_buffer[0].shape[1],
|
145
|
-
B_buffer[0].shape[2],
|
146
|
-
),
|
147
|
-
dtype=B_buffer[0].dtype,
|
148
|
-
device=B_buffer[0].device,
|
149
|
-
)
|
150
|
-
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
151
|
-
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
152
|
-
else:
|
153
|
-
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
131
|
+
self.B_buffer_gate_up = B_buffer
|
154
132
|
|
155
133
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
156
|
-
backend_kwargs = {"base_output": base_output}
|
157
|
-
|
158
134
|
lora_output = self.lora_backend.run_gate_up_lora(
|
159
|
-
x,
|
160
|
-
self.A_buffer_gate_up,
|
161
|
-
self.B_buffer_gate_up,
|
162
|
-
|
163
|
-
)
|
164
|
-
return (
|
165
|
-
lora_output
|
166
|
-
if self.lora_backend.fuse_output_add
|
167
|
-
else base_output + lora_output
|
135
|
+
x=x,
|
136
|
+
gate_up_lora_a=self.A_buffer_gate_up,
|
137
|
+
gate_up_lora_b=self.B_buffer_gate_up,
|
138
|
+
base_output=base_output,
|
168
139
|
)
|
140
|
+
return lora_output
|
169
141
|
|
170
142
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
171
143
|
return A
|
@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
173
145
|
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
174
146
|
# Since the outputs for both gate and up are identical, we use a random one.
|
175
147
|
shard_size = self.base_layer.output_partition_sizes[0]
|
148
|
+
gate_size = self.base_layer.output_sizes[0]
|
176
149
|
start_idx = tp_rank * shard_size
|
177
150
|
end_idx = (tp_rank + 1) * shard_size
|
178
|
-
return
|
151
|
+
return torch.concat(
|
152
|
+
(
|
153
|
+
B[start_idx:end_idx, :],
|
154
|
+
B[gate_size + start_idx : gate_size + end_idx],
|
155
|
+
),
|
156
|
+
dim=0,
|
157
|
+
)
|
179
158
|
|
180
159
|
|
181
160
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
185
164
|
lora_backend: BaseLoRABackend,
|
186
165
|
) -> None:
|
187
166
|
super().__init__(base_layer, lora_backend)
|
167
|
+
q_proj_shard_size = self.base_layer.q_proj_shard_size
|
168
|
+
kv_proj_shard_size = self.base_layer.kv_proj_shard_size
|
169
|
+
self.output_offset = torch.tensor(
|
170
|
+
[
|
171
|
+
0,
|
172
|
+
q_proj_shard_size,
|
173
|
+
q_proj_shard_size + kv_proj_shard_size,
|
174
|
+
q_proj_shard_size + 2 * kv_proj_shard_size,
|
175
|
+
],
|
176
|
+
dtype=torch.int32,
|
177
|
+
device=next(self.base_layer.parameters()).device,
|
178
|
+
)
|
179
|
+
|
180
|
+
# For computing number of launched blocks
|
181
|
+
self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
|
188
182
|
|
189
183
|
def set_lora_info(
|
190
184
|
self,
|
191
185
|
A_buffer_qkv: torch.Tensor,
|
192
|
-
|
193
|
-
B_buffer_kv: torch.Tensor,
|
186
|
+
B_buffer_qkv: torch.Tensor,
|
194
187
|
):
|
195
188
|
self.set_lora = True
|
196
189
|
self.A_buffer_qkv = A_buffer_qkv
|
197
|
-
|
198
|
-
if self.lora_backend.fuse_stacked_lora_b:
|
199
|
-
assert (
|
200
|
-
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
|
201
|
-
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
|
202
|
-
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
203
|
-
|
204
|
-
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
205
|
-
if getattr(self, "B_buffer_qkv", None) is None:
|
206
|
-
self.B_buffer_qkv = torch.empty(
|
207
|
-
(
|
208
|
-
B_buffer_q[0].shape[0],
|
209
|
-
output_dim_q + 2 * output_dim_kv,
|
210
|
-
B_buffer_q[0].shape[2],
|
211
|
-
),
|
212
|
-
dtype=B_buffer_q[0].dtype,
|
213
|
-
device=B_buffer_q[0].device,
|
214
|
-
)
|
215
|
-
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
216
|
-
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
217
|
-
B_buffer_kv[0]
|
218
|
-
)
|
219
|
-
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
220
|
-
B_buffer_kv[1]
|
221
|
-
)
|
222
|
-
|
223
|
-
# Offsets of q/k/v in output dimension
|
224
|
-
if getattr(self, "output_offset", None) is None:
|
225
|
-
self.output_offset = torch.tensor(
|
226
|
-
[
|
227
|
-
0,
|
228
|
-
output_dim_q,
|
229
|
-
output_dim_q + output_dim_kv,
|
230
|
-
output_dim_q + 2 * output_dim_kv,
|
231
|
-
],
|
232
|
-
dtype=torch.int32,
|
233
|
-
device=B_buffer_q.device,
|
234
|
-
)
|
235
|
-
# For computing number of launched blocks
|
236
|
-
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
|
237
|
-
else:
|
238
|
-
self.B_buffer_qkv = (
|
239
|
-
B_buffer_q,
|
240
|
-
B_buffer_kv,
|
241
|
-
)
|
190
|
+
self.B_buffer_qkv = B_buffer_qkv
|
242
191
|
|
243
192
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
244
|
-
backend_kwargs = {"base_output": base_output}
|
245
|
-
if self.lora_backend.fuse_stacked_lora_b:
|
246
|
-
backend_kwargs["output_offset"] = self.output_offset
|
247
|
-
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
248
|
-
|
249
193
|
lora_output = self.lora_backend.run_qkv_lora(
|
250
|
-
x,
|
251
|
-
self.A_buffer_qkv,
|
252
|
-
self.B_buffer_qkv,
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
lora_output
|
257
|
-
if self.lora_backend.fuse_output_add
|
258
|
-
else base_output + lora_output
|
194
|
+
x=x,
|
195
|
+
qkv_lora_a=self.A_buffer_qkv,
|
196
|
+
qkv_lora_b=self.B_buffer_qkv,
|
197
|
+
base_output=base_output,
|
198
|
+
output_offset=self.output_offset,
|
199
|
+
max_qkv_out_dim=self.max_qkv_out_dim,
|
259
200
|
)
|
201
|
+
return lora_output
|
260
202
|
|
261
203
|
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
262
204
|
return A
|
263
205
|
|
264
|
-
def slice_lora_b_weights(
|
265
|
-
self, B: List[torch.Tensor], tp_rank: int
|
266
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
267
|
-
B_q, B_kv = B
|
206
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
|
268
207
|
base_layer = self.base_layer
|
269
208
|
q_proj_shard_size = base_layer.q_proj_shard_size
|
270
209
|
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
277
216
|
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
278
217
|
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
279
218
|
|
280
|
-
|
219
|
+
q_size, k_size, _ = base_layer.output_sizes
|
220
|
+
B_q_shard = B[q_start_idx:q_end_idx, :]
|
221
|
+
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
|
222
|
+
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
|
223
|
+
|
224
|
+
return torch.concat(
|
225
|
+
(
|
226
|
+
B_q_shard,
|
227
|
+
B_k_shard,
|
228
|
+
B_v_shard,
|
229
|
+
),
|
230
|
+
dim=0,
|
231
|
+
)
|
281
232
|
|
282
233
|
|
283
234
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
@@ -294,20 +245,15 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
294
245
|
self.B_buffer = B_buffer
|
295
246
|
|
296
247
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
297
|
-
backend_kwargs = {"base_output": base_output}
|
298
248
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
299
249
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
300
|
-
lora_a_output,
|
301
|
-
self.B_buffer
|
302
|
-
|
303
|
-
)
|
304
|
-
return (
|
305
|
-
lora_output
|
306
|
-
if self.lora_backend.fuse_output_add
|
307
|
-
else base_output + lora_output
|
250
|
+
x=lora_a_output,
|
251
|
+
weights=self.B_buffer,
|
252
|
+
base_output=base_output,
|
308
253
|
)
|
254
|
+
return lora_output
|
309
255
|
|
310
|
-
def forward(self, input_: torch.Tensor):
|
256
|
+
def forward(self, input_: torch.Tensor, skip_all_reduce=False):
|
311
257
|
# duplicate the logic in RowParallelLinear
|
312
258
|
if self.base_layer.input_is_parallel:
|
313
259
|
input_parallel = input_
|
@@ -324,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
324
270
|
if self.set_lora:
|
325
271
|
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
326
272
|
|
327
|
-
if
|
273
|
+
if (
|
274
|
+
self.base_layer.reduce_results
|
275
|
+
and self.base_layer.tp_size > 1
|
276
|
+
and not skip_all_reduce
|
277
|
+
):
|
328
278
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
329
279
|
else:
|
330
280
|
output_ = output_parallel
|