sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
|
|
1
|
+
# Copyright 2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
|
15
|
+
|
16
|
+
|
17
|
+
import math
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
import torch
|
21
|
+
|
22
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass(kw_only=True)
|
26
|
+
class ForwardMetadata:
|
27
|
+
query_start_loc: torch.Tensor
|
28
|
+
mamba_cache_indices: torch.Tensor
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass(kw_only=True)
|
32
|
+
class Mamba2Metadata(ForwardMetadata):
|
33
|
+
"""stable metadata across all mamba2 layers in the forward pass"""
|
34
|
+
|
35
|
+
num_prefills: int
|
36
|
+
num_prefill_tokens: int
|
37
|
+
num_decodes: int
|
38
|
+
|
39
|
+
@dataclass(kw_only=True, frozen=True)
|
40
|
+
class MixedMetadata:
|
41
|
+
has_initial_states: torch.Tensor
|
42
|
+
prep_initial_states: bool
|
43
|
+
|
44
|
+
chunk_size: int
|
45
|
+
seq_idx: torch.Tensor
|
46
|
+
chunk_indices: torch.Tensor
|
47
|
+
chunk_offsets: torch.Tensor
|
48
|
+
|
49
|
+
extend_seq_lens_cpu: list[int]
|
50
|
+
|
51
|
+
mixed_metadata: MixedMetadata | None = None
|
52
|
+
"""`mixed_metadata` is used for extend/mixed requests"""
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def _query_start_loc_to_chunk_indices_offsets(
|
56
|
+
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
|
57
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
58
|
+
"""
|
59
|
+
Args:
|
60
|
+
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
61
|
+
lengths, shape (num_seqs + 1,).
|
62
|
+
The first element should be 0. Each entry represents the starting
|
63
|
+
index of a sequence in the flattened token array.
|
64
|
+
chunk_size (int): The size of each physical mamba chunk
|
65
|
+
(number of tokens per chunk).
|
66
|
+
total_seqlens (int): The total number of tokens in the batch.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
70
|
+
- chunk_indices (torch.Tensor): 1D tensor of indices
|
71
|
+
indicating the physical chunk for each logical chunk.
|
72
|
+
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
73
|
+
indicating the starting index of each logical chunk within
|
74
|
+
its physical chunk.
|
75
|
+
|
76
|
+
This function computes the chunk indices and offsets for the given
|
77
|
+
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
78
|
+
where N is the number of logical (pseudo) chunks.
|
79
|
+
A logical chunk is a sequence of tokens that are all part of the same
|
80
|
+
sequence and are all in the same physical mamba chunk.
|
81
|
+
In other words, a logical chunk changes every time we cross a sequence
|
82
|
+
boundary or a physical mamba chunk boundary.
|
83
|
+
Logical chunks are needed to handle batched requests with initial states
|
84
|
+
(see _state_passing_fwd and _chunk_scan_fwd).
|
85
|
+
The chunk_indices tensor contains the index of the physical chunk for each
|
86
|
+
logical chunk.
|
87
|
+
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
88
|
+
logical chunk in the physical chunk.
|
89
|
+
|
90
|
+
Example:
|
91
|
+
query_start_loc = [0, 5, 10]
|
92
|
+
chunk_size = 8
|
93
|
+
total_seqlens = 10
|
94
|
+
-> chunk_indices = [0, 0, 1]
|
95
|
+
-> chunk_offsets = [0, 5, 0]
|
96
|
+
|
97
|
+
In this example, we have 2 sequences, each with 5 tokens. The physical
|
98
|
+
chunk size is 8 tokens.
|
99
|
+
We have three logical chunks:
|
100
|
+
- the first logical chunk starts at token 0 in the first physical chunk
|
101
|
+
and contains all 5 tokens from the first sequence
|
102
|
+
- the second logical chunk starts at token 5 in the first physical chunk
|
103
|
+
and contains first 3 tokens from the second sequence
|
104
|
+
- the third logical chunk starts at token 0 in the second physical chunk
|
105
|
+
and contains the remaining 2 tokens from the second sequence
|
106
|
+
"""
|
107
|
+
|
108
|
+
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
109
|
+
|
110
|
+
# outputs will have length expansion of chunks that do not divide
|
111
|
+
# chunk_size
|
112
|
+
N = (
|
113
|
+
math.ceil(total_seqlens / chunk_size)
|
114
|
+
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
|
115
|
+
)
|
116
|
+
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
|
117
|
+
chunk_offsets = torch.zeros(
|
118
|
+
(N,), dtype=torch.int, device=query_start_loc.device
|
119
|
+
)
|
120
|
+
|
121
|
+
p = 0 # num of insertions
|
122
|
+
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
123
|
+
|
124
|
+
# if does not divide chunk_size, then there is one chunk insertion
|
125
|
+
p += s % chunk_size > 0
|
126
|
+
|
127
|
+
# get the dimensions
|
128
|
+
# - the + 1 for _e is to shift the boundary by one chunk
|
129
|
+
# - this shifting is not needed if chunk_size divides e
|
130
|
+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
|
131
|
+
|
132
|
+
# adjust indices and offsets
|
133
|
+
chunk_indices[_s:_e] -= p
|
134
|
+
chunk_offsets[_s] = s % chunk_size
|
135
|
+
|
136
|
+
return chunk_indices, chunk_offsets
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def prepare_decode(
|
140
|
+
query_start_loc: torch.Tensor,
|
141
|
+
mamba_cache_indices: torch.Tensor,
|
142
|
+
seq_lens: torch.Tensor,
|
143
|
+
) -> "Mamba2Metadata":
|
144
|
+
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
|
145
|
+
return Mamba2Metadata(
|
146
|
+
query_start_loc=query_start_loc,
|
147
|
+
mamba_cache_indices=mamba_cache_indices,
|
148
|
+
num_decodes=len(seq_lens),
|
149
|
+
num_prefills=0,
|
150
|
+
num_prefill_tokens=0,
|
151
|
+
)
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def prepare_mixed(
|
155
|
+
cls,
|
156
|
+
query_start_loc: torch.Tensor,
|
157
|
+
mamba_cache_indices: torch.Tensor,
|
158
|
+
chunk_size: int,
|
159
|
+
forward_batch: ForwardBatch,
|
160
|
+
) -> "Mamba2Metadata":
|
161
|
+
"""This path cannot run with CUDA graph, as it contains extend requests."""
|
162
|
+
if forward_batch.extend_num_tokens is None:
|
163
|
+
return cls.prepare_decode(
|
164
|
+
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
|
165
|
+
)
|
166
|
+
num_prefills = len(forward_batch.extend_seq_lens)
|
167
|
+
num_prefill_tokens = forward_batch.extend_num_tokens
|
168
|
+
num_decodes = len(forward_batch.seq_lens) - num_prefills
|
169
|
+
context_lens_tensor = forward_batch.extend_prefix_lens
|
170
|
+
assert context_lens_tensor is not None
|
171
|
+
# precompute flag to avoid device syncs later
|
172
|
+
has_initial_states = context_lens_tensor > 0
|
173
|
+
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
|
174
|
+
|
175
|
+
query_start_loc = query_start_loc[: num_prefills + 1]
|
176
|
+
seq_idx = torch.repeat_interleave(
|
177
|
+
torch.arange(
|
178
|
+
num_prefills, dtype=torch.int32, device=query_start_loc.device
|
179
|
+
),
|
180
|
+
query_start_loc.diff(),
|
181
|
+
output_size=num_prefill_tokens,
|
182
|
+
)
|
183
|
+
seq_idx.unsqueeze_(0)
|
184
|
+
|
185
|
+
# We compute metadata for chunked prefill once at the top level model
|
186
|
+
# forward and reuse them in mamba layers. If not needed, they will be
|
187
|
+
# ignored inside mamba kernels.
|
188
|
+
chunk_offsets, chunk_indices = None, None
|
189
|
+
if prep_initial_states:
|
190
|
+
chunk_indices, chunk_offsets = (
|
191
|
+
cls._query_start_loc_to_chunk_indices_offsets(
|
192
|
+
query_start_loc, chunk_size, num_prefill_tokens
|
193
|
+
)
|
194
|
+
)
|
195
|
+
|
196
|
+
return Mamba2Metadata(
|
197
|
+
query_start_loc=query_start_loc,
|
198
|
+
mamba_cache_indices=mamba_cache_indices,
|
199
|
+
num_prefills=num_prefills,
|
200
|
+
num_prefill_tokens=num_prefill_tokens,
|
201
|
+
num_decodes=num_decodes,
|
202
|
+
mixed_metadata=cls.MixedMetadata(
|
203
|
+
has_initial_states=has_initial_states,
|
204
|
+
prep_initial_states=prep_initial_states,
|
205
|
+
chunk_size=chunk_size,
|
206
|
+
seq_idx=seq_idx,
|
207
|
+
chunk_indices=chunk_indices,
|
208
|
+
chunk_offsets=chunk_offsets,
|
209
|
+
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
210
|
+
),
|
211
|
+
)
|
@@ -0,0 +1,120 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.custom_op import CustomOp
|
6
|
+
from sglang.srt.distributed.communication_op import (
|
7
|
+
tensor_model_parallel_all_gather,
|
8
|
+
tensor_model_parallel_all_reduce,
|
9
|
+
)
|
10
|
+
from sglang.srt.distributed.parallel_state import (
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
|
15
|
+
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
|
16
|
+
from sglang.srt.utils.common import set_weight_attrs
|
17
|
+
|
18
|
+
|
19
|
+
class Mixer2RMSNormGated(CustomOp):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
full_hidden_size: int,
|
23
|
+
full_n_groups: int,
|
24
|
+
use_rms_norm: bool = True,
|
25
|
+
eps: float = 1e-6,
|
26
|
+
):
|
27
|
+
super().__init__()
|
28
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
29
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
30
|
+
self.full_hidden_size = full_hidden_size
|
31
|
+
self.group_size = full_hidden_size // full_n_groups
|
32
|
+
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
33
|
+
self.n_groups = full_hidden_size // self.group_size
|
34
|
+
|
35
|
+
self.variance_epsilon = eps
|
36
|
+
self.use_rms_norm = use_rms_norm
|
37
|
+
if self.use_rms_norm:
|
38
|
+
# Register norm weight only if we're actually applying RMSNorm
|
39
|
+
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
40
|
+
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
41
|
+
else:
|
42
|
+
# Avoid checkpoint mismatch by skipping unused parameter
|
43
|
+
self.register_parameter("weight", None)
|
44
|
+
assert (
|
45
|
+
self.full_hidden_size % self.tp_size == 0
|
46
|
+
), "Tensor parallel world size must divide hidden size."
|
47
|
+
|
48
|
+
def forward_native(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
gate: torch.Tensor,
|
52
|
+
):
|
53
|
+
# Three tensor-parallel cases:
|
54
|
+
# 1. n_groups is 1
|
55
|
+
# In this case we parallelize along the reduction dim.
|
56
|
+
# Each rank computes a local sum of squares followed by AllReduce
|
57
|
+
# 2. tp_size divides n_groups
|
58
|
+
# Each rank only reduces within its local group(s).
|
59
|
+
# No collective ops necessary.
|
60
|
+
# 3. The general case can be pretty complicated so we AllGather
|
61
|
+
# the input and then redundantly compute the RMSNorm.
|
62
|
+
input_dtype = x.dtype
|
63
|
+
x = x * torch.nn.functional.silu(gate.to(torch.float32))
|
64
|
+
if not self.use_rms_norm:
|
65
|
+
return x.to(input_dtype)
|
66
|
+
|
67
|
+
if self.n_groups == 1:
|
68
|
+
if self.tp_size > 1:
|
69
|
+
# Compute local sum and then reduce to obtain global sum
|
70
|
+
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
71
|
+
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
72
|
+
# Calculate the variance
|
73
|
+
count = self.tp_size * x.shape[-1]
|
74
|
+
variance = global_sums / count
|
75
|
+
|
76
|
+
else:
|
77
|
+
variance = x.pow(2).mean(-1, keepdim=True)
|
78
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
79
|
+
else:
|
80
|
+
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
81
|
+
if redundant_tp:
|
82
|
+
# To handle the general case, redundantly apply the variance
|
83
|
+
x = tensor_model_parallel_all_gather(x, -1)
|
84
|
+
|
85
|
+
*prefix_dims, hidden_dim = x.shape
|
86
|
+
group_count = hidden_dim // self.group_size
|
87
|
+
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
88
|
+
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
89
|
+
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
90
|
+
x = x_grouped.view(*prefix_dims, hidden_dim)
|
91
|
+
|
92
|
+
if redundant_tp:
|
93
|
+
start = self.per_rank_hidden_size * self.tp_rank
|
94
|
+
end = start + self.per_rank_hidden_size
|
95
|
+
x = x[..., start:end]
|
96
|
+
|
97
|
+
return self.weight * x.to(input_dtype)
|
98
|
+
|
99
|
+
def forward_cuda(
|
100
|
+
self,
|
101
|
+
x: torch.Tensor,
|
102
|
+
gate: torch.Tensor,
|
103
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
104
|
+
input_dtype = x.dtype
|
105
|
+
if not self.use_rms_norm:
|
106
|
+
# Keep gate in float32 for numerical stability during silu
|
107
|
+
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
108
|
+
|
109
|
+
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
110
|
+
return self.forward_native(x, gate)
|
111
|
+
|
112
|
+
return rms_norm_gated(
|
113
|
+
x=x,
|
114
|
+
weight=self.weight.data,
|
115
|
+
bias=None,
|
116
|
+
z=gate,
|
117
|
+
eps=self.variance_epsilon,
|
118
|
+
norm_before_gate=False,
|
119
|
+
is_rms_norm=True,
|
120
|
+
)
|
@@ -0,0 +1,172 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
# Copyright (c) 2024, Tri Dao.
|
4
|
+
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
|
10
|
+
|
11
|
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
12
|
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
13
|
+
@triton.jit
|
14
|
+
def _layer_norm_fwd_1pass_kernel(
|
15
|
+
X, # pointer to the input
|
16
|
+
Y, # pointer to the output
|
17
|
+
W, # pointer to the weights
|
18
|
+
B, # pointer to the biases
|
19
|
+
Z, # pointer to the other branch
|
20
|
+
Mean, # pointer to the mean
|
21
|
+
Rstd, # pointer to the 1/std
|
22
|
+
stride_x_row: tl.int64,
|
23
|
+
stride_y_row: tl.int64,
|
24
|
+
stride_z_row: tl.int64,
|
25
|
+
M: tl.int64, # number of rows in X
|
26
|
+
N: tl.int64, # number of columns in X
|
27
|
+
eps, # epsilon to avoid division by zero
|
28
|
+
BLOCK_N: tl.constexpr,
|
29
|
+
HAS_BIAS: tl.constexpr,
|
30
|
+
HAS_Z: tl.constexpr,
|
31
|
+
NORM_BEFORE_GATE: tl.constexpr,
|
32
|
+
IS_RMS_NORM: tl.constexpr,
|
33
|
+
):
|
34
|
+
# Map the program id to the row of X and Y it should compute.
|
35
|
+
row = tl.program_id(0)
|
36
|
+
group = tl.program_id(1)
|
37
|
+
X += row * stride_x_row + group * N
|
38
|
+
Y += row * stride_y_row + group * N
|
39
|
+
if HAS_Z:
|
40
|
+
Z += row * stride_z_row + group * N
|
41
|
+
if not IS_RMS_NORM:
|
42
|
+
Mean += group * M
|
43
|
+
Rstd += group * M
|
44
|
+
W += group * N
|
45
|
+
if HAS_BIAS:
|
46
|
+
B += group * N
|
47
|
+
# Compute mean and variance
|
48
|
+
cols = tl.arange(0, BLOCK_N)
|
49
|
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
50
|
+
if HAS_Z and not NORM_BEFORE_GATE:
|
51
|
+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
52
|
+
x *= z * tl.sigmoid(z)
|
53
|
+
if not IS_RMS_NORM:
|
54
|
+
mean = tl.sum(x, axis=0) / N
|
55
|
+
tl.store(Mean + row, mean)
|
56
|
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
57
|
+
var = tl.sum(xbar * xbar, axis=0) / N
|
58
|
+
else:
|
59
|
+
xbar = tl.where(cols < N, x, 0.0)
|
60
|
+
var = tl.sum(xbar * xbar, axis=0) / N
|
61
|
+
rstd = 1 / tl.sqrt(var + eps)
|
62
|
+
tl.store(Rstd + row, rstd)
|
63
|
+
# Normalize and apply linear transformation
|
64
|
+
mask = cols < N
|
65
|
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
66
|
+
if HAS_BIAS:
|
67
|
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
68
|
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
69
|
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
70
|
+
if HAS_Z and NORM_BEFORE_GATE:
|
71
|
+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
72
|
+
y *= z * tl.sigmoid(z)
|
73
|
+
# Write output
|
74
|
+
tl.store(Y + cols, y, mask=mask)
|
75
|
+
|
76
|
+
|
77
|
+
def _layer_norm_fwd(
|
78
|
+
x,
|
79
|
+
weight,
|
80
|
+
bias,
|
81
|
+
eps,
|
82
|
+
z=None,
|
83
|
+
out=None,
|
84
|
+
group_size=None,
|
85
|
+
norm_before_gate=True,
|
86
|
+
is_rms_norm=False,
|
87
|
+
):
|
88
|
+
M, N = x.shape
|
89
|
+
if group_size is None:
|
90
|
+
group_size = N
|
91
|
+
assert N % group_size == 0
|
92
|
+
ngroups = N // group_size
|
93
|
+
assert x.stride(-1) == 1
|
94
|
+
if z is not None:
|
95
|
+
assert z.stride(-1) == 1
|
96
|
+
assert z.shape == (M, N)
|
97
|
+
assert weight.shape == (N,)
|
98
|
+
assert weight.stride(-1) == 1
|
99
|
+
if bias is not None:
|
100
|
+
assert bias.stride(-1) == 1
|
101
|
+
assert bias.shape == (N,)
|
102
|
+
# allocate output
|
103
|
+
if out is not None:
|
104
|
+
assert out.shape == x.shape
|
105
|
+
else:
|
106
|
+
out = torch.empty_like(x)
|
107
|
+
assert out.stride(-1) == 1
|
108
|
+
mean = (
|
109
|
+
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
110
|
+
if not is_rms_norm
|
111
|
+
else None
|
112
|
+
)
|
113
|
+
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
114
|
+
# Less than 64KB per feature: enqueue fused kernel
|
115
|
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
116
|
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
117
|
+
if group_size > BLOCK_N:
|
118
|
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
119
|
+
# heuristics for number of warps
|
120
|
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
121
|
+
grid = (M, ngroups)
|
122
|
+
with torch.cuda.device(x.device.index):
|
123
|
+
_layer_norm_fwd_1pass_kernel[grid](
|
124
|
+
x,
|
125
|
+
out,
|
126
|
+
weight,
|
127
|
+
bias,
|
128
|
+
z,
|
129
|
+
mean,
|
130
|
+
rstd,
|
131
|
+
x.stride(0),
|
132
|
+
out.stride(0),
|
133
|
+
z.stride(0) if z is not None else 0,
|
134
|
+
M,
|
135
|
+
group_size,
|
136
|
+
eps,
|
137
|
+
BLOCK_N=BLOCK_N,
|
138
|
+
NORM_BEFORE_GATE=norm_before_gate,
|
139
|
+
IS_RMS_NORM=is_rms_norm,
|
140
|
+
num_warps=num_warps,
|
141
|
+
)
|
142
|
+
return out, mean, rstd
|
143
|
+
|
144
|
+
|
145
|
+
def rms_norm_gated(
|
146
|
+
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
147
|
+
):
|
148
|
+
x_shape_og = x.shape
|
149
|
+
# reshape input data into 2D tensor
|
150
|
+
x = x.reshape(-1, x.shape[-1])
|
151
|
+
if x.stride(-1) != 1:
|
152
|
+
x = x.contiguous()
|
153
|
+
if z is not None:
|
154
|
+
assert z.shape == x_shape_og
|
155
|
+
z = z.reshape(-1, z.shape[-1])
|
156
|
+
if z.stride(-1) != 1:
|
157
|
+
z = z.contiguous()
|
158
|
+
weight = weight.contiguous()
|
159
|
+
if bias is not None:
|
160
|
+
bias = bias.contiguous()
|
161
|
+
y, _, _ = _layer_norm_fwd(
|
162
|
+
x,
|
163
|
+
weight,
|
164
|
+
bias,
|
165
|
+
eps,
|
166
|
+
z=z,
|
167
|
+
group_size=group_size,
|
168
|
+
norm_before_gate=norm_before_gate,
|
169
|
+
is_rms_norm=True,
|
170
|
+
)
|
171
|
+
|
172
|
+
return y.reshape(x_shape_og)
|