sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
33
33
 
34
34
  def __init__(self, labels: Dict[str, str]) -> None:
35
35
  # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
36
- from prometheus_client import Gauge
36
+ from prometheus_client import Gauge, Histogram
37
37
 
38
38
  self.labels = labels
39
39
  self.last_log_time = time.time()
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
139
139
  labelnames=labels.keys(),
140
140
  buckets=[
141
141
  0.1,
142
- 0.3,
143
- 0.5,
144
- 0.7,
145
- 0.9,
142
+ 0.2,
143
+ 0.4,
144
+ 0.6,
145
+ 0.8,
146
146
  1,
147
147
  2,
148
148
  4,
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
153
153
  40,
154
154
  60,
155
155
  80,
156
- 120,
157
- 160,
158
- ],
159
- )
160
-
161
- self.histogram_time_per_output_token = Histogram(
162
- name="sglang:time_per_output_token_seconds",
163
- documentation="Histogram of time per output token in seconds.",
164
- labelnames=labels.keys(),
165
- buckets=[
166
- 0.002,
167
- 0.005,
168
- 0.010,
169
- 0.020,
170
- 0.030,
171
- 0.040,
172
- 0.050,
173
- 0.060,
174
- 0.070,
175
- 0.080,
176
- 0.090,
177
- 0.100,
178
- 0.150,
179
- 0.200,
180
- 0.300,
181
- 0.400,
182
- 0.600,
183
- 0.800,
184
- 1.000,
185
- 2.000,
156
+ 100,
157
+ 200,
158
+ 400,
186
159
  ],
187
160
  )
188
161
 
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
202
175
  0.030,
203
176
  0.035,
204
177
  0.040,
205
- 0.050,
206
- 0.075,
178
+ 0.060,
179
+ 0.080,
207
180
  0.100,
208
- 0.150,
209
181
  0.200,
210
- 0.300,
211
182
  0.400,
212
- 0.500,
213
- 0.750,
183
+ 0.600,
184
+ 0.800,
214
185
  1.000,
215
186
  2.000,
187
+ 4.000,
188
+ 6.000,
189
+ 8.000,
216
190
  ],
217
191
  )
218
192
 
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
224
198
  0.1,
225
199
  0.2,
226
200
  0.4,
201
+ 0.6,
227
202
  0.8,
228
203
  1,
229
204
  2,
230
- 5,
205
+ 4,
206
+ 6,
207
+ 8,
231
208
  10,
232
209
  20,
233
210
  40,
234
211
  60,
235
212
  80,
236
213
  100,
237
- 150,
238
214
  200,
239
- 250,
240
- 300,
241
- 350,
242
- 500,
243
- 1000,
215
+ 400,
216
+ 800,
244
217
  ],
245
218
  )
246
219
 
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
256
229
  ):
257
230
  self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
258
231
  self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
259
- self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
232
+ if cached_tokens > 0:
233
+ self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
260
234
  self.num_requests_total.labels(**self.labels).inc(1)
261
235
  self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
262
- if generation_tokens >= 1:
263
- self.histogram_time_per_output_token.labels(**self.labels).observe(
264
- e2e_latency / generation_tokens
265
- )
266
236
 
267
237
  def observe_time_to_first_token(self, value: float):
268
238
  self.histogram_time_to_first_token.labels(**self.labels).observe(value)
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
116
116
  if capture_bs is None:
117
117
  if server_args.speculative_algorithm is None:
118
118
  if server_args.disable_cuda_graph_padding:
119
- capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
119
+ capture_bs = list(range(1, 33)) + range(40, 161, 16)
120
120
  else:
121
- capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
121
+ capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
122
122
  else:
123
123
  # Since speculative decoding requires more cuda graph memory, we
124
124
  # capture less.
125
- capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
125
+ capture_bs = (
126
+ list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
127
+ )
126
128
 
127
- if _is_hip:
128
- capture_bs += [i * 8 for i in range(21, 33)]
129
+ if _is_hip:
130
+ capture_bs += list(range(160, 257, 8))
129
131
 
130
132
  if max(capture_bs) > model_runner.req_to_token_pool.size:
131
133
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -174,6 +176,7 @@ class CudaGraphRunner:
174
176
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175
177
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176
178
  self.enable_dp_attention = model_runner.server_args.enable_dp_attention
179
+ self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
177
180
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
178
181
  self.tp_size = model_runner.server_args.tp_size
179
182
  self.dp_size = model_runner.server_args.dp_size
@@ -245,8 +248,8 @@ class CudaGraphRunner:
245
248
  )
246
249
  else:
247
250
  self.encoder_lens = None
248
-
249
- if self.enable_dp_attention:
251
+ if self.enable_dp_attention or self.enable_sp_layernorm:
252
+ # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
250
253
  self.gathered_buffer = torch.zeros(
251
254
  (
252
255
  self.max_bs * self.dp_size * self.num_tokens_per_bs,
@@ -288,7 +291,7 @@ class CudaGraphRunner:
288
291
  self.model_runner.token_to_kv_pool.capture_mode = False
289
292
 
290
293
  def can_run(self, forward_batch: ForwardBatch):
291
- if self.enable_dp_attention:
294
+ if self.enable_dp_attention or self.enable_sp_layernorm:
292
295
  total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
293
296
 
294
297
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
@@ -369,7 +372,7 @@ class CudaGraphRunner:
369
372
  encoder_lens = None
370
373
  mrope_positions = self.mrope_positions[:, :bs]
371
374
 
372
- if self.enable_dp_attention:
375
+ if self.enable_dp_attention or self.enable_sp_layernorm:
373
376
  self.global_num_tokens_gpu.copy_(
374
377
  torch.tensor(
375
378
  [
@@ -471,7 +474,7 @@ class CudaGraphRunner:
471
474
  raw_num_token = raw_bs * self.num_tokens_per_bs
472
475
 
473
476
  # Pad
474
- if self.enable_dp_attention:
477
+ if self.enable_dp_attention or self.enable_sp_layernorm:
475
478
  index = bisect.bisect_left(
476
479
  self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
477
480
  )
@@ -488,16 +491,16 @@ class CudaGraphRunner:
488
491
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
489
492
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
490
493
  self.positions[:raw_num_token].copy_(forward_batch.positions)
491
- if forward_batch.decode_seq_lens_cpu is not None:
494
+ if forward_batch.seq_lens_cpu is not None:
492
495
  if bs != raw_bs:
493
496
  self.seq_lens_cpu.fill_(1)
494
- self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
497
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
495
498
 
496
499
  if self.is_encoder_decoder:
497
500
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
498
501
  if forward_batch.mrope_positions is not None:
499
502
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
500
- if self.enable_dp_attention:
503
+ if self.enable_dp_attention or self.enable_sp_layernorm:
501
504
  self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
502
505
 
503
506
  if hasattr(forward_batch.spec_info, "hidden_states"):
@@ -104,6 +104,9 @@ class ForwardMode(IntEnum):
104
104
  or self == ForwardMode.IDLE
105
105
  )
106
106
 
107
+ def is_extend_or_draft_extend(self):
108
+ return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
109
+
107
110
  def is_dummy_first(self):
108
111
  return self == ForwardMode.DUMMY_FIRST
109
112
 
@@ -148,6 +151,9 @@ class ForwardBatch:
148
151
  # The sum of all sequence lengths
149
152
  seq_lens_sum: int
150
153
 
154
+ # Optional seq_lens on cpu
155
+ seq_lens_cpu: Optional[torch.Tensor] = None
156
+
151
157
  # For logprob
152
158
  return_logprob: bool = False
153
159
  top_logprobs_nums: Optional[List[int]] = None
@@ -162,9 +168,6 @@ class ForwardBatch:
162
168
  # Position information
163
169
  positions: torch.Tensor = None
164
170
 
165
- # For decode
166
- decode_seq_lens_cpu: Optional[torch.Tensor] = None
167
-
168
171
  # For extend
169
172
  extend_num_tokens: Optional[int] = None
170
173
  extend_seq_lens: Optional[torch.Tensor] = None
@@ -293,12 +296,14 @@ class ForwardBatch:
293
296
  ):
294
297
  ret.positions = ret.spec_info.positions
295
298
 
299
+ # Get seq_lens_cpu if needed
300
+ if ret.seq_lens_cpu is None:
301
+ ret.seq_lens_cpu = batch.seq_lens_cpu
302
+
296
303
  # Init position information
297
304
  if ret.forward_mode.is_decode():
298
305
  if ret.positions is None:
299
306
  ret.positions = clamp_position(batch.seq_lens)
300
- if ret.decode_seq_lens_cpu is None:
301
- ret.decode_seq_lens_cpu = batch.decode_seq_lens
302
307
  else:
303
308
  ret.extend_seq_lens = torch.tensor(
304
309
  batch.extend_seq_lens, dtype=torch.int32
@@ -353,11 +358,6 @@ class ForwardBatch:
353
358
  for mm_input in valid_inputs[1:]:
354
359
  merged.merge(mm_input)
355
360
 
356
- if isinstance(merged.pixel_values, np.ndarray):
357
- merged.pixel_values = torch.from_numpy(merged.pixel_values)
358
- if isinstance(merged.audio_features, np.ndarray):
359
- merged.audio_features = torch.from_numpy(merged.audio_features)
360
-
361
361
  return merged
362
362
 
363
363
  def contains_image_inputs(self) -> bool:
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
64
64
  )
65
65
  from sglang.srt.model_loader.utils import set_default_torch_dtype
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
67
68
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
68
69
  from sglang.srt.server_args import ServerArgs
69
70
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -74,6 +75,7 @@ from sglang.srt.utils import (
74
75
  get_available_gpu_memory,
75
76
  init_custom_process_group,
76
77
  is_cuda,
78
+ is_flashinfer_available,
77
79
  is_hip,
78
80
  monkey_patch_p2p_access_check,
79
81
  monkey_patch_vllm_gguf_config,
@@ -122,6 +124,10 @@ class ModelRunner:
122
124
  self.page_size = server_args.page_size
123
125
  self.req_to_token_pool = req_to_token_pool
124
126
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
127
+ self.use_mla_backend = (
128
+ self.model_config.attention_arch == AttentionArch.MLA
129
+ and not server_args.disable_mla
130
+ )
125
131
 
126
132
  # Model-specific adjustment
127
133
  self.model_specific_adjustment()
@@ -146,15 +152,18 @@ class ModelRunner:
146
152
  "enable_dp_attention": server_args.enable_dp_attention,
147
153
  "enable_ep_moe": server_args.enable_ep_moe,
148
154
  "enable_deepep_moe": server_args.enable_deepep_moe,
155
+ "deepep_mode": server_args.deepep_mode,
149
156
  "device": server_args.device,
150
157
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
151
158
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
152
- "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
153
159
  "enable_flashmla": server_args.enable_flashmla,
154
160
  "disable_radix_cache": server_args.disable_radix_cache,
155
161
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
156
162
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
157
163
  "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
164
+ "n_share_experts_fusion": server_args.n_share_experts_fusion,
165
+ "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
166
+ "use_mla_backend": self.use_mla_backend,
158
167
  }
159
168
  )
160
169
 
@@ -215,23 +224,38 @@ class ModelRunner:
215
224
  def model_specific_adjustment(self):
216
225
  server_args = self.server_args
217
226
 
218
- if (
219
- self.model_config.attention_arch == AttentionArch.MLA
220
- and not server_args.disable_mla
221
- ):
227
+ if server_args.enable_flashinfer_mla:
228
+ # TODO: remove this branch after enable_flashinfer_mla is deprecated
229
+ logger.info("MLA optimization is turned on. Use flashinfer backend.")
230
+ server_args.attention_backend = "flashinfer"
231
+ elif server_args.enable_flashmla:
232
+ # TODO: remove this branch after enable_flashmla is deprecated
233
+ logger.info("MLA optimization is turned on. Use flashmla decode.")
234
+ server_args.attention_backend = "flashmla"
235
+ elif server_args.attention_backend is None:
236
+ # By default, use flashinfer for non-mla attention and triton for mla attention
237
+ if not self.use_mla_backend:
238
+ server_args.attention_backend = (
239
+ "flashinfer" if is_flashinfer_available() else "triton"
240
+ )
241
+ else:
242
+ server_args.attention_backend = "triton"
243
+ logger.info(
244
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
245
+ )
246
+ elif self.use_mla_backend:
222
247
  # TODO: add MLA optimization on CPU
223
248
  if server_args.device != "cpu":
224
- if server_args.enable_flashinfer_mla:
249
+ if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
225
250
  logger.info(
226
- "MLA optimization is turned on. Use flashinfer mla backend."
251
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
227
252
  )
228
- server_args.attention_backend = "flashinfer_mla"
229
- elif server_args.enable_flashmla:
230
- logger.info("MLA optimization is turned on. Use flashmla decode.")
231
- server_args.attention_backend = "flashmla"
232
253
  else:
233
- logger.info("MLA optimization is turned on. Use triton backend.")
234
- server_args.attention_backend = "triton"
254
+ raise ValueError(
255
+ f"Invalid attention backend for MLA: {server_args.attention_backend}"
256
+ )
257
+ else:
258
+ raise ValueError(f"MLA optimization not supported on CPU.")
235
259
 
236
260
  if server_args.enable_double_sparsity:
237
261
  logger.info(
@@ -246,17 +270,16 @@ class ModelRunner:
246
270
  self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
247
271
 
248
272
  if self.is_multimodal:
249
- self.mem_fraction_static *= 0.95
273
+ self.mem_fraction_static *= 0.90
250
274
  logger.info(
251
275
  f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
252
276
  f"because this is a multimodal model."
253
277
  )
254
278
 
255
- if self.model_config.hf_config.architectures == [
256
- "MllamaForConditionalGeneration"
257
- ]:
258
- logger.info("Automatically turn off --chunked-prefill-size for mllama.")
259
- server_args.chunked_prefill_size = -1
279
+ logger.info(
280
+ "Automatically turn off --chunked-prefill-size for multimodal model."
281
+ )
282
+ server_args.chunked_prefill_size = -1
260
283
 
261
284
  if self.model_config.hf_config.architectures == [
262
285
  "Qwen2VLForConditionalGeneration"
@@ -264,25 +287,11 @@ class ModelRunner:
264
287
  "Qwen2_5_VLForConditionalGeneration"
265
288
  ]:
266
289
  # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
267
- logger.info(
268
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
269
- )
270
- server_args.chunked_prefill_size = -1
271
- server_args.disable_radix_cache = True
272
-
273
- if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
274
- # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
275
- logger.info(
276
- "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
277
- )
278
- server_args.chunked_prefill_size = -1
290
+ logger.info("Automatically disable radix cache for qwen-vl series.")
279
291
  server_args.disable_radix_cache = True
280
292
 
281
293
  if server_args.enable_deepep_moe:
282
- logger.info("DeepEP is turned on.")
283
- assert (
284
- server_args.enable_dp_attention == True
285
- ), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
294
+ logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
286
295
 
287
296
  def init_torch_distributed(self):
288
297
  logger.info("Init torch distributed begin.")
@@ -644,10 +653,7 @@ class ModelRunner:
644
653
  available_gpu_memory = get_available_gpu_memory(
645
654
  self.device, self.gpu_id, distributed=self.tp_size > 1
646
655
  )
647
- if (
648
- self.model_config.attention_arch == AttentionArch.MLA
649
- and not self.server_args.disable_mla
650
- ):
656
+ if self.use_mla_backend:
651
657
  cell_size = (
652
658
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
653
659
  * self.model_config.num_hidden_layers
@@ -758,10 +764,7 @@ class ModelRunner:
758
764
  # Draft worker shares req_to_token_pool with the target worker.
759
765
  assert self.is_draft_worker
760
766
 
761
- if (
762
- self.model_config.attention_arch == AttentionArch.MLA
763
- and not self.server_args.disable_mla
764
- ):
767
+ if self.use_mla_backend:
765
768
  self.token_to_kv_pool = MLATokenToKVPool(
766
769
  self.max_total_num_tokens,
767
770
  page_size=self.page_size,
@@ -832,14 +835,21 @@ class ModelRunner:
832
835
  def init_attention_backend(self):
833
836
  """Init attention kernel backend."""
834
837
  if self.server_args.attention_backend == "flashinfer":
835
- from sglang.srt.layers.attention.flashinfer_backend import (
836
- FlashInferAttnBackend,
837
- )
838
+ if not self.use_mla_backend:
839
+ from sglang.srt.layers.attention.flashinfer_backend import (
840
+ FlashInferAttnBackend,
841
+ )
838
842
 
839
- # Init streams
840
- if self.server_args.speculative_algorithm == "EAGLE":
841
- self.plan_stream_for_flashinfer = torch.cuda.Stream()
842
- self.attn_backend = FlashInferAttnBackend(self)
843
+ # Init streams
844
+ if self.server_args.speculative_algorithm == "EAGLE":
845
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
846
+ self.attn_backend = FlashInferAttnBackend(self)
847
+ else:
848
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
849
+ FlashInferMLAAttnBackend,
850
+ )
851
+
852
+ self.attn_backend = FlashInferMLAAttnBackend(self)
843
853
  elif self.server_args.attention_backend == "triton":
844
854
  assert self.sliding_window_size is None, (
845
855
  "Window attention is not supported in the triton attention backend. "
@@ -865,12 +875,6 @@ class ModelRunner:
865
875
  )
866
876
 
867
877
  self.attn_backend = TorchNativeAttnBackend(self)
868
- elif self.server_args.attention_backend == "flashinfer_mla":
869
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
870
- FlashInferMLAAttnBackend,
871
- )
872
-
873
- self.attn_backend = FlashInferMLAAttnBackend(self)
874
878
  elif self.server_args.attention_backend == "flashmla":
875
879
  from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
876
880
 
@@ -881,7 +885,7 @@ class ModelRunner:
881
885
  "Please use `--attention-backend flashinfer`."
882
886
  )
883
887
  logger.warning(
884
- "FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
888
+ "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
885
889
  )
886
890
  from sglang.srt.layers.attention.flashattention_backend import (
887
891
  FlashAttentionBackend,
@@ -1082,8 +1086,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
1082
1086
 
1083
1087
  def _unwrap_tensor(tensor, tp_rank):
1084
1088
  if isinstance(tensor, LocalSerializedTensor):
1085
- return tensor.get(tp_rank)
1086
- return tensor
1089
+ monkey_patch_torch_reductions()
1090
+ tensor = tensor.get(tp_rank)
1091
+ return tensor.to(torch.cuda.current_device())
1087
1092
 
1088
1093
 
1089
1094
  @dataclass
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
14
14
  from contextlib import contextmanager
15
15
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
16
16
 
17
- import gguf
18
17
  import huggingface_hub
19
18
  import numpy as np
20
19
  import torch
@@ -490,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
490
489
  # NOTE(woosuk): For accurate performance evaluation, we assign
491
490
  # random values to the weights.
492
491
  initialize_dummy_weights(model)
492
+
493
+ # Model weight loading consists of two stages:
494
+ # 1. Initial weight loading.
495
+ # 2. Post-processing of weights, including assigning specific member variables.
496
+ # For `dummy_init`, only the second stage is required.
497
+ if hasattr(model, "post_load_weights"):
498
+ model.post_load_weights()
499
+
493
500
  return model.eval()
494
501
 
495
502
 
@@ -1155,6 +1162,17 @@ class GGUFModelLoader(BaseModelLoader):
1155
1162
  See "Standardized tensor names" in
1156
1163
  https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
1157
1164
  """
1165
+
1166
+ # only load the gguf module when needed
1167
+ try:
1168
+ import gguf
1169
+
1170
+ # FIXME: add version check for gguf
1171
+ except ImportError as err:
1172
+ raise ImportError(
1173
+ "Please install gguf via `pip install gguf` to use gguf quantizer."
1174
+ ) from err
1175
+
1158
1176
  config = model_config.hf_config
1159
1177
  model_type = config.model_type
1160
1178
  # hack: ggufs have a different name than transformers
@@ -22,7 +22,6 @@ from typing import (
22
22
  )
23
23
 
24
24
  import filelock
25
- import gguf
26
25
  import huggingface_hub.constants
27
26
  import numpy as np
28
27
  import safetensors.torch
@@ -93,7 +92,7 @@ def convert_bin_to_safetensor_file(
93
92
  pt_filename: str,
94
93
  sf_filename: str,
95
94
  ) -> None:
96
- loaded = torch.load(pt_filename, map_location="cpu")
95
+ loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
97
96
  if "state_dict" in loaded:
98
97
  loaded = loaded["state_dict"]
99
98
  shared = _shared_pointers(loaded)
@@ -381,7 +380,7 @@ def np_cache_weights_iterator(
381
380
  disable=not enable_tqdm,
382
381
  bar_format=_BAR_FORMAT,
383
382
  ):
384
- state = torch.load(bin_file, map_location="cpu")
383
+ state = torch.load(bin_file, map_location="cpu", weights_only=True)
385
384
  for name, param in state.items():
386
385
  param_path = os.path.join(np_folder, name)
387
386
  with open(param_path, "wb") as f:
@@ -464,6 +463,8 @@ def pt_weights_iterator(
464
463
  def get_gguf_extra_tensor_names(
465
464
  gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
466
465
  ) -> List[str]:
466
+ import gguf
467
+
467
468
  reader = gguf.GGUFReader(gguf_file)
468
469
  expected_gguf_keys = set(gguf_to_hf_name_map.keys())
469
470
  exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator(
479
480
  them to torch tensors
480
481
  """
481
482
 
483
+ import gguf
484
+
482
485
  reader = gguf.GGUFReader(gguf_file)
483
486
 
484
487
  for tensor in reader.tensors: