sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module):
118
118
  self.scaling,
119
119
  num_kv_heads=self.num_kv_heads,
120
120
  layer_id=layer_id,
121
+ quant_config=quant_config,
121
122
  prefix=add_prefix("attn", prefix),
122
123
  )
123
124
 
@@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module):
170
170
  self.scaling,
171
171
  layer_id=layer_id,
172
172
  num_kv_heads=self.num_kv_heads,
173
+ quant_config=quant_config,
173
174
  prefix=add_prefix("attn", prefix),
174
175
  )
175
176
 
@@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
202
202
  self.scale,
203
203
  num_kv_heads=self.num_kv_heads_per_partion,
204
204
  layer_id=layer_id,
205
+ quant_config=quant_config,
205
206
  prefix=add_prefix("attn", prefix),
206
207
  )
207
208
 
sglang/srt/models/qwen.py CHANGED
@@ -133,6 +133,7 @@ class QWenAttention(nn.Module):
133
133
  self.scaling,
134
134
  num_kv_heads=self.num_heads,
135
135
  layer_id=layer_id,
136
+ quant_config=quant_config,
136
137
  prefix=add_prefix("attn", prefix),
137
138
  )
138
139
 
@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
154
154
  self.scaling,
155
155
  num_kv_heads=self.num_kv_heads,
156
156
  layer_id=layer_id,
157
+ quant_config=quant_config,
157
158
  prefix=add_prefix("attn", prefix),
158
159
  )
159
160
 
@@ -30,12 +30,16 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
- from transformers import Qwen2VLConfig
34
33
  from transformers.activations import ACT2FN
35
34
  from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
36
35
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
+ Qwen2_5_VLConfig,
37
37
  Qwen2_5_VLVisionConfig,
38
38
  )
39
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
+ Qwen2_5_VisionPatchEmbed,
41
+ Qwen2_5_VisionRotaryEmbedding,
42
+ )
39
43
 
40
44
  from sglang.srt.hf_transformers_utils import get_processor
41
45
  from sglang.srt.layers.attention.vision import VisionAttention
@@ -137,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
137
141
  embed_dim=dim,
138
142
  num_heads=num_heads,
139
143
  projection_size=dim,
140
- use_qkv_parallel=False,
144
+ use_qkv_parallel=True,
141
145
  use_context_forward=use_context_forward,
142
146
  softmax_in_single_precision=softmax_in_single_precision,
143
147
  flatten_batch=flatten_batch,
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
173
177
  return x
174
178
 
175
179
 
176
- class Qwen2_5_VisionPatchEmbed(nn.Module):
177
-
178
- def __init__(
179
- self,
180
- patch_size: int = 14,
181
- temporal_patch_size: int = 2,
182
- in_chans: int = 3,
183
- embed_dim: int = 1152,
184
- ) -> None:
185
- super().__init__()
186
- self.patch_size = patch_size
187
- self.temporal_patch_size = temporal_patch_size
188
- self.embed_dim = embed_dim
189
-
190
- kernel_size = [temporal_patch_size, patch_size, patch_size]
191
- self.proj = nn.Conv3d(
192
- in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
193
- )
194
-
195
- def forward(self, x: torch.Tensor) -> torch.Tensor:
196
- target_dtype = self.proj.weight.dtype
197
- L, C = x.shape
198
- x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
199
- x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
200
- return x
201
-
202
-
203
180
  class Qwen2_5_VisionPatchMerger(nn.Module):
204
181
 
205
182
  def __init__(
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
244
221
  return out
245
222
 
246
223
 
247
- class Qwen2_5_VisionRotaryEmbedding(nn.Module):
248
-
249
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
250
- super().__init__()
251
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
252
- self.register_buffer("inv_freq", inv_freq, persistent=False)
253
-
254
- def forward(self, seqlen: int) -> torch.Tensor:
255
- seq = torch.arange(
256
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
257
- )
258
- freqs = torch.outer(seq, self.inv_freq)
259
- return freqs
260
-
261
-
262
224
  class Qwen2_5_VisionTransformer(nn.Module):
263
225
 
264
226
  def __init__(
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
275
237
  spatial_merge_size: int = vision_config.spatial_merge_size
276
238
  self.spatial_merge_size = spatial_merge_size
277
239
  self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
278
- in_chans: int = vision_config.in_channels
240
+ in_channels: int = vision_config.in_channels
279
241
  hidden_size: int = vision_config.hidden_size
280
242
  depth: int = vision_config.depth
281
243
  num_heads: int = vision_config.num_heads
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
286
248
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
287
249
  patch_size=patch_size,
288
250
  temporal_patch_size=temporal_patch_size,
289
- in_chans=in_chans,
251
+ in_channels=in_channels,
290
252
  embed_dim=hidden_size,
291
253
  )
292
254
 
@@ -363,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
363
325
 
364
326
  @property
365
327
  def dtype(self) -> torch.dtype:
366
- return self.blocks[0].mlp.gate_proj.weight.dtype
328
+ return self.patch_embed.proj.weight.dtype
367
329
 
368
330
  @property
369
331
  def device(self) -> torch.device:
@@ -467,9 +429,28 @@ cached_get_processor = lru_cache(get_processor)
467
429
 
468
430
 
469
431
  class Qwen2_5_VLForConditionalGeneration(nn.Module):
432
+ # BitandBytes specific attributes
433
+ default_bitsandbytes_target_modules = [
434
+ ".gate_proj.",
435
+ ".down_proj.",
436
+ ".up_proj.",
437
+ ".q_proj.",
438
+ ".k_proj.",
439
+ ".v_proj.",
440
+ ".o_proj.",
441
+ ]
442
+ bitsandbytes_stacked_params_mapping = {
443
+ # shard_name, weight_name, index
444
+ "q_proj": ("qkv_proj", 0),
445
+ "k_proj": ("qkv_proj", 1),
446
+ "v_proj": ("qkv_proj", 2),
447
+ "gate_proj": ("gate_up_proj", 0),
448
+ "up_proj": ("gate_up_proj", 1),
449
+ }
450
+
470
451
  def __init__(
471
452
  self,
472
- config: Qwen2VLConfig,
453
+ config: Qwen2_5_VLConfig,
473
454
  quant_config: Optional[QuantizationConfig] = None,
474
455
  prefix: str = "",
475
456
  ) -> None:
@@ -479,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
479
460
  self.visual = Qwen2_5_VisionTransformer(
480
461
  config.vision_config,
481
462
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
482
- # NOTE: Qwen2-VL vision encoder does not support any
483
- # quantization method now.
484
- quant_config=None,
463
+ # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
464
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
465
+ quant_config=quant_config,
485
466
  prefix=add_prefix("visual", prefix),
486
467
  )
487
468
 
@@ -500,6 +481,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
500
481
  quant_config=quant_config,
501
482
  prefix=add_prefix("lm_head", prefix),
502
483
  )
484
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
503
485
 
504
486
  self.logits_processor = LogitsProcessor(config)
505
487
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -553,14 +535,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
553
535
  otherwise it will be `(seq_len,).
554
536
  (Use input_metadata.mrope_positions to replace it)
555
537
  """
556
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
538
+ if self.is_mrope_enabled:
557
539
  positions = forward_batch.mrope_positions
558
540
 
559
541
  if not (
560
542
  forward_batch.forward_mode.is_decode()
561
543
  or not forward_batch.contains_image_inputs()
562
544
  ):
563
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
545
+ if self.is_mrope_enabled:
564
546
  assert positions.ndim == 2 and positions.size(0) == 3, (
565
547
  "multimodal section rotary embedding requires "
566
548
  f"(3, seq_len) positions, but got {positions.size()}"
@@ -610,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
610
592
  weight_loader(param, loaded_weight, shard_id)
611
593
  break
612
594
  else:
613
- if "visual" in name and "qkv.weight" in name:
614
- visual_num_heads = self.config.vision_config.num_heads
615
- visual_embed_dim = self.config.vision_config.hidden_size
616
- head_size = visual_embed_dim // visual_num_heads
617
- loaded_weight = loaded_weight.view(
618
- 3, visual_num_heads, head_size, visual_embed_dim
619
- )
620
- loaded_weight = loaded_weight.transpose(0, 1)
621
- loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
622
- elif "visual" in name and "qkv.bias" in name:
623
- visual_num_heads = self.config.vision_config.num_heads
624
- visual_embed_dim = self.config.vision_config.hidden_size
625
- head_size = visual_embed_dim // visual_num_heads
626
- loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
627
- loaded_weight = loaded_weight.transpose(0, 1)
628
- loaded_weight = loaded_weight.reshape(-1)
629
-
630
595
  if "visual" in name:
631
596
  # adapt to VisionAttention
632
597
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
@@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module):
231
231
  self.scaling,
232
232
  num_kv_heads=self.num_kv_heads,
233
233
  layer_id=layer_id,
234
+ quant_config=quant_config,
234
235
  prefix=add_prefix("attn", prefix),
235
236
  )
236
237
 
@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
152
152
  embed_dim=dim,
153
153
  num_heads=num_heads,
154
154
  projection_size=dim,
155
- use_qkv_parallel=False,
155
+ use_qkv_parallel=True,
156
156
  use_context_forward=use_context_forward,
157
157
  softmax_in_single_precision=softmax_in_single_precision,
158
158
  flatten_batch=True,
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
351
351
 
352
352
  @property
353
353
  def dtype(self) -> torch.dtype:
354
- return next(self.parameters()).dtype
354
+ return self.patch_embed.proj.weight.dtype
355
355
 
356
356
  @property
357
357
  def device(self) -> torch.device:
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
423
423
 
424
424
 
425
425
  class Qwen2VLForConditionalGeneration(nn.Module):
426
+ # BitandBytes specific attributes
427
+ default_bitsandbytes_target_modules = [
428
+ ".gate_proj.",
429
+ ".down_proj.",
430
+ ".up_proj.",
431
+ ".q_proj.",
432
+ ".k_proj.",
433
+ ".v_proj.",
434
+ ".o_proj.",
435
+ ]
436
+ bitsandbytes_stacked_params_mapping = {
437
+ # shard_name, weight_name, index
438
+ "q_proj": ("qkv_proj", 0),
439
+ "k_proj": ("qkv_proj", 1),
440
+ "v_proj": ("qkv_proj", 2),
441
+ "gate_proj": ("gate_up_proj", 0),
442
+ "up_proj": ("gate_up_proj", 1),
443
+ }
444
+
426
445
  def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
427
446
  processor = cached_get_processor(self.config._name_or_path)
428
447
  grid_t, grid_h, grid_w = image_grid_thw
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
447
466
  self.visual = Qwen2VisionTransformer(
448
467
  config.vision_config,
449
468
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
450
- # NOTE: Qwen2-VL vision encoder does not support any
451
- # quantization method now.
452
- quant_config=None,
469
+ # NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
470
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
471
+ quant_config=quant_config,
453
472
  prefix=add_prefix("visual", prefix),
454
473
  )
455
474
 
@@ -467,6 +486,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
467
486
  prefix=add_prefix("lm_head", prefix),
468
487
  )
469
488
 
489
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
470
490
  self.logits_processor = LogitsProcessor(config)
471
491
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
472
492
 
@@ -521,14 +541,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
521
541
  otherwise it will be `(seq_len,).
522
542
  (Use input_metadata.mrope_positions to replace it)
523
543
  """
524
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
544
+ if self.is_mrope_enabled:
525
545
  positions = forward_batch.mrope_positions
526
546
 
527
547
  if not (
528
548
  forward_batch.forward_mode.is_decode()
529
549
  or not forward_batch.contains_image_inputs()
530
550
  ):
531
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
551
+ if self.is_mrope_enabled:
532
552
  assert positions.ndim == 2 and positions.size(0) == 3, (
533
553
  "multimodal section rotary embedding requires "
534
554
  f"(3, seq_len) positions, but got {positions.size()}"
@@ -577,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
577
597
  weight_loader(param, loaded_weight, shard_id)
578
598
  break
579
599
  else:
580
-
581
- if "visual" in name and "qkv.weight" in name:
582
- visual_num_heads = self.config.vision_config.num_heads
583
- visual_embed_dim = self.config.vision_config.embed_dim
584
- head_size = visual_embed_dim // visual_num_heads
585
- loaded_weight = loaded_weight.view(
586
- 3, visual_num_heads, head_size, visual_embed_dim
587
- )
588
- loaded_weight = loaded_weight.transpose(0, 1)
589
- loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
590
- elif "visual" in name and "qkv.bias" in name:
591
- visual_num_heads = self.config.vision_config.num_heads
592
- visual_embed_dim = self.config.vision_config.embed_dim
593
- head_size = visual_embed_dim // visual_num_heads
594
- loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
595
- loaded_weight = loaded_weight.transpose(0, 1)
596
- loaded_weight = loaded_weight.reshape(-1)
597
-
598
600
  if "visual" in name:
599
601
  # adapt to VisionAttention
600
602
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
@@ -149,6 +149,7 @@ class StablelmAttention(nn.Module):
149
149
  self.scaling,
150
150
  num_kv_heads=self.num_key_value_heads,
151
151
  layer_id=layer_id,
152
+ quant_config=quant_config,
152
153
  prefix=add_prefix("attn", prefix),
153
154
  )
154
155
 
@@ -153,6 +153,7 @@ class XverseAttention(nn.Module):
153
153
  self.scaling,
154
154
  num_kv_heads=self.num_kv_heads,
155
155
  layer_id=layer_id,
156
+ quant_config=quant_config,
156
157
  prefix=add_prefix("attn", prefix),
157
158
  )
158
159
 
@@ -252,6 +252,7 @@ class XverseAttention(nn.Module):
252
252
  self.scaling,
253
253
  num_kv_heads=self.num_kv_heads,
254
254
  layer_id=layer_id,
255
+ quant_config=quant_config,
255
256
  prefix=add_prefix("attn", prefix),
256
257
  )
257
258
 
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
983
983
  ):
984
984
  encoded = encoded[1:]
985
985
  prompt_ids += encoded
986
+ if tokenizer_manager.model_config.is_multimodal:
987
+ prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
986
988
  stop = request.stop
987
989
  image_data = None
988
990
  audio_data = None
@@ -993,7 +995,8 @@ def v1_chat_generate_request(
993
995
  image_data = conv.image_data
994
996
  audio_data = conv.audio_data
995
997
  modalities = conv.modalities
996
- stop = conv.stop_str or []
998
+ stop = conv.stop_str or [] if not request.ignore_eos else []
999
+
997
1000
  if request.stop:
998
1001
  if isinstance(request.stop, str):
999
1002
  stop.append(request.stop)
sglang/srt/patch_torch.py CHANGED
@@ -14,6 +14,7 @@
14
14
  from typing import Callable, Union
15
15
 
16
16
  import torch
17
+ from packaging import version
17
18
  from torch.multiprocessing import reductions
18
19
 
19
20
 
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
69
70
 
70
71
  def _modify_tuple(t, index: int, modifier: Callable):
71
72
  return *t[:index], modifier(t[index]), *t[index + 1 :]
73
+
74
+
75
+ def monkey_patch_torch_compile():
76
+ if version.parse(torch.__version__) < version.parse("2.8.0"):
77
+ # These things are cacheable by torch.compile. torch.compile just doesn't know it.
78
+ # This was fixed in PyTorch 2.8, but until then, we monkey patch.
79
+ import torch._higher_order_ops.auto_functionalize as af
80
+
81
+ af.auto_functionalized_v2._cacheable = True
82
+ af.auto_functionalized._cacheable = True
sglang/srt/server_args.py CHANGED
@@ -156,6 +156,7 @@ class ServerArgs:
156
156
  disable_outlines_disk_cache: bool = False
157
157
  disable_custom_all_reduce: bool = False
158
158
  disable_mla: bool = False
159
+ enable_llama4_multimodal: Optional[bool] = None
159
160
  disable_overlap_schedule: bool = False
160
161
  enable_mixed_chunk: bool = False
161
162
  enable_dp_attention: bool = False
@@ -185,6 +186,7 @@ class ServerArgs:
185
186
  warmups: Optional[str] = None
186
187
  n_share_experts_fusion: int = 0
187
188
  disable_shared_experts_fusion: bool = False
189
+ disable_chunked_prefix_cache: bool = False
188
190
 
189
191
  # Debug tensor dumps
190
192
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -194,6 +196,10 @@ class ServerArgs:
194
196
  # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
195
197
  disaggregation_mode: str = "null"
196
198
  disaggregation_bootstrap_port: int = 8998
199
+ disaggregation_transfer_backend: str = "mooncake"
200
+
201
+ # multimodal
202
+ disable_fast_image_processor: bool = False
197
203
 
198
204
  def __post_init__(self):
199
205
  # Expert parallelism
@@ -294,6 +300,8 @@ class ServerArgs:
294
300
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
295
301
  )
296
302
 
303
+ self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal
304
+
297
305
  # Data parallelism attention
298
306
  if self.enable_dp_attention:
299
307
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
@@ -495,6 +503,7 @@ class ServerArgs:
495
503
  "bitsandbytes",
496
504
  "gguf",
497
505
  "modelopt",
506
+ "modelopt_fp4",
498
507
  "w8a8_int8",
499
508
  "w8a8_fp8",
500
509
  "moe_wna16",
@@ -973,6 +982,12 @@ class ServerArgs:
973
982
  action="store_true",
974
983
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
975
984
  )
985
+ parser.add_argument(
986
+ "--enable-llama4-multimodal",
987
+ default=ServerArgs.enable_llama4_multimodal,
988
+ action="store_true",
989
+ help="Enable the multimodal functionality for Llama-4.",
990
+ )
976
991
  parser.add_argument(
977
992
  "--disable-overlap-schedule",
978
993
  action="store_true",
@@ -1100,6 +1115,7 @@ class ServerArgs:
1100
1115
  "--deepep-mode",
1101
1116
  type=str,
1102
1117
  choices=["normal", "low_latency", "auto"],
1118
+ default="auto",
1103
1119
  help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
1104
1120
  )
1105
1121
 
@@ -1115,6 +1131,11 @@ class ServerArgs:
1115
1131
  action="store_true",
1116
1132
  help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
1117
1133
  )
1134
+ parser.add_argument(
1135
+ "--disable-chunked-prefix-cache",
1136
+ action="store_true",
1137
+ help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
1138
+ )
1118
1139
 
1119
1140
  # Server warmups
1120
1141
  parser.add_argument(
@@ -1159,6 +1180,19 @@ class ServerArgs:
1159
1180
  default=ServerArgs.disaggregation_bootstrap_port,
1160
1181
  help="Bootstrap server port on the prefill server. Default is 8998.",
1161
1182
  )
1183
+ parser.add_argument(
1184
+ "--disaggregation-transfer-backend",
1185
+ type=str,
1186
+ default=ServerArgs.disaggregation_transfer_backend,
1187
+ help="The backend for disaggregation transfer. Default is mooncake.",
1188
+ )
1189
+
1190
+ # Multimodal
1191
+ parser.add_argument(
1192
+ "--disable-fast-image-processor",
1193
+ action="store_true",
1194
+ help="Adopt base image processor instead of fast image processor.",
1195
+ )
1162
1196
 
1163
1197
  @classmethod
1164
1198
  def from_cli_args(cls, args: argparse.Namespace):
@@ -84,10 +84,10 @@ class EAGLEDraftCudaGraphRunner:
84
84
  raise Exception(
85
85
  f"Capture cuda graph failed: {e}\n"
86
86
  "Possible solutions:\n"
87
- "1. disable cuda graph by --disable-cuda-graph\n"
88
- "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
89
- "3. disable torch compile by not using --enable-torch-compile\n"
90
- "4. specify --dtype to the same dtype (e.g. bfloat16)\n"
87
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
88
+ "2. disable torch compile by not using --enable-torch-compile\n"
89
+ "3. specify --dtype to the same dtype (e.g. bfloat16)\n"
90
+ "4. disable cuda graph by --disable-cuda-graph\n"
91
91
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
92
92
  )
93
93
 
@@ -19,7 +19,7 @@ from sglang.srt.managers.schedule_batch import (
19
19
  from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
20
20
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
21
21
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
22
- from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
22
+ from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
23
23
 
24
24
  if is_cuda_available():
25
25
  from sgl_kernel import (
@@ -772,16 +772,6 @@ def select_top_k_tokens(
772
772
  return input_ids, hidden_states, scores, tree_info
773
773
 
774
774
 
775
- def fast_topk(values, topk, dim):
776
- if topk == 1:
777
- # Use max along the specified dimension to get both value and index
778
- max_value, max_index = torch.max(values, dim=dim)
779
- return max_value.unsqueeze(1), max_index.unsqueeze(1)
780
- else:
781
- # Use topk for efficiency with larger k values
782
- return torch.topk(values, topk, dim=dim)
783
-
784
-
785
775
  def _generate_simulated_accept_index(
786
776
  accept_index,
787
777
  predict,
@@ -31,11 +31,15 @@ from sglang.srt.speculative.eagle_utils import (
31
31
  EagleVerifyInput,
32
32
  EagleVerifyOutput,
33
33
  assign_draft_cache_locs,
34
- fast_topk,
35
34
  select_top_k_tokens,
36
35
  )
37
36
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
38
- from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
37
+ from sglang.srt.utils import (
38
+ empty_context,
39
+ fast_topk,
40
+ get_available_gpu_memory,
41
+ is_cuda_available,
42
+ )
39
43
 
40
44
  if is_cuda_available():
41
45
  from sgl_kernel import segment_packbits