sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -17,13 +17,14 @@
|
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
20
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import LlamaConfig
|
25
25
|
|
26
26
|
from sglang.srt.distributed import (
|
27
|
+
get_pp_group,
|
27
28
|
get_tensor_model_parallel_rank,
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
)
|
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
39
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
42
|
from sglang.srt.layers.rotary_embedding import get_rope
|
43
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
42
44
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
45
|
ParallelLMHead,
|
44
46
|
VocabParallelEmbedding,
|
45
47
|
)
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
49
|
from sglang.srt.model_loader.weight_utils import (
|
48
50
|
default_weight_loader,
|
49
51
|
kv_cache_scales_loader,
|
@@ -88,7 +90,7 @@ class LlamaMLP(nn.Module):
|
|
88
90
|
)
|
89
91
|
self.act_fn = SiluAndMul()
|
90
92
|
|
91
|
-
def forward(self, x):
|
93
|
+
def forward(self, x, forward_batch=None):
|
92
94
|
gate_up, _ = self.gate_up_proj(x)
|
93
95
|
x = self.act_fn(gate_up)
|
94
96
|
x, _ = self.down_proj(x)
|
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
|
|
275
277
|
self.config = config
|
276
278
|
self.padding_idx = config.pad_token_id
|
277
279
|
self.vocab_size = config.vocab_size
|
278
|
-
self.
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
280
|
+
self.pp_group = get_pp_group()
|
281
|
+
if self.pp_group.is_first_rank:
|
282
|
+
self.embed_tokens = VocabParallelEmbedding(
|
283
|
+
config.vocab_size,
|
284
|
+
config.hidden_size,
|
285
|
+
quant_config=quant_config,
|
286
|
+
prefix=add_prefix("embed_tokens", prefix),
|
287
|
+
)
|
288
|
+
else:
|
289
|
+
self.embed_tokens = PPMissingLayer()
|
290
|
+
|
291
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
285
292
|
config.num_hidden_layers,
|
286
293
|
lambda idx, prefix: LlamaDecoderLayer(
|
287
|
-
config=config,
|
294
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
288
295
|
),
|
296
|
+
pp_rank=self.pp_group.rank_in_group,
|
297
|
+
pp_size=self.pp_group.world_size,
|
289
298
|
prefix="model.layers",
|
290
299
|
)
|
291
300
|
|
292
|
-
self.
|
301
|
+
if self.pp_group.is_last_rank:
|
302
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
303
|
+
else:
|
304
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
293
305
|
self.layers_to_capture = []
|
294
306
|
|
295
307
|
def forward(
|
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
|
|
298
310
|
positions: torch.Tensor,
|
299
311
|
forward_batch: ForwardBatch,
|
300
312
|
input_embeds: torch.Tensor = None,
|
301
|
-
|
302
|
-
|
303
|
-
|
313
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
314
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
|
315
|
+
if self.pp_group.is_first_rank:
|
316
|
+
if input_embeds is None:
|
317
|
+
hidden_states = self.embed_tokens(input_ids)
|
318
|
+
else:
|
319
|
+
hidden_states = input_embeds
|
320
|
+
residual = None
|
304
321
|
else:
|
305
|
-
|
306
|
-
|
322
|
+
assert pp_proxy_tensors is not None
|
323
|
+
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
|
324
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
325
|
+
residual = pp_proxy_tensors["residual"]
|
326
|
+
deferred_norm = None
|
327
|
+
|
307
328
|
aux_hidden_states = []
|
308
|
-
for i in range(
|
329
|
+
for i in range(self.start_layer, self.end_layer):
|
309
330
|
if i in self.layers_to_capture:
|
310
331
|
aux_hidden_states.append(hidden_states + residual)
|
311
332
|
layer = self.layers[i]
|
@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
|
|
315
336
|
forward_batch,
|
316
337
|
residual,
|
317
338
|
)
|
318
|
-
|
339
|
+
|
340
|
+
if not self.pp_group.is_last_rank:
|
341
|
+
return PPProxyTensors(
|
342
|
+
{
|
343
|
+
"hidden_states": hidden_states,
|
344
|
+
"residual": residual,
|
345
|
+
}
|
346
|
+
)
|
347
|
+
else:
|
348
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
319
349
|
|
320
350
|
if len(aux_hidden_states) == 0:
|
321
351
|
return hidden_states
|
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
|
|
376
406
|
prefix: str = "",
|
377
407
|
) -> None:
|
378
408
|
super().__init__()
|
409
|
+
self.pp_group = get_pp_group()
|
379
410
|
self.config = config
|
380
411
|
self.quant_config = quant_config
|
381
412
|
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
|
|
419
450
|
forward_batch: ForwardBatch,
|
420
451
|
input_embeds: torch.Tensor = None,
|
421
452
|
get_embedding: bool = False,
|
453
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
422
454
|
) -> LogitsProcessorOutput:
|
455
|
+
hidden_states = self.model(
|
456
|
+
input_ids,
|
457
|
+
positions,
|
458
|
+
forward_batch,
|
459
|
+
input_embeds,
|
460
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
461
|
+
)
|
462
|
+
|
423
463
|
aux_hidden_states = None
|
424
464
|
if self.capture_aux_hidden_states:
|
425
|
-
hidden_states, aux_hidden_states =
|
426
|
-
|
427
|
-
|
465
|
+
hidden_states, aux_hidden_states = hidden_states
|
466
|
+
|
467
|
+
if self.pp_group.is_last_rank:
|
468
|
+
if not get_embedding:
|
469
|
+
return self.logits_processor(
|
470
|
+
input_ids,
|
471
|
+
hidden_states,
|
472
|
+
self.lm_head,
|
473
|
+
forward_batch,
|
474
|
+
aux_hidden_states,
|
475
|
+
)
|
476
|
+
else:
|
477
|
+
return self.pooler(hidden_states, forward_batch)
|
428
478
|
else:
|
429
|
-
hidden_states
|
430
|
-
input_ids, positions, forward_batch, input_embeds
|
431
|
-
)
|
479
|
+
return hidden_states
|
432
480
|
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
481
|
+
@property
|
482
|
+
def start_layer(self):
|
483
|
+
return self.model.start_layer
|
484
|
+
|
485
|
+
@property
|
486
|
+
def end_layer(self):
|
487
|
+
return self.model.end_layer
|
439
488
|
|
440
489
|
def get_input_embeddings(self) -> nn.Embedding:
|
441
490
|
return self.model.embed_tokens
|
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
|
|
491
540
|
params_dict = dict(self.named_parameters())
|
492
541
|
|
493
542
|
for name, loaded_weight in weights:
|
543
|
+
layer_id = get_layer_id(name)
|
544
|
+
if (
|
545
|
+
layer_id is not None
|
546
|
+
and hasattr(self.model, "start_layer")
|
547
|
+
and (
|
548
|
+
layer_id < self.model.start_layer
|
549
|
+
or layer_id >= self.model.end_layer
|
550
|
+
)
|
551
|
+
):
|
552
|
+
continue
|
494
553
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
495
554
|
continue
|
496
555
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
|
|
637
696
|
self.model.load_kv_cache_scales(quantization_param_path)
|
638
697
|
|
639
698
|
def set_eagle3_layers_to_capture(self):
|
699
|
+
if not self.pp_group.is_last_rank:
|
700
|
+
return
|
701
|
+
|
640
702
|
self.capture_aux_hidden_states = True
|
641
703
|
num_layers = self.config.num_hidden_layers
|
642
704
|
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
sglang/srt/models/llama4.py
CHANGED
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
46
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
47
47
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
50
|
+
ForwardBatch,
|
51
|
+
ForwardMode,
|
52
|
+
PPProxyTensors,
|
53
|
+
)
|
50
54
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
51
55
|
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
52
56
|
|
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
|
|
81
85
|
super().__init__()
|
82
86
|
self.tp_size = get_tensor_model_parallel_world_size()
|
83
87
|
self.top_k = config.num_experts_per_tok
|
88
|
+
self.device_module = torch.get_device_module()
|
84
89
|
|
85
90
|
intermediate_size_moe = config.intermediate_size
|
86
91
|
self.router = ReplicatedLinear(
|
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
|
|
113
118
|
reduce_results=False, # We need to do scatter before reduce
|
114
119
|
)
|
115
120
|
|
116
|
-
def forward(self, hidden_states):
|
121
|
+
def forward(self, hidden_states, forward_batch: ForwardBatch):
|
122
|
+
shared_out, routed_out = self._forward_core(
|
123
|
+
hidden_states, forward_batch.forward_mode
|
124
|
+
)
|
125
|
+
|
126
|
+
out_aD = routed_out + shared_out
|
127
|
+
|
128
|
+
if self.tp_size > 1:
|
129
|
+
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
130
|
+
|
131
|
+
return out_aD
|
132
|
+
|
133
|
+
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
134
|
+
if hidden_states.shape[0] < 4:
|
135
|
+
return self._forward_core_shared_routed_overlap(hidden_states)
|
136
|
+
else:
|
137
|
+
return self._forward_core_normal(hidden_states)
|
138
|
+
|
139
|
+
def _forward_core_normal(self, hidden_states):
|
117
140
|
# router_scores: [num_tokens, num_experts]
|
118
141
|
router_logits, _ = self.router(hidden_states)
|
119
142
|
shared_out = self.shared_expert(hidden_states)
|
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
|
|
121
144
|
hidden_states=hidden_states,
|
122
145
|
router_logits=router_logits,
|
123
146
|
)
|
124
|
-
|
147
|
+
return shared_out, routed_out
|
125
148
|
|
126
|
-
|
127
|
-
|
149
|
+
def _forward_core_shared_routed_overlap(self, hidden_states):
|
150
|
+
alt_stream = _get_or_create_alt_stream(self.device_module)
|
128
151
|
|
129
|
-
|
152
|
+
alt_stream.wait_stream(self.device_module.current_stream())
|
153
|
+
|
154
|
+
shared_out = self.shared_expert(hidden_states)
|
155
|
+
|
156
|
+
with self.device_module.stream(alt_stream):
|
157
|
+
# router_scores: [num_tokens, num_experts]
|
158
|
+
router_logits, _ = self.router(hidden_states)
|
159
|
+
routed_out = self.experts(
|
160
|
+
hidden_states=hidden_states,
|
161
|
+
router_logits=router_logits,
|
162
|
+
)
|
163
|
+
self.device_module.current_stream().wait_stream(alt_stream)
|
164
|
+
|
165
|
+
return shared_out, routed_out
|
166
|
+
|
167
|
+
|
168
|
+
_alt_stream = None
|
169
|
+
|
170
|
+
|
171
|
+
def _get_or_create_alt_stream(device_module):
|
172
|
+
global _alt_stream
|
173
|
+
if _alt_stream is None:
|
174
|
+
_alt_stream = device_module.Stream()
|
175
|
+
return _alt_stream
|
130
176
|
|
131
177
|
|
132
178
|
class Llama4Attention(nn.Module):
|
@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
380
426
|
)
|
381
427
|
|
382
428
|
# Fully Connected
|
383
|
-
hidden_states = self.feed_forward(hidden_states)
|
429
|
+
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
384
430
|
|
385
431
|
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
386
432
|
# Scatter
|
@@ -431,6 +477,7 @@ class Llama4Model(nn.Module):
|
|
431
477
|
positions: torch.Tensor,
|
432
478
|
forward_batch: ForwardBatch,
|
433
479
|
input_embeds: torch.Tensor = None,
|
480
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
434
481
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
435
482
|
if input_embeds is None:
|
436
483
|
hidden_states = self.embed_tokens(input_ids)
|
sglang/srt/models/llama_eagle.py
CHANGED
@@ -25,13 +25,14 @@ import torch
|
|
25
25
|
from torch import nn
|
26
26
|
from transformers import LlamaConfig
|
27
27
|
|
28
|
+
from sglang.srt.distributed import get_pp_group
|
28
29
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
30
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
31
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
31
32
|
ParallelLMHead,
|
32
33
|
VocabParallelEmbedding,
|
33
34
|
)
|
34
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
36
|
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
36
37
|
|
37
38
|
|
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
|
|
86
87
|
positions: torch.Tensor,
|
87
88
|
forward_batch: ForwardBatch,
|
88
89
|
input_embeds: torch.Tensor = None,
|
90
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
89
91
|
) -> torch.Tensor:
|
90
92
|
if input_embeds is None:
|
91
93
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
|
118
120
|
nn.Module.__init__(self)
|
119
121
|
self.config = config
|
120
122
|
self.quant_config = quant_config
|
123
|
+
self.pp_group = get_pp_group()
|
121
124
|
self.model = LlamaModel(
|
122
125
|
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
123
126
|
)
|
@@ -25,6 +25,7 @@ import torch
|
|
25
25
|
from torch import nn
|
26
26
|
from transformers import LlamaConfig
|
27
27
|
|
28
|
+
from sglang.srt.distributed import get_pp_group
|
28
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
29
30
|
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
30
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
33
34
|
ParallelLMHead,
|
34
35
|
VocabParallelEmbedding,
|
35
36
|
)
|
36
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
37
38
|
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
38
39
|
|
39
40
|
|
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
|
|
118
119
|
positions: torch.Tensor,
|
119
120
|
forward_batch: ForwardBatch,
|
120
121
|
input_embeds: torch.Tensor = None,
|
122
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
121
123
|
) -> torch.Tensor:
|
122
124
|
if input_embeds is None:
|
123
125
|
embeds = self.embed_tokens(input_ids)
|
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|
155
157
|
nn.Module.__init__(self)
|
156
158
|
self.config = config
|
157
159
|
self.quant_config = quant_config
|
160
|
+
self.pp_group = get_pp_group()
|
158
161
|
|
159
162
|
if self.config.num_hidden_layers != 1:
|
160
163
|
raise ValueError("EAGLE3 currently only supports 1 layer")
|
sglang/srt/models/minicpmv.py
CHANGED
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|
197
197
|
use_qkv_parallel=True,
|
198
198
|
quant_config=quant_config,
|
199
199
|
dropout=config.attention_dropout,
|
200
|
-
|
200
|
+
qkv_backend="sdpa",
|
201
201
|
softmax_in_single_precision=True,
|
202
202
|
flatten_batch=False,
|
203
203
|
prefix=add_prefix("self_attn", prefix),
|
sglang/srt/models/mllama.py
CHANGED
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
203
203
|
use_qkv_parallel=True,
|
204
204
|
quant_config=quant_config,
|
205
205
|
dropout=0.0,
|
206
|
-
|
206
|
+
qkv_backend="sdpa",
|
207
207
|
softmax_in_single_precision=False,
|
208
208
|
flatten_batch=False,
|
209
209
|
prefix=add_prefix("self_attn", prefix),
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -6,7 +6,7 @@ from torch import nn
|
|
6
6
|
from transformers import Phi3Config
|
7
7
|
from transformers.configuration_utils import PretrainedConfig
|
8
8
|
|
9
|
-
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
9
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
10
10
|
from sglang.srt.layers.linear import (
|
11
11
|
MergedColumnParallelLinear,
|
12
12
|
QKVParallelLinear,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
17
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
19
|
from sglang.srt.layers.rotary_embedding import get_rope
|
20
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
20
21
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
21
22
|
DEFAULT_VOCAB_PADDING_SIZE,
|
22
23
|
ParallelLMHead,
|
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
|
|
294
295
|
super().__init__()
|
295
296
|
|
296
297
|
self.config = config
|
298
|
+
|
299
|
+
self.pp_group = get_pp_group()
|
300
|
+
if self.pp_group.is_first_rank:
|
301
|
+
self.embed_tokens = VocabParallelEmbedding(
|
302
|
+
config.vocab_size,
|
303
|
+
config.hidden_size,
|
304
|
+
prefix=add_prefix("embed_tokens", prefix),
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
self.embed_tokens = PPMissingLayer()
|
308
|
+
|
297
309
|
self.embed_tokens = VocabParallelEmbedding(
|
298
310
|
config.vocab_size,
|
299
311
|
config.hidden_size,
|
300
312
|
prefix=add_prefix("embed_tokens", prefix),
|
301
313
|
)
|
302
314
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
303
|
-
self.
|
315
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
304
316
|
config.num_hidden_layers,
|
305
317
|
lambda idx, prefix: Phi3SmallDecoderLayer(
|
306
318
|
config,
|
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
|
|
308
320
|
quant_config,
|
309
321
|
prefix=prefix,
|
310
322
|
),
|
323
|
+
pp_rank=self.pp_group.rank_in_group,
|
324
|
+
pp_size=self.pp_group.world_size,
|
311
325
|
prefix=add_prefix("layers", prefix),
|
312
326
|
)
|
313
327
|
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
125
125
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
126
126
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
127
127
|
if attn_implementation == "sdpa":
|
128
|
-
use_context_forward = False
|
129
128
|
softmax_in_single_precision = False
|
129
|
+
qkv_backend = "sdpa"
|
130
130
|
flatten_batch = True
|
131
131
|
elif attn_implementation == "flash_attention_2":
|
132
132
|
softmax_in_single_precision = False
|
133
|
-
|
133
|
+
qkv_backend = "triton_attn"
|
134
134
|
flatten_batch = True
|
135
135
|
elif attn_implementation == "eager":
|
136
136
|
softmax_in_single_precision = True
|
137
|
-
|
137
|
+
qkv_backend = "sdpa"
|
138
|
+
flatten_batch = True
|
139
|
+
elif attn_implementation == "flash_attention_3":
|
140
|
+
softmax_in_single_precision = False
|
141
|
+
qkv_backend = "fa3"
|
138
142
|
flatten_batch = True
|
139
143
|
|
140
144
|
self.attn = VisionAttention(
|
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
142
146
|
num_heads=num_heads,
|
143
147
|
projection_size=dim,
|
144
148
|
use_qkv_parallel=True,
|
145
|
-
|
149
|
+
qkv_backend=qkv_backend,
|
146
150
|
softmax_in_single_precision=softmax_in_single_precision,
|
147
151
|
flatten_batch=flatten_batch,
|
148
152
|
quant_config=quant_config,
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import (
|
|
36
36
|
RowParallelLinear,
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
39
40
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
40
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
45
46
|
VocabParallelEmbedding,
|
46
47
|
)
|
47
48
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
49
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
50
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
51
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
52
|
from sglang.srt.utils import add_prefix, make_layers
|
@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
108
110
|
f"the number of experts {config.num_experts}."
|
109
111
|
)
|
110
112
|
|
111
|
-
|
113
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
114
|
+
|
115
|
+
self.experts = MoEImpl(
|
112
116
|
num_experts=config.num_experts,
|
113
117
|
top_k=config.num_experts_per_tok,
|
114
118
|
hidden_size=config.hidden_size,
|
115
119
|
intermediate_size=config.moe_intermediate_size,
|
116
|
-
reduce_results=False,
|
117
120
|
renormalize=config.norm_topk_prob,
|
118
121
|
quant_config=quant_config,
|
119
122
|
prefix=add_prefix("experts", prefix),
|
@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
427
430
|
("gate_up_proj", "up_proj", 1),
|
428
431
|
]
|
429
432
|
|
430
|
-
|
433
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
434
|
+
|
435
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
431
436
|
ckpt_gate_proj_name="gate_proj",
|
432
437
|
ckpt_down_proj_name="down_proj",
|
433
438
|
ckpt_up_proj_name="up_proj",
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
|
|
139
139
|
self.norm2 = norm_layer(dim)
|
140
140
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
141
141
|
if attn_implementation == "sdpa":
|
142
|
-
|
142
|
+
qkv_backend = "sdpa"
|
143
143
|
softmax_in_single_precision = False
|
144
144
|
elif attn_implementation == "flash_attention_2":
|
145
|
+
qkv_backend = "triton_attn"
|
145
146
|
softmax_in_single_precision = False
|
146
|
-
use_context_forward = True
|
147
147
|
elif attn_implementation == "eager":
|
148
|
+
qkv_backend = "sdpa"
|
148
149
|
softmax_in_single_precision = True
|
149
|
-
use_context_forward = False
|
150
150
|
|
151
151
|
self.attn = VisionAttention(
|
152
152
|
embed_dim=dim,
|
153
153
|
num_heads=num_heads,
|
154
154
|
projection_size=dim,
|
155
155
|
use_qkv_parallel=True,
|
156
|
-
|
156
|
+
qkv_backend=qkv_backend,
|
157
157
|
softmax_in_single_precision=softmax_in_single_precision,
|
158
158
|
flatten_batch=True,
|
159
159
|
quant_config=quant_config,
|
@@ -442,18 +442,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
442
442
|
"up_proj": ("gate_up_proj", 1),
|
443
443
|
}
|
444
444
|
|
445
|
-
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
446
|
-
processor = cached_get_processor(self.config._name_or_path)
|
447
|
-
grid_t, grid_h, grid_w = image_grid_thw
|
448
|
-
num_image_tokens = (
|
449
|
-
grid_t
|
450
|
-
* grid_h
|
451
|
-
* grid_w
|
452
|
-
// processor.image_processor.merge_size
|
453
|
-
// processor.image_processor.merge_size
|
454
|
-
)
|
455
|
-
return num_image_tokens
|
456
|
-
|
457
445
|
def __init__(
|
458
446
|
self,
|
459
447
|
config: Qwen2VLConfig,
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
RowParallelLinear,
|
41
41
|
)
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
43
44
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
46
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
49
|
ParallelLMHead,
|
49
50
|
VocabParallelEmbedding,
|
50
51
|
)
|
52
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
51
53
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
54
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
55
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
73
75
|
f"the number of experts {config.num_experts}."
|
74
76
|
)
|
75
77
|
|
76
|
-
|
78
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
79
|
+
|
80
|
+
self.experts = MoEImpl(
|
77
81
|
num_experts=config.num_experts,
|
78
82
|
top_k=config.num_experts_per_tok,
|
79
83
|
hidden_size=config.hidden_size,
|
80
84
|
intermediate_size=config.moe_intermediate_size,
|
81
|
-
reduce_results=False,
|
82
85
|
renormalize=config.norm_topk_prob,
|
83
86
|
quant_config=quant_config,
|
84
87
|
prefix=add_prefix("experts", prefix),
|
@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
356
359
|
("gate_up_proj", "up_proj", 1),
|
357
360
|
]
|
358
361
|
|
359
|
-
|
362
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
363
|
+
|
364
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
360
365
|
ckpt_gate_proj_name="gate_proj",
|
361
366
|
ckpt_down_proj_name="down_proj",
|
362
367
|
ckpt_up_proj_name="up_proj",
|