sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
|
|
1
|
+
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py
|
2
|
+
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
5
|
+
|
6
|
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
7
|
+
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
8
|
+
|
9
|
+
# ruff: noqa: E501
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
import triton.language as tl
|
14
|
+
from einops import rearrange
|
15
|
+
from packaging import version
|
16
|
+
|
17
|
+
from .ssd_bmm import _bmm_chunk_fwd
|
18
|
+
from .ssd_chunk_scan import _chunk_scan_fwd
|
19
|
+
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen
|
20
|
+
from .ssd_state_passing import _state_passing_fwd
|
21
|
+
|
22
|
+
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
23
|
+
|
24
|
+
|
25
|
+
def is_int_pow_2(n):
|
26
|
+
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
27
|
+
|
28
|
+
|
29
|
+
def _mamba_chunk_scan_combined_fwd(
|
30
|
+
x,
|
31
|
+
dt,
|
32
|
+
A,
|
33
|
+
B,
|
34
|
+
C,
|
35
|
+
chunk_size,
|
36
|
+
D=None,
|
37
|
+
z=None,
|
38
|
+
dt_bias=None,
|
39
|
+
initial_states=None,
|
40
|
+
seq_idx=None,
|
41
|
+
chunk_indices=None,
|
42
|
+
chunk_offsets=None,
|
43
|
+
cu_seqlens=None,
|
44
|
+
dt_softplus=False,
|
45
|
+
dt_limit=(0.0, float("inf")),
|
46
|
+
state_dtype=None,
|
47
|
+
out=None,
|
48
|
+
):
|
49
|
+
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
50
|
+
batch, seqlen, nheads, headdim = x.shape
|
51
|
+
_, _, ngroups, dstate = B.shape
|
52
|
+
assert nheads % ngroups == 0
|
53
|
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
54
|
+
assert dt.shape == (batch, seqlen, nheads)
|
55
|
+
assert A.shape == (nheads,)
|
56
|
+
assert C.shape == B.shape
|
57
|
+
if z is not None:
|
58
|
+
assert z.shape == x.shape
|
59
|
+
if D is not None:
|
60
|
+
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
61
|
+
if seq_idx is not None:
|
62
|
+
assert seq_idx.shape == (batch, seqlen)
|
63
|
+
if B.stride(-1) != 1:
|
64
|
+
B = B.contiguous()
|
65
|
+
if C.stride(-1) != 1:
|
66
|
+
C = C.contiguous()
|
67
|
+
if (
|
68
|
+
x.stride(-1) != 1 and x.stride(1) != 1
|
69
|
+
): # Either M or K dimension should be contiguous
|
70
|
+
x = x.contiguous()
|
71
|
+
if (
|
72
|
+
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
73
|
+
): # Either M or K dimension should be contiguous
|
74
|
+
z = z.contiguous()
|
75
|
+
if D is not None and D.stride(-1) != 1:
|
76
|
+
D = D.contiguous()
|
77
|
+
if initial_states is not None:
|
78
|
+
if cu_seqlens is None:
|
79
|
+
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
80
|
+
else:
|
81
|
+
assert initial_states.shape == (
|
82
|
+
len(cu_seqlens) - 1,
|
83
|
+
nheads,
|
84
|
+
headdim,
|
85
|
+
dstate,
|
86
|
+
)
|
87
|
+
|
88
|
+
# This function executes 5 sub-functions for computing mamba
|
89
|
+
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
90
|
+
# which has a minimal implementation to understand the below operations
|
91
|
+
# - as explained by the blog, mamba is a special case of causal attention
|
92
|
+
# - the idea is to chunk the attention matrix and compute each
|
93
|
+
# submatrix separately using different optimizations.
|
94
|
+
# - see the blog and paper for a visualization of the submatrices
|
95
|
+
# which we refer to in the comments below
|
96
|
+
|
97
|
+
# 1. Compute chunked cumsum of A * dt
|
98
|
+
# - here dt may go through a softplus activation
|
99
|
+
dA_cumsum, dt = _chunk_cumsum_fwd(
|
100
|
+
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
101
|
+
)
|
102
|
+
|
103
|
+
# 2. Compute the state for each intra-chunk
|
104
|
+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
105
|
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
106
|
+
|
107
|
+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
108
|
+
# (middle term of factorization of off-diag blocks; A terms)
|
109
|
+
# - for handling chunked prefill, this requires i) initial_states
|
110
|
+
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
111
|
+
# - When a new seq_idx is detected, we will stop passing the prev_state
|
112
|
+
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
113
|
+
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
114
|
+
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
115
|
+
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
116
|
+
# of the previous chunk. This implies that the first chunk of states is either 0
|
117
|
+
# or equal to init_states of the first example.
|
118
|
+
states, final_states = _state_passing_fwd(
|
119
|
+
rearrange(states, "... p n -> ... (p n)"),
|
120
|
+
dA_cumsum,
|
121
|
+
initial_states=(
|
122
|
+
rearrange(initial_states, "... p n -> ... (p n)")
|
123
|
+
if initial_states is not None
|
124
|
+
else None
|
125
|
+
),
|
126
|
+
seq_idx=seq_idx,
|
127
|
+
chunk_size=chunk_size,
|
128
|
+
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
129
|
+
is_cont_batched=cu_seqlens is not None,
|
130
|
+
chunk_offsets=chunk_offsets,
|
131
|
+
)
|
132
|
+
states, final_states = (
|
133
|
+
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
134
|
+
)
|
135
|
+
|
136
|
+
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
137
|
+
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
138
|
+
|
139
|
+
# 5. Scan and compute the diagonal blocks, taking into
|
140
|
+
# account past causal states.
|
141
|
+
# - if initial states are provided, then states information will be
|
142
|
+
# augmented with initial_states.
|
143
|
+
# - to do this properly, we need to account for example changes in
|
144
|
+
# the continuous batch, therefore we introduce pseudo chunks, which is
|
145
|
+
# a chunk that is split up each time an example changes.
|
146
|
+
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
147
|
+
# a seq_idx change, in which case we take states information from
|
148
|
+
# init_states.
|
149
|
+
out_x = _chunk_scan_fwd(
|
150
|
+
CB,
|
151
|
+
x,
|
152
|
+
dt,
|
153
|
+
dA_cumsum,
|
154
|
+
C,
|
155
|
+
states,
|
156
|
+
D=D,
|
157
|
+
z=z,
|
158
|
+
seq_idx=seq_idx,
|
159
|
+
chunk_indices=chunk_indices,
|
160
|
+
chunk_offsets=chunk_offsets,
|
161
|
+
initial_states=initial_states,
|
162
|
+
out=out,
|
163
|
+
)
|
164
|
+
if cu_seqlens is None:
|
165
|
+
return out_x, dt, dA_cumsum, states, final_states
|
166
|
+
else:
|
167
|
+
assert (
|
168
|
+
batch == 1
|
169
|
+
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
170
|
+
varlen_states = chunk_state_varlen(
|
171
|
+
B.squeeze(0),
|
172
|
+
x.squeeze(0),
|
173
|
+
dt.squeeze(0),
|
174
|
+
dA_cumsum.squeeze(0),
|
175
|
+
cu_seqlens,
|
176
|
+
states.squeeze(0),
|
177
|
+
initial_states=initial_states,
|
178
|
+
)
|
179
|
+
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
180
|
+
|
181
|
+
|
182
|
+
def mamba_chunk_scan_combined(
|
183
|
+
x,
|
184
|
+
dt,
|
185
|
+
A,
|
186
|
+
B,
|
187
|
+
C,
|
188
|
+
chunk_size,
|
189
|
+
D=None,
|
190
|
+
z=None,
|
191
|
+
dt_bias=None,
|
192
|
+
initial_states=None,
|
193
|
+
seq_idx=None,
|
194
|
+
chunk_indices=None,
|
195
|
+
chunk_offsets=None,
|
196
|
+
cu_seqlens=None,
|
197
|
+
dt_softplus=False,
|
198
|
+
dt_limit=(0.0, float("inf")),
|
199
|
+
out=None,
|
200
|
+
return_final_states=False,
|
201
|
+
return_varlen_states=False,
|
202
|
+
state_dtype=None,
|
203
|
+
):
|
204
|
+
"""
|
205
|
+
Argument:
|
206
|
+
x: (batch, seqlen, nheads, headdim)
|
207
|
+
dt: (batch, seqlen, nheads)
|
208
|
+
A: (nheads)
|
209
|
+
B: (batch, seqlen, ngroups, dstate)
|
210
|
+
C: (batch, seqlen, ngroups, dstate)
|
211
|
+
chunk_size: int
|
212
|
+
D: (nheads, headdim) or (nheads,)
|
213
|
+
z: (batch, seqlen, nheads, headdim)
|
214
|
+
dt_bias: (nheads,)
|
215
|
+
initial_states: (batch, nheads, headdim, dstate)
|
216
|
+
seq_idx: (batch, seqlen)
|
217
|
+
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
218
|
+
dt_softplus: Whether to apply softplus to dt
|
219
|
+
out: Preallocated output tensor
|
220
|
+
state_dtype: The data type of the ssm state
|
221
|
+
"""
|
222
|
+
|
223
|
+
if not return_varlen_states:
|
224
|
+
cu_seqlens = None
|
225
|
+
else:
|
226
|
+
assert (
|
227
|
+
cu_seqlens is not None
|
228
|
+
), "cu_seqlens must be provided if return_varlen_states is True"
|
229
|
+
out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
230
|
+
_mamba_chunk_scan_combined_fwd(
|
231
|
+
x,
|
232
|
+
dt,
|
233
|
+
A,
|
234
|
+
B,
|
235
|
+
C,
|
236
|
+
chunk_size,
|
237
|
+
D=D,
|
238
|
+
z=z,
|
239
|
+
dt_bias=dt_bias,
|
240
|
+
initial_states=initial_states,
|
241
|
+
seq_idx=seq_idx,
|
242
|
+
chunk_indices=chunk_indices,
|
243
|
+
chunk_offsets=chunk_offsets,
|
244
|
+
cu_seqlens=cu_seqlens,
|
245
|
+
dt_softplus=dt_softplus,
|
246
|
+
dt_limit=dt_limit,
|
247
|
+
out=out,
|
248
|
+
state_dtype=state_dtype,
|
249
|
+
)
|
250
|
+
)
|
251
|
+
if not return_varlen_states:
|
252
|
+
if not return_final_states:
|
253
|
+
return
|
254
|
+
else:
|
255
|
+
return final_states
|
256
|
+
else:
|
257
|
+
varlen_states = rest[0]
|
258
|
+
return (
|
259
|
+
(varlen_states)
|
260
|
+
if not return_final_states
|
261
|
+
else (final_states, varlen_states)
|
262
|
+
)
|
@@ -0,0 +1,275 @@
|
|
1
|
+
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
|
2
|
+
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
5
|
+
|
6
|
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
7
|
+
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
|
8
|
+
|
9
|
+
# ruff: noqa: E501
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import triton
|
13
|
+
import triton.language as tl
|
14
|
+
|
15
|
+
|
16
|
+
# @triton.autotune(
|
17
|
+
# configs=[
|
18
|
+
# triton.Config({"BLOCK_SIZE": 64}),
|
19
|
+
# triton.Config({"BLOCK_SIZE": 128}),
|
20
|
+
# triton.Config({"BLOCK_SIZE": 256}),
|
21
|
+
# triton.Config({"BLOCK_SIZE": 512}),
|
22
|
+
# triton.Config({"BLOCK_SIZE": 1024}),
|
23
|
+
# triton.Config({"BLOCK_SIZE": 2048}),
|
24
|
+
# ],
|
25
|
+
# key=["dim"],
|
26
|
+
# )
|
27
|
+
@triton.jit
|
28
|
+
def _state_passing_fwd_kernel(
|
29
|
+
# Pointers to matrices
|
30
|
+
states_ptr,
|
31
|
+
out_ptr,
|
32
|
+
final_states_ptr,
|
33
|
+
dA_cs_ptr,
|
34
|
+
initstates_ptr,
|
35
|
+
seq_idx_ptr,
|
36
|
+
chunk_offsets_ptr,
|
37
|
+
chunk_meta_num,
|
38
|
+
# Matrix dimensions
|
39
|
+
dim,
|
40
|
+
nchunks,
|
41
|
+
seqlen,
|
42
|
+
chunk_size,
|
43
|
+
# Strides
|
44
|
+
stride_states_batch,
|
45
|
+
stride_states_chunk,
|
46
|
+
stride_states_head,
|
47
|
+
stride_states_dim,
|
48
|
+
stride_out_batch,
|
49
|
+
stride_out_chunk,
|
50
|
+
stride_out_head,
|
51
|
+
stride_out_dim,
|
52
|
+
stride_final_states_batch,
|
53
|
+
stride_final_states_head,
|
54
|
+
stride_final_states_dim,
|
55
|
+
stride_dA_cs_batch,
|
56
|
+
stride_dA_cs_chunk,
|
57
|
+
stride_dA_cs_head,
|
58
|
+
stride_dA_cs_csize,
|
59
|
+
stride_initstates_batch,
|
60
|
+
stride_initstates_head,
|
61
|
+
stride_initstates_dim,
|
62
|
+
stride_seq_idx_batch,
|
63
|
+
stride_seq_idx_seqlen,
|
64
|
+
# Meta-parameters
|
65
|
+
HAS_INITSTATES: tl.constexpr,
|
66
|
+
HAS_SEQ_IDX: tl.constexpr,
|
67
|
+
IS_CONT_BATCHED: tl.constexpr,
|
68
|
+
BLOCK_SIZE: tl.constexpr = 16,
|
69
|
+
):
|
70
|
+
pid_b = tl.program_id(axis=1)
|
71
|
+
pid_h = tl.program_id(axis=2)
|
72
|
+
pid_m = tl.program_id(axis=0)
|
73
|
+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
74
|
+
dA_cs_ptr += (
|
75
|
+
pid_b * stride_dA_cs_batch
|
76
|
+
+ pid_h * stride_dA_cs_head
|
77
|
+
+ (chunk_size - 1) * stride_dA_cs_csize
|
78
|
+
)
|
79
|
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
80
|
+
final_states_ptr += (
|
81
|
+
pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
82
|
+
)
|
83
|
+
if HAS_INITSTATES:
|
84
|
+
initstates_ptr += pid_h * stride_initstates_head
|
85
|
+
if not IS_CONT_BATCHED:
|
86
|
+
initstates_ptr += pid_b * stride_initstates_batch
|
87
|
+
|
88
|
+
if HAS_SEQ_IDX:
|
89
|
+
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
90
|
+
|
91
|
+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
92
|
+
states_ptrs = states_ptr + offs_m * stride_states_dim
|
93
|
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
94
|
+
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
95
|
+
|
96
|
+
# - states will be the past state of the sequence that continues on the current check
|
97
|
+
if not HAS_INITSTATES:
|
98
|
+
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
99
|
+
else:
|
100
|
+
initstates_ptr += offs_m * stride_initstates_dim
|
101
|
+
initstates_ptrs = initstates_ptr
|
102
|
+
# - for cont batches, for the first chunk mean it will be the first batch's
|
103
|
+
# init state
|
104
|
+
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
105
|
+
|
106
|
+
tl.store(out_ptrs, states, mask=offs_m < dim)
|
107
|
+
out_ptrs += stride_out_chunk
|
108
|
+
prev_seq_idx_chunk_end = 0
|
109
|
+
logical_chunk_idx = 0
|
110
|
+
for c in range(nchunks):
|
111
|
+
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
112
|
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
113
|
+
scale_mask = True
|
114
|
+
if HAS_SEQ_IDX:
|
115
|
+
# - the seq to pass forward is the one that is flushed to the right
|
116
|
+
# boundary.
|
117
|
+
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
118
|
+
seq_idx_chunk_end = tl.load(
|
119
|
+
seq_idx_ptr
|
120
|
+
+ (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen
|
121
|
+
)
|
122
|
+
if HAS_INITSTATES:
|
123
|
+
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
124
|
+
# this means in the current chunk the rightmost flushed seq
|
125
|
+
# has changed.
|
126
|
+
# - so we do not propagate the state from previous chunk
|
127
|
+
# - but rather we load that sequence's init state
|
128
|
+
initstates_ptrs = (
|
129
|
+
initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
130
|
+
)
|
131
|
+
|
132
|
+
# - update state with seq_idx_new's init state
|
133
|
+
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
|
134
|
+
tl.float32
|
135
|
+
)
|
136
|
+
|
137
|
+
# - we need to consider the cumsum only of the last sequence in the chunk
|
138
|
+
# - find its starting position (given by c_off of the logical chunk index)
|
139
|
+
# - and subtract the cumsum just before that position from the total cumsum
|
140
|
+
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
141
|
+
# sequence index at the start of the current chunk
|
142
|
+
seq_idx_chunk_start = tl.load(
|
143
|
+
seq_idx_ptr
|
144
|
+
+ min(c * chunk_size, seqlen) * stride_seq_idx_seqlen
|
145
|
+
)
|
146
|
+
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
147
|
+
# - load the chunk offset:
|
148
|
+
c_off = tl.load(
|
149
|
+
chunk_offsets_ptr + logical_chunk_idx,
|
150
|
+
mask=logical_chunk_idx < chunk_meta_num,
|
151
|
+
other=0,
|
152
|
+
)
|
153
|
+
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
154
|
+
if c_off > 0:
|
155
|
+
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
156
|
+
dA_cs_boundary = tl.load(
|
157
|
+
dA_cs_ptr
|
158
|
+
- (chunk_size - 1) * stride_dA_cs_csize
|
159
|
+
+ (c_off - 1) * stride_dA_cs_csize,
|
160
|
+
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
161
|
+
other=0.0,
|
162
|
+
)
|
163
|
+
dA_cs -= dA_cs_boundary
|
164
|
+
|
165
|
+
# - increment logical chunk index for every physical chunk
|
166
|
+
logical_chunk_idx += 1
|
167
|
+
else:
|
168
|
+
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
169
|
+
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
170
|
+
|
171
|
+
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
172
|
+
states = scale * states + new_states
|
173
|
+
if c < nchunks - 1:
|
174
|
+
tl.store(out_ptrs, states, mask=offs_m < dim)
|
175
|
+
else:
|
176
|
+
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
177
|
+
states_ptrs += stride_states_chunk
|
178
|
+
dA_cs_ptr += stride_dA_cs_chunk
|
179
|
+
out_ptrs += stride_out_chunk
|
180
|
+
|
181
|
+
|
182
|
+
def _state_passing_fwd(
|
183
|
+
states,
|
184
|
+
dA_cumsum,
|
185
|
+
initial_states=None,
|
186
|
+
seq_idx=None,
|
187
|
+
chunk_size=None,
|
188
|
+
out_dtype=None,
|
189
|
+
is_cont_batched=False,
|
190
|
+
chunk_offsets=None,
|
191
|
+
):
|
192
|
+
batch, nchunks, nheads, dim = states.shape
|
193
|
+
if chunk_size is None:
|
194
|
+
chunk_size = dA_cumsum.shape[-1]
|
195
|
+
else:
|
196
|
+
assert chunk_size == dA_cumsum.shape[-1]
|
197
|
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
198
|
+
if initial_states is not None:
|
199
|
+
if is_cont_batched:
|
200
|
+
# - if cu_seqlens is provided, then the initial states
|
201
|
+
# are used for continuous batching. In which case we
|
202
|
+
# require seq_idx to be provided
|
203
|
+
assert (
|
204
|
+
seq_idx is not None
|
205
|
+
), "seq_idx must be provided for continuous batching"
|
206
|
+
# - we also need chunk_offsets to be provided, to account
|
207
|
+
# for computation of dA_cumsum from the start of the
|
208
|
+
# sequence
|
209
|
+
assert (
|
210
|
+
chunk_offsets is not None
|
211
|
+
), "chunk_offsets must be provided for continuous batching"
|
212
|
+
else:
|
213
|
+
# - this is the regular batching case, where initial
|
214
|
+
# states are used are for each example of the batch.
|
215
|
+
assert initial_states.shape == (batch, nheads, dim)
|
216
|
+
|
217
|
+
if seq_idx is not None:
|
218
|
+
seqlen = seq_idx.shape[-1]
|
219
|
+
assert seq_idx.shape == (batch, seqlen)
|
220
|
+
out_dtype = states.dtype if out_dtype is None else out_dtype
|
221
|
+
out = torch.empty(
|
222
|
+
(batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype
|
223
|
+
)
|
224
|
+
final_states = torch.empty(
|
225
|
+
(batch, nheads, dim), device=states.device, dtype=torch.float32
|
226
|
+
)
|
227
|
+
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
|
228
|
+
with torch.cuda.device(states.device.index):
|
229
|
+
_state_passing_fwd_kernel[grid](
|
230
|
+
states,
|
231
|
+
out,
|
232
|
+
final_states,
|
233
|
+
dA_cumsum,
|
234
|
+
initial_states,
|
235
|
+
seq_idx,
|
236
|
+
chunk_offsets,
|
237
|
+
len(chunk_offsets) if chunk_offsets is not None else 0,
|
238
|
+
dim,
|
239
|
+
nchunks,
|
240
|
+
seqlen if seq_idx is not None else 0,
|
241
|
+
chunk_size,
|
242
|
+
states.stride(0),
|
243
|
+
states.stride(1),
|
244
|
+
states.stride(2),
|
245
|
+
states.stride(3),
|
246
|
+
out.stride(0),
|
247
|
+
out.stride(1),
|
248
|
+
out.stride(2),
|
249
|
+
out.stride(3),
|
250
|
+
final_states.stride(0),
|
251
|
+
final_states.stride(1),
|
252
|
+
final_states.stride(2),
|
253
|
+
dA_cumsum.stride(0),
|
254
|
+
dA_cumsum.stride(2),
|
255
|
+
dA_cumsum.stride(1),
|
256
|
+
dA_cumsum.stride(3),
|
257
|
+
*(
|
258
|
+
(
|
259
|
+
initial_states.stride(0),
|
260
|
+
initial_states.stride(1),
|
261
|
+
initial_states.stride(2),
|
262
|
+
)
|
263
|
+
if initial_states is not None
|
264
|
+
else (0, 0, 0)
|
265
|
+
),
|
266
|
+
*(
|
267
|
+
(seq_idx.stride(0), seq_idx.stride(1))
|
268
|
+
if seq_idx is not None
|
269
|
+
else (0, 0)
|
270
|
+
),
|
271
|
+
HAS_INITSTATES=initial_states is not None,
|
272
|
+
HAS_SEQ_IDX=seq_idx is not None,
|
273
|
+
IS_CONT_BATCHED=is_cont_batched,
|
274
|
+
)
|
275
|
+
return out, final_states
|