sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
1
1
  import argparse
2
2
  import functools
3
- import re
4
3
  from pathlib import Path
5
4
 
6
5
  import polars as pl
7
6
  import torch
8
7
 
8
+ from sglang.srt.debug_utils.dump_loader import find_row, read_meta
9
9
  from sglang.srt.debug_utils.dumper import get_truncated_value
10
10
 
11
11
 
@@ -26,66 +26,77 @@ def main(args):
26
26
  print("df_baseline", df_baseline)
27
27
 
28
28
  for row in df_target.iter_rows(named=True):
29
- rows_baseline = df_baseline.filter(
30
- (
31
- pl.col("forward_pass_id")
32
- == row["forward_pass_id"] - args.start_id + args.baseline_start_id
33
- )
34
- & functools.reduce(
35
- lambda a, b: a & b,
36
- [
37
- pl.col(col) == row[col]
38
- for col in row.keys()
39
- if col not in ["forward_pass_id", "dump_index", "filename"]
40
- ],
41
- )
29
+ path_target = Path(args.target_path) / row["filename"]
30
+
31
+ row_baseline = find_row(
32
+ df_baseline,
33
+ conditions=dict(
34
+ forward_pass_id=row["forward_pass_id"]
35
+ - args.start_id
36
+ + args.baseline_start_id,
37
+ **{
38
+ k: v
39
+ for k, v in row.items()
40
+ if k not in ["forward_pass_id", "dump_index", "filename"]
41
+ },
42
+ ),
42
43
  )
43
- assert len(rows_baseline) == 1, f"{rows_baseline=}"
44
- row_baseline = rows_baseline.to_dicts()[0]
44
+
45
+ if row_baseline is None:
46
+ print(f"Skip: target={str(path_target)} since no baseline")
47
+ x_target = _load_object(path_target)
48
+ if x_target is not None:
49
+ print(f"x_target(sample)={get_truncated_value(x_target)}")
50
+ continue
45
51
 
46
52
  path_baseline = Path(args.baseline_path) / row_baseline["filename"]
47
- path_target = Path(args.target_path) / row["filename"]
48
53
  print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
49
- check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
54
+ check_tensor_pair(
55
+ path_baseline=path_baseline, path_target=path_target, name=row["name"]
56
+ )
50
57
  print()
51
58
 
52
59
 
53
- def read_meta(directory):
54
- directory = Path(directory)
55
- assert directory.is_dir(), f"{directory=} should be a directory"
56
-
57
- rows = []
58
- for p in directory.glob("*.pt"):
59
- full_kwargs = {}
60
- for kv in p.stem.split("___"):
61
- k, v = kv.split("=")
62
- full_kwargs[k] = v
63
- rows.append(
64
- {
65
- "filename": str(p.name),
66
- **full_kwargs,
67
- }
68
- )
60
+ def check_tensor_pair(path_baseline, path_target, name=""):
61
+ x_baseline = _load_object(path_baseline)
62
+ x_target = _load_object(path_target)
69
63
 
70
- df = pl.DataFrame(rows)
71
- df = df.with_columns(
72
- pl.col("forward_pass_id").cast(int),
73
- pl.col("rank").cast(int),
64
+ print(
65
+ f"Raw "
66
+ f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
67
+ f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
74
68
  )
75
- return df
76
-
77
69
 
78
- def check_tensor_pair(path_baseline, path_target):
79
- x_baseline = torch.load(path_baseline, weights_only=True)
80
- x_target = torch.load(path_target, weights_only=True)
70
+ x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
71
+ x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
81
72
 
82
73
  print(
74
+ f"After preprocessor "
83
75
  f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
84
76
  f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
85
77
  )
86
78
 
79
+ x_target = x_target.float()
80
+ x_baseline = x_baseline.float()
81
+
82
+ for name, fn in (
83
+ ("mean", torch.mean),
84
+ ("std", torch.std),
85
+ ("min", torch.min),
86
+ ("max", torch.max),
87
+ ("p1", functools.partial(torch.quantile, q=0.01)),
88
+ ("p5", functools.partial(torch.quantile, q=0.05)),
89
+ ("p95", functools.partial(torch.quantile, q=0.95)),
90
+ ("p99", functools.partial(torch.quantile, q=0.99)),
91
+ ):
92
+ value_baseline = fn(x_baseline).item()
93
+ value_target = fn(x_target).item()
94
+ print(
95
+ f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
96
+ )
97
+
87
98
  if x_baseline.shape != x_target.shape:
88
- print(f" Shape mismatch")
99
+ print(f"⚠️ Shape mismatch")
89
100
  return
90
101
 
91
102
  raw_abs_diff = (x_target - x_baseline).abs()
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
112
123
  print(f"x_target(sample)={get_truncated_value(x_target)}")
113
124
 
114
125
 
126
+ def _try_unify_shape(x: torch.Tensor, target_shape):
127
+ x_shape = x.shape
128
+ num_dim_to_remove = len(x_shape) - len(target_shape)
129
+ if (x_shape[num_dim_to_remove:] == target_shape) and all(
130
+ val == 1 for val in x_shape[:num_dim_to_remove]
131
+ ):
132
+ out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
133
+ print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
134
+ return out
135
+
136
+ return x
137
+
138
+
115
139
  # Copied from DeepGEMM
116
140
  def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
117
141
  x, y = x.double(), y.double()
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
120
144
  return 1 - sim
121
145
 
122
146
 
147
+ def _comparison_preprocessor(x_baseline, x_target, name):
148
+ # can insert arbitrary adhoc postprocessing logic here
149
+ return x_baseline, x_target
150
+
151
+
152
+ def _load_object(path):
153
+ x = torch.load(path, weights_only=False)
154
+ if not isinstance(x, torch.Tensor):
155
+ print(f"Skip load {path} since {type(x)=} is not a Tensor")
156
+ return None
157
+ return x.cuda()
158
+
159
+
123
160
  if __name__ == "__main__":
124
161
  parser = argparse.ArgumentParser()
125
162
  parser.add_argument("--baseline-path", type=str)
@@ -0,0 +1,97 @@
1
+ import functools
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Any, Dict
5
+
6
+ import polars as pl
7
+ import torch
8
+
9
+
10
+ class DumpLoader:
11
+ def __init__(self):
12
+ directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
13
+
14
+ self._enable = directory is not None
15
+ if self._enable:
16
+ self._directory = Path(directory)
17
+ self._df = read_meta(directory)
18
+
19
+ @property
20
+ def enable(self):
21
+ return self._enable
22
+
23
+ def load(self, name, **kwargs):
24
+ assert self._enable, "Please call DumpLoader.load only when it is enabled"
25
+
26
+ from sglang.srt.debug_utils.dumper import dumper
27
+
28
+ forward_pass_id = dumper._forward_pass_id
29
+ conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
30
+ row = find_row(self._df, conditions=conditions)
31
+ assert (
32
+ row is not None
33
+ ), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
34
+
35
+ path = self._directory / row["filename"]
36
+ output = torch.load(path, weights_only=False)
37
+
38
+ print(
39
+ f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
40
+ )
41
+ return output
42
+
43
+
44
+ def read_meta(directory):
45
+ directory = Path(directory)
46
+ assert directory.is_dir(), f"{directory=} should be a directory"
47
+
48
+ rows = []
49
+ for p in directory.glob("*.pt"):
50
+ full_kwargs = {}
51
+ for kv in p.stem.split("___"):
52
+ k, v = kv.split("=")
53
+ full_kwargs[k] = v
54
+ rows.append(
55
+ {
56
+ "filename": str(p.name),
57
+ **full_kwargs,
58
+ }
59
+ )
60
+
61
+ df = pl.DataFrame(rows)
62
+ df = df.with_columns(
63
+ pl.col("forward_pass_id").cast(int),
64
+ pl.col("rank").cast(int),
65
+ pl.col("dump_index").cast(int),
66
+ )
67
+ return df
68
+
69
+
70
+ def find_row(df, conditions: Dict[str, Any]):
71
+ df_sub = df.filter(
72
+ functools.reduce(
73
+ lambda a, b: a & b,
74
+ [
75
+ pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
76
+ for col in conditions.keys()
77
+ ],
78
+ )
79
+ )
80
+ assert len(df_sub) <= 1
81
+ return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
82
+
83
+
84
+ def _cast_to_polars_dtype(value, target_dtype):
85
+ if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
86
+ return int(value)
87
+ elif target_dtype in (pl.Float64, pl.Float32):
88
+ return float(value)
89
+ elif target_dtype == pl.Boolean:
90
+ return bool(value)
91
+ elif target_dtype == pl.String:
92
+ return str(value)
93
+ else:
94
+ return value
95
+
96
+
97
+ dump_loader = DumpLoader()
@@ -53,7 +53,7 @@ class _Dumper:
53
53
  if self._partial_name is None:
54
54
  self._partial_name = _get_partial_name()
55
55
 
56
- rank = dist.get_rank()
56
+ rank = _get_rank()
57
57
  full_kwargs = dict(
58
58
  forward_pass_id=self._forward_pass_id,
59
59
  rank=rank,
@@ -80,12 +80,20 @@ class _Dumper:
80
80
 
81
81
 
82
82
  def _get_partial_name():
83
- rank = dist.get_rank()
83
+ rank = _get_rank()
84
84
  object_list = [str(time.time()) if rank == 0 else None]
85
- dist.broadcast_object_list(object_list, device="cuda")
85
+ if dist.is_initialized():
86
+ dist.broadcast_object_list(object_list, device="cuda")
86
87
  return object_list[0]
87
88
 
88
89
 
90
+ def _get_rank():
91
+ if dist.is_initialized():
92
+ return dist.get_rank()
93
+ else:
94
+ return 0
95
+
96
+
89
97
  def get_truncated_value(value):
90
98
  if value is None:
91
99
  return None
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import hashlib
2
3
  import json
3
4
  from pathlib import Path
4
5
 
@@ -13,7 +14,11 @@ Supported inputs:
13
14
 
14
15
 
15
16
  def main(args):
16
- df_input = _transform_df_input(_compute_df_raw(args))
17
+ if args.data_type == "simple_evals":
18
+ df_input = _compute_df_input_mode_simple_evals(args)
19
+ else:
20
+ df_input = _transform_df_input(_compute_df_raw(args))
21
+
17
22
  assert all(
18
23
  c in df_input.columns
19
24
  for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
@@ -37,8 +42,9 @@ def main(args):
37
42
  df_meta=df_meta.to_dicts(),
38
43
  df_good_to_bad=df_good_to_bad.to_dicts(),
39
44
  df_bad_to_good=df_bad_to_good.to_dicts(),
40
- )
41
- )
45
+ ),
46
+ indent=4,
47
+ ),
42
48
  )
43
49
 
44
50
  if not args.disable_print_details:
@@ -65,19 +71,70 @@ def main(args):
65
71
  print(df)
66
72
 
67
73
 
74
+ def _compute_df_input_mode_simple_evals(args):
75
+ return pl.concat(
76
+ [
77
+ _compute_df_input_one_mode_simple_evals(**info)
78
+ for info in _get_file_infos(args=args)
79
+ ]
80
+ )
81
+
82
+
83
+ def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
84
+ data = json.loads(Path(path).read_text())
85
+ rows = []
86
+
87
+ for single_eval_result in data["metadata"]["single_eval_results"]:
88
+ prompt = single_eval_result["example_level_metadata"][
89
+ "actual_queried_prompt_messages"
90
+ ]
91
+ score = single_eval_result["score"]
92
+ assert score in {0.0, 1.0}, f"{score=}"
93
+
94
+ row = dict(
95
+ category=category,
96
+ trial_index=trial_index,
97
+ prompt_id=_compute_id_from_object(prompt),
98
+ prompt=json.dumps(prompt),
99
+ output=single_eval_result["example_level_metadata"]["response_text"],
100
+ correct=score == 1.0,
101
+ )
102
+ rows.append(row)
103
+
104
+ return pl.DataFrame(rows)
105
+
106
+
107
+ def _compute_id_from_object(obj):
108
+ if isinstance(obj, pl.Series):
109
+ obj = obj.to_list()
110
+ json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
111
+ return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
112
+
113
+
68
114
  def _compute_df_raw(args):
69
115
  return pl.concat(
70
116
  [
71
- _read_df_raw(p, category=category, trial_index=i)
72
- for category, paths in [
73
- ("baseline", args.baseline_path),
74
- ("target", args.target_path),
75
- ]
76
- for i, p in enumerate(paths)
117
+ _read_df_raw(
118
+ path=info["path"],
119
+ category=info["category"],
120
+ trial_index=info["trial_index"],
121
+ )
122
+ for info in _get_file_infos(args=args)
77
123
  ]
78
124
  )
79
125
 
80
126
 
127
+ def _get_file_infos(args):
128
+ return [
129
+ dict(path=path, category=category, trial_index=trial_index)
130
+ for category, paths in [
131
+ ("baseline", args.baseline_path),
132
+ ("target", args.target_path),
133
+ ]
134
+ for trial_index, path in enumerate(paths)
135
+ ]
136
+
137
+
81
138
  def _read_df_raw(path: str, category: str, trial_index: int):
82
139
  return pl.read_ndjson(path).with_columns(
83
140
  category=pl.lit(category), trial_index=trial_index
@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
108
165
  print("Transform mode: SGLang bench")
109
166
  return df
110
167
  else:
111
- raise Exception(f"Unknown data: {df.columns}")
168
+ raise Exception(
169
+ f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
170
+ )
112
171
 
113
172
 
114
173
  def _compute_df_meta(df_input: pl.DataFrame):
@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
127
186
 
128
187
 
129
188
  def _handle_one_prompt(df_one_prompt: pl.DataFrame):
130
- assert len(set(df_one_prompt["prompt"])) == 1
189
+ assert (
190
+ len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
191
+ )
131
192
 
132
193
  df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
133
194
  df_target = df_one_prompt.filter(pl.col("category") == "target")
@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
162
223
 
163
224
  if __name__ == "__main__":
164
225
  parser = argparse.ArgumentParser(description=_DESCRIPTION)
226
+ parser.add_argument("--data-type", type=str, default="auto")
165
227
  parser.add_argument("--baseline-path", type=str, nargs="+")
166
228
  parser.add_argument("--target-path", type=str, nargs="+")
167
229
  parser.add_argument(
@@ -1,6 +1,12 @@
1
+ import concurrent.futures
1
2
  import logging
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
2
7
 
3
8
  from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
9
+ from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
4
10
  from sglang.srt.disaggregation.mooncake.conn import (
5
11
  MooncakeKVBootstrapServer,
6
12
  MooncakeKVManager,
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
29
35
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
30
36
  )
31
37
 
38
+ def send_kvcache(
39
+ self,
40
+ mooncake_session_id: str,
41
+ prefill_kv_indices: npt.NDArray[np.int32],
42
+ dst_kv_ptrs: list[int],
43
+ dst_kv_indices: npt.NDArray[np.int32],
44
+ executor: concurrent.futures.ThreadPoolExecutor,
45
+ ):
46
+ # Group by indices
47
+ prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
48
+ prefill_kv_indices, dst_kv_indices
49
+ )
50
+
51
+ num_layers = len(self.kv_args.kv_data_ptrs)
52
+ layers_params = [
53
+ (
54
+ self.kv_args.kv_data_ptrs[layer_id],
55
+ dst_kv_ptrs[layer_id],
56
+ self.kv_args.kv_item_lens[layer_id],
57
+ )
58
+ for layer_id in range(num_layers)
59
+ ]
60
+
61
+ def set_transfer_blocks(
62
+ src_ptr: int, dst_ptr: int, item_len: int
63
+ ) -> List[Tuple[int, int, int]]:
64
+ transfer_blocks = []
65
+ for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
66
+ src_addr = src_ptr + int(prefill_index[0]) * item_len
67
+ dst_addr = dst_ptr + int(decode_index[0]) * item_len
68
+ length = item_len * len(prefill_index)
69
+ transfer_blocks.append((src_addr, dst_addr, length))
70
+ return transfer_blocks
71
+
72
+ # Worker function for processing a single layer
73
+ def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
74
+ transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
75
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
76
+
77
+ # Worker function for processing all layers in a batch
78
+ def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
79
+ transfer_blocks = []
80
+ for src_ptr, dst_ptr, item_len in layers_params:
81
+ transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
82
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
83
+
84
+ if self.enable_custom_mem_pool:
85
+ futures = [
86
+ executor.submit(
87
+ process_layer,
88
+ src_ptr,
89
+ dst_ptr,
90
+ item_len,
91
+ )
92
+ for (src_ptr, dst_ptr, item_len) in layers_params
93
+ ]
94
+ for future in concurrent.futures.as_completed(futures):
95
+ status = future.result()
96
+ if status != 0:
97
+ for f in futures:
98
+ f.cancel()
99
+ return status
100
+ else:
101
+ # Combining all layers' params in one batch transfer is more efficient
102
+ # compared to using multiple threads
103
+ return process_layers(layers_params)
104
+
105
+ return 0
106
+
32
107
 
33
108
  class AscendKVSender(MooncakeKVSender):
34
109
  pass
@@ -131,4 +131,4 @@ class BaseKVReceiver(ABC):
131
131
 
132
132
  class BaseKVBootstrapServer(ABC):
133
133
  @abstractmethod
134
- def __init__(self, port: int): ...
134
+ def __init__(self, host: str, port: int): ...
@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager):
47
47
  self.is_mla_backend = is_mla_backend
48
48
  self.disaggregation_mode = disaggregation_mode
49
49
  # for p/d multi node infer
50
+ self.bootstrap_host = server_args.host
50
51
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
51
52
  self.dist_init_addr = server_args.dist_init_addr
52
53
  self.tp_size = server_args.tp_size
@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager):
72
73
  def _register_to_bootstrap(self):
73
74
  """Register KVSender to bootstrap server via HTTP POST."""
74
75
  if self.dist_init_addr:
76
+ # multi node: bootstrap server's host is dist_init_addr
75
77
  if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
76
78
  if self.dist_init_addr.endswith("]"):
77
79
  host = self.dist_init_addr
@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager):
80
82
  else:
81
83
  host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
82
84
  else:
83
- host = get_ip()
85
+ # single node: bootstrap server's host is same as http server's host
86
+ host = self.bootstrap_host
84
87
  host = maybe_wrap_ipv6_address(host)
85
88
 
86
89
  bootstrap_server_url = f"{host}:{self.bootstrap_port}"
@@ -125,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
125
128
  mgr: BaseKVManager,
126
129
  bootstrap_addr: str,
127
130
  bootstrap_room: Optional[int] = None,
128
- data_parallel_rank: Optional[int] = None,
131
+ prefill_dp_rank: Optional[int] = None,
129
132
  ):
130
133
  self.bootstrap_room = bootstrap_room
131
134
  self.bootstrap_addr = bootstrap_addr
132
135
  self.kv_mgr = mgr
133
- self.data_parallel_rank = data_parallel_rank
134
136
 
135
137
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
136
138
  self.prefill_tp_size, self.prefill_dp_size = (
@@ -166,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver):
166
168
  self.required_dst_info_num = 1
167
169
  self.target_tp_ranks = [self.target_tp_rank]
168
170
  elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
169
- assert (
170
- self.kv_mgr.is_mla_backend
171
- ), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
172
171
  self.target_tp_rank = (
173
172
  self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
174
173
  ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
@@ -198,11 +197,14 @@ class CommonKVReceiver(BaseKVReceiver):
198
197
  self.target_tp_rank = self.target_tp_ranks[0]
199
198
  self.required_dst_info_num = 1
200
199
 
201
- if self.data_parallel_rank is not None:
202
- logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
203
- self.target_dp_group = self.data_parallel_rank
200
+ if prefill_dp_rank is not None:
201
+ logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
202
+ self.prefill_dp_rank = prefill_dp_rank
204
203
  else:
205
- self.target_dp_group = bootstrap_room % self.prefill_dp_size
204
+ self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
205
+
206
+ # FIXME: alias here: target_dp_group -> prefill_dp_rank
207
+ self.target_dp_group = self.prefill_dp_rank
206
208
 
207
209
  # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
208
210
  bootstrap_key = (
@@ -308,7 +310,8 @@ class CommonKVReceiver(BaseKVReceiver):
308
310
 
309
311
 
310
312
  class CommonKVBootstrapServer(BaseKVBootstrapServer):
311
- def __init__(self, port: int):
313
+ def __init__(self, host: str, port: int):
314
+ self.host = host
312
315
  self.port = port
313
316
  self.app = web.Application()
314
317
  self.store = dict()
@@ -412,7 +415,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
412
415
  self._runner = web.AppRunner(self.app)
413
416
  self._loop.run_until_complete(self._runner.setup())
414
417
 
415
- site = web.TCPSite(self._runner, port=self.port)
418
+ site = web.TCPSite(self._runner, host=self.host, port=self.port)
416
419
  self._loop.run_until_complete(site.start())
417
420
  self._loop.run_forever()
418
421
  except Exception as e:
@@ -24,7 +24,7 @@ import logging
24
24
  from collections import deque
25
25
  from dataclasses import dataclass
26
26
  from http import HTTPStatus
27
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
28
28
 
29
29
  import torch
30
30
  from torch.distributed import ProcessGroup
@@ -218,8 +218,10 @@ class DecodePreallocQueue:
218
218
 
219
219
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
220
220
  kv_args.gpu_id = self.scheduler.gpu_id
221
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
222
- kv_manager = kv_manager_class(
221
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
222
+ self.transfer_backend, KVClassType.MANAGER
223
+ )
224
+ kv_manager: BaseKVManager = kv_manager_class(
223
225
  kv_args,
224
226
  DisaggregationMode.DECODE,
225
227
  self.scheduler.server_args,
@@ -248,7 +250,7 @@ class DecodePreallocQueue:
248
250
  mgr=self.kv_manager,
249
251
  bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
250
252
  bootstrap_room=req.bootstrap_room,
251
- data_parallel_rank=req.data_parallel_rank,
253
+ prefill_dp_rank=req.data_parallel_rank,
252
254
  )
253
255
 
254
256
  self.queue.append(
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
62
62
  mgr: BaseKVManager,
63
63
  bootstrap_addr: str,
64
64
  bootstrap_room: Optional[int] = None,
65
- data_parallel_rank: Optional[int] = None,
65
+ prefill_dp_rank: Optional[int] = None,
66
66
  ):
67
67
  self.has_init = False
68
68