sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,420 +1,6 @@
1
- """
2
- Minimal HTTP load balancer for prefill and decode servers for testing.
3
- """
4
-
5
- import asyncio
6
- import dataclasses
7
- import logging
8
- import random
9
- import urllib
10
- from itertools import chain
11
- from typing import List, Optional
12
-
13
- import aiohttp
14
- import orjson
15
- import uvicorn
16
- from fastapi import FastAPI, HTTPException
17
- from fastapi.responses import ORJSONResponse, Response, StreamingResponse
18
-
19
- from sglang.srt.disaggregation.utils import PDRegistryRequest
20
- from sglang.srt.utils import maybe_wrap_ipv6_address
21
-
22
- AIOHTTP_STREAM_READ_CHUNK_SIZE = (
23
- 1024 * 64
24
- ) # 64KB, to prevent aiohttp's "Chunk too big" error
25
-
26
-
27
- def setup_logger():
28
- logger = logging.getLogger("pdlb")
29
- logger.setLevel(logging.INFO)
30
-
31
- formatter = logging.Formatter(
32
- "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
33
- datefmt="%Y-%m-%d %H:%M:%S",
34
- )
35
-
36
- handler = logging.StreamHandler()
37
- handler.setFormatter(formatter)
38
- logger.addHandler(handler)
39
-
40
- return logger
41
-
42
-
43
- logger = setup_logger()
44
-
45
-
46
- @dataclasses.dataclass
47
- class PrefillConfig:
48
- url: str
49
- bootstrap_port: Optional[int] = None
50
-
51
-
52
- class MiniLoadBalancer:
53
- def __init__(
54
- self,
55
- prefill_configs: List[PrefillConfig],
56
- decode_servers: List[str],
57
- timeout: int,
58
- ):
59
- self.prefill_configs = prefill_configs
60
- self.prefill_servers = [p.url for p in prefill_configs]
61
- self.decode_servers = decode_servers
62
- self.timeout = timeout
63
-
64
- def add_prefill_server(self, new_prefill_config: PrefillConfig):
65
- self.prefill_configs.append(new_prefill_config)
66
- self.prefill_servers.append(new_prefill_config.url)
67
-
68
- def add_decode_server(self, new_decode_server: str):
69
- self.decode_servers.append(new_decode_server)
70
-
71
- def select_pair(self):
72
- # TODO: return some message instead of panic
73
- assert len(self.prefill_configs) > 0, "No prefill servers available"
74
- assert len(self.decode_servers) > 0, "No decode servers available"
75
-
76
- prefill_config = random.choice(self.prefill_configs)
77
- decode_server = random.choice(self.decode_servers)
78
- return prefill_config.url, prefill_config.bootstrap_port, decode_server
79
-
80
- async def generate(
81
- self, modified_request, prefill_server, decode_server, endpoint
82
- ) -> ORJSONResponse:
83
- assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
84
-
85
- async with aiohttp.ClientSession(
86
- timeout=aiohttp.ClientTimeout(
87
- total=self.timeout
88
- ) # Add timeout for request reliability
89
- ) as session:
90
- tasks = [
91
- session.post(f"{prefill_server}/{endpoint}", json=modified_request),
92
- session.post(f"{decode_server}/{endpoint}", json=modified_request),
93
- ]
94
-
95
- # Wait for both responses to complete. Prefill should end first.
96
- prefill_response, decode_response = await asyncio.gather(*tasks)
97
-
98
- if "return_logprob" in modified_request:
99
-
100
- prefill_json = await prefill_response.json()
101
- ret_json = await decode_response.json()
102
-
103
- # merge `meta_info.input_token_logprobs` from prefill to decode
104
- if "meta_info" in ret_json:
105
- if "input_token_logprobs" in ret_json["meta_info"]:
106
- ret_json["meta_info"]["input_token_logprobs"] = (
107
- prefill_json["meta_info"]["input_token_logprobs"]
108
- + ret_json["meta_info"]["input_token_logprobs"]
109
- )
110
- else:
111
- ret_json = await decode_response.json()
112
-
113
- return ORJSONResponse(
114
- content=ret_json,
115
- status_code=decode_response.status,
116
- )
117
-
118
- async def generate_stream(
119
- self, modified_request, prefill_server, decode_server, endpoint="generate"
120
- ):
121
- assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
122
-
123
- async def stream_results():
124
- async with aiohttp.ClientSession(
125
- timeout=aiohttp.ClientTimeout(
126
- total=self.timeout
127
- ) # Add timeout for request reliability
128
- ) as session:
129
- # Create the tasks for both prefill and decode requests
130
- tasks = [
131
- session.post(f"{prefill_server}/{endpoint}", json=modified_request),
132
- session.post(f"{decode_server}/{endpoint}", json=modified_request),
133
- ]
134
- # Wait for both responses to complete. Since this is streaming, they return immediately.
135
- prefill_response, decode_response = await asyncio.gather(*tasks)
136
-
137
- if modified_request.get("return_logprob", False):
138
- prefill_chunks = []
139
- async for chunk in prefill_response.content:
140
- prefill_chunks.append(chunk)
141
-
142
- first_prefill_chunk = (
143
- prefill_chunks[0].decode("utf-8")[5:].strip("\n")
144
- )
145
- first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
146
-
147
- async for chunk in decode_response.content:
148
- # Note: This is inefficient
149
- # merge prefill input_token_logprobs, output_token_logprobs to decode
150
- decoded_chunk = chunk.decode("utf-8")
151
- if (
152
- decoded_chunk
153
- and decoded_chunk.startswith("data:")
154
- and "[DONE]" not in decoded_chunk
155
- ):
156
- ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
157
- ret_json["meta_info"]["input_token_logprobs"] = (
158
- first_prefill_chunk_json["meta_info"][
159
- "input_token_logprobs"
160
- ]
161
- + ret_json["meta_info"]["input_token_logprobs"]
162
- )
163
-
164
- yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
165
- else:
166
- yield chunk
167
- else:
168
- async for chunk in decode_response.content.iter_chunked(
169
- AIOHTTP_STREAM_READ_CHUNK_SIZE
170
- ):
171
- yield chunk
172
-
173
- return StreamingResponse(
174
- stream_results(),
175
- media_type="text/event-stream",
176
- )
177
-
178
-
179
- app = FastAPI()
180
- load_balancer: Optional[MiniLoadBalancer] = None
181
-
182
-
183
- @app.get("/health")
184
- async def health_check():
185
- return Response(status_code=200)
186
-
187
-
188
- @app.get("/health_generate")
189
- async def health_check():
190
- prefill_servers, decode_servers = (
191
- load_balancer.prefill_servers,
192
- load_balancer.decode_servers,
193
- )
194
- async with aiohttp.ClientSession() as session:
195
- # Create the tasks
196
- tasks = []
197
- for server in chain(prefill_servers, decode_servers):
198
- tasks.append(session.post(f"{server}/health_generate"))
199
- for i, response in enumerate(asyncio.as_completed(tasks)):
200
- await response
201
- return Response(status_code=200)
202
-
203
-
204
- @app.post("/flush_cache")
205
- async def flush_cache():
206
- prefill_servers, decode_servers = (
207
- load_balancer.prefill_servers,
208
- load_balancer.decode_servers,
209
- )
210
- async with aiohttp.ClientSession() as session:
211
- # Create the tasks
212
- tasks = []
213
- for server in chain(prefill_servers, decode_servers):
214
- tasks.append(session.post(f"{server}/flush_cache"))
215
- for i, response in enumerate(asyncio.as_completed(tasks)):
216
- await response
217
- return Response(status_code=200)
218
-
219
-
220
- @app.get("/get_server_info")
221
- async def get_server_info():
222
- prefill_servers, decode_servers = (
223
- load_balancer.prefill_servers,
224
- load_balancer.decode_servers,
225
- )
226
- prefill_infos = []
227
- decode_infos = []
228
- all_internal_states = []
229
-
230
- async with aiohttp.ClientSession() as session:
231
- for server in chain(prefill_servers):
232
- server_info = await session.get(f"{server}/get_server_info")
233
- prefill_infos.append(await server_info.json())
234
- for server in chain(decode_servers):
235
- server_info = await session.get(f"{server}/get_server_info")
236
- info_json = await server_info.json()
237
- decode_infos.append(info_json)
238
- # Extract internal_states from decode servers
239
- if "internal_states" in info_json:
240
- all_internal_states.extend(info_json["internal_states"])
241
-
242
- # Return format expected by bench_one_batch_server.py
243
- if all_internal_states:
244
- return {
245
- "internal_states": all_internal_states,
246
- "prefill": prefill_infos,
247
- "decode": decode_infos,
248
- }
249
- else:
250
- # Fallback with dummy data if no internal states found
251
- return {
252
- "internal_states": [
253
- {
254
- "last_gen_throughput": 0.0,
255
- "avg_spec_accept_length": None,
256
- }
257
- ],
258
- "prefill": prefill_infos,
259
- "decode": decode_infos,
260
- }
261
-
262
-
263
- @app.get("/get_model_info")
264
- async def get_model_info():
265
- # Dummy model information
266
- model_info = {
267
- "model_path": "/path/to/dummy/model",
268
- "tokenizer_path": "/path/to/dummy/tokenizer",
269
- "is_generation": True,
270
- "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
271
- }
272
- return ORJSONResponse(content=model_info)
273
-
274
-
275
- @app.post("/generate")
276
- async def handle_generate_request(request_data: dict):
277
- prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
278
-
279
- # Parse and transform prefill_server for bootstrap data
280
- parsed_url = urllib.parse.urlparse(prefill_server)
281
- hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
282
- modified_request = request_data.copy()
283
-
284
- batch_size = _get_request_batch_size(modified_request)
285
- if batch_size is not None:
286
- modified_request.update(
287
- {
288
- "bootstrap_host": [hostname] * batch_size,
289
- "bootstrap_port": [bootstrap_port] * batch_size,
290
- "bootstrap_room": [
291
- _generate_bootstrap_room() for _ in range(batch_size)
292
- ],
293
- }
294
- )
295
- else:
296
- modified_request.update(
297
- {
298
- "bootstrap_host": hostname,
299
- "bootstrap_port": bootstrap_port,
300
- "bootstrap_room": _generate_bootstrap_room(),
301
- }
302
- )
303
-
304
- if request_data.get("stream", False):
305
- return await load_balancer.generate_stream(
306
- modified_request, prefill_server, decode_server, "generate"
307
- )
308
- else:
309
- return await load_balancer.generate(
310
- modified_request, prefill_server, decode_server, "generate"
311
- )
312
-
313
-
314
- async def _forward_to_backend(request_data: dict, endpoint_name: str):
315
- prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
316
-
317
- # Parse and transform prefill_server for bootstrap data
318
- parsed_url = urllib.parse.urlparse(prefill_server)
319
- hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
320
- modified_request = request_data.copy()
321
- modified_request.update(
322
- {
323
- "bootstrap_host": hostname,
324
- "bootstrap_port": bootstrap_port,
325
- "bootstrap_room": _generate_bootstrap_room(),
326
- }
327
- )
328
-
329
- if request_data.get("stream", False):
330
- return await load_balancer.generate_stream(
331
- modified_request,
332
- prefill_server,
333
- decode_server,
334
- endpoint=endpoint_name,
335
- )
336
- else:
337
- return await load_balancer.generate(
338
- modified_request,
339
- prefill_server,
340
- decode_server,
341
- endpoint=endpoint_name,
342
- )
343
-
344
-
345
- @app.post("/v1/chat/completions")
346
- async def handle_chat_completion_request(request_data: dict):
347
- return await _forward_to_backend(request_data, "v1/chat/completions")
348
-
349
-
350
- @app.post("/v1/completions")
351
- async def handle_completion_request(request_data: dict):
352
- return await _forward_to_backend(request_data, "v1/completions")
353
-
354
-
355
- def _generate_bootstrap_room():
356
- return random.randint(0, 2**63 - 1)
357
-
358
-
359
- # We may utilize `GenerateReqInput`'s logic later
360
- def _get_request_batch_size(request):
361
- if (text := request.get("text")) is not None:
362
- return None if isinstance(text, str) else len(text)
363
- if (input_ids := request.get("input_ids")) is not None:
364
- return None if isinstance(input_ids[0], int) else len(input_ids)
365
- return None
366
-
367
-
368
- @app.get("/v1/models")
369
- async def get_models():
370
- prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
371
- async with aiohttp.ClientSession() as session:
372
- try:
373
- response = await session.get(f"{prefill_server}/v1/models")
374
- if response.status != 200:
375
- raise HTTPException(
376
- status_code=response.status,
377
- detail=f"Prefill server error: Status {response.status}",
378
- )
379
- return ORJSONResponse(content=await response.json())
380
- except Exception as e:
381
- raise HTTPException(status_code=500, detail=str(e))
382
-
383
-
384
- @app.post("/register")
385
- async def register(obj: PDRegistryRequest):
386
- if obj.mode == "prefill":
387
- load_balancer.add_prefill_server(
388
- PrefillConfig(obj.registry_url, obj.bootstrap_port)
389
- )
390
- logger.info(
391
- f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
392
- )
393
- elif obj.mode == "decode":
394
- load_balancer.add_decode_server(obj.registry_url)
395
- logger.info(f"Registered decode server: {obj.registry_url}")
396
- else:
397
- raise HTTPException(
398
- status_code=400,
399
- detail="Invalid mode. Must be either PREFILL or DECODE.",
400
- )
401
-
402
- logger.info(
403
- f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
404
- f"#Decode servers: {len(load_balancer.decode_servers)}"
405
- )
406
-
407
- return Response(status_code=200)
408
-
409
-
410
- def run(prefill_configs, decode_addrs, host, port, timeout):
411
- global load_balancer
412
- load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
413
- uvicorn.run(app, host=host, port=port)
414
-
415
-
416
- if __name__ == "__main__":
417
- # FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
418
- from sglang.srt.disaggregation.launch_lb import main
419
-
420
- main()
1
+ raise RuntimeError(
2
+ """The 'mini_lb' module has been relocated to the 'sglang_router' package.
3
+ We recommend installing 'sglang-router' with Rust support for optimal performance.
4
+ If you encounter issues building the router with Rust, set the environment variable
5
+ 'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
6
+ )
@@ -175,6 +175,7 @@ class MooncakeKVManager(BaseKVManager):
175
175
  self.disaggregation_mode = disaggregation_mode
176
176
  self.init_engine()
177
177
  # for p/d multi node infer
178
+ self.bootstrap_host = server_args.host
178
179
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
179
180
  self.dist_init_addr = server_args.dist_init_addr
180
181
  self.attn_tp_size = get_attention_tp_size()
@@ -458,7 +459,9 @@ class MooncakeKVManager(BaseKVManager):
458
459
  dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
459
460
  else:
460
461
  # Send KVCache from 1 prefill instance to multiple decode instances
461
- src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
462
+ src_head_start_offset = (
463
+ dst_tp_rank_in_group * dst_heads_per_rank
464
+ ) % src_heads_per_rank
462
465
  num_heads_to_send = dst_heads_per_rank
463
466
  dst_head_start_offset = 0
464
467
 
@@ -1020,6 +1023,7 @@ class MooncakeKVManager(BaseKVManager):
1020
1023
  def _register_to_bootstrap(self):
1021
1024
  """Register KVSender to bootstrap server via HTTP POST."""
1022
1025
  if self.dist_init_addr:
1026
+ # multi node case: bootstrap server's host is dist_init_addr
1023
1027
  if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
1024
1028
  if self.dist_init_addr.endswith("]"):
1025
1029
  host = self.dist_init_addr
@@ -1028,7 +1032,8 @@ class MooncakeKVManager(BaseKVManager):
1028
1032
  else:
1029
1033
  host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
1030
1034
  else:
1031
- host = get_ip()
1035
+ # single node case: bootstrap server's host is same as http server's host
1036
+ host = self.bootstrap_host
1032
1037
  host = maybe_wrap_ipv6_address(host)
1033
1038
 
1034
1039
  bootstrap_server_url = f"{host}:{self.bootstrap_port}"
@@ -1209,7 +1214,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1209
1214
  mgr: MooncakeKVManager,
1210
1215
  bootstrap_addr: str,
1211
1216
  bootstrap_room: Optional[int] = None,
1212
- data_parallel_rank: Optional[int] = None,
1217
+ prefill_dp_rank: Optional[int] = None,
1213
1218
  ):
1214
1219
  self.bootstrap_room = bootstrap_room
1215
1220
  self.bootstrap_addr = bootstrap_addr
@@ -1218,7 +1223,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1218
1223
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
1219
1224
  self.conclude_state = None
1220
1225
  self.init_time = None
1221
- self.data_parallel_rank = data_parallel_rank
1222
1226
 
1223
1227
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
1224
1228
  (
@@ -1317,11 +1321,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
1317
1321
  self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1318
1322
  ) * (self.prefill_pp_size // self.kv_mgr.pp_size)
1319
1323
 
1320
- if self.data_parallel_rank is not None:
1321
- logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
1322
- self.target_dp_group = self.data_parallel_rank
1324
+ if prefill_dp_rank is not None:
1325
+ logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
1326
+ self.prefill_dp_rank = prefill_dp_rank
1323
1327
  else:
1324
- self.target_dp_group = bootstrap_room % self.prefill_dp_size
1328
+ self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
1329
+
1330
+ # FIXME: alias here: target_dp_group -> prefill_dp_rank
1331
+ self.target_dp_group = self.prefill_dp_rank
1325
1332
 
1326
1333
  self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
1327
1334
  self.required_prefill_response_num
@@ -1545,7 +1552,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
1545
1552
 
1546
1553
 
1547
1554
  class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1548
- def __init__(self, port: int):
1555
+ def __init__(self, host: str, port: int):
1556
+ self.host = host
1549
1557
  self.port = port
1550
1558
  self.app = web.Application()
1551
1559
  self.store = dict()
@@ -1673,7 +1681,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
1673
1681
  self._runner = web.AppRunner(self.app, access_log=access_log)
1674
1682
  self._loop.run_until_complete(self._runner.setup())
1675
1683
 
1676
- site = web.TCPSite(self._runner, port=self.port)
1684
+ site = web.TCPSite(self._runner, host=self.host, port=self.port)
1677
1685
  self._loop.run_until_complete(site.start())
1678
1686
  self._loop.run_forever()
1679
1687
  except Exception as e: