sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
get_tensor_model_parallel_world_size,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
27
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
28
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
27
29
|
from sglang.srt.layers.parameter import (
|
28
30
|
ChannelQuantScaleParameter,
|
29
31
|
ModelWeightParameter,
|
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
|
|
49
51
|
)
|
50
52
|
|
51
53
|
if TYPE_CHECKING:
|
52
|
-
from sglang.srt.layers.moe.
|
53
|
-
|
54
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
55
|
+
CombineInput,
|
56
|
+
StandardDispatchOutput,
|
57
|
+
)
|
54
58
|
|
55
59
|
_is_cuda = is_cuda()
|
56
60
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
339
343
|
_is_cpu_amx_available
|
340
344
|
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
341
345
|
_amx_process_weight_after_loading(layer, ["weight"])
|
342
|
-
|
343
|
-
|
344
|
-
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
346
|
+
else:
|
347
|
+
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
345
348
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
346
349
|
|
347
350
|
def create_weights(
|
@@ -417,7 +420,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
417
420
|
layer: torch.nn.Module,
|
418
421
|
num_experts: int,
|
419
422
|
hidden_size: int,
|
420
|
-
|
423
|
+
intermediate_size_per_partition: int,
|
421
424
|
params_dtype: torch.dtype,
|
422
425
|
**extra_weight_attrs,
|
423
426
|
):
|
@@ -428,7 +431,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
428
431
|
# WEIGHTS
|
429
432
|
w13_weight = torch.nn.Parameter(
|
430
433
|
torch.empty(
|
431
|
-
num_experts,
|
434
|
+
num_experts,
|
435
|
+
2 * intermediate_size_per_partition,
|
436
|
+
hidden_size,
|
437
|
+
dtype=torch.int8,
|
432
438
|
),
|
433
439
|
requires_grad=False,
|
434
440
|
)
|
@@ -436,14 +442,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
436
442
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
437
443
|
|
438
444
|
w2_weight = torch.nn.Parameter(
|
439
|
-
torch.empty(
|
445
|
+
torch.empty(
|
446
|
+
num_experts,
|
447
|
+
hidden_size,
|
448
|
+
intermediate_size_per_partition,
|
449
|
+
dtype=torch.int8,
|
450
|
+
),
|
440
451
|
requires_grad=False,
|
441
452
|
)
|
442
453
|
layer.register_parameter("w2_weight", w2_weight)
|
443
454
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
444
455
|
|
445
456
|
w13_weight_scale = torch.nn.Parameter(
|
446
|
-
torch.ones(
|
457
|
+
torch.ones(
|
458
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
459
|
+
),
|
447
460
|
requires_grad=False,
|
448
461
|
)
|
449
462
|
w2_weight_scale = torch.nn.Parameter(
|
@@ -472,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
472
485
|
_is_cpu_amx_available
|
473
486
|
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
474
487
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
488
|
+
else:
|
489
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
490
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
479
491
|
layer.w13_weight_scale = Parameter(
|
480
492
|
layer.w13_weight_scale.data, requires_grad=False
|
481
493
|
)
|
@@ -483,23 +495,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
483
495
|
layer.w2_weight_scale.data, requires_grad=False
|
484
496
|
)
|
485
497
|
|
498
|
+
def create_moe_runner(
|
499
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
500
|
+
):
|
501
|
+
self.moe_runner_config = moe_runner_config
|
502
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
503
|
+
|
486
504
|
def apply(
|
487
505
|
self,
|
488
506
|
layer: torch.nn.Module,
|
489
|
-
|
490
|
-
topk_output: TopKOutput,
|
491
|
-
moe_runner_config: MoeRunnerConfig,
|
507
|
+
dispatch_output: StandardDispatchOutput,
|
492
508
|
) -> torch.Tensor:
|
493
|
-
from sglang.srt.layers.moe.
|
509
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
510
|
+
|
511
|
+
x = dispatch_output.hidden_states
|
512
|
+
topk_output = dispatch_output.topk_output
|
494
513
|
|
495
514
|
if use_intel_amx_backend(layer):
|
496
515
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
497
516
|
|
498
517
|
topk_weights, topk_ids, _ = topk_output
|
499
518
|
x, topk_weights = apply_topk_weights_cpu(
|
500
|
-
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
519
|
+
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
501
520
|
)
|
502
|
-
|
521
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
503
522
|
x,
|
504
523
|
layer.w13_weight,
|
505
524
|
layer.w2_weight,
|
@@ -515,20 +534,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
515
534
|
layer.w2_input_scale, # a2_scale
|
516
535
|
True, # is_vnni
|
517
536
|
)
|
537
|
+
return StandardCombineInput(hidden_states=output)
|
518
538
|
|
519
|
-
|
520
|
-
|
521
|
-
layer.
|
522
|
-
layer.w2_weight,
|
523
|
-
topk_output=topk_output,
|
524
|
-
moe_runner_config=moe_runner_config,
|
539
|
+
quant_info = TritonMoeQuantInfo(
|
540
|
+
w13_weight=layer.w13_weight,
|
541
|
+
w2_weight=layer.w2_weight,
|
525
542
|
use_int8_w8a8=True,
|
526
543
|
per_channel_quant=True,
|
527
|
-
|
528
|
-
w2_scale=
|
529
|
-
|
544
|
+
w13_scale=layer.w13_weight_scale,
|
545
|
+
w2_scale=layer.w2_weight_scale,
|
546
|
+
a13_scale=layer.w13_input_scale,
|
530
547
|
a2_scale=layer.w2_input_scale,
|
531
548
|
)
|
549
|
+
return self.runner.run(dispatch_output, quant_info)
|
532
550
|
|
533
551
|
|
534
552
|
class NPU_W8A8LinearMethodImpl:
|
@@ -551,7 +569,7 @@ class NPU_W8A8LinearMethodImpl:
|
|
551
569
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
552
570
|
params_dict = {}
|
553
571
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
554
|
-
params_dict["input_offset"] = torch.empty(1, dtype=
|
572
|
+
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
|
555
573
|
return params_dict
|
556
574
|
|
557
575
|
@staticmethod
|
@@ -582,11 +600,11 @@ class NPU_W8A8LinearMethodImpl:
|
|
582
600
|
if original_dtype != torch.int8:
|
583
601
|
x = torch_npu.npu_quantize(
|
584
602
|
x,
|
585
|
-
layer.
|
603
|
+
layer.aclnn_input_scale_reciprocal,
|
586
604
|
layer.aclnn_input_offset,
|
587
605
|
torch.qint8,
|
588
606
|
-1,
|
589
|
-
|
607
|
+
False,
|
590
608
|
)
|
591
609
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
592
610
|
# bias will not get added more than once in Attention TP>1 case)
|
@@ -608,6 +626,10 @@ class NPU_W8A8LinearMethodImpl:
|
|
608
626
|
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
609
627
|
requires_grad=False,
|
610
628
|
)
|
629
|
+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
630
|
+
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
631
|
+
requires_grad=False,
|
632
|
+
)
|
611
633
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
612
634
|
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
613
635
|
requires_grad=False,
|
@@ -896,7 +918,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
896
918
|
layer: torch.nn.Module,
|
897
919
|
num_experts: int,
|
898
920
|
hidden_size: int,
|
899
|
-
|
921
|
+
intermediate_size_per_partition: int,
|
900
922
|
params_dtype: torch.dtype,
|
901
923
|
**extra_weight_attrs,
|
902
924
|
) -> None:
|
@@ -910,21 +932,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
910
932
|
# weight
|
911
933
|
w13_weight = torch.nn.Parameter(
|
912
934
|
torch.empty(
|
913
|
-
num_experts,
|
935
|
+
num_experts,
|
936
|
+
2 * intermediate_size_per_partition,
|
937
|
+
hidden_size,
|
938
|
+
dtype=torch.int8,
|
914
939
|
),
|
915
940
|
requires_grad=False,
|
916
941
|
)
|
917
942
|
layer.register_parameter("w13_weight", w13_weight)
|
918
943
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
919
944
|
w2_weight = torch.nn.Parameter(
|
920
|
-
torch.empty(
|
945
|
+
torch.empty(
|
946
|
+
num_experts,
|
947
|
+
hidden_size,
|
948
|
+
intermediate_size_per_partition,
|
949
|
+
dtype=torch.int8,
|
950
|
+
),
|
921
951
|
requires_grad=False,
|
922
952
|
)
|
923
953
|
layer.register_parameter("w2_weight", w2_weight)
|
924
954
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
925
955
|
# scale
|
926
956
|
w13_weight_scale = torch.nn.Parameter(
|
927
|
-
torch.empty(
|
957
|
+
torch.empty(
|
958
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
959
|
+
),
|
928
960
|
requires_grad=False,
|
929
961
|
)
|
930
962
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
@@ -937,7 +969,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
937
969
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
938
970
|
# offset
|
939
971
|
w13_weight_offset = torch.nn.Parameter(
|
940
|
-
torch.empty(
|
972
|
+
torch.empty(
|
973
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
974
|
+
),
|
941
975
|
requires_grad=False,
|
942
976
|
)
|
943
977
|
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
@@ -969,18 +1003,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
969
1003
|
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
970
1004
|
)
|
971
1005
|
|
1006
|
+
def create_moe_runner(
|
1007
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
|
+
):
|
1009
|
+
self.moe_runner_config = moe_runner_config
|
1010
|
+
|
972
1011
|
def apply(
|
973
1012
|
self,
|
974
1013
|
layer,
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
1014
|
+
dispatch_output: StandardDispatchOutput,
|
1015
|
+
) -> CombineInput:
|
1016
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1017
|
+
|
1018
|
+
x = dispatch_output.hidden_states
|
1019
|
+
topk_output = dispatch_output.topk_output
|
979
1020
|
|
980
1021
|
topk_weights, topk_ids, _ = topk_output
|
981
1022
|
topk_ids = topk_ids.to(torch.int32)
|
982
1023
|
topk_weights = topk_weights.to(x.dtype)
|
983
|
-
|
1024
|
+
output = npu_fused_experts(
|
984
1025
|
hidden_states=x,
|
985
1026
|
w13=layer.w13_weight,
|
986
1027
|
w13_scale=layer.w13_weight_scale,
|
@@ -990,3 +1031,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
990
1031
|
topk_ids=topk_ids,
|
991
1032
|
top_k=topk_ids.shape[1],
|
992
1033
|
)
|
1034
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import torch
|
2
|
+
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
|
3
|
+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
4
|
+
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
|
5
|
+
|
6
|
+
from sglang.srt.utils import BumpAllocator
|
7
|
+
|
8
|
+
__all__ = ["fused_qk_rope_cat"]
|
9
|
+
|
10
|
+
|
11
|
+
def aiter_dsv3_router_gemm(
|
12
|
+
hidden_states: torch.Tensor,
|
13
|
+
weight: torch.Tensor,
|
14
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
15
|
+
):
|
16
|
+
M = hidden_states.shape[0]
|
17
|
+
N = weight.shape[0]
|
18
|
+
y = None
|
19
|
+
|
20
|
+
if M <= 256:
|
21
|
+
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
|
22
|
+
# for now it is also coupled with zero allocator.
|
23
|
+
if gemm_output_zero_allocator != None:
|
24
|
+
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
|
25
|
+
else:
|
26
|
+
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
|
27
|
+
|
28
|
+
if y is not None:
|
29
|
+
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
|
30
|
+
else:
|
31
|
+
logits = gemm_a16w16(hidden_states, weight)
|
32
|
+
|
33
|
+
return logits
|
34
|
+
|
35
|
+
|
36
|
+
def get_dsv3_gemm_output_zero_allocator_size(
|
37
|
+
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
|
38
|
+
):
|
39
|
+
if embedding_dim != 7168 or n_routed_experts != 256:
|
40
|
+
return 0
|
41
|
+
|
42
|
+
per_layer_size = 256 * (allocate_size + n_routed_experts)
|
43
|
+
|
44
|
+
return num_moe_layers * per_layer_size
|
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1433
1433
|
|
1434
1434
|
return position_ids, mrope_position_deltas
|
1435
1435
|
|
1436
|
-
@staticmethod
|
1437
|
-
def get_next_input_positions(
|
1438
|
-
mrope_position_delta: int,
|
1439
|
-
context_len: int,
|
1440
|
-
seq_len: int,
|
1441
|
-
) -> torch.Tensor:
|
1442
|
-
return torch.tensor(
|
1443
|
-
[
|
1444
|
-
list(
|
1445
|
-
range(
|
1446
|
-
context_len + mrope_position_delta,
|
1447
|
-
seq_len + mrope_position_delta,
|
1448
|
-
)
|
1449
|
-
)
|
1450
|
-
for _ in range(3)
|
1451
|
-
]
|
1452
|
-
)
|
1453
|
-
|
1454
1436
|
|
1455
1437
|
class DualChunkRotaryEmbedding(CustomOp):
|
1456
1438
|
"""Rotary positional embedding for Dual Chunk Attention."""
|
@@ -1876,7 +1858,7 @@ def rotate_half(x):
|
|
1876
1858
|
return torch.cat((-x2, x1), dim=-1)
|
1877
1859
|
|
1878
1860
|
|
1879
|
-
def
|
1861
|
+
def apply_rotary_pos_emb_native(
|
1880
1862
|
q: torch.Tensor,
|
1881
1863
|
k: torch.Tensor,
|
1882
1864
|
cos: torch.Tensor,
|
@@ -1899,6 +1881,33 @@ def apply_rotary_pos_emb(
|
|
1899
1881
|
return q_embed, k_embed
|
1900
1882
|
|
1901
1883
|
|
1884
|
+
def apply_rotary_pos_emb_npu(
|
1885
|
+
q: torch.Tensor,
|
1886
|
+
k: torch.Tensor,
|
1887
|
+
cos: torch.Tensor,
|
1888
|
+
sin: torch.Tensor,
|
1889
|
+
unsqueeze_dim=1,
|
1890
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1891
|
+
if q.shape[1] != 128:
|
1892
|
+
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
1893
|
+
cos = cos.unsqueeze(unsqueeze_dim)
|
1894
|
+
cos = torch.transpose(cos, 1, 2)
|
1895
|
+
sin = sin.unsqueeze(unsqueeze_dim)
|
1896
|
+
sin = torch.transpose(sin, 1, 2)
|
1897
|
+
q = torch.transpose(q, 1, 2)
|
1898
|
+
k = torch.transpose(k, 1, 2)
|
1899
|
+
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
1900
|
+
q_embed = torch.transpose(q_embed, 1, 2)
|
1901
|
+
k_embed = torch.transpose(k_embed, 1, 2)
|
1902
|
+
return q_embed, k_embed
|
1903
|
+
|
1904
|
+
|
1905
|
+
if _is_npu:
|
1906
|
+
apply_rotary_pos_emb = apply_rotary_pos_emb_npu
|
1907
|
+
else:
|
1908
|
+
apply_rotary_pos_emb = apply_rotary_pos_emb_native
|
1909
|
+
|
1910
|
+
|
1902
1911
|
def get_rope_cpu(
|
1903
1912
|
head_size: int,
|
1904
1913
|
rotary_dim: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -27,6 +27,7 @@ if is_cuda():
|
|
27
27
|
logger = logging.getLogger(__name__)
|
28
28
|
|
29
29
|
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
30
|
+
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
30
31
|
|
31
32
|
|
32
33
|
class Sampler(nn.Module):
|
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
|
|
77
78
|
batch_next_token_ids = torch.argmax(logits, -1)
|
78
79
|
if return_logprob:
|
79
80
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
81
|
+
|
80
82
|
else:
|
83
|
+
# Post process original logits. if temperatures are all 1.0, no need to rescale
|
84
|
+
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
85
|
+
logprobs = torch.softmax(logits, dim=-1)
|
86
|
+
|
81
87
|
# Post process logits
|
82
88
|
logits.div_(sampling_info.temperatures)
|
83
89
|
logits[:] = torch.softmax(logits, dim=-1)
|
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
|
|
116
122
|
|
117
123
|
if return_logprob:
|
118
124
|
# clamp to avoid -inf
|
119
|
-
|
125
|
+
if RETURN_ORIGINAL_LOGPROB:
|
126
|
+
logprobs = torch.log(logprobs).clamp(
|
127
|
+
min=torch.finfo(logprobs.dtype).min
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
120
131
|
|
121
132
|
# Attach logprobs to logits_output (in-place modification)
|
122
133
|
if return_logprob:
|
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
|
|
201
212
|
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
202
213
|
|
203
214
|
|
204
|
-
def get_top_logprobs(
|
215
|
+
def get_top_logprobs(
|
216
|
+
logprobs: torch.Tensor,
|
217
|
+
top_logprobs_nums: List[int],
|
218
|
+
):
|
205
219
|
max_k = max(top_logprobs_nums)
|
206
220
|
ret = logprobs.topk(max_k, dim=1)
|
207
221
|
values = ret.values.tolist()
|
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
|
212
226
|
for i, k in enumerate(top_logprobs_nums):
|
213
227
|
output_top_logprobs_val.append(values[i][:k])
|
214
228
|
output_top_logprobs_idx.append(indices[i][:k])
|
215
|
-
|
229
|
+
|
230
|
+
return (
|
231
|
+
output_top_logprobs_val,
|
232
|
+
output_top_logprobs_idx,
|
233
|
+
)
|
216
234
|
|
217
235
|
|
218
|
-
def get_token_ids_logprobs(
|
236
|
+
def get_token_ids_logprobs(
|
237
|
+
logprobs: torch.Tensor,
|
238
|
+
token_ids_logprobs: List[List[int]],
|
239
|
+
):
|
219
240
|
output_token_ids_logprobs_val = []
|
220
241
|
output_token_ids_logprobs_idx = []
|
221
242
|
for i, token_ids in enumerate(token_ids_logprobs):
|
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
|
|
226
247
|
output_token_ids_logprobs_val.append([])
|
227
248
|
output_token_ids_logprobs_idx.append([])
|
228
249
|
|
229
|
-
return
|
250
|
+
return (
|
251
|
+
output_token_ids_logprobs_val,
|
252
|
+
output_token_ids_logprobs_idx,
|
253
|
+
)
|
230
254
|
|
231
255
|
|
232
256
|
def apply_custom_logit_processor(
|
sglang/srt/layers/utils.py
CHANGED
@@ -34,17 +34,3 @@ class PPMissingLayer(torch.nn.Identity):
|
|
34
34
|
"""
|
35
35
|
input = args[0] if args else next(iter(kwargs.values()))
|
36
36
|
return (input,) if self.return_tuple else input
|
37
|
-
|
38
|
-
|
39
|
-
@lru_cache(maxsize=1)
|
40
|
-
def is_sm100_supported(device=None) -> bool:
|
41
|
-
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
42
|
-
torch.version.cuda >= "12.8"
|
43
|
-
)
|
44
|
-
|
45
|
-
|
46
|
-
@lru_cache(maxsize=1)
|
47
|
-
def is_sm90_supported(device=None) -> bool:
|
48
|
-
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
49
|
-
torch.version.cuda >= "12.3"
|
50
|
-
)
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Tuple, Union
|
1
|
+
from typing import Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
6
7
|
|
7
8
|
|
8
9
|
class BaseLoRABackend:
|
@@ -10,13 +11,14 @@ class BaseLoRABackend:
|
|
10
11
|
Each backend has its own implementation of Lora kernels.
|
11
12
|
|
12
13
|
Args:
|
13
|
-
|
14
|
-
|
14
|
+
max_loras_per_batch: maximum number of different lora weights
|
15
|
+
that can be applied in a single forward batch.
|
16
|
+
device: the device where the backend runs.
|
15
17
|
"""
|
16
18
|
|
17
|
-
def __init__(self,
|
18
|
-
self.
|
19
|
-
self.
|
19
|
+
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
20
|
+
self.max_loras_per_batch = max_loras_per_batch
|
21
|
+
self.device = device
|
20
22
|
|
21
23
|
def run_lora_a_sgemm(
|
22
24
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -93,8 +95,44 @@ class BaseLoRABackend:
|
|
93
95
|
"""
|
94
96
|
pass
|
95
97
|
|
96
|
-
def
|
97
|
-
self
|
98
|
+
def init_cuda_graph_batch_info(
|
99
|
+
self,
|
100
|
+
cuda_graph_batch_info: LoRABatchInfo,
|
101
|
+
max_bs_in_cuda_graph: int,
|
102
|
+
):
|
103
|
+
"""Initialize the batch info for CUDA Graph mode.
|
104
|
+
|
105
|
+
This method provides a hook for each backend to conduct its own initialization
|
106
|
+
logic for CUDA Graph mode.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
|
110
|
+
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
def prepare_lora_batch(
|
115
|
+
self,
|
116
|
+
forward_batch: ForwardBatch,
|
117
|
+
weight_indices: list[int],
|
118
|
+
lora_ranks: list[int],
|
119
|
+
scalings: list[float],
|
120
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
121
|
+
):
|
122
|
+
"""Prepare the lora weights and batch info for current forward batch.
|
123
|
+
|
124
|
+
This method provides a hook for each backend to conduct its own preparation
|
125
|
+
logic for each forward batch.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
forward_batch: the ForwardBatch object for current forward pass
|
129
|
+
weight_indices: list of indices of lora weights to be applied for current batch
|
130
|
+
lora_ranks: list of lora ranks corresponding to weight_indices
|
131
|
+
scalings: list of scaling factors corresponding to weight_indices
|
132
|
+
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
|
133
|
+
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
|
134
|
+
"""
|
135
|
+
pass
|
98
136
|
|
99
137
|
|
100
138
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
105
143
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
106
144
|
|
107
145
|
return TritonLoRABackend
|
146
|
+
# elif name == "csgmv":
|
147
|
+
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
148
|
+
|
149
|
+
# return ChunkedSgmvLoRABackend
|
108
150
|
elif name == "flashinfer":
|
109
151
|
raise ValueError(
|
110
152
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
1
3
|
import torch
|
2
4
|
|
3
5
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
|
|
8
10
|
sgemm_lora_b_fwd,
|
9
11
|
)
|
10
12
|
from sglang.srt.lora.utils import LoRABatchInfo
|
13
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
14
|
|
12
15
|
|
13
16
|
class TritonLoRABackend(BaseLoRABackend):
|
17
|
+
name = "triton"
|
14
18
|
|
15
|
-
def __init__(self,
|
16
|
-
super().__init__(
|
19
|
+
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
20
|
+
super().__init__(max_loras_per_batch, device)
|
17
21
|
|
18
22
|
def run_lora_a_sgemm(
|
19
23
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
86
90
|
base_output,
|
87
91
|
)
|
88
92
|
return lora_output
|
93
|
+
|
94
|
+
def init_cuda_graph_batch_info(
|
95
|
+
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
|
96
|
+
):
|
97
|
+
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
98
|
+
# across batches.
|
99
|
+
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
|
100
|
+
torch.cumsum(
|
101
|
+
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
|
102
|
+
dim=0,
|
103
|
+
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
|
104
|
+
)
|
105
|
+
|
106
|
+
def prepare_lora_batch(
|
107
|
+
self,
|
108
|
+
forward_batch: ForwardBatch,
|
109
|
+
weight_indices: list[int],
|
110
|
+
lora_ranks: list[int],
|
111
|
+
scalings: list[float],
|
112
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
113
|
+
):
|
114
|
+
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
115
|
+
weight_indices_tensor = torch.tensor(
|
116
|
+
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
117
|
+
)
|
118
|
+
lora_ranks_tensor = torch.tensor(
|
119
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
120
|
+
)
|
121
|
+
scalings_tensor = torch.tensor(
|
122
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
123
|
+
)
|
124
|
+
|
125
|
+
bs = forward_batch.batch_size
|
126
|
+
|
127
|
+
if batch_info is not None:
|
128
|
+
assert (
|
129
|
+
batch_info.use_cuda_graph
|
130
|
+
), "batch_info.use_cuda_graph must be True when batch_info is provided"
|
131
|
+
batch_info.bs = forward_batch.batch_size
|
132
|
+
batch_info.num_segments = forward_batch.batch_size
|
133
|
+
else:
|
134
|
+
max_len = (
|
135
|
+
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
136
|
+
max(forward_batch.extend_seq_lens_cpu)
|
137
|
+
if forward_batch.forward_mode.is_extend()
|
138
|
+
else 1
|
139
|
+
)
|
140
|
+
seg_lens = (
|
141
|
+
forward_batch.extend_seq_lens
|
142
|
+
if forward_batch.forward_mode.is_extend()
|
143
|
+
else torch.ones(bs, device=self.device)
|
144
|
+
)
|
145
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
146
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
147
|
+
|
148
|
+
batch_info = LoRABatchInfo(
|
149
|
+
bs=forward_batch.batch_size,
|
150
|
+
num_segments=forward_batch.batch_size,
|
151
|
+
max_len=max_len,
|
152
|
+
use_cuda_graph=False,
|
153
|
+
seg_lens=seg_lens,
|
154
|
+
seg_indptr=seg_indptr,
|
155
|
+
weight_indices=torch.empty(
|
156
|
+
(bs,), dtype=torch.int32, device=self.device
|
157
|
+
),
|
158
|
+
lora_ranks=torch.empty(
|
159
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
160
|
+
),
|
161
|
+
scalings=torch.empty(
|
162
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
163
|
+
),
|
164
|
+
permutation=None,
|
165
|
+
)
|
166
|
+
|
167
|
+
# Copy to device asynchronously
|
168
|
+
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
169
|
+
lora_ranks_tensor, non_blocking=True
|
170
|
+
)
|
171
|
+
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
172
|
+
scalings_tensor, non_blocking=True
|
173
|
+
)
|
174
|
+
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
175
|
+
|
176
|
+
self.batch_info = batch_info
|