sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
"""
|
2
|
+
Configuration argument parser for command-line applications.
|
3
|
+
Handles merging of YAML configuration files with command-line arguments.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Dict, List, Union
|
9
|
+
|
10
|
+
import yaml
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class ConfigArgumentMerger:
|
16
|
+
"""Handles merging of configuration file arguments with command-line arguments."""
|
17
|
+
|
18
|
+
def __init__(self, boolean_actions: List[str] = None):
|
19
|
+
"""Initialize with list of boolean action destinations."""
|
20
|
+
self.boolean_actions = boolean_actions or []
|
21
|
+
|
22
|
+
def merge_config_with_args(self, cli_args: List[str]) -> List[str]:
|
23
|
+
"""
|
24
|
+
Merge configuration file arguments with command-line arguments.
|
25
|
+
|
26
|
+
Configuration arguments are inserted after the subcommand to maintain
|
27
|
+
proper precedence: CLI > Config > Defaults
|
28
|
+
|
29
|
+
Args:
|
30
|
+
cli_args: List of command-line arguments
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
Merged argument list with config values inserted
|
34
|
+
|
35
|
+
Raises:
|
36
|
+
ValueError: If multiple config files specified or no config file provided
|
37
|
+
"""
|
38
|
+
config_file_path = self._extract_config_file_path(cli_args)
|
39
|
+
if not config_file_path:
|
40
|
+
return cli_args
|
41
|
+
|
42
|
+
config_args = self._parse_yaml_config(config_file_path)
|
43
|
+
return self._insert_config_args(cli_args, config_args, config_file_path)
|
44
|
+
|
45
|
+
def _extract_config_file_path(self, args: List[str]) -> str:
|
46
|
+
"""Extract the config file path from arguments."""
|
47
|
+
config_indices = [i for i, arg in enumerate(args) if arg == "--config"]
|
48
|
+
|
49
|
+
if len(config_indices) > 1:
|
50
|
+
raise ValueError("Multiple config files specified! Only one allowed.")
|
51
|
+
|
52
|
+
if not config_indices:
|
53
|
+
return None
|
54
|
+
|
55
|
+
config_index = config_indices[0]
|
56
|
+
if config_index == len(args) - 1:
|
57
|
+
raise ValueError("No config file specified after --config flag!")
|
58
|
+
|
59
|
+
return args[config_index + 1]
|
60
|
+
|
61
|
+
def _insert_config_args(
|
62
|
+
self, cli_args: List[str], config_args: List[str], config_file_path: str
|
63
|
+
) -> List[str]:
|
64
|
+
"""Insert configuration arguments into the CLI argument list."""
|
65
|
+
config_index = cli_args.index("--config")
|
66
|
+
|
67
|
+
# Split arguments around config file
|
68
|
+
before_config = cli_args[:config_index]
|
69
|
+
after_config = cli_args[config_index + 2 :] # Skip --config and file path
|
70
|
+
|
71
|
+
# Simple merge: config args + CLI args
|
72
|
+
return config_args + before_config + after_config
|
73
|
+
|
74
|
+
def _parse_yaml_config(self, file_path: str) -> List[str]:
|
75
|
+
"""
|
76
|
+
Parse YAML configuration file and convert to argument list.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
file_path: Path to the YAML configuration file
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
List of arguments in format ['--key', 'value', ...]
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
ValueError: If file is not YAML or cannot be read
|
86
|
+
"""
|
87
|
+
self._validate_yaml_file(file_path)
|
88
|
+
|
89
|
+
try:
|
90
|
+
with open(file_path, "r") as file:
|
91
|
+
config_data = yaml.safe_load(file)
|
92
|
+
except Exception as e:
|
93
|
+
logger.error(f"Failed to read config file {file_path}: {e}")
|
94
|
+
raise
|
95
|
+
|
96
|
+
# Handle empty files or None content
|
97
|
+
if config_data is None:
|
98
|
+
config_data = {}
|
99
|
+
|
100
|
+
if not isinstance(config_data, dict):
|
101
|
+
raise ValueError("Config file must contain a dictionary at root level")
|
102
|
+
|
103
|
+
return self._convert_config_to_args(config_data)
|
104
|
+
|
105
|
+
def _validate_yaml_file(self, file_path: str) -> None:
|
106
|
+
"""Validate that the file is a YAML file."""
|
107
|
+
path = Path(file_path)
|
108
|
+
if path.suffix.lower() not in [".yaml", ".yml"]:
|
109
|
+
raise ValueError(f"Config file must be YAML format, got: {path.suffix}")
|
110
|
+
|
111
|
+
if not path.exists():
|
112
|
+
raise ValueError(f"Config file not found: {file_path}")
|
113
|
+
|
114
|
+
def _convert_config_to_args(self, config: Dict[str, Any]) -> List[str]:
|
115
|
+
"""Convert configuration dictionary to argument list."""
|
116
|
+
args = []
|
117
|
+
|
118
|
+
for key, value in config.items():
|
119
|
+
if isinstance(value, bool):
|
120
|
+
self._add_boolean_arg(args, key, value)
|
121
|
+
elif isinstance(value, list):
|
122
|
+
self._add_list_arg(args, key, value)
|
123
|
+
else:
|
124
|
+
self._add_scalar_arg(args, key, value)
|
125
|
+
|
126
|
+
return args
|
127
|
+
|
128
|
+
def _add_boolean_arg(self, args: List[str], key: str, value: bool) -> None:
|
129
|
+
"""Add boolean argument to the list."""
|
130
|
+
if key in self.boolean_actions:
|
131
|
+
# For boolean actions, always add the flag and value
|
132
|
+
args.extend([f"--{key}", str(value).lower()])
|
133
|
+
else:
|
134
|
+
# For regular booleans, only add flag if True
|
135
|
+
if value:
|
136
|
+
args.append(f"--{key}")
|
137
|
+
|
138
|
+
def _add_list_arg(self, args: List[str], key: str, value: List[Any]) -> None:
|
139
|
+
"""Add list argument to the list."""
|
140
|
+
if value: # Only add if list is not empty
|
141
|
+
args.append(f"--{key}")
|
142
|
+
args.extend(str(item) for item in value)
|
143
|
+
|
144
|
+
def _add_scalar_arg(self, args: List[str], key: str, value: Any) -> None:
|
145
|
+
"""Add scalar argument to the list."""
|
146
|
+
args.extend([f"--{key}", str(value)])
|
@@ -0,0 +1,151 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
7
|
+
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
8
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
9
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
|
+
from sglang.srt.utils import get_int_env_var
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
15
|
+
|
16
|
+
|
17
|
+
class SboFlags:
|
18
|
+
# TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
|
19
|
+
|
20
|
+
@classmethod
|
21
|
+
def enable_combine_down_gemm_two_stream_overlap(cls):
|
22
|
+
return (
|
23
|
+
is_sbo_enabled()
|
24
|
+
# currently only cutedsl backend supports it
|
25
|
+
and get_moe_runner_backend().is_flashinfer_cutedsl()
|
26
|
+
)
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def enable_combine_shared_two_stream_overlap(cls):
|
30
|
+
return is_sbo_enabled()
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def fuse_shared_experts_inside_sbo(cls):
|
34
|
+
# TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
|
35
|
+
return cls.enable_combine_shared_two_stream_overlap()
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class CombineOverlapArgs:
|
40
|
+
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
|
41
|
+
overlap: bool
|
42
|
+
stream: torch.cuda.Stream
|
43
|
+
wait_event: torch.cuda.Event
|
44
|
+
num_sms: int
|
45
|
+
signal: Optional[torch.Tensor] = None
|
46
|
+
threshold: int = -1
|
47
|
+
|
48
|
+
|
49
|
+
@dataclass
|
50
|
+
class DownGemmOverlapArgs:
|
51
|
+
num_sms: int
|
52
|
+
signal: torch.Tensor
|
53
|
+
start_event: torch.cuda.Event
|
54
|
+
|
55
|
+
|
56
|
+
def execute_sbo(
|
57
|
+
forward_shared_experts: Callable[[], Any],
|
58
|
+
experts: "DeepEPMoE",
|
59
|
+
hidden_states: torch.Tensor,
|
60
|
+
topk_idx: torch.Tensor,
|
61
|
+
topk_weights: torch.Tensor,
|
62
|
+
forward_batch: ForwardBatch,
|
63
|
+
alt_stream: Optional = None,
|
64
|
+
):
|
65
|
+
shared_output = None
|
66
|
+
|
67
|
+
dispatch_output = experts.dispatch(
|
68
|
+
hidden_states, topk_idx, topk_weights, forward_batch
|
69
|
+
)
|
70
|
+
|
71
|
+
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
72
|
+
_compute_overlap_args(dispatch_output, alt_stream)
|
73
|
+
)
|
74
|
+
|
75
|
+
hidden_states = experts.moe_impl(
|
76
|
+
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
77
|
+
)
|
78
|
+
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
79
|
+
e.record()
|
80
|
+
|
81
|
+
if SboFlags.enable_combine_shared_two_stream_overlap():
|
82
|
+
# TODO reduce sm for non-deepgemm
|
83
|
+
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
84
|
+
meta_overlap_args["compute_num_sms"]
|
85
|
+
):
|
86
|
+
shared_output = forward_shared_experts()
|
87
|
+
|
88
|
+
hidden_states = experts.combine(
|
89
|
+
hidden_states,
|
90
|
+
dispatch_output.topk_idx,
|
91
|
+
dispatch_output.topk_weights,
|
92
|
+
forward_batch,
|
93
|
+
overlap_args=combine_overlap_args,
|
94
|
+
)
|
95
|
+
|
96
|
+
return hidden_states, shared_output
|
97
|
+
|
98
|
+
|
99
|
+
def _compute_overlap_args(dispatch_output, alt_stream):
|
100
|
+
if not (
|
101
|
+
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
102
|
+
or SboFlags.enable_combine_shared_two_stream_overlap()
|
103
|
+
):
|
104
|
+
return None, None, {}
|
105
|
+
|
106
|
+
hidden_states = dispatch_output.hidden_states_fp8
|
107
|
+
if isinstance(hidden_states, tuple):
|
108
|
+
hidden_states = hidden_states[0]
|
109
|
+
|
110
|
+
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
111
|
+
|
112
|
+
total_num_sms = torch.cuda.get_device_properties(
|
113
|
+
device="cuda"
|
114
|
+
).multi_processor_count
|
115
|
+
communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
|
116
|
+
compute_num_sms = total_num_sms - communicate_num_sms
|
117
|
+
|
118
|
+
assert alt_stream is not None
|
119
|
+
combine_wait_event = torch.cuda.Event()
|
120
|
+
combine_overlap_args = CombineOverlapArgs(
|
121
|
+
overlap=False,
|
122
|
+
num_sms=communicate_num_sms,
|
123
|
+
stream=alt_stream,
|
124
|
+
wait_event=combine_wait_event,
|
125
|
+
)
|
126
|
+
meta_overlap_args = dict(
|
127
|
+
compute_num_sms=compute_num_sms,
|
128
|
+
)
|
129
|
+
down_gemm_overlap_args = None
|
130
|
+
|
131
|
+
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
|
132
|
+
# TODO use zero_allocator to remove this `torch.zeros` call
|
133
|
+
# NOTE ours v2 use uint32 not int32 currently
|
134
|
+
combine_signal = torch.zeros(
|
135
|
+
num_local_experts, dtype=torch.uint32, device=hidden_states.device
|
136
|
+
)
|
137
|
+
|
138
|
+
down_gemm_overlap_args = DownGemmOverlapArgs(
|
139
|
+
signal=combine_signal,
|
140
|
+
start_event=combine_wait_event,
|
141
|
+
num_sms=compute_num_sms,
|
142
|
+
)
|
143
|
+
combine_overlap_args.overlap = True
|
144
|
+
combine_overlap_args.signal = combine_signal
|
145
|
+
combine_overlap_args.threshold = compute_num_sms
|
146
|
+
else:
|
147
|
+
meta_overlap_args |= dict(
|
148
|
+
record_event_after_down=combine_wait_event,
|
149
|
+
)
|
150
|
+
|
151
|
+
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
|
@@ -0,0 +1,374 @@
|
|
1
|
+
#include "ngram.h"
|
2
|
+
|
3
|
+
#include <algorithm>
|
4
|
+
#include <cstring>
|
5
|
+
#include <limits>
|
6
|
+
#include <queue>
|
7
|
+
#include <vector>
|
8
|
+
|
9
|
+
namespace ngram {
|
10
|
+
|
11
|
+
struct Node {
|
12
|
+
std::unordered_map<int32_t, int32_t> next;
|
13
|
+
};
|
14
|
+
|
15
|
+
Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
16
|
+
Ngram::Result info;
|
17
|
+
std::vector<int32_t> prevs;
|
18
|
+
info.token.reserve(draft_token_num);
|
19
|
+
prevs.reserve(draft_token_num);
|
20
|
+
std::queue<std::tuple<int32_t, int32_t, int32_t>> queue;
|
21
|
+
info.token.emplace_back(last_token);
|
22
|
+
prevs.emplace_back(-1);
|
23
|
+
|
24
|
+
for (auto [token, next] : tree[root].next) {
|
25
|
+
queue.emplace(token, next, 0);
|
26
|
+
}
|
27
|
+
while (queue.size()) {
|
28
|
+
auto [token, next, prev] = queue.front();
|
29
|
+
queue.pop();
|
30
|
+
info.token.emplace_back(token);
|
31
|
+
prevs.emplace_back(prev);
|
32
|
+
for (auto [t, n] : tree[next].next) {
|
33
|
+
queue.emplace(t, n, info.token.size() - 1);
|
34
|
+
}
|
35
|
+
}
|
36
|
+
|
37
|
+
// zero padding to length
|
38
|
+
while (info.token.size() < draft_token_num) {
|
39
|
+
info.token.emplace_back(0);
|
40
|
+
prevs.emplace_back(0);
|
41
|
+
}
|
42
|
+
|
43
|
+
int n = info.token.size();
|
44
|
+
info.mask.resize(n * n, 0);
|
45
|
+
info.mask[0] = 1;
|
46
|
+
for (int i = 0; i < n; ++i) {
|
47
|
+
if (prevs[i] != -1) {
|
48
|
+
memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1);
|
49
|
+
}
|
50
|
+
info.mask[i * n + i] = 1;
|
51
|
+
}
|
52
|
+
|
53
|
+
return info;
|
54
|
+
}
|
55
|
+
|
56
|
+
Ngram::Ngram(size_t capacity, const Param& param) {
|
57
|
+
param_ = param;
|
58
|
+
nodes_.resize(capacity);
|
59
|
+
for (auto& node : nodes_) {
|
60
|
+
node_pool_.emplace_back(&node);
|
61
|
+
}
|
62
|
+
free_node_count_ = node_pool_.size();
|
63
|
+
root_ = getNode();
|
64
|
+
|
65
|
+
if (!(param_.branch_length > 1)) {
|
66
|
+
throw std::runtime_error(
|
67
|
+
"param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
|
68
|
+
}
|
69
|
+
if (!(param_.min_match_window_size > 0)) {
|
70
|
+
throw std::runtime_error(
|
71
|
+
"min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
|
72
|
+
}
|
73
|
+
if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
|
74
|
+
throw std::runtime_error(
|
75
|
+
"min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " +
|
76
|
+
std::to_string(param_.min_match_window_size) +
|
77
|
+
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
|
78
|
+
}
|
79
|
+
if (!(param_.max_match_window_size < param_.branch_length)) {
|
80
|
+
throw std::runtime_error(
|
81
|
+
"max_match_window_size must be less than branch_length, current max_match_window_size: " +
|
82
|
+
std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
|
83
|
+
}
|
84
|
+
if (!(param_.min_bfs_breadth > 0)) {
|
85
|
+
throw std::runtime_error(
|
86
|
+
"min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
|
87
|
+
}
|
88
|
+
if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) {
|
89
|
+
throw std::runtime_error(
|
90
|
+
"min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " +
|
91
|
+
std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth));
|
92
|
+
}
|
93
|
+
if (!(param_.draft_token_num > 0)) {
|
94
|
+
throw std::runtime_error(
|
95
|
+
"draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num));
|
96
|
+
}
|
97
|
+
for (auto config : param_.batch_draft_token_num) {
|
98
|
+
if (config != std::numeric_limits<decltype(config)>::max()) {
|
99
|
+
if (!(config <= param_.draft_token_num)) {
|
100
|
+
throw std::runtime_error(
|
101
|
+
"batch_draft_token_num config value " + std::to_string(config) +
|
102
|
+
" must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num));
|
103
|
+
}
|
104
|
+
}
|
105
|
+
}
|
106
|
+
for (auto config : param_.batch_min_match_window_size) {
|
107
|
+
if (config != std::numeric_limits<decltype(config)>::max()) {
|
108
|
+
if (!(config >= param_.min_match_window_size)) {
|
109
|
+
throw std::runtime_error(
|
110
|
+
"batch_min_match_window_size config value " + std::to_string(config) +
|
111
|
+
" must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
|
112
|
+
}
|
113
|
+
if (!(config <= param_.max_match_window_size)) {
|
114
|
+
throw std::runtime_error(
|
115
|
+
"batch_min_match_window_size config value " + std::to_string(config) +
|
116
|
+
" must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
quit_flag_ = false;
|
122
|
+
insert_worker_ = std::thread(&Ngram::insert, this);
|
123
|
+
}
|
124
|
+
|
125
|
+
Ngram::~Ngram() {
|
126
|
+
quit_flag_ = true;
|
127
|
+
insert_queue_.close();
|
128
|
+
insert_worker_.join();
|
129
|
+
}
|
130
|
+
|
131
|
+
std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
132
|
+
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
133
|
+
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
|
134
|
+
auto max_match_window_size = param_.max_match_window_size;
|
135
|
+
std::vector<std::pair<TrieNode*, int32_t>> result;
|
136
|
+
result.reserve(param_.max_match_window_size - param_.min_match_window_size);
|
137
|
+
for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size);
|
138
|
+
match_window_size >= param_.min_match_window_size;
|
139
|
+
--match_window_size) {
|
140
|
+
auto start = tokens.data() + tokens.size() - match_window_size;
|
141
|
+
auto end = start + match_window_size;
|
142
|
+
auto cursor = root_;
|
143
|
+
while (start != end) {
|
144
|
+
auto iter = cursor->child.find(*start);
|
145
|
+
if (iter == cursor->child.end()) {
|
146
|
+
cursor = nullptr;
|
147
|
+
break;
|
148
|
+
}
|
149
|
+
++start;
|
150
|
+
cursor = iter->second;
|
151
|
+
}
|
152
|
+
if (cursor) {
|
153
|
+
result.emplace_back(std::make_pair(cursor, match_window_size));
|
154
|
+
}
|
155
|
+
}
|
156
|
+
return result;
|
157
|
+
}
|
158
|
+
|
159
|
+
void Ngram::squeeze(size_t count) {
|
160
|
+
if (!(node_pool_.size() >= free_node_count_ + count)) {
|
161
|
+
throw std::runtime_error(
|
162
|
+
"Insufficient node size to release required nodes. "
|
163
|
+
"available to release: " +
|
164
|
+
std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count));
|
165
|
+
}
|
166
|
+
while (count--) {
|
167
|
+
auto last = global_lru_.back();
|
168
|
+
global_lru_.pop_back();
|
169
|
+
|
170
|
+
if (!last->child.empty()) {
|
171
|
+
throw std::runtime_error("The node to be released still has child nodes and cannot be released. ");
|
172
|
+
}
|
173
|
+
|
174
|
+
last->parent->lru.erase(last->parent_lru_pos);
|
175
|
+
last->parent->sorted_children.erase(last);
|
176
|
+
last->parent->child.erase(last->token);
|
177
|
+
|
178
|
+
node_pool_[free_node_count_++] = last;
|
179
|
+
}
|
180
|
+
}
|
181
|
+
|
182
|
+
void Ngram::synchronize() const {
|
183
|
+
while (!insert_queue_.empty()) {
|
184
|
+
std::this_thread::sleep_for(std::chrono::microseconds(10));
|
185
|
+
}
|
186
|
+
}
|
187
|
+
|
188
|
+
void Ngram::insert() {
|
189
|
+
while (!quit_flag_) {
|
190
|
+
std::vector<int32_t> data;
|
191
|
+
if (!insert_queue_.dequeue(data)) {
|
192
|
+
continue;
|
193
|
+
}
|
194
|
+
const auto* token = data.data();
|
195
|
+
size_t size = data.size();
|
196
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
197
|
+
|
198
|
+
for (size_t i = 0; i + param_.min_match_window_size < size; ++i) {
|
199
|
+
auto start = token + i;
|
200
|
+
auto end = start + std::min(size - i, param_.branch_length);
|
201
|
+
|
202
|
+
if (end - start > free_node_count_) {
|
203
|
+
squeeze(end - start - free_node_count_);
|
204
|
+
}
|
205
|
+
|
206
|
+
TrieNode* cursor = root_;
|
207
|
+
path_.clear();
|
208
|
+
while (start != end) {
|
209
|
+
auto token = *start;
|
210
|
+
auto iter = cursor->child.find(token);
|
211
|
+
if (iter == cursor->child.end()) {
|
212
|
+
iter = cursor->child.insert({token, getNode()}).first;
|
213
|
+
auto node = iter->second;
|
214
|
+
|
215
|
+
cursor->lru.emplace_front(node);
|
216
|
+
global_lru_.emplace_back(node);
|
217
|
+
|
218
|
+
node->token = token;
|
219
|
+
node->parent = cursor;
|
220
|
+
node->parent_lru_pos = cursor->lru.begin();
|
221
|
+
node->global_lru_pos = --global_lru_.end();
|
222
|
+
node->freq = 1;
|
223
|
+
cursor->sorted_children.insert(node);
|
224
|
+
} else {
|
225
|
+
auto node = iter->second;
|
226
|
+
cursor->sorted_children.erase(node);
|
227
|
+
node->freq++;
|
228
|
+
cursor->sorted_children.insert(node);
|
229
|
+
cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos);
|
230
|
+
}
|
231
|
+
cursor = iter->second;
|
232
|
+
path_.emplace_back(cursor);
|
233
|
+
++start;
|
234
|
+
}
|
235
|
+
|
236
|
+
for (auto it = path_.rbegin(); it != path_.rend(); ++it) {
|
237
|
+
TrieNode* node = *it;
|
238
|
+
global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos);
|
239
|
+
}
|
240
|
+
}
|
241
|
+
}
|
242
|
+
}
|
243
|
+
|
244
|
+
void Ngram::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
245
|
+
for (auto&& token : tokens) {
|
246
|
+
insert_queue_.enqueue(std::move(token));
|
247
|
+
}
|
248
|
+
}
|
249
|
+
|
250
|
+
Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
251
|
+
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
252
|
+
|
253
|
+
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
|
254
|
+
(param_.max_match_window_size - param_.min_match_window_size + 1);
|
255
|
+
|
256
|
+
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
257
|
+
std::vector<Node> tree(draft_token_num + 1);
|
258
|
+
int root = 0;
|
259
|
+
int cursor = 1;
|
260
|
+
|
261
|
+
for (auto [node, depth] : nodes) {
|
262
|
+
std::queue<std::tuple<int32_t, double, const TrieNode*>> queue; // parent, bfs_breadth, node
|
263
|
+
queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node});
|
264
|
+
while (queue.size() && cursor <= draft_token_num) {
|
265
|
+
auto front = queue.front();
|
266
|
+
queue.pop();
|
267
|
+
|
268
|
+
auto parent = std::get<0>(front);
|
269
|
+
auto cur_breadth = std::get<1>(front);
|
270
|
+
auto iter = std::get<2>(front)->lru.begin();
|
271
|
+
|
272
|
+
auto breadth = std::max(1, int32_t(cur_breadth));
|
273
|
+
for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) {
|
274
|
+
auto token = (*iter)->token;
|
275
|
+
auto pos = -1;
|
276
|
+
if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) {
|
277
|
+
pos = tit->second;
|
278
|
+
} else {
|
279
|
+
pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second;
|
280
|
+
}
|
281
|
+
queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter);
|
282
|
+
}
|
283
|
+
}
|
284
|
+
}
|
285
|
+
|
286
|
+
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
287
|
+
}
|
288
|
+
|
289
|
+
Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
290
|
+
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
291
|
+
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
292
|
+
|
293
|
+
struct CompareByLastDouble {
|
294
|
+
bool operator()(
|
295
|
+
const std::tuple<double, const TrieNode*, double>& a, // parent_pos, node, final_prob
|
296
|
+
const std::tuple<double, const TrieNode*, double>& b) const {
|
297
|
+
return std::get<2>(a) < std::get<2>(b);
|
298
|
+
}
|
299
|
+
};
|
300
|
+
|
301
|
+
std::priority_queue<
|
302
|
+
std::tuple<double, const TrieNode*, double>,
|
303
|
+
std::vector<std::tuple<double, const TrieNode*, double>>,
|
304
|
+
CompareByLastDouble>
|
305
|
+
heap;
|
306
|
+
|
307
|
+
std::vector<Node> tree(draft_token_num + 1);
|
308
|
+
|
309
|
+
int root = 0;
|
310
|
+
int cursor = 1;
|
311
|
+
int top_k = param_.max_bfs_breadth;
|
312
|
+
|
313
|
+
auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void {
|
314
|
+
double sum_freq = 0.0;
|
315
|
+
int count = 0;
|
316
|
+
std::list<std::pair<TrieNode*, int32_t>> topk_children;
|
317
|
+
for (auto* child : trie_node->sorted_children) {
|
318
|
+
sum_freq += static_cast<double>(child->freq);
|
319
|
+
topk_children.emplace_back(child, child->freq);
|
320
|
+
if (++count >= top_k) break;
|
321
|
+
}
|
322
|
+
if (sum_freq <= 0) sum_freq = 1.0;
|
323
|
+
for (const auto& [child, freq] : topk_children) {
|
324
|
+
double norm_freq = static_cast<double>(freq) / sum_freq * prob;
|
325
|
+
heap.emplace(parent, child, norm_freq);
|
326
|
+
}
|
327
|
+
};
|
328
|
+
|
329
|
+
for (auto [node, _] : nodes) {
|
330
|
+
addToHeap(root, node, 1.0);
|
331
|
+
|
332
|
+
while (!heap.empty() && cursor <= draft_token_num) {
|
333
|
+
auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob
|
334
|
+
heap.pop();
|
335
|
+
auto token = trie_node->token;
|
336
|
+
int pos = -1;
|
337
|
+
auto tit = tree[parent].next.find(token);
|
338
|
+
if (tit != tree[parent].next.end()) {
|
339
|
+
pos = tit->second;
|
340
|
+
} else {
|
341
|
+
pos = cursor++;
|
342
|
+
tree[parent].next[token] = pos;
|
343
|
+
}
|
344
|
+
addToHeap(pos, trie_node, prob);
|
345
|
+
}
|
346
|
+
}
|
347
|
+
|
348
|
+
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
349
|
+
}
|
350
|
+
|
351
|
+
Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
352
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
353
|
+
Result merged_result;
|
354
|
+
auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb;
|
355
|
+
for (const auto& tks : tokens) {
|
356
|
+
Result res = (this->*match_func)(tks, tokens.size());
|
357
|
+
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
|
358
|
+
merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end());
|
359
|
+
}
|
360
|
+
return merged_result;
|
361
|
+
}
|
362
|
+
|
363
|
+
void Ngram::Result::truncate(size_t n) {
|
364
|
+
if (n < token.size()) {
|
365
|
+
int full_n = token.size();
|
366
|
+
for (int i = 1; i < n; ++i) {
|
367
|
+
memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n);
|
368
|
+
}
|
369
|
+
token.resize(n);
|
370
|
+
mask.resize(n * n);
|
371
|
+
}
|
372
|
+
}
|
373
|
+
|
374
|
+
} // namespace ngram
|