sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
|
2
2
|
|
3
3
|
"""Rotary Positional Embeddings."""
|
4
|
+
import itertools
|
4
5
|
import math
|
5
6
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
6
7
|
|
@@ -221,6 +222,7 @@ class RotaryEmbedding(CustomOp):
|
|
221
222
|
query: torch.Tensor,
|
222
223
|
key: torch.Tensor,
|
223
224
|
offsets: Optional[torch.Tensor] = None,
|
225
|
+
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
|
224
226
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
225
227
|
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
226
228
|
apply_rope_with_cos_sin_cache_inplace(
|
@@ -230,8 +232,17 @@ class RotaryEmbedding(CustomOp):
|
|
230
232
|
head_size=self.head_size,
|
231
233
|
cos_sin_cache=self.cos_sin_cache,
|
232
234
|
is_neox=self.is_neox_style,
|
235
|
+
# Compatible with old sgl-kernel
|
236
|
+
**(
|
237
|
+
dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg)
|
238
|
+
if fused_set_kv_buffer_arg is not None
|
239
|
+
else {}
|
240
|
+
),
|
233
241
|
)
|
234
242
|
else:
|
243
|
+
assert (
|
244
|
+
fused_set_kv_buffer_arg is None
|
245
|
+
), "save kv cache is not supported for vllm_rotary_embedding."
|
235
246
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
236
247
|
self.vllm_rotary_embedding(
|
237
248
|
positions,
|
@@ -679,7 +690,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
679
690
|
)
|
680
691
|
|
681
692
|
# Re-dispatch
|
682
|
-
if _is_hip
|
693
|
+
if _is_hip:
|
683
694
|
self._forward_method = self.forward_native
|
684
695
|
|
685
696
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
@@ -764,6 +775,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
764
775
|
key = key_rot
|
765
776
|
return query.to(dtype), key.to(dtype)
|
766
777
|
|
778
|
+
def forward_npu(
|
779
|
+
self,
|
780
|
+
positions: torch.Tensor,
|
781
|
+
query: torch.Tensor,
|
782
|
+
key: torch.Tensor,
|
783
|
+
offsets: Optional[torch.Tensor] = None,
|
784
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
785
|
+
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
|
786
|
+
# and generalization to more scenarios will be supported in the future.
|
787
|
+
if query.shape[1] * query.shape[2] > 4096:
|
788
|
+
return self.forward_native(positions, query, key, offsets)
|
789
|
+
num_tokens = query.shape[0]
|
790
|
+
rotary_mode = "half" if self.is_neox_style else "interleave"
|
791
|
+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
792
|
+
query_rot = query[..., : self.rotary_dim]
|
793
|
+
key_rot = key[..., : self.rotary_dim]
|
794
|
+
if self.rotary_dim < self.head_size:
|
795
|
+
query_pass = query[..., self.rotary_dim :]
|
796
|
+
key_pass = key[..., self.rotary_dim :]
|
797
|
+
|
798
|
+
query_rot, key_rot = torch_npu.npu_mrope(
|
799
|
+
torch.add(positions, offsets) if offsets is not None else positions,
|
800
|
+
query_rot.reshape(num_tokens, -1),
|
801
|
+
key_rot.reshape(num_tokens, -1),
|
802
|
+
self.cos_sin_cache,
|
803
|
+
self.rotary_dim,
|
804
|
+
mrope_section=[0, 0, 0],
|
805
|
+
rotary_mode=rotary_mode,
|
806
|
+
)
|
807
|
+
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
808
|
+
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
809
|
+
|
810
|
+
if self.rotary_dim < self.head_size:
|
811
|
+
query = torch.cat((query_rot, query_pass), dim=-1)
|
812
|
+
key = torch.cat((key_rot, key_pass), dim=-1)
|
813
|
+
else:
|
814
|
+
query = query_rot
|
815
|
+
key = key_rot
|
816
|
+
return query, key
|
817
|
+
|
767
818
|
def forward_cpu(
|
768
819
|
self,
|
769
820
|
positions: torch.Tensor,
|
@@ -946,7 +997,37 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
946
997
|
|
947
998
|
self.mrope_section = mrope_section
|
948
999
|
if self.mrope_section:
|
949
|
-
|
1000
|
+
expected_sum = rotary_dim // 2
|
1001
|
+
actual_sum = sum(self.mrope_section)
|
1002
|
+
if actual_sum != expected_sum:
|
1003
|
+
print(
|
1004
|
+
f"MRoPE section sum mismatch: expected {expected_sum}, got {actual_sum}. "
|
1005
|
+
f"Adjusting mrope_section to match rotary_dim // 2 = {expected_sum}"
|
1006
|
+
)
|
1007
|
+
# Auto-correct by scaling the mrope_section proportionally
|
1008
|
+
if actual_sum > 0:
|
1009
|
+
scale_factor = expected_sum / actual_sum
|
1010
|
+
self.mrope_section = [
|
1011
|
+
max(1, int(section * scale_factor))
|
1012
|
+
for section in self.mrope_section
|
1013
|
+
]
|
1014
|
+
# Ensure the sum exactly matches by adjusting the last element
|
1015
|
+
current_sum = sum(self.mrope_section)
|
1016
|
+
if current_sum != expected_sum:
|
1017
|
+
self.mrope_section[-1] += expected_sum - current_sum
|
1018
|
+
else:
|
1019
|
+
# If all sections are 0, create a default distribution
|
1020
|
+
self.mrope_section = [
|
1021
|
+
expected_sum // len(self.mrope_section)
|
1022
|
+
] * len(self.mrope_section)
|
1023
|
+
# Handle remainder
|
1024
|
+
remainder = expected_sum % len(self.mrope_section)
|
1025
|
+
for i in range(remainder):
|
1026
|
+
self.mrope_section[i] += 1
|
1027
|
+
|
1028
|
+
print(
|
1029
|
+
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
1030
|
+
)
|
950
1031
|
|
951
1032
|
def forward(
|
952
1033
|
self,
|
@@ -1153,6 +1234,204 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1153
1234
|
mrope_position_deltas = max_position_ids + 1 - s
|
1154
1235
|
return position_ids, mrope_position_deltas
|
1155
1236
|
|
1237
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
|
1238
|
+
@staticmethod
|
1239
|
+
def get_rope_index_glm4v(
|
1240
|
+
input_ids: torch.Tensor,
|
1241
|
+
hf_config: Any,
|
1242
|
+
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
1243
|
+
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
1244
|
+
attention_mask: torch.Tensor,
|
1245
|
+
**kwargs,
|
1246
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1247
|
+
"""Get mrope input positions and delta value for GLM4V."""
|
1248
|
+
image_token_id = hf_config.image_token_id
|
1249
|
+
video_start_token_id = hf_config.video_start_token_id
|
1250
|
+
video_end_token_id = hf_config.video_end_token_id
|
1251
|
+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
1252
|
+
|
1253
|
+
mrope_position_deltas = []
|
1254
|
+
if input_ids is not None and (
|
1255
|
+
image_grid_thw is not None or video_grid_thw is not None
|
1256
|
+
):
|
1257
|
+
total_input_ids = input_ids
|
1258
|
+
if attention_mask is None:
|
1259
|
+
attention_mask = torch.ones_like(total_input_ids)
|
1260
|
+
position_ids = torch.ones(
|
1261
|
+
3,
|
1262
|
+
input_ids.shape[0],
|
1263
|
+
input_ids.shape[1],
|
1264
|
+
dtype=input_ids.dtype,
|
1265
|
+
device=input_ids.device,
|
1266
|
+
)
|
1267
|
+
image_index, video_index = 0, 0
|
1268
|
+
video_group_index = 0
|
1269
|
+
attention_mask = attention_mask.to(total_input_ids.device)
|
1270
|
+
for i, input_ids in enumerate(total_input_ids):
|
1271
|
+
input_ids = input_ids[attention_mask[i] == 1]
|
1272
|
+
input_tokens = input_ids.tolist()
|
1273
|
+
|
1274
|
+
input_token_type = []
|
1275
|
+
video_check_flg = False
|
1276
|
+
for token in input_tokens:
|
1277
|
+
if token == video_start_token_id:
|
1278
|
+
video_check_flg = True
|
1279
|
+
elif token == video_end_token_id:
|
1280
|
+
video_check_flg = False
|
1281
|
+
|
1282
|
+
if token == image_token_id and not video_check_flg:
|
1283
|
+
input_token_type.append("image")
|
1284
|
+
elif token == image_token_id and video_check_flg:
|
1285
|
+
input_token_type.append("video")
|
1286
|
+
else:
|
1287
|
+
input_token_type.append("text")
|
1288
|
+
|
1289
|
+
input_type_group = []
|
1290
|
+
for key, group in itertools.groupby(
|
1291
|
+
enumerate(input_token_type), lambda x: x[1]
|
1292
|
+
):
|
1293
|
+
group = list(group)
|
1294
|
+
start_index = group[0][0]
|
1295
|
+
end_index = group[-1][0] + 1
|
1296
|
+
input_type_group.append((key, start_index, end_index))
|
1297
|
+
|
1298
|
+
llm_pos_ids_list = []
|
1299
|
+
video_frame_num = 1
|
1300
|
+
for modality_type, start_idx, end_idx in input_type_group:
|
1301
|
+
st_idx = (
|
1302
|
+
llm_pos_ids_list[-1].max() + 1
|
1303
|
+
if len(llm_pos_ids_list) > 0
|
1304
|
+
else 0
|
1305
|
+
)
|
1306
|
+
|
1307
|
+
if modality_type == "image":
|
1308
|
+
t, h, w = (
|
1309
|
+
image_grid_thw[image_index][0],
|
1310
|
+
image_grid_thw[image_index][1],
|
1311
|
+
image_grid_thw[image_index][2],
|
1312
|
+
)
|
1313
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
1314
|
+
t.item(),
|
1315
|
+
h.item() // spatial_merge_size,
|
1316
|
+
w.item() // spatial_merge_size,
|
1317
|
+
)
|
1318
|
+
|
1319
|
+
t_index = (
|
1320
|
+
torch.arange(llm_grid_t)
|
1321
|
+
.view(-1, 1)
|
1322
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
1323
|
+
.flatten()
|
1324
|
+
)
|
1325
|
+
h_index = (
|
1326
|
+
torch.arange(llm_grid_h)
|
1327
|
+
.view(1, -1, 1)
|
1328
|
+
.expand(llm_grid_t, -1, llm_grid_w)
|
1329
|
+
.flatten()
|
1330
|
+
)
|
1331
|
+
w_index = (
|
1332
|
+
torch.arange(llm_grid_w)
|
1333
|
+
.view(1, 1, -1)
|
1334
|
+
.expand(llm_grid_t, llm_grid_h, -1)
|
1335
|
+
.flatten()
|
1336
|
+
)
|
1337
|
+
llm_pos_ids_list.append(
|
1338
|
+
torch.stack([t_index, h_index, w_index]) + st_idx
|
1339
|
+
)
|
1340
|
+
|
1341
|
+
image_index += 1
|
1342
|
+
video_frame_num = 1
|
1343
|
+
|
1344
|
+
elif modality_type == "video":
|
1345
|
+
t, h, w = (
|
1346
|
+
video_frame_num,
|
1347
|
+
video_grid_thw[video_index][1],
|
1348
|
+
video_grid_thw[video_index][2],
|
1349
|
+
)
|
1350
|
+
|
1351
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
1352
|
+
t,
|
1353
|
+
h.item() // spatial_merge_size,
|
1354
|
+
w.item() // spatial_merge_size,
|
1355
|
+
)
|
1356
|
+
|
1357
|
+
for t_idx in range(llm_grid_t):
|
1358
|
+
t_index = (
|
1359
|
+
torch.tensor(t_idx)
|
1360
|
+
.view(-1, 1)
|
1361
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
1362
|
+
.flatten()
|
1363
|
+
)
|
1364
|
+
|
1365
|
+
h_index = (
|
1366
|
+
torch.arange(llm_grid_h)
|
1367
|
+
.view(1, -1, 1)
|
1368
|
+
.expand(1, -1, llm_grid_w)
|
1369
|
+
.flatten()
|
1370
|
+
)
|
1371
|
+
w_index = (
|
1372
|
+
torch.arange(llm_grid_w)
|
1373
|
+
.view(1, 1, -1)
|
1374
|
+
.expand(1, llm_grid_h, -1)
|
1375
|
+
.flatten()
|
1376
|
+
)
|
1377
|
+
llm_pos_ids_list.append(
|
1378
|
+
torch.stack([t_index, h_index, w_index]) + st_idx
|
1379
|
+
)
|
1380
|
+
|
1381
|
+
video_group_index += 1
|
1382
|
+
|
1383
|
+
if video_group_index >= video_grid_thw[video_index][0]:
|
1384
|
+
video_index += 1
|
1385
|
+
video_group_index = 0
|
1386
|
+
|
1387
|
+
video_frame_num += 1
|
1388
|
+
|
1389
|
+
else:
|
1390
|
+
text_len = end_idx - start_idx
|
1391
|
+
llm_pos_ids_list.append(
|
1392
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
video_frame_num = 1
|
1396
|
+
|
1397
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
1398
|
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
1399
|
+
position_ids.device
|
1400
|
+
)
|
1401
|
+
mrope_position_deltas.append(
|
1402
|
+
llm_positions.max() + 1 - len(total_input_ids[i])
|
1403
|
+
)
|
1404
|
+
mrope_position_deltas = torch.tensor(
|
1405
|
+
mrope_position_deltas, device=input_ids.device
|
1406
|
+
).unsqueeze(1)
|
1407
|
+
return position_ids, mrope_position_deltas
|
1408
|
+
else:
|
1409
|
+
if attention_mask is not None:
|
1410
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1411
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1412
|
+
position_ids = (
|
1413
|
+
position_ids.unsqueeze(0)
|
1414
|
+
.expand(3, -1, -1)
|
1415
|
+
.to(attention_mask.device)
|
1416
|
+
)
|
1417
|
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
1418
|
+
-1, keepdim=True
|
1419
|
+
)[0]
|
1420
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
1421
|
+
else:
|
1422
|
+
position_ids = (
|
1423
|
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
1424
|
+
.view(1, 1, -1)
|
1425
|
+
.expand(3, input_ids.shape[0], -1)
|
1426
|
+
)
|
1427
|
+
mrope_position_deltas = torch.zeros(
|
1428
|
+
[input_ids.shape[0], 1],
|
1429
|
+
device=input_ids.device,
|
1430
|
+
dtype=input_ids.dtype,
|
1431
|
+
)
|
1432
|
+
|
1433
|
+
return position_ids, mrope_position_deltas
|
1434
|
+
|
1156
1435
|
@staticmethod
|
1157
1436
|
def get_next_input_positions(
|
1158
1437
|
mrope_position_delta: int,
|
@@ -1172,6 +1451,202 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1172
1451
|
)
|
1173
1452
|
|
1174
1453
|
|
1454
|
+
class DualChunkRotaryEmbedding(CustomOp):
|
1455
|
+
"""Rotary positional embedding for Dual Chunk Attention."""
|
1456
|
+
|
1457
|
+
def __init__(
|
1458
|
+
self,
|
1459
|
+
head_size: int,
|
1460
|
+
rotary_dim: int,
|
1461
|
+
max_position_embeddings: int,
|
1462
|
+
base: int,
|
1463
|
+
is_neox_style: bool,
|
1464
|
+
dtype: torch.dtype,
|
1465
|
+
chunk_size: int,
|
1466
|
+
local_size: int,
|
1467
|
+
) -> None:
|
1468
|
+
super().__init__()
|
1469
|
+
self.head_size = head_size
|
1470
|
+
self.rotary_dim = rotary_dim
|
1471
|
+
self.max_position_embeddings = max_position_embeddings
|
1472
|
+
self.base = base
|
1473
|
+
self.is_neox_style = is_neox_style
|
1474
|
+
self.chunk_size = chunk_size
|
1475
|
+
self.local_size = local_size
|
1476
|
+
self.dtype = dtype
|
1477
|
+
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
1478
|
+
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
|
1479
|
+
self._compute_cos_sin_cache()
|
1480
|
+
)
|
1481
|
+
|
1482
|
+
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
|
1483
|
+
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
|
1484
|
+
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
|
1485
|
+
self.register_buffer(
|
1486
|
+
"cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
|
1487
|
+
)
|
1488
|
+
self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
|
1489
|
+
|
1490
|
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
1491
|
+
"""Compute the inverse frequency."""
|
1492
|
+
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
1493
|
+
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
1494
|
+
# avoid numerical issues with large base values (e.g., 10000000).
|
1495
|
+
# This may cause a slight numerical difference between the HF
|
1496
|
+
# implementation and ours.
|
1497
|
+
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
1498
|
+
# use CPU to compute the cache and then move it to GPU. However, we
|
1499
|
+
# create the cache on GPU for faster initialization. This may cause
|
1500
|
+
# a slight numerical difference between the HF implementation and ours.
|
1501
|
+
inv_freq = 1.0 / (
|
1502
|
+
base
|
1503
|
+
** (
|
1504
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
1505
|
+
)
|
1506
|
+
)
|
1507
|
+
return inv_freq
|
1508
|
+
|
1509
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
1510
|
+
"""Compute the cos and sin cache."""
|
1511
|
+
inv_freq = self._compute_inv_freq(self.base)
|
1512
|
+
chunk_len = self.chunk_size - self.local_size
|
1513
|
+
q_t = torch.arange(chunk_len, dtype=torch.float)
|
1514
|
+
qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
|
1515
|
+
max=self.chunk_size
|
1516
|
+
)
|
1517
|
+
k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
|
1518
|
+
|
1519
|
+
# count from chunk_len, no clamp(self.chunk_size) restriction
|
1520
|
+
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
|
1521
|
+
# count from self.chunk_size for q_inter's rope
|
1522
|
+
q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
|
1523
|
+
|
1524
|
+
q_freqs = torch.outer(q_t, inv_freq)
|
1525
|
+
qc_freqs = torch.outer(qc_t, inv_freq)
|
1526
|
+
k_freqs = torch.outer(k_t, inv_freq)
|
1527
|
+
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
|
1528
|
+
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
|
1529
|
+
|
1530
|
+
q_cos = q_freqs.cos()
|
1531
|
+
q_sin = q_freqs.sin()
|
1532
|
+
qc_cos = qc_freqs.cos()
|
1533
|
+
qc_sin = qc_freqs.sin()
|
1534
|
+
k_cos = k_freqs.cos()
|
1535
|
+
k_sin = k_freqs.sin()
|
1536
|
+
|
1537
|
+
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
|
1538
|
+
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
|
1539
|
+
q_inter_cos = q_inter_freqs.cos()
|
1540
|
+
q_inter_sin = q_inter_freqs.sin()
|
1541
|
+
|
1542
|
+
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
|
1543
|
+
dtype=self.dtype, device=self.device
|
1544
|
+
)
|
1545
|
+
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
|
1546
|
+
dtype=self.dtype, device=self.device
|
1547
|
+
)
|
1548
|
+
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
|
1549
|
+
dtype=self.dtype, device=self.device
|
1550
|
+
)
|
1551
|
+
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
|
1552
|
+
dtype=self.dtype, device=self.device
|
1553
|
+
)
|
1554
|
+
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
|
1555
|
+
dtype=self.dtype, device=self.device
|
1556
|
+
)
|
1557
|
+
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
1558
|
+
|
1559
|
+
def forward(
|
1560
|
+
self,
|
1561
|
+
positions: torch.Tensor,
|
1562
|
+
query: torch.Tensor,
|
1563
|
+
key: torch.Tensor,
|
1564
|
+
offsets: Optional[torch.Tensor] = None,
|
1565
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1566
|
+
query = query.view(*query.shape[:-1], -1, self.head_size)
|
1567
|
+
key = key.view(*key.shape[:-1], -1, self.head_size)
|
1568
|
+
query_rot = query[..., : self.rotary_dim]
|
1569
|
+
key_rot = key[..., : self.rotary_dim]
|
1570
|
+
if self.rotary_dim < self.head_size:
|
1571
|
+
query_pass = query[..., self.rotary_dim :]
|
1572
|
+
key_pass = key[..., self.rotary_dim :]
|
1573
|
+
else:
|
1574
|
+
query_pass = None
|
1575
|
+
key_pass = None
|
1576
|
+
|
1577
|
+
positions_with_offsets = (
|
1578
|
+
torch.add(positions, offsets) if offsets is not None else positions
|
1579
|
+
)
|
1580
|
+
key = self._apply_rotary_embedding(
|
1581
|
+
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
|
1582
|
+
)
|
1583
|
+
chunk_len = self.chunk_size - self.local_size
|
1584
|
+
query = self._apply_rotary_embedding(
|
1585
|
+
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
|
1586
|
+
query_rot,
|
1587
|
+
query_pass,
|
1588
|
+
)
|
1589
|
+
query_succ = self._apply_rotary_embedding(
|
1590
|
+
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
|
1591
|
+
query_rot,
|
1592
|
+
query_pass,
|
1593
|
+
)
|
1594
|
+
query_inter = self._apply_rotary_embedding(
|
1595
|
+
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
|
1596
|
+
query_rot,
|
1597
|
+
query_pass,
|
1598
|
+
)
|
1599
|
+
query_succ_critical = self._apply_rotary_embedding(
|
1600
|
+
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
|
1601
|
+
query_rot,
|
1602
|
+
query_pass,
|
1603
|
+
)
|
1604
|
+
query_inter_critical = self._apply_rotary_embedding(
|
1605
|
+
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
|
1606
|
+
query_rot,
|
1607
|
+
query_pass,
|
1608
|
+
)
|
1609
|
+
|
1610
|
+
# merge query into one tensor to simplify the interfaces
|
1611
|
+
query = torch.cat(
|
1612
|
+
(
|
1613
|
+
query,
|
1614
|
+
query_succ,
|
1615
|
+
query_inter,
|
1616
|
+
query_succ_critical,
|
1617
|
+
query_inter_critical,
|
1618
|
+
),
|
1619
|
+
dim=-1,
|
1620
|
+
)
|
1621
|
+
return query, key
|
1622
|
+
|
1623
|
+
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
1624
|
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
1625
|
+
if self.is_neox_style:
|
1626
|
+
# NOTE(woosuk): Here we assume that the positions tensor has the
|
1627
|
+
# shape [batch_size, seq_len].
|
1628
|
+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
1629
|
+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
1630
|
+
else:
|
1631
|
+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
1632
|
+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
1633
|
+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
1634
|
+
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
|
1635
|
+
|
1636
|
+
if self.rotary_dim < self.head_size:
|
1637
|
+
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
|
1638
|
+
else:
|
1639
|
+
hidden = hidden_rot
|
1640
|
+
return hidden.flatten(-2).squeeze(0)
|
1641
|
+
|
1642
|
+
def extra_repr(self) -> str:
|
1643
|
+
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
1644
|
+
s += f", max_position_embeddings={self.max_position_embeddings}"
|
1645
|
+
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
1646
|
+
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
|
1647
|
+
return s
|
1648
|
+
|
1649
|
+
|
1175
1650
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
1176
1651
|
|
1177
1652
|
|
@@ -1184,6 +1659,7 @@ def get_rope(
|
|
1184
1659
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
1185
1660
|
dtype: Optional[torch.dtype] = None,
|
1186
1661
|
partial_rotary_factor: float = 1.0,
|
1662
|
+
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
1187
1663
|
) -> RotaryEmbedding:
|
1188
1664
|
if dtype is None:
|
1189
1665
|
dtype = torch.get_default_dtype()
|
@@ -1195,6 +1671,17 @@ def get_rope(
|
|
1195
1671
|
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
1196
1672
|
else:
|
1197
1673
|
rope_scaling_args = None
|
1674
|
+
|
1675
|
+
if dual_chunk_attention_config is not None:
|
1676
|
+
dual_chunk_attention_tuple = {
|
1677
|
+
k: tuple(v) if isinstance(v, list) else v
|
1678
|
+
for k, v in dual_chunk_attention_config.items()
|
1679
|
+
if k != "sparse_attention_config"
|
1680
|
+
}
|
1681
|
+
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
1682
|
+
else:
|
1683
|
+
dual_chunk_attention_args = None
|
1684
|
+
|
1198
1685
|
if partial_rotary_factor < 1.0:
|
1199
1686
|
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
1200
1687
|
key = (
|
@@ -1204,12 +1691,28 @@ def get_rope(
|
|
1204
1691
|
base,
|
1205
1692
|
is_neox_style,
|
1206
1693
|
rope_scaling_args,
|
1694
|
+
dual_chunk_attention_args,
|
1207
1695
|
dtype,
|
1208
1696
|
)
|
1209
1697
|
if key in _ROPE_DICT:
|
1210
1698
|
return _ROPE_DICT[key]
|
1211
1699
|
|
1212
|
-
if
|
1700
|
+
if dual_chunk_attention_config is not None:
|
1701
|
+
extra_kwargs = {
|
1702
|
+
k: v
|
1703
|
+
for k, v in dual_chunk_attention_config.items()
|
1704
|
+
if k in ("chunk_size", "local_size")
|
1705
|
+
}
|
1706
|
+
rotary_emb = DualChunkRotaryEmbedding(
|
1707
|
+
head_size,
|
1708
|
+
rotary_dim,
|
1709
|
+
max_position,
|
1710
|
+
base,
|
1711
|
+
is_neox_style,
|
1712
|
+
dtype,
|
1713
|
+
**extra_kwargs,
|
1714
|
+
)
|
1715
|
+
elif rope_scaling is None:
|
1213
1716
|
rotary_emb = RotaryEmbedding(
|
1214
1717
|
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
1215
1718
|
)
|
sglang/srt/layers/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
import re
|
3
|
+
from functools import lru_cache
|
3
4
|
|
4
5
|
import torch
|
5
6
|
|
@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
|
|
35
36
|
return (input,) if self.return_tuple else input
|
36
37
|
|
37
38
|
|
39
|
+
@lru_cache(maxsize=1)
|
38
40
|
def is_sm100_supported(device=None) -> bool:
|
39
41
|
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
40
42
|
torch.version.cuda >= "12.8"
|
41
43
|
)
|
44
|
+
|
45
|
+
|
46
|
+
@lru_cache(maxsize=1)
|
47
|
+
def is_sm90_supported(device=None) -> bool:
|
48
|
+
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
49
|
+
torch.version.cuda >= "12.3"
|
50
|
+
)
|
@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
26
26
|
method_has_implemented_embedding,
|
27
27
|
)
|
28
28
|
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
29
|
-
from sglang.srt.utils import
|
29
|
+
from sglang.srt.utils import (
|
30
|
+
cpu_has_amx_support,
|
31
|
+
get_compiler_backend,
|
32
|
+
is_cpu,
|
33
|
+
set_weight_attrs,
|
34
|
+
)
|
30
35
|
|
31
36
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
32
37
|
|
@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices:
|
|
117
122
|
assert self.num_added_elements <= self.num_added_elements_padded
|
118
123
|
|
119
124
|
|
120
|
-
@torch.
|
125
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
121
126
|
def get_masked_input_and_mask(
|
122
127
|
input_: torch.Tensor,
|
123
128
|
org_vocab_start_index: int,
|
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
|
|
126
131
|
added_vocab_start_index: int,
|
127
132
|
added_vocab_end_index: int,
|
128
133
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
129
|
-
# torch.
|
134
|
+
# torch.compile will fuse all of the pointwise ops below
|
130
135
|
# into a single kernel, making it very fast
|
131
136
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
132
137
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
@@ -5,22 +5,6 @@ import torch
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
6
|
|
7
7
|
|
8
|
-
def get_fuse_output_add_from_name(name: str) -> bool:
|
9
|
-
mapping = {
|
10
|
-
"triton": True,
|
11
|
-
"flashinfer": False,
|
12
|
-
}
|
13
|
-
return mapping.get(name, False)
|
14
|
-
|
15
|
-
|
16
|
-
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
|
17
|
-
mapping = {
|
18
|
-
"triton": True,
|
19
|
-
"flashinfer": False,
|
20
|
-
}
|
21
|
-
return mapping.get(name, False)
|
22
|
-
|
23
|
-
|
24
8
|
class BaseLoRABackend:
|
25
9
|
"""Base class for different Lora backends.
|
26
10
|
Each backend has its own implementation of Lora kernels.
|
@@ -28,15 +12,11 @@ class BaseLoRABackend:
|
|
28
12
|
Args:
|
29
13
|
name: name of backend
|
30
14
|
batch_info: information of current batch for use
|
31
|
-
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
|
32
|
-
and the operation of adding will be fused into kernel
|
33
15
|
"""
|
34
16
|
|
35
17
|
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
36
18
|
self.name = name
|
37
19
|
self.batch_info = batch_info
|
38
|
-
self.fuse_output_add = get_fuse_output_add_from_name(name)
|
39
|
-
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
|
40
20
|
|
41
21
|
def run_lora_a_sgemm(
|
42
22
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
126
106
|
|
127
107
|
return TritonLoRABackend
|
128
108
|
elif name == "flashinfer":
|
129
|
-
|
130
|
-
|
131
|
-
|
109
|
+
raise ValueError(
|
110
|
+
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
111
|
+
)
|
132
112
|
else:
|
133
113
|
raise ValueError(f"Invalid backend: {name}")
|