sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from __future__ import annotations
17
17
 
18
18
  from dataclasses import dataclass
19
19
 
20
- from sglang.srt.configs.mamba_utils import Mamba2CacheParams
20
+ from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, Mamba2CacheParams
21
21
  from sglang.srt.layers.attention.nsa import index_buf_accessor
22
22
  from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
23
23
  from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -33,7 +33,7 @@ KVCache actually holds the physical kv cache.
33
33
 
34
34
  import abc
35
35
  import logging
36
- from contextlib import nullcontext
36
+ from contextlib import contextmanager, nullcontext
37
37
  from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
38
38
 
39
39
  import numpy as np
@@ -59,7 +59,9 @@ if _is_npu:
59
59
  import torch_npu
60
60
 
61
61
 
62
- def get_tensor_size_bytes(t: torch.Tensor):
62
+ def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]):
63
+ if isinstance(t, list):
64
+ return sum(get_tensor_size_bytes(x) for x in t)
63
65
  return np.prod(t.shape) * t.dtype.itemsize
64
66
 
65
67
 
@@ -116,10 +118,15 @@ class ReqToTokenPool:
116
118
  class MambaPool:
117
119
  @dataclass(frozen=True, kw_only=True)
118
120
  class State:
119
- conv: torch.Tensor
121
+ conv: Union[torch.Tensor, List[torch.Tensor]]
120
122
  temporal: torch.Tensor
121
123
 
122
124
  def at_layer_idx(self, layer: int):
125
+ if isinstance(self.conv, list):
126
+ return type(self)(
127
+ conv=[v[layer] for v in self.conv],
128
+ temporal=self.temporal[layer],
129
+ )
123
130
  return type(self)(**{k: v[layer] for k, v in vars(self).items()})
124
131
 
125
132
  def mem_usage_bytes(self):
@@ -127,14 +134,14 @@ class MambaPool:
127
134
 
128
135
  @dataclass(frozen=True, kw_only=True)
129
136
  class SpeculativeState(State):
130
- intermediate_ssm: torch.Tensor
137
+ intermediate_ssm: Union[torch.Tensor, List[torch.Tensor]]
131
138
  intermediate_conv_window: torch.Tensor
132
139
 
133
140
  def __init__(
134
141
  self,
135
142
  *,
136
143
  size: int,
137
- cache_params: "Mamba2CacheParams",
144
+ cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
138
145
  device: str,
139
146
  speculative_num_draft_tokens: Optional[int] = None,
140
147
  ):
@@ -157,18 +164,29 @@ class MambaPool:
157
164
  else:
158
165
  self.custom_mem_pool = None
159
166
 
167
+ self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
160
168
  with (
161
169
  torch.cuda.use_mem_pool(self.custom_mem_pool)
162
170
  if self.enable_custom_mem_pool
163
171
  else nullcontext()
164
172
  ):
165
- # assume conv_state = (dim, state_len)
166
- assert conv_state_shape[0] > conv_state_shape[1]
167
- conv_state = torch.zeros(
168
- size=(num_mamba_layers, size + 1) + conv_state_shape,
169
- dtype=conv_dtype,
170
- device=device,
171
- )
173
+ if self.is_kda_cache:
174
+ conv_state = [
175
+ torch.zeros(
176
+ size=(num_mamba_layers, size + 1) + conv_shape,
177
+ dtype=conv_dtype,
178
+ device=device,
179
+ )
180
+ for conv_shape in conv_state_shape
181
+ ]
182
+ else:
183
+ # assume conv_state = (dim, state_len)
184
+ assert conv_state_shape[0] > conv_state_shape[1]
185
+ conv_state = torch.zeros(
186
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
187
+ dtype=conv_dtype,
188
+ device=device,
189
+ )
172
190
  temporal_state = torch.zeros(
173
191
  size=(num_mamba_layers, size + 1) + temporal_state_shape,
174
192
  dtype=ssm_dtype,
@@ -191,17 +209,34 @@ class MambaPool:
191
209
  )
192
210
  # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
193
211
  # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
194
- intermediate_conv_window_cache = torch.zeros(
195
- size=(
196
- num_mamba_layers,
197
- size + 1,
198
- speculative_num_draft_tokens,
199
- conv_state_shape[0],
200
- conv_state_shape[1],
201
- ),
202
- dtype=conv_dtype,
203
- device="cuda",
204
- )
212
+
213
+ if self.is_kda_cache:
214
+ intermediate_conv_window_cache = [
215
+ torch.zeros(
216
+ size=(
217
+ num_mamba_layers,
218
+ size + 1,
219
+ speculative_num_draft_tokens,
220
+ conv_shape[0],
221
+ conv_shape[1],
222
+ ),
223
+ dtype=conv_dtype,
224
+ device="cuda",
225
+ )
226
+ for conv_shape in conv_state_shape
227
+ ]
228
+ else:
229
+ intermediate_conv_window_cache = torch.zeros(
230
+ size=(
231
+ num_mamba_layers,
232
+ size + 1,
233
+ speculative_num_draft_tokens,
234
+ conv_state_shape[0],
235
+ conv_state_shape[1],
236
+ ),
237
+ dtype=conv_dtype,
238
+ device="cuda",
239
+ )
205
240
  self.mamba_cache = self.SpeculativeState(
206
241
  conv=conv_state,
207
242
  temporal=temporal_state,
@@ -255,15 +290,25 @@ class MambaPool:
255
290
  if free_index.numel() == 0:
256
291
  return
257
292
  self.free_slots = torch.cat((self.free_slots, free_index))
258
- self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
259
- :, free_index
260
- ] = 0
293
+ if self.is_kda_cache:
294
+ for i in range(len(self.mamba_cache.conv)):
295
+ self.mamba_cache.conv[i][:, free_index] = 0
296
+ else:
297
+ self.mamba_cache.conv[:, free_index] = 0
298
+ self.mamba_cache.temporal[:, free_index] = 0
261
299
 
262
300
  def clear(self):
263
301
  self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
264
302
 
265
303
  def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
266
- self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
304
+ if self.is_kda_cache:
305
+ for i in range(len(self.mamba_cache.conv)):
306
+ self.mamba_cache.conv[i][:, dst_index] = self.mamba_cache.conv[i][
307
+ :, src_index
308
+ ]
309
+ else:
310
+ self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
311
+
267
312
  self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
268
313
  :, src_index
269
314
  ]
@@ -304,7 +349,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
304
349
  max_context_len: int,
305
350
  device: str,
306
351
  enable_memory_saver: bool,
307
- cache_params: "Mamba2CacheParams",
352
+ cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
308
353
  speculative_num_draft_tokens: int = None,
309
354
  ):
310
355
  super().__init__(
@@ -323,7 +368,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
323
368
  def _init_mamba_pool(
324
369
  self,
325
370
  size: int,
326
- cache_params: "Mamba2CacheParams",
371
+ cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
327
372
  device: str,
328
373
  speculative_num_draft_tokens: int = None,
329
374
  ):
@@ -509,6 +554,7 @@ class MHATokenToKVPool(KVCache):
509
554
  enable_memory_saver: bool,
510
555
  start_layer: Optional[int] = None,
511
556
  end_layer: Optional[int] = None,
557
+ enable_alt_stream: bool = True,
512
558
  enable_kv_cache_copy: bool = False,
513
559
  ):
514
560
  super().__init__(
@@ -527,7 +573,9 @@ class MHATokenToKVPool(KVCache):
527
573
  self._create_buffers()
528
574
 
529
575
  self.device_module = torch.get_device_module(self.device)
530
- self.alt_stream = self.device_module.Stream() if _is_cuda else None
576
+ self.alt_stream = (
577
+ self.device_module.Stream() if _is_cuda and enable_alt_stream else None
578
+ )
531
579
 
532
580
  if enable_kv_cache_copy:
533
581
  self._init_kv_copy_and_warmup()
@@ -809,6 +857,10 @@ class HybridLinearKVPool(KVCache):
809
857
  enable_kvcache_transpose: bool,
810
858
  device: str,
811
859
  mamba_pool: MambaPool,
860
+ # TODO: refactor mla related args
861
+ use_mla: bool = False,
862
+ kv_lora_rank: int = None,
863
+ qk_rope_head_dim: int = None,
812
864
  ):
813
865
  self.size = size
814
866
  self.dtype = dtype
@@ -822,25 +874,42 @@ class HybridLinearKVPool(KVCache):
822
874
  self.mamba_pool = mamba_pool
823
875
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
824
876
  assert not enable_kvcache_transpose
825
- if _is_npu:
826
- TokenToKVPoolClass = AscendTokenToKVPool
877
+ self.use_mla = use_mla
878
+ if not use_mla:
879
+ if _is_npu:
880
+ TokenToKVPoolClass = AscendTokenToKVPool
881
+ else:
882
+ TokenToKVPoolClass = MHATokenToKVPool
883
+ self.full_kv_pool = TokenToKVPoolClass(
884
+ size=size,
885
+ page_size=self.page_size,
886
+ dtype=dtype,
887
+ head_num=head_num,
888
+ head_dim=head_dim,
889
+ layer_num=self.full_layer_nums,
890
+ device=device,
891
+ enable_memory_saver=False,
892
+ )
827
893
  else:
828
- TokenToKVPoolClass = MHATokenToKVPool
829
- self.full_kv_pool = TokenToKVPoolClass(
830
- size=size,
831
- page_size=self.page_size,
832
- dtype=dtype,
833
- head_num=head_num,
834
- head_dim=head_dim,
835
- layer_num=self.full_layer_nums,
836
- device=device,
837
- enable_memory_saver=False,
838
- )
894
+ TokenToKVPoolClass = MLATokenToKVPool
895
+ self.full_kv_pool = TokenToKVPoolClass(
896
+ size=size,
897
+ page_size=self.page_size,
898
+ dtype=dtype,
899
+ layer_num=self.full_layer_nums,
900
+ device=device,
901
+ kv_lora_rank=kv_lora_rank,
902
+ qk_rope_head_dim=qk_rope_head_dim,
903
+ enable_memory_saver=False,
904
+ )
839
905
  self.full_attention_layer_id_mapping = {
840
906
  id: i for i, id in enumerate(full_attention_layer_ids)
841
907
  }
842
- k_size, v_size = self.get_kv_size_bytes()
843
- self.mem_usage = (k_size + v_size) / GB
908
+ if use_mla:
909
+ self.mem_usage = self.get_kv_size_bytes() / GB
910
+ else:
911
+ k_size, v_size = self.get_kv_size_bytes()
912
+ self.mem_usage = (k_size + v_size) / GB
844
913
 
845
914
  def get_kv_size_bytes(self):
846
915
  return self.full_kv_pool.get_kv_size_bytes()
@@ -876,6 +945,21 @@ class HybridLinearKVPool(KVCache):
876
945
  layer_id = self._transfer_full_attention_id(layer_id)
877
946
  return self.full_kv_pool.get_kv_buffer(layer_id)
878
947
 
948
+ @contextmanager
949
+ def _transfer_id_context(self, layer: RadixAttention):
950
+
951
+ @contextmanager
952
+ def _patch_layer_id(layer):
953
+ original_layer_id = layer.layer_id
954
+ layer.layer_id = self._transfer_full_attention_id(layer.layer_id)
955
+ try:
956
+ yield
957
+ finally:
958
+ layer.layer_id = original_layer_id
959
+
960
+ with _patch_layer_id(layer):
961
+ yield
962
+
879
963
  def set_kv_buffer(
880
964
  self,
881
965
  layer: RadixAttention,
@@ -886,19 +970,49 @@ class HybridLinearKVPool(KVCache):
886
970
  v_scale: float = 1.0,
887
971
  ):
888
972
  layer_id = self._transfer_full_attention_id(layer.layer_id)
889
- self.full_kv_pool.set_kv_buffer(
890
- None,
891
- loc,
892
- cache_k,
893
- cache_v,
894
- k_scale,
895
- v_scale,
896
- layer_id_override=layer_id,
897
- )
973
+ if not self.use_mla:
974
+ self.full_kv_pool.set_kv_buffer(
975
+ None,
976
+ loc,
977
+ cache_k,
978
+ cache_v,
979
+ k_scale,
980
+ v_scale,
981
+ layer_id_override=layer_id,
982
+ )
983
+ else:
984
+ with self._transfer_id_context(layer):
985
+ self.full_kv_pool.set_kv_buffer(
986
+ layer,
987
+ loc,
988
+ cache_k,
989
+ cache_v,
990
+ )
898
991
 
899
992
  def get_v_head_dim(self):
900
993
  return self.full_kv_pool.get_value_buffer(0).shape[-1]
901
994
 
995
+ def set_mla_kv_buffer(
996
+ self,
997
+ layer: RadixAttention,
998
+ loc: torch.Tensor,
999
+ cache_k_nope: torch.Tensor,
1000
+ cache_k_rope: torch.Tensor,
1001
+ ):
1002
+ assert self.use_mla, "set_mla_kv_buffer called when use_mla is False"
1003
+ with self._transfer_id_context(layer):
1004
+ self.full_kv_pool.set_mla_kv_buffer(layer, loc, cache_k_nope, cache_k_rope)
1005
+
1006
+ def get_mla_kv_buffer(
1007
+ self,
1008
+ layer: RadixAttention,
1009
+ loc: torch.Tensor,
1010
+ dst_dtype: Optional[torch.dtype] = None,
1011
+ ):
1012
+ assert self.use_mla, "get_mla_kv_buffer called when use_mla is False"
1013
+ with self._transfer_id_context(layer):
1014
+ return self.full_kv_pool.get_mla_kv_buffer(layer, loc, dst_dtype)
1015
+
902
1016
 
903
1017
  class SWAKVPool(KVCache):
904
1018
  """KV cache with separate pools for full and SWA attention layers."""
@@ -1137,10 +1251,10 @@ class AscendTokenToKVPool(MHATokenToKVPool):
1137
1251
  torch_npu._npu_reshape_and_cache(
1138
1252
  key=cache_k,
1139
1253
  value=cache_v,
1140
- key_cache=self.k_buffer[layer_id].view(
1254
+ key_cache=self.k_buffer[layer_id - self.start_layer].view(
1141
1255
  -1, self.page_size, self.head_num, self.head_dim
1142
1256
  ),
1143
- value_cache=self.v_buffer[layer_id].view(
1257
+ value_cache=self.v_buffer[layer_id - self.start_layer].view(
1144
1258
  -1, self.page_size, self.head_num, self.head_dim
1145
1259
  ),
1146
1260
  slot_indices=loc,
@@ -238,12 +238,16 @@ class MHATokenToKVPoolHost(HostKVCache):
238
238
  raise ValueError(f"Unsupported layout: {self.layout}")
239
239
  self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
240
240
  self.layout_dim = self.token_stride_size * self.layer_num
241
- return torch.empty(
241
+ buffer = torch.empty(
242
242
  dims,
243
243
  dtype=self.dtype,
244
244
  device=self.device,
245
- pin_memory=self.pin_memory,
246
245
  )
246
+ if self.pin_memory:
247
+ torch.cuda.cudart().cudaHostRegister(
248
+ buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
249
+ )
250
+ return buffer
247
251
 
248
252
  @property
249
253
  def k_buffer(self):
@@ -551,13 +555,16 @@ class MLATokenToKVPoolHost(HostKVCache):
551
555
  self.kv_lora_rank + self.qk_rope_head_dim
552
556
  ) * self.dtype.itemsize
553
557
  self.layout_dim = self.token_stride_size * self.layer_num
554
-
555
- return torch.empty(
558
+ buffer = torch.empty(
556
559
  dims,
557
560
  dtype=self.dtype,
558
561
  device=self.device,
559
- pin_memory=self.pin_memory,
560
562
  )
563
+ if self.pin_memory:
564
+ torch.cuda.cudart().cudaHostRegister(
565
+ buffer.data_ptr(), buffer.numel() * buffer.element_size(), 0
566
+ )
567
+ return buffer
561
568
 
562
569
  def load_to_device_per_layer(
563
570
  self, device_pool, host_indices, device_indices, layer_id, io_backend
@@ -533,6 +533,10 @@ class RadixCache(BasePrefixCache):
533
533
  self.protected_size_ -= len(node.key)
534
534
  delta += len(node.key)
535
535
  node.lock_ref -= 1
536
+ if node.parent is None:
537
+ assert (
538
+ node is self.root_node
539
+ ), f"This request holds the node from another tree"
536
540
  node = node.parent
537
541
  return delta
538
542
 
@@ -104,7 +104,7 @@ class MooncakeStoreConfig:
104
104
  device_name=os.getenv("MOONCAKE_DEVICE", ""),
105
105
  master_server_address=os.getenv("MOONCAKE_MASTER"),
106
106
  master_metrics_port=int(
107
- os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_GLOBAL_SEGMENT_SIZE)
107
+ os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_MASTER_METRICS_PORT)
108
108
  ),
109
109
  check_server=bool(os.getenv("MOONCAKE_CHECK_SERVER", DEFAULT_CHECK_SERVER)),
110
110
  )
@@ -811,6 +811,34 @@ class TokenizerMetricsCollector:
811
811
  buckets=bucket_e2e_request_latency,
812
812
  )
813
813
 
814
+ # Retraction count histogram
815
+ self.num_retractions = Histogram(
816
+ name="sglang:num_retractions",
817
+ documentation="Histogram of retraction counts per request.",
818
+ labelnames=labels.keys(),
819
+ buckets=[
820
+ 0,
821
+ 1,
822
+ 2,
823
+ 3,
824
+ 4,
825
+ 5,
826
+ 6,
827
+ 7,
828
+ 8,
829
+ 9,
830
+ 10,
831
+ 15,
832
+ 20,
833
+ 25,
834
+ 30,
835
+ 40,
836
+ 50,
837
+ 75,
838
+ 100,
839
+ ],
840
+ )
841
+
814
842
  def observe_one_finished_request(
815
843
  self,
816
844
  labels: Dict[str, str],
@@ -819,6 +847,7 @@ class TokenizerMetricsCollector:
819
847
  cached_tokens: int,
820
848
  e2e_latency: float,
821
849
  has_grammar: bool,
850
+ retraction_count: int,
822
851
  ):
823
852
  self.prompt_tokens_total.labels(**labels).inc(prompt_tokens)
824
853
  self.generation_tokens_total.labels(**labels).inc(generation_tokens)
@@ -833,6 +862,7 @@ class TokenizerMetricsCollector:
833
862
  self.generation_tokens_histogram.labels(**labels).observe(
834
863
  float(generation_tokens)
835
864
  )
865
+ self.num_retractions.labels(**labels).observe(retraction_count)
836
866
 
837
867
  def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
838
868
  self.histogram_time_to_first_token.labels(**labels).observe(value)
@@ -840,13 +870,13 @@ class TokenizerMetricsCollector:
840
870
  def check_time_to_first_token_straggler(self, value: float) -> bool:
841
871
  his = self.histogram_time_to_first_token.labels(**self.labels)
842
872
  total_observations = sum(bucket._value for bucket in his._buckets)
843
- if total_observations < 1000:
873
+ if total_observations < 100:
844
874
  return False
845
- p999_threshold = total_observations * 0.999
875
+ p99_threshold = total_observations * 0.99
846
876
  cumulative_count = 0
847
877
  for i, bucket in enumerate(his._buckets):
848
878
  cumulative_count += bucket._value
849
- if cumulative_count > p999_threshold:
879
+ if cumulative_count > p99_threshold:
850
880
  return value >= his._upper_bounds[i]
851
881
  return False
852
882
 
@@ -969,3 +999,16 @@ class StorageMetricsCollector:
969
999
  self._log_histogram(self.histogram_prefetch_bandwidth, v)
970
1000
  for v in storage_metrics.backup_bandwidth:
971
1001
  self._log_histogram(self.histogram_backup_bandwidth, v)
1002
+
1003
+
1004
+ class ExpertDispatchCollector:
1005
+ def __init__(self, ep_size: int) -> None:
1006
+ from prometheus_client import Histogram
1007
+
1008
+ ep_size_buckets = [i for i in range(ep_size)]
1009
+ self.eplb_gpu_physical_count = Histogram(
1010
+ name="sglang:eplb_gpu_physical_count",
1011
+ documentation="The selected count of physical experts on each layer and GPU rank.",
1012
+ labelnames={"layer"},
1013
+ buckets=ep_size_buckets,
1014
+ )
@@ -21,12 +21,14 @@ import inspect
21
21
  import logging
22
22
  import os
23
23
  from contextlib import contextmanager
24
+ from functools import partial
24
25
  from typing import TYPE_CHECKING, Callable, Optional, Union
25
26
 
26
27
  import torch
27
28
  import tqdm
28
29
  from torch.profiler import ProfilerActivity, profile
29
30
 
31
+ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
30
32
  from sglang.srt.custom_op import CustomOp
31
33
  from sglang.srt.distributed import get_tensor_model_parallel_rank
32
34
  from sglang.srt.distributed.device_communicators.pynccl_allocator import (
@@ -64,6 +66,7 @@ from sglang.srt.utils import (
64
66
  require_mlp_tp_gather,
65
67
  )
66
68
  from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
69
+ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
67
70
 
68
71
  try:
69
72
  from kt_kernel import AMXMoEWrapper
@@ -320,11 +323,11 @@ class CudaGraphRunner:
320
323
  self.pp_proxy_tensors = {
321
324
  "hidden_states": torch.zeros(
322
325
  (self.max_bs, self.model_runner.model_config.hidden_size),
323
- dtype=torch.bfloat16,
326
+ dtype=self.model_runner.model_config.dtype,
324
327
  ),
325
328
  "residual": torch.zeros(
326
329
  (self.max_bs, self.model_runner.model_config.hidden_size),
327
- dtype=torch.bfloat16,
330
+ dtype=self.model_runner.model_config.dtype,
328
331
  ),
329
332
  }
330
333
 
@@ -518,7 +521,16 @@ class CudaGraphRunner:
518
521
  logger.info(log_message)
519
522
 
520
523
  def _capture_graph(self, graph, pool, stream, run_once_fn):
521
- with self.device_module.graph(graph, pool=pool, stream=stream):
524
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
525
+ enable=self.model_runner.server_args.enable_memory_saver
526
+ and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH")
527
+ )
528
+ graph_fn = (
529
+ partial(memory_saver_adapter.cuda_graph, tag=GPU_MEMORY_TYPE_CUDA_GRAPH)
530
+ if memory_saver_adapter.enabled
531
+ else self.device_module.graph
532
+ )
533
+ with graph_fn(cuda_graph=graph, pool=pool, stream=stream):
522
534
  out = run_once_fn()
523
535
  return out
524
536
 
@@ -90,12 +90,9 @@ class ForwardMode(IntEnum):
90
90
  self == ForwardMode.EXTEND
91
91
  or self == ForwardMode.MIXED
92
92
  or self == ForwardMode.DRAFT_EXTEND
93
- or (
94
- self == ForwardMode.DRAFT_EXTEND_V2
95
- if include_draft_extend_v2
96
- else False
97
- )
93
+ or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
98
94
  or self == ForwardMode.TARGET_VERIFY
95
+ or self == ForwardMode.SPLIT_PREFILL
99
96
  )
100
97
 
101
98
  def is_decode(self):
@@ -114,22 +111,21 @@ class ForwardMode(IntEnum):
114
111
  return self == ForwardMode.TARGET_VERIFY
115
112
 
116
113
  def is_draft_extend(self, include_v2: bool = False):
117
- if include_v2:
118
- return (
119
- self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND
120
- )
121
- return self == ForwardMode.DRAFT_EXTEND
114
+ return self == ForwardMode.DRAFT_EXTEND or (
115
+ include_v2 and self == ForwardMode.DRAFT_EXTEND_V2
116
+ )
122
117
 
123
118
  def is_draft_extend_v2(self):
124
119
  # For fixed shape logits output in v2 eagle worker
125
120
  return self == ForwardMode.DRAFT_EXTEND_V2
126
121
 
127
- def is_extend_or_draft_extend_or_mixed(self):
122
+ def is_extend_or_draft_extend_or_mixed(self, include_draft_extend_v2: bool = False):
128
123
  return (
129
124
  self == ForwardMode.EXTEND
130
125
  or self == ForwardMode.DRAFT_EXTEND
131
126
  or self == ForwardMode.MIXED
132
127
  or self == ForwardMode.SPLIT_PREFILL
128
+ or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
133
129
  )
134
130
 
135
131
  def is_cuda_graph(self):
@@ -319,6 +315,9 @@ class ForwardBatch:
319
315
  tbo_parent_token_range: Optional[Tuple[int, int]] = None
320
316
  tbo_children: Optional[List[ForwardBatch]] = None
321
317
 
318
+ # For matryoshka embeddings
319
+ dimensions: Optional[list[int]] = None
320
+
322
321
  @classmethod
323
322
  def init_new(
324
323
  cls,
@@ -360,6 +359,7 @@ class ForwardBatch:
360
359
  input_embeds=batch.input_embeds,
361
360
  token_type_ids=batch.token_type_ids,
362
361
  tbo_split_seq_index=batch.tbo_split_seq_index,
362
+ dimensions=batch.dimensions,
363
363
  )
364
364
  device = model_runner.device
365
365