sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn.functional as F
|
|
@@ -12,17 +12,20 @@ from sglang.srt.custom_op import CustomOp
|
|
|
12
12
|
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
|
|
13
13
|
|
|
14
14
|
if is_cuda():
|
|
15
|
-
|
|
15
|
+
try:
|
|
16
|
+
import deep_gemm
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
deep_gemm = e
|
|
16
19
|
|
|
17
|
-
from sglang.srt.layers
|
|
20
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
21
|
+
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
|
|
18
22
|
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
|
19
23
|
from sglang.srt.layers.linear import ReplicatedLinear
|
|
20
|
-
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
21
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
22
25
|
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
|
23
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
24
26
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
|
25
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
28
|
+
from sglang.srt.server_args import get_global_server_args
|
|
26
29
|
|
|
27
30
|
if TYPE_CHECKING:
|
|
28
31
|
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
|
@@ -71,7 +74,7 @@ class BaseIndexerMetadata(ABC):
|
|
|
71
74
|
|
|
72
75
|
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
|
73
76
|
assert x.dtype == torch.bfloat16
|
|
74
|
-
from
|
|
77
|
+
from sgl_kernel import hadamard_transform
|
|
75
78
|
|
|
76
79
|
hidden_size = x.size(-1)
|
|
77
80
|
assert (
|
|
@@ -159,49 +162,13 @@ class Indexer(CustomOp):
|
|
|
159
162
|
base=rope_theta, # type: ignore
|
|
160
163
|
rope_scaling=rope_scaling,
|
|
161
164
|
is_neox_style=False,
|
|
162
|
-
device=
|
|
165
|
+
device=get_global_server_args().device,
|
|
163
166
|
)
|
|
164
167
|
self.block_size = block_size
|
|
165
168
|
self.scale_fmt = scale_fmt
|
|
166
169
|
self.softmax_scale = self.head_dim**-0.5
|
|
167
170
|
|
|
168
|
-
|
|
169
|
-
self,
|
|
170
|
-
x: torch.Tensor,
|
|
171
|
-
q_lora: torch.Tensor,
|
|
172
|
-
positions: torch.Tensor,
|
|
173
|
-
forward_batch: ForwardBatch,
|
|
174
|
-
layer_id: int,
|
|
175
|
-
):
|
|
176
|
-
bs = x.shape[0]
|
|
177
|
-
assert self.index_topk == 2048
|
|
178
|
-
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
|
|
179
|
-
None, ...
|
|
180
|
-
].repeat(bs, 1)
|
|
181
|
-
if forward_batch.forward_mode.is_extend():
|
|
182
|
-
assert (
|
|
183
|
-
forward_batch.extend_seq_lens_cpu is not None
|
|
184
|
-
and forward_batch.seq_lens_cpu is not None
|
|
185
|
-
)
|
|
186
|
-
which = 0
|
|
187
|
-
for i, (kv_len, qo_len) in enumerate(
|
|
188
|
-
zip(
|
|
189
|
-
forward_batch.seq_lens_cpu.tolist(),
|
|
190
|
-
forward_batch.extend_seq_lens_cpu,
|
|
191
|
-
strict=True,
|
|
192
|
-
)
|
|
193
|
-
):
|
|
194
|
-
for j in range(kv_len - qo_len, kv_len):
|
|
195
|
-
ans[which, j + 1 :] = -1
|
|
196
|
-
which += 1
|
|
197
|
-
assert which == ans.shape[0]
|
|
198
|
-
else:
|
|
199
|
-
assert forward_batch.seq_lens_cpu is not None
|
|
200
|
-
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
|
|
201
|
-
ans[i, seq_len:] = -1
|
|
202
|
-
|
|
203
|
-
return ans
|
|
204
|
-
|
|
171
|
+
@torch.compile(dynamic=True)
|
|
205
172
|
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
|
206
173
|
weights, _ = self.weights_proj(x)
|
|
207
174
|
weights = weights * self.n_heads**-0.5
|
|
@@ -299,7 +266,10 @@ class Indexer(CustomOp):
|
|
|
299
266
|
)
|
|
300
267
|
|
|
301
268
|
blocksize = page_size
|
|
302
|
-
|
|
269
|
+
if forward_batch.forward_mode.is_target_verify():
|
|
270
|
+
seqlens_32 = metadata.get_seqlens_expanded()
|
|
271
|
+
else:
|
|
272
|
+
seqlens_32 = metadata.get_seqlens_int32()
|
|
303
273
|
# NOTE(dark): 132 is SM count on H200/B200, not magic number
|
|
304
274
|
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
|
|
305
275
|
seqlens_32, blocksize, self.sm_count
|
|
@@ -350,8 +320,9 @@ class Indexer(CustomOp):
|
|
|
350
320
|
k_fp8_list = []
|
|
351
321
|
k_scale_list = []
|
|
352
322
|
ks_list = []
|
|
323
|
+
ke_list = []
|
|
353
324
|
offset = 0
|
|
354
|
-
|
|
325
|
+
seq_lens_expanded = metadata.get_seqlens_expanded()
|
|
355
326
|
block_tables = metadata.get_page_table_64()
|
|
356
327
|
|
|
357
328
|
assert (
|
|
@@ -374,33 +345,37 @@ class Indexer(CustomOp):
|
|
|
374
345
|
)
|
|
375
346
|
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
|
|
376
347
|
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
|
|
348
|
+
ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
|
|
377
349
|
k_fp8_list.append(k_fp8)
|
|
378
350
|
k_scale_list.append(k_scale)
|
|
379
351
|
ks_list.append(ks)
|
|
352
|
+
ke_list.append(ke)
|
|
380
353
|
offset += extend_seq_len
|
|
381
354
|
|
|
382
355
|
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
|
|
383
356
|
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
|
|
384
357
|
kv_fp8 = (k_fp8, k_scale)
|
|
385
358
|
ks = torch.cat(ks_list, dim=0)
|
|
386
|
-
|
|
387
|
-
ke = ks + seq_lens_expanded
|
|
359
|
+
ke = torch.cat(ke_list, dim=0)
|
|
388
360
|
|
|
389
361
|
logits = deep_gemm.fp8_mqa_logits(
|
|
390
|
-
q_fp8,
|
|
362
|
+
q_fp8[:offset],
|
|
391
363
|
kv_fp8,
|
|
392
|
-
weights,
|
|
364
|
+
weights[:offset],
|
|
393
365
|
ks,
|
|
394
366
|
ke,
|
|
395
367
|
clean_logits=False,
|
|
396
368
|
)
|
|
397
|
-
|
|
369
|
+
token_nums, _, _ = q_fp8.shape
|
|
398
370
|
assert logits.shape[0] == len(seq_lens_expanded)
|
|
399
|
-
|
|
400
|
-
|
|
371
|
+
raw_topk_result = metadata.topk_transform(logits, self.index_topk)
|
|
372
|
+
topk_result = torch.full(
|
|
373
|
+
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
|
|
374
|
+
)
|
|
375
|
+
topk_result[:offset] = raw_topk_result
|
|
401
376
|
return topk_result
|
|
402
377
|
|
|
403
|
-
def
|
|
378
|
+
def forward_indexer(
|
|
404
379
|
self,
|
|
405
380
|
q_fp8: torch.Tensor,
|
|
406
381
|
weights: torch.Tensor,
|
|
@@ -481,20 +456,9 @@ class Indexer(CustomOp):
|
|
|
481
456
|
q_len_start = q_len_end
|
|
482
457
|
|
|
483
458
|
topk_indices = torch.cat(topk_indices_list, dim=0)
|
|
484
|
-
|
|
485
459
|
return topk_indices
|
|
486
460
|
|
|
487
|
-
def
|
|
488
|
-
self,
|
|
489
|
-
q_fp8: torch.Tensor,
|
|
490
|
-
weights: torch.Tensor,
|
|
491
|
-
forward_batch: ForwardBatch,
|
|
492
|
-
topk: int,
|
|
493
|
-
layer_id: int,
|
|
494
|
-
) -> Optional[torch.Tensor]:
|
|
495
|
-
return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
|
|
496
|
-
|
|
497
|
-
def _forward(
|
|
461
|
+
def forward_cuda(
|
|
498
462
|
self,
|
|
499
463
|
x: torch.Tensor,
|
|
500
464
|
q_lora: torch.Tensor,
|
|
@@ -502,8 +466,10 @@ class Indexer(CustomOp):
|
|
|
502
466
|
forward_batch: ForwardBatch,
|
|
503
467
|
layer_id: int,
|
|
504
468
|
) -> Optional[torch.Tensor]:
|
|
505
|
-
if
|
|
469
|
+
if is_hip():
|
|
506
470
|
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
|
|
471
|
+
elif not is_npu():
|
|
472
|
+
from sglang.srt.layers.attention.nsa.triton_kernel import act_quant
|
|
507
473
|
|
|
508
474
|
if TYPE_CHECKING:
|
|
509
475
|
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
|
@@ -524,9 +490,6 @@ class Indexer(CustomOp):
|
|
|
524
490
|
if metadata is None:
|
|
525
491
|
return None
|
|
526
492
|
|
|
527
|
-
if not NSA_USE_REAL_INDEXER: # temporary
|
|
528
|
-
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
|
|
529
|
-
|
|
530
493
|
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
|
|
531
494
|
|
|
532
495
|
if enable_dual_stream:
|
|
@@ -545,6 +508,8 @@ class Indexer(CustomOp):
|
|
|
545
508
|
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
|
|
546
509
|
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
|
|
547
510
|
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
|
|
511
|
+
if not forward_batch.out_cache_loc.is_contiguous():
|
|
512
|
+
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
|
|
548
513
|
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
|
|
549
514
|
layer_id=layer_id,
|
|
550
515
|
loc=forward_batch.out_cache_loc,
|
|
@@ -566,7 +531,10 @@ class Indexer(CustomOp):
|
|
|
566
531
|
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
|
|
567
532
|
)
|
|
568
533
|
|
|
569
|
-
if
|
|
534
|
+
if (
|
|
535
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
|
536
|
+
or forward_batch.forward_mode.is_target_verify()
|
|
537
|
+
):
|
|
570
538
|
topk_result = self._get_topk_paged(
|
|
571
539
|
forward_batch, layer_id, q_fp8, weights, metadata
|
|
572
540
|
)
|
|
@@ -582,19 +550,8 @@ class Indexer(CustomOp):
|
|
|
582
550
|
topk=self.index_topk,
|
|
583
551
|
layer_id=layer_id,
|
|
584
552
|
)
|
|
585
|
-
|
|
586
553
|
return topk_result
|
|
587
554
|
|
|
588
|
-
def forward_cuda(
|
|
589
|
-
self,
|
|
590
|
-
x: torch.Tensor,
|
|
591
|
-
q_lora: torch.Tensor,
|
|
592
|
-
positions: torch.Tensor,
|
|
593
|
-
forward_batch: ForwardBatch,
|
|
594
|
-
layer_id: int,
|
|
595
|
-
) -> Optional[torch.Tensor]:
|
|
596
|
-
return self._forward(x, q_lora, positions, forward_batch, layer_id)
|
|
597
|
-
|
|
598
555
|
def forward_npu(
|
|
599
556
|
self,
|
|
600
557
|
x: torch.Tensor,
|
|
@@ -603,7 +560,7 @@ class Indexer(CustomOp):
|
|
|
603
560
|
forward_batch: ForwardBatch,
|
|
604
561
|
layer_id: int,
|
|
605
562
|
) -> torch.Tensor:
|
|
606
|
-
import custom_ops
|
|
563
|
+
import custom_ops # noqa: F401
|
|
607
564
|
import torch_npu
|
|
608
565
|
|
|
609
566
|
from sglang.srt.layers.dp_attention import (
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Triton implementation
|
|
9
|
+
@triton.jit
|
|
10
|
+
def _act_quant_kernel(
|
|
11
|
+
X_ptr,
|
|
12
|
+
Y_ptr,
|
|
13
|
+
S_ptr,
|
|
14
|
+
M,
|
|
15
|
+
N,
|
|
16
|
+
group_size: tl.constexpr,
|
|
17
|
+
round_scale: tl.constexpr,
|
|
18
|
+
BLOCK_M: tl.constexpr,
|
|
19
|
+
BLOCK_N: tl.constexpr,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Triton kernel for activation quantization.
|
|
23
|
+
|
|
24
|
+
Each block processes BLOCK_M rows and group_size columns.
|
|
25
|
+
"""
|
|
26
|
+
# Get block IDs
|
|
27
|
+
pid_m = tl.program_id(0)
|
|
28
|
+
pid_n = tl.program_id(1)
|
|
29
|
+
|
|
30
|
+
# FP8 constants
|
|
31
|
+
fp8_min = -448.0
|
|
32
|
+
fp8_max = 448.0
|
|
33
|
+
fp8_max_inv = 1.0 / fp8_max
|
|
34
|
+
|
|
35
|
+
# Calculate row and column offsets
|
|
36
|
+
row_start = pid_m * BLOCK_M
|
|
37
|
+
col_start = pid_n * group_size
|
|
38
|
+
|
|
39
|
+
# Create offset arrays
|
|
40
|
+
rows = row_start + tl.arange(0, BLOCK_M)
|
|
41
|
+
cols = col_start + tl.arange(0, BLOCK_N)
|
|
42
|
+
|
|
43
|
+
# Mask for valid rows and columns
|
|
44
|
+
row_mask = rows < M
|
|
45
|
+
col_mask = cols < N
|
|
46
|
+
mask = row_mask[:, None] & col_mask[None, :]
|
|
47
|
+
|
|
48
|
+
# Load input data
|
|
49
|
+
x_ptrs = X_ptr + rows[:, None] * N + cols[None, :]
|
|
50
|
+
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
|
|
51
|
+
|
|
52
|
+
# Compute absolute max along columns (group_size dimension) for each row
|
|
53
|
+
x_abs = tl.abs(x)
|
|
54
|
+
amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,)
|
|
55
|
+
|
|
56
|
+
# Clamp amax to avoid division by zero
|
|
57
|
+
amax = tl.maximum(amax, 1e-4)
|
|
58
|
+
|
|
59
|
+
# Compute scale
|
|
60
|
+
if round_scale:
|
|
61
|
+
# Fast round scale using bit manipulation approximation
|
|
62
|
+
# This is a simplified version - the exact bit manipulation is harder in Triton
|
|
63
|
+
# Using log2 + ceil + pow2 as approximation
|
|
64
|
+
log_val = tl.log2(amax * fp8_max_inv)
|
|
65
|
+
log_ceil = tl.ceil(log_val)
|
|
66
|
+
scale = tl.exp2(log_ceil)
|
|
67
|
+
else:
|
|
68
|
+
scale = amax * fp8_max_inv
|
|
69
|
+
|
|
70
|
+
# Quantize: y = clamp(x / scale, fp8_min, fp8_max)
|
|
71
|
+
scale_broadcast = scale[:, None]
|
|
72
|
+
y = x / scale_broadcast
|
|
73
|
+
y = tl.minimum(tl.maximum(y, fp8_min), fp8_max)
|
|
74
|
+
|
|
75
|
+
# Store quantized output
|
|
76
|
+
y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :]
|
|
77
|
+
tl.store(y_ptrs, y, mask=mask)
|
|
78
|
+
|
|
79
|
+
# Store scales
|
|
80
|
+
s_cols = pid_n
|
|
81
|
+
s_ptrs = S_ptr + rows * (N // group_size) + s_cols
|
|
82
|
+
s_mask = row_mask
|
|
83
|
+
tl.store(s_ptrs, scale, mask=s_mask)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def act_quant(
|
|
87
|
+
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
|
88
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
89
|
+
"""
|
|
90
|
+
Quantizes the input tensor `x` using block-wise quantization with Triton.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
|
94
|
+
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
|
95
|
+
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
|
96
|
+
Returns:
|
|
97
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
98
|
+
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
|
99
|
+
- A tensor of scaling factors with dtype `torch.float32`.
|
|
100
|
+
"""
|
|
101
|
+
assert x.is_contiguous(), "Input tensor must be contiguous"
|
|
102
|
+
assert (
|
|
103
|
+
x.size(-1) % block_size == 0
|
|
104
|
+
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
|
105
|
+
|
|
106
|
+
# Flatten all dims except last
|
|
107
|
+
N = x.size(-1)
|
|
108
|
+
x_flat = x.view(-1, N)
|
|
109
|
+
M = x_flat.size(0)
|
|
110
|
+
|
|
111
|
+
# Allocate output tensors
|
|
112
|
+
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
113
|
+
y_flat = y.view(-1, N)
|
|
114
|
+
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
|
115
|
+
s_flat = s.view(-1, N // block_size)
|
|
116
|
+
|
|
117
|
+
# Launch kernel
|
|
118
|
+
BLOCK_M = 32
|
|
119
|
+
BLOCK_N = block_size
|
|
120
|
+
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size))
|
|
121
|
+
round_scale = scale_fmt is not None
|
|
122
|
+
|
|
123
|
+
_act_quant_kernel[grid](
|
|
124
|
+
x_flat,
|
|
125
|
+
y_flat,
|
|
126
|
+
s_flat,
|
|
127
|
+
M,
|
|
128
|
+
N,
|
|
129
|
+
group_size=block_size,
|
|
130
|
+
round_scale=round_scale,
|
|
131
|
+
BLOCK_M=BLOCK_M,
|
|
132
|
+
BLOCK_N=BLOCK_N,
|
|
133
|
+
num_stages=0 if round_scale else 2,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return y, s
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# temp NSA debugging environ
|
|
2
2
|
from sglang.srt.utils import get_bool_env_var
|
|
3
3
|
|
|
4
|
-
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
|
|
5
4
|
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
|
|
6
5
|
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
|
|
7
6
|
|