sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -17,13 +17,14 @@
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
25
 
26
26
  from sglang.srt.distributed import (
27
+ get_pp_group,
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  )
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
39
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
41
  from sglang.srt.layers.radix_attention import RadixAttention
41
42
  from sglang.srt.layers.rotary_embedding import get_rope
43
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
42
44
  from sglang.srt.layers.vocab_parallel_embedding import (
43
45
  ParallelLMHead,
44
46
  VocabParallelEmbedding,
45
47
  )
46
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
49
  from sglang.srt.model_loader.weight_utils import (
48
50
  default_weight_loader,
49
51
  kv_cache_scales_loader,
@@ -88,7 +90,7 @@ class LlamaMLP(nn.Module):
88
90
  )
89
91
  self.act_fn = SiluAndMul()
90
92
 
91
- def forward(self, x):
93
+ def forward(self, x, forward_batch=None):
92
94
  gate_up, _ = self.gate_up_proj(x)
93
95
  x = self.act_fn(gate_up)
94
96
  x, _ = self.down_proj(x)
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
275
277
  self.config = config
276
278
  self.padding_idx = config.pad_token_id
277
279
  self.vocab_size = config.vocab_size
278
- self.embed_tokens = VocabParallelEmbedding(
279
- config.vocab_size,
280
- config.hidden_size,
281
- quant_config=quant_config,
282
- prefix=add_prefix("embed_tokens", prefix),
283
- )
284
- self.layers = make_layers(
280
+ self.pp_group = get_pp_group()
281
+ if self.pp_group.is_first_rank:
282
+ self.embed_tokens = VocabParallelEmbedding(
283
+ config.vocab_size,
284
+ config.hidden_size,
285
+ quant_config=quant_config,
286
+ prefix=add_prefix("embed_tokens", prefix),
287
+ )
288
+ else:
289
+ self.embed_tokens = PPMissingLayer()
290
+
291
+ self.layers, self.start_layer, self.end_layer = make_layers(
285
292
  config.num_hidden_layers,
286
293
  lambda idx, prefix: LlamaDecoderLayer(
287
- config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
294
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
288
295
  ),
296
+ pp_rank=self.pp_group.rank_in_group,
297
+ pp_size=self.pp_group.world_size,
289
298
  prefix="model.layers",
290
299
  )
291
300
 
292
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
+ if self.pp_group.is_last_rank:
302
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
+ else:
304
+ self.norm = PPMissingLayer(return_tuple=True)
293
305
  self.layers_to_capture = []
294
306
 
295
307
  def forward(
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
298
310
  positions: torch.Tensor,
299
311
  forward_batch: ForwardBatch,
300
312
  input_embeds: torch.Tensor = None,
301
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
302
- if input_embeds is None:
303
- hidden_states = self.embed_tokens(input_ids)
313
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
314
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
315
+ if self.pp_group.is_first_rank:
316
+ if input_embeds is None:
317
+ hidden_states = self.embed_tokens(input_ids)
318
+ else:
319
+ hidden_states = input_embeds
320
+ residual = None
304
321
  else:
305
- hidden_states = input_embeds
306
- residual = None
322
+ assert pp_proxy_tensors is not None
323
+ # FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
324
+ hidden_states = pp_proxy_tensors["hidden_states"]
325
+ residual = pp_proxy_tensors["residual"]
326
+ deferred_norm = None
327
+
307
328
  aux_hidden_states = []
308
- for i in range(len(self.layers)):
329
+ for i in range(self.start_layer, self.end_layer):
309
330
  if i in self.layers_to_capture:
310
331
  aux_hidden_states.append(hidden_states + residual)
311
332
  layer = self.layers[i]
@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
315
336
  forward_batch,
316
337
  residual,
317
338
  )
318
- hidden_states, _ = self.norm(hidden_states, residual)
339
+
340
+ if not self.pp_group.is_last_rank:
341
+ return PPProxyTensors(
342
+ {
343
+ "hidden_states": hidden_states,
344
+ "residual": residual,
345
+ }
346
+ )
347
+ else:
348
+ hidden_states, _ = self.norm(hidden_states, residual)
319
349
 
320
350
  if len(aux_hidden_states) == 0:
321
351
  return hidden_states
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
376
406
  prefix: str = "",
377
407
  ) -> None:
378
408
  super().__init__()
409
+ self.pp_group = get_pp_group()
379
410
  self.config = config
380
411
  self.quant_config = quant_config
381
412
  self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
419
450
  forward_batch: ForwardBatch,
420
451
  input_embeds: torch.Tensor = None,
421
452
  get_embedding: bool = False,
453
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
422
454
  ) -> LogitsProcessorOutput:
455
+ hidden_states = self.model(
456
+ input_ids,
457
+ positions,
458
+ forward_batch,
459
+ input_embeds,
460
+ pp_proxy_tensors=pp_proxy_tensors,
461
+ )
462
+
423
463
  aux_hidden_states = None
424
464
  if self.capture_aux_hidden_states:
425
- hidden_states, aux_hidden_states = self.model(
426
- input_ids, positions, forward_batch, input_embeds
427
- )
465
+ hidden_states, aux_hidden_states = hidden_states
466
+
467
+ if self.pp_group.is_last_rank:
468
+ if not get_embedding:
469
+ return self.logits_processor(
470
+ input_ids,
471
+ hidden_states,
472
+ self.lm_head,
473
+ forward_batch,
474
+ aux_hidden_states,
475
+ )
476
+ else:
477
+ return self.pooler(hidden_states, forward_batch)
428
478
  else:
429
- hidden_states = self.model(
430
- input_ids, positions, forward_batch, input_embeds
431
- )
479
+ return hidden_states
432
480
 
433
- if not get_embedding:
434
- return self.logits_processor(
435
- input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
436
- )
437
- else:
438
- return self.pooler(hidden_states, forward_batch)
481
+ @property
482
+ def start_layer(self):
483
+ return self.model.start_layer
484
+
485
+ @property
486
+ def end_layer(self):
487
+ return self.model.end_layer
439
488
 
440
489
  def get_input_embeddings(self) -> nn.Embedding:
441
490
  return self.model.embed_tokens
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
491
540
  params_dict = dict(self.named_parameters())
492
541
 
493
542
  for name, loaded_weight in weights:
543
+ layer_id = get_layer_id(name)
544
+ if (
545
+ layer_id is not None
546
+ and hasattr(self.model, "start_layer")
547
+ and (
548
+ layer_id < self.model.start_layer
549
+ or layer_id >= self.model.end_layer
550
+ )
551
+ ):
552
+ continue
494
553
  if "rotary_emb.inv_freq" in name or "projector" in name:
495
554
  continue
496
555
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
637
696
  self.model.load_kv_cache_scales(quantization_param_path)
638
697
 
639
698
  def set_eagle3_layers_to_capture(self):
699
+ if not self.pp_group.is_last_rank:
700
+ return
701
+
640
702
  self.capture_aux_hidden_states = True
641
703
  num_layers = self.config.num_hidden_layers
642
704
  self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
46
46
  from sglang.srt.layers.rotary_embedding import get_rope
47
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
+ from sglang.srt.model_executor.forward_batch_info import (
50
+ ForwardBatch,
51
+ ForwardMode,
52
+ PPProxyTensors,
53
+ )
50
54
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
51
55
  from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
52
56
 
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
81
85
  super().__init__()
82
86
  self.tp_size = get_tensor_model_parallel_world_size()
83
87
  self.top_k = config.num_experts_per_tok
88
+ self.device_module = torch.get_device_module()
84
89
 
85
90
  intermediate_size_moe = config.intermediate_size
86
91
  self.router = ReplicatedLinear(
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
113
118
  reduce_results=False, # We need to do scatter before reduce
114
119
  )
115
120
 
116
- def forward(self, hidden_states):
121
+ def forward(self, hidden_states, forward_batch: ForwardBatch):
122
+ shared_out, routed_out = self._forward_core(
123
+ hidden_states, forward_batch.forward_mode
124
+ )
125
+
126
+ out_aD = routed_out + shared_out
127
+
128
+ if self.tp_size > 1:
129
+ out_aD = tensor_model_parallel_all_reduce(out_aD)
130
+
131
+ return out_aD
132
+
133
+ def _forward_core(self, hidden_states, forward_mode: ForwardMode):
134
+ if hidden_states.shape[0] < 4:
135
+ return self._forward_core_shared_routed_overlap(hidden_states)
136
+ else:
137
+ return self._forward_core_normal(hidden_states)
138
+
139
+ def _forward_core_normal(self, hidden_states):
117
140
  # router_scores: [num_tokens, num_experts]
118
141
  router_logits, _ = self.router(hidden_states)
119
142
  shared_out = self.shared_expert(hidden_states)
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
121
144
  hidden_states=hidden_states,
122
145
  router_logits=router_logits,
123
146
  )
124
- out_aD = routed_out + shared_out
147
+ return shared_out, routed_out
125
148
 
126
- if self.tp_size > 1:
127
- out_aD = tensor_model_parallel_all_reduce(out_aD)
149
+ def _forward_core_shared_routed_overlap(self, hidden_states):
150
+ alt_stream = _get_or_create_alt_stream(self.device_module)
128
151
 
129
- return out_aD
152
+ alt_stream.wait_stream(self.device_module.current_stream())
153
+
154
+ shared_out = self.shared_expert(hidden_states)
155
+
156
+ with self.device_module.stream(alt_stream):
157
+ # router_scores: [num_tokens, num_experts]
158
+ router_logits, _ = self.router(hidden_states)
159
+ routed_out = self.experts(
160
+ hidden_states=hidden_states,
161
+ router_logits=router_logits,
162
+ )
163
+ self.device_module.current_stream().wait_stream(alt_stream)
164
+
165
+ return shared_out, routed_out
166
+
167
+
168
+ _alt_stream = None
169
+
170
+
171
+ def _get_or_create_alt_stream(device_module):
172
+ global _alt_stream
173
+ if _alt_stream is None:
174
+ _alt_stream = device_module.Stream()
175
+ return _alt_stream
130
176
 
131
177
 
132
178
  class Llama4Attention(nn.Module):
@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
380
426
  )
381
427
 
382
428
  # Fully Connected
383
- hidden_states = self.feed_forward(hidden_states)
429
+ hidden_states = self.feed_forward(hidden_states, forward_batch)
384
430
 
385
431
  # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
386
432
  # Scatter
@@ -431,6 +477,7 @@ class Llama4Model(nn.Module):
431
477
  positions: torch.Tensor,
432
478
  forward_batch: ForwardBatch,
433
479
  input_embeds: torch.Tensor = None,
480
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
434
481
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
435
482
  if input_embeds is None:
436
483
  hidden_states = self.embed_tokens(input_ids)
@@ -25,13 +25,14 @@ import torch
25
25
  from torch import nn
26
26
  from transformers import LlamaConfig
27
27
 
28
+ from sglang.srt.distributed import get_pp_group
28
29
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
30
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
31
  from sglang.srt.layers.vocab_parallel_embedding import (
31
32
  ParallelLMHead,
32
33
  VocabParallelEmbedding,
33
34
  )
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
36
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
36
37
 
37
38
 
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
86
87
  positions: torch.Tensor,
87
88
  forward_batch: ForwardBatch,
88
89
  input_embeds: torch.Tensor = None,
90
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
89
91
  ) -> torch.Tensor:
90
92
  if input_embeds is None:
91
93
  hidden_states = self.embed_tokens(input_ids)
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
118
120
  nn.Module.__init__(self)
119
121
  self.config = config
120
122
  self.quant_config = quant_config
123
+ self.pp_group = get_pp_group()
121
124
  self.model = LlamaModel(
122
125
  config, quant_config=quant_config, prefix=add_prefix("model", prefix)
123
126
  )
@@ -25,6 +25,7 @@ import torch
25
25
  from torch import nn
26
26
  from transformers import LlamaConfig
27
27
 
28
+ from sglang.srt.distributed import get_pp_group
28
29
  from sglang.srt.layers.layernorm import RMSNorm
29
30
  from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
30
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
33
34
  ParallelLMHead,
34
35
  VocabParallelEmbedding,
35
36
  )
36
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
37
38
  from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
38
39
 
39
40
 
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
118
119
  positions: torch.Tensor,
119
120
  forward_batch: ForwardBatch,
120
121
  input_embeds: torch.Tensor = None,
122
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
121
123
  ) -> torch.Tensor:
122
124
  if input_embeds is None:
123
125
  embeds = self.embed_tokens(input_ids)
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
155
157
  nn.Module.__init__(self)
156
158
  self.config = config
157
159
  self.quant_config = quant_config
160
+ self.pp_group = get_pp_group()
158
161
 
159
162
  if self.config.num_hidden_layers != 1:
160
163
  raise ValueError("EAGLE3 currently only supports 1 layer")
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
197
197
  use_qkv_parallel=True,
198
198
  quant_config=quant_config,
199
199
  dropout=config.attention_dropout,
200
- use_context_forward=False,
200
+ qkv_backend="sdpa",
201
201
  softmax_in_single_precision=True,
202
202
  flatten_batch=False,
203
203
  prefix=add_prefix("self_attn", prefix),
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
203
203
  use_qkv_parallel=True,
204
204
  quant_config=quant_config,
205
205
  dropout=0.0,
206
- use_context_forward=False,
206
+ qkv_backend="sdpa",
207
207
  softmax_in_single_precision=False,
208
208
  flatten_batch=False,
209
209
  prefix=add_prefix("self_attn", prefix),
@@ -6,7 +6,7 @@ from torch import nn
6
6
  from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
8
 
9
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
9
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
10
  from sglang.srt.layers.linear import (
11
11
  MergedColumnParallelLinear,
12
12
  QKVParallelLinear,
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
17
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
19
  from sglang.srt.layers.rotary_embedding import get_rope
20
+ from sglang.srt.layers.utils import PPMissingLayer
20
21
  from sglang.srt.layers.vocab_parallel_embedding import (
21
22
  DEFAULT_VOCAB_PADDING_SIZE,
22
23
  ParallelLMHead,
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
294
295
  super().__init__()
295
296
 
296
297
  self.config = config
298
+
299
+ self.pp_group = get_pp_group()
300
+ if self.pp_group.is_first_rank:
301
+ self.embed_tokens = VocabParallelEmbedding(
302
+ config.vocab_size,
303
+ config.hidden_size,
304
+ prefix=add_prefix("embed_tokens", prefix),
305
+ )
306
+ else:
307
+ self.embed_tokens = PPMissingLayer()
308
+
297
309
  self.embed_tokens = VocabParallelEmbedding(
298
310
  config.vocab_size,
299
311
  config.hidden_size,
300
312
  prefix=add_prefix("embed_tokens", prefix),
301
313
  )
302
314
  self.mup_embedding_multiplier = config.mup_embedding_multiplier
303
- self.start_layer, self.end_layer, self.layers = make_layers(
315
+ self.layers, self.start_layer, self.end_layer = make_layers(
304
316
  config.num_hidden_layers,
305
317
  lambda idx, prefix: Phi3SmallDecoderLayer(
306
318
  config,
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
308
320
  quant_config,
309
321
  prefix=prefix,
310
322
  ),
323
+ pp_rank=self.pp_group.rank_in_group,
324
+ pp_size=self.pp_group.world_size,
311
325
  prefix=add_prefix("layers", prefix),
312
326
  )
313
327
 
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
125
125
  self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
126
126
  self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
127
127
  if attn_implementation == "sdpa":
128
- use_context_forward = False
129
128
  softmax_in_single_precision = False
129
+ qkv_backend = "sdpa"
130
130
  flatten_batch = True
131
131
  elif attn_implementation == "flash_attention_2":
132
132
  softmax_in_single_precision = False
133
- use_context_forward = True
133
+ qkv_backend = "triton_attn"
134
134
  flatten_batch = True
135
135
  elif attn_implementation == "eager":
136
136
  softmax_in_single_precision = True
137
- use_context_forward = False
137
+ qkv_backend = "sdpa"
138
+ flatten_batch = True
139
+ elif attn_implementation == "flash_attention_3":
140
+ softmax_in_single_precision = False
141
+ qkv_backend = "fa3"
138
142
  flatten_batch = True
139
143
 
140
144
  self.attn = VisionAttention(
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
142
146
  num_heads=num_heads,
143
147
  projection_size=dim,
144
148
  use_qkv_parallel=True,
145
- use_context_forward=use_context_forward,
149
+ qkv_backend=qkv_backend,
146
150
  softmax_in_single_precision=softmax_in_single_precision,
147
151
  flatten_batch=flatten_batch,
148
152
  quant_config=quant_config,
@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import (
36
36
  RowParallelLinear,
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
39
40
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
40
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
46
  VocabParallelEmbedding,
46
47
  )
47
48
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
49
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
50
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
51
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
52
  from sglang.srt.utils import add_prefix, make_layers
@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
108
110
  f"the number of experts {config.num_experts}."
109
111
  )
110
112
 
111
- self.experts = FusedMoE(
113
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
114
+
115
+ self.experts = MoEImpl(
112
116
  num_experts=config.num_experts,
113
117
  top_k=config.num_experts_per_tok,
114
118
  hidden_size=config.hidden_size,
115
119
  intermediate_size=config.moe_intermediate_size,
116
- reduce_results=False,
117
120
  renormalize=config.norm_topk_prob,
118
121
  quant_config=quant_config,
119
122
  prefix=add_prefix("experts", prefix),
@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module):
427
430
  ("gate_up_proj", "up_proj", 1),
428
431
  ]
429
432
 
430
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
433
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
434
+
435
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
431
436
  ckpt_gate_proj_name="gate_proj",
432
437
  ckpt_down_proj_name="down_proj",
433
438
  ckpt_up_proj_name="up_proj",
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
139
139
  self.norm2 = norm_layer(dim)
140
140
  mlp_hidden_dim = int(dim * mlp_ratio)
141
141
  if attn_implementation == "sdpa":
142
- use_context_forward = False
142
+ qkv_backend = "sdpa"
143
143
  softmax_in_single_precision = False
144
144
  elif attn_implementation == "flash_attention_2":
145
+ qkv_backend = "triton_attn"
145
146
  softmax_in_single_precision = False
146
- use_context_forward = True
147
147
  elif attn_implementation == "eager":
148
+ qkv_backend = "sdpa"
148
149
  softmax_in_single_precision = True
149
- use_context_forward = False
150
150
 
151
151
  self.attn = VisionAttention(
152
152
  embed_dim=dim,
153
153
  num_heads=num_heads,
154
154
  projection_size=dim,
155
155
  use_qkv_parallel=True,
156
- use_context_forward=use_context_forward,
156
+ qkv_backend=qkv_backend,
157
157
  softmax_in_single_precision=softmax_in_single_precision,
158
158
  flatten_batch=True,
159
159
  quant_config=quant_config,
@@ -442,18 +442,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
442
442
  "up_proj": ("gate_up_proj", 1),
443
443
  }
444
444
 
445
- def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
446
- processor = cached_get_processor(self.config._name_or_path)
447
- grid_t, grid_h, grid_w = image_grid_thw
448
- num_image_tokens = (
449
- grid_t
450
- * grid_h
451
- * grid_w
452
- // processor.image_processor.merge_size
453
- // processor.image_processor.merge_size
454
- )
455
- return num_image_tokens
456
-
457
445
  def __init__(
458
446
  self,
459
447
  config: Qwen2VLConfig,
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  RowParallelLinear,
41
41
  )
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
43
44
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
46
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
49
  ParallelLMHead,
49
50
  VocabParallelEmbedding,
50
51
  )
52
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
51
53
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
54
  from sglang.srt.model_loader.weight_utils import default_weight_loader
53
55
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
73
75
  f"the number of experts {config.num_experts}."
74
76
  )
75
77
 
76
- self.experts = FusedMoE(
78
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
79
+
80
+ self.experts = MoEImpl(
77
81
  num_experts=config.num_experts,
78
82
  top_k=config.num_experts_per_tok,
79
83
  hidden_size=config.hidden_size,
80
84
  intermediate_size=config.moe_intermediate_size,
81
- reduce_results=False,
82
85
  renormalize=config.norm_topk_prob,
83
86
  quant_config=quant_config,
84
87
  prefix=add_prefix("experts", prefix),
@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module):
356
359
  ("gate_up_proj", "up_proj", 1),
357
360
  ]
358
361
 
359
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
362
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
363
+
364
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
360
365
  ckpt_gate_proj_name="gate_proj",
361
366
  ckpt_down_proj_name="down_proj",
362
367
  ckpt_up_proj_name="up_proj",