sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -250,7 +250,7 @@ class HiCacheController:
|
|
250
250
|
storage_backend: Optional[str] = None,
|
251
251
|
prefetch_threshold: int = 256,
|
252
252
|
model_name: Optional[str] = None,
|
253
|
-
storage_backend_extra_config: Optional[
|
253
|
+
storage_backend_extra_config: Optional[dict] = None,
|
254
254
|
):
|
255
255
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
256
256
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -275,43 +275,17 @@ class HiCacheController:
|
|
275
275
|
and self.storage_config.tp_rank != 0
|
276
276
|
)
|
277
277
|
|
278
|
-
|
279
|
-
|
278
|
+
# Use storage backend factory for dynamic backend creation
|
279
|
+
from sglang.srt.mem_cache.storage import StorageBackendFactory
|
280
280
|
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
self.storage_backend = HiCacheNixl()
|
286
|
-
elif storage_backend == "mooncake":
|
287
|
-
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
288
|
-
MooncakeStore,
|
281
|
+
try:
|
282
|
+
self.storage_backend = StorageBackendFactory.create_backend(
|
283
|
+
storage_backend, self.storage_config, self.mem_pool_host
|
289
284
|
)
|
285
|
+
except ValueError as e:
|
286
|
+
raise ValueError(f"Failed to create storage backend: {e}") from e
|
290
287
|
|
291
|
-
|
292
|
-
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
293
|
-
assert self.mem_pool_host.layout == "page_first"
|
294
|
-
elif storage_backend == "hf3fs":
|
295
|
-
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
296
|
-
HiCacheHF3FS,
|
297
|
-
)
|
298
|
-
|
299
|
-
if self.mem_pool_host.layout == "page_first":
|
300
|
-
bytes_per_page = (
|
301
|
-
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
302
|
-
)
|
303
|
-
elif self.mem_pool_host.layout == "layer_first":
|
304
|
-
bytes_per_page = (
|
305
|
-
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
306
|
-
)
|
307
|
-
dtype = mem_pool_host.dtype
|
308
|
-
self.storage_backend = HiCacheHF3FS.from_env_config(
|
309
|
-
bytes_per_page, dtype, self.storage_config
|
310
|
-
)
|
311
|
-
else:
|
312
|
-
raise NotImplementedError(
|
313
|
-
f"Unsupported storage backend: {storage_backend}"
|
314
|
-
)
|
288
|
+
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
315
289
|
|
316
290
|
self.enable_storage = True
|
317
291
|
# todo: threshold policy for prefetching
|
@@ -335,18 +309,10 @@ class HiCacheController:
|
|
335
309
|
# Select the get and set functions
|
336
310
|
self.page_get_func = self._generic_page_get
|
337
311
|
self.page_set_func = self._generic_page_set
|
338
|
-
|
339
|
-
self.
|
340
|
-
self.
|
341
|
-
|
342
|
-
)
|
343
|
-
if self.storage_backend_type == "mooncake":
|
344
|
-
self.page_get_func = self._mooncake_page_get
|
345
|
-
self.page_set_func = self._mooncake_page_set
|
346
|
-
elif self.is_3fs_zerocopy:
|
347
|
-
self.page_get_func = self._3fs_zero_copy_page_get
|
348
|
-
self.page_set_func = self._3fs_zero_copy_page_set
|
349
|
-
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
312
|
+
|
313
|
+
if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
|
314
|
+
self.page_get_func = self._page_get_zero_copy
|
315
|
+
self.page_set_func = self._page_set_zero_copy
|
350
316
|
|
351
317
|
self.device = self.mem_pool_device.device
|
352
318
|
self.layer_num = self.mem_pool_device.layer_num
|
@@ -395,7 +361,7 @@ class HiCacheController:
|
|
395
361
|
def _generate_storage_config(
|
396
362
|
self,
|
397
363
|
model_name: Optional[str] = None,
|
398
|
-
storage_backend_extra_config: Optional[
|
364
|
+
storage_backend_extra_config: Optional[dict] = None,
|
399
365
|
):
|
400
366
|
|
401
367
|
if is_dp_attention_enabled():
|
@@ -410,23 +376,13 @@ class HiCacheController:
|
|
410
376
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
411
377
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
412
378
|
|
413
|
-
# Parse extra config JSON if provided
|
414
|
-
extra_config = None
|
415
|
-
if storage_backend_extra_config:
|
416
|
-
try:
|
417
|
-
import json
|
418
|
-
|
419
|
-
extra_config = json.loads(storage_backend_extra_config)
|
420
|
-
except Exception as e:
|
421
|
-
logger.error(f"Invalid backend extra config JSON: {e}")
|
422
|
-
|
423
379
|
return HiCacheStorageConfig(
|
424
380
|
tp_rank=self.tp_rank,
|
425
381
|
tp_size=self.tp_size,
|
426
382
|
is_mla_model=is_mla_backend,
|
427
383
|
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
428
384
|
model_name=model_name,
|
429
|
-
extra_config=
|
385
|
+
extra_config=storage_backend_extra_config,
|
430
386
|
)
|
431
387
|
|
432
388
|
def reset(self):
|
@@ -470,7 +426,6 @@ class HiCacheController:
|
|
470
426
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
471
427
|
if host_indices is None:
|
472
428
|
return None
|
473
|
-
self.mem_pool_host.protect_write(host_indices)
|
474
429
|
self.write_queue.append(
|
475
430
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
476
431
|
)
|
@@ -494,7 +449,6 @@ class HiCacheController:
|
|
494
449
|
self.mem_pool_host.backup_from_device_all_layer(
|
495
450
|
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
496
451
|
)
|
497
|
-
self.mem_pool_host.complete_io(op.host_indices)
|
498
452
|
finish_event.record()
|
499
453
|
# NOTE: We must save the host indices and device indices here,
|
500
454
|
# this is because we need to guarantee that these tensors are
|
@@ -518,7 +472,6 @@ class HiCacheController:
|
|
518
472
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
519
473
|
if device_indices is None:
|
520
474
|
return None
|
521
|
-
self.mem_pool_host.protect_load(host_indices)
|
522
475
|
self.load_queue.append(
|
523
476
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
524
477
|
)
|
@@ -563,7 +516,6 @@ class HiCacheController:
|
|
563
516
|
self.io_backend,
|
564
517
|
)
|
565
518
|
producer_event.complete(i)
|
566
|
-
self.mem_pool_host.complete_io(op.host_indices)
|
567
519
|
# NOTE: We must save the host indices and device indices here,
|
568
520
|
# this is because we need to guarantee that these tensors are
|
569
521
|
# still alive when the load stream is executing.
|
@@ -581,29 +533,16 @@ class HiCacheController:
|
|
581
533
|
)
|
582
534
|
return producer_id
|
583
535
|
|
584
|
-
def evict_device(
|
585
|
-
self
|
586
|
-
|
587
|
-
if self.mem_pool_host.is_synced(host_indices):
|
588
|
-
self.mem_pool_device_allocator.free(device_indices)
|
589
|
-
self.mem_pool_host.update_backup(host_indices)
|
590
|
-
return len(device_indices)
|
591
|
-
else:
|
592
|
-
raise ValueError(
|
593
|
-
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
594
|
-
)
|
536
|
+
def evict_device(self, device_indices: torch.Tensor) -> int:
|
537
|
+
self.mem_pool_device_allocator.free(device_indices)
|
538
|
+
return len(device_indices)
|
595
539
|
|
596
540
|
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
597
541
|
if not backup_only:
|
598
542
|
raise ValueError("Other eviction policies are not supported yet.")
|
599
543
|
|
600
|
-
|
601
|
-
|
602
|
-
return len(host_indices)
|
603
|
-
else:
|
604
|
-
raise ValueError(
|
605
|
-
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
606
|
-
)
|
544
|
+
self.mem_pool_host.free(host_indices)
|
545
|
+
return len(host_indices)
|
607
546
|
|
608
547
|
def prefetch(
|
609
548
|
self,
|
@@ -626,46 +565,25 @@ class HiCacheController:
|
|
626
565
|
return operation.completed_tokens, operation.hash_value
|
627
566
|
|
628
567
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
635
|
-
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
636
|
-
return hit_page_num
|
637
|
-
|
638
|
-
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
639
|
-
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
640
|
-
hash_values, host_indices
|
641
|
-
)
|
642
|
-
page_data = self.storage_backend.batch_get(hashes, dsts)
|
643
|
-
if page_data:
|
644
|
-
inc = self.page_size * len(hashes) // factor
|
645
|
-
operation.increment(inc)
|
646
|
-
else:
|
647
|
-
logger.warning(
|
648
|
-
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
649
|
-
)
|
568
|
+
if host_indices.numel() == 0:
|
569
|
+
return
|
570
|
+
pages = host_indices.split(self.mem_pool_host.page_size)
|
571
|
+
for page in pages:
|
572
|
+
self.host_mem_release_queue.put(page)
|
650
573
|
|
651
|
-
def
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
)
|
662
|
-
if get_result != len(hash_values):
|
663
|
-
logger.warning(
|
664
|
-
f"Prefetch operation {operation.request_id} failed or partially failed."
|
665
|
-
)
|
666
|
-
if get_result != 0:
|
667
|
-
operation.increment(get_result * self.page_size)
|
574
|
+
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
575
|
+
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
576
|
+
inc = 0
|
577
|
+
for i in range(len(hash_values)):
|
578
|
+
if not results[i]:
|
579
|
+
logger.warning(
|
580
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
581
|
+
)
|
582
|
+
break
|
583
|
+
inc += self.page_size
|
584
|
+
operation.increment(inc)
|
668
585
|
|
586
|
+
# todo: deprecate
|
669
587
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
670
588
|
dummy_page_dst = [
|
671
589
|
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
@@ -755,7 +673,7 @@ class HiCacheController:
|
|
755
673
|
batch_tokens[i : i + self.page_size], last_hash
|
756
674
|
)
|
757
675
|
batch_hashes.append(last_hash)
|
758
|
-
hit_page_num = self.
|
676
|
+
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
759
677
|
hash_value.extend(batch_hashes[:hit_page_num])
|
760
678
|
storage_query_count += hit_page_num * self.page_size
|
761
679
|
if hit_page_num < len(batch_hashes):
|
@@ -824,34 +742,16 @@ class HiCacheController:
|
|
824
742
|
self.backup_queue.put(operation)
|
825
743
|
return operation.id
|
826
744
|
|
827
|
-
#
|
745
|
+
# todo: deprecate
|
828
746
|
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
829
747
|
data = [
|
830
|
-
self.mem_pool_host.
|
748
|
+
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
831
749
|
for i in range(len(hash_values))
|
832
750
|
]
|
833
751
|
return self.storage_backend.batch_set(hash_values, data)
|
834
752
|
|
835
|
-
|
836
|
-
|
837
|
-
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
838
|
-
hash_values,
|
839
|
-
host_indices,
|
840
|
-
self.storage_config.tp_rank,
|
841
|
-
)
|
842
|
-
success = self.storage_backend.batch_set(
|
843
|
-
key_strs,
|
844
|
-
target_locations=buffer_ptrs,
|
845
|
-
target_sizes=buffer_sizes,
|
846
|
-
)
|
847
|
-
return success
|
848
|
-
|
849
|
-
# zero copy
|
850
|
-
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
851
|
-
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
852
|
-
hash_values, host_indices
|
853
|
-
)
|
854
|
-
return self.storage_backend.batch_set(hashes, dsts)
|
753
|
+
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
754
|
+
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
855
755
|
|
856
756
|
# Backup batch by batch
|
857
757
|
def _page_backup(self, operation):
|
@@ -17,14 +17,11 @@ import faulthandler
|
|
17
17
|
import logging
|
18
18
|
import multiprocessing as mp
|
19
19
|
import signal
|
20
|
-
import struct
|
21
|
-
import sys
|
22
20
|
import threading
|
23
21
|
import time
|
24
22
|
from collections import deque
|
25
23
|
from enum import Enum, auto
|
26
|
-
from
|
27
|
-
from typing import Dict, List
|
24
|
+
from typing import List
|
28
25
|
|
29
26
|
import psutil
|
30
27
|
import setproctitle
|
@@ -39,7 +36,6 @@ from sglang.srt.managers.io_struct import (
|
|
39
36
|
)
|
40
37
|
from sglang.srt.managers.schedule_batch import Req
|
41
38
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
42
|
-
from sglang.srt.managers.utils import DPBalanceMeta
|
43
39
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
44
40
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
45
41
|
from sglang.srt.utils import (
|
@@ -108,15 +104,9 @@ class DPBudget:
|
|
108
104
|
class DataParallelController:
|
109
105
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
110
106
|
|
111
|
-
def __init__(
|
112
|
-
self,
|
113
|
-
server_args: ServerArgs,
|
114
|
-
port_args: PortArgs,
|
115
|
-
dp_balance_meta: DPBalanceMeta,
|
116
|
-
) -> None:
|
107
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
117
108
|
# for dp balance
|
118
109
|
self.global_balance_id = 0
|
119
|
-
self.balance_meta = dp_balance_meta
|
120
110
|
|
121
111
|
# Parse args
|
122
112
|
self.max_total_num_tokens = None
|
@@ -219,7 +209,9 @@ class DataParallelController:
|
|
219
209
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
|
220
210
|
)
|
221
211
|
threads.append(thread)
|
222
|
-
base_gpu_id +=
|
212
|
+
base_gpu_id += (
|
213
|
+
server_args.tp_size * server_args.pp_size * server_args.gpu_id_step
|
214
|
+
)
|
223
215
|
|
224
216
|
# Free all sockets before starting the threads to launch TP workers
|
225
217
|
for sock in sockets:
|
@@ -322,7 +314,6 @@ class DataParallelController:
|
|
322
314
|
pp_rank,
|
323
315
|
dp_rank,
|
324
316
|
writer,
|
325
|
-
self.balance_meta,
|
326
317
|
),
|
327
318
|
)
|
328
319
|
with memory_saver_adapter.configure_subprocess():
|
@@ -370,31 +361,11 @@ class DataParallelController:
|
|
370
361
|
if self.maybe_external_dp_rank_routing(req):
|
371
362
|
return
|
372
363
|
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
|
379
|
-
return current_id
|
380
|
-
|
381
|
-
req.dp_balance_id = get_next_global_balance_id()
|
382
|
-
with self.balance_meta.mutex:
|
383
|
-
# 1. local_tokens represents the tokens currently inferring on the worker,
|
384
|
-
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
|
385
|
-
onfly_info = self.balance_meta.get_shared_onfly()
|
386
|
-
local_tokens = self.balance_meta.get_shared_local_tokens()
|
387
|
-
total_tokens = [
|
388
|
-
local_token + sum(onfly_dict.values())
|
389
|
-
for local_token, onfly_dict in zip(local_tokens, onfly_info)
|
390
|
-
]
|
391
|
-
target_worker = total_tokens.index(min(total_tokens))
|
392
|
-
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
|
393
|
-
# 2. write the new onfly info to the shm
|
394
|
-
self.balance_meta.set_shared_onfly_info(onfly_info)
|
395
|
-
|
396
|
-
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
|
397
|
-
self.workers[target_worker].send_pyobj(req)
|
364
|
+
logger.warning(
|
365
|
+
"The 'minimum_tokens' load balancing method is deprecated for now and will introduced later."
|
366
|
+
"Fall back to 'round_robin_scheduler'"
|
367
|
+
)
|
368
|
+
self.round_robin_scheduler(req)
|
398
369
|
|
399
370
|
def event_loop(self):
|
400
371
|
while True:
|
@@ -416,12 +387,9 @@ def run_data_parallel_controller_process(
|
|
416
387
|
faulthandler.enable()
|
417
388
|
configure_logger(server_args)
|
418
389
|
parent_process = psutil.Process().parent()
|
419
|
-
balance_meta = DPBalanceMeta(server_args.dp_size)
|
420
390
|
|
421
391
|
try:
|
422
|
-
controller = DataParallelController(
|
423
|
-
server_args, port_args, dp_balance_meta=balance_meta
|
424
|
-
)
|
392
|
+
controller = DataParallelController(server_args, port_args)
|
425
393
|
pipe_writer.send(
|
426
394
|
{
|
427
395
|
"status": "ready",
|
@@ -440,6 +408,3 @@ def run_data_parallel_controller_process(
|
|
440
408
|
traceback = get_exception_traceback()
|
441
409
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
442
410
|
parent_process.send_signal(signal.SIGQUIT)
|
443
|
-
finally:
|
444
|
-
# we need to destruct mp.Manager() in balance_meta
|
445
|
-
balance_meta.destructor()
|
@@ -24,13 +24,12 @@ import psutil
|
|
24
24
|
import setproctitle
|
25
25
|
import zmq
|
26
26
|
|
27
|
-
from sglang.srt.hf_transformers_utils import get_tokenizer
|
28
27
|
from sglang.srt.managers.io_struct import (
|
29
|
-
|
28
|
+
BatchEmbeddingOutput,
|
30
29
|
BatchMultimodalDecodeReq,
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
BatchMultimodalOutput,
|
31
|
+
BatchStrOutput,
|
32
|
+
BatchTokenIDOutput,
|
34
33
|
FreezeGCReq,
|
35
34
|
MultiTokenizerRegisterReq,
|
36
35
|
)
|
@@ -42,6 +41,7 @@ from sglang.srt.utils import (
|
|
42
41
|
get_zmq_socket,
|
43
42
|
kill_itself_when_parent_died,
|
44
43
|
)
|
44
|
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
45
45
|
from sglang.utils import (
|
46
46
|
TypeBasedDispatcher,
|
47
47
|
find_printable_text,
|
@@ -101,8 +101,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
101
101
|
|
102
102
|
self._request_dispatcher = TypeBasedDispatcher(
|
103
103
|
[
|
104
|
-
(
|
105
|
-
(
|
104
|
+
(BatchEmbeddingOutput, self.handle_batch_embedding_out),
|
105
|
+
(BatchTokenIDOutput, self.handle_batch_token_id_out),
|
106
106
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
107
107
|
(MultiTokenizerRegisterReq, lambda x: x),
|
108
108
|
(FreezeGCReq, self.handle_freeze_gc_req),
|
@@ -145,11 +145,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
145
145
|
return output[:-1]
|
146
146
|
return output
|
147
147
|
|
148
|
-
def handle_batch_embedding_out(self, recv_obj:
|
148
|
+
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOutput):
|
149
149
|
# If it is embedding model, no detokenization is needed.
|
150
150
|
return recv_obj
|
151
151
|
|
152
|
-
def handle_batch_token_id_out(self, recv_obj:
|
152
|
+
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
|
153
153
|
bs = len(recv_obj.rids)
|
154
154
|
|
155
155
|
# Initialize decode status
|
@@ -224,7 +224,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
224
224
|
s.sent_offset = len(output_str)
|
225
225
|
output_strs.append(incremental_output)
|
226
226
|
|
227
|
-
return
|
227
|
+
return BatchStrOutput(
|
228
228
|
rids=recv_obj.rids,
|
229
229
|
finished_reasons=recv_obj.finished_reasons,
|
230
230
|
output_strs=output_strs,
|
@@ -252,7 +252,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
|
252
252
|
|
253
253
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
254
254
|
outputs = self.tokenizer.detokenize(recv_obj)
|
255
|
-
return
|
255
|
+
return BatchMultimodalOutput(
|
256
256
|
rids=recv_obj.rids,
|
257
257
|
finished_reasons=recv_obj.finished_reasons,
|
258
258
|
outputs=outputs,
|