sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,549 @@
|
|
1
|
+
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
from collections import namedtuple
|
5
|
+
from collections.abc import Callable
|
6
|
+
from typing import Any, Dict
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import triton
|
10
|
+
import triton.language as tl
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"set_batch_invariant_mode",
|
14
|
+
"is_batch_invariant_mode_enabled",
|
15
|
+
"disable_batch_invariant_mode",
|
16
|
+
"enable_batch_invariant_mode",
|
17
|
+
]
|
18
|
+
|
19
|
+
|
20
|
+
def _matmul_launch_metadata(
|
21
|
+
grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]
|
22
|
+
) -> Dict[str, Any]:
|
23
|
+
ret = {}
|
24
|
+
m, n, k = args["M"], args["N"], args["K"]
|
25
|
+
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
26
|
+
if "tiles_per_update" in args:
|
27
|
+
ret["name"] = (
|
28
|
+
f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]"
|
29
|
+
)
|
30
|
+
if "c_ptr" in args:
|
31
|
+
bytes_per_elem = args["c_ptr"].element_size()
|
32
|
+
else:
|
33
|
+
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
34
|
+
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
35
|
+
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
36
|
+
return ret
|
37
|
+
|
38
|
+
|
39
|
+
@triton.jit
|
40
|
+
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
41
|
+
group_id = tile_id // num_pid_in_group
|
42
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
43
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
44
|
+
pid_m = first_pid_m + (tile_id % group_size_m)
|
45
|
+
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
46
|
+
return pid_m, pid_n
|
47
|
+
|
48
|
+
|
49
|
+
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
50
|
+
def matmul_kernel_persistent(
|
51
|
+
a_ptr,
|
52
|
+
b_ptr,
|
53
|
+
c_ptr, #
|
54
|
+
bias_ptr,
|
55
|
+
M,
|
56
|
+
N,
|
57
|
+
K, #
|
58
|
+
stride_am,
|
59
|
+
stride_ak,
|
60
|
+
stride_bk,
|
61
|
+
stride_bn,
|
62
|
+
stride_cm,
|
63
|
+
stride_cn,
|
64
|
+
BLOCK_SIZE_M: tl.constexpr, #
|
65
|
+
BLOCK_SIZE_N: tl.constexpr, #
|
66
|
+
BLOCK_SIZE_K: tl.constexpr, #
|
67
|
+
GROUP_SIZE_M: tl.constexpr, #
|
68
|
+
NUM_SMS: tl.constexpr, #
|
69
|
+
A_LARGE: tl.constexpr,
|
70
|
+
B_LARGE: tl.constexpr,
|
71
|
+
C_LARGE: tl.constexpr,
|
72
|
+
HAS_BIAS: tl.constexpr,
|
73
|
+
):
|
74
|
+
start_pid = tl.program_id(axis=0)
|
75
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
76
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
77
|
+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
78
|
+
num_tiles = num_pid_m * num_pid_n
|
79
|
+
|
80
|
+
tile_id_c = start_pid - NUM_SMS
|
81
|
+
|
82
|
+
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
83
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
84
|
+
|
85
|
+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
86
|
+
pid_m, pid_n = _compute_pid(
|
87
|
+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
88
|
+
)
|
89
|
+
start_m = pid_m * BLOCK_SIZE_M
|
90
|
+
start_n = pid_n * BLOCK_SIZE_N
|
91
|
+
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
92
|
+
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
93
|
+
if A_LARGE:
|
94
|
+
offs_am = offs_am.to(tl.int64)
|
95
|
+
if B_LARGE:
|
96
|
+
offs_bn = offs_bn.to(tl.int64)
|
97
|
+
offs_am = tl.where(offs_am < M, offs_am, 0)
|
98
|
+
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
99
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
100
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
101
|
+
|
102
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
103
|
+
for ki in range(k_tiles):
|
104
|
+
if A_LARGE or B_LARGE:
|
105
|
+
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
106
|
+
else:
|
107
|
+
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
108
|
+
a_ptrs = a_ptr + (
|
109
|
+
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
110
|
+
)
|
111
|
+
b_ptrs = b_ptr + (
|
112
|
+
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
113
|
+
)
|
114
|
+
|
115
|
+
a = tl.load(
|
116
|
+
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
|
117
|
+
)
|
118
|
+
b = tl.load(
|
119
|
+
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
|
120
|
+
)
|
121
|
+
accumulator = tl.dot(a, b, accumulator)
|
122
|
+
|
123
|
+
tile_id_c += NUM_SMS
|
124
|
+
pid_m, pid_n = _compute_pid(
|
125
|
+
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
126
|
+
)
|
127
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
128
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
129
|
+
if C_LARGE:
|
130
|
+
offs_cm = offs_cm.to(tl.int64)
|
131
|
+
offs_cn = offs_cn.to(tl.int64)
|
132
|
+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
133
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
134
|
+
if HAS_BIAS:
|
135
|
+
bias_ptrs = bias_ptr + offs_cn
|
136
|
+
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
137
|
+
accumulator += bias
|
138
|
+
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
139
|
+
c = accumulator.to(tl.float8e4nv)
|
140
|
+
else:
|
141
|
+
c = accumulator.to(tl.float16)
|
142
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
143
|
+
|
144
|
+
|
145
|
+
def matmul_persistent(
|
146
|
+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
147
|
+
):
|
148
|
+
# Check constraints.
|
149
|
+
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
150
|
+
assert a.dtype == b.dtype, "Incompatible dtypes"
|
151
|
+
assert (
|
152
|
+
bias is None or bias.dim() == 1
|
153
|
+
), "Currently assuming bias is 1D, let Horace know if you run into this"
|
154
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
155
|
+
M, K = a.shape
|
156
|
+
K, N = b.shape
|
157
|
+
dtype = a.dtype
|
158
|
+
# Allocates output.
|
159
|
+
c = torch.empty((M, N), device=a.device, dtype=dtype)
|
160
|
+
|
161
|
+
# 1D launch kernel where each block gets its own program.
|
162
|
+
def grid(META):
|
163
|
+
return (
|
164
|
+
min(
|
165
|
+
NUM_SMS,
|
166
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
167
|
+
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
168
|
+
),
|
169
|
+
)
|
170
|
+
|
171
|
+
configs = {
|
172
|
+
torch.bfloat16: {
|
173
|
+
"BLOCK_SIZE_M": 128,
|
174
|
+
"BLOCK_SIZE_N": 128,
|
175
|
+
"BLOCK_SIZE_K": 64,
|
176
|
+
"GROUP_SIZE_M": 8,
|
177
|
+
"num_stages": 3,
|
178
|
+
"num_warps": 8,
|
179
|
+
},
|
180
|
+
torch.float16: {
|
181
|
+
"BLOCK_SIZE_M": 128,
|
182
|
+
"BLOCK_SIZE_N": 256,
|
183
|
+
"BLOCK_SIZE_K": 64,
|
184
|
+
"GROUP_SIZE_M": 8,
|
185
|
+
"num_stages": 3,
|
186
|
+
"num_warps": 8,
|
187
|
+
},
|
188
|
+
torch.float32: {
|
189
|
+
"BLOCK_SIZE_M": 128,
|
190
|
+
"BLOCK_SIZE_N": 128,
|
191
|
+
"BLOCK_SIZE_K": 32,
|
192
|
+
"GROUP_SIZE_M": 8,
|
193
|
+
"num_stages": 3,
|
194
|
+
"num_warps": 8,
|
195
|
+
},
|
196
|
+
}
|
197
|
+
# print(a.device, b.device, c.device)
|
198
|
+
matmul_kernel_persistent[grid](
|
199
|
+
a,
|
200
|
+
b,
|
201
|
+
c, #
|
202
|
+
bias,
|
203
|
+
M,
|
204
|
+
N,
|
205
|
+
K, #
|
206
|
+
a.stride(0),
|
207
|
+
a.stride(1), #
|
208
|
+
b.stride(0),
|
209
|
+
b.stride(1), #
|
210
|
+
c.stride(0),
|
211
|
+
c.stride(1), #
|
212
|
+
NUM_SMS=NUM_SMS, #
|
213
|
+
A_LARGE=a.numel() > 2**31,
|
214
|
+
B_LARGE=b.numel() > 2**31,
|
215
|
+
C_LARGE=c.numel() > 2**31,
|
216
|
+
HAS_BIAS=bias is not None,
|
217
|
+
**configs[dtype],
|
218
|
+
)
|
219
|
+
return c
|
220
|
+
|
221
|
+
|
222
|
+
@triton.jit
|
223
|
+
def _log_softmax_kernel(
|
224
|
+
input_ptr,
|
225
|
+
output_ptr,
|
226
|
+
input_row_stride,
|
227
|
+
output_row_stride,
|
228
|
+
n_cols,
|
229
|
+
BLOCK_SIZE: tl.constexpr,
|
230
|
+
):
|
231
|
+
"""
|
232
|
+
Compute log_softmax along the last dimension of a 2D tensor.
|
233
|
+
Each block handles one row of the input tensor.
|
234
|
+
"""
|
235
|
+
# Get the row index for this block
|
236
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
237
|
+
|
238
|
+
# Compute base pointers for input and output rows
|
239
|
+
row_start_ptr = input_ptr + row_idx * input_row_stride
|
240
|
+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
241
|
+
|
242
|
+
# Step 1: Find maximum value in the row for numerical stability
|
243
|
+
max_val = -float("inf")
|
244
|
+
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
245
|
+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
246
|
+
mask = col_idx < n_cols
|
247
|
+
|
248
|
+
# Load values
|
249
|
+
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
250
|
+
|
251
|
+
# Update maximum
|
252
|
+
max_val = tl.max(tl.maximum(vals, max_val))
|
253
|
+
|
254
|
+
# Step 2: Compute sum of exp(x - max_val)
|
255
|
+
sum_exp = 0.0
|
256
|
+
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
257
|
+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
258
|
+
mask = col_idx < n_cols
|
259
|
+
|
260
|
+
# Load values
|
261
|
+
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
262
|
+
|
263
|
+
# Compute exp(x - max_val) and accumulate
|
264
|
+
exp_vals = tl.exp(vals - max_val)
|
265
|
+
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
266
|
+
|
267
|
+
# Compute log(sum_exp)
|
268
|
+
log_sum_exp = tl.log(sum_exp)
|
269
|
+
|
270
|
+
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
271
|
+
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
272
|
+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
273
|
+
mask = col_idx < n_cols
|
274
|
+
|
275
|
+
# Load values
|
276
|
+
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
277
|
+
|
278
|
+
# Compute log_softmax
|
279
|
+
output = vals - max_val - log_sum_exp
|
280
|
+
|
281
|
+
# Store results
|
282
|
+
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
283
|
+
|
284
|
+
|
285
|
+
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
286
|
+
"""
|
287
|
+
Compute log_softmax using Triton kernel.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
input: Input tensor
|
291
|
+
dim: Dimension along which to compute log_softmax (only -1 or last dim supported)
|
292
|
+
>> Stashed changes
|
293
|
+
Returns:
|
294
|
+
Tensor with log_softmax applied along the specified dimension
|
295
|
+
"""
|
296
|
+
if dim != -1 and dim != input.ndim - 1:
|
297
|
+
raise ValueError(
|
298
|
+
"This implementation only supports log_softmax along the last dimension"
|
299
|
+
)
|
300
|
+
|
301
|
+
# Flatten all dimensions except the last one
|
302
|
+
original_shape = input.shape
|
303
|
+
input_2d = input.reshape(-1, input.shape[-1])
|
304
|
+
input_2d = input_2d.contiguous()
|
305
|
+
|
306
|
+
n_rows, n_cols = input_2d.shape
|
307
|
+
|
308
|
+
# Allocate output tensor
|
309
|
+
output = torch.empty_like(input_2d)
|
310
|
+
|
311
|
+
# Choose block size based on the number of columns
|
312
|
+
BLOCK_SIZE = 1024
|
313
|
+
|
314
|
+
# Launch kernel with one block per row
|
315
|
+
grid = (n_rows,)
|
316
|
+
_log_softmax_kernel[grid](
|
317
|
+
input_2d,
|
318
|
+
output,
|
319
|
+
input_2d.stride(0),
|
320
|
+
output.stride(0),
|
321
|
+
n_cols,
|
322
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
323
|
+
)
|
324
|
+
# Reshape output back to original shape
|
325
|
+
return output.reshape(original_shape)
|
326
|
+
|
327
|
+
|
328
|
+
@triton.jit
|
329
|
+
def mean_kernel(
|
330
|
+
input_ptr,
|
331
|
+
output_ptr,
|
332
|
+
input_stride0,
|
333
|
+
input_stride1,
|
334
|
+
input_stride2,
|
335
|
+
output_stride0,
|
336
|
+
output_stride1,
|
337
|
+
M, # size before reduction dim
|
338
|
+
N, # size of reduction dim
|
339
|
+
K, # size after reduction dim
|
340
|
+
BLOCK_SIZE: tl.constexpr,
|
341
|
+
):
|
342
|
+
"""
|
343
|
+
Kernel for computing mean along a single dimension.
|
344
|
+
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
345
|
+
"""
|
346
|
+
# Program ID gives us which output element we're computing
|
347
|
+
pid = tl.program_id(0)
|
348
|
+
|
349
|
+
# Compute output indices
|
350
|
+
m_idx = pid // K
|
351
|
+
k_idx = pid % K
|
352
|
+
|
353
|
+
# Bounds check
|
354
|
+
if m_idx >= M or k_idx >= K:
|
355
|
+
return
|
356
|
+
|
357
|
+
# Accumulate sum across reduction dimension
|
358
|
+
acc = 0.0
|
359
|
+
for n_start in range(0, N, BLOCK_SIZE):
|
360
|
+
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
361
|
+
mask = n_offsets < N
|
362
|
+
|
363
|
+
# Calculate input indices
|
364
|
+
input_idx = (
|
365
|
+
m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
366
|
+
)
|
367
|
+
|
368
|
+
# Load and accumulate
|
369
|
+
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
370
|
+
acc += tl.sum(vals)
|
371
|
+
|
372
|
+
# Compute mean and store
|
373
|
+
mean_val = acc / N
|
374
|
+
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
375
|
+
tl.store(output_ptr + output_idx, mean_val)
|
376
|
+
|
377
|
+
|
378
|
+
def mean_dim(
|
379
|
+
input: torch.Tensor,
|
380
|
+
dim: int,
|
381
|
+
keepdim: bool = False,
|
382
|
+
dtype: torch.dtype | None = None,
|
383
|
+
) -> torch.Tensor:
|
384
|
+
"""
|
385
|
+
Triton implementation of torch.mean with single dimension reduction.
|
386
|
+
|
387
|
+
Args:
|
388
|
+
input: Input tensor
|
389
|
+
dim: Single dimension along which to compute mean
|
390
|
+
keepdim: Whether to keep the reduced dimension
|
391
|
+
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
|
392
|
+
|
393
|
+
Returns:
|
394
|
+
Tensor with mean values along specified dimension
|
395
|
+
"""
|
396
|
+
# Validate inputs
|
397
|
+
assert input.is_cuda, "Input must be a CUDA tensor"
|
398
|
+
assert (
|
399
|
+
-input.ndim <= dim < input.ndim
|
400
|
+
), f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
401
|
+
|
402
|
+
# Handle negative dim
|
403
|
+
if dim < 0:
|
404
|
+
dim = dim + input.ndim
|
405
|
+
|
406
|
+
# Handle dtype
|
407
|
+
if dtype is None:
|
408
|
+
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
409
|
+
dtype = torch.float32
|
410
|
+
else:
|
411
|
+
dtype = input.dtype
|
412
|
+
|
413
|
+
# Convert input to appropriate dtype if needed
|
414
|
+
if input.dtype != dtype:
|
415
|
+
input = input.to(dtype)
|
416
|
+
|
417
|
+
# Get input shape and strides
|
418
|
+
shape = list(input.shape)
|
419
|
+
|
420
|
+
# Calculate dimensions for kernel
|
421
|
+
M = 1
|
422
|
+
for i in range(dim):
|
423
|
+
M *= shape[i]
|
424
|
+
|
425
|
+
N = shape[dim]
|
426
|
+
|
427
|
+
K = 1
|
428
|
+
for i in range(dim + 1, len(shape)):
|
429
|
+
K *= shape[i]
|
430
|
+
|
431
|
+
# Reshape input to 3D view (M, N, K)
|
432
|
+
input_3d = input.reshape(M, N, K)
|
433
|
+
|
434
|
+
# Create output shape
|
435
|
+
if keepdim:
|
436
|
+
output_shape = shape.copy()
|
437
|
+
output_shape[dim] = 1
|
438
|
+
else:
|
439
|
+
output_shape = shape[:dim] + shape[dim + 1 :]
|
440
|
+
|
441
|
+
# Create output tensor
|
442
|
+
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
443
|
+
|
444
|
+
# Reshape output for kernel
|
445
|
+
if keepdim:
|
446
|
+
output_2d = output.reshape(M, 1, K).squeeze(1)
|
447
|
+
else:
|
448
|
+
output_2d = output.reshape(M, K)
|
449
|
+
|
450
|
+
# Launch kernel
|
451
|
+
grid = (M * K,)
|
452
|
+
BLOCK_SIZE = 1024
|
453
|
+
|
454
|
+
mean_kernel[grid](
|
455
|
+
input_3d,
|
456
|
+
output_2d,
|
457
|
+
input_3d.stride(0),
|
458
|
+
input_3d.stride(1),
|
459
|
+
input_3d.stride(2),
|
460
|
+
output_2d.stride(0),
|
461
|
+
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
462
|
+
M,
|
463
|
+
N,
|
464
|
+
K,
|
465
|
+
BLOCK_SIZE,
|
466
|
+
)
|
467
|
+
|
468
|
+
return output
|
469
|
+
|
470
|
+
|
471
|
+
def mm_batch_invariant(a, b):
|
472
|
+
return matmul_persistent(a, b)
|
473
|
+
|
474
|
+
|
475
|
+
def addmm_batch_invariant(bias, a, b):
|
476
|
+
return matmul_persistent(a, b, bias=bias)
|
477
|
+
|
478
|
+
|
479
|
+
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
480
|
+
assert not _half_to_float, "not implemented"
|
481
|
+
return log_softmax(input, dim=dim)
|
482
|
+
|
483
|
+
|
484
|
+
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
485
|
+
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
486
|
+
if len(dim) == 1:
|
487
|
+
return mean_dim(input, dim[0], keepdim=keepdim)
|
488
|
+
else:
|
489
|
+
assert input.dtype in {
|
490
|
+
torch.float16,
|
491
|
+
torch.bfloat16,
|
492
|
+
torch.float32,
|
493
|
+
}, "only float types supported for now"
|
494
|
+
n_elems = 1
|
495
|
+
for d in dim:
|
496
|
+
n_elems *= input.shape[d]
|
497
|
+
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
|
498
|
+
|
499
|
+
|
500
|
+
_batch_invariant_MODE = False
|
501
|
+
_batch_invariant_LIB = None
|
502
|
+
|
503
|
+
|
504
|
+
def is_batch_invariant_mode_enabled():
|
505
|
+
return _batch_invariant_MODE
|
506
|
+
|
507
|
+
|
508
|
+
def enable_batch_invariant_mode():
|
509
|
+
global _batch_invariant_MODE, _batch_invariant_LIB
|
510
|
+
if _batch_invariant_MODE:
|
511
|
+
return
|
512
|
+
|
513
|
+
_batch_invariant_MODE = True
|
514
|
+
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
515
|
+
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
516
|
+
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
517
|
+
_batch_invariant_LIB.impl(
|
518
|
+
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
519
|
+
)
|
520
|
+
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
521
|
+
|
522
|
+
|
523
|
+
def disable_batch_invariant_mode():
|
524
|
+
global _batch_invariant_MODE, _batch_invariant_LIB
|
525
|
+
if _batch_invariant_LIB is not None:
|
526
|
+
_batch_invariant_LIB._destroy()
|
527
|
+
_batch_invariant_MODE = False
|
528
|
+
_batch_invariant_LIB = None
|
529
|
+
|
530
|
+
|
531
|
+
@contextlib.contextmanager
|
532
|
+
def set_batch_invariant_mode(enabled: bool = True):
|
533
|
+
global _batch_invariant_MODE, _batch_invariant_LIB
|
534
|
+
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
|
535
|
+
if enabled:
|
536
|
+
enable_batch_invariant_mode()
|
537
|
+
else:
|
538
|
+
disable_batch_invariant_mode()
|
539
|
+
yield
|
540
|
+
if _batch_invariant_LIB is not None:
|
541
|
+
_batch_invariant_LIB._destroy()
|
542
|
+
_batch_invariant_MODE, _batch_invariant_LIB = old_data
|
543
|
+
|
544
|
+
|
545
|
+
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
|
546
|
+
|
547
|
+
|
548
|
+
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
549
|
+
return AttentionBlockSize(block_m=16, block_n=16)
|
sglang/srt/configs/__init__.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1
1
|
from sglang.srt.configs.chatglm import ChatGLMConfig
|
2
2
|
from sglang.srt.configs.dbrx import DbrxConfig
|
3
3
|
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
4
|
+
from sglang.srt.configs.dots_ocr import DotsOCRConfig
|
5
|
+
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
4
6
|
from sglang.srt.configs.exaone import ExaoneConfig
|
7
|
+
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
5
8
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
6
9
|
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
7
10
|
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
8
11
|
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
12
|
+
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
9
13
|
from sglang.srt.configs.step3_vl import (
|
10
14
|
Step3TextConfig,
|
11
15
|
Step3VisionEncoderConfig,
|
@@ -24,4 +28,8 @@ __all__ = [
|
|
24
28
|
"Step3VLConfig",
|
25
29
|
"Step3TextConfig",
|
26
30
|
"Step3VisionEncoderConfig",
|
31
|
+
"Qwen3NextConfig",
|
32
|
+
"DotsVLMConfig",
|
33
|
+
"DotsOCRConfig",
|
34
|
+
"FalconH1Config",
|
27
35
|
]
|
@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
|
|
8
8
|
|
9
9
|
class DeviceConfig:
|
10
10
|
device: Optional[torch.device]
|
11
|
+
gpu_id: Optional[int]
|
11
12
|
|
12
|
-
def __init__(self, device: str = "cuda") -> None:
|
13
|
+
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
|
13
14
|
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
14
15
|
self.device_type = device
|
15
16
|
else:
|
16
17
|
raise RuntimeError(f"Not supported device type: {device}")
|
17
18
|
self.device = torch.device(self.device_type)
|
19
|
+
self.gpu_id = gpu_id
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from transformers import AutoProcessor, Qwen2_5_VLProcessor
|
4
|
+
from transformers.image_processing_utils import BaseImageProcessor
|
5
|
+
from transformers.models.qwen2 import Qwen2Config
|
6
|
+
|
7
|
+
from sglang.srt.configs.dots_vlm import DotsVisionConfig
|
8
|
+
|
9
|
+
|
10
|
+
class DotsOCRConfig(Qwen2Config):
|
11
|
+
model_type = "dots_ocr"
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
image_token_id=151665,
|
16
|
+
video_token_id=151656,
|
17
|
+
vision_config: Optional[dict] = None,
|
18
|
+
*args,
|
19
|
+
**kwargs
|
20
|
+
):
|
21
|
+
super().__init__(*args, **kwargs)
|
22
|
+
self.image_token_id = image_token_id
|
23
|
+
self.video_token_id = video_token_id
|
24
|
+
self.vision_config = DotsVisionConfig(**(vision_config or {}))
|
25
|
+
|
26
|
+
def save_pretrained(self, save_directory, **kwargs):
|
27
|
+
self._auto_class = None
|
28
|
+
super().save_pretrained(save_directory, **kwargs)
|
29
|
+
|
30
|
+
|
31
|
+
class DummyVideoProcessor(BaseImageProcessor):
|
32
|
+
model_input_names = ["pixel_values"]
|
33
|
+
|
34
|
+
def __call__(self, *args, **kwargs):
|
35
|
+
return None
|
36
|
+
|
37
|
+
|
38
|
+
class DotsVLProcessor(Qwen2_5_VLProcessor):
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
image_processor=None,
|
42
|
+
tokenizer=None,
|
43
|
+
video_processor=None,
|
44
|
+
chat_template=None,
|
45
|
+
**kwargs
|
46
|
+
):
|
47
|
+
if video_processor is None:
|
48
|
+
video_processor = DummyVideoProcessor()
|
49
|
+
super().__init__(
|
50
|
+
image_processor, tokenizer, video_processor, chat_template=chat_template
|
51
|
+
)
|
52
|
+
self.image_token = (
|
53
|
+
"<|imgpad|>"
|
54
|
+
if not hasattr(tokenizer, "image_token")
|
55
|
+
else tokenizer.image_token
|
56
|
+
)
|
57
|
+
self.image_token_id = (
|
58
|
+
tokenizer.image_token_id
|
59
|
+
if getattr(tokenizer, "image_token_id", None) is not None
|
60
|
+
else tokenizer.convert_tokens_to_ids(self.image_token)
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
AutoProcessor.register(DotsOCRConfig, DotsVLProcessor)
|