sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
3
  from functools import lru_cache
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
5
5
 
6
6
  import dill
7
7
  import orjson
@@ -126,3 +126,69 @@ class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
126
126
  THINKING_START_TOKEN_ID: int = 128798
127
127
  THINKING_END_TOKEN_ID: int = 128799
128
128
  NEW_LINE_TOKEN_ID: int = 201
129
+
130
+
131
+ # Adapted from DeepSeek's implementation: https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/ngram_norepeat.py
132
+ class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
133
+ """Block n-gram repetitions within a sliding window for DeepSeek-OCR outputs."""
134
+
135
+ def __call__(
136
+ self,
137
+ logits: torch.Tensor,
138
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
139
+ ) -> torch.Tensor:
140
+ if not custom_param_list:
141
+ return logits
142
+
143
+ for batch_idx, params in enumerate(custom_param_list):
144
+ if not params:
145
+ continue
146
+
147
+ req = params.get("__req__")
148
+ if req is None:
149
+ continue
150
+
151
+ try:
152
+ ngram_size = int(params.get("ngram_size") or 0)
153
+ window_size = int(params.get("window_size") or 0)
154
+ except (TypeError, ValueError):
155
+ continue
156
+
157
+ if ngram_size <= 0 or window_size <= 0:
158
+ continue
159
+
160
+ sequence: List[int] = req.origin_input_ids + req.output_ids
161
+ if len(sequence) < ngram_size:
162
+ continue
163
+
164
+ search_start = max(0, len(sequence) - window_size)
165
+ search_end = len(sequence) - ngram_size + 1
166
+ if search_end <= search_start:
167
+ continue
168
+
169
+ if ngram_size > 1:
170
+ current_prefix = tuple(sequence[-(ngram_size - 1) :])
171
+ else:
172
+ current_prefix = tuple()
173
+
174
+ banned_tokens: Set[int] = set()
175
+ for idx in range(search_start, search_end):
176
+ ngram = sequence[idx : idx + ngram_size]
177
+ if ngram_size == 1 or tuple(ngram[:-1]) == current_prefix:
178
+ banned_tokens.add(ngram[-1])
179
+
180
+ whitelist_ids = params.get("whitelist_token_ids") or []
181
+ try:
182
+ whitelist = {int(token_id) for token_id in whitelist_ids}
183
+ except (TypeError, ValueError):
184
+ whitelist = set()
185
+
186
+ banned_tokens.difference_update(whitelist)
187
+
188
+ if not banned_tokens:
189
+ continue
190
+
191
+ indices = list(banned_tokens)
192
+ logits[batch_idx, indices] = -float("inf")
193
+
194
+ return logits
@@ -1,9 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.sampling.penaltylib.orchestrator import (
4
- BatchedPenalizerOrchestrator,
5
- _BatchedPenalizer,
6
- )
3
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
7
4
 
8
5
 
9
6
  class BatchedFrequencyPenalizer(_BatchedPenalizer):
@@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
11
8
  Frequency penalizer penalizes tokens based on their frequency in the output.
12
9
  """
13
10
 
14
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
- self.orchestrator = orchestrator
16
- self._is_prepared = False
17
-
18
11
  def _is_required(self) -> bool:
19
12
  return any(
20
13
  req.sampling_params.frequency_penalty != 0.0
@@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
63
56
  [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
64
57
  dim=0,
65
58
  )
59
+
60
+ def _teardown(self) -> None:
61
+ for name in ("frequency_penalties", "cumulated_frequency_penalties"):
62
+ if hasattr(self, name):
63
+ delattr(self, name)
@@ -1,9 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.sampling.penaltylib.orchestrator import (
4
- BatchedPenalizerOrchestrator,
5
- _BatchedPenalizer,
6
- )
3
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
7
4
 
8
5
 
9
6
  class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
11
8
  Min new tokens penalizer penalizes tokens based on the length of the output.
12
9
  """
13
10
 
14
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
- self.orchestrator = orchestrator
16
- self._is_prepared = False
17
-
18
11
  def _is_required(self) -> bool:
19
12
  return any(
20
13
  req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
@@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
92
85
  self.len_output_tokens = torch.cat(
93
86
  [self.len_output_tokens, their.len_output_tokens], dim=0
94
87
  )
88
+
89
+ # Explicit resource cleanup to aid GC and free CUDA memory promptly
90
+ def _teardown(self) -> None:
91
+ for name in ("min_new_tokens", "stop_token_penalties", "len_output_tokens"):
92
+ if hasattr(self, name):
93
+ delattr(self, name)
@@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator:
77
77
  return
78
78
 
79
79
  if len(keep_indices) == 0:
80
- self.is_required = False
81
- for penalizer in self.penalizers.values():
82
- penalizer.teardown()
80
+ # No requests left in the batch, fully release orchestrator resources
81
+ self.release()
83
82
  return
84
83
 
85
84
  is_required = False
@@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator:
92
91
  penalizer.teardown()
93
92
  self.is_required = is_required
94
93
 
94
+ # Resource management helpers
95
+ def release(self) -> None:
96
+ """Release all penalizers and break references so GC can reclaim promptly."""
97
+ for penalizer in self.penalizers.values():
98
+ penalizer.teardown()
99
+ self.penalizers.clear()
100
+ # Break reference to ScheduleBatch
101
+ self._batch_ref = None
102
+ self.is_required = False
103
+
104
+ # Context manager support
105
+ def __enter__(self) -> "BatchedPenalizerOrchestrator":
106
+ return self
107
+
108
+ def __exit__(self, exc_type, exc, tb) -> None:
109
+ self.release()
110
+
95
111
  def merge(self, their: "BatchedPenalizerOrchestrator"):
96
112
  """
97
113
  Merge the penalizers of another orchestrator into this one.
@@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC):
116
132
  An abstract class for a batched penalizer.
117
133
  """
118
134
 
135
+ def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
136
+ self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
137
+ weakref.ref(orchestrator)
138
+ )
139
+ self._is_prepared = False
140
+
141
+ @property
142
+ def orchestrator(self) -> BatchedPenalizerOrchestrator:
143
+ orch: Optional[BatchedPenalizerOrchestrator] = self._orchestrator_ref()
144
+ # This should never happen, but we need to handle it gracefully
145
+ if orch is None:
146
+ raise RuntimeError(
147
+ "BatchedPenalizerOrchestrator has been garbage-collected"
148
+ )
149
+ return orch
150
+
119
151
  def is_prepared(self) -> bool:
120
152
  return self._is_prepared
121
153
 
@@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC):
135
167
  return False
136
168
 
137
169
  def teardown(self):
170
+ self._teardown()
138
171
  self._is_prepared = False
139
172
 
140
173
  def cumulate_output_tokens(self, output_ids: torch.Tensor):
@@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC):
207
240
  Merge the penalizer with another penalizer.
208
241
  """
209
242
  pass
243
+
244
+ @abc.abstractmethod
245
+ def _teardown(self):
246
+ """
247
+ Teardown the penalizer.
248
+ """
249
+ pass
@@ -1,9 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.sampling.penaltylib.orchestrator import (
4
- BatchedPenalizerOrchestrator,
5
- _BatchedPenalizer,
6
- )
3
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
7
4
 
8
5
 
9
6
  class BatchedPresencePenalizer(_BatchedPenalizer):
@@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
11
8
  Presence penalizer penalizes tokens based on their presence in the output.
12
9
  """
13
10
 
14
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
- self.orchestrator = orchestrator
16
- self._is_prepared = False
17
-
18
11
  def _is_required(self) -> bool:
19
12
  return any(
20
13
  req.sampling_params.presence_penalty != 0.0
@@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
63
56
  [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
64
57
  dim=0,
65
58
  )
59
+
60
+ def _teardown(self) -> None:
61
+ for name in ("presence_penalties", "cumulated_presence_penalties"):
62
+ if hasattr(self, name):
63
+ delattr(self, name)
sglang/srt/server_args.py CHANGED
@@ -39,6 +39,7 @@ from sglang.srt.utils.common import (
39
39
  get_device,
40
40
  get_device_memory_capacity,
41
41
  get_device_sm,
42
+ is_blackwell_supported,
42
43
  is_cuda,
43
44
  is_fa3_default_architecture,
44
45
  is_flashinfer_available,
@@ -98,6 +99,7 @@ QUANTIZATION_CHOICES = [
98
99
  "qoq",
99
100
  "w4afp8",
100
101
  "mxfp4",
102
+ "auto-round",
101
103
  "compressed-tensors", # for Ktransformers
102
104
  ]
103
105
 
@@ -133,7 +135,18 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
133
135
 
134
136
  DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
135
137
 
136
- NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]
138
+ RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND = ["fa3", "triton"]
139
+
140
+ DEFAULT_LORA_EVICTION_POLICY = "lru"
141
+
142
+ NSA_CHOICES = [
143
+ "flashmla_sparse",
144
+ "flashmla_kv",
145
+ "flashmla_auto",
146
+ "fa3",
147
+ "tilelang",
148
+ "aiter",
149
+ ]
137
150
 
138
151
  RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
139
152
 
@@ -179,6 +192,10 @@ def add_deterministic_attention_backend_choices(choices):
179
192
  DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
180
193
 
181
194
 
195
+ def add_radix_supported_deterministic_attention_backend_choices(choices):
196
+ RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND.extend(choices)
197
+
198
+
182
199
  def add_radix_eviction_policy_choices(choices):
183
200
  RADIX_EVICTION_POLICY_CHOICES.extend(choices)
184
201
 
@@ -288,7 +305,7 @@ class ServerArgs:
288
305
  enable_request_time_stats_logging: bool = False
289
306
  kv_events_config: Optional[str] = None
290
307
  enable_trace: bool = False
291
- oltp_traces_endpoint: str = "localhost:4317"
308
+ otlp_traces_endpoint: str = "localhost:4317"
292
309
 
293
310
  # API related
294
311
  api_key: Optional[str] = None
@@ -329,7 +346,7 @@ class ServerArgs:
329
346
  max_loaded_loras: Optional[int] = None
330
347
  max_loras_per_batch: int = 8
331
348
  lora_eviction_policy: str = "lru"
332
- lora_backend: str = "triton"
349
+ lora_backend: str = "csgmv"
333
350
  max_lora_chunk_size: Optional[int] = 16
334
351
 
335
352
  # Kernel backend
@@ -494,6 +511,9 @@ class ServerArgs:
494
511
 
495
512
  # Debug tensor dumps
496
513
  debug_tensor_dump_output_folder: Optional[str] = None
514
+ # -1 mean dump all layers.
515
+ debug_tensor_dump_layers: int = -1
516
+ # TODO(guoyuhong): clean the old dumper code.
497
517
  debug_tensor_dump_input_file: Optional[str] = None
498
518
  debug_tensor_dump_inject: bool = False
499
519
 
@@ -522,6 +542,10 @@ class ServerArgs:
522
542
  pdmux_config_path: Optional[str] = None
523
543
  sm_group_num: int = 8
524
544
 
545
+ # For Multi-Modal
546
+ mm_max_concurrent_calls: int = 32
547
+ mm_per_request_timeout: float = 10.0
548
+
525
549
  def __post_init__(self):
526
550
  """
527
551
  Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
@@ -811,7 +835,7 @@ class ServerArgs:
811
835
  capture_bs = (
812
836
  list(range(1, 9, 1))
813
837
  + list(range(10, 33, 2))
814
- + list(range(40, 64, 4))
838
+ + list(range(40, 65, 4))
815
839
  + list(range(72, 257, 8))
816
840
  + list(range(272, self.cuda_graph_max_bs + 1, 16))
817
841
  )
@@ -874,7 +898,7 @@ class ServerArgs:
874
898
  logger.info(
875
899
  "Enable FlashInfer AllReduce Fusion on sm100 for DeepseekV3ForCausalLM"
876
900
  )
877
- if self.moe_runner_backend == "auto":
901
+ if self.moe_a2a_backend == "none" and self.moe_runner_backend == "auto":
878
902
  self.moe_runner_backend = "flashinfer_trtllm"
879
903
  logger.info(
880
904
  "Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM"
@@ -912,7 +936,7 @@ class ServerArgs:
912
936
  f"- Decode: {decode_attn_backend}\n"
913
937
  )
914
938
 
915
- if is_sm100_supported():
939
+ if is_blackwell_supported():
916
940
  if not self.enable_dp_attention:
917
941
  self.enable_flashinfer_allreduce_fusion = True
918
942
  logger.info(
@@ -924,7 +948,7 @@ class ServerArgs:
924
948
  and quantization_config.get("quant_method") == "mxfp4"
925
949
  )
926
950
 
927
- if is_sm100_supported() and is_mxfp4_quant_format:
951
+ if is_blackwell_supported() and is_mxfp4_quant_format:
928
952
  self.moe_runner_backend = "flashinfer_mxfp4"
929
953
  logger.warning(
930
954
  "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
@@ -960,6 +984,12 @@ class ServerArgs:
960
984
  logger.warning(
961
985
  "Use trtllm_mha as attention backend on sm100 for Llama4 model"
962
986
  )
987
+ if is_sm100_supported() and self.moe_runner_backend == "auto":
988
+ if self.quantization in {"fp8", "modelopt_fp8"}:
989
+ self.moe_runner_backend = "flashinfer_trtllm"
990
+ logger.info(
991
+ "Use flashinfer_trtllm as MoE runner backend on SM100 for Llama4"
992
+ )
963
993
  elif model_arch in [
964
994
  "Gemma2ForCausalLM",
965
995
  "Gemma3ForCausalLM",
@@ -998,6 +1028,11 @@ class ServerArgs:
998
1028
  logger.info(
999
1029
  f"Using {self.attention_backend} as attention backend for {model_arch}."
1000
1030
  )
1031
+ elif model_arch in ["KimiLinearForCausalLM"]:
1032
+ logger.warning(
1033
+ f"Disabling Radix Cache for {model_arch} as it is not yet supported."
1034
+ )
1035
+ self.disable_radix_cache = True
1001
1036
 
1002
1037
  if is_deepseek_nsa(hf_config):
1003
1038
  if (
@@ -1020,16 +1055,30 @@ class ServerArgs:
1020
1055
  import torch
1021
1056
 
1022
1057
  major, _ = torch.cuda.get_device_capability()
1023
- if major >= 10:
1024
- self.kv_cache_dtype = "fp8_e4m3"
1025
- logger.warning("Setting KV cache dtype to fp8.")
1058
+ if self.kv_cache_dtype == "auto":
1059
+ self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
1060
+ logger.warning(
1061
+ f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
1062
+ )
1063
+ if self.kv_cache_dtype == "bf16":
1064
+ self.kv_cache_dtype = "bfloat16"
1065
+ assert self.kv_cache_dtype in [
1066
+ "bfloat16",
1067
+ "fp8_e4m3",
1068
+ ], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"
1026
1069
 
1027
1070
  if self.kv_cache_dtype == "fp8_e4m3":
1028
- self.nsa_prefill_backend = "flashmla_kv"
1071
+ # flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
1072
+ self.nsa_prefill_backend = "flashmla_auto"
1029
1073
  self.nsa_decode_backend = "flashmla_kv"
1030
1074
  logger.warning(
1031
- "Setting NSA backend to flashmla_kv for FP8 KV Cache."
1075
+ "Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
1032
1076
  )
1077
+ else:
1078
+ # set prefill/decode backends for Blackwell. The default settings are for Hopper.
1079
+ if major >= 10:
1080
+ self.nsa_prefill_backend = "flashmla_sparse"
1081
+ self.nsa_decode_backend = "flashmla_sparse"
1033
1082
 
1034
1083
  # Logging env vars for NSA
1035
1084
  from sglang.srt.layers.attention.nsa.utils import (
@@ -1144,7 +1193,7 @@ class ServerArgs:
1144
1193
  self.attention_backend == "trtllm_mla"
1145
1194
  or self.decode_attention_backend == "trtllm_mla"
1146
1195
  ):
1147
- if not is_sm100_supported():
1196
+ if not is_blackwell_supported():
1148
1197
  raise ValueError(
1149
1198
  "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
1150
1199
  )
@@ -1196,7 +1245,7 @@ class ServerArgs:
1196
1245
  # AMD platforms backends
1197
1246
  if self.attention_backend == "aiter":
1198
1247
  if model_config.context_len > 8192:
1199
- self.mem_fraction_static *= 0.90
1248
+ self.mem_fraction_static *= 0.85
1200
1249
 
1201
1250
  # NPU platforms backends
1202
1251
  if is_npu() and self.attention_backend in ["ascend"]:
@@ -1311,8 +1360,10 @@ class ServerArgs:
1311
1360
 
1312
1361
  if self.moe_runner_backend == "flashinfer_trtllm":
1313
1362
  assert (
1314
- self.quantization == "modelopt_fp4" or self.quantization == "fp8"
1315
- ), "modelopt_fp4 or fp8 quantization is required for Flashinfer TRTLLM MoE"
1363
+ self.quantization == "modelopt_fp4"
1364
+ or self.quantization == "modelopt_fp8"
1365
+ or self.quantization == "fp8"
1366
+ ), "modelopt_fp4, modelopt_fp8 or fp8 quantization is required for Flashinfer TRTLLM MoE"
1316
1367
  self.disable_shared_experts_fusion = True
1317
1368
  logger.warning(
1318
1369
  "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
@@ -1713,13 +1764,17 @@ class ServerArgs:
1713
1764
  f"but you explicitly specified '{self.attention_backend}'."
1714
1765
  )
1715
1766
 
1716
- if self.attention_backend not in ["fa3", "triton"]:
1717
- if is_deepseek_model:
1767
+ if is_deepseek_model:
1768
+ if self.attention_backend not in ["fa3", "triton"]:
1718
1769
  raise ValueError(
1719
- f"Currently only fa3 and triton attention backends are supported for deterministic inference with DeepSeek models. But you're using {self.attention_backend}."
1770
+ f"Currently only {RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND} attention backends are supported for deterministic inference with DeepSeek models. But you're using {self.attention_backend}."
1720
1771
  )
1721
1772
 
1722
- # Currently, only FA3 and Triton supports radix cache. Support for other backends is in progress
1773
+ if (
1774
+ self.attention_backend
1775
+ not in RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND
1776
+ ):
1777
+ # Currently, only certain backends support radix cache. Support for other backends is in progress
1723
1778
  self.disable_radix_cache = True
1724
1779
  logger.warning(
1725
1780
  f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
@@ -1734,7 +1789,13 @@ class ServerArgs:
1734
1789
  )
1735
1790
 
1736
1791
  def _handle_other_validations(self):
1737
- pass
1792
+ # Handle model inference tensor dump.
1793
+ if self.debug_tensor_dump_output_folder is not None:
1794
+ logger.warning(
1795
+ "Cuda graph and server warmup are disabled because of using tensor dump mode"
1796
+ )
1797
+ self.disable_cuda_graph = True
1798
+ self.skip_server_warmup = True
1738
1799
 
1739
1800
  @staticmethod
1740
1801
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -2315,7 +2376,7 @@ class ServerArgs:
2315
2376
  help="Enable opentelemetry trace",
2316
2377
  )
2317
2378
  parser.add_argument(
2318
- "--oltp-traces-endpoint",
2379
+ "--otlp-traces-endpoint",
2319
2380
  type=str,
2320
2381
  default="localhost:4317",
2321
2382
  help="Config opentelemetry collector endpoint if --enable-trace is set. format: <ip>:<port>",
@@ -3325,6 +3386,12 @@ class ServerArgs:
3325
3386
  default=ServerArgs.debug_tensor_dump_output_folder,
3326
3387
  help="The output folder for dumping tensors.",
3327
3388
  )
3389
+ parser.add_argument(
3390
+ "--debug-tensor-dump-layers",
3391
+ type=int,
3392
+ default=-1,
3393
+ help="The layer number for dumping tensors.",
3394
+ )
3328
3395
  parser.add_argument(
3329
3396
  "--debug-tensor-dump-input-file",
3330
3397
  type=str,
@@ -3461,6 +3528,20 @@ class ServerArgs:
3461
3528
  help="Read CLI options from a config file. Must be a YAML file with configuration options.",
3462
3529
  )
3463
3530
 
3531
+ # For Multi-Modal
3532
+ parser.add_argument(
3533
+ "--mm-max-concurrent-calls",
3534
+ type=int,
3535
+ default=ServerArgs.mm_max_concurrent_calls,
3536
+ help="The max concurrent calls for async mm data processing.",
3537
+ )
3538
+ parser.add_argument(
3539
+ "--mm-per-request-timeout",
3540
+ type=int,
3541
+ default=ServerArgs.mm_per_request_timeout,
3542
+ help="The timeout for each multi-modal request in seconds.",
3543
+ )
3544
+
3464
3545
  @classmethod
3465
3546
  def from_cli_args(cls, args: argparse.Namespace):
3466
3547
  args.tp_size = args.tensor_parallel_size
@@ -98,7 +98,10 @@ def execute_sbo(
98
98
  ):
99
99
  forward_shared_experts()
100
100
 
101
- hidden_states = experts.dispatcher.combine(combine_input=combine_input)
101
+ hidden_states = experts.dispatcher.combine(
102
+ combine_input=combine_input,
103
+ overlap_args=combine_overlap_args,
104
+ )
102
105
 
103
106
  return hidden_states
104
107
 
@@ -49,6 +49,7 @@ class DraftBackendFactory:
49
49
  "trtllm_mha": self._create_trtllm_mha_decode_backend,
50
50
  "trtllm_mla": self._create_trtllm_mla_decode_backend,
51
51
  "nsa": self._create_nsa_decode_backend,
52
+ "ascend": self._create_ascend_decode_backend,
52
53
  }
53
54
 
54
55
  return self._create_backend(
@@ -72,6 +73,7 @@ class DraftBackendFactory:
72
73
  "trtllm_mha": self._create_trtllm_mha_prefill_backend,
73
74
  "trtllm_mla": self._create_trtllm_mla_prefill_backend,
74
75
  "nsa": self._create_nsa_prefill_backend,
76
+ "ascend": self._create_ascend_prefill_backend,
75
77
  }
76
78
  backend_name = (
77
79
  "decode_attention_backend"
@@ -173,6 +175,15 @@ class DraftBackendFactory:
173
175
  self.draft_model_runner, self.topk, self.speculative_num_steps
174
176
  )
175
177
 
178
+ def _create_ascend_decode_backend(self):
179
+ from sglang.srt.layers.attention.ascend_backend import (
180
+ AscendAttnMultiStepDraftBackend,
181
+ )
182
+
183
+ return AscendAttnMultiStepDraftBackend(
184
+ self.draft_model_runner, self.topk, self.speculative_num_steps
185
+ )
186
+
176
187
  def _create_flashinfer_prefill_backend(self):
177
188
  if not get_global_server_args().use_mla_backend:
178
189
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -219,6 +230,11 @@ class DraftBackendFactory:
219
230
 
220
231
  return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
221
232
 
233
+ def _create_ascend_prefill_backend(self):
234
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
235
+
236
+ return AscendAttnBackend(self.draft_model_runner)
237
+
222
238
  def _create_flashmla_prefill_backend(self):
223
239
  logger.warning(
224
240
  "flashmla prefill backend is not yet supported for draft extend."