sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,13 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
+ from sglang.srt.layers.dp_attention import (
31
+ dp_gather_partial,
32
+ dp_scatter,
33
+ get_attention_dp_size,
34
+ get_attention_tp_rank,
35
+ get_attention_tp_size,
36
+ )
30
37
  from sglang.srt.layers.layernorm import RMSNorm
31
38
  from sglang.srt.layers.linear import (
32
39
  QKVParallelLinear,
@@ -38,9 +45,10 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
45
  from sglang.srt.layers.radix_attention import RadixAttention
39
46
  from sglang.srt.layers.rotary_embedding import get_rope
40
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
41
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
50
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
43
- from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
51
+ from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
44
52
 
45
53
  logger = logging.getLogger(__name__)
46
54
 
@@ -55,7 +63,7 @@ class Llama4MoE(nn.Module):
55
63
  topk: int,
56
64
  renormalize: bool,
57
65
  ) -> Tuple[torch.Tensor, torch.Tensor]:
58
- router_scores_aK, router_indices_aK = torch.topk(gating_output, topk, dim=-1)
66
+ router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
59
67
  router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
60
68
  hidden_states.dtype
61
69
  )
@@ -143,20 +151,24 @@ class Llama4Attention(nn.Module):
143
151
  self.hidden_size = hidden_size
144
152
  self.use_rope = int((layer_id + 1) % 4 != 0)
145
153
  self.use_qk_norm = config.use_qk_norm and self.use_rope
146
- tp_size = get_tensor_model_parallel_world_size()
154
+
155
+ self.dp_size = get_attention_dp_size()
156
+ attn_tp_rank = get_attention_tp_rank()
157
+ attn_tp_size = get_attention_tp_size()
158
+
147
159
  self.total_num_heads = num_heads
148
- assert self.total_num_heads % tp_size == 0
149
- self.num_heads = self.total_num_heads // tp_size
160
+ assert self.total_num_heads % attn_tp_size == 0
161
+ self.num_heads = self.total_num_heads // attn_tp_size
150
162
  self.total_num_kv_heads = num_kv_heads
151
- if self.total_num_kv_heads >= tp_size:
163
+ if self.total_num_kv_heads >= attn_tp_size:
152
164
  # Number of KV heads is greater than TP size, so we partition
153
165
  # the KV heads across multiple tensor parallel GPUs.
154
- assert self.total_num_kv_heads % tp_size == 0
166
+ assert self.total_num_kv_heads % attn_tp_size == 0
155
167
  else:
156
168
  # Number of KV heads is less than TP size, so we replicate
157
169
  # the KV heads across multiple tensor parallel GPUs.
158
- assert tp_size % self.total_num_kv_heads == 0
159
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
170
+ assert attn_tp_size % self.total_num_kv_heads == 0
171
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
160
172
  self.head_dim = config.head_dim
161
173
  self.q_size = self.num_heads * self.head_dim
162
174
  self.kv_size = self.num_kv_heads * self.head_dim
@@ -183,6 +195,8 @@ class Llama4Attention(nn.Module):
183
195
  bias=bias,
184
196
  quant_config=quant_config,
185
197
  prefix=add_prefix("qkv_proj", prefix),
198
+ tp_rank=attn_tp_rank,
199
+ tp_size=attn_tp_size,
186
200
  )
187
201
 
188
202
  self.o_proj = RowParallelLinear(
@@ -191,6 +205,9 @@ class Llama4Attention(nn.Module):
191
205
  bias=bias_o_proj,
192
206
  quant_config=quant_config,
193
207
  prefix=add_prefix("o_proj", prefix),
208
+ tp_rank=attn_tp_rank,
209
+ tp_size=attn_tp_size,
210
+ reduce_results=False,
194
211
  )
195
212
  is_neox_style = True
196
213
  is_gguf = quant_config and quant_config.get_name() == "gguf"
@@ -223,9 +240,13 @@ class Llama4Attention(nn.Module):
223
240
  def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
224
241
  floor = torch.floor((positions + 1.0) / self.floor_scale)
225
242
  attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
226
-
227
243
  return attn_scale.unsqueeze(-1)
228
244
 
245
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
246
+ def _mul_attn_scale(self, positions, q):
247
+ attn_scale = self._get_attn_scale(positions)
248
+ return (q * attn_scale).to(q.dtype)
249
+
229
250
  def forward(
230
251
  self,
231
252
  positions: torch.Tensor,
@@ -233,27 +254,29 @@ class Llama4Attention(nn.Module):
233
254
  forward_batch: ForwardBatch,
234
255
  ) -> torch.Tensor:
235
256
  qkv, _ = self.qkv_proj(hidden_states)
236
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
257
+
258
+ qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
237
259
 
238
260
  if self.rotary_emb is not None:
239
- q, k = self.rotary_emb(positions, q, k)
261
+ q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
262
+ q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
263
+ assert (q_out_unused is q_view) and (k_out_unused is k_view)
264
+ del q_view, k_view, q_out_unused, k_out_unused
240
265
 
241
266
  if self.qk_norm is not None:
242
- # TODO: support float
243
- q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
244
- k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
245
- q = self.qk_norm(q).to(q.dtype)
246
- k = self.qk_norm(k).to(k.dtype)
247
- q = q.reshape(-1, self.q_size)
248
- k = k.reshape(-1, self.kv_size)
267
+ # TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
268
+ qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
269
+ qk = self.qk_norm(qk).to(torch.bfloat16)
270
+ qk = qk.reshape(-1, self.q_size + self.kv_size)
271
+
272
+ q, k = qk.split([self.q_size, self.kv_size], dim=-1)
249
273
 
250
274
  # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
251
275
  # the inference-time temperature tuning function is customized to not affect short context
252
276
  # while working at very long context
253
277
  # https://arxiv.org/abs/2501.19399
254
278
  if self.attn_temperature_tuning and not self.use_rope:
255
- attn_scale = self._get_attn_scale(positions)
256
- q = (q * attn_scale).to(q.dtype)
279
+ q = self._mul_attn_scale(positions=positions, q=q)
257
280
 
258
281
  attn_output = self.attn(q, k, v, forward_batch)
259
282
  output, _ = self.o_proj(attn_output)
@@ -274,6 +297,9 @@ class Llama4DecoderLayer(nn.Module):
274
297
  rope_theta = config.rope_theta
275
298
  rope_scaling = config.rope_scaling
276
299
  max_position_embeddings = config.max_position_embeddings
300
+ self.dp_size = get_attention_dp_size()
301
+ self.attn_tp_size = get_attention_tp_size()
302
+ self.attn_tp_rank = get_attention_tp_rank()
277
303
 
278
304
  self.self_attn = Llama4Attention(
279
305
  config=config,
@@ -316,21 +342,58 @@ class Llama4DecoderLayer(nn.Module):
316
342
  forward_batch: ForwardBatch,
317
343
  residual: Optional[torch.Tensor],
318
344
  ) -> Tuple[torch.Tensor, torch.Tensor]:
319
- # Self Attention
320
- if residual is None:
345
+ if hidden_states.shape[0] == 0:
321
346
  residual = hidden_states
322
- hidden_states = self.input_layernorm(hidden_states)
323
347
  else:
324
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
325
- hidden_states = self.self_attn(
326
- positions=positions,
327
- hidden_states=hidden_states,
328
- forward_batch=forward_batch,
329
- )
348
+ # Self Attention
349
+ if residual is None:
350
+ residual = hidden_states
351
+ hidden_states = self.input_layernorm(hidden_states)
352
+ else:
353
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
354
+ hidden_states = self.self_attn(
355
+ positions=positions,
356
+ hidden_states=hidden_states,
357
+ forward_batch=forward_batch,
358
+ )
359
+
360
+ # Gather
361
+ if get_tensor_model_parallel_world_size() > 1:
362
+ # all gather and all reduce
363
+ if self.dp_size != 1:
364
+ if self.attn_tp_rank == 0:
365
+ hidden_states += residual
366
+ hidden_states, local_hidden_states = (
367
+ forward_batch.gathered_buffer,
368
+ hidden_states,
369
+ )
370
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
371
+ dp_scatter(residual, hidden_states, forward_batch)
372
+ hidden_states = self.post_attention_layernorm(hidden_states)
373
+ else:
374
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
375
+ hidden_states, residual = self.post_attention_layernorm(
376
+ hidden_states, residual
377
+ )
378
+ else:
379
+ hidden_states, residual = self.post_attention_layernorm(
380
+ hidden_states, residual
381
+ )
330
382
 
331
383
  # Fully Connected
332
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
333
384
  hidden_states = self.feed_forward(hidden_states)
385
+
386
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
387
+ # Scatter
388
+ if self.dp_size != 1:
389
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
390
+ # be careful about this!
391
+ hidden_states, global_hidden_states = (
392
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
393
+ hidden_states,
394
+ )
395
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
396
+
334
397
  return hidden_states, residual
335
398
 
336
399
 
@@ -350,13 +413,14 @@ class Llama4Model(nn.Module):
350
413
  config.hidden_size,
351
414
  quant_config=quant_config,
352
415
  prefix=add_prefix("embed_tokens", prefix),
416
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
353
417
  )
354
418
  self.layers = make_layers(
355
419
  config.num_hidden_layers,
356
420
  lambda idx, prefix: Llama4DecoderLayer(
357
421
  config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
358
422
  ),
359
- prefix="model.layers",
423
+ prefix=add_prefix("layers", prefix),
360
424
  )
361
425
 
362
426
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -385,7 +449,8 @@ class Llama4Model(nn.Module):
385
449
  forward_batch,
386
450
  residual,
387
451
  )
388
- hidden_states, _ = self.norm(hidden_states, residual)
452
+ if not forward_batch.forward_mode.is_idle():
453
+ hidden_states, _ = self.norm(hidden_states, residual)
389
454
 
390
455
  if len(aux_hidden_states) == 0:
391
456
  return hidden_states
@@ -394,7 +459,6 @@ class Llama4Model(nn.Module):
394
459
 
395
460
 
396
461
  class Llama4ForCausalLM(LlamaForCausalLM):
397
-
398
462
  packed_modules_mapping = {
399
463
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
400
464
  "gate_up_proj": ["gate_proj", "up_proj"],
@@ -408,6 +472,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
408
472
  ):
409
473
  super().__init__(config, quant_config, prefix)
410
474
 
475
+ def get_input_embeddings(self):
476
+ return self.model.embed_tokens
477
+
411
478
  def _init_model(
412
479
  self,
413
480
  config: Llama4TextConfig,
@@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module):
146
146
  self.scaling,
147
147
  num_kv_heads=self.num_kv_heads,
148
148
  layer_id=layer_id,
149
+ quant_config=quant_config,
149
150
  prefix=add_prefix("attn", prefix),
150
151
  )
151
152
 
@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
192
192
  self.scaling,
193
193
  num_kv_heads=self.num_local_heads,
194
194
  layer_id=layer_id,
195
+ quant_config=quant_config,
195
196
  prefix=add_prefix("attn", prefix),
196
197
  )
197
198
 
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
343
344
  num_kv_heads=1,
344
345
  layer_id=layer_id,
345
346
  v_head_dim=self.kv_lora_rank,
347
+ quant_config=quant_config,
346
348
  prefix=add_prefix("attn", prefix),
347
349
  )
348
350
 
@@ -169,6 +169,7 @@ class MixtralAttention(nn.Module):
169
169
  self.scaling,
170
170
  num_kv_heads=self.num_kv_heads,
171
171
  layer_id=layer_id,
172
+ quant_config=quant_config,
172
173
  prefix=add_prefix("attn", prefix),
173
174
  )
174
175
 
@@ -232,6 +232,7 @@ class MixtralAttention(nn.Module):
232
232
  self.scaling,
233
233
  num_kv_heads=self.num_kv_heads,
234
234
  layer_id=layer_id,
235
+ quant_config=quant_config,
235
236
  prefix=add_prefix("attn", prefix),
236
237
  )
237
238
 
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
22
22
  from sglang.srt.layers.linear import (
23
23
  ColumnParallelLinear,
24
24
  QKVParallelLinear,
25
+ ReplicatedLinear,
25
26
  RowParallelLinear,
26
27
  )
27
28
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
184
185
  def __init__(
185
186
  self,
186
187
  config: config_mllama.MllamaVisionConfig,
188
+ quant_config: Optional[QuantizationConfig] = None,
187
189
  is_gated: bool = False,
188
190
  prefix: str = "",
189
191
  ):
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
199
201
  self.num_attention_heads,
200
202
  self.hidden_size,
201
203
  use_qkv_parallel=True,
202
- quant_config=None,
204
+ quant_config=quant_config,
203
205
  dropout=0.0,
204
206
  use_context_forward=False,
205
207
  softmax_in_single_precision=False,
206
208
  flatten_batch=False,
207
209
  prefix=add_prefix("self_attn", prefix),
208
210
  )
209
- self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
211
+ self.mlp = MllamaVisionMLP(
212
+ config, quant_config, prefix=add_prefix("mlp", prefix)
213
+ )
210
214
 
211
215
  self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
212
216
  self.post_attention_layernorm = nn.LayerNorm(
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
244
248
  def __init__(
245
249
  self,
246
250
  config: config_mllama.MllamaVisionConfig,
251
+ quant_config: Optional[QuantizationConfig] = None,
247
252
  num_layers=32,
248
253
  is_gated=False,
249
254
  output_hidden_states=None,
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
254
259
  self.layers = nn.ModuleList(
255
260
  [
256
261
  MllamaVisionEncoderLayer(
257
- config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
262
+ config,
263
+ quant_config,
264
+ is_gated,
265
+ prefix=add_prefix(f"layers.{i}", prefix),
258
266
  )
259
267
  for i in range(num_layers)
260
268
  ]
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
283
291
 
284
292
 
285
293
  class MllamaVisionModel(nn.Module):
286
- def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
294
+ def __init__(
295
+ self,
296
+ config: config_mllama.MllamaVisionConfig,
297
+ quant_config: Optional[QuantizationConfig] = None,
298
+ prefix: str = "",
299
+ ):
287
300
  super().__init__()
288
301
  self.image_size = config.image_size
289
302
  self.patch_size = config.patch_size
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
320
333
  # encoders
321
334
  self.transformer = MllamaVisionEncoder(
322
335
  config,
336
+ quant_config,
323
337
  config.num_hidden_layers,
324
338
  is_gated=False,
325
339
  output_hidden_states=config.intermediate_layers_indices,
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
327
341
  )
328
342
  self.global_transformer = MllamaVisionEncoder(
329
343
  config,
344
+ quant_config,
330
345
  config.num_global_layers,
331
346
  is_gated=True,
332
347
  prefix=add_prefix("global_transformer", prefix),
@@ -535,6 +550,7 @@ class MllamaTextCrossAttention(nn.Module):
535
550
  self.num_local_key_value_heads,
536
551
  layer_id=layer_id,
537
552
  is_cross_attention=True,
553
+ quant_config=quant_config,
538
554
  prefix=add_prefix("attn", prefix),
539
555
  )
540
556
 
@@ -764,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
764
780
 
765
781
 
766
782
  class MllamaForConditionalGeneration(nn.Module):
783
+ # BitandBytes specific attributes
784
+ default_bitsandbytes_target_modules = [
785
+ ".gate_proj.",
786
+ ".down_proj.",
787
+ ".up_proj.",
788
+ ".q_proj.",
789
+ ".k_proj.",
790
+ ".v_proj.",
791
+ ".o_proj.",
792
+ ]
793
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
794
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
795
+ bitsandbytes_stacked_params_mapping = {
796
+ # shard_name, weight_name, index
797
+ "q_proj": ("qkv_proj", 0),
798
+ "k_proj": ("qkv_proj", 1),
799
+ "v_proj": ("qkv_proj", 2),
800
+ "gate_proj": ("gate_up_proj", 0),
801
+ "up_proj": ("gate_up_proj", 1),
802
+ }
803
+
767
804
  def __init__(
768
805
  self,
769
806
  config: config_mllama.MllamaConfig,
@@ -771,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
771
808
  prefix: str = "",
772
809
  ):
773
810
  super().__init__()
811
+ self.quant_config = quant_config
774
812
  self.vocab_size = config.text_config.vocab_size
775
813
  self.hidden_size = config.text_config.hidden_size
776
814
  self.max_num_tiles = config.vision_config.max_num_tiles
@@ -781,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
781
819
  self.image_size = config.vision_config.image_size
782
820
 
783
821
  self.vision_model = MllamaVisionModel(
784
- config.vision_config, prefix=add_prefix("vision_model", prefix)
822
+ config.vision_config,
823
+ quant_config=quant_config,
824
+ prefix=add_prefix("vision_model", prefix),
785
825
  )
786
826
  self.language_model = MllamaForCausalLM(
787
827
  config.text_config,
788
828
  quant_config=quant_config,
789
829
  prefix=add_prefix("language_model", prefix),
790
830
  )
791
- self.multi_modal_projector = nn.Linear(
831
+ self.multi_modal_projector = ReplicatedLinear(
792
832
  config.vision_config.vision_output_dim,
793
833
  config.text_config.hidden_size,
794
834
  bias=True,
835
+ quant_config=quant_config,
836
+ prefix="multi_modal_projector",
795
837
  )
796
838
  self.logits_processor = LogitsProcessor(config.text_config)
797
839
  self.capture_mode = False
@@ -958,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
958
1000
  cross_attention_states = self.vision_model(
959
1001
  batched_images, batched_ar_ids, batched_ar_mask
960
1002
  )
961
- cross_attention_states = self.multi_modal_projector(cross_attention_states)
1003
+ cross_attention_states, _ = self.multi_modal_projector(
1004
+ cross_attention_states
1005
+ )
962
1006
 
963
1007
  bs, _, _, _, image_token_dim = cross_attention_states.shape
964
1008
  cross_attention_states = cross_attention_states.view(
@@ -1012,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
1012
1056
  if "vision_model" in name:
1013
1057
  # adapt to VisionAttention
1014
1058
  name = name.replace("self_attn.o_proj", "self_attn.proj")
1015
-
1016
1059
  param = params_dict.pop(name)
1017
1060
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
1018
1061
  weight_loader(param, loaded_weight)
@@ -1,13 +1,19 @@
1
- # TODO: add Aapted from vllm/mllama4.py
2
1
  from collections.abc import Iterable
3
- from typing import Optional, Set, Tuple
2
+ from typing import List, Optional, Set, Tuple
4
3
 
5
4
  import torch
6
5
  from torch import nn
7
- from transformers import Llama4Config
6
+ from transformers import Llama4Config, Llama4VisionModel
7
+ from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
8
8
 
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessor
10
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
10
11
  from sglang.srt.layers.quantization import QuantizationConfig
12
+ from sglang.srt.managers.mm_utils import (
13
+ MultiModalityDataPaddingPatternImageTokens,
14
+ general_mm_embed_routine,
15
+ )
16
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
11
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
12
18
  from sglang.srt.model_loader.weight_utils import default_weight_loader
13
19
  from sglang.srt.utils import add_prefix
@@ -16,6 +22,7 @@ from sglang.srt.utils import add_prefix
16
22
  class Llama4ForConditionalGeneration(nn.Module):
17
23
  packed_modules_mapping = {
18
24
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
25
+ "gate_up_proj": ["gate_proj", "up_proj"],
19
26
  }
20
27
 
21
28
  def __init__(
@@ -28,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
28
35
  self.config = config
29
36
  self.quant_config = quant_config
30
37
 
38
+ self.vision_model = Llama4VisionModel(config.vision_config)
39
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
40
+
31
41
  # Initialize the language model
32
42
  from sglang.srt.models.llama4 import Llama4ForCausalLM
33
43
 
@@ -39,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
39
49
 
40
50
  self.logits_processor = LogitsProcessor(config.text_config)
41
51
 
52
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
53
+ # Get all special token IDs
54
+ im_token_id: int = mm_inputs.im_token_id
55
+
56
+ pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
57
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
58
+
59
+ def get_image_feature(
60
+ self,
61
+ items: List[MultimodalDataItem],
62
+ ) -> torch.Tensor:
63
+ pixel_values = (
64
+ torch.concat([item.pixel_values for item in items])
65
+ .to(next(self.vision_model.parameters()).device)
66
+ .type(next(self.vision_model.parameters()).dtype)
67
+ )
68
+
69
+ image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
70
+ image_features = image_outputs.last_hidden_state
71
+ vision_flat = image_features.view(-1, image_features.size(-1))
72
+ projected_vision_flat = self.multi_modal_projector(vision_flat)
73
+ return projected_vision_flat
74
+
42
75
  def forward(
43
76
  self,
44
77
  input_ids: torch.Tensor,
@@ -47,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
47
80
  **kwargs: object,
48
81
  ) -> torch.Tensor:
49
82
 
50
- return self.language_model(input_ids, positions, forward_batch)
83
+ hs = general_mm_embed_routine(
84
+ input_ids=input_ids,
85
+ forward_batch=forward_batch,
86
+ language_model=self.language_model,
87
+ image_data_embedding_func=self.get_image_feature,
88
+ positions=positions,
89
+ )
90
+
91
+ return hs
51
92
 
52
93
  def permute_qk_weight_for_rotary(
53
94
  self,
@@ -96,18 +137,27 @@ class Llama4ForConditionalGeneration(nn.Module):
96
137
 
97
138
  num_experts = self.config.text_config.num_local_experts
98
139
 
99
- for name, loaded_weight in weights:
100
-
101
- if name.startswith("vision_model") or name.startswith(
102
- "multi_modal_projector"
103
- ):
104
- continue
140
+ # Params for weights, fp8 weight scales, fp8 activation scales
141
+ # (param_name, weight_name, expert_id, shard_id)
142
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
143
+ ckpt_gate_proj_name="gate_proj",
144
+ ckpt_down_proj_name="down_proj",
145
+ ckpt_up_proj_name="up_proj",
146
+ num_experts=num_experts,
147
+ )
105
148
 
106
- name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
149
+ for name, loaded_weight in weights:
150
+ if not "vision" in name:
151
+ name, loaded_weight = self.permute_qk_weight_for_rotary(
152
+ name, loaded_weight
153
+ )
107
154
 
108
155
  for param_name, weight_name, shard_id in stacked_params_mapping:
109
156
  if weight_name not in name:
110
157
  continue
158
+
159
+ if "vision" in name:
160
+ continue
111
161
  name = name.replace(weight_name, param_name)
112
162
  param = params_dict[name]
113
163
  weight_loader = param.weight_loader
@@ -115,31 +165,54 @@ class Llama4ForConditionalGeneration(nn.Module):
115
165
  break
116
166
  else:
117
167
  if ".experts" in name:
118
- if ".gate_up_proj" in name:
119
- name_list = [
120
- name.replace(".experts.gate_up_proj", ".experts.w13_weight")
121
- ] * 2
122
- loaded_weight_list = loaded_weight.chunk(2, dim=-1)
123
- shard_id_list = ["w1", "w3"]
124
- else:
125
- name_list = [
126
- name.replace(".experts.down_proj", ".experts.w2_weight")
127
- ]
128
- shard_id_list = ["w2"]
129
- loaded_weight_list = [loaded_weight]
130
- for name, loaded_weight, shard_id in zip(
131
- name_list, loaded_weight_list, shard_id_list
168
+ # NOTE: llama4 fp8 has different weight format for experts
169
+ if (
170
+ "experts.gate_up_proj" not in name
171
+ and "experts.down_proj" not in name
132
172
  ):
133
- param = params_dict[name]
134
- weight_loader = param.weight_loader
135
- for expert_id in range(num_experts):
173
+ for mapping in expert_params_mapping:
174
+ param_name, weight_name, expert_id, shard_id = mapping
175
+ if weight_name not in name:
176
+ continue
177
+ name = name.replace(weight_name, param_name)
178
+ param = params_dict[name]
179
+ weight_loader = param.weight_loader
136
180
  weight_loader(
137
181
  param,
138
- loaded_weight[expert_id].T,
182
+ loaded_weight,
139
183
  name,
140
184
  shard_id=shard_id,
141
185
  expert_id=expert_id,
142
186
  )
187
+ break
188
+ else:
189
+ if ".gate_up_proj" in name:
190
+ name_list = [
191
+ name.replace(
192
+ ".experts.gate_up_proj", ".experts.w13_weight"
193
+ )
194
+ ] * 2
195
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
196
+ shard_id_list = ["w1", "w3"]
197
+ else:
198
+ name_list = [
199
+ name.replace(".experts.down_proj", ".experts.w2_weight")
200
+ ]
201
+ shard_id_list = ["w2"]
202
+ loaded_weight_list = [loaded_weight]
203
+ for name, loaded_weight, shard_id in zip(
204
+ name_list, loaded_weight_list, shard_id_list
205
+ ):
206
+ param = params_dict[name]
207
+ weight_loader = param.weight_loader
208
+ for expert_id in range(num_experts):
209
+ weight_loader(
210
+ param,
211
+ loaded_weight[expert_id].T,
212
+ name,
213
+ shard_id=shard_id,
214
+ expert_id=expert_id,
215
+ )
143
216
  else:
144
217
  # Skip loading extra bias for GPTQ models.
145
218
  if name.endswith(".bias") and name not in params_dict:
sglang/srt/models/olmo.py CHANGED
@@ -93,6 +93,7 @@ class OlmoAttention(nn.Module):
93
93
  self.scaling,
94
94
  num_kv_heads=self.num_heads,
95
95
  layer_id=layer_id,
96
+ quant_config=quant_config,
96
97
  prefix=add_prefix("attn", prefix),
97
98
  )
98
99