sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Radix attention."""
15
15
 
16
+ from enum import Enum
16
17
  from typing import Optional
17
18
 
18
19
  from torch import nn
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
24
 
24
25
 
26
+ class AttentionType(Enum):
27
+ """
28
+ Attention type.
29
+ Use string to be compatible with `torch.compile`.
30
+ """
31
+
32
+ # Decoder attention between previous layer Q/K/V
33
+ DECODER = "decoder"
34
+ # Encoder attention between previous layer Q/K/V
35
+ ENCODER_ONLY = "encoder_only"
36
+
37
+
25
38
  class RadixAttention(nn.Module):
26
39
  """
27
40
  The attention layer implementation.
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
39
52
  sliding_window_size: int = -1,
40
53
  is_cross_attention: bool = False,
41
54
  quant_config: Optional[QuantizationConfig] = None,
55
+ attn_type=AttentionType.DECODER,
42
56
  prefix: str = "",
43
57
  use_irope: bool = False,
44
58
  ):
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
64
78
  self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
65
79
  if self.quant_method is not None:
66
80
  self.quant_method.create_weights(self)
81
+ self.attn_type = attn_type
67
82
 
68
83
  def forward(
69
84
  self,
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
11
11
  from sglang.srt.utils import is_cuda_available
12
12
 
13
13
  _is_cuda_available = is_cuda_available()
14
+
14
15
  if _is_cuda_available:
15
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
16
17
  else:
17
- from vllm import _custom_ops as ops
18
+ from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
18
19
 
19
20
 
20
21
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
159
160
  )
160
161
  else:
161
162
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
162
- ops.rotary_embedding(
163
+ vllm_rotary_embedding(
163
164
  positions,
164
165
  query,
165
166
  key,
@@ -93,28 +93,23 @@ class Sampler(nn.Module):
93
93
  ).clamp(min=torch.finfo(probs.dtype).min)
94
94
 
95
95
  max_top_k_round, batch_size = 32, probs.shape[0]
96
- uniform_samples = torch.rand(
97
- (max_top_k_round, batch_size), device=probs.device
98
- )
99
96
  if sampling_info.need_min_p_sampling:
100
97
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
101
98
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
102
99
  batch_next_token_ids = min_p_sampling_from_probs(
103
- probs, uniform_samples, sampling_info.min_ps
100
+ probs, sampling_info.min_ps
104
101
  )
105
102
  else:
106
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
103
+ # Check Nan will throw exception, only check when crash_on_warnings is True
104
+ check_nan = self.use_nan_detection and crash_on_warnings()
105
+ batch_next_token_ids = top_k_top_p_sampling_from_probs(
107
106
  probs,
108
- uniform_samples,
109
107
  sampling_info.top_ks,
110
108
  sampling_info.top_ps,
111
109
  filter_apply_order="joint",
110
+ check_nan=check_nan,
112
111
  )
113
112
 
114
- if self.use_nan_detection and not torch.all(success):
115
- logger.warning("Detected errors during sampling!")
116
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
117
-
118
113
  elif global_server_args_dict["sampling_backend"] == "pytorch":
119
114
  # A slower fallback implementation with torch native operations.
120
115
  batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
@@ -75,7 +75,7 @@ class BaseLoRABackend:
75
75
  qkv_lora_a: torch.Tensor,
76
76
  qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
77
77
  *args,
78
- **kwargs
78
+ **kwargs,
79
79
  ) -> torch.Tensor:
80
80
  """Run the lora pass for QKV Layer.
81
81
 
@@ -98,7 +98,7 @@ class BaseLoRABackend:
98
98
  gate_up_lora_a: torch.Tensor,
99
99
  gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
100
100
  *args,
101
- **kwargs
101
+ **kwargs,
102
102
  ) -> torch.Tensor:
103
103
  """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
104
104
 
@@ -115,3 +115,19 @@ class BaseLoRABackend:
115
115
 
116
116
  def set_batch_info(self, batch_info: LoRABatchInfo):
117
117
  self.batch_info = batch_info
118
+
119
+
120
+ def get_backend_from_name(name: str) -> BaseLoRABackend:
121
+ """
122
+ Get corresponding backend class from backend's name
123
+ """
124
+ if name == "triton":
125
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
126
+
127
+ return TritonLoRABackend
128
+ elif name == "flashinfer":
129
+ from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
130
+
131
+ return FlashInferLoRABackend
132
+ else:
133
+ raise ValueError(f"Invalid backend: {name}")
@@ -2,7 +2,7 @@ from typing import Tuple
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.lora.backend import BaseLoRABackend
5
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
6
6
  from sglang.srt.lora.utils import LoRABatchInfo
7
7
  from sglang.srt.utils import is_flashinfer_available
8
8
 
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.lora.backend import BaseLoRABackend
3
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
4
4
  from sglang.srt.lora.triton_ops import (
5
5
  gate_up_lora_b_fwd,
6
6
  qkv_lora_b_fwd,
sglang/srt/lora/layers.py CHANGED
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
16
16
  RowParallelLinear,
17
17
  )
18
18
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
19
- from sglang.srt.lora.backend import BaseLoRABackend
19
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
20
20
 
21
21
 
22
22
  class BaseLayerWithLoRA(nn.Module):
sglang/srt/lora/lora.py CHANGED
@@ -27,7 +27,7 @@ from torch import nn
27
27
 
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
29
  from sglang.srt.hf_transformers_utils import AutoConfig
30
- from sglang.srt.lora.backend import BaseLoRABackend
30
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
31
  from sglang.srt.lora.lora_config import LoRAConfig
32
32
  from sglang.srt.model_loader.loader import DefaultModelLoader
33
33
 
@@ -22,7 +22,7 @@ import torch
22
22
 
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
24
  from sglang.srt.hf_transformers_utils import AutoConfig
25
- from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
25
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
26
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -14,7 +14,6 @@
14
14
  """DetokenizerManager is a process that detokenizes the token ids."""
15
15
 
16
16
  import dataclasses
17
- import json
18
17
  import logging
19
18
  import os
20
19
  import signal
@@ -834,6 +834,7 @@ class ProfileReq:
834
834
  activities: Optional[List[str]] = None
835
835
  with_stack: Optional[bool] = None
836
836
  record_shapes: Optional[bool] = None
837
+ profile_id: Optional[str] = None
837
838
 
838
839
 
839
840
  @dataclass
@@ -1,7 +1,8 @@
1
1
  """
2
- Multi-modality utils
2
+ Multi-modality utils
3
3
  """
4
4
 
5
+ import logging
5
6
  from abc import abstractmethod
6
7
  from typing import Callable, List, Optional, Tuple
7
8
 
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
12
13
  MultimodalDataItem,
13
14
  MultimodalInputs,
14
15
  global_server_args_dict,
15
- logger,
16
16
  )
17
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
18
  from sglang.srt.utils import print_warning_once
19
- from sglang.utils import logger
19
+
20
+ logger = logging.getLogger(__name__)
20
21
 
21
22
 
22
23
  class MultiModalityDataPaddingPattern:
@@ -5,8 +5,6 @@ import logging
5
5
  import pkgutil
6
6
  from functools import lru_cache
7
7
 
8
- from transformers import PROCESSOR_MAPPING
9
-
10
8
  from sglang.srt.managers.multimodal_processors.base_processor import (
11
9
  BaseMultimodalProcessor,
12
10
  )
@@ -8,8 +8,6 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
- from decord import VideoReader, cpu
12
- from PIL import Image
13
11
  from transformers import BaseImageProcessorFast
14
12
 
15
13
  from sglang.srt.managers.schedule_batch import Modality
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
102
100
  """
103
101
  estimate the total frame count from all visual input
104
102
  """
103
+ # Lazy import because decord is not available on some arm platforms.
104
+ from decord import VideoReader, cpu
105
+
105
106
  # Before processing inputs
106
107
  estimated_frames_list = []
107
108
  for image in image_data:
@@ -67,7 +67,6 @@ global_server_args_dict = {
67
67
  "attention_backend": ServerArgs.attention_backend,
68
68
  "sampling_backend": ServerArgs.sampling_backend,
69
69
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
70
- "disable_mla": ServerArgs.disable_mla,
71
70
  "torchao_config": ServerArgs.torchao_config,
72
71
  "enable_nan_detection": ServerArgs.enable_nan_detection,
73
72
  "enable_dp_attention": ServerArgs.enable_dp_attention,
@@ -77,12 +76,11 @@ global_server_args_dict = {
77
76
  "device": ServerArgs.device,
78
77
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
79
78
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
80
- "enable_flashmla": ServerArgs.enable_flashmla,
81
79
  "disable_radix_cache": ServerArgs.disable_radix_cache,
82
80
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
81
+ "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
83
82
  "chunked_prefill_size": ServerArgs.chunked_prefill_size,
84
83
  "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
85
- "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
86
84
  "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
87
85
  }
88
86
 
@@ -1481,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1481
1479
  global_server_args_dict["use_mla_backend"]
1482
1480
  and global_server_args_dict["attention_backend"] == "flashinfer"
1483
1481
  )
1484
- or global_server_args_dict["enable_flashmla"]
1482
+ or global_server_args_dict["attention_backend"] == "flashmla"
1485
1483
  or global_server_args_dict["attention_backend"] == "fa3"
1486
1484
  ):
1487
1485
  seq_lens_cpu = self.seq_lens.cpu()
@@ -391,6 +391,7 @@ class Scheduler(
391
391
  self.torch_profiler = None
392
392
  self.torch_profiler_output_dir: Optional[str] = None
393
393
  self.profiler_activities: Optional[List[str]] = None
394
+ self.profiler_id: Optional[str] = None
394
395
  self.profiler_target_forward_ct: Optional[int] = None
395
396
 
396
397
  # Init metrics stats
@@ -484,7 +485,7 @@ class Scheduler(
484
485
  self.tree_cache = HiRadixCache(
485
486
  req_to_token_pool=self.req_to_token_pool,
486
487
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
487
- tp_cache_group=self.tp_worker.get_tp_cpu_group(),
488
+ tp_cache_group=self.tp_cpu_group,
488
489
  page_size=self.page_size,
489
490
  hicache_ratio=server_args.hicache_ratio,
490
491
  )
@@ -553,7 +554,7 @@ class Scheduler(
553
554
 
554
555
  # The decode requests polling kv cache
555
556
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
556
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
557
+ gloo_group=self.attn_tp_cpu_group,
557
558
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
558
559
  metadata_buffers=metadata_buffers,
559
560
  )
@@ -568,7 +569,7 @@ class Scheduler(
568
569
  scheduler=self,
569
570
  transfer_queue=self.disagg_decode_transfer_queue,
570
571
  tree_cache=self.tree_cache,
571
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
572
+ gloo_group=self.attn_tp_cpu_group,
572
573
  tp_rank=self.tp_rank,
573
574
  tp_size=self.tp_size,
574
575
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
@@ -597,7 +598,7 @@ class Scheduler(
597
598
  tp_rank=self.tp_rank,
598
599
  tp_size=self.tp_size,
599
600
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
600
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
601
+ gloo_group=self.attn_tp_cpu_group,
601
602
  transfer_backend=self.transfer_backend,
602
603
  scheduler=self,
603
604
  )
@@ -664,70 +665,6 @@ class Scheduler(
664
665
 
665
666
  self.last_batch = batch
666
667
 
667
- @torch.no_grad()
668
- def event_loop_normal_disagg_prefill(self):
669
- """A normal scheduler loop for prefill worker in disaggregation mode."""
670
-
671
- while True:
672
- recv_reqs = self.recv_requests()
673
- self.process_input_requests(recv_reqs)
674
- self.waiting_queue.extend(
675
- self.disagg_prefill_pending_queue.pop_bootstrapped()
676
- )
677
- self.process_prefill_chunk()
678
- batch = self.get_new_batch_prefill()
679
- self.cur_batch = batch
680
-
681
- if batch:
682
- result = self.run_batch(batch)
683
- self.process_batch_result_disagg_prefill(batch, result)
684
-
685
- if len(self.disagg_prefill_inflight_queue) > 0:
686
- self.process_disagg_prefill_inflight_queue()
687
-
688
- if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
689
- self.check_memory()
690
- self.new_token_ratio = self.init_new_token_ratio
691
-
692
- self.last_batch = batch
693
- # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
694
- # Otherwise, it hangs under high concurrency
695
- self.running_batch.batch_is_full = False
696
-
697
- @torch.no_grad()
698
- def event_loop_normal_disagg_decode(self):
699
- """A normal scheduler loop for decode worker in disaggregation mode."""
700
-
701
- while True:
702
- recv_reqs = self.recv_requests()
703
- self.process_input_requests(recv_reqs)
704
- # polling and allocating kv cache
705
- self.process_decode_queue()
706
- batch = self.get_next_disagg_decode_batch_to_run()
707
- self.cur_batch = batch
708
-
709
- if batch:
710
- # Generate fake extend output.
711
- if batch.forward_mode.is_extend():
712
- # Note: Logprobs should be handled on the prefill engine.
713
- self.stream_output(
714
- batch.reqs, [False for _ in range(len(batch.reqs))]
715
- )
716
- else:
717
- result = self.run_batch(batch)
718
- self.process_batch_result(batch, result)
719
-
720
- if batch is None and (
721
- len(self.disagg_decode_transfer_queue.queue)
722
- + len(self.disagg_decode_prealloc_queue.queue)
723
- == 0
724
- ):
725
- # When the server is idle, do self-check and re-init some states
726
- self.check_memory()
727
- self.new_token_ratio = self.init_new_token_ratio
728
-
729
- self.last_batch = batch
730
-
731
668
  def recv_requests(self) -> List[Req]:
732
669
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
733
670
  if self.attn_tp_rank == 0:
@@ -1869,6 +1806,7 @@ class Scheduler(
1869
1806
  recv_req.activities,
1870
1807
  recv_req.with_stack,
1871
1808
  recv_req.record_shapes,
1809
+ recv_req.profile_id,
1872
1810
  )
1873
1811
  else:
1874
1812
  return self.stop_profile()
@@ -1880,6 +1818,7 @@ class Scheduler(
1880
1818
  activities: Optional[List[str]],
1881
1819
  with_stack: Optional[bool],
1882
1820
  record_shapes: Optional[bool],
1821
+ profile_id: Optional[str],
1883
1822
  ) -> None:
1884
1823
  if self.profiler_activities:
1885
1824
  return ProfileReqOutput(
@@ -1894,9 +1833,11 @@ class Scheduler(
1894
1833
 
1895
1834
  self.torch_profiler_output_dir = output_dir
1896
1835
  self.profiler_activities = activities
1836
+ self.profiler_id = profile_id
1897
1837
  logger.info(
1898
- "Profiling starts. Traces will be saved to: %s",
1838
+ "Profiling starts. Traces will be saved to: %s (with id %s)",
1899
1839
  self.torch_profiler_output_dir,
1840
+ self.profiler_id,
1900
1841
  )
1901
1842
 
1902
1843
  activity_map = {
@@ -1938,14 +1879,14 @@ class Scheduler(
1938
1879
  self.torch_profiler.export_chrome_trace(
1939
1880
  os.path.join(
1940
1881
  self.torch_profiler_output_dir,
1941
- str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1882
+ self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1942
1883
  )
1943
1884
  )
1944
1885
 
1945
1886
  if "MEM" in self.profiler_activities:
1946
1887
  memory_profile_path = os.path.join(
1947
1888
  self.torch_profiler_output_dir,
1948
- str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
1889
+ self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
1949
1890
  )
1950
1891
  torch.cuda.memory._dump_snapshot(memory_profile_path)
1951
1892
  torch.cuda.memory._record_memory_history(enabled=None)
@@ -650,6 +650,7 @@ class TokenizerManager:
650
650
  output_dir=output_dir,
651
651
  num_steps=num_steps,
652
652
  activities=activities,
653
+ profile_id=str(time.time()),
653
654
  )
654
655
  result = (await self.start_profile_communicator(req))[0]
655
656
  if not result.success:
@@ -92,7 +92,7 @@ class HiRadixCache(RadixCache):
92
92
  self.ongoing_write_through[node.id] = node
93
93
  self.inc_lock_ref(node)
94
94
  else:
95
- return None
95
+ return 0
96
96
 
97
97
  return len(host_indices)
98
98
 
@@ -153,6 +153,7 @@ class HiRadixCache(RadixCache):
153
153
  if x.host_value is None:
154
154
  if self.cache_controller.write_policy == "write_back":
155
155
  num_evicted += self.write_backup(x)
156
+ pending_nodes.append(x)
156
157
  elif self.cache_controller.write_policy == "write_through_selective":
157
158
  num_evicted += self._evict_write_through_selective(x)
158
159
  else:
@@ -177,6 +178,9 @@ class HiRadixCache(RadixCache):
177
178
  while len(self.ongoing_write_through) > 0:
178
179
  self.writing_check()
179
180
  time.sleep(0.1)
181
+ for node in pending_nodes:
182
+ assert node.host_value is not None
183
+ self._evict_write_through(node)
180
184
 
181
185
  def _evict_write_through(self, node: TreeNode):
182
186
  # evict a node already written to host
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
286
286
  self.get_key_buffer(i).nbytes for i in range(self.layer_num)
287
287
  ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
288
288
  kv_item_lens = [
289
- self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
290
- ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
289
+ self.get_key_buffer(i)[0].nbytes * self.page_size
290
+ for i in range(self.layer_num)
291
+ ] + [
292
+ self.get_value_buffer(i)[0].nbytes * self.page_size
293
+ for i in range(self.layer_num)
294
+ ]
291
295
  return kv_data_ptrs, kv_data_lens, kv_item_lens
292
296
 
293
297
  # Todo: different memory layout
@@ -414,6 +418,7 @@ class MLATokenToKVPool(KVCache):
414
418
  enable_memory_saver: bool,
415
419
  ):
416
420
  self.size = size
421
+ self.page_size = page_size
417
422
  self.dtype = dtype
418
423
  self.device = device
419
424
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
37
37
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
38
  from sglang.srt.utils import get_available_gpu_memory, is_hip
39
39
 
40
- _is_hip = is_hip()
41
-
42
40
  if TYPE_CHECKING:
43
41
  from sglang.srt.model_executor.model_runner import ModelRunner
44
42
 
43
+ _is_hip = is_hip()
44
+
45
45
 
46
46
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
47
47
  for sub in model._modules.values():
@@ -73,6 +73,7 @@ from sglang.srt.utils import (
73
73
  MultiprocessingSerializer,
74
74
  enable_show_time_cost,
75
75
  get_available_gpu_memory,
76
+ get_bool_env_var,
76
77
  init_custom_process_group,
77
78
  is_cuda,
78
79
  is_fa3_default_architecture,
@@ -127,10 +128,7 @@ class ModelRunner:
127
128
  self.page_size = server_args.page_size
128
129
  self.req_to_token_pool = req_to_token_pool
129
130
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
130
- self.use_mla_backend = (
131
- self.model_config.attention_arch == AttentionArch.MLA
132
- and not server_args.disable_mla
133
- )
131
+ self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
134
132
  self.attention_chunk_size = model_config.attention_chunk_size
135
133
 
136
134
  # Model-specific adjustment
@@ -139,18 +137,12 @@ class ModelRunner:
139
137
  if server_args.show_time_cost:
140
138
  enable_show_time_cost()
141
139
 
142
- if server_args.disable_outlines_disk_cache:
143
- from outlines.caching import disable_cache
144
-
145
- disable_cache()
146
-
147
140
  # Global vars
148
141
  global_server_args_dict.update(
149
142
  {
150
143
  "attention_backend": server_args.attention_backend,
151
144
  "sampling_backend": server_args.sampling_backend,
152
145
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
153
- "disable_mla": server_args.disable_mla,
154
146
  "torchao_config": server_args.torchao_config,
155
147
  "enable_nan_detection": server_args.enable_nan_detection,
156
148
  "enable_dp_attention": server_args.enable_dp_attention,
@@ -160,13 +152,12 @@ class ModelRunner:
160
152
  "device": server_args.device,
161
153
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
162
154
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
163
- "enable_flashmla": server_args.enable_flashmla,
164
155
  "disable_radix_cache": server_args.disable_radix_cache,
165
156
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
157
+ "moe_dense_tp_size": server_args.moe_dense_tp_size,
166
158
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
167
159
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
168
160
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
169
- "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
170
161
  "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
171
162
  "use_mla_backend": self.use_mla_backend,
172
163
  }
@@ -229,15 +220,7 @@ class ModelRunner:
229
220
  def model_specific_adjustment(self):
230
221
  server_args = self.server_args
231
222
 
232
- if server_args.enable_flashinfer_mla:
233
- # TODO: remove this branch after enable_flashinfer_mla is deprecated
234
- logger.info("MLA optimization is turned on. Use flashinfer backend.")
235
- server_args.attention_backend = "flashinfer"
236
- elif server_args.enable_flashmla:
237
- # TODO: remove this branch after enable_flashmla is deprecated
238
- logger.info("MLA optimization is turned on. Use flashmla decode.")
239
- server_args.attention_backend = "flashmla"
240
- elif server_args.attention_backend is None:
223
+ if server_args.attention_backend is None:
241
224
  # By default, use flashinfer for non-mla attention and triton for mla attention
242
225
  if not self.use_mla_backend:
243
226
  if (
@@ -263,7 +246,12 @@ class ModelRunner:
263
246
  elif self.use_mla_backend:
264
247
  # TODO: add MLA optimization on CPU
265
248
  if server_args.device != "cpu":
266
- if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
249
+ if server_args.attention_backend in [
250
+ "flashinfer",
251
+ "fa3",
252
+ "triton",
253
+ "flashmla",
254
+ ]:
267
255
  logger.info(
268
256
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
269
257
  )
@@ -320,7 +308,6 @@ class ModelRunner:
320
308
  logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
321
309
 
322
310
  if not self.use_mla_backend:
323
- logger.info("Disable chunked prefix cache for non-MLA backend.")
324
311
  server_args.disable_chunked_prefix_cache = True
325
312
  elif self.page_size > 1:
326
313
  logger.info("Disable chunked prefix cache when page size > 1.")
@@ -387,10 +374,16 @@ class ModelRunner:
387
374
  local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
388
375
  if self.tp_size > 1:
389
376
  if min_per_gpu_memory < local_gpu_memory * 0.9:
390
- raise ValueError(
391
- "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
392
- f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
393
- )
377
+ if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
378
+ logger.warning(
379
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
380
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
381
+ )
382
+ else:
383
+ raise ValueError(
384
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
385
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
386
+ )
394
387
 
395
388
  logger.info(
396
389
  f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"