sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,206 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
logger = logging.getLogger(__name__)
|
4
|
+
|
5
|
+
ATTENTION_BACKENDS = {}
|
6
|
+
|
7
|
+
|
8
|
+
def register_attention_backend(name):
|
9
|
+
def decorator(fn):
|
10
|
+
ATTENTION_BACKENDS[name] = fn
|
11
|
+
return fn
|
12
|
+
|
13
|
+
return decorator
|
14
|
+
|
15
|
+
|
16
|
+
@register_attention_backend("flashinfer")
|
17
|
+
def create_flashinfer_backend(runner):
|
18
|
+
import torch
|
19
|
+
|
20
|
+
if not runner.use_mla_backend:
|
21
|
+
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
22
|
+
|
23
|
+
# Init streams
|
24
|
+
if runner.server_args.speculative_algorithm == "EAGLE":
|
25
|
+
if (
|
26
|
+
not hasattr(runner, "plan_stream_for_flashinfer")
|
27
|
+
or not runner.plan_stream_for_flashinfer
|
28
|
+
):
|
29
|
+
runner.plan_stream_for_flashinfer = torch.cuda.Stream()
|
30
|
+
return FlashInferAttnBackend(runner)
|
31
|
+
else:
|
32
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
33
|
+
FlashInferMLAAttnBackend,
|
34
|
+
)
|
35
|
+
|
36
|
+
return FlashInferMLAAttnBackend(runner)
|
37
|
+
|
38
|
+
|
39
|
+
@register_attention_backend("trtllm_mla")
|
40
|
+
def create_trtllm_mla_backend(runner):
|
41
|
+
if not runner.use_mla_backend:
|
42
|
+
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
43
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
44
|
+
|
45
|
+
return TRTLLMMLABackend(runner)
|
46
|
+
|
47
|
+
|
48
|
+
@register_attention_backend("aiter")
|
49
|
+
def create_aiter_backend(runner):
|
50
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
51
|
+
|
52
|
+
return AiterAttnBackend(runner)
|
53
|
+
|
54
|
+
|
55
|
+
@register_attention_backend("wave")
|
56
|
+
def create_wave_backend(runner):
|
57
|
+
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
58
|
+
|
59
|
+
return WaveAttnBackend(runner)
|
60
|
+
|
61
|
+
|
62
|
+
@register_attention_backend("ascend")
|
63
|
+
def create_ascend_backend(runner):
|
64
|
+
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
65
|
+
|
66
|
+
return AscendAttnBackend(runner)
|
67
|
+
|
68
|
+
|
69
|
+
@register_attention_backend("nsa")
|
70
|
+
def create_nsa_backend(runner):
|
71
|
+
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
72
|
+
|
73
|
+
return NativeSparseAttnBackend(runner)
|
74
|
+
|
75
|
+
|
76
|
+
@register_attention_backend("triton")
|
77
|
+
def create_triton_backend(runner):
|
78
|
+
assert not runner.model_config.is_encoder_decoder, (
|
79
|
+
"Cross attention is not supported in the triton attention backend. "
|
80
|
+
"Please use `--attention-backend flashinfer`."
|
81
|
+
)
|
82
|
+
if runner.server_args.enable_double_sparsity:
|
83
|
+
from sglang.srt.layers.attention.double_sparsity_backend import (
|
84
|
+
DoubleSparseAttnBackend,
|
85
|
+
)
|
86
|
+
|
87
|
+
return DoubleSparseAttnBackend(runner)
|
88
|
+
else:
|
89
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
90
|
+
|
91
|
+
return TritonAttnBackend(runner)
|
92
|
+
|
93
|
+
|
94
|
+
@register_attention_backend("torch_native")
|
95
|
+
def create_torch_native_backend(runner):
|
96
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
97
|
+
|
98
|
+
return TorchNativeAttnBackend(runner)
|
99
|
+
|
100
|
+
|
101
|
+
@register_attention_backend("flex_attention")
|
102
|
+
def create_flex_attention_backend(runner):
|
103
|
+
from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
|
104
|
+
|
105
|
+
return TorchFlexAttnBackend(runner)
|
106
|
+
|
107
|
+
|
108
|
+
@register_attention_backend("flashmla")
|
109
|
+
def create_flashmla_backend(runner):
|
110
|
+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
111
|
+
|
112
|
+
return FlashMLABackend(runner)
|
113
|
+
|
114
|
+
|
115
|
+
@register_attention_backend("fa3")
|
116
|
+
def create_flashattention_v3_backend(runner):
|
117
|
+
import torch
|
118
|
+
|
119
|
+
assert (
|
120
|
+
torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
|
121
|
+
) or torch.cuda.get_device_capability()[0] == 9, (
|
122
|
+
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
123
|
+
"Please use `--attention-backend flashinfer`."
|
124
|
+
)
|
125
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
126
|
+
|
127
|
+
return FlashAttentionBackend(runner)
|
128
|
+
|
129
|
+
|
130
|
+
@register_attention_backend("fa4")
|
131
|
+
def create_flashattention_v4_backend(runner):
|
132
|
+
assert (
|
133
|
+
runner.use_mla_backend
|
134
|
+
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
|
135
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
136
|
+
|
137
|
+
return FlashAttentionBackend(runner, fa_impl_ver=4)
|
138
|
+
|
139
|
+
|
140
|
+
@register_attention_backend("cutlass_mla")
|
141
|
+
def create_cutlass_mla_backend(runner):
|
142
|
+
from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
|
143
|
+
|
144
|
+
return CutlassMLABackend(runner)
|
145
|
+
|
146
|
+
|
147
|
+
@register_attention_backend("trtllm_mha")
|
148
|
+
def create_trtllm_mha_backend(runner):
|
149
|
+
if runner.use_mla_backend:
|
150
|
+
raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
|
151
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
152
|
+
|
153
|
+
return TRTLLMHAAttnBackend(runner)
|
154
|
+
|
155
|
+
|
156
|
+
@register_attention_backend("intel_amx")
|
157
|
+
def create_intel_amx_backend(runner):
|
158
|
+
from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
|
159
|
+
|
160
|
+
return IntelAMXAttnBackend(runner)
|
161
|
+
|
162
|
+
|
163
|
+
@register_attention_backend("dual_chunk_flash_attn")
|
164
|
+
def create_dual_chunk_flash_attn_backend(runner):
|
165
|
+
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
166
|
+
DualChunkFlashAttentionBackend,
|
167
|
+
)
|
168
|
+
|
169
|
+
return DualChunkFlashAttentionBackend(runner)
|
170
|
+
|
171
|
+
|
172
|
+
def attn_backend_wrapper(runner, full_attn_backend):
|
173
|
+
"""
|
174
|
+
Wrapper for special models like hybrid GDN, so we don't
|
175
|
+
need to change the code of the original attention backend.
|
176
|
+
"""
|
177
|
+
assert not (
|
178
|
+
runner.is_hybrid_gdn and runner.use_mla_backend
|
179
|
+
), "hybrid_gdn can only be used with non-MLA models."
|
180
|
+
|
181
|
+
# wrap for hybrid GDN models
|
182
|
+
if runner.is_hybrid_gdn:
|
183
|
+
from sglang.srt.utils import is_blackwell, is_npu
|
184
|
+
|
185
|
+
if is_blackwell():
|
186
|
+
assert (
|
187
|
+
runner.server_args.attention_backend == "triton"
|
188
|
+
or runner.server_args.attention_backend == "trtllm_mha"
|
189
|
+
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
|
190
|
+
if is_npu():
|
191
|
+
assert (
|
192
|
+
runner.server_args.attention_backend == "ascend"
|
193
|
+
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
194
|
+
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
195
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
196
|
+
HybridLinearAttnBackend,
|
197
|
+
MambaAttnBackend,
|
198
|
+
)
|
199
|
+
|
200
|
+
linear_attn_backend = MambaAttnBackend(runner)
|
201
|
+
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
202
|
+
return HybridLinearAttnBackend(
|
203
|
+
full_attn_backend, linear_attn_backend, full_attn_layers
|
204
|
+
)
|
205
|
+
|
206
|
+
return full_attn_backend
|
@@ -6,9 +6,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
if TYPE_CHECKING:
|
9
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
9
10
|
from sglang.srt.layers.radix_attention import RadixAttention
|
10
11
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
-
from sglang.srt.speculative.
|
12
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
12
13
|
|
13
14
|
|
14
15
|
class AttentionBackend(ABC):
|
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
|
|
31
32
|
seq_lens: torch.Tensor,
|
32
33
|
encoder_lens: Optional[torch.Tensor],
|
33
34
|
forward_mode: ForwardMode,
|
34
|
-
spec_info: Optional[
|
35
|
+
spec_info: Optional[SpecInput],
|
35
36
|
):
|
36
37
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
37
38
|
raise NotImplementedError()
|
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
|
|
44
45
|
seq_lens_sum: int,
|
45
46
|
encoder_lens: Optional[torch.Tensor],
|
46
47
|
forward_mode: ForwardMode,
|
47
|
-
spec_info: Optional[
|
48
|
+
spec_info: Optional[SpecInput],
|
48
49
|
seq_lens_cpu: Optional[torch.Tensor],
|
49
50
|
):
|
50
51
|
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
|
|
115
116
|
def support_triton(self):
|
116
117
|
"""Check if the current backend supports triton."""
|
117
118
|
return True
|
119
|
+
|
120
|
+
def get_indexer_metadata(
|
121
|
+
self,
|
122
|
+
layer_id: int,
|
123
|
+
forward_batch: ForwardBatch,
|
124
|
+
) -> Optional[BaseIndexerMetadata]:
|
125
|
+
"""Get the indexer metadata. None means don't support indexer."""
|
126
|
+
return None
|
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
|
|
20
20
|
if TYPE_CHECKING:
|
21
21
|
from sglang.srt.layers.radix_attention import RadixAttention
|
22
22
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
23
|
-
from sglang.srt.speculative.spec_info import
|
23
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
24
24
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
if _is_cuda:
|
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
151
151
|
seq_lens: torch.Tensor,
|
152
152
|
encoder_lens: Optional[torch.Tensor],
|
153
153
|
forward_mode: ForwardMode,
|
154
|
-
spec_info: Optional[
|
154
|
+
spec_info: Optional[SpecInput],
|
155
155
|
):
|
156
156
|
if forward_mode.is_decode_or_idle():
|
157
157
|
if spec_info is None:
|
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
190
190
|
seq_lens_sum: int,
|
191
191
|
encoder_lens: Optional[torch.Tensor],
|
192
192
|
forward_mode: ForwardMode,
|
193
|
-
spec_info: Optional[
|
193
|
+
spec_info: Optional[SpecInput],
|
194
194
|
seq_lens_cpu: Optional[torch.Tensor],
|
195
195
|
):
|
196
196
|
|
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
|
|
1537
1537
|
query_inter,
|
1538
1538
|
key_cache,
|
1539
1539
|
value_cache,
|
1540
|
-
block_table
|
1540
|
+
block_table,
|
1541
1541
|
decode_meta.seq_lens_inter,
|
1542
1542
|
softmax_scale,
|
1543
1543
|
causal=False,
|
@@ -0,0 +1,242 @@
|
|
1
|
+
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from einops import rearrange
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
12
|
+
from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o
|
13
|
+
from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (
|
14
|
+
chunk_scaled_dot_kkt_fwd,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
|
17
|
+
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
|
18
|
+
from sglang.srt.layers.attention.fla.solve_tril import solve_tril
|
19
|
+
from sglang.srt.layers.attention.fla.utils import (
|
20
|
+
SUPPRESS_LEVEL,
|
21
|
+
autocast_custom_fwd,
|
22
|
+
input_guard,
|
23
|
+
)
|
24
|
+
from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd
|
25
|
+
|
26
|
+
|
27
|
+
def chunk_gated_delta_rule_fwd(
|
28
|
+
q: torch.Tensor,
|
29
|
+
k: torch.Tensor,
|
30
|
+
v: torch.Tensor,
|
31
|
+
g: torch.Tensor,
|
32
|
+
beta: torch.Tensor,
|
33
|
+
scale: float,
|
34
|
+
initial_state: torch.Tensor,
|
35
|
+
output_final_state: bool,
|
36
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
37
|
+
):
|
38
|
+
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
39
|
+
# obtain WY representation. u is actually the new v.
|
40
|
+
A = chunk_scaled_dot_kkt_fwd(
|
41
|
+
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
42
|
+
)
|
43
|
+
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
44
|
+
w, u = recompute_w_u_fwd(
|
45
|
+
k=k,
|
46
|
+
v=v,
|
47
|
+
beta=beta,
|
48
|
+
A=A,
|
49
|
+
g_cumsum=g,
|
50
|
+
cu_seqlens=cu_seqlens,
|
51
|
+
)
|
52
|
+
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
53
|
+
k=k,
|
54
|
+
w=w,
|
55
|
+
u=u,
|
56
|
+
g=g,
|
57
|
+
initial_state=initial_state,
|
58
|
+
output_final_state=output_final_state,
|
59
|
+
cu_seqlens=cu_seqlens,
|
60
|
+
)
|
61
|
+
o = chunk_fwd_o(
|
62
|
+
q=q,
|
63
|
+
k=k,
|
64
|
+
v=v_new,
|
65
|
+
h=h,
|
66
|
+
g=g,
|
67
|
+
scale=scale,
|
68
|
+
cu_seqlens=cu_seqlens,
|
69
|
+
)
|
70
|
+
if SUPPRESS_LEVEL < 3:
|
71
|
+
return g, o, A, final_state, None, None, None
|
72
|
+
elif SUPPRESS_LEVEL >= 3:
|
73
|
+
return g, o, A, final_state, w, h, v_new
|
74
|
+
|
75
|
+
|
76
|
+
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
@input_guard
|
80
|
+
@autocast_custom_fwd
|
81
|
+
def forward(
|
82
|
+
ctx,
|
83
|
+
q: torch.Tensor,
|
84
|
+
k: torch.Tensor,
|
85
|
+
v: torch.Tensor,
|
86
|
+
g: torch.Tensor,
|
87
|
+
beta: torch.Tensor,
|
88
|
+
scale: float,
|
89
|
+
initial_state: torch.Tensor,
|
90
|
+
output_final_state: bool,
|
91
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
92
|
+
use_qk_l2norm_in_kernel: bool = False,
|
93
|
+
):
|
94
|
+
q_orig = q
|
95
|
+
k_orig = k
|
96
|
+
|
97
|
+
if use_qk_l2norm_in_kernel:
|
98
|
+
q = l2norm_fwd(q)
|
99
|
+
k = l2norm_fwd(k)
|
100
|
+
|
101
|
+
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
102
|
+
q=q,
|
103
|
+
k=k,
|
104
|
+
v=v,
|
105
|
+
g=g,
|
106
|
+
beta=beta,
|
107
|
+
scale=scale,
|
108
|
+
initial_state=initial_state,
|
109
|
+
output_final_state=output_final_state,
|
110
|
+
cu_seqlens=cu_seqlens,
|
111
|
+
)
|
112
|
+
return o.to(q.dtype), final_state
|
113
|
+
|
114
|
+
|
115
|
+
@torch.compiler.disable
|
116
|
+
def chunk_gated_delta_rule(
|
117
|
+
q: torch.Tensor,
|
118
|
+
k: torch.Tensor,
|
119
|
+
v: torch.Tensor,
|
120
|
+
g: torch.Tensor,
|
121
|
+
beta: torch.Tensor,
|
122
|
+
scale: float = None,
|
123
|
+
initial_state: torch.Tensor = None,
|
124
|
+
output_final_state: bool = False,
|
125
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
126
|
+
head_first: bool = False,
|
127
|
+
use_qk_l2norm_in_kernel: bool = False,
|
128
|
+
):
|
129
|
+
r"""
|
130
|
+
Args:
|
131
|
+
q (torch.Tensor):
|
132
|
+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
133
|
+
k (torch.Tensor):
|
134
|
+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
135
|
+
v (torch.Tensor):
|
136
|
+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
137
|
+
g (torch.Tensor):
|
138
|
+
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
139
|
+
beta (torch.Tensor):
|
140
|
+
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
141
|
+
scale (Optional[int]):
|
142
|
+
Scale factor for the RetNet attention scores.
|
143
|
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
144
|
+
initial_state (Optional[torch.Tensor]):
|
145
|
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
146
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
147
|
+
Default: `None`.
|
148
|
+
output_final_state (Optional[bool]):
|
149
|
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
150
|
+
cu_seqlens (torch.LongTensor):
|
151
|
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
152
|
+
consistent with the FlashAttention API.
|
153
|
+
head_first (Optional[bool]):
|
154
|
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
155
|
+
Default: `False`.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
o (torch.Tensor):
|
159
|
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
160
|
+
final_state (torch.Tensor):
|
161
|
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
162
|
+
|
163
|
+
Examples::
|
164
|
+
>>> import torch
|
165
|
+
>>> import torch.nn.functional as F
|
166
|
+
>>> from einops import rearrange
|
167
|
+
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
168
|
+
# inputs with equal lengths
|
169
|
+
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
170
|
+
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
171
|
+
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
172
|
+
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
173
|
+
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
174
|
+
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
175
|
+
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
176
|
+
>>> o, ht = chunk_gated_delta_rule(
|
177
|
+
q, k, v, g, beta,
|
178
|
+
initial_state=h0,
|
179
|
+
output_final_state=True
|
180
|
+
)
|
181
|
+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
182
|
+
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
183
|
+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
184
|
+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
185
|
+
>>> o_var, ht_var = chunk_gated_delta_rule(
|
186
|
+
q, k, v, g, beta,
|
187
|
+
initial_state=h0,
|
188
|
+
output_final_state=True,
|
189
|
+
cu_seqlens=cu_seqlens
|
190
|
+
)
|
191
|
+
"""
|
192
|
+
assert q.dtype == k.dtype == v.dtype
|
193
|
+
assert (
|
194
|
+
q.dtype != torch.float32
|
195
|
+
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
196
|
+
assert (
|
197
|
+
len(beta.shape) == 3
|
198
|
+
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
199
|
+
|
200
|
+
if head_first:
|
201
|
+
raise DeprecationWarning(
|
202
|
+
"head_first is deprecated and will be removed in a future version. "
|
203
|
+
"Please use head_first=False for now instead."
|
204
|
+
)
|
205
|
+
q, k, v, beta, g = map(
|
206
|
+
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
207
|
+
)
|
208
|
+
# if not head_first and q.shape[1] < q.shape[2]:
|
209
|
+
# warnings.warn(
|
210
|
+
# f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
211
|
+
# "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
212
|
+
# "when head_first=False was specified. "
|
213
|
+
# "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
214
|
+
# )
|
215
|
+
if cu_seqlens is not None:
|
216
|
+
if q.shape[0] != 1:
|
217
|
+
raise ValueError(
|
218
|
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
219
|
+
f"Please flatten variable-length inputs before processing."
|
220
|
+
)
|
221
|
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
222
|
+
raise ValueError(
|
223
|
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
224
|
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
225
|
+
)
|
226
|
+
if scale is None:
|
227
|
+
scale = k.shape[-1] ** -0.5
|
228
|
+
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
229
|
+
q,
|
230
|
+
k,
|
231
|
+
v,
|
232
|
+
g,
|
233
|
+
beta,
|
234
|
+
scale,
|
235
|
+
initial_state,
|
236
|
+
output_final_state,
|
237
|
+
cu_seqlens,
|
238
|
+
use_qk_l2norm_in_kernel,
|
239
|
+
)
|
240
|
+
if head_first:
|
241
|
+
o = rearrange(o, "b t h ... -> b h t ...")
|
242
|
+
return o, final_state
|