sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <cstddef>
|
4
|
+
#include <cstdint>
|
5
|
+
#include <functional>
|
6
|
+
#include <list>
|
7
|
+
#include <mutex>
|
8
|
+
#include <set>
|
9
|
+
#include <sstream>
|
10
|
+
#include <thread>
|
11
|
+
#include <tuple>
|
12
|
+
#include <unordered_map>
|
13
|
+
#include <vector>
|
14
|
+
|
15
|
+
#include "param.h"
|
16
|
+
#include "queue.h"
|
17
|
+
|
18
|
+
namespace ngram {
|
19
|
+
|
20
|
+
struct TrieNode {
|
21
|
+
std::unordered_map<int32_t, TrieNode*> child;
|
22
|
+
std::list<TrieNode*>::const_iterator global_lru_pos;
|
23
|
+
std::list<TrieNode*>::const_iterator parent_lru_pos;
|
24
|
+
int32_t token;
|
25
|
+
TrieNode* parent;
|
26
|
+
std::list<TrieNode*> lru;
|
27
|
+
int32_t freq = 0;
|
28
|
+
|
29
|
+
struct CompareByFreq {
|
30
|
+
bool operator()(TrieNode* a, TrieNode* b) const {
|
31
|
+
return std::tie(b->freq, a->token, a) < std::tie(a->freq, b->token, b);
|
32
|
+
}
|
33
|
+
};
|
34
|
+
std::multiset<TrieNode*, CompareByFreq> sorted_children;
|
35
|
+
};
|
36
|
+
|
37
|
+
class Ngram {
|
38
|
+
std::vector<TrieNode> nodes_;
|
39
|
+
std::vector<TrieNode*> node_pool_;
|
40
|
+
size_t free_node_count_;
|
41
|
+
std::list<TrieNode*> global_lru_;
|
42
|
+
TrieNode* root_;
|
43
|
+
std::vector<TrieNode*> path_;
|
44
|
+
Param param_;
|
45
|
+
|
46
|
+
std::vector<std::pair<TrieNode*, int32_t>> match(const std::vector<int32_t>& tokens, size_t batch_size) const;
|
47
|
+
|
48
|
+
void squeeze(size_t count);
|
49
|
+
|
50
|
+
TrieNode* getNode() {
|
51
|
+
auto node = node_pool_[--free_node_count_];
|
52
|
+
node->~TrieNode();
|
53
|
+
new (node) TrieNode();
|
54
|
+
return node;
|
55
|
+
}
|
56
|
+
|
57
|
+
mutable std::mutex mutex_;
|
58
|
+
bool quit_flag_;
|
59
|
+
utils::Queue<std::vector<int32_t>> insert_queue_;
|
60
|
+
std::thread insert_worker_;
|
61
|
+
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
|
62
|
+
|
63
|
+
public:
|
64
|
+
Ngram(size_t capacity, const Param& param);
|
65
|
+
Ngram() = default;
|
66
|
+
~Ngram();
|
67
|
+
|
68
|
+
static Ngram& instance() {
|
69
|
+
static Ngram instance;
|
70
|
+
return instance;
|
71
|
+
}
|
72
|
+
|
73
|
+
void synchronize() const;
|
74
|
+
|
75
|
+
void asyncInsert(std::vector<std::vector<int32_t>>&& tokens);
|
76
|
+
|
77
|
+
struct Result {
|
78
|
+
std::vector<int32_t> token;
|
79
|
+
std::vector<uint8_t> mask;
|
80
|
+
|
81
|
+
void truncate(size_t n);
|
82
|
+
};
|
83
|
+
|
84
|
+
Result batchMatch(const std::vector<std::vector<int32_t>>& tokens) const;
|
85
|
+
|
86
|
+
void reset() {
|
87
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
88
|
+
|
89
|
+
global_lru_.clear();
|
90
|
+
path_.clear();
|
91
|
+
node_pool_.clear();
|
92
|
+
for (auto& node : nodes_) {
|
93
|
+
node_pool_.emplace_back(&node);
|
94
|
+
}
|
95
|
+
free_node_count_ = node_pool_.size();
|
96
|
+
root_ = getNode();
|
97
|
+
}
|
98
|
+
|
99
|
+
const Param& param() const {
|
100
|
+
return param_;
|
101
|
+
}
|
102
|
+
|
103
|
+
private:
|
104
|
+
Result matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const;
|
105
|
+
Result matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const;
|
106
|
+
|
107
|
+
void insert();
|
108
|
+
};
|
109
|
+
|
110
|
+
} // namespace ngram
|
@@ -0,0 +1,138 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import List, Tuple
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from torch.utils.cpp_extension import load
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
13
|
+
ngram_cache_cpp = load(
|
14
|
+
name="ngram_cache_cpp",
|
15
|
+
sources=[
|
16
|
+
f"{_abs_path}/ngram_cache_binding.cpp",
|
17
|
+
f"{_abs_path}/ngram.cpp",
|
18
|
+
],
|
19
|
+
extra_cflags=["-O3", "-std=c++20"],
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class NgramCache:
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
branch_length=18,
|
27
|
+
min_match_window_size=1,
|
28
|
+
max_match_window_size=10,
|
29
|
+
min_bfs_breadth=1,
|
30
|
+
max_bfs_breadth=8,
|
31
|
+
draft_token_num=8,
|
32
|
+
match_type="BFS",
|
33
|
+
capacity=1000000,
|
34
|
+
):
|
35
|
+
param = ngram_cache_cpp.Param()
|
36
|
+
param.branch_length = branch_length
|
37
|
+
param.min_match_window_size = min_match_window_size
|
38
|
+
param.max_match_window_size = max_match_window_size
|
39
|
+
param.min_bfs_breadth = min_bfs_breadth
|
40
|
+
param.max_bfs_breadth = max_bfs_breadth
|
41
|
+
param.draft_token_num = draft_token_num
|
42
|
+
param.match_type = match_type
|
43
|
+
self.cache = ngram_cache_cpp.Ngram(capacity, param)
|
44
|
+
|
45
|
+
self.default_mask = np.ones((1, 1), dtype=np.int64)
|
46
|
+
self.draft_token_num = draft_token_num
|
47
|
+
|
48
|
+
def batch_put(self, batch_tokens: List[List[int]]):
|
49
|
+
self.cache.asyncInsert(batch_tokens)
|
50
|
+
|
51
|
+
def synchronize(self):
|
52
|
+
self.cache.synchronize()
|
53
|
+
|
54
|
+
def reset(self):
|
55
|
+
self.cache.reset()
|
56
|
+
|
57
|
+
def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
|
58
|
+
result = self.cache.batchMatch(batch_tokens)
|
59
|
+
return np.array(result.token), np.array(result.mask)
|
60
|
+
|
61
|
+
def leaf_paths_from_mask(
|
62
|
+
self, tokens: List[int], tree_mask: List[List[int]]
|
63
|
+
) -> List[List[int]]:
|
64
|
+
"""
|
65
|
+
Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path).
|
66
|
+
|
67
|
+
Args:
|
68
|
+
mask : List[List[int]] # nxn binary matrix
|
69
|
+
tokens : List[int] # token list corresponding to columns
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
List[List[int]] # token lists of only the leaf paths, preserving their order of appearance
|
73
|
+
"""
|
74
|
+
|
75
|
+
row_sets = [
|
76
|
+
(i, {idx for idx, v in enumerate(row) if v == 1})
|
77
|
+
for i, row in enumerate(tree_mask)
|
78
|
+
]
|
79
|
+
leaf_sets = []
|
80
|
+
leaf_rows = []
|
81
|
+
|
82
|
+
for i, cur_set in reversed(row_sets):
|
83
|
+
if any(cur_set <= kept for kept in leaf_sets):
|
84
|
+
continue
|
85
|
+
leaf_sets.append(cur_set)
|
86
|
+
leaf_rows.append(i)
|
87
|
+
|
88
|
+
leaf_rows.reverse()
|
89
|
+
result = []
|
90
|
+
for r in leaf_rows:
|
91
|
+
path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1]
|
92
|
+
result.append(path)
|
93
|
+
|
94
|
+
return result
|
95
|
+
|
96
|
+
def debug_result(
|
97
|
+
self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None
|
98
|
+
):
|
99
|
+
decoding_ids = decoding_ids.reshape(-1, self.draft_token_num)
|
100
|
+
decoding_masks = decoding_masks.reshape(
|
101
|
+
-1, self.draft_token_num, self.draft_token_num
|
102
|
+
)
|
103
|
+
logger.info(f"\n{decoding_ids=}\n{decoding_masks=}")
|
104
|
+
for i in range(decoding_ids.shape[0]):
|
105
|
+
leaf_paths = self.leaf_paths_from_mask(
|
106
|
+
decoding_ids[i].tolist(), decoding_masks[i].tolist()
|
107
|
+
)
|
108
|
+
if tokenizer is None:
|
109
|
+
logger.info(f"draft path {i}: {leaf_paths}")
|
110
|
+
else:
|
111
|
+
logger.info(f"result {i}:")
|
112
|
+
for leaf_path in leaf_paths:
|
113
|
+
logger.info(
|
114
|
+
f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}"
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
# main function
|
119
|
+
if __name__ == "__main__":
|
120
|
+
format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
121
|
+
logging.basicConfig(
|
122
|
+
level=logging.DEBUG,
|
123
|
+
format=format,
|
124
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
125
|
+
force=True,
|
126
|
+
)
|
127
|
+
|
128
|
+
token_ids = [
|
129
|
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
130
|
+
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
|
131
|
+
]
|
132
|
+
cache = NgramCache(branch_length=12, draft_token_num=8)
|
133
|
+
cache.batch_put(token_ids)
|
134
|
+
|
135
|
+
cache.synchronize()
|
136
|
+
decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]])
|
137
|
+
|
138
|
+
cache.debug_result(decoding_ids, decoding_masks)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
#include <pybind11/pybind11.h>
|
2
|
+
#include <pybind11/stl.h>
|
3
|
+
|
4
|
+
#include "ngram.h"
|
5
|
+
|
6
|
+
PYBIND11_MODULE(ngram_cache_cpp, m) {
|
7
|
+
using namespace ngram;
|
8
|
+
namespace py = pybind11;
|
9
|
+
m.doc() = "";
|
10
|
+
|
11
|
+
py::class_<Ngram>(m, "Ngram")
|
12
|
+
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
|
13
|
+
.def("asyncInsert", &Ngram::asyncInsert, "")
|
14
|
+
.def("batchMatch", &Ngram::batchMatch, "")
|
15
|
+
.def("reset", &Ngram::reset, "")
|
16
|
+
.def("synchronize", &Ngram::synchronize, "");
|
17
|
+
|
18
|
+
py::class_<Param>(m, "Param")
|
19
|
+
.def(py::init<>())
|
20
|
+
.def_readwrite("enable", &Param::enable)
|
21
|
+
.def_readwrite("enable_router_mode", &Param::enable_router_mode)
|
22
|
+
.def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth)
|
23
|
+
.def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
|
24
|
+
.def_readwrite("min_match_window_size", &Param::min_match_window_size)
|
25
|
+
.def_readwrite("max_match_window_size", &Param::max_match_window_size)
|
26
|
+
.def_readwrite("branch_length", &Param::branch_length)
|
27
|
+
.def_readwrite("draft_token_num", &Param::draft_token_num)
|
28
|
+
.def_readwrite("match_type", &Param::match_type)
|
29
|
+
.def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
|
30
|
+
.def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num)
|
31
|
+
.def("get_draft_token_num", &Param::get_draft_token_num, "")
|
32
|
+
.def("get_min_match_window_size", &Param::get_min_match_window_size, "")
|
33
|
+
.def("parse", &Param::parse, "")
|
34
|
+
.def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "")
|
35
|
+
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
|
36
|
+
.def("detail", &Param::detail, "");
|
37
|
+
|
38
|
+
py::class_<Ngram::Result>(m, "Result")
|
39
|
+
.def(py::init<>())
|
40
|
+
.def_readwrite("token", &Ngram::Result::token)
|
41
|
+
.def_readwrite("mask", &Ngram::Result::mask)
|
42
|
+
.def("truncate", &Ngram::Result::truncate);
|
43
|
+
}
|
@@ -0,0 +1,125 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <cstddef>
|
4
|
+
#include <iostream>
|
5
|
+
#include <limits>
|
6
|
+
#include <regex>
|
7
|
+
#include <sstream>
|
8
|
+
#include <stdexcept>
|
9
|
+
#include <string>
|
10
|
+
#include <vector>
|
11
|
+
|
12
|
+
namespace ngram {
|
13
|
+
|
14
|
+
struct Param {
|
15
|
+
bool enable;
|
16
|
+
bool enable_router_mode;
|
17
|
+
size_t min_bfs_breadth;
|
18
|
+
size_t max_bfs_breadth;
|
19
|
+
size_t min_match_window_size;
|
20
|
+
size_t max_match_window_size;
|
21
|
+
size_t branch_length;
|
22
|
+
size_t draft_token_num;
|
23
|
+
std::string match_type;
|
24
|
+
|
25
|
+
std::vector<size_t> batch_min_match_window_size;
|
26
|
+
std::vector<size_t> batch_draft_token_num;
|
27
|
+
|
28
|
+
size_t get_draft_token_num(size_t batch_size) const {
|
29
|
+
if (batch_size < batch_draft_token_num.size()) {
|
30
|
+
if (batch_draft_token_num[batch_size] !=
|
31
|
+
std::numeric_limits<decltype(batch_draft_token_num)::value_type>::max()) {
|
32
|
+
return batch_draft_token_num[batch_size];
|
33
|
+
}
|
34
|
+
}
|
35
|
+
return draft_token_num - 1;
|
36
|
+
}
|
37
|
+
|
38
|
+
size_t get_min_match_window_size(size_t batch_size) const {
|
39
|
+
if (batch_size < batch_min_match_window_size.size()) {
|
40
|
+
if (batch_min_match_window_size[batch_size] !=
|
41
|
+
std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
|
42
|
+
return batch_min_match_window_size[batch_size];
|
43
|
+
}
|
44
|
+
}
|
45
|
+
return min_match_window_size;
|
46
|
+
}
|
47
|
+
|
48
|
+
std::vector<size_t> parse(const std::string& value) {
|
49
|
+
// 0-1|10,2-3|20,
|
50
|
+
std::vector<size_t> result;
|
51
|
+
if (value.empty()) {
|
52
|
+
return result;
|
53
|
+
}
|
54
|
+
std::vector<size_t> mark;
|
55
|
+
std::regex comma_re(",");
|
56
|
+
std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last;
|
57
|
+
for (auto p : std::vector<std::string>(first, last)) {
|
58
|
+
std::cerr << "seg " << p << std::endl;
|
59
|
+
}
|
60
|
+
for (const auto& seg : std::vector<std::string>(first, last)) {
|
61
|
+
std::regex pipe_re("\\|");
|
62
|
+
std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last;
|
63
|
+
std::vector<std::string> part(seg_first, seg_last);
|
64
|
+
for (auto p : part) {
|
65
|
+
std::cerr << "part " << p << std::endl;
|
66
|
+
}
|
67
|
+
if (part.size() != 2) {
|
68
|
+
throw std::runtime_error(
|
69
|
+
"failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size()));
|
70
|
+
}
|
71
|
+
std::regex endash_re("-");
|
72
|
+
std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last;
|
73
|
+
std::vector<std::string> range(range_first, range_last);
|
74
|
+
if (range.size() != 2) {
|
75
|
+
throw std::runtime_error("failed to get range, invalid config: " + value);
|
76
|
+
}
|
77
|
+
size_t L = std::atoi(range[0].c_str());
|
78
|
+
size_t R = std::atoi(range[1].c_str());
|
79
|
+
if (L > R || R > 128) {
|
80
|
+
throw std::runtime_error("invalid range, config: " + value);
|
81
|
+
}
|
82
|
+
if (R >= result.size()) {
|
83
|
+
result.resize(R + 1, std::numeric_limits<decltype(result)::value_type>::max());
|
84
|
+
mark.resize(result.size(), false);
|
85
|
+
}
|
86
|
+
size_t config = std::atoi(part[1].c_str());
|
87
|
+
do {
|
88
|
+
if (mark[L]) {
|
89
|
+
throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value);
|
90
|
+
}
|
91
|
+
mark[L] = true;
|
92
|
+
result[L] = config;
|
93
|
+
} while (++L <= R);
|
94
|
+
}
|
95
|
+
return result;
|
96
|
+
}
|
97
|
+
|
98
|
+
void resetBatchMinMatchWindowSize(const std::string& value) {
|
99
|
+
batch_min_match_window_size = parse(value);
|
100
|
+
}
|
101
|
+
|
102
|
+
void resetBatchReturnTokenNum(const std::string& value) {
|
103
|
+
batch_draft_token_num = parse(value);
|
104
|
+
}
|
105
|
+
|
106
|
+
std::string detail() {
|
107
|
+
std::stringstream ss;
|
108
|
+
ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
|
109
|
+
<< ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
|
110
|
+
<< ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
|
111
|
+
<< ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
|
112
|
+
<< ", match_type = " << match_type;
|
113
|
+
ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
|
114
|
+
for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
|
115
|
+
ss << i << "|" << batch_min_match_window_size[i] << ",";
|
116
|
+
}
|
117
|
+
ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
|
118
|
+
for (int i = 0; i < batch_draft_token_num.size(); ++i) {
|
119
|
+
ss << i << "|" << batch_draft_token_num[i] << ",";
|
120
|
+
}
|
121
|
+
return ss.str();
|
122
|
+
}
|
123
|
+
};
|
124
|
+
|
125
|
+
} // namespace ngram
|
@@ -0,0 +1,71 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <condition_variable>
|
4
|
+
#include <queue>
|
5
|
+
|
6
|
+
namespace utils {
|
7
|
+
|
8
|
+
template <typename T>
|
9
|
+
class Queue {
|
10
|
+
public:
|
11
|
+
bool enqueue(T&& rhs) {
|
12
|
+
{
|
13
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
14
|
+
if (closed_) {
|
15
|
+
return false;
|
16
|
+
}
|
17
|
+
queue_.emplace(std::move(rhs));
|
18
|
+
}
|
19
|
+
cv_.notify_one();
|
20
|
+
return true;
|
21
|
+
}
|
22
|
+
|
23
|
+
bool enqueue(const T& rhs) {
|
24
|
+
{
|
25
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
26
|
+
if (closed_) {
|
27
|
+
return false;
|
28
|
+
}
|
29
|
+
queue_.emplace(rhs);
|
30
|
+
}
|
31
|
+
cv_.notify_one();
|
32
|
+
return true;
|
33
|
+
}
|
34
|
+
|
35
|
+
bool dequeue(T& rhs) {
|
36
|
+
std::unique_lock<std::mutex> lock(mutex_);
|
37
|
+
cv_.wait(lock, [this] { return queue_.size() || closed_; });
|
38
|
+
if (closed_) {
|
39
|
+
return false;
|
40
|
+
}
|
41
|
+
rhs = std::move(queue_.front());
|
42
|
+
queue_.pop();
|
43
|
+
return true;
|
44
|
+
}
|
45
|
+
|
46
|
+
size_t size() const {
|
47
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
48
|
+
return queue_.size();
|
49
|
+
}
|
50
|
+
|
51
|
+
bool empty() const {
|
52
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
53
|
+
return queue_.empty();
|
54
|
+
}
|
55
|
+
|
56
|
+
void close() {
|
57
|
+
{
|
58
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
59
|
+
closed_ = true;
|
60
|
+
}
|
61
|
+
cv_.notify_all();
|
62
|
+
}
|
63
|
+
|
64
|
+
private:
|
65
|
+
std::queue<T> queue_;
|
66
|
+
mutable std::mutex mutex_;
|
67
|
+
std::condition_variable cv_;
|
68
|
+
bool closed_{false};
|
69
|
+
};
|
70
|
+
|
71
|
+
} // namespace utils
|
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
20
20
|
ForwardBatch,
|
21
21
|
ForwardMode,
|
22
22
|
)
|
23
|
-
from sglang.srt.speculative.
|
23
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
24
24
|
from sglang.srt.utils import (
|
25
25
|
require_attn_tp_gather,
|
26
26
|
require_gathered_buffer,
|
@@ -302,6 +302,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
302
302
|
if bs != raw_bs:
|
303
303
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
304
304
|
self.out_cache_loc.zero_()
|
305
|
+
self.positions.zero_()
|
305
306
|
|
306
307
|
num_tokens = bs * self.num_tokens_per_bs
|
307
308
|
|
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
21
21
|
ForwardBatch,
|
22
22
|
ForwardMode,
|
23
23
|
)
|
24
|
-
from sglang.srt.speculative.
|
24
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
25
|
+
from sglang.srt.speculative.spec_utils import fast_topk
|
25
26
|
from sglang.srt.utils import (
|
26
27
|
require_attn_tp_gather,
|
27
28
|
require_gathered_buffer,
|
@@ -331,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
331
332
|
if bs * self.num_tokens_per_bs != num_tokens:
|
332
333
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
333
334
|
self.out_cache_loc.zero_()
|
335
|
+
self.positions.zero_()
|
334
336
|
self.accept_length.fill_(1)
|
335
337
|
self.extend_seq_lens.fill_(1)
|
336
338
|
|