sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +95 -49
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +33 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +258 -782
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +7 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +112 -46
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/metrics/collector.py +8 -0
  95. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  96. sglang/srt/model_executor/forward_batch_info.py +12 -8
  97. sglang/srt/model_executor/model_runner.py +153 -134
  98. sglang/srt/model_loader/loader.py +2 -1
  99. sglang/srt/model_loader/weight_utils.py +1 -1
  100. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  101. sglang/srt/models/deepseek_nextn.py +23 -3
  102. sglang/srt/models/deepseek_v2.py +25 -19
  103. sglang/srt/models/minicpmv.py +28 -89
  104. sglang/srt/models/mllama.py +1 -1
  105. sglang/srt/models/qwen2.py +0 -1
  106. sglang/srt/models/qwen2_5_vl.py +25 -50
  107. sglang/srt/models/qwen2_vl.py +33 -49
  108. sglang/srt/openai_api/adapter.py +37 -15
  109. sglang/srt/openai_api/protocol.py +8 -1
  110. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  111. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  112. sglang/srt/server_args.py +19 -20
  113. sglang/srt/speculative/build_eagle_tree.py +6 -1
  114. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
  115. sglang/srt/speculative/eagle_utils.py +2 -1
  116. sglang/srt/speculative/eagle_worker.py +109 -38
  117. sglang/srt/utils.py +104 -9
  118. sglang/test/runners.py +104 -10
  119. sglang/test/test_block_fp8.py +106 -16
  120. sglang/test/test_custom_ops.py +88 -0
  121. sglang/test/test_utils.py +20 -4
  122. sglang/utils.py +0 -4
  123. sglang/version.py +1 -1
  124. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
  125. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
  126. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  127. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -38,12 +38,12 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import get_compiler_backend
41
+ from sglang.srt.utils import get_compiler_backend, next_power_of_2
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
45
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
+ from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
47
47
  from sglang.srt.model_executor.model_runner import ModelRunner
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
49
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
51
51
 
52
52
 
53
53
  class ForwardMode(IntEnum):
54
- # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
55
- PREFILL = auto()
56
54
  # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
55
+ # It is also called "prefill" in common terminology.
57
56
  EXTEND = auto()
58
57
  # Decode one token.
59
58
  DECODE = auto()
@@ -153,6 +152,12 @@ class ForwardBatch:
153
152
  top_logprobs_nums: Optional[List[int]] = None
154
153
  token_ids_logprobs: Optional[List[List[int]]] = None
155
154
 
155
+ # For logits and logprobs post processing
156
+ temp_scaled_logprobs: bool = False
157
+ temperature: torch.Tensor = None
158
+ top_p_normalized_logprobs: bool = False
159
+ top_p: torch.Tensor = None
160
+
156
161
  # Position information
157
162
  positions: torch.Tensor = None
158
163
 
@@ -189,7 +194,7 @@ class ForwardBatch:
189
194
 
190
195
  # Attention backend
191
196
  req_to_token_pool: ReqToTokenPool = None
192
- token_to_kv_pool: BaseTokenToKVPool = None
197
+ token_to_kv_pool: KVCache = None
193
198
  attn_backend: AttentionBackend = None
194
199
 
195
200
  # For DP attention
@@ -229,7 +234,6 @@ class ForwardBatch:
229
234
  extend_input_logprob_token_ids_gpu = (
230
235
  batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
231
236
  )
232
-
233
237
  ret = cls(
234
238
  forward_mode=batch.forward_mode,
235
239
  batch_size=len(batch.seq_lens),
@@ -417,8 +421,8 @@ def compute_position_kernel(
417
421
  prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
418
422
  seq_len = tl.load(extend_seq_lens + pid)
419
423
 
420
- # TODO: optimize this?
421
- cumsum_start = 0
424
+ # NOTE: This can be slow for large bs
425
+ cumsum_start = tl.cast(0, tl.int64)
422
426
  for i in range(pid):
423
427
  cumsum_start += tl.load(extend_seq_lens + i)
424
428
 
@@ -35,17 +35,13 @@ from sglang.srt.distributed import (
35
35
  set_custom_all_reduce,
36
36
  )
37
37
  from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
38
- from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
39
- from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
40
- from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
41
- from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
42
- from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
43
38
  from sglang.srt.layers.dp_attention import (
44
39
  get_attention_tp_group,
45
40
  get_attention_tp_size,
46
41
  initialize_dp_attention,
47
42
  )
48
43
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
44
+ from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
49
45
  from sglang.srt.layers.sampler import Sampler
50
46
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
51
47
  from sglang.srt.lora.lora_manager import LoRAManager
@@ -57,9 +53,16 @@ from sglang.srt.mem_cache.memory_pool import (
57
53
  ReqToTokenPool,
58
54
  TokenToKVPoolAllocator,
59
55
  )
56
+ from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
60
57
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
61
58
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
62
59
  from sglang.srt.model_loader import get_model
60
+ from sglang.srt.model_loader.loader import (
61
+ DefaultModelLoader,
62
+ device_loading_context,
63
+ get_model_loader,
64
+ )
65
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
63
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
64
67
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
65
68
  from sglang.srt.server_args import ServerArgs
@@ -77,11 +80,9 @@ from sglang.srt.utils import (
77
80
  set_cpu_offload_max_bytes,
78
81
  set_cuda_arch,
79
82
  )
80
- from sglang.utils import get_exception_traceback
81
83
 
82
84
  logger = logging.getLogger(__name__)
83
85
 
84
-
85
86
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
86
87
  UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
87
88
 
@@ -118,70 +119,22 @@ class ModelRunner:
118
119
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
119
120
  server_args.speculative_algorithm
120
121
  )
122
+ self.page_size = server_args.page_size
121
123
  self.req_to_token_pool = req_to_token_pool
122
124
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
123
125
 
124
126
  # Model-specific adjustment
125
- if (
126
- self.model_config.attention_arch == AttentionArch.MLA
127
- and not self.server_args.disable_mla
128
- ):
129
- # TODO: add MLA optimization on CPU
130
- if self.server_args.device != "cpu":
131
- if server_args.enable_flashinfer_mla:
132
- logger.info(
133
- "MLA optimization is turned on. Use flashinfer mla backend."
134
- )
135
- self.server_args.attention_backend = "flashinfer_mla"
136
- else:
137
- logger.info("MLA optimization is turned on. Use triton backend.")
138
- self.server_args.attention_backend = "triton"
139
-
140
- if self.server_args.enable_double_sparsity:
141
- logger.info(
142
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
143
- )
144
- self.server_args.attention_backend = "triton"
145
- self.server_args.disable_cuda_graph = True
146
- if self.server_args.ds_heavy_channel_type is None:
147
- raise ValueError(
148
- "Please specify the heavy channel type for double sparsity optimization."
149
- )
150
- self.init_double_sparsity_channel_config(
151
- self.server_args.ds_heavy_channel_type
152
- )
127
+ self.model_specific_adjustment()
153
128
 
154
- if self.is_multimodal:
155
- self.mem_fraction_static *= 0.95
156
- logger.info(
157
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
158
- f"because this is a multimodal model."
159
- )
160
-
161
- if self.model_config.hf_config.architectures == [
162
- "MllamaForConditionalGeneration"
163
- ]:
164
- logger.info("Automatically turn off --chunked-prefill-size for mllama.")
165
- server_args.chunked_prefill_size = -1
166
-
167
- if self.model_config.hf_config.architectures == [
168
- "Qwen2VLForConditionalGeneration"
169
- ]:
170
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
171
- logger.info(
172
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
173
- )
174
- server_args.chunked_prefill_size = -1
175
- server_args.disable_radix_cache = True
176
-
177
- # Global vars
178
129
  if server_args.show_time_cost:
179
130
  enable_show_time_cost()
131
+
180
132
  if server_args.disable_outlines_disk_cache:
181
133
  from outlines.caching import disable_cache
182
134
 
183
135
  disable_cache()
184
136
 
137
+ # Global vars
185
138
  global_server_args_dict.update(
186
139
  {
187
140
  "attention_backend": server_args.attention_backend,
@@ -203,11 +156,17 @@ class ModelRunner:
203
156
  }
204
157
  )
205
158
 
159
+ # CPU offload
206
160
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
207
161
 
208
162
  # Get memory before model loading
209
163
  min_per_gpu_memory = self.init_torch_distributed()
210
164
 
165
+ # If it is a draft model tp_group can be different.
166
+ self.initialize(min_per_gpu_memory)
167
+
168
+ def initialize(self, min_per_gpu_memory: float):
169
+ server_args = self.server_args
211
170
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
212
171
  enable=self.server_args.enable_memory_saver
213
172
  )
@@ -216,18 +175,6 @@ class ModelRunner:
216
175
  self.sampler = Sampler()
217
176
  self.load_model()
218
177
 
219
- # Handle the case where some of models don't finish loading.
220
- try:
221
- dist.monitored_barrier(
222
- group=get_tp_group().cpu_group,
223
- timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
224
- wait_all_ranks=True,
225
- )
226
- except RuntimeError:
227
- raise ValueError(
228
- f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
229
- ) from None
230
-
231
178
  # Apply torchao quantization
232
179
  torchao_applied = getattr(self.model, "torchao_applied", False)
233
180
  # In layered loading, torchao may have been applied
@@ -244,9 +191,11 @@ class ModelRunner:
244
191
  else:
245
192
  self.torch_tp_applied = False
246
193
 
247
- # Init memory pool and attention backends
194
+ # Init lora
248
195
  if server_args.lora_paths is not None:
249
196
  self.init_lora_manager()
197
+
198
+ # Init memory pool and attention backends
250
199
  self.init_memory_pool(
251
200
  min_per_gpu_memory,
252
201
  server_args.max_running_requests,
@@ -260,10 +209,63 @@ class ModelRunner:
260
209
  self.cuda_graph_runner = None
261
210
  self.init_attention_backend()
262
211
 
212
+ def model_specific_adjustment(self):
213
+ server_args = self.server_args
214
+
215
+ if (
216
+ self.model_config.attention_arch == AttentionArch.MLA
217
+ and not server_args.disable_mla
218
+ ):
219
+ # TODO: add MLA optimization on CPU
220
+ if server_args.device != "cpu":
221
+ if server_args.enable_flashinfer_mla:
222
+ logger.info(
223
+ "MLA optimization is turned on. Use flashinfer mla backend."
224
+ )
225
+ server_args.attention_backend = "flashinfer_mla"
226
+ else:
227
+ logger.info("MLA optimization is turned on. Use triton backend.")
228
+ server_args.attention_backend = "triton"
229
+
230
+ if server_args.enable_double_sparsity:
231
+ logger.info(
232
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
233
+ )
234
+ server_args.attention_backend = "triton"
235
+ server_args.disable_cuda_graph = True
236
+ if server_args.ds_heavy_channel_type is None:
237
+ raise ValueError(
238
+ "Please specify the heavy channel type for double sparsity optimization."
239
+ )
240
+ self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
241
+
242
+ if self.is_multimodal:
243
+ self.mem_fraction_static *= 0.95
244
+ logger.info(
245
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
246
+ f"because this is a multimodal model."
247
+ )
248
+
249
+ if self.model_config.hf_config.architectures == [
250
+ "MllamaForConditionalGeneration"
251
+ ]:
252
+ logger.info("Automatically turn off --chunked-prefill-size for mllama.")
253
+ server_args.chunked_prefill_size = -1
254
+
255
+ if self.model_config.hf_config.architectures == [
256
+ "Qwen2VLForConditionalGeneration"
257
+ ]:
258
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
259
+ logger.info(
260
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
261
+ )
262
+ server_args.chunked_prefill_size = -1
263
+ server_args.disable_radix_cache = True
264
+
263
265
  def init_torch_distributed(self):
264
266
  logger.info("Init torch distributed begin.")
265
- torch.get_device_module(self.device).set_device(self.gpu_id)
266
267
 
268
+ torch.get_device_module(self.device).set_device(self.gpu_id)
267
269
  if self.device == "cuda":
268
270
  backend = "nccl"
269
271
  elif self.device == "xpu":
@@ -304,15 +306,16 @@ class ModelRunner:
304
306
  min_per_gpu_memory = get_available_gpu_memory(
305
307
  self.device, self.gpu_id, distributed=self.tp_size > 1
306
308
  )
307
- local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
308
309
  self.tp_group = get_tp_group()
309
310
  self.attention_tp_group = get_attention_tp_group()
310
311
 
311
312
  # Check memory for tensor parallelism
313
+ local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
312
314
  if self.tp_size > 1:
313
315
  if min_per_gpu_memory < local_gpu_memory * 0.9:
314
316
  raise ValueError(
315
- "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
317
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
318
+ f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
316
319
  )
317
320
 
318
321
  logger.info(
@@ -352,6 +355,8 @@ class ModelRunner:
352
355
  # Load the model
353
356
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
354
357
  monkey_patch_vllm_parallel_state()
358
+ monkey_patch_isinstance_for_vllm_base_layer()
359
+
355
360
  with self.memory_saver_adapter.region():
356
361
  self.model = get_model(
357
362
  model_config=self.model_config,
@@ -359,6 +364,7 @@ class ModelRunner:
359
364
  device_config=DeviceConfig(self.device),
360
365
  )
361
366
  monkey_patch_vllm_parallel_state(reverse=True)
367
+ monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
362
368
 
363
369
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
364
370
  if self.server_args.quantization_param_path is not None:
@@ -400,17 +406,22 @@ class ModelRunner:
400
406
  f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
401
407
  )
402
408
 
409
+ # Handle the case where some ranks do not finish loading.
410
+ try:
411
+ dist.monitored_barrier(
412
+ group=get_tp_group().cpu_group,
413
+ timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
414
+ wait_all_ranks=True,
415
+ )
416
+ except RuntimeError:
417
+ raise ValueError(
418
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
419
+ ) from None
420
+
403
421
  def update_weights_from_disk(
404
422
  self, model_path: str, load_format: str
405
423
  ) -> tuple[bool, str]:
406
424
  """Update engine weights in-place from the disk."""
407
- from sglang.srt.model_loader.loader import (
408
- DefaultModelLoader,
409
- device_loading_context,
410
- get_model_loader,
411
- )
412
- from sglang.srt.model_loader.utils import set_default_torch_dtype
413
-
414
425
  logger.info(
415
426
  f"Update engine weights online from disk begin. "
416
427
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -420,7 +431,7 @@ class ModelRunner:
420
431
  self.model_config.model_path = model_path
421
432
  load_config = LoadConfig(load_format=load_format)
422
433
 
423
- # Only support vllm DefaultModelLoader for now
434
+ # Only support DefaultModelLoader for now
424
435
  loader = get_model_loader(load_config)
425
436
  if not isinstance(loader, DefaultModelLoader):
426
437
  message = f"Failed to get model loader: {loader}."
@@ -694,6 +705,12 @@ class ModelRunner:
694
705
  )
695
706
  self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
696
707
 
708
+ self.max_total_num_tokens = (
709
+ self.max_total_num_tokens
710
+ // self.server_args.page_size
711
+ * self.server_args.page_size
712
+ )
713
+
697
714
  if self.max_total_num_tokens <= 0:
698
715
  raise RuntimeError(
699
716
  "Not enough memory. Please try to increase --mem-fraction-static."
@@ -710,21 +727,13 @@ class ModelRunner:
710
727
  # Draft worker shares req_to_token_pool with the target worker.
711
728
  assert self.is_draft_worker
712
729
 
713
- if self.token_to_kv_pool_allocator is None:
714
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
715
- self.max_total_num_tokens,
716
- dtype=self.kv_cache_dtype,
717
- device=self.device,
718
- )
719
- else:
720
- assert self.is_draft_worker
721
-
722
730
  if (
723
731
  self.model_config.attention_arch == AttentionArch.MLA
724
732
  and not self.server_args.disable_mla
725
733
  ):
726
734
  self.token_to_kv_pool = MLATokenToKVPool(
727
735
  self.max_total_num_tokens,
736
+ page_size=self.page_size,
728
737
  dtype=self.kv_cache_dtype,
729
738
  kv_lora_rank=self.model_config.kv_lora_rank,
730
739
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
@@ -735,6 +744,7 @@ class ModelRunner:
735
744
  elif self.server_args.enable_double_sparsity:
736
745
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
737
746
  self.max_total_num_tokens,
747
+ page_size=self.page_size,
738
748
  dtype=self.kv_cache_dtype,
739
749
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
740
750
  head_dim=self.model_config.head_dim,
@@ -746,6 +756,7 @@ class ModelRunner:
746
756
  else:
747
757
  self.token_to_kv_pool = MHATokenToKVPool(
748
758
  self.max_total_num_tokens,
759
+ page_size=self.page_size,
749
760
  dtype=self.kv_cache_dtype,
750
761
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
751
762
  head_dim=self.model_config.head_dim,
@@ -753,6 +764,26 @@ class ModelRunner:
753
764
  device=self.device,
754
765
  enable_memory_saver=self.server_args.enable_memory_saver,
755
766
  )
767
+
768
+ if self.token_to_kv_pool_allocator is None:
769
+ if self.page_size == 1:
770
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
771
+ self.max_total_num_tokens,
772
+ dtype=self.kv_cache_dtype,
773
+ device=self.device,
774
+ kvcache=self.token_to_kv_pool,
775
+ )
776
+ else:
777
+ self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
778
+ self.max_total_num_tokens,
779
+ page_size=self.page_size,
780
+ dtype=self.kv_cache_dtype,
781
+ device=self.device,
782
+ kvcache=self.token_to_kv_pool,
783
+ )
784
+ else:
785
+ assert self.is_draft_worker
786
+
756
787
  logger.info(
757
788
  f"Memory pool end. "
758
789
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -770,6 +801,13 @@ class ModelRunner:
770
801
  def init_attention_backend(self):
771
802
  """Init attention kernel backend."""
772
803
  if self.server_args.attention_backend == "flashinfer":
804
+ from sglang.srt.layers.attention.flashinfer_backend import (
805
+ FlashInferAttnBackend,
806
+ )
807
+
808
+ # Init streams
809
+ if self.server_args.speculative_algorithm == "EAGLE":
810
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
773
811
  self.attn_backend = FlashInferAttnBackend(self)
774
812
  elif self.server_args.attention_backend == "triton":
775
813
  assert self.sliding_window_size is None, (
@@ -781,12 +819,26 @@ class ModelRunner:
781
819
  "Please use `--attention-backend flashinfer`."
782
820
  )
783
821
  if self.server_args.enable_double_sparsity:
822
+ from sglang.srt.layers.attention.double_sparsity_backend import (
823
+ DoubleSparseAttnBackend,
824
+ )
825
+
784
826
  self.attn_backend = DoubleSparseAttnBackend(self)
785
827
  else:
828
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
829
+
786
830
  self.attn_backend = TritonAttnBackend(self)
787
831
  elif self.server_args.attention_backend == "torch_native":
832
+ from sglang.srt.layers.attention.torch_native_backend import (
833
+ TorchNativeAttnBackend,
834
+ )
835
+
788
836
  self.attn_backend = TorchNativeAttnBackend(self)
789
837
  elif self.server_args.attention_backend == "flashinfer_mla":
838
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
839
+ FlashInferMLAAttnBackend,
840
+ )
841
+
790
842
  self.attn_backend = FlashInferMLAAttnBackend(self)
791
843
  else:
792
844
  raise ValueError(
@@ -878,18 +930,24 @@ class ModelRunner:
878
930
  forward_batch.input_ids, forward_batch.positions, forward_batch
879
931
  )
880
932
 
881
- def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
933
+ def forward(
934
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
935
+ ) -> LogitsProcessorOutput:
882
936
  if (
883
937
  forward_batch.forward_mode.is_cuda_graph()
884
938
  and self.cuda_graph_runner
885
939
  and self.cuda_graph_runner.can_run(forward_batch)
886
940
  ):
887
- return self.cuda_graph_runner.replay(forward_batch)
941
+ return self.cuda_graph_runner.replay(
942
+ forward_batch, skip_attn_backend_init=skip_attn_backend_init
943
+ )
888
944
 
889
945
  if forward_batch.forward_mode.is_decode():
890
946
  return self.forward_decode(forward_batch)
891
947
  elif forward_batch.forward_mode.is_extend():
892
- return self.forward_extend(forward_batch)
948
+ return self.forward_extend(
949
+ forward_batch, skip_attn_backend_init=skip_attn_backend_init
950
+ )
893
951
  elif forward_batch.forward_mode.is_idle():
894
952
  return self.forward_idle(forward_batch)
895
953
  else:
@@ -909,45 +967,6 @@ class ModelRunner:
909
967
  sampling_info.update_regex_vocab_mask()
910
968
  sampling_info.apply_logits_bias(logits_output.next_token_logits)
911
969
 
912
- def update_output_logprobs(
913
- self,
914
- logits_output: LogitsProcessorOutput,
915
- sampling_info: SamplingBatchInfo,
916
- top_logprobs_nums: List[int],
917
- token_ids_logprobs: List[int],
918
- next_token_ids: torch.Tensor,
919
- *,
920
- num_tokens_per_req: List[int],
921
- ):
922
- """Update the logits_output's output logprob based on next_token_ids
923
-
924
- Args:
925
- logits_output: The logits output from the model forward
926
- sampling_info: Sampling info for logprob calculation
927
- top_logprobs_nums: Number of logprobs per request.
928
- next_token_ids: Next token ids.
929
- num_tokens_per_req: The number of tokens per request.
930
-
931
- Returns:
932
- A list of next_token_ids
933
- """
934
- self._preprocess_logits(logits_output, sampling_info)
935
- # We should repeat top_logprobs_nums to match num_tokens_per_req.
936
- top_logprobs_nums_repeat_interleaved = []
937
- token_ids_logprobs_repeat_interleaved = []
938
- for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
939
- top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
940
- for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
941
- token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
942
- self.sampler(
943
- logits_output,
944
- sampling_info,
945
- True,
946
- top_logprobs_nums_repeat_interleaved,
947
- token_ids_logprobs_repeat_interleaved,
948
- batch_next_token_ids=next_token_ids,
949
- )
950
-
951
970
  def sample(
952
971
  self,
953
972
  logits_output: LogitsProcessorOutput,
@@ -48,6 +48,7 @@ from sglang.srt.model_loader.weight_utils import (
48
48
  safetensors_weights_iterator,
49
49
  )
50
50
  from sglang.srt.utils import (
51
+ get_bool_env_var,
51
52
  get_device_capability,
52
53
  is_pin_memory_available,
53
54
  set_weight_attrs,
@@ -197,7 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
197
198
 
198
199
  Returns the path to the downloaded model, or None if the model is not
199
200
  downloaded from ModelScope."""
200
- if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
201
+ if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
201
202
  # download model from ModelScope hub,
202
203
  # lazy import so that modelscope is not required for normal use.
203
204
  # pylint: disable=C.
@@ -455,7 +455,7 @@ def pt_weights_iterator(
455
455
  disable=not enable_tqdm,
456
456
  bar_format=_BAR_FORMAT,
457
457
  ):
458
- state = torch.load(bin_file, map_location="cpu")
458
+ state = torch.load(bin_file, map_location="cpu", weights_only=True)
459
459
  yield from state.items()
460
460
  del state
461
461
  torch.cuda.empty_cache()