sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,420 +1,6 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
import logging
|
8
|
-
import random
|
9
|
-
import urllib
|
10
|
-
from itertools import chain
|
11
|
-
from typing import List, Optional
|
12
|
-
|
13
|
-
import aiohttp
|
14
|
-
import orjson
|
15
|
-
import uvicorn
|
16
|
-
from fastapi import FastAPI, HTTPException
|
17
|
-
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
18
|
-
|
19
|
-
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
|
-
from sglang.srt.utils import maybe_wrap_ipv6_address
|
21
|
-
|
22
|
-
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
23
|
-
1024 * 64
|
24
|
-
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
25
|
-
|
26
|
-
|
27
|
-
def setup_logger():
|
28
|
-
logger = logging.getLogger("pdlb")
|
29
|
-
logger.setLevel(logging.INFO)
|
30
|
-
|
31
|
-
formatter = logging.Formatter(
|
32
|
-
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
33
|
-
datefmt="%Y-%m-%d %H:%M:%S",
|
34
|
-
)
|
35
|
-
|
36
|
-
handler = logging.StreamHandler()
|
37
|
-
handler.setFormatter(formatter)
|
38
|
-
logger.addHandler(handler)
|
39
|
-
|
40
|
-
return logger
|
41
|
-
|
42
|
-
|
43
|
-
logger = setup_logger()
|
44
|
-
|
45
|
-
|
46
|
-
@dataclasses.dataclass
|
47
|
-
class PrefillConfig:
|
48
|
-
url: str
|
49
|
-
bootstrap_port: Optional[int] = None
|
50
|
-
|
51
|
-
|
52
|
-
class MiniLoadBalancer:
|
53
|
-
def __init__(
|
54
|
-
self,
|
55
|
-
prefill_configs: List[PrefillConfig],
|
56
|
-
decode_servers: List[str],
|
57
|
-
timeout: int,
|
58
|
-
):
|
59
|
-
self.prefill_configs = prefill_configs
|
60
|
-
self.prefill_servers = [p.url for p in prefill_configs]
|
61
|
-
self.decode_servers = decode_servers
|
62
|
-
self.timeout = timeout
|
63
|
-
|
64
|
-
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
65
|
-
self.prefill_configs.append(new_prefill_config)
|
66
|
-
self.prefill_servers.append(new_prefill_config.url)
|
67
|
-
|
68
|
-
def add_decode_server(self, new_decode_server: str):
|
69
|
-
self.decode_servers.append(new_decode_server)
|
70
|
-
|
71
|
-
def select_pair(self):
|
72
|
-
# TODO: return some message instead of panic
|
73
|
-
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
74
|
-
assert len(self.decode_servers) > 0, "No decode servers available"
|
75
|
-
|
76
|
-
prefill_config = random.choice(self.prefill_configs)
|
77
|
-
decode_server = random.choice(self.decode_servers)
|
78
|
-
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
79
|
-
|
80
|
-
async def generate(
|
81
|
-
self, modified_request, prefill_server, decode_server, endpoint
|
82
|
-
) -> ORJSONResponse:
|
83
|
-
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
84
|
-
|
85
|
-
async with aiohttp.ClientSession(
|
86
|
-
timeout=aiohttp.ClientTimeout(
|
87
|
-
total=self.timeout
|
88
|
-
) # Add timeout for request reliability
|
89
|
-
) as session:
|
90
|
-
tasks = [
|
91
|
-
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
92
|
-
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
93
|
-
]
|
94
|
-
|
95
|
-
# Wait for both responses to complete. Prefill should end first.
|
96
|
-
prefill_response, decode_response = await asyncio.gather(*tasks)
|
97
|
-
|
98
|
-
if "return_logprob" in modified_request:
|
99
|
-
|
100
|
-
prefill_json = await prefill_response.json()
|
101
|
-
ret_json = await decode_response.json()
|
102
|
-
|
103
|
-
# merge `meta_info.input_token_logprobs` from prefill to decode
|
104
|
-
if "meta_info" in ret_json:
|
105
|
-
if "input_token_logprobs" in ret_json["meta_info"]:
|
106
|
-
ret_json["meta_info"]["input_token_logprobs"] = (
|
107
|
-
prefill_json["meta_info"]["input_token_logprobs"]
|
108
|
-
+ ret_json["meta_info"]["input_token_logprobs"]
|
109
|
-
)
|
110
|
-
else:
|
111
|
-
ret_json = await decode_response.json()
|
112
|
-
|
113
|
-
return ORJSONResponse(
|
114
|
-
content=ret_json,
|
115
|
-
status_code=decode_response.status,
|
116
|
-
)
|
117
|
-
|
118
|
-
async def generate_stream(
|
119
|
-
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
120
|
-
):
|
121
|
-
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
122
|
-
|
123
|
-
async def stream_results():
|
124
|
-
async with aiohttp.ClientSession(
|
125
|
-
timeout=aiohttp.ClientTimeout(
|
126
|
-
total=self.timeout
|
127
|
-
) # Add timeout for request reliability
|
128
|
-
) as session:
|
129
|
-
# Create the tasks for both prefill and decode requests
|
130
|
-
tasks = [
|
131
|
-
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
132
|
-
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
133
|
-
]
|
134
|
-
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
135
|
-
prefill_response, decode_response = await asyncio.gather(*tasks)
|
136
|
-
|
137
|
-
if modified_request.get("return_logprob", False):
|
138
|
-
prefill_chunks = []
|
139
|
-
async for chunk in prefill_response.content:
|
140
|
-
prefill_chunks.append(chunk)
|
141
|
-
|
142
|
-
first_prefill_chunk = (
|
143
|
-
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
144
|
-
)
|
145
|
-
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
146
|
-
|
147
|
-
async for chunk in decode_response.content:
|
148
|
-
# Note: This is inefficient
|
149
|
-
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
150
|
-
decoded_chunk = chunk.decode("utf-8")
|
151
|
-
if (
|
152
|
-
decoded_chunk
|
153
|
-
and decoded_chunk.startswith("data:")
|
154
|
-
and "[DONE]" not in decoded_chunk
|
155
|
-
):
|
156
|
-
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
157
|
-
ret_json["meta_info"]["input_token_logprobs"] = (
|
158
|
-
first_prefill_chunk_json["meta_info"][
|
159
|
-
"input_token_logprobs"
|
160
|
-
]
|
161
|
-
+ ret_json["meta_info"]["input_token_logprobs"]
|
162
|
-
)
|
163
|
-
|
164
|
-
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
165
|
-
else:
|
166
|
-
yield chunk
|
167
|
-
else:
|
168
|
-
async for chunk in decode_response.content.iter_chunked(
|
169
|
-
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
170
|
-
):
|
171
|
-
yield chunk
|
172
|
-
|
173
|
-
return StreamingResponse(
|
174
|
-
stream_results(),
|
175
|
-
media_type="text/event-stream",
|
176
|
-
)
|
177
|
-
|
178
|
-
|
179
|
-
app = FastAPI()
|
180
|
-
load_balancer: Optional[MiniLoadBalancer] = None
|
181
|
-
|
182
|
-
|
183
|
-
@app.get("/health")
|
184
|
-
async def health_check():
|
185
|
-
return Response(status_code=200)
|
186
|
-
|
187
|
-
|
188
|
-
@app.get("/health_generate")
|
189
|
-
async def health_check():
|
190
|
-
prefill_servers, decode_servers = (
|
191
|
-
load_balancer.prefill_servers,
|
192
|
-
load_balancer.decode_servers,
|
193
|
-
)
|
194
|
-
async with aiohttp.ClientSession() as session:
|
195
|
-
# Create the tasks
|
196
|
-
tasks = []
|
197
|
-
for server in chain(prefill_servers, decode_servers):
|
198
|
-
tasks.append(session.post(f"{server}/health_generate"))
|
199
|
-
for i, response in enumerate(asyncio.as_completed(tasks)):
|
200
|
-
await response
|
201
|
-
return Response(status_code=200)
|
202
|
-
|
203
|
-
|
204
|
-
@app.post("/flush_cache")
|
205
|
-
async def flush_cache():
|
206
|
-
prefill_servers, decode_servers = (
|
207
|
-
load_balancer.prefill_servers,
|
208
|
-
load_balancer.decode_servers,
|
209
|
-
)
|
210
|
-
async with aiohttp.ClientSession() as session:
|
211
|
-
# Create the tasks
|
212
|
-
tasks = []
|
213
|
-
for server in chain(prefill_servers, decode_servers):
|
214
|
-
tasks.append(session.post(f"{server}/flush_cache"))
|
215
|
-
for i, response in enumerate(asyncio.as_completed(tasks)):
|
216
|
-
await response
|
217
|
-
return Response(status_code=200)
|
218
|
-
|
219
|
-
|
220
|
-
@app.get("/get_server_info")
|
221
|
-
async def get_server_info():
|
222
|
-
prefill_servers, decode_servers = (
|
223
|
-
load_balancer.prefill_servers,
|
224
|
-
load_balancer.decode_servers,
|
225
|
-
)
|
226
|
-
prefill_infos = []
|
227
|
-
decode_infos = []
|
228
|
-
all_internal_states = []
|
229
|
-
|
230
|
-
async with aiohttp.ClientSession() as session:
|
231
|
-
for server in chain(prefill_servers):
|
232
|
-
server_info = await session.get(f"{server}/get_server_info")
|
233
|
-
prefill_infos.append(await server_info.json())
|
234
|
-
for server in chain(decode_servers):
|
235
|
-
server_info = await session.get(f"{server}/get_server_info")
|
236
|
-
info_json = await server_info.json()
|
237
|
-
decode_infos.append(info_json)
|
238
|
-
# Extract internal_states from decode servers
|
239
|
-
if "internal_states" in info_json:
|
240
|
-
all_internal_states.extend(info_json["internal_states"])
|
241
|
-
|
242
|
-
# Return format expected by bench_one_batch_server.py
|
243
|
-
if all_internal_states:
|
244
|
-
return {
|
245
|
-
"internal_states": all_internal_states,
|
246
|
-
"prefill": prefill_infos,
|
247
|
-
"decode": decode_infos,
|
248
|
-
}
|
249
|
-
else:
|
250
|
-
# Fallback with dummy data if no internal states found
|
251
|
-
return {
|
252
|
-
"internal_states": [
|
253
|
-
{
|
254
|
-
"last_gen_throughput": 0.0,
|
255
|
-
"avg_spec_accept_length": None,
|
256
|
-
}
|
257
|
-
],
|
258
|
-
"prefill": prefill_infos,
|
259
|
-
"decode": decode_infos,
|
260
|
-
}
|
261
|
-
|
262
|
-
|
263
|
-
@app.get("/get_model_info")
|
264
|
-
async def get_model_info():
|
265
|
-
# Dummy model information
|
266
|
-
model_info = {
|
267
|
-
"model_path": "/path/to/dummy/model",
|
268
|
-
"tokenizer_path": "/path/to/dummy/tokenizer",
|
269
|
-
"is_generation": True,
|
270
|
-
"preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
|
271
|
-
}
|
272
|
-
return ORJSONResponse(content=model_info)
|
273
|
-
|
274
|
-
|
275
|
-
@app.post("/generate")
|
276
|
-
async def handle_generate_request(request_data: dict):
|
277
|
-
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
278
|
-
|
279
|
-
# Parse and transform prefill_server for bootstrap data
|
280
|
-
parsed_url = urllib.parse.urlparse(prefill_server)
|
281
|
-
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
282
|
-
modified_request = request_data.copy()
|
283
|
-
|
284
|
-
batch_size = _get_request_batch_size(modified_request)
|
285
|
-
if batch_size is not None:
|
286
|
-
modified_request.update(
|
287
|
-
{
|
288
|
-
"bootstrap_host": [hostname] * batch_size,
|
289
|
-
"bootstrap_port": [bootstrap_port] * batch_size,
|
290
|
-
"bootstrap_room": [
|
291
|
-
_generate_bootstrap_room() for _ in range(batch_size)
|
292
|
-
],
|
293
|
-
}
|
294
|
-
)
|
295
|
-
else:
|
296
|
-
modified_request.update(
|
297
|
-
{
|
298
|
-
"bootstrap_host": hostname,
|
299
|
-
"bootstrap_port": bootstrap_port,
|
300
|
-
"bootstrap_room": _generate_bootstrap_room(),
|
301
|
-
}
|
302
|
-
)
|
303
|
-
|
304
|
-
if request_data.get("stream", False):
|
305
|
-
return await load_balancer.generate_stream(
|
306
|
-
modified_request, prefill_server, decode_server, "generate"
|
307
|
-
)
|
308
|
-
else:
|
309
|
-
return await load_balancer.generate(
|
310
|
-
modified_request, prefill_server, decode_server, "generate"
|
311
|
-
)
|
312
|
-
|
313
|
-
|
314
|
-
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
315
|
-
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
316
|
-
|
317
|
-
# Parse and transform prefill_server for bootstrap data
|
318
|
-
parsed_url = urllib.parse.urlparse(prefill_server)
|
319
|
-
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
320
|
-
modified_request = request_data.copy()
|
321
|
-
modified_request.update(
|
322
|
-
{
|
323
|
-
"bootstrap_host": hostname,
|
324
|
-
"bootstrap_port": bootstrap_port,
|
325
|
-
"bootstrap_room": _generate_bootstrap_room(),
|
326
|
-
}
|
327
|
-
)
|
328
|
-
|
329
|
-
if request_data.get("stream", False):
|
330
|
-
return await load_balancer.generate_stream(
|
331
|
-
modified_request,
|
332
|
-
prefill_server,
|
333
|
-
decode_server,
|
334
|
-
endpoint=endpoint_name,
|
335
|
-
)
|
336
|
-
else:
|
337
|
-
return await load_balancer.generate(
|
338
|
-
modified_request,
|
339
|
-
prefill_server,
|
340
|
-
decode_server,
|
341
|
-
endpoint=endpoint_name,
|
342
|
-
)
|
343
|
-
|
344
|
-
|
345
|
-
@app.post("/v1/chat/completions")
|
346
|
-
async def handle_chat_completion_request(request_data: dict):
|
347
|
-
return await _forward_to_backend(request_data, "v1/chat/completions")
|
348
|
-
|
349
|
-
|
350
|
-
@app.post("/v1/completions")
|
351
|
-
async def handle_completion_request(request_data: dict):
|
352
|
-
return await _forward_to_backend(request_data, "v1/completions")
|
353
|
-
|
354
|
-
|
355
|
-
def _generate_bootstrap_room():
|
356
|
-
return random.randint(0, 2**63 - 1)
|
357
|
-
|
358
|
-
|
359
|
-
# We may utilize `GenerateReqInput`'s logic later
|
360
|
-
def _get_request_batch_size(request):
|
361
|
-
if (text := request.get("text")) is not None:
|
362
|
-
return None if isinstance(text, str) else len(text)
|
363
|
-
if (input_ids := request.get("input_ids")) is not None:
|
364
|
-
return None if isinstance(input_ids[0], int) else len(input_ids)
|
365
|
-
return None
|
366
|
-
|
367
|
-
|
368
|
-
@app.get("/v1/models")
|
369
|
-
async def get_models():
|
370
|
-
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
371
|
-
async with aiohttp.ClientSession() as session:
|
372
|
-
try:
|
373
|
-
response = await session.get(f"{prefill_server}/v1/models")
|
374
|
-
if response.status != 200:
|
375
|
-
raise HTTPException(
|
376
|
-
status_code=response.status,
|
377
|
-
detail=f"Prefill server error: Status {response.status}",
|
378
|
-
)
|
379
|
-
return ORJSONResponse(content=await response.json())
|
380
|
-
except Exception as e:
|
381
|
-
raise HTTPException(status_code=500, detail=str(e))
|
382
|
-
|
383
|
-
|
384
|
-
@app.post("/register")
|
385
|
-
async def register(obj: PDRegistryRequest):
|
386
|
-
if obj.mode == "prefill":
|
387
|
-
load_balancer.add_prefill_server(
|
388
|
-
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
389
|
-
)
|
390
|
-
logger.info(
|
391
|
-
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
392
|
-
)
|
393
|
-
elif obj.mode == "decode":
|
394
|
-
load_balancer.add_decode_server(obj.registry_url)
|
395
|
-
logger.info(f"Registered decode server: {obj.registry_url}")
|
396
|
-
else:
|
397
|
-
raise HTTPException(
|
398
|
-
status_code=400,
|
399
|
-
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
400
|
-
)
|
401
|
-
|
402
|
-
logger.info(
|
403
|
-
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
404
|
-
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
405
|
-
)
|
406
|
-
|
407
|
-
return Response(status_code=200)
|
408
|
-
|
409
|
-
|
410
|
-
def run(prefill_configs, decode_addrs, host, port, timeout):
|
411
|
-
global load_balancer
|
412
|
-
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
|
413
|
-
uvicorn.run(app, host=host, port=port)
|
414
|
-
|
415
|
-
|
416
|
-
if __name__ == "__main__":
|
417
|
-
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
|
418
|
-
from sglang.srt.disaggregation.launch_lb import main
|
419
|
-
|
420
|
-
main()
|
1
|
+
raise RuntimeError(
|
2
|
+
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
|
3
|
+
We recommend installing 'sglang-router' with Rust support for optimal performance.
|
4
|
+
If you encounter issues building the router with Rust, set the environment variable
|
5
|
+
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
|
6
|
+
)
|
@@ -175,6 +175,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
175
175
|
self.disaggregation_mode = disaggregation_mode
|
176
176
|
self.init_engine()
|
177
177
|
# for p/d multi node infer
|
178
|
+
self.bootstrap_host = server_args.host
|
178
179
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
179
180
|
self.dist_init_addr = server_args.dist_init_addr
|
180
181
|
self.attn_tp_size = get_attention_tp_size()
|
@@ -458,7 +459,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
458
459
|
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
459
460
|
else:
|
460
461
|
# Send KVCache from 1 prefill instance to multiple decode instances
|
461
|
-
src_head_start_offset =
|
462
|
+
src_head_start_offset = (
|
463
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
464
|
+
) % src_heads_per_rank
|
462
465
|
num_heads_to_send = dst_heads_per_rank
|
463
466
|
dst_head_start_offset = 0
|
464
467
|
|
@@ -1020,6 +1023,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
1020
1023
|
def _register_to_bootstrap(self):
|
1021
1024
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
1022
1025
|
if self.dist_init_addr:
|
1026
|
+
# multi node case: bootstrap server's host is dist_init_addr
|
1023
1027
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
1024
1028
|
if self.dist_init_addr.endswith("]"):
|
1025
1029
|
host = self.dist_init_addr
|
@@ -1028,7 +1032,8 @@ class MooncakeKVManager(BaseKVManager):
|
|
1028
1032
|
else:
|
1029
1033
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
1030
1034
|
else:
|
1031
|
-
host
|
1035
|
+
# single node case: bootstrap server's host is same as http server's host
|
1036
|
+
host = self.bootstrap_host
|
1032
1037
|
host = maybe_wrap_ipv6_address(host)
|
1033
1038
|
|
1034
1039
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
@@ -1209,7 +1214,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1209
1214
|
mgr: MooncakeKVManager,
|
1210
1215
|
bootstrap_addr: str,
|
1211
1216
|
bootstrap_room: Optional[int] = None,
|
1212
|
-
|
1217
|
+
prefill_dp_rank: Optional[int] = None,
|
1213
1218
|
):
|
1214
1219
|
self.bootstrap_room = bootstrap_room
|
1215
1220
|
self.bootstrap_addr = bootstrap_addr
|
@@ -1218,7 +1223,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1218
1223
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
1219
1224
|
self.conclude_state = None
|
1220
1225
|
self.init_time = None
|
1221
|
-
self.data_parallel_rank = data_parallel_rank
|
1222
1226
|
|
1223
1227
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
1224
1228
|
(
|
@@ -1317,11 +1321,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1317
1321
|
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
1318
1322
|
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
1319
1323
|
|
1320
|
-
if
|
1321
|
-
logger.debug(f"Targeting DP rank: {
|
1322
|
-
self.
|
1324
|
+
if prefill_dp_rank is not None:
|
1325
|
+
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
1326
|
+
self.prefill_dp_rank = prefill_dp_rank
|
1323
1327
|
else:
|
1324
|
-
self.
|
1328
|
+
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
1329
|
+
|
1330
|
+
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
1331
|
+
self.target_dp_group = self.prefill_dp_rank
|
1325
1332
|
|
1326
1333
|
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
1327
1334
|
self.required_prefill_response_num
|
@@ -1545,7 +1552,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1545
1552
|
|
1546
1553
|
|
1547
1554
|
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
1548
|
-
def __init__(self, port: int):
|
1555
|
+
def __init__(self, host: str, port: int):
|
1556
|
+
self.host = host
|
1549
1557
|
self.port = port
|
1550
1558
|
self.app = web.Application()
|
1551
1559
|
self.store = dict()
|
@@ -1673,7 +1681,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1673
1681
|
self._runner = web.AppRunner(self.app, access_log=access_log)
|
1674
1682
|
self._loop.run_until_complete(self._runner.setup())
|
1675
1683
|
|
1676
|
-
site = web.TCPSite(self._runner, port=self.port)
|
1684
|
+
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
1677
1685
|
self._loop.run_until_complete(site.start())
|
1678
1686
|
self._loop.run_forever()
|
1679
1687
|
except Exception as e:
|