sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
get_tensor_model_parallel_world_size,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
27
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
28
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
27
29
|
from sglang.srt.layers.parameter import (
|
28
30
|
ChannelQuantScaleParameter,
|
29
31
|
ModelWeightParameter,
|
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
|
|
49
51
|
)
|
50
52
|
|
51
53
|
if TYPE_CHECKING:
|
52
|
-
from sglang.srt.layers.moe.
|
53
|
-
|
54
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
55
|
+
CombineInput,
|
56
|
+
StandardDispatchOutput,
|
57
|
+
)
|
54
58
|
|
55
59
|
_is_cuda = is_cuda()
|
56
60
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
339
343
|
_is_cpu_amx_available
|
340
344
|
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
341
345
|
_amx_process_weight_after_loading(layer, ["weight"])
|
342
|
-
|
343
|
-
|
344
|
-
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
346
|
+
else:
|
347
|
+
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
345
348
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
346
349
|
|
347
350
|
def create_weights(
|
@@ -390,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
390
393
|
x.dtype,
|
391
394
|
True, # is_vnni
|
392
395
|
)
|
393
|
-
|
394
396
|
x_q, x_scale = per_token_quant_int8(x)
|
395
397
|
|
396
|
-
|
397
|
-
|
398
|
+
x_q_2d = x_q.view(-1, x_q.shape[-1])
|
399
|
+
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
400
|
+
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
401
|
+
|
402
|
+
output = int8_scaled_mm(
|
403
|
+
x_q_2d,
|
404
|
+
layer.weight,
|
405
|
+
x_scale_2d,
|
406
|
+
layer.weight_scale,
|
407
|
+
out_dtype=x.dtype,
|
408
|
+
bias=bias,
|
398
409
|
)
|
399
410
|
|
411
|
+
return output.view(output_shape)
|
412
|
+
|
400
413
|
|
401
414
|
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
402
415
|
"""MoE method for INT8.
|
@@ -417,7 +430,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
417
430
|
layer: torch.nn.Module,
|
418
431
|
num_experts: int,
|
419
432
|
hidden_size: int,
|
420
|
-
|
433
|
+
intermediate_size_per_partition: int,
|
421
434
|
params_dtype: torch.dtype,
|
422
435
|
**extra_weight_attrs,
|
423
436
|
):
|
@@ -428,7 +441,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
428
441
|
# WEIGHTS
|
429
442
|
w13_weight = torch.nn.Parameter(
|
430
443
|
torch.empty(
|
431
|
-
num_experts,
|
444
|
+
num_experts,
|
445
|
+
2 * intermediate_size_per_partition,
|
446
|
+
hidden_size,
|
447
|
+
dtype=torch.int8,
|
432
448
|
),
|
433
449
|
requires_grad=False,
|
434
450
|
)
|
@@ -436,14 +452,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
436
452
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
437
453
|
|
438
454
|
w2_weight = torch.nn.Parameter(
|
439
|
-
torch.empty(
|
455
|
+
torch.empty(
|
456
|
+
num_experts,
|
457
|
+
hidden_size,
|
458
|
+
intermediate_size_per_partition,
|
459
|
+
dtype=torch.int8,
|
460
|
+
),
|
440
461
|
requires_grad=False,
|
441
462
|
)
|
442
463
|
layer.register_parameter("w2_weight", w2_weight)
|
443
464
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
444
465
|
|
445
466
|
w13_weight_scale = torch.nn.Parameter(
|
446
|
-
torch.ones(
|
467
|
+
torch.ones(
|
468
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
469
|
+
),
|
447
470
|
requires_grad=False,
|
448
471
|
)
|
449
472
|
w2_weight_scale = torch.nn.Parameter(
|
@@ -472,10 +495,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
472
495
|
_is_cpu_amx_available
|
473
496
|
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
474
497
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
498
|
+
else:
|
499
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
500
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
479
501
|
layer.w13_weight_scale = Parameter(
|
480
502
|
layer.w13_weight_scale.data, requires_grad=False
|
481
503
|
)
|
@@ -483,23 +505,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
483
505
|
layer.w2_weight_scale.data, requires_grad=False
|
484
506
|
)
|
485
507
|
|
508
|
+
def create_moe_runner(
|
509
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
510
|
+
):
|
511
|
+
self.moe_runner_config = moe_runner_config
|
512
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
513
|
+
|
486
514
|
def apply(
|
487
515
|
self,
|
488
516
|
layer: torch.nn.Module,
|
489
|
-
|
490
|
-
topk_output: TopKOutput,
|
491
|
-
moe_runner_config: MoeRunnerConfig,
|
517
|
+
dispatch_output: StandardDispatchOutput,
|
492
518
|
) -> torch.Tensor:
|
493
|
-
from sglang.srt.layers.moe.
|
519
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
520
|
+
|
521
|
+
x = dispatch_output.hidden_states
|
522
|
+
topk_output = dispatch_output.topk_output
|
494
523
|
|
495
524
|
if use_intel_amx_backend(layer):
|
496
525
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
497
526
|
|
498
527
|
topk_weights, topk_ids, _ = topk_output
|
499
528
|
x, topk_weights = apply_topk_weights_cpu(
|
500
|
-
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
529
|
+
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
501
530
|
)
|
502
|
-
|
531
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
503
532
|
x,
|
504
533
|
layer.w13_weight,
|
505
534
|
layer.w2_weight,
|
@@ -515,20 +544,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
515
544
|
layer.w2_input_scale, # a2_scale
|
516
545
|
True, # is_vnni
|
517
546
|
)
|
547
|
+
return StandardCombineInput(hidden_states=output)
|
518
548
|
|
519
|
-
|
520
|
-
|
521
|
-
layer.
|
522
|
-
layer.w2_weight,
|
523
|
-
topk_output=topk_output,
|
524
|
-
moe_runner_config=moe_runner_config,
|
549
|
+
quant_info = TritonMoeQuantInfo(
|
550
|
+
w13_weight=layer.w13_weight,
|
551
|
+
w2_weight=layer.w2_weight,
|
525
552
|
use_int8_w8a8=True,
|
526
553
|
per_channel_quant=True,
|
527
|
-
|
528
|
-
w2_scale=
|
529
|
-
|
554
|
+
w13_scale=layer.w13_weight_scale,
|
555
|
+
w2_scale=layer.w2_weight_scale,
|
556
|
+
a13_scale=layer.w13_input_scale,
|
530
557
|
a2_scale=layer.w2_input_scale,
|
531
558
|
)
|
559
|
+
return self.runner.run(dispatch_output, quant_info)
|
532
560
|
|
533
561
|
|
534
562
|
class NPU_W8A8LinearMethodImpl:
|
@@ -620,6 +648,7 @@ class NPU_W8A8LinearMethodImpl:
|
|
620
648
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
621
649
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
622
650
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
651
|
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
623
652
|
|
624
653
|
|
625
654
|
class NPU_W8A8LinearMethodMTImpl:
|
@@ -812,6 +841,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|
812
841
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
813
842
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
814
843
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
844
|
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
815
845
|
|
816
846
|
|
817
847
|
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
@@ -900,7 +930,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
900
930
|
layer: torch.nn.Module,
|
901
931
|
num_experts: int,
|
902
932
|
hidden_size: int,
|
903
|
-
|
933
|
+
intermediate_size_per_partition: int,
|
904
934
|
params_dtype: torch.dtype,
|
905
935
|
**extra_weight_attrs,
|
906
936
|
) -> None:
|
@@ -914,21 +944,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
914
944
|
# weight
|
915
945
|
w13_weight = torch.nn.Parameter(
|
916
946
|
torch.empty(
|
917
|
-
num_experts,
|
947
|
+
num_experts,
|
948
|
+
2 * intermediate_size_per_partition,
|
949
|
+
hidden_size,
|
950
|
+
dtype=torch.int8,
|
918
951
|
),
|
919
952
|
requires_grad=False,
|
920
953
|
)
|
921
954
|
layer.register_parameter("w13_weight", w13_weight)
|
922
955
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
923
956
|
w2_weight = torch.nn.Parameter(
|
924
|
-
torch.empty(
|
957
|
+
torch.empty(
|
958
|
+
num_experts,
|
959
|
+
hidden_size,
|
960
|
+
intermediate_size_per_partition,
|
961
|
+
dtype=torch.int8,
|
962
|
+
),
|
925
963
|
requires_grad=False,
|
926
964
|
)
|
927
965
|
layer.register_parameter("w2_weight", w2_weight)
|
928
966
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
929
967
|
# scale
|
930
968
|
w13_weight_scale = torch.nn.Parameter(
|
931
|
-
torch.empty(
|
969
|
+
torch.empty(
|
970
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
971
|
+
),
|
932
972
|
requires_grad=False,
|
933
973
|
)
|
934
974
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
@@ -941,7 +981,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
941
981
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
942
982
|
# offset
|
943
983
|
w13_weight_offset = torch.nn.Parameter(
|
944
|
-
torch.empty(
|
984
|
+
torch.empty(
|
985
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
986
|
+
),
|
945
987
|
requires_grad=False,
|
946
988
|
)
|
947
989
|
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
@@ -973,18 +1015,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
973
1015
|
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
974
1016
|
)
|
975
1017
|
|
1018
|
+
def create_moe_runner(
|
1019
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1020
|
+
):
|
1021
|
+
self.moe_runner_config = moe_runner_config
|
1022
|
+
|
976
1023
|
def apply(
|
977
1024
|
self,
|
978
1025
|
layer,
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
1026
|
+
dispatch_output: StandardDispatchOutput,
|
1027
|
+
) -> CombineInput:
|
1028
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1029
|
+
|
1030
|
+
x = dispatch_output.hidden_states
|
1031
|
+
topk_output = dispatch_output.topk_output
|
983
1032
|
|
984
1033
|
topk_weights, topk_ids, _ = topk_output
|
985
1034
|
topk_ids = topk_ids.to(torch.int32)
|
986
1035
|
topk_weights = topk_weights.to(x.dtype)
|
987
|
-
|
1036
|
+
output = npu_fused_experts(
|
988
1037
|
hidden_states=x,
|
989
1038
|
w13=layer.w13_weight,
|
990
1039
|
w13_scale=layer.w13_weight_scale,
|
@@ -994,3 +1043,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
994
1043
|
topk_ids=topk_ids,
|
995
1044
|
top_k=topk_ids.shape[1],
|
996
1045
|
)
|
1046
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
|
|
12
12
|
from sglang.srt.utils import (
|
13
13
|
cpu_has_amx_support,
|
14
14
|
get_bool_env_var,
|
15
|
+
get_compiler_backend,
|
15
16
|
is_cpu,
|
16
17
|
is_cuda,
|
17
18
|
is_hip,
|
@@ -26,13 +27,19 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|
26
27
|
_is_cpu = is_cpu()
|
27
28
|
|
28
29
|
if _is_cuda:
|
29
|
-
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
30
|
+
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
31
|
+
else:
|
32
|
+
FusedSetKVBufferArg = None
|
33
|
+
|
30
34
|
if _use_aiter:
|
31
35
|
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
32
36
|
|
33
37
|
if is_npu():
|
34
38
|
import torch_npu
|
35
39
|
|
40
|
+
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
|
41
|
+
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
|
42
|
+
|
36
43
|
|
37
44
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
38
45
|
x1 = x[..., : x.shape[-1] // 2]
|
@@ -142,8 +149,13 @@ class RotaryEmbedding(CustomOp):
|
|
142
149
|
query: torch.Tensor,
|
143
150
|
key: torch.Tensor,
|
144
151
|
offsets: Optional[torch.Tensor] = None,
|
152
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
145
153
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
146
154
|
"""A PyTorch-native implementation of forward()."""
|
155
|
+
assert (
|
156
|
+
fused_set_kv_buffer_arg is None
|
157
|
+
), "fused_set_kv_buffer_arg is not supported for native implementation"
|
158
|
+
|
147
159
|
if offsets is not None:
|
148
160
|
positions = positions + offsets
|
149
161
|
positions = positions.flatten()
|
@@ -172,12 +184,17 @@ class RotaryEmbedding(CustomOp):
|
|
172
184
|
query: torch.Tensor,
|
173
185
|
key: torch.Tensor,
|
174
186
|
offsets: Optional[torch.Tensor] = None,
|
187
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
175
188
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
176
189
|
"""A PyTorch-npu implementation of forward()."""
|
177
|
-
|
190
|
+
assert (
|
191
|
+
fused_set_kv_buffer_arg is None
|
192
|
+
), "fused_set_kv_buffer_arg is not supported for npu implementation"
|
178
193
|
|
179
194
|
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
180
|
-
return self.forward_native(
|
195
|
+
return self.forward_native(
|
196
|
+
positions, query, key, offsets, fused_set_kv_buffer_arg
|
197
|
+
)
|
181
198
|
else:
|
182
199
|
rotary_mode = "half"
|
183
200
|
if self.is_neox_style:
|
@@ -202,7 +219,12 @@ class RotaryEmbedding(CustomOp):
|
|
202
219
|
query: torch.Tensor,
|
203
220
|
key: torch.Tensor,
|
204
221
|
offsets: Optional[torch.Tensor] = None,
|
222
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
205
223
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
224
|
+
assert (
|
225
|
+
fused_set_kv_buffer_arg is None
|
226
|
+
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
|
227
|
+
|
206
228
|
positions = torch.add(positions, offsets) if offsets is not None else positions
|
207
229
|
if _is_cpu_amx_available:
|
208
230
|
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
@@ -214,7 +236,9 @@ class RotaryEmbedding(CustomOp):
|
|
214
236
|
self.is_neox_style,
|
215
237
|
)
|
216
238
|
else:
|
217
|
-
return self.forward_native(
|
239
|
+
return self.forward_native(
|
240
|
+
positions, query, key, offsets, fused_set_kv_buffer_arg
|
241
|
+
)
|
218
242
|
|
219
243
|
def forward_cuda(
|
220
244
|
self,
|
@@ -222,7 +246,7 @@ class RotaryEmbedding(CustomOp):
|
|
222
246
|
query: torch.Tensor,
|
223
247
|
key: torch.Tensor,
|
224
248
|
offsets: Optional[torch.Tensor] = None,
|
225
|
-
fused_set_kv_buffer_arg
|
249
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
226
250
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
227
251
|
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
228
252
|
apply_rope_with_cos_sin_cache_inplace(
|
@@ -782,27 +806,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
782
806
|
key: torch.Tensor,
|
783
807
|
offsets: Optional[torch.Tensor] = None,
|
784
808
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
return self.forward_native(positions, query, key, offsets)
|
789
|
-
num_tokens = query.shape[0]
|
790
|
-
rotary_mode = "half" if self.is_neox_style else "interleave"
|
809
|
+
num_tokens, num_q_heads, _ = query.shape
|
810
|
+
num_k_heads = key.shape[1]
|
811
|
+
|
791
812
|
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
813
|
+
cos_sin = self.cos_sin_cache[
|
814
|
+
torch.add(positions, offsets) if offsets is not None else positions
|
815
|
+
]
|
816
|
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
817
|
+
# Reshape to [batchsize, head_dim, seq, rotary_dim]
|
818
|
+
cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
|
819
|
+
sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
|
820
|
+
|
792
821
|
query_rot = query[..., : self.rotary_dim]
|
793
822
|
key_rot = key[..., : self.rotary_dim]
|
794
823
|
if self.rotary_dim < self.head_size:
|
795
824
|
query_pass = query[..., self.rotary_dim :]
|
796
825
|
key_pass = key[..., self.rotary_dim :]
|
797
826
|
|
798
|
-
query_rot
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
827
|
+
query_rot = torch_npu.npu_interleave_rope(
|
828
|
+
query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
|
829
|
+
cos,
|
830
|
+
sin,
|
831
|
+
)
|
832
|
+
key_rot = torch_npu.npu_interleave_rope(
|
833
|
+
key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
|
834
|
+
cos,
|
835
|
+
sin,
|
806
836
|
)
|
807
837
|
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
808
838
|
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
@@ -1029,12 +1059,13 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1029
1059
|
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
1030
1060
|
)
|
1031
1061
|
|
1032
|
-
@torch.compile(dynamic=True)
|
1062
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
1033
1063
|
def forward(
|
1034
1064
|
self,
|
1035
1065
|
positions: torch.Tensor,
|
1036
1066
|
query: torch.Tensor,
|
1037
1067
|
key: torch.Tensor,
|
1068
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
1038
1069
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1039
1070
|
"""PyTorch-native implementation equivalent to forward().
|
1040
1071
|
|
@@ -1045,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1045
1076
|
query: [num_tokens, num_heads * head_size]
|
1046
1077
|
key: [num_tokens, num_kv_heads * head_size]
|
1047
1078
|
"""
|
1079
|
+
assert (
|
1080
|
+
fused_set_kv_buffer_arg is None
|
1081
|
+
), "save kv cache is not supported for MRotaryEmbedding."
|
1048
1082
|
assert positions.ndim == 1 or positions.ndim == 2
|
1049
1083
|
|
1050
1084
|
num_tokens = positions.shape[-1]
|
@@ -1177,7 +1211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1177
1211
|
|
1178
1212
|
time_tensor_long = time_tensor.long()
|
1179
1213
|
t_index = time_tensor_long.flatten()
|
1180
|
-
elif model_type
|
1214
|
+
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
|
1181
1215
|
t_index = (
|
1182
1216
|
torch.arange(llm_grid_t)
|
1183
1217
|
.view(-1, 1)
|
@@ -1888,17 +1922,30 @@ def apply_rotary_pos_emb_npu(
|
|
1888
1922
|
sin: torch.Tensor,
|
1889
1923
|
unsqueeze_dim=1,
|
1890
1924
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1891
|
-
|
1925
|
+
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
|
1926
|
+
|
1927
|
+
Args:
|
1928
|
+
q: [num_tokens, num_heads, head_size]
|
1929
|
+
k: [num_tokens, num_kv_heads, head_size]
|
1930
|
+
cos: [num_tokens, head_size]
|
1931
|
+
sin: [num_tokens, head_size]
|
1932
|
+
"""
|
1933
|
+
if (
|
1934
|
+
cos.dim() != 2
|
1935
|
+
or q.dim() != 3
|
1936
|
+
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
|
1937
|
+
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
|
1938
|
+
):
|
1939
|
+
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
|
1892
1940
|
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
1893
|
-
cos = cos.unsqueeze(unsqueeze_dim)
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
|
1899
|
-
q_embed
|
1900
|
-
|
1901
|
-
k_embed = torch.transpose(k_embed, 1, 2)
|
1941
|
+
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
1942
|
+
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
1943
|
+
q = q.unsqueeze(0)
|
1944
|
+
k = k.unsqueeze(0)
|
1945
|
+
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
1946
|
+
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
1947
|
+
q_embed = q_embed.squeeze(0)
|
1948
|
+
k_embed = k_embed.squeeze(0)
|
1902
1949
|
return q_embed, k_embed
|
1903
1950
|
|
1904
1951
|
|