sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -205,6 +205,14 @@ class ModelConfig:
205
205
  self.hf_config, "image_token_id", None
206
206
  ) or getattr(self.hf_config, "image_token_index", None)
207
207
 
208
+ # matryoshka embeddings
209
+ self.matryoshka_dimensions = getattr(
210
+ self.hf_config, "matryoshka_dimensions", None
211
+ )
212
+ self.is_matryoshka = self.matryoshka_dimensions or getattr(
213
+ self.hf_config, "is_matryoshka", False
214
+ )
215
+
208
216
  @staticmethod
209
217
  def from_server_args(
210
218
  server_args: ServerArgs,
@@ -358,6 +366,13 @@ class ModelConfig:
358
366
  self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
359
367
  self.v_head_dim = self.hf_text_config.v_head_dim
360
368
  self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
369
+ elif "KimiLinearForCausalLM" in self.hf_config.architectures:
370
+ self.head_dim = 72
371
+ self.attention_arch = AttentionArch.MLA
372
+ self.kv_lora_rank = self.hf_config.kv_lora_rank
373
+ self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
374
+ self.v_head_dim = self.hf_config.v_head_dim
375
+ self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
361
376
  else:
362
377
  if (
363
378
  "MistralModel" in self.hf_config.architectures
@@ -535,7 +550,7 @@ class ModelConfig:
535
550
  quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
536
551
  return quant_cfg
537
552
 
538
- def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
553
+ def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:
539
554
  """Parse ModelOpt quantization config and return the appropriate quant_method."""
540
555
  json_quant_configs = quant_config_dict["quantization"]
541
556
  quant_algo = json_quant_configs.get("quant_algo", None)
@@ -547,8 +562,7 @@ class ModelConfig:
547
562
  elif quant_algo and "FP8" in quant_algo:
548
563
  return {"quant_method": "modelopt_fp8"}
549
564
  else:
550
- # Default to FP8 for backward compatibility
551
- return {"quant_method": "modelopt_fp8"}
565
+ return None
552
566
 
553
567
  def _is_already_quantized(self) -> bool:
554
568
  """Check if the model is already quantized based on config files."""
@@ -583,14 +597,20 @@ class ModelConfig:
583
597
  return
584
598
 
585
599
  # Check if ModelOpt quantization is specified
586
- modelopt_quantization_specified = self.quantization in [
600
+ _MODELOPT_QUANTIZATION_METHODS = [
587
601
  "modelopt",
588
602
  "modelopt_fp8",
589
603
  "modelopt_fp4",
590
604
  ]
605
+ modelopt_quantization_specified = (
606
+ self.quantization in _MODELOPT_QUANTIZATION_METHODS
607
+ )
591
608
 
592
609
  if not modelopt_quantization_specified:
593
- raise ValueError("quantize_and_serve requires ModelOpt quantization")
610
+ raise ValueError(
611
+ "quantize_and_serve requires ModelOpt quantization (set with --quantization "
612
+ f"{{{', '.join(sorted(_MODELOPT_QUANTIZATION_METHODS))}}})"
613
+ )
594
614
 
595
615
  # quantize_and_serve is disabled due to compatibility issues
596
616
  raise NotImplementedError(
@@ -614,6 +634,7 @@ class ModelConfig:
614
634
  "petit_nvfp4",
615
635
  "quark",
616
636
  "mxfp4",
637
+ "auto-round",
617
638
  ]
618
639
  optimized_quantization_methods = [
619
640
  "fp8",
@@ -635,6 +656,7 @@ class ModelConfig:
635
656
  "petit_nvfp4",
636
657
  ]
637
658
  compatible_quantization_methods = {
659
+ "modelopt_fp8": ["modelopt"],
638
660
  "modelopt_fp4": ["modelopt"],
639
661
  "petit_nvfp4": ["modelopt"],
640
662
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
@@ -806,7 +828,7 @@ def _get_and_verify_dtype(
806
828
  ) -> torch.dtype:
807
829
  # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
808
830
  # because config.torch_dtype can be None.
809
- config_dtype = getattr(config, "torch_dtype", None)
831
+ config_dtype = getattr(config, "dtype", None)
810
832
  if isinstance(config_dtype, str):
811
833
  config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
812
834
  if config_dtype is None:
@@ -915,12 +937,13 @@ multimodal_model_archs = [
915
937
  "InternVLChatModel",
916
938
  "InternS1ForConditionalGeneration",
917
939
  "Phi4MMForCausalLM",
918
- "VILAForConditionalGeneration",
919
940
  "Step3VLForConditionalGeneration",
920
941
  "POINTSV15ChatModel",
921
942
  "DotsVLMForCausalLM",
922
943
  "DotsOCRForCausalLM",
923
944
  "Sarashina2VisionForCausalLM",
945
+ "NVILAForConditionalGeneration",
946
+ "NVILALiteForConditionalGeneration",
924
947
  "DeepseekOCRForCausalLM",
925
948
  ]
926
949
 
sglang/srt/constants.py CHANGED
@@ -1,3 +1,10 @@
1
1
  # GPU Memory Types
2
2
  GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
3
3
  GPU_MEMORY_TYPE_WEIGHTS = "weights"
4
+ GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph"
5
+
6
+ GPU_MEMORY_ALL_TYPES = [
7
+ GPU_MEMORY_TYPE_KV_CACHE,
8
+ GPU_MEMORY_TYPE_WEIGHTS,
9
+ GPU_MEMORY_TYPE_CUDA_GRAPH,
10
+ ]
@@ -0,0 +1,149 @@
1
+ """
2
+ This file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.
3
+ After registration, during model inference, all tensors generated throughout the forward pass will be recorded.
4
+
5
+ Usage:
6
+ Specify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.
7
+ A separate directory will be created for each GPU rank, named in the format `f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}"`.
8
+ Each complete forward pass of the model generates a `.pt` file named `f"Pass{pass_num}.pt"`, which can be loaded using `torch.load`.
9
+ The file contains a series of key-value pairs, where the keys correspond to operator names in the model
10
+ (similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ from pathlib import Path
16
+
17
+ import torch
18
+
19
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class TensorDumper:
26
+ def __init__(
27
+ self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
28
+ ):
29
+ self._dump_layers = dump_layers
30
+ self._forward_pass_id = 0
31
+ self._pid = os.getpid()
32
+ self._current_tensors = {}
33
+ self._base_dir = Path(dump_dir)
34
+ rank = tp_size * pp_rank + tp_rank
35
+ self._process_dir = (
36
+ self._base_dir / f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{self._pid}"
37
+ )
38
+ self._process_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ def get_dump_dir(self):
41
+ return str(self._process_dir)
42
+
43
+ def add_tensor(self, name, tensor_item):
44
+ if isinstance(tensor_item, (tuple, list)):
45
+ tensors = [t.cpu() for t in tensor_item if t is not None]
46
+ if len(tensors) == 1:
47
+ self._current_tensors[name] = tensors[0]
48
+ else:
49
+ self._current_tensors[name] = tensors
50
+ elif isinstance(tensor_item, torch.Tensor):
51
+ self._current_tensors[name] = tensor_item.cpu()
52
+ elif isinstance(tensor_item, LogitsProcessorOutput):
53
+ self._current_tensors[name] = tensor_item.next_token_logits.cpu()
54
+ elif isinstance(tensor_item, ForwardBatch):
55
+ self._current_tensors[name + ".forward_batch_info.input_ids"] = (
56
+ tensor_item.input_ids.cpu()
57
+ )
58
+ self._current_tensors[name + ".forward_batch_info.seq_lens"] = (
59
+ tensor_item.seq_lens.cpu()
60
+ )
61
+ self._current_tensors[name + ".forward_batch_info.positions"] = (
62
+ tensor_item.positions.cpu()
63
+ )
64
+ elif isinstance(tensor_item, PPProxyTensors):
65
+ for tensor_name in tensor_item.tensors.keys():
66
+ self._current_tensors[name + ".pp_proxy_tensors." + tensor_name] = (
67
+ tensor_item.tensors[tensor_name].cpu()
68
+ )
69
+ else:
70
+ logger.warning(f"Unsupported type: {type(tensor_item)}: {tensor_item}")
71
+
72
+ def dump_current_tensors(self):
73
+ if len(self._current_tensors) == 0:
74
+ return
75
+ tensor_file_for_pass = self._process_dir / f"Pass{self._forward_pass_id:05d}.pt"
76
+ logger.info(
77
+ f"Dump {self._forward_pass_id:05d}th pass to {tensor_file_for_pass}"
78
+ )
79
+ torch.save(self._current_tensors, str(tensor_file_for_pass))
80
+ self._current_tensors = {}
81
+ self._forward_pass_id += 1
82
+
83
+ def _add_hook_recursive(
84
+ self, model, prefix, top_level_module_name, layers_module_name
85
+ ):
86
+ model_top_level_module_matched = False
87
+ layers_prefix = top_level_module_name + "." + layers_module_name
88
+ for name, module in model._modules.items():
89
+ top_level_model = False
90
+ if len(prefix) == 0:
91
+ cur_name = name
92
+ if cur_name == top_level_module_name:
93
+ model_top_level_module_matched = True
94
+ top_level_model = True
95
+ else:
96
+ cur_name = prefix + "." + name
97
+ if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix:
98
+ # If we only need n layers, skip the reset layers.
99
+ # Most models' layout is like model.layers.0.
100
+ cur_layer = int(name)
101
+ if cur_layer >= self._dump_layers:
102
+ continue
103
+ if module is not None:
104
+ _, sub_count = self._add_hook_recursive(
105
+ module, cur_name, top_level_module_name, layers_module_name
106
+ )
107
+ if sub_count == 0 or top_level_model:
108
+ # Avoid duplicated output hooks, e.g. self_attn may contain:
109
+ # self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.
110
+ # Therefore, we do not need to add output hooks for self_attn,
111
+ # since the output of self_attn should be the same to self_attn.o_proj.
112
+ module.register_forward_hook(
113
+ self._dump_hook(cur_name, top_level_model)
114
+ )
115
+ return model_top_level_module_matched, len(model._modules.items())
116
+
117
+ def _dump_hook(self, tensor_name, do_dump):
118
+ def inner_dump_hook(module, input, output):
119
+ if do_dump:
120
+ # This is the top-level model, so we will record the input for it.
121
+ for item in input:
122
+ if isinstance(item, ForwardBatch):
123
+ self.add_tensor(tensor_name, item)
124
+ self.dump_current_tensors()
125
+ if output is not None:
126
+ self.add_tensor(tensor_name, output)
127
+
128
+ return inner_dump_hook
129
+
130
+
131
+ def register_forward_hook_for_model(
132
+ model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
133
+ ):
134
+ tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
135
+ # Most models have the layerout like:
136
+ # XxxxForCausalLM
137
+ # (model): XxxxModel
138
+ # (layers): ModuleList
139
+ # If the model is not constructed with this layout,
140
+ # environment variable can be used to specify the module names.
141
+ top_level_module_name = os.getenv("TENSOR_DUMP_TOP_LEVEL_MODULE_NAME", "model")
142
+ layers_module_name = os.getenv("TENSOR_DUMP_LAYERS_MODULE_NAME", "layers")
143
+ model_top_level_module_matched, _ = tensor_dumper._add_hook_recursive(
144
+ model, "", top_level_module_name, layers_module_name
145
+ )
146
+ assert (
147
+ model_top_level_module_matched
148
+ ), f"model should have a module named {top_level_module_name}"
149
+ return tensor_dumper
@@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import (
58
58
  ReqToTokenPool,
59
59
  SWAKVPool,
60
60
  )
61
+ from sglang.srt.tracing.trace import (
62
+ trace_event_batch,
63
+ trace_slice_batch,
64
+ trace_slice_end,
65
+ )
61
66
  from sglang.srt.utils import get_int_env_var, require_mlp_sync
62
67
  from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
63
68
 
@@ -313,6 +318,7 @@ class DecodePreallocQueue:
313
318
  )
314
319
 
315
320
  req.add_latency(RequestStage.DECODE_PREPARE)
321
+ trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
316
322
  self.queue.append(
317
323
  DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
318
324
  )
@@ -521,13 +527,15 @@ class DecodePreallocQueue:
521
527
  decode_req.kv_receiver.init(
522
528
  page_indices, decode_req.metadata_buffer_index, state_indices
523
529
  )
524
- decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
525
530
  preallocated_reqs.append(decode_req)
526
531
  indices_to_remove.add(i)
527
532
  decode_req.req.time_stats.decode_transfer_queue_entry_time = (
528
533
  time.perf_counter()
529
534
  )
530
535
  decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
536
+ trace_slice_end(
537
+ RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
538
+ )
531
539
 
532
540
  self.queue = [
533
541
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -765,8 +773,12 @@ class DecodeTransferQueue:
765
773
  indices_to_remove.add(i)
766
774
  decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
767
775
 
768
- # special handling for sampling_params.max_new_tokens == 1
769
- if decode_req.req.sampling_params.max_new_tokens == 1:
776
+ # special handling for corner cases
777
+ should_finish = (
778
+ decode_req.req.sampling_params.max_new_tokens == 1
779
+ or output_id in decode_req.req.eos_token_ids
780
+ )
781
+ if should_finish:
770
782
  # finish immediately
771
783
  decode_req.req.time_stats.forward_entry_time = (
772
784
  decode_req.req.time_stats.completion_time
@@ -776,8 +788,19 @@ class DecodeTransferQueue:
776
788
  [decode_req.req], decode_req.req.return_logprob
777
789
  )
778
790
  self.tree_cache.cache_finished_req(decode_req.req)
791
+ trace_slice_end(
792
+ RequestStage.DECODE_QUICK_FINISH,
793
+ decode_req.req.rid,
794
+ thread_finish_flag=True,
795
+ )
779
796
  else:
780
797
  transferred_reqs.append(decode_req.req)
798
+ trace_slice_end(
799
+ RequestStage.DECODE_TRANSFERRED,
800
+ decode_req.req.rid,
801
+ auto_next_anon=True,
802
+ )
803
+
781
804
  elif poll in [
782
805
  KVPoll.Bootstrapping,
783
806
  KVPoll.WaitingForInput,
@@ -823,6 +846,7 @@ class SchedulerDisaggregationDecodeMixin:
823
846
  self.stream_output(
824
847
  batch.reqs, any(req.return_logprob for req in batch.reqs)
825
848
  )
849
+ trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
826
850
  if prepare_mlp_sync_flag:
827
851
  self._prepare_idle_batch_and_run(None)
828
852
  else:
@@ -872,6 +896,7 @@ class SchedulerDisaggregationDecodeMixin:
872
896
  self.stream_output(
873
897
  batch.reqs, any(req.return_logprob for req in batch.reqs)
874
898
  )
899
+ trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
875
900
  if prepare_mlp_sync_flag:
876
901
  batch_, batch_result = self._prepare_idle_batch_and_run(
877
902
  None, delay_process=True
@@ -954,6 +979,9 @@ class SchedulerDisaggregationDecodeMixin:
954
979
  self.running_batch = self.update_running_batch(self.running_batch)
955
980
  ret = self.running_batch if not self.running_batch.is_empty() else None
956
981
 
982
+ if ret:
983
+ attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
984
+ trace_event_batch("schedule", ret.reqs, attrs=attrs)
957
985
  return ret
958
986
 
959
987
  def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
@@ -1009,6 +1037,9 @@ class SchedulerDisaggregationDecodeMixin:
1009
1037
  return new_batch
1010
1038
 
1011
1039
  def process_decode_queue(self: Scheduler):
1040
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
1041
+ self.decode_offload_manager.check_offload_progress()
1042
+
1012
1043
  # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
1013
1044
  resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
1014
1045
  self.waiting_queue.extend(resumed_reqs)
@@ -1031,6 +1062,3 @@ class SchedulerDisaggregationDecodeMixin:
1031
1062
  self.disagg_decode_transfer_queue.pop_transferred()
1032
1063
  ) # the requests which kv has arrived
1033
1064
  self.waiting_queue.extend(alloc_reqs)
1034
-
1035
- if self.server_args.disaggregation_decode_enable_offload_kvcache:
1036
- self.decode_offload_manager.check_offload_progress()
@@ -231,8 +231,8 @@ class NixlKVManager(CommonKVManager):
231
231
  ]
232
232
  for k in keys_to_remove:
233
233
  del self.connection_pool[k]
234
- if failed_bootstrap_addr in self.prefill_tp_size_table:
235
- del self.prefill_tp_size_table[failed_bootstrap_addr]
234
+ if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
235
+ del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
236
236
  if failed_bootstrap_addr in self.prefill_dp_size_table:
237
237
  del self.prefill_dp_size_table[failed_bootstrap_addr]
238
238
  if failed_bootstrap_addr in self.prefill_pp_size_table:
@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
53
53
  NSATokenToKVPool,
54
54
  SWAKVPool,
55
55
  )
56
+ from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
56
57
  from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
57
58
 
58
59
  if TYPE_CHECKING:
@@ -198,6 +199,7 @@ class PrefillBootstrapQueue:
198
199
  self._process_req(req)
199
200
  req.add_latency(RequestStage.PREFILL_PREPARE)
200
201
  self.queue.append(req)
202
+ trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)
201
203
 
202
204
  def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
203
205
  for req in reqs:
@@ -289,6 +291,10 @@ class PrefillBootstrapQueue:
289
291
  req.time_stats.wait_queue_entry_time = time.perf_counter()
290
292
  req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
291
293
 
294
+ trace_slice_end(
295
+ RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
296
+ )
297
+
292
298
  self.queue = [
293
299
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
294
300
  ]
@@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin:
316
322
  )
317
323
  self.process_prefill_chunk()
318
324
  batch = self.get_new_batch_prefill()
325
+ if batch:
326
+ attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
327
+ trace_event_batch("schedule", batch.reqs, attrs=attrs)
319
328
 
320
329
  if require_mlp_sync(self.server_args):
321
330
  batch = self.prepare_mlp_sync_batch(batch)
@@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin:
348
357
  )
349
358
  self.process_prefill_chunk()
350
359
  batch = self.get_new_batch_prefill()
360
+ if batch:
361
+ attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
362
+ trace_event_batch("schedule", batch.reqs, attrs=attrs)
351
363
 
352
364
  if require_mlp_sync(self.server_args):
353
365
  batch = self.prepare_mlp_sync_batch(batch)
@@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin:
423
435
  req.output_ids.append(next_token_id)
424
436
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
425
437
  req.add_latency(RequestStage.PREFILL_FORWARD)
438
+ trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
426
439
  self.disagg_prefill_inflight_queue.append(req)
427
440
  if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
428
441
  req.output_topk_p = batch.spec_info.topk_p[i]
@@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin:
487
500
 
488
501
  if self.enable_overlap:
489
502
  self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
503
+ trace_slice(
504
+ RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
505
+ )
490
506
 
491
507
  self.maybe_send_health_check_signal()
492
508
 
@@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin:
558
574
  req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
559
575
  self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
560
576
  req.metadata_buffer_index = -1
577
+ trace_slice(
578
+ RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
579
+ )
561
580
 
562
581
  self.disagg_prefill_inflight_queue = undone_reqs
563
582
 
@@ -569,7 +588,7 @@ class SchedulerDisaggregationPrefillMixin:
569
588
  """
570
589
  polls = poll_and_all_reduce(
571
590
  [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
572
- self.tp_worker.get_tp_group().cpu_group,
591
+ self.tp_worker.get_attention_tp_cpu_group(),
573
592
  )
574
593
 
575
594
  transferred_rids: List[str] = []
@@ -703,8 +722,11 @@ class SchedulerDisaggregationPrefillMixin:
703
722
  else:
704
723
  data = None
705
724
 
706
- if self.tp_size != 1:
725
+ if self.attn_tp_size != 1:
707
726
  data = broadcast_pyobj(
708
- data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
727
+ data,
728
+ self.attn_tp_group.rank,
729
+ self.attn_tp_cpu_group,
730
+ src=self.attn_tp_group.ranks[0],
709
731
  )
710
732
  return data
@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
18
18
  is_weak_contiguous,
19
19
  )
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
+ from sglang.srt.environ import envs
21
22
  from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
22
23
 
23
24
  logger = logging.getLogger(__name__)
@@ -210,6 +211,7 @@ class CustomAllreduce:
210
211
  self.register_buffer(self.buffer)
211
212
 
212
213
  self.disabled = False
214
+ self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()
213
215
 
214
216
  @staticmethod
215
217
  def create_shared_buffer(
@@ -394,7 +396,7 @@ class CustomAllreduce:
394
396
  if _is_hip:
395
397
  return self.all_reduce_reg(input)
396
398
  else:
397
- return self.all_reduce(input, registered=True)
399
+ return self.all_reduce(input, registered=not self.tms_cudagraph)
398
400
  else:
399
401
  # If warm up, mimic the allocation pattern since custom
400
402
  # allreduce is out-of-place.
@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
68
68
 
69
69
  @dataclass
70
70
  class GraphCaptureContext:
71
- stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
71
+ stream: torch.get_device_module().Stream
72
72
 
73
73
 
74
74
  @dataclass
@@ -340,17 +340,10 @@ class GroupCoordinator:
340
340
  self.qr_comm: Optional[QuickAllReduce] = None
341
341
  if use_custom_allreduce and self.world_size > 1:
342
342
  # Initialize a custom fast all-reduce implementation.
343
- if torch_compile is not None and torch_compile:
344
- # For piecewise CUDA graph, the requirement for custom allreduce is larger to
345
- # avoid illegal cuda memory access.
346
- ca_max_size = 256 * 1024 * 1024
347
- else:
348
- ca_max_size = 8 * 1024 * 1024
349
343
  try:
350
344
  self.ca_comm = CustomAllreduce(
351
345
  group=self.cpu_group,
352
346
  device=self.device,
353
- max_size=ca_max_size,
354
347
  )
355
348
  except Exception as e:
356
349
  logger.warning(
@@ -505,7 +498,7 @@ class GroupCoordinator:
505
498
  maybe_pynccl_context = nullcontext()
506
499
  else:
507
500
  maybe_pynccl_context = pynccl_comm.change_state(
508
- enable=True, stream=torch.cuda.current_stream()
501
+ enable=True, stream=torch.get_device_module().current_stream()
509
502
  )
510
503
 
511
504
  pymscclpp_comm = self.pymscclpp_comm
@@ -562,7 +555,7 @@ class GroupCoordinator:
562
555
  and input_.symmetric_memory
563
556
  ):
564
557
  with self.pynccl_comm.change_state(
565
- enable=True, stream=torch.cuda.current_stream()
558
+ enable=True, stream=torch.get_device_module().current_stream()
566
559
  ):
567
560
  self.pynccl_comm.all_reduce(input_)
568
561
  return input_
@@ -662,7 +655,9 @@ class GroupCoordinator:
662
655
  world_size = self.world_size
663
656
  pynccl_comm = self.pynccl_comm
664
657
 
665
- with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
658
+ with pynccl_comm.change_state(
659
+ enable=True, stream=torch.get_device_module().current_stream()
660
+ ):
666
661
  assert (
667
662
  pynccl_comm is not None and not pynccl_comm.disabled
668
663
  ), "pynccl is required for reduce_scatterv"
@@ -786,7 +781,9 @@ class GroupCoordinator:
786
781
  world_size = self.world_size
787
782
  pynccl_comm = self.pynccl_comm
788
783
 
789
- with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
784
+ with pynccl_comm.change_state(
785
+ enable=True, stream=torch.get_device_module().current_stream()
786
+ ):
790
787
  assert (
791
788
  pynccl_comm is not None and not pynccl_comm.disabled
792
789
  ), "pynccl is required for all_gatherv"