sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
get_tensor_model_parallel_world_size,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
27
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
28
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
27
29
|
from sglang.srt.layers.parameter import (
|
28
30
|
ChannelQuantScaleParameter,
|
29
31
|
ModelWeightParameter,
|
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
|
|
49
51
|
)
|
50
52
|
|
51
53
|
if TYPE_CHECKING:
|
52
|
-
from sglang.srt.layers.moe.
|
53
|
-
|
54
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
55
|
+
CombineInput,
|
56
|
+
StandardDispatchOutput,
|
57
|
+
)
|
54
58
|
|
55
59
|
_is_cuda = is_cuda()
|
56
60
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
339
343
|
_is_cpu_amx_available
|
340
344
|
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
341
345
|
_amx_process_weight_after_loading(layer, ["weight"])
|
342
|
-
|
343
|
-
|
344
|
-
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
346
|
+
else:
|
347
|
+
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
345
348
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
346
349
|
|
347
350
|
def create_weights(
|
@@ -417,7 +420,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
417
420
|
layer: torch.nn.Module,
|
418
421
|
num_experts: int,
|
419
422
|
hidden_size: int,
|
420
|
-
|
423
|
+
intermediate_size_per_partition: int,
|
421
424
|
params_dtype: torch.dtype,
|
422
425
|
**extra_weight_attrs,
|
423
426
|
):
|
@@ -428,7 +431,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
428
431
|
# WEIGHTS
|
429
432
|
w13_weight = torch.nn.Parameter(
|
430
433
|
torch.empty(
|
431
|
-
num_experts,
|
434
|
+
num_experts,
|
435
|
+
2 * intermediate_size_per_partition,
|
436
|
+
hidden_size,
|
437
|
+
dtype=torch.int8,
|
432
438
|
),
|
433
439
|
requires_grad=False,
|
434
440
|
)
|
@@ -436,14 +442,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
436
442
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
437
443
|
|
438
444
|
w2_weight = torch.nn.Parameter(
|
439
|
-
torch.empty(
|
445
|
+
torch.empty(
|
446
|
+
num_experts,
|
447
|
+
hidden_size,
|
448
|
+
intermediate_size_per_partition,
|
449
|
+
dtype=torch.int8,
|
450
|
+
),
|
440
451
|
requires_grad=False,
|
441
452
|
)
|
442
453
|
layer.register_parameter("w2_weight", w2_weight)
|
443
454
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
444
455
|
|
445
456
|
w13_weight_scale = torch.nn.Parameter(
|
446
|
-
torch.ones(
|
457
|
+
torch.ones(
|
458
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
459
|
+
),
|
447
460
|
requires_grad=False,
|
448
461
|
)
|
449
462
|
w2_weight_scale = torch.nn.Parameter(
|
@@ -472,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
472
485
|
_is_cpu_amx_available
|
473
486
|
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
474
487
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
488
|
+
else:
|
489
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
490
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
479
491
|
layer.w13_weight_scale = Parameter(
|
480
492
|
layer.w13_weight_scale.data, requires_grad=False
|
481
493
|
)
|
@@ -483,23 +495,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
483
495
|
layer.w2_weight_scale.data, requires_grad=False
|
484
496
|
)
|
485
497
|
|
498
|
+
def create_moe_runner(
|
499
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
500
|
+
):
|
501
|
+
self.moe_runner_config = moe_runner_config
|
502
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
503
|
+
|
486
504
|
def apply(
|
487
505
|
self,
|
488
506
|
layer: torch.nn.Module,
|
489
|
-
|
490
|
-
topk_output: TopKOutput,
|
491
|
-
moe_runner_config: MoeRunnerConfig,
|
507
|
+
dispatch_output: StandardDispatchOutput,
|
492
508
|
) -> torch.Tensor:
|
493
|
-
from sglang.srt.layers.moe.
|
509
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
510
|
+
|
511
|
+
x = dispatch_output.hidden_states
|
512
|
+
topk_output = dispatch_output.topk_output
|
494
513
|
|
495
514
|
if use_intel_amx_backend(layer):
|
496
515
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
497
516
|
|
498
517
|
topk_weights, topk_ids, _ = topk_output
|
499
518
|
x, topk_weights = apply_topk_weights_cpu(
|
500
|
-
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
519
|
+
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
501
520
|
)
|
502
|
-
|
521
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
503
522
|
x,
|
504
523
|
layer.w13_weight,
|
505
524
|
layer.w2_weight,
|
@@ -515,20 +534,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
515
534
|
layer.w2_input_scale, # a2_scale
|
516
535
|
True, # is_vnni
|
517
536
|
)
|
537
|
+
return StandardCombineInput(hidden_states=output)
|
518
538
|
|
519
|
-
|
520
|
-
|
521
|
-
layer.
|
522
|
-
layer.w2_weight,
|
523
|
-
topk_output=topk_output,
|
524
|
-
moe_runner_config=moe_runner_config,
|
539
|
+
quant_info = TritonMoeQuantInfo(
|
540
|
+
w13_weight=layer.w13_weight,
|
541
|
+
w2_weight=layer.w2_weight,
|
525
542
|
use_int8_w8a8=True,
|
526
543
|
per_channel_quant=True,
|
527
|
-
|
528
|
-
w2_scale=
|
529
|
-
|
544
|
+
w13_scale=layer.w13_weight_scale,
|
545
|
+
w2_scale=layer.w2_weight_scale,
|
546
|
+
a13_scale=layer.w13_input_scale,
|
530
547
|
a2_scale=layer.w2_input_scale,
|
531
548
|
)
|
549
|
+
return self.runner.run(dispatch_output, quant_info)
|
532
550
|
|
533
551
|
|
534
552
|
class NPU_W8A8LinearMethodImpl:
|
@@ -900,7 +918,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
900
918
|
layer: torch.nn.Module,
|
901
919
|
num_experts: int,
|
902
920
|
hidden_size: int,
|
903
|
-
|
921
|
+
intermediate_size_per_partition: int,
|
904
922
|
params_dtype: torch.dtype,
|
905
923
|
**extra_weight_attrs,
|
906
924
|
) -> None:
|
@@ -914,21 +932,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
914
932
|
# weight
|
915
933
|
w13_weight = torch.nn.Parameter(
|
916
934
|
torch.empty(
|
917
|
-
num_experts,
|
935
|
+
num_experts,
|
936
|
+
2 * intermediate_size_per_partition,
|
937
|
+
hidden_size,
|
938
|
+
dtype=torch.int8,
|
918
939
|
),
|
919
940
|
requires_grad=False,
|
920
941
|
)
|
921
942
|
layer.register_parameter("w13_weight", w13_weight)
|
922
943
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
923
944
|
w2_weight = torch.nn.Parameter(
|
924
|
-
torch.empty(
|
945
|
+
torch.empty(
|
946
|
+
num_experts,
|
947
|
+
hidden_size,
|
948
|
+
intermediate_size_per_partition,
|
949
|
+
dtype=torch.int8,
|
950
|
+
),
|
925
951
|
requires_grad=False,
|
926
952
|
)
|
927
953
|
layer.register_parameter("w2_weight", w2_weight)
|
928
954
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
929
955
|
# scale
|
930
956
|
w13_weight_scale = torch.nn.Parameter(
|
931
|
-
torch.empty(
|
957
|
+
torch.empty(
|
958
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
959
|
+
),
|
932
960
|
requires_grad=False,
|
933
961
|
)
|
934
962
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
@@ -941,7 +969,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
941
969
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
942
970
|
# offset
|
943
971
|
w13_weight_offset = torch.nn.Parameter(
|
944
|
-
torch.empty(
|
972
|
+
torch.empty(
|
973
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
974
|
+
),
|
945
975
|
requires_grad=False,
|
946
976
|
)
|
947
977
|
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
@@ -973,18 +1003,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
973
1003
|
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
974
1004
|
)
|
975
1005
|
|
1006
|
+
def create_moe_runner(
|
1007
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
|
+
):
|
1009
|
+
self.moe_runner_config = moe_runner_config
|
1010
|
+
|
976
1011
|
def apply(
|
977
1012
|
self,
|
978
1013
|
layer,
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
1014
|
+
dispatch_output: StandardDispatchOutput,
|
1015
|
+
) -> CombineInput:
|
1016
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1017
|
+
|
1018
|
+
x = dispatch_output.hidden_states
|
1019
|
+
topk_output = dispatch_output.topk_output
|
983
1020
|
|
984
1021
|
topk_weights, topk_ids, _ = topk_output
|
985
1022
|
topk_ids = topk_ids.to(torch.int32)
|
986
1023
|
topk_weights = topk_weights.to(x.dtype)
|
987
|
-
|
1024
|
+
output = npu_fused_experts(
|
988
1025
|
hidden_states=x,
|
989
1026
|
w13=layer.w13_weight,
|
990
1027
|
w13_scale=layer.w13_weight_scale,
|
@@ -994,3 +1031,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
994
1031
|
topk_ids=topk_ids,
|
995
1032
|
top_k=topk_ids.shape[1],
|
996
1033
|
)
|
1034
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import torch
|
2
|
+
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
|
3
|
+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
4
|
+
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
|
5
|
+
|
6
|
+
from sglang.srt.utils import BumpAllocator
|
7
|
+
|
8
|
+
__all__ = ["fused_qk_rope_cat"]
|
9
|
+
|
10
|
+
|
11
|
+
def aiter_dsv3_router_gemm(
|
12
|
+
hidden_states: torch.Tensor,
|
13
|
+
weight: torch.Tensor,
|
14
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
15
|
+
):
|
16
|
+
M = hidden_states.shape[0]
|
17
|
+
N = weight.shape[0]
|
18
|
+
y = None
|
19
|
+
|
20
|
+
if M <= 256:
|
21
|
+
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
|
22
|
+
# for now it is also coupled with zero allocator.
|
23
|
+
if gemm_output_zero_allocator != None:
|
24
|
+
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
|
25
|
+
else:
|
26
|
+
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
|
27
|
+
|
28
|
+
if y is not None:
|
29
|
+
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
|
30
|
+
else:
|
31
|
+
logits = gemm_a16w16(hidden_states, weight)
|
32
|
+
|
33
|
+
return logits
|
34
|
+
|
35
|
+
|
36
|
+
def get_dsv3_gemm_output_zero_allocator_size(
|
37
|
+
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
|
38
|
+
):
|
39
|
+
if embedding_dim != 7168 or n_routed_experts != 256:
|
40
|
+
return 0
|
41
|
+
|
42
|
+
per_layer_size = 256 * (allocate_size + n_routed_experts)
|
43
|
+
|
44
|
+
return num_moe_layers * per_layer_size
|
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1433
1433
|
|
1434
1434
|
return position_ids, mrope_position_deltas
|
1435
1435
|
|
1436
|
-
@staticmethod
|
1437
|
-
def get_next_input_positions(
|
1438
|
-
mrope_position_delta: int,
|
1439
|
-
context_len: int,
|
1440
|
-
seq_len: int,
|
1441
|
-
) -> torch.Tensor:
|
1442
|
-
return torch.tensor(
|
1443
|
-
[
|
1444
|
-
list(
|
1445
|
-
range(
|
1446
|
-
context_len + mrope_position_delta,
|
1447
|
-
seq_len + mrope_position_delta,
|
1448
|
-
)
|
1449
|
-
)
|
1450
|
-
for _ in range(3)
|
1451
|
-
]
|
1452
|
-
)
|
1453
|
-
|
1454
1436
|
|
1455
1437
|
class DualChunkRotaryEmbedding(CustomOp):
|
1456
1438
|
"""Rotary positional embedding for Dual Chunk Attention."""
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
|
|
39
39
|
if is_dp_attention_enabled():
|
40
40
|
self.tp_sync_group = get_attention_tp_group().device_group
|
41
41
|
|
42
|
+
def _preprocess_logits(
|
43
|
+
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
44
|
+
) -> torch.Tensor:
|
45
|
+
"""Apply custom logit processors and handle NaN detection."""
|
46
|
+
# Apply the custom logit processors if registered in the sampling info
|
47
|
+
if sampling_info.has_custom_logit_processor:
|
48
|
+
apply_custom_logit_processor(logits, sampling_info)
|
49
|
+
|
50
|
+
# Detect and handle NaN values in logits
|
51
|
+
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
52
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
53
|
+
logits = torch.where(
|
54
|
+
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
55
|
+
)
|
56
|
+
if crash_on_warnings():
|
57
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
58
|
+
|
59
|
+
return logits
|
60
|
+
|
42
61
|
def forward(
|
43
62
|
self,
|
44
63
|
logits_output: LogitsProcessorOutput,
|
@@ -61,17 +80,8 @@ class Sampler(nn.Module):
|
|
61
80
|
"""
|
62
81
|
logits = logits_output.next_token_logits
|
63
82
|
|
64
|
-
#
|
65
|
-
|
66
|
-
apply_custom_logit_processor(logits, sampling_info)
|
67
|
-
|
68
|
-
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
69
|
-
logger.warning("Detected errors during sampling! NaN in the logits.")
|
70
|
-
logits = torch.where(
|
71
|
-
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
72
|
-
)
|
73
|
-
if crash_on_warnings():
|
74
|
-
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
83
|
+
# Preprocess logits (custom processors and NaN handling)
|
84
|
+
logits = self._preprocess_logits(logits, sampling_info)
|
75
85
|
|
76
86
|
if sampling_info.is_all_greedy:
|
77
87
|
# Use torch.argmax if all requests use greedy sampling
|
@@ -80,9 +90,9 @@ class Sampler(nn.Module):
|
|
80
90
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
81
91
|
|
82
92
|
else:
|
83
|
-
#
|
93
|
+
# If requested, cache probabilities from original logits before temperature scaling.
|
84
94
|
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
85
|
-
|
95
|
+
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
|
86
96
|
|
87
97
|
# Post process logits
|
88
98
|
logits.div_(sampling_info.temperatures)
|
@@ -123,9 +133,10 @@ class Sampler(nn.Module):
|
|
123
133
|
if return_logprob:
|
124
134
|
# clamp to avoid -inf
|
125
135
|
if RETURN_ORIGINAL_LOGPROB:
|
126
|
-
logprobs = torch.log(
|
127
|
-
min=torch.finfo(
|
136
|
+
logprobs = torch.log(probs_without_temp_scaling).clamp(
|
137
|
+
min=torch.finfo(probs_without_temp_scaling.dtype).min
|
128
138
|
)
|
139
|
+
del probs_without_temp_scaling
|
129
140
|
else:
|
130
141
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
131
142
|
|
@@ -164,6 +175,54 @@ class Sampler(nn.Module):
|
|
164
175
|
|
165
176
|
return batch_next_token_ids
|
166
177
|
|
178
|
+
def compute_logprobs_only(
|
179
|
+
self,
|
180
|
+
logits_output: LogitsProcessorOutput,
|
181
|
+
sampling_info: SamplingBatchInfo,
|
182
|
+
return_logprob: bool,
|
183
|
+
top_logprobs_nums: List[int],
|
184
|
+
token_ids_logprobs: List[List[int]],
|
185
|
+
) -> None:
|
186
|
+
"""
|
187
|
+
Compute logprobs for requested token IDs without performing sampling.
|
188
|
+
|
189
|
+
Optimized for prefill-only scoring requests that need token probabilities
|
190
|
+
but don't require next token generation.
|
191
|
+
"""
|
192
|
+
if logits_output.next_token_logits is None:
|
193
|
+
logger.warning("No logits available for logprob computation")
|
194
|
+
return
|
195
|
+
|
196
|
+
# Check if any requests actually need logprobs computation
|
197
|
+
needs_token_ids_logprobs = any(
|
198
|
+
token_ids is not None and len(token_ids) > 0
|
199
|
+
for token_ids in token_ids_logprobs
|
200
|
+
)
|
201
|
+
needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
|
202
|
+
|
203
|
+
if not (needs_token_ids_logprobs or needs_top_logprobs):
|
204
|
+
return
|
205
|
+
|
206
|
+
# Preprocess logits (custom processors and NaN handling)
|
207
|
+
logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
|
208
|
+
|
209
|
+
# Compute logprobs
|
210
|
+
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
211
|
+
|
212
|
+
# Handle top logprobs if requested
|
213
|
+
if needs_top_logprobs:
|
214
|
+
(
|
215
|
+
logits_output.next_token_top_logprobs_val,
|
216
|
+
logits_output.next_token_top_logprobs_idx,
|
217
|
+
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
218
|
+
|
219
|
+
# Handle token_ids logprobs if requested
|
220
|
+
if needs_token_ids_logprobs:
|
221
|
+
(
|
222
|
+
logits_output.next_token_token_ids_logprobs_val,
|
223
|
+
logits_output.next_token_token_ids_logprobs_idx,
|
224
|
+
) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
|
225
|
+
|
167
226
|
|
168
227
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
169
228
|
probs: torch.Tensor,
|
@@ -233,10 +292,95 @@ def get_top_logprobs(
|
|
233
292
|
)
|
234
293
|
|
235
294
|
|
236
|
-
def
|
295
|
+
def get_token_ids_logprobs_batch_optimized(
|
237
296
|
logprobs: torch.Tensor,
|
238
297
|
token_ids_logprobs: List[List[int]],
|
239
|
-
):
|
298
|
+
) -> Tuple[List, List]:
|
299
|
+
"""
|
300
|
+
Vectorized batch processing for token ID logprobs extraction.
|
301
|
+
|
302
|
+
Uses a single GPU kernel call for the entire batch instead of multiple
|
303
|
+
separate calls, significantly improving performance for large batches.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
logprobs: Log probabilities tensor [batch_size, vocab_size]
|
307
|
+
token_ids_logprobs: List of token IDs to extract logprobs for
|
308
|
+
|
309
|
+
Example:
|
310
|
+
# Input: batch_size=3, vocab_size=5
|
311
|
+
logprobs = torch.tensor([
|
312
|
+
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
|
313
|
+
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
|
314
|
+
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
|
315
|
+
])
|
316
|
+
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
|
317
|
+
|
318
|
+
# Output:
|
319
|
+
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
|
320
|
+
# indices = [[1, 3], [2], [0, 2, 4]]
|
321
|
+
"""
|
322
|
+
batch_size = len(token_ids_logprobs)
|
323
|
+
device = logprobs.device
|
324
|
+
|
325
|
+
# Step 1: Calculate lengths for each request, treating None as empty list
|
326
|
+
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
|
327
|
+
token_lengths = torch.tensor(
|
328
|
+
[len(token_ids or []) for token_ids in token_ids_logprobs], device=device
|
329
|
+
)
|
330
|
+
total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
|
331
|
+
|
332
|
+
# Handle edge case where no tokens are requested
|
333
|
+
if total_tokens == 0:
|
334
|
+
return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
|
335
|
+
[] for _ in token_ids_logprobs
|
336
|
+
]
|
337
|
+
|
338
|
+
# Step 2: Build flattened indices using torch operations
|
339
|
+
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
|
340
|
+
row_indices = torch.repeat_interleave(
|
341
|
+
torch.arange(batch_size, device=device), token_lengths
|
342
|
+
)
|
343
|
+
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
|
344
|
+
col_indices = torch.tensor(
|
345
|
+
[
|
346
|
+
token_id
|
347
|
+
for token_ids in token_ids_logprobs
|
348
|
+
for token_id in (token_ids or [])
|
349
|
+
],
|
350
|
+
device=device,
|
351
|
+
dtype=torch.long,
|
352
|
+
)
|
353
|
+
|
354
|
+
# Step 3: Single vectorized gather operation
|
355
|
+
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
|
356
|
+
gathered_logprobs = logprobs[row_indices, col_indices]
|
357
|
+
|
358
|
+
# Step 4: Split results back per request using torch operations
|
359
|
+
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
|
360
|
+
split_logprobs = torch.split_with_sizes(
|
361
|
+
gathered_logprobs, token_lengths.tolist(), dim=0
|
362
|
+
)
|
363
|
+
|
364
|
+
# Step 5: Format output to match expected return structure
|
365
|
+
# Example: Convert split tensors back to list format with proper empty handling
|
366
|
+
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
|
367
|
+
# i=1: [2] -> append split_logprobs[1] and [2]
|
368
|
+
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
|
369
|
+
output_token_ids_logprobs_val = []
|
370
|
+
output_token_ids_logprobs_idx = []
|
371
|
+
|
372
|
+
for i, token_ids in enumerate(token_ids_logprobs):
|
373
|
+
if token_ids is not None and len(token_ids) > 0:
|
374
|
+
output_token_ids_logprobs_val.append(split_logprobs[i])
|
375
|
+
output_token_ids_logprobs_idx.append(token_ids)
|
376
|
+
else:
|
377
|
+
output_token_ids_logprobs_val.append(logprobs.new_empty(0))
|
378
|
+
output_token_ids_logprobs_idx.append([])
|
379
|
+
|
380
|
+
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
381
|
+
|
382
|
+
|
383
|
+
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
240
384
|
output_token_ids_logprobs_val = []
|
241
385
|
output_token_ids_logprobs_idx = []
|
242
386
|
for i, token_ids in enumerate(token_ids_logprobs):
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Tuple, Union
|
1
|
+
from typing import Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
6
7
|
|
7
8
|
|
8
9
|
class BaseLoRABackend:
|
@@ -10,13 +11,14 @@ class BaseLoRABackend:
|
|
10
11
|
Each backend has its own implementation of Lora kernels.
|
11
12
|
|
12
13
|
Args:
|
13
|
-
|
14
|
-
|
14
|
+
max_loras_per_batch: maximum number of different lora weights
|
15
|
+
that can be applied in a single forward batch.
|
16
|
+
device: the device where the backend runs.
|
15
17
|
"""
|
16
18
|
|
17
|
-
def __init__(self,
|
18
|
-
self.
|
19
|
-
self.
|
19
|
+
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
20
|
+
self.max_loras_per_batch = max_loras_per_batch
|
21
|
+
self.device = device
|
20
22
|
|
21
23
|
def run_lora_a_sgemm(
|
22
24
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -93,8 +95,44 @@ class BaseLoRABackend:
|
|
93
95
|
"""
|
94
96
|
pass
|
95
97
|
|
96
|
-
def
|
97
|
-
self
|
98
|
+
def init_cuda_graph_batch_info(
|
99
|
+
self,
|
100
|
+
cuda_graph_batch_info: LoRABatchInfo,
|
101
|
+
max_bs_in_cuda_graph: int,
|
102
|
+
):
|
103
|
+
"""Initialize the batch info for CUDA Graph mode.
|
104
|
+
|
105
|
+
This method provides a hook for each backend to conduct its own initialization
|
106
|
+
logic for CUDA Graph mode.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
|
110
|
+
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
def prepare_lora_batch(
|
115
|
+
self,
|
116
|
+
forward_batch: ForwardBatch,
|
117
|
+
weight_indices: list[int],
|
118
|
+
lora_ranks: list[int],
|
119
|
+
scalings: list[float],
|
120
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
121
|
+
):
|
122
|
+
"""Prepare the lora weights and batch info for current forward batch.
|
123
|
+
|
124
|
+
This method provides a hook for each backend to conduct its own preparation
|
125
|
+
logic for each forward batch.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
forward_batch: the ForwardBatch object for current forward pass
|
129
|
+
weight_indices: list of indices of lora weights to be applied for current batch
|
130
|
+
lora_ranks: list of lora ranks corresponding to weight_indices
|
131
|
+
scalings: list of scaling factors corresponding to weight_indices
|
132
|
+
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
|
133
|
+
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
|
134
|
+
"""
|
135
|
+
pass
|
98
136
|
|
99
137
|
|
100
138
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
105
143
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
106
144
|
|
107
145
|
return TritonLoRABackend
|
146
|
+
# elif name == "csgmv":
|
147
|
+
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
148
|
+
|
149
|
+
# return ChunkedSgmvLoRABackend
|
108
150
|
elif name == "flashinfer":
|
109
151
|
raise ValueError(
|
110
152
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|