sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import time
3
5
  from collections import defaultdict
4
- from typing import List, Optional
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
+
8
+ import torch
5
9
 
6
10
  from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
7
11
  from sglang.srt.disaggregation.utils import DisaggregationMode
12
+ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
8
13
  from sglang.srt.managers.schedule_policy import PrefillAdder
9
14
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
15
+ from sglang.srt.managers.utils import DPBalanceMeta
10
16
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
11
17
  from sglang.srt.utils import get_bool_env_var
12
18
 
19
+ if TYPE_CHECKING:
20
+ from sglang.srt.managers.scheduler import Scheduler
21
+
13
22
  logger = logging.getLogger(__name__)
14
23
 
15
24
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@@ -28,7 +37,9 @@ class KvMetrics:
28
37
 
29
38
 
30
39
  class SchedulerMetricsMixin:
31
- def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
40
+ def init_metrics(
41
+ self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
42
+ ):
32
43
  self.last_gen_throughput: float = 0.0
33
44
  self.last_input_throughput: float = 0.0
34
45
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
50
61
  labels["dp_rank"] = dp_rank
51
62
  self.metrics_collector = SchedulerMetricsCollector(labels=labels)
52
63
 
53
- def init_kv_events(self, kv_events_config: Optional[str]):
64
+ def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
65
+ self.balance_meta = dp_balance_meta
66
+ if (
67
+ self.server_args.enable_dp_attention
68
+ and self.server_args.load_balance_method == "minimum_tokens"
69
+ ):
70
+ assert dp_balance_meta is not None
71
+
72
+ self.recv_dp_balance_id_this_term = []
73
+
74
+ def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
54
75
  if self.enable_kv_cache_events:
55
76
  self.kv_event_publisher = EventPublisherFactory.create(
56
77
  kv_events_config, self.attn_dp_rank
57
78
  )
58
79
 
59
80
  def log_prefill_stats(
60
- self,
81
+ self: Scheduler,
61
82
  adder: PrefillAdder,
62
83
  can_run_list: List[Req],
63
84
  running_bs: int,
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
138
159
  self._publish_kv_events()
139
160
 
140
161
  def log_decode_stats(
141
- self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
162
+ self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
142
163
  ):
143
164
  batch = running_batch or self.running_batch
144
165
 
@@ -193,7 +214,7 @@ class SchedulerMetricsMixin:
193
214
  msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
194
215
 
195
216
  msg += (
196
- f"cuda graph: {can_run_cuda_graph}, "
217
+ f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
197
218
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
198
219
  f"#queue-req: {len(self.waiting_queue)}, "
199
220
  )
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
220
241
  self._emit_kv_metrics()
221
242
  self._publish_kv_events()
222
243
 
223
- def _emit_kv_metrics(self):
244
+ def _emit_kv_metrics(self: Scheduler):
224
245
  kv_metrics = KvMetrics()
225
246
  kv_metrics.request_active_slots = self.stats.num_running_reqs
226
247
  kv_metrics.request_total_slots = self.max_running_requests
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
236
257
  if not self.send_metrics_from_scheduler.closed:
237
258
  self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
238
259
 
239
- def _publish_kv_events(self):
260
+ def _publish_kv_events(self: Scheduler):
240
261
  if self.enable_kv_cache_events:
241
262
  events = self.tree_cache.take_events()
242
263
  if events:
243
264
  batch = KVEventBatch(ts=time.time(), events=events)
244
265
  self.kv_event_publisher.publish(batch)
266
+
267
+ def maybe_update_dp_balance_data(
268
+ self: Scheduler, recv_req: TokenizedGenerateReqInput
269
+ ):
270
+ if (
271
+ self.server_args.enable_dp_attention
272
+ and self.server_args.load_balance_method == "minimum_tokens"
273
+ ):
274
+ self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
275
+
276
+ def maybe_handle_dp_balance_data(self: Scheduler):
277
+ if (
278
+ self.server_args.load_balance_method == "minimum_tokens"
279
+ and self.forward_ct % 40 == 0
280
+ ):
281
+ holding_tokens = self.get_load()
282
+
283
+ new_recv_dp_balance_id_list, holding_token_list = (
284
+ self.gather_dp_balance_info(holding_tokens)
285
+ )
286
+
287
+ self.recv_dp_balance_id_this_term.clear()
288
+ if self.tp_rank == 0: # only first worker write info
289
+ self.write_shared_dp_balance_info(
290
+ new_recv_dp_balance_id_list, holding_token_list
291
+ )
292
+
293
+ def gather_dp_balance_info(
294
+ self: Scheduler, holding_tokens_list
295
+ ) -> Union[None, List[List[int]]]:
296
+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
297
+ recv_list = self.recv_dp_balance_id_this_term
298
+ assert len(recv_list) <= 511, (
299
+ "The number of requests received this round is too large. "
300
+ "Please increase gather_tensor_size and onfly_info_size."
301
+ )
302
+ # The maximum size of the tensor used for gathering data from all workers.
303
+ gather_tensor_size = 512
304
+
305
+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
306
+ recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
307
+ recv_tensor[0] = holding_tokens_list
308
+ recv_tensor[1] = len(recv_list) # The first element is the length of the list.
309
+ recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
310
+
311
+ if self.tp_rank == 0:
312
+ gathered_list = [
313
+ torch.zeros(gather_tensor_size, dtype=torch.int32)
314
+ for _ in range(self.balance_meta.num_workers)
315
+ ]
316
+ else:
317
+ gathered_list = None
318
+
319
+ torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
320
+
321
+ gathered_id_list_per_worker = None
322
+ if self.tp_rank == 0:
323
+ gathered_id_list_per_worker = []
324
+ holding_tokens_list = []
325
+ for tensor in gathered_list:
326
+ holding_tokens_list.append(tensor[0].item())
327
+ list_length = tensor[1].item()
328
+ gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
329
+
330
+ return gathered_id_list_per_worker, holding_tokens_list
331
+
332
+ def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
333
+ meta = self.balance_meta
334
+
335
+ with meta.mutex:
336
+ onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
337
+ assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
338
+ # 1.Check if the rid received by each worker this round is present in onfly.
339
+ # If it is, remove the corresponding onfly item.
340
+ worker_id = 0
341
+ for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
342
+ for new_recv_rid in new_recv_rids:
343
+ assert (
344
+ new_recv_rid in on_fly_reqs
345
+ ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
346
+ del on_fly_reqs[new_recv_rid]
347
+ worker_id += 1
348
+ # 2. Atomically write local_tokens and onfly into shm under the mutex
349
+ meta.set_shared_onfly_info(onfly_list)
350
+ meta.set_shared_local_tokens(local_tokens)
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
93
93
  # This updates radix so others can match
94
94
  self.tree_cache.cache_unfinished_req(req)
95
95
 
96
- if req.return_logprob:
96
+ if batch.return_logprob:
97
97
  assert extend_logprob_start_len_per_req is not None
98
98
  assert extend_input_len_per_req is not None
99
99
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
100
100
  extend_input_len = extend_input_len_per_req[i]
101
101
  num_input_logprobs = extend_input_len - extend_logprob_start_len
102
- self.add_logprob_return_values(
103
- i,
104
- req,
105
- logprob_pt,
106
- next_token_ids,
107
- num_input_logprobs,
108
- logits_output,
109
- )
102
+ if req.return_logprob:
103
+ self.add_logprob_return_values(
104
+ i,
105
+ req,
106
+ logprob_pt,
107
+ next_token_ids,
108
+ num_input_logprobs,
109
+ logits_output,
110
+ )
110
111
  logprob_pt += num_input_logprobs
111
112
 
112
113
  if (
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
146
147
  skip_stream_req = req
147
148
 
148
149
  # Incrementally update input logprobs.
149
- if req.return_logprob:
150
+ if batch.return_logprob:
150
151
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
151
152
  extend_input_len = extend_input_len_per_req[i]
152
153
  if extend_logprob_start_len < extend_input_len:
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
154
155
  num_input_logprobs = (
155
156
  extend_input_len - extend_logprob_start_len
156
157
  )
157
- self.add_input_logprob_return_values(
158
- i,
159
- req,
160
- logits_output,
161
- logprob_pt,
162
- num_input_logprobs,
163
- last_prefill_chunk=False,
164
- )
158
+ if req.return_logprob:
159
+ self.add_input_logprob_return_values(
160
+ i,
161
+ req,
162
+ logits_output,
163
+ logprob_pt,
164
+ num_input_logprobs,
165
+ last_prefill_chunk=False,
166
+ )
165
167
  logprob_pt += num_input_logprobs
166
168
 
167
169
  self.set_next_batch_sampling_info_done(batch)
@@ -698,6 +700,8 @@ class SchedulerOutputProcessorMixin:
698
700
  output_token_ids_logprobs_val,
699
701
  output_token_ids_logprobs_idx,
700
702
  output_hidden_states,
703
+ placeholder_tokens_idx=None,
704
+ placeholder_tokens_val=None,
701
705
  )
702
706
  )
703
707
 
@@ -717,6 +721,12 @@ class SchedulerOutputProcessorMixin:
717
721
  cached_tokens.append(req.cached_tokens)
718
722
  self.send_to_detokenizer.send_pyobj(
719
723
  BatchEmbeddingOut(
720
- rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
724
+ rids,
725
+ finished_reasons,
726
+ embeddings,
727
+ prompt_tokens,
728
+ cached_tokens,
729
+ placeholder_tokens_idx=None,
730
+ placeholder_tokens_val=None,
721
731
  )
722
732
  )
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
26
26
 
27
27
  class SchedulerProfilerMixin:
28
28
 
29
- def init_profier(self):
29
+ def init_profiler(self):
30
30
  self.torch_profiler = None
31
31
  self.torch_profiler_output_dir: Optional[str] = None
32
32
  self.profiler_activities: Optional[List[str]] = None
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
121
121
  url = params["url"]
122
122
 
123
123
  worker = self.tp_worker.worker
124
-
125
124
  worker.model_runner.save_remote_model(url)
126
125
 
126
+ if self.draft_worker is not None:
127
+ draft_url = params.get("draft_url", None)
128
+ assert (
129
+ draft_url is not None
130
+ ), "draft_url must be provided when draft model is enabled"
131
+ draft_worker = self.draft_worker.worker
132
+ draft_worker.model_runner.save_remote_model(draft_url)
133
+
127
134
  def save_sharded_model(self, params):
128
135
  worker = self.tp_worker.worker
129
136
 
@@ -24,20 +24,20 @@ import os
24
24
  import re
25
25
  from typing import Optional
26
26
 
27
- from sglang.srt.code_completion_parser import (
27
+ from sglang.srt.parser.code_completion_parser import (
28
28
  CompletionTemplate,
29
29
  FimPosition,
30
30
  completion_template_exists,
31
31
  register_completion_template,
32
32
  )
33
- from sglang.srt.conversation import (
33
+ from sglang.srt.parser.conversation import (
34
34
  Conversation,
35
35
  SeparatorStyle,
36
36
  chat_template_exists,
37
37
  get_conv_template_by_model_path,
38
38
  register_conv_template,
39
39
  )
40
- from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
40
+ from sglang.srt.parser.jinja_template_utils import detect_jinja_template_content_format
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43