sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import hashlib
2
3
  import json
3
4
  from pathlib import Path
4
5
 
@@ -13,7 +14,11 @@ Supported inputs:
13
14
 
14
15
 
15
16
  def main(args):
16
- df_input = _transform_df_input(_compute_df_raw(args))
17
+ if args.data_type == "simple_evals":
18
+ df_input = _compute_df_input_mode_simple_evals(args)
19
+ else:
20
+ df_input = _transform_df_input(_compute_df_raw(args))
21
+
17
22
  assert all(
18
23
  c in df_input.columns
19
24
  for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
@@ -37,8 +42,9 @@ def main(args):
37
42
  df_meta=df_meta.to_dicts(),
38
43
  df_good_to_bad=df_good_to_bad.to_dicts(),
39
44
  df_bad_to_good=df_bad_to_good.to_dicts(),
40
- )
41
- )
45
+ ),
46
+ indent=4,
47
+ ),
42
48
  )
43
49
 
44
50
  if not args.disable_print_details:
@@ -65,19 +71,70 @@ def main(args):
65
71
  print(df)
66
72
 
67
73
 
74
+ def _compute_df_input_mode_simple_evals(args):
75
+ return pl.concat(
76
+ [
77
+ _compute_df_input_one_mode_simple_evals(**info)
78
+ for info in _get_file_infos(args=args)
79
+ ]
80
+ )
81
+
82
+
83
+ def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
84
+ data = json.loads(Path(path).read_text())
85
+ rows = []
86
+
87
+ for single_eval_result in data["metadata"]["single_eval_results"]:
88
+ prompt = single_eval_result["example_level_metadata"][
89
+ "actual_queried_prompt_messages"
90
+ ]
91
+ score = single_eval_result["score"]
92
+ assert score in {0.0, 1.0}, f"{score=}"
93
+
94
+ row = dict(
95
+ category=category,
96
+ trial_index=trial_index,
97
+ prompt_id=_compute_id_from_object(prompt),
98
+ prompt=json.dumps(prompt),
99
+ output=single_eval_result["example_level_metadata"]["response_text"],
100
+ correct=score == 1.0,
101
+ )
102
+ rows.append(row)
103
+
104
+ return pl.DataFrame(rows)
105
+
106
+
107
+ def _compute_id_from_object(obj):
108
+ if isinstance(obj, pl.Series):
109
+ obj = obj.to_list()
110
+ json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
111
+ return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
112
+
113
+
68
114
  def _compute_df_raw(args):
69
115
  return pl.concat(
70
116
  [
71
- _read_df_raw(p, category=category, trial_index=i)
72
- for category, paths in [
73
- ("baseline", args.baseline_path),
74
- ("target", args.target_path),
75
- ]
76
- for i, p in enumerate(paths)
117
+ _read_df_raw(
118
+ path=info["path"],
119
+ category=info["category"],
120
+ trial_index=info["trial_index"],
121
+ )
122
+ for info in _get_file_infos(args=args)
77
123
  ]
78
124
  )
79
125
 
80
126
 
127
+ def _get_file_infos(args):
128
+ return [
129
+ dict(path=path, category=category, trial_index=trial_index)
130
+ for category, paths in [
131
+ ("baseline", args.baseline_path),
132
+ ("target", args.target_path),
133
+ ]
134
+ for trial_index, path in enumerate(paths)
135
+ ]
136
+
137
+
81
138
  def _read_df_raw(path: str, category: str, trial_index: int):
82
139
  return pl.read_ndjson(path).with_columns(
83
140
  category=pl.lit(category), trial_index=trial_index
@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
108
165
  print("Transform mode: SGLang bench")
109
166
  return df
110
167
  else:
111
- raise Exception(f"Unknown data: {df.columns}")
168
+ raise Exception(
169
+ f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
170
+ )
112
171
 
113
172
 
114
173
  def _compute_df_meta(df_input: pl.DataFrame):
@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
127
186
 
128
187
 
129
188
  def _handle_one_prompt(df_one_prompt: pl.DataFrame):
130
- assert len(set(df_one_prompt["prompt"])) == 1
189
+ assert (
190
+ len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
191
+ )
131
192
 
132
193
  df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
133
194
  df_target = df_one_prompt.filter(pl.col("category") == "target")
@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
162
223
 
163
224
  if __name__ == "__main__":
164
225
  parser = argparse.ArgumentParser(description=_DESCRIPTION)
226
+ parser.add_argument("--data-type", type=str, default="auto")
165
227
  parser.add_argument("--baseline-path", type=str, nargs="+")
166
228
  parser.add_argument("--target-path", type=str, nargs="+")
167
229
  parser.add_argument(
@@ -131,4 +131,4 @@ class BaseKVReceiver(ABC):
131
131
 
132
132
  class BaseKVBootstrapServer(ABC):
133
133
  @abstractmethod
134
- def __init__(self, port: int): ...
134
+ def __init__(self, host: str, port: int): ...
@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager):
47
47
  self.is_mla_backend = is_mla_backend
48
48
  self.disaggregation_mode = disaggregation_mode
49
49
  # for p/d multi node infer
50
+ self.bootstrap_host = server_args.host
50
51
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
51
52
  self.dist_init_addr = server_args.dist_init_addr
52
53
  self.tp_size = server_args.tp_size
@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager):
72
73
  def _register_to_bootstrap(self):
73
74
  """Register KVSender to bootstrap server via HTTP POST."""
74
75
  if self.dist_init_addr:
76
+ # multi node: bootstrap server's host is dist_init_addr
75
77
  if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
76
78
  if self.dist_init_addr.endswith("]"):
77
79
  host = self.dist_init_addr
@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager):
80
82
  else:
81
83
  host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
82
84
  else:
83
- host = get_ip()
85
+ # single node: bootstrap server's host is same as http server's host
86
+ host = self.bootstrap_host
84
87
  host = maybe_wrap_ipv6_address(host)
85
88
 
86
89
  bootstrap_server_url = f"{host}:{self.bootstrap_port}"
@@ -125,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
125
128
  mgr: BaseKVManager,
126
129
  bootstrap_addr: str,
127
130
  bootstrap_room: Optional[int] = None,
128
- data_parallel_rank: Optional[int] = None,
131
+ prefill_dp_rank: Optional[int] = None,
129
132
  ):
130
133
  self.bootstrap_room = bootstrap_room
131
134
  self.bootstrap_addr = bootstrap_addr
132
135
  self.kv_mgr = mgr
133
- self.data_parallel_rank = data_parallel_rank
134
136
 
135
137
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
136
138
  self.prefill_tp_size, self.prefill_dp_size = (
@@ -166,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver):
166
168
  self.required_dst_info_num = 1
167
169
  self.target_tp_ranks = [self.target_tp_rank]
168
170
  elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
169
- assert (
170
- self.kv_mgr.is_mla_backend
171
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
172
171
  self.target_tp_rank = (
173
172
  self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
174
173
  ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
@@ -198,11 +197,14 @@ class CommonKVReceiver(BaseKVReceiver):
198
197
  self.target_tp_rank = self.target_tp_ranks[0]
199
198
  self.required_dst_info_num = 1
200
199
 
201
- if self.data_parallel_rank is not None:
202
- logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
203
- self.target_dp_group = self.data_parallel_rank
200
+ if prefill_dp_rank is not None:
201
+ logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
202
+ self.prefill_dp_rank = prefill_dp_rank
204
203
  else:
205
- self.target_dp_group = bootstrap_room % self.prefill_dp_size
204
+ self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
205
+
206
+ # FIXME: alias here: target_dp_group -> prefill_dp_rank
207
+ self.target_dp_group = self.prefill_dp_rank
206
208
 
207
209
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
208
210
  bootstrap_key = (
@@ -308,7 +310,8 @@ class CommonKVReceiver(BaseKVReceiver):
308
310
 
309
311
 
310
312
  class CommonKVBootstrapServer(BaseKVBootstrapServer):
311
- def __init__(self, port: int):
313
+ def __init__(self, host: str, port: int):
314
+ self.host = host
312
315
  self.port = port
313
316
  self.app = web.Application()
314
317
  self.store = dict()
@@ -412,7 +415,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
412
415
  self._runner = web.AppRunner(self.app)
413
416
  self._loop.run_until_complete(self._runner.setup())
414
417
 
415
- site = web.TCPSite(self._runner, port=self.port)
418
+ site = web.TCPSite(self._runner, host=self.host, port=self.port)
416
419
  self._loop.run_until_complete(site.start())
417
420
  self._loop.run_forever()
418
421
  except Exception as e:
@@ -24,7 +24,7 @@ import logging
24
24
  from collections import deque
25
25
  from dataclasses import dataclass
26
26
  from http import HTTPStatus
27
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
28
28
 
29
29
  import torch
30
30
  from torch.distributed import ProcessGroup
@@ -218,8 +218,10 @@ class DecodePreallocQueue:
218
218
 
219
219
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
220
220
  kv_args.gpu_id = self.scheduler.gpu_id
221
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
222
- kv_manager = kv_manager_class(
221
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
222
+ self.transfer_backend, KVClassType.MANAGER
223
+ )
224
+ kv_manager: BaseKVManager = kv_manager_class(
223
225
  kv_args,
224
226
  DisaggregationMode.DECODE,
225
227
  self.scheduler.server_args,
@@ -248,7 +250,7 @@ class DecodePreallocQueue:
248
250
  mgr=self.kv_manager,
249
251
  bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
250
252
  bootstrap_room=req.bootstrap_room,
251
- data_parallel_rank=req.data_parallel_rank,
253
+ prefill_dp_rank=req.data_parallel_rank,
252
254
  )
253
255
 
254
256
  self.queue.append(
@@ -884,9 +886,18 @@ class SchedulerDisaggregationDecodeMixin:
884
886
  # if there are still retracted requests, we do not allocate new requests
885
887
  return
886
888
 
887
- req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
888
- self.disagg_decode_transfer_queue.extend(req_conns)
889
- alloc_reqs = (
890
- self.disagg_decode_transfer_queue.pop_transferred()
891
- ) # the requests which kv has arrived
892
- self.waiting_queue.extend(alloc_reqs)
889
+ if not hasattr(self, "polling_count"):
890
+ self.polling_count = 0
891
+ self.polling_interval = (
892
+ self.server_args.disaggregation_decode_polling_interval
893
+ )
894
+
895
+ self.polling_count = (self.polling_count + 1) % self.polling_interval
896
+
897
+ if self.polling_count % self.polling_interval == 0:
898
+ req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
899
+ self.disagg_decode_transfer_queue.extend(req_conns)
900
+ alloc_reqs = (
901
+ self.disagg_decode_transfer_queue.pop_transferred()
902
+ ) # the requests which kv has arrived
903
+ self.waiting_queue.extend(alloc_reqs)
@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
110
110
  if req.grammar is not None:
111
111
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
112
112
  try:
113
- req.grammar.accept_token(req.output_ids[-1])
113
+ # if it is not None, then the grammar is from a retracted request, and we should not
114
+ # accept the token as it's already accepted
115
+ if req.grammar.current_token is None:
116
+ req.grammar.accept_token(req.output_ids[-1])
114
117
  except ValueError as e:
115
118
  # Grammar accept_token can raise ValueError if the token is not in the grammar.
116
119
  # This can happen if the grammar is not set correctly or the token is invalid.
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
62
62
  mgr: BaseKVManager,
63
63
  bootstrap_addr: str,
64
64
  bootstrap_room: Optional[int] = None,
65
- data_parallel_rank: Optional[int] = None,
65
+ prefill_dp_rank: Optional[int] = None,
66
66
  ):
67
67
  self.has_init = False
68
68