sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -205,6 +205,14 @@ class ModelConfig:
205
205
  self.hf_config, "image_token_id", None
206
206
  ) or getattr(self.hf_config, "image_token_index", None)
207
207
 
208
+ # matryoshka embeddings
209
+ self.matryoshka_dimensions = getattr(
210
+ self.hf_config, "matryoshka_dimensions", None
211
+ )
212
+ self.is_matryoshka = self.matryoshka_dimensions or getattr(
213
+ self.hf_config, "is_matryoshka", False
214
+ )
215
+
208
216
  @staticmethod
209
217
  def from_server_args(
210
218
  server_args: ServerArgs,
@@ -358,6 +366,13 @@ class ModelConfig:
358
366
  self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
359
367
  self.v_head_dim = self.hf_text_config.v_head_dim
360
368
  self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
369
+ elif "KimiLinearForCausalLM" in self.hf_config.architectures:
370
+ self.head_dim = 72
371
+ self.attention_arch = AttentionArch.MLA
372
+ self.kv_lora_rank = self.hf_config.kv_lora_rank
373
+ self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
374
+ self.v_head_dim = self.hf_config.v_head_dim
375
+ self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
361
376
  else:
362
377
  if (
363
378
  "MistralModel" in self.hf_config.architectures
@@ -582,14 +597,20 @@ class ModelConfig:
582
597
  return
583
598
 
584
599
  # Check if ModelOpt quantization is specified
585
- modelopt_quantization_specified = self.quantization in [
600
+ _MODELOPT_QUANTIZATION_METHODS = [
586
601
  "modelopt",
587
602
  "modelopt_fp8",
588
603
  "modelopt_fp4",
589
604
  ]
605
+ modelopt_quantization_specified = (
606
+ self.quantization in _MODELOPT_QUANTIZATION_METHODS
607
+ )
590
608
 
591
609
  if not modelopt_quantization_specified:
592
- raise ValueError("quantize_and_serve requires ModelOpt quantization")
610
+ raise ValueError(
611
+ "quantize_and_serve requires ModelOpt quantization (set with --quantization "
612
+ f"{{{', '.join(sorted(_MODELOPT_QUANTIZATION_METHODS))}}})"
613
+ )
593
614
 
594
615
  # quantize_and_serve is disabled due to compatibility issues
595
616
  raise NotImplementedError(
@@ -613,6 +634,7 @@ class ModelConfig:
613
634
  "petit_nvfp4",
614
635
  "quark",
615
636
  "mxfp4",
637
+ "auto-round",
616
638
  ]
617
639
  optimized_quantization_methods = [
618
640
  "fp8",
@@ -634,6 +656,7 @@ class ModelConfig:
634
656
  "petit_nvfp4",
635
657
  ]
636
658
  compatible_quantization_methods = {
659
+ "modelopt_fp8": ["modelopt"],
637
660
  "modelopt_fp4": ["modelopt"],
638
661
  "petit_nvfp4": ["modelopt"],
639
662
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
sglang/srt/constants.py CHANGED
@@ -1,3 +1,10 @@
1
1
  # GPU Memory Types
2
2
  GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
3
3
  GPU_MEMORY_TYPE_WEIGHTS = "weights"
4
+ GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph"
5
+
6
+ GPU_MEMORY_ALL_TYPES = [
7
+ GPU_MEMORY_TYPE_KV_CACHE,
8
+ GPU_MEMORY_TYPE_WEIGHTS,
9
+ GPU_MEMORY_TYPE_CUDA_GRAPH,
10
+ ]
@@ -0,0 +1,149 @@
1
+ """
2
+ This file provides a function `register_forward_hook_for_model` that registers a forward hook on every operator of the model.
3
+ After registration, during model inference, all tensors generated throughout the forward pass will be recorded.
4
+
5
+ Usage:
6
+ Specify the output directory for dumping tensors using the argument `--debug-tensor-dump-output-folder`.
7
+ A separate directory will be created for each GPU rank, named in the format `f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{pid}"`.
8
+ Each complete forward pass of the model generates a `.pt` file named `f"Pass{pass_num}.pt"`, which can be loaded using `torch.load`.
9
+ The file contains a series of key-value pairs, where the keys correspond to operator names in the model
10
+ (similar to those in model.safetensors.index.json), and the values are the outputs produced by the respective operators.
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ from pathlib import Path
16
+
17
+ import torch
18
+
19
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class TensorDumper:
26
+ def __init__(
27
+ self, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
28
+ ):
29
+ self._dump_layers = dump_layers
30
+ self._forward_pass_id = 0
31
+ self._pid = os.getpid()
32
+ self._current_tensors = {}
33
+ self._base_dir = Path(dump_dir)
34
+ rank = tp_size * pp_rank + tp_rank
35
+ self._process_dir = (
36
+ self._base_dir / f"TP{tp_rank}_PP{pp_rank}_Rank{rank}_pid{self._pid}"
37
+ )
38
+ self._process_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ def get_dump_dir(self):
41
+ return str(self._process_dir)
42
+
43
+ def add_tensor(self, name, tensor_item):
44
+ if isinstance(tensor_item, (tuple, list)):
45
+ tensors = [t.cpu() for t in tensor_item if t is not None]
46
+ if len(tensors) == 1:
47
+ self._current_tensors[name] = tensors[0]
48
+ else:
49
+ self._current_tensors[name] = tensors
50
+ elif isinstance(tensor_item, torch.Tensor):
51
+ self._current_tensors[name] = tensor_item.cpu()
52
+ elif isinstance(tensor_item, LogitsProcessorOutput):
53
+ self._current_tensors[name] = tensor_item.next_token_logits.cpu()
54
+ elif isinstance(tensor_item, ForwardBatch):
55
+ self._current_tensors[name + ".forward_batch_info.input_ids"] = (
56
+ tensor_item.input_ids.cpu()
57
+ )
58
+ self._current_tensors[name + ".forward_batch_info.seq_lens"] = (
59
+ tensor_item.seq_lens.cpu()
60
+ )
61
+ self._current_tensors[name + ".forward_batch_info.positions"] = (
62
+ tensor_item.positions.cpu()
63
+ )
64
+ elif isinstance(tensor_item, PPProxyTensors):
65
+ for tensor_name in tensor_item.tensors.keys():
66
+ self._current_tensors[name + ".pp_proxy_tensors." + tensor_name] = (
67
+ tensor_item.tensors[tensor_name].cpu()
68
+ )
69
+ else:
70
+ logger.warning(f"Unsupported type: {type(tensor_item)}: {tensor_item}")
71
+
72
+ def dump_current_tensors(self):
73
+ if len(self._current_tensors) == 0:
74
+ return
75
+ tensor_file_for_pass = self._process_dir / f"Pass{self._forward_pass_id:05d}.pt"
76
+ logger.info(
77
+ f"Dump {self._forward_pass_id:05d}th pass to {tensor_file_for_pass}"
78
+ )
79
+ torch.save(self._current_tensors, str(tensor_file_for_pass))
80
+ self._current_tensors = {}
81
+ self._forward_pass_id += 1
82
+
83
+ def _add_hook_recursive(
84
+ self, model, prefix, top_level_module_name, layers_module_name
85
+ ):
86
+ model_top_level_module_matched = False
87
+ layers_prefix = top_level_module_name + "." + layers_module_name
88
+ for name, module in model._modules.items():
89
+ top_level_model = False
90
+ if len(prefix) == 0:
91
+ cur_name = name
92
+ if cur_name == top_level_module_name:
93
+ model_top_level_module_matched = True
94
+ top_level_model = True
95
+ else:
96
+ cur_name = prefix + "." + name
97
+ if self._dump_layers > 0 and name.isdigit() and prefix == layers_prefix:
98
+ # If we only need n layers, skip the reset layers.
99
+ # Most models' layout is like model.layers.0.
100
+ cur_layer = int(name)
101
+ if cur_layer >= self._dump_layers:
102
+ continue
103
+ if module is not None:
104
+ _, sub_count = self._add_hook_recursive(
105
+ module, cur_name, top_level_module_name, layers_module_name
106
+ )
107
+ if sub_count == 0 or top_level_model:
108
+ # Avoid duplicated output hooks, e.g. self_attn may contain:
109
+ # self_attn.qkv_proj, self_attn.attn & self_attn.o_proj.
110
+ # Therefore, we do not need to add output hooks for self_attn,
111
+ # since the output of self_attn should be the same to self_attn.o_proj.
112
+ module.register_forward_hook(
113
+ self._dump_hook(cur_name, top_level_model)
114
+ )
115
+ return model_top_level_module_matched, len(model._modules.items())
116
+
117
+ def _dump_hook(self, tensor_name, do_dump):
118
+ def inner_dump_hook(module, input, output):
119
+ if do_dump:
120
+ # This is the top-level model, so we will record the input for it.
121
+ for item in input:
122
+ if isinstance(item, ForwardBatch):
123
+ self.add_tensor(tensor_name, item)
124
+ self.dump_current_tensors()
125
+ if output is not None:
126
+ self.add_tensor(tensor_name, output)
127
+
128
+ return inner_dump_hook
129
+
130
+
131
+ def register_forward_hook_for_model(
132
+ model, dump_dir: str, dump_layers: int, tp_size: int, tp_rank: int, pp_rank: int
133
+ ):
134
+ tensor_dumper = TensorDumper(dump_dir, dump_layers, tp_size, tp_rank, pp_rank)
135
+ # Most models have the layerout like:
136
+ # XxxxForCausalLM
137
+ # (model): XxxxModel
138
+ # (layers): ModuleList
139
+ # If the model is not constructed with this layout,
140
+ # environment variable can be used to specify the module names.
141
+ top_level_module_name = os.getenv("TENSOR_DUMP_TOP_LEVEL_MODULE_NAME", "model")
142
+ layers_module_name = os.getenv("TENSOR_DUMP_LAYERS_MODULE_NAME", "layers")
143
+ model_top_level_module_matched, _ = tensor_dumper._add_hook_recursive(
144
+ model, "", top_level_module_name, layers_module_name
145
+ )
146
+ assert (
147
+ model_top_level_module_matched
148
+ ), f"model should have a module named {top_level_module_name}"
149
+ return tensor_dumper
@@ -58,6 +58,11 @@ from sglang.srt.mem_cache.memory_pool import (
58
58
  ReqToTokenPool,
59
59
  SWAKVPool,
60
60
  )
61
+ from sglang.srt.tracing.trace import (
62
+ trace_event_batch,
63
+ trace_slice_batch,
64
+ trace_slice_end,
65
+ )
61
66
  from sglang.srt.utils import get_int_env_var, require_mlp_sync
62
67
  from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
63
68
 
@@ -313,6 +318,7 @@ class DecodePreallocQueue:
313
318
  )
314
319
 
315
320
  req.add_latency(RequestStage.DECODE_PREPARE)
321
+ trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
316
322
  self.queue.append(
317
323
  DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
318
324
  )
@@ -521,13 +527,15 @@ class DecodePreallocQueue:
521
527
  decode_req.kv_receiver.init(
522
528
  page_indices, decode_req.metadata_buffer_index, state_indices
523
529
  )
524
- decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
525
530
  preallocated_reqs.append(decode_req)
526
531
  indices_to_remove.add(i)
527
532
  decode_req.req.time_stats.decode_transfer_queue_entry_time = (
528
533
  time.perf_counter()
529
534
  )
530
535
  decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
536
+ trace_slice_end(
537
+ RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True
538
+ )
531
539
 
532
540
  self.queue = [
533
541
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -765,8 +773,12 @@ class DecodeTransferQueue:
765
773
  indices_to_remove.add(i)
766
774
  decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
767
775
 
768
- # special handling for sampling_params.max_new_tokens == 1
769
- if decode_req.req.sampling_params.max_new_tokens == 1:
776
+ # special handling for corner cases
777
+ should_finish = (
778
+ decode_req.req.sampling_params.max_new_tokens == 1
779
+ or output_id in decode_req.req.eos_token_ids
780
+ )
781
+ if should_finish:
770
782
  # finish immediately
771
783
  decode_req.req.time_stats.forward_entry_time = (
772
784
  decode_req.req.time_stats.completion_time
@@ -776,8 +788,19 @@ class DecodeTransferQueue:
776
788
  [decode_req.req], decode_req.req.return_logprob
777
789
  )
778
790
  self.tree_cache.cache_finished_req(decode_req.req)
791
+ trace_slice_end(
792
+ RequestStage.DECODE_QUICK_FINISH,
793
+ decode_req.req.rid,
794
+ thread_finish_flag=True,
795
+ )
779
796
  else:
780
797
  transferred_reqs.append(decode_req.req)
798
+ trace_slice_end(
799
+ RequestStage.DECODE_TRANSFERRED,
800
+ decode_req.req.rid,
801
+ auto_next_anon=True,
802
+ )
803
+
781
804
  elif poll in [
782
805
  KVPoll.Bootstrapping,
783
806
  KVPoll.WaitingForInput,
@@ -823,6 +846,7 @@ class SchedulerDisaggregationDecodeMixin:
823
846
  self.stream_output(
824
847
  batch.reqs, any(req.return_logprob for req in batch.reqs)
825
848
  )
849
+ trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
826
850
  if prepare_mlp_sync_flag:
827
851
  self._prepare_idle_batch_and_run(None)
828
852
  else:
@@ -872,6 +896,7 @@ class SchedulerDisaggregationDecodeMixin:
872
896
  self.stream_output(
873
897
  batch.reqs, any(req.return_logprob for req in batch.reqs)
874
898
  )
899
+ trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs)
875
900
  if prepare_mlp_sync_flag:
876
901
  batch_, batch_result = self._prepare_idle_batch_and_run(
877
902
  None, delay_process=True
@@ -954,6 +979,9 @@ class SchedulerDisaggregationDecodeMixin:
954
979
  self.running_batch = self.update_running_batch(self.running_batch)
955
980
  ret = self.running_batch if not self.running_batch.is_empty() else None
956
981
 
982
+ if ret:
983
+ attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
984
+ trace_event_batch("schedule", ret.reqs, attrs=attrs)
957
985
  return ret
958
986
 
959
987
  def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
@@ -1009,6 +1037,9 @@ class SchedulerDisaggregationDecodeMixin:
1009
1037
  return new_batch
1010
1038
 
1011
1039
  def process_decode_queue(self: Scheduler):
1040
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
1041
+ self.decode_offload_manager.check_offload_progress()
1042
+
1012
1043
  # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
1013
1044
  resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
1014
1045
  self.waiting_queue.extend(resumed_reqs)
@@ -1031,6 +1062,3 @@ class SchedulerDisaggregationDecodeMixin:
1031
1062
  self.disagg_decode_transfer_queue.pop_transferred()
1032
1063
  ) # the requests which kv has arrived
1033
1064
  self.waiting_queue.extend(alloc_reqs)
1034
-
1035
- if self.server_args.disaggregation_decode_enable_offload_kvcache:
1036
- self.decode_offload_manager.check_offload_progress()
@@ -231,8 +231,8 @@ class NixlKVManager(CommonKVManager):
231
231
  ]
232
232
  for k in keys_to_remove:
233
233
  del self.connection_pool[k]
234
- if failed_bootstrap_addr in self.prefill_tp_size_table:
235
- del self.prefill_tp_size_table[failed_bootstrap_addr]
234
+ if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
235
+ del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
236
236
  if failed_bootstrap_addr in self.prefill_dp_size_table:
237
237
  del self.prefill_dp_size_table[failed_bootstrap_addr]
238
238
  if failed_bootstrap_addr in self.prefill_pp_size_table:
@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
53
53
  NSATokenToKVPool,
54
54
  SWAKVPool,
55
55
  )
56
+ from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end
56
57
  from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
57
58
 
58
59
  if TYPE_CHECKING:
@@ -198,6 +199,7 @@ class PrefillBootstrapQueue:
198
199
  self._process_req(req)
199
200
  req.add_latency(RequestStage.PREFILL_PREPARE)
200
201
  self.queue.append(req)
202
+ trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True)
201
203
 
202
204
  def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
203
205
  for req in reqs:
@@ -289,6 +291,10 @@ class PrefillBootstrapQueue:
289
291
  req.time_stats.wait_queue_entry_time = time.perf_counter()
290
292
  req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
291
293
 
294
+ trace_slice_end(
295
+ RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True
296
+ )
297
+
292
298
  self.queue = [
293
299
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
294
300
  ]
@@ -316,6 +322,9 @@ class SchedulerDisaggregationPrefillMixin:
316
322
  )
317
323
  self.process_prefill_chunk()
318
324
  batch = self.get_new_batch_prefill()
325
+ if batch:
326
+ attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
327
+ trace_event_batch("schedule", batch.reqs, attrs=attrs)
319
328
 
320
329
  if require_mlp_sync(self.server_args):
321
330
  batch = self.prepare_mlp_sync_batch(batch)
@@ -348,6 +357,9 @@ class SchedulerDisaggregationPrefillMixin:
348
357
  )
349
358
  self.process_prefill_chunk()
350
359
  batch = self.get_new_batch_prefill()
360
+ if batch:
361
+ attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()}
362
+ trace_event_batch("schedule", batch.reqs, attrs=attrs)
351
363
 
352
364
  if require_mlp_sync(self.server_args):
353
365
  batch = self.prepare_mlp_sync_batch(batch)
@@ -423,6 +435,7 @@ class SchedulerDisaggregationPrefillMixin:
423
435
  req.output_ids.append(next_token_id)
424
436
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
425
437
  req.add_latency(RequestStage.PREFILL_FORWARD)
438
+ trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True)
426
439
  self.disagg_prefill_inflight_queue.append(req)
427
440
  if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
428
441
  req.output_topk_p = batch.spec_info.topk_p[i]
@@ -487,6 +500,9 @@ class SchedulerDisaggregationPrefillMixin:
487
500
 
488
501
  if self.enable_overlap:
489
502
  self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
503
+ trace_slice(
504
+ RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True
505
+ )
490
506
 
491
507
  self.maybe_send_health_check_signal()
492
508
 
@@ -558,6 +574,9 @@ class SchedulerDisaggregationPrefillMixin:
558
574
  req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
559
575
  self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
560
576
  req.metadata_buffer_index = -1
577
+ trace_slice(
578
+ RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True
579
+ )
561
580
 
562
581
  self.disagg_prefill_inflight_queue = undone_reqs
563
582
 
@@ -569,7 +588,7 @@ class SchedulerDisaggregationPrefillMixin:
569
588
  """
570
589
  polls = poll_and_all_reduce(
571
590
  [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
572
- self.tp_worker.get_tp_group().cpu_group,
591
+ self.tp_worker.get_attention_tp_cpu_group(),
573
592
  )
574
593
 
575
594
  transferred_rids: List[str] = []
@@ -703,8 +722,11 @@ class SchedulerDisaggregationPrefillMixin:
703
722
  else:
704
723
  data = None
705
724
 
706
- if self.tp_size != 1:
725
+ if self.attn_tp_size != 1:
707
726
  data = broadcast_pyobj(
708
- data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
727
+ data,
728
+ self.attn_tp_group.rank,
729
+ self.attn_tp_cpu_group,
730
+ src=self.attn_tp_group.ranks[0],
709
731
  )
710
732
  return data
@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
18
18
  is_weak_contiguous,
19
19
  )
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
+ from sglang.srt.environ import envs
21
22
  from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
22
23
 
23
24
  logger = logging.getLogger(__name__)
@@ -210,6 +211,7 @@ class CustomAllreduce:
210
211
  self.register_buffer(self.buffer)
211
212
 
212
213
  self.disabled = False
214
+ self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()
213
215
 
214
216
  @staticmethod
215
217
  def create_shared_buffer(
@@ -394,7 +396,7 @@ class CustomAllreduce:
394
396
  if _is_hip:
395
397
  return self.all_reduce_reg(input)
396
398
  else:
397
- return self.all_reduce(input, registered=True)
399
+ return self.all_reduce(input, registered=not self.tms_cudagraph)
398
400
  else:
399
401
  # If warm up, mimic the allocation pattern since custom
400
402
  # allreduce is out-of-place.
@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
68
68
 
69
69
  @dataclass
70
70
  class GraphCaptureContext:
71
- stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
71
+ stream: torch.get_device_module().Stream
72
72
 
73
73
 
74
74
  @dataclass
@@ -498,7 +498,7 @@ class GroupCoordinator:
498
498
  maybe_pynccl_context = nullcontext()
499
499
  else:
500
500
  maybe_pynccl_context = pynccl_comm.change_state(
501
- enable=True, stream=torch.cuda.current_stream()
501
+ enable=True, stream=torch.get_device_module().current_stream()
502
502
  )
503
503
 
504
504
  pymscclpp_comm = self.pymscclpp_comm
@@ -555,7 +555,7 @@ class GroupCoordinator:
555
555
  and input_.symmetric_memory
556
556
  ):
557
557
  with self.pynccl_comm.change_state(
558
- enable=True, stream=torch.cuda.current_stream()
558
+ enable=True, stream=torch.get_device_module().current_stream()
559
559
  ):
560
560
  self.pynccl_comm.all_reduce(input_)
561
561
  return input_
@@ -655,7 +655,9 @@ class GroupCoordinator:
655
655
  world_size = self.world_size
656
656
  pynccl_comm = self.pynccl_comm
657
657
 
658
- with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
658
+ with pynccl_comm.change_state(
659
+ enable=True, stream=torch.get_device_module().current_stream()
660
+ ):
659
661
  assert (
660
662
  pynccl_comm is not None and not pynccl_comm.disabled
661
663
  ), "pynccl is required for reduce_scatterv"
@@ -779,7 +781,9 @@ class GroupCoordinator:
779
781
  world_size = self.world_size
780
782
  pynccl_comm = self.pynccl_comm
781
783
 
782
- with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
784
+ with pynccl_comm.change_state(
785
+ enable=True, stream=torch.get_device_module().current_stream()
786
+ ):
783
787
  assert (
784
788
  pynccl_comm is not None and not pynccl_comm.disabled
785
789
  ), "pynccl is required for all_gatherv"
@@ -143,10 +143,13 @@ class Engine(EngineBase):
143
143
 
144
144
  # Enable tracing
145
145
  if server_args.enable_trace:
146
- process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
147
- if server_args.disaggregation_mode == "null":
148
- thread_label = "Tokenizer"
149
- trace_set_thread_info(thread_label)
146
+ process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
147
+ thread_label = "Tokenizer"
148
+ if server_args.disaggregation_mode == "prefill":
149
+ thread_label = "Prefill Tokenizer"
150
+ elif server_args.disaggregation_mode == "decode":
151
+ thread_label = "Decode Tokenizer"
152
+ trace_set_thread_info(thread_label)
150
153
 
151
154
  try:
152
155
  self.loop = asyncio.get_running_loop()
@@ -312,6 +315,7 @@ class Engine(EngineBase):
312
315
  image_data: Optional[MultimodalDataInputFormat] = None,
313
316
  audio_data: Optional[MultimodalDataInputFormat] = None,
314
317
  video_data: Optional[MultimodalDataInputFormat] = None,
318
+ dimensions: Optional[int] = None,
315
319
  ) -> Dict:
316
320
  """
317
321
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
@@ -322,6 +326,7 @@ class Engine(EngineBase):
322
326
  image_data=image_data,
323
327
  audio_data=audio_data,
324
328
  video_data=video_data,
329
+ dimensions=dimensions,
325
330
  )
326
331
  generator = self.tokenizer_manager.generate_request(obj, None)
327
332
  ret = self.loop.run_until_complete(generator.__anext__())
@@ -333,6 +338,7 @@ class Engine(EngineBase):
333
338
  image_data: Optional[MultimodalDataInputFormat] = None,
334
339
  audio_data: Optional[MultimodalDataInputFormat] = None,
335
340
  video_data: Optional[MultimodalDataInputFormat] = None,
341
+ dimensions: Optional[int] = None,
336
342
  ) -> Dict:
337
343
  """
338
344
  Asynchronous version of encode method.
@@ -345,6 +351,7 @@ class Engine(EngineBase):
345
351
  image_data=image_data,
346
352
  audio_data=audio_data,
347
353
  video_data=video_data,
354
+ dimensions=dimensions,
348
355
  )
349
356
  generator = self.tokenizer_manager.generate_request(obj, None)
350
357
  return await generator.__anext__()
@@ -670,7 +677,8 @@ class Engine(EngineBase):
670
677
  def _set_envs_and_config(server_args: ServerArgs):
671
678
  # Set global environments
672
679
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
673
- os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
680
+ if "NCCL_CUMEM_ENABLE" not in os.environ:
681
+ os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
674
682
  if not server_args.enable_symm_mem:
675
683
  os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
676
684
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
@@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI):
220
220
 
221
221
  # Init tracing
222
222
  if server_args.enable_trace:
223
- process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
224
- if server_args.disaggregation_mode == "null":
225
- trace_set_thread_info(thread_label)
223
+ process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
224
+ if server_args.disaggregation_mode == "prefill":
225
+ thread_label = "Prefill" + thread_label
226
+ elif server_args.disaggregation_mode == "decode":
227
+ thread_label = "Decode" + thread_label
228
+ trace_set_thread_info(thread_label)
226
229
 
227
230
  # Initialize OpenAI serving handlers
228
231
  fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
@@ -1168,6 +1171,8 @@ async def available_models():
1168
1171
  """Show available models. OpenAI-compatible endpoint."""
1169
1172
  served_model_names = [_global_state.tokenizer_manager.served_model_name]
1170
1173
  model_cards = []
1174
+
1175
+ # Add base model
1171
1176
  for served_model_name in served_model_names:
1172
1177
  model_cards.append(
1173
1178
  ModelCard(
@@ -1176,6 +1181,20 @@ async def available_models():
1176
1181
  max_model_len=_global_state.tokenizer_manager.model_config.context_len,
1177
1182
  )
1178
1183
  )
1184
+
1185
+ # Add loaded LoRA adapters
1186
+ if _global_state.tokenizer_manager.server_args.enable_lora:
1187
+ lora_registry = _global_state.tokenizer_manager.lora_registry
1188
+ for _, lora_ref in lora_registry.get_all_adapters().items():
1189
+ model_cards.append(
1190
+ ModelCard(
1191
+ id=lora_ref.lora_name,
1192
+ root=lora_ref.lora_path,
1193
+ parent=served_model_names[0],
1194
+ max_model_len=None,
1195
+ )
1196
+ )
1197
+
1179
1198
  return ModelList(data=model_cards)
1180
1199
 
1181
1200
 
@@ -37,7 +37,11 @@ from pydantic import (
37
37
  model_validator,
38
38
  )
39
39
  from typing_extensions import Literal
40
- from xgrammar import StructuralTag
40
+
41
+ try:
42
+ from xgrammar import StructuralTag
43
+ except:
44
+ StructuralTag = Any
41
45
 
42
46
  from sglang.utils import convert_json_schema_to_str
43
47
 
@@ -54,6 +58,7 @@ class ModelCard(BaseModel):
54
58
  created: int = Field(default_factory=lambda: int(time.time()))
55
59
  owned_by: str = "sglang"
56
60
  root: Optional[str] = None
61
+ parent: Optional[str] = None
57
62
  max_model_len: Optional[int] = None
58
63
 
59
64
 
@@ -108,6 +113,7 @@ class UsageInfo(BaseModel):
108
113
 
109
114
  class StreamOptions(BaseModel):
110
115
  include_usage: Optional[bool] = False
116
+ continuous_usage_stats: Optional[bool] = False
111
117
 
112
118
 
113
119
  class JsonSchemaResponseFormat(BaseModel):